Commit de548dd3 by xiaotong

bu fixes

parent 51b4da42
......@@ -40,17 +40,12 @@ using namespace transformer;
int main( int argc, const char ** argv )
{
//TransposeTest();
//return 0;
//SumDimTest();
//return 0;
//_CrtSetBreakAlloc(896);
if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1);
TransformerMain(argc - 1, argv + 1);
else{
fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
fprintf(stderr, "neural networks in an easy way. \n\n");
......@@ -58,37 +53,6 @@ int main( int argc, const char ** argv )
fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n");
}
return 0;
XNet net;
XTensor a;
XTensor b;
XTensor c;
InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 4);
InitTensor2D(&c, 2, 4);
a.SetZeroAll();
b.SetZeroAll();
c.SetZeroAll();
SetDataFixed(a, 0.1F);
a.Set2D(0.3F, 1, 0);
a.Set2D(0.4F, 1, 1);
b = Merge(a, a, 1);
c = HTanH(MMul(a, b));
a.Dump(stderr, "a:");
b.Dump(stderr, "b:");
c.Dump(stderr, "c:");
XLink::ShowNetwork(stderr, &c);
net.Backward(c);
net.Dump(stderr);
//_CrtDumpMemoryLeaks();
return 0;
......
......@@ -46,6 +46,11 @@ unsigned int MakeNetID()
return id;
}
void XNetClearAll()
{
MUTEX_DELE(netMutex);
}
/* constructor */
XNet::XNet()
{
......
......@@ -95,6 +95,7 @@ struct XNet
extern unsigned int netIDGlobal;
extern MUTEX_HANDLE netMutex;
extern unsigned int MakeNetID();
extern void XNetClearAll();
}
......
......@@ -240,6 +240,7 @@ void Check(FNNModel &model)
{
CheckErrors(model.n > 0 && model.n <= MAX_N_GRAM, "The LM order is out of range (use -n)!");
CheckErrors(model.vSize > 0, "no vocabulary size found (use -vsize)!");
CheckErrors(model.eSize > 0, "no embedding size found (use -esize)!");
}
/* make a hard copy of the fnn model */
......@@ -632,8 +633,10 @@ int LoadNGrams(FILE * file, int n, NGram * ngrams, int sentNum, int wordNum)
if(pin <= 0){
int len = (int)strlen(lineBuf);
if(lineBuf[len - 1] == '\r')
while(lineBuf[len - 1] == '\r' || lineBuf[len - 1] == '\n'){
lineBuf[len - 1] = 0;
len--;
}
len = (int)strlen(lineBuf);
if(len == 0)
......@@ -644,10 +647,11 @@ int LoadNGrams(FILE * file, int n, NGram * ngrams, int sentNum, int wordNum)
/* how many words are in the sentence */
int wNum = 0;
int i = 0;
for(int i = pin; i < len; i++){
for(i = pin; i < len; i++){
/* load word (id) seperated by space or tab */
if((lineBuf[i] == ' ' || lineBuf[i] == '\t' || i == len - 1) && wSize > 0){
if((lineBuf[i] == ' ' || lineBuf[i] == '\t') && wSize > 0){
lineBuf[i] = 0;
wordBuf[wNum++] = atoi(lineBuf + i - wSize);
wSize = 0;
......@@ -656,6 +660,9 @@ int LoadNGrams(FILE * file, int n, NGram * ngrams, int sentNum, int wordNum)
wSize++;
}
if(wSize > 0)
wordBuf[wNum++] = atoi(lineBuf + i - wSize);
wordBufCount = wNum;
lineNum++;
}
......
......@@ -80,16 +80,16 @@ make the network
>> v - values
<< return - multi-attention result
*/
XTensor * T2TAttention::Make(XTensor * k, XTensor * q, XTensor * v)
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
{
XTensor k2;
XTensor q2;
XTensor v2;
/* linear transofmration before self-attention */
k2 = MMul(*k, wk);
q2 = MMul(*q, wq);
v2 = MMul(*v, wv);
k2 = MMul(k, wk);
q2 = MMul(q, wq);
v2 = MMul(v, wv);
XTensor kheads;
XTensor qheads;
......@@ -107,12 +107,8 @@ XTensor * T2TAttention::Make(XTensor * k, XTensor * q, XTensor * v)
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/sqrt((float)dk)), -1);
att = BMMul(scalar, vheads);
XTensor * result = new XTensor();
/* concatenate the heads */
*result = Merge(att, att.order - 1);
return result;
return Merge(att, att.order - 1);
}
}
......@@ -77,7 +77,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * k, XTensor * q, XTensor * v);
XTensor Make(XTensor &k, XTensor &q, XTensor &v);
};
}
......
......@@ -101,21 +101,21 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
/*
make the network
*/
XTensor * T2TEmbedder::Make(XTensor * input)
XTensor T2TEmbedder::Make(XTensor &input)
{
CheckNTErrors(input->GetDim(-1) == vSize, "Wrong vocabulary size!");
CheckNTErrors(input->order > 1, "Wrong input tensor size!");
CheckNTErrors(input->dimSize[input->order - 2] < maxLength, "The sequence is too long!");
CheckNTErrors(input.GetDim(-1) == vSize, "Wrong vocabulary size!");
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 2] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, input->dimSize, input->order * sizeof(int));
dims[input->order - 1] = eSize;
memcpy(dims, input.dimSize, input.order * sizeof(int));
dims[input.order - 1] = eSize;
bool match = (posEmbedding.order == input->order);
bool match = (posEmbedding.order == input.order);
if(match){
for(int i = 0; i < input->order; i++){
for(int i = 0; i < input.order; i++){
if(dims[i] != posEmbedding.GetDim(i))
match = false;
}
......@@ -123,7 +123,7 @@ XTensor * T2TEmbedder::Make(XTensor * input)
/* we make positional embeddings first */
if(!match){
InitTensor(&posEmbedding, input->order, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor(&posEmbedding, input.order, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
_CopyValues(&posEmbeddingBase, 0, posTMP->unitNum, posTMP, 0);
......@@ -135,14 +135,10 @@ XTensor * T2TEmbedder::Make(XTensor * input)
XTensor wordEmbedding;
/* then we make word embeddings */
wordEmbedding = MMul(*input, w);
XTensor * result = new XTensor();
wordEmbedding = MMul(&input, w);
/* we sum over the two embeddings */
*result = wordEmbedding + posEmbedding;
return result;
return wordEmbedding + posEmbedding;
}
}
......@@ -77,7 +77,7 @@ public:
void MakePosEmbedding(int eSize, int d, int length);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
}
......
......@@ -82,26 +82,28 @@ make the encoding network
>> input - the input tensor of the encoder
<< return - the output tensor of the encoder
*/
XTensor * AttEncoder::Make(XTensor * input)
XTensor AttEncoder::Make(XTensor &input)
{
XTensor * x = embedder.Make(input);
XTensor x;
x = embedder.Make(input);
for(int i = 0; i < nlayer; i++){
XTensor * att;
XTensor * ln;
XTensor * fnn;
XTensor att;
XTensor ln;
XTensor fnn;
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x);
/* residual connection */
res = Sum(*att, *x);
res = Sum(att, x);
/* TODO: dropout */
/* layer normalization */
ln = layerNorms[i].Make(&res);
ln = layerNorms[i].Make(res);
/* input of next layer */
x = ln;
......@@ -110,12 +112,12 @@ XTensor * AttEncoder::Make(XTensor * input)
fnn = fnns[i].Make(x);
/* residual connection */
res = Sum(*fnn, *x);
res = Sum(fnn, x);
/* TODO: dropout */
/* layer normalization */
ln = layerNorms[i].Make(&res);
ln = layerNorms[i].Make(res);
/* input of next layer */
x = ln;
......
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor * Make(XTensor * input) = 0;
XTensor Make(XTensor &input) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
......@@ -106,7 +106,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
......
......@@ -78,18 +78,15 @@ y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor
*/
XTensor * T2TFNN::Make(XTensor * input)
XTensor T2TFNN::Make(XTensor &input)
{
XTensor t1;
XTensor * result = new XTensor();
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(*input, X_NOTRANS, w1, X_NOTRANS) + b1);
t1 = Rectify(MMul(input, X_NOTRANS, w1, X_NOTRANS) + b1);
/* result = t1 * w2 + b2 */
*result = MMul(t1, X_NOTRANS, w2, X_NOTRANS) + b2;
return result;
return MMul(t1, X_NOTRANS, w2, X_NOTRANS) + b2;
}
......
......@@ -72,7 +72,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
......
......@@ -57,15 +57,14 @@ y =
>> input - the input tensor
>> return - layer normalization output
*/
XTensor * T2TLN::Make(XTensor * input)
XTensor T2TLN::Make(XTensor &input)
{
XTensor &x = *input;
XTensor &x = input;
XTensor mean;
XTensor variance;
XTensor standard;
XTensor meanFilled;
XTensor standardFilled;
XTensor * result = new XTensor();
/* \mu = (sum_i x_i)/m */
mean = ReduceSum(x, x.order - 1);
......@@ -82,9 +81,7 @@ XTensor * T2TLN::Make(XTensor * input)
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */
*result = (x - meanFilled)/standardFilled;
return result;
return (x - meanFilled)/standardFilled;
}
}
......@@ -49,7 +49,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
}
......
......@@ -69,7 +69,7 @@ make the encoding network
>> input - input tensor
<< return - encoding result
*/
XTensor * T2TModel::MakeEncoding(XTensor * input)
XTensor T2TModel::MakeEncoding(XTensor &input)
{
return encoder.Make(input);
}
......@@ -79,10 +79,12 @@ make the entire network (with the output softmax layer)
>> input - input tensor
>> output - output tensor (distribution)
*/
void T2TModel::Make(XTensor * input, XTensor * output)
void T2TModel::Make(XTensor &input, XTensor &output)
{
if(isLM){
XTensor * encoding = MakeEncoding(input);
XTensor encoding;
encoding = MakeEncoding(input);
outputLayer.Make(encoding, output);
}
else{
......
......@@ -66,10 +66,10 @@ public:
void InitModel(int argc, const char ** argv);
/* make the encoding network */
XTensor * MakeEncoding(XTensor * input);
XTensor MakeEncoding(XTensor &input);
/* make the entire network (with the output softmax layer) */
void Make(XTensor * input, XTensor * output);
void Make(XTensor &input, XTensor &output);
};
}
......
......@@ -53,11 +53,15 @@ void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMe
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID, mem);
w.SetDataRand(-minmax, minmax);
}
/*
......@@ -66,14 +70,11 @@ y = softmax(x * w)
>> input - input tensor
<< return - output tensor
*/
XTensor * T2TOutput::Make(XTensor * input)
XTensor T2TOutput::Make(XTensor &input)
{
XTensor &x = *input;
XTensor * result = new XTensor();
*result = LogSoftmax(MMul(x, w), -1);
XTensor &x = input;
return result;
return LogSoftmax(MMul(x, w), -1);
}
/*
......@@ -81,11 +82,11 @@ make the network (redefined output tensor)
>> input - input tensor
>> output - output tensor
*/
void T2TOutput::Make(XTensor * input, XTensor * output)
void T2TOutput::Make(XTensor &input, XTensor &output)
{
XTensor &x = *input;
XTensor &x = input;
*output = LogSoftmax(MMul(x, w), -1);
output = LogSoftmax(MMul(x, w), -1);
}
}
\ No newline at end of file
......@@ -62,10 +62,10 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
/* make the network (redefined output tensor) */
void Make(XTensor * input, XTensor * output);
void Make(XTensor &input, XTensor &output);
};
......
......@@ -43,6 +43,7 @@ T2TTrainer::~T2TTrainer()
{
delete[] buf;
delete[] seqLen;
delete[] seqOffset;
}
/*
......@@ -96,18 +97,19 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
/* batch of input sequences */
XTensor batch;
/* output probabilities */
XTensor output;
while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
/* output probabilities */
XTensor output;
/* make the network */
model->Make(&batch, &output);
model->Make(batch, output);
/* back-propagation for obtaining gradients */
net.Backward(output, batch, CROSSENTROPY);
/* TODO: update the model!!!! */
/* update the parameters */
Update(model);
/* get probabilities */
float prob = GetProb(&output, &batch, NULL);
......@@ -121,7 +123,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
break;
}
if (step % 100 == 0) {
if (step % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
......@@ -153,8 +155,10 @@ int T2TTrainer::LoadBuf(FILE * file)
while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
int len = (int)strlen(line);
if(line[len - 1] == '\r')
while(line[len - 1] == '\r' || line[len - 1] == '\n'){
line[len - 1] = 0;
len--;
}
len = (int)strlen(line);
if(len == 0)
......@@ -166,10 +170,11 @@ int T2TTrainer::LoadBuf(FILE * file)
/* how many words are in the sentence */
int wNum = 0;
int wNumLocal = 0;
int i = 0;
for(int i = 0; i < len; i++){
for(i = 0; i < len; i++){
/* load word (id) seperated by space or tab */
if((line[i] == ' ' || line[i] == '\t' || i == len - 1) && wSize > 0){
if((line[i] == ' ' || line[i] == '\t') && wSize > 0){
line[i] = 0;
if(wSize == 3 && line[i - 1] == '|' && line[i - 2] == '|' && line[i - 3] == '|'){
......@@ -179,7 +184,7 @@ int T2TTrainer::LoadBuf(FILE * file)
wNumLocal = 0;
}
else{
buf[wNum++] = atoi(line + i - wSize);
buf[wordCount + wNum++] = atoi(line + i - wSize);
wNumLocal++;
}
......@@ -189,6 +194,11 @@ int T2TTrainer::LoadBuf(FILE * file)
wSize++;
}
if(wSize > 0){
buf[wordCount + wNum++] = atoi(line + i - wSize);
wNumLocal++;
}
seqLen[seqCount] = wNumLocal;
seqOffset[seqCount] = wordCount + wNum - wNumLocal;
seqCount++;
......@@ -305,4 +315,35 @@ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
return result.Get1D(0);
}
/*
update the model by delta rule
>> model - the t2t model
*/
void T2TTrainer::Update(T2TModel * model)
{
XList ws(100);
ws.Add(&model->outputLayer.w);
for(int i = 0; i < model->encoder.nlayer; i++){
ws.Add(&model->encoder.fnns[i].w1);
ws.Add(&model->encoder.fnns[i].b1);
ws.Add(&model->encoder.fnns[i].w2);
ws.Add(&model->encoder.fnns[i].b2);
}
ws.Add(&model->encoder.embedder.w);
for(int i = 0; i < ws.count; i++){
XTensor * para = (XTensor*)ws.Get(i);
XTensor * paraGrad = para->grad;
CheckNTErrors(para != NULL, "NULL parameter tensor!");
CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!");
/* the delta rule */
_Sum(para, paraGrad, para, -lrate);
}
}
}
......@@ -103,6 +103,9 @@ public:
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
/* update the model by delta rule */
void Update(T2TModel * model);
};
......
......@@ -23,6 +23,7 @@
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "../../tensor/XDevice.h"
namespace transformer
{
......
......@@ -1042,11 +1042,11 @@ set the value of a cell in a 3d tensor in default type
*/
bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
{
CheckNTErrors((order == 3), "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((d0 >= 0 && d0 < dimSize[0]), "dimension 0 is out of range!");
CheckNTErrors((d2 >= 0 && d1 < dimSize[1]), "dimension 1 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[2]), "dimension 1 is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type.");
CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d1};
......
......@@ -162,10 +162,10 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
_CudaBLASMatrixMULList(handle,
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
BacktoCudaDev(a->devID, devIDBackup);
#else
......
......@@ -117,14 +117,19 @@ void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
blockNum *= a->dimSizeRDI[i];
}
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
_CudaBLASMatrixMULBatchedStrided(handle,
a->data, transposedA, a->dataType, aBlockSize,
b->data, transposedB, b->dataType, bBlockSize,
c->data, c->dataType, cBlockSize, blockNum,
a->dimSizeRDI[1], a->dimSizeRDI[0],
b->dimSizeRDI[1], b->dimSizeRDI[0],
c->dimSizeRDI[1], c->dimSizeRDI[0], alpha, beta);
a->data, transposedA, a->dataType, aBlockSize,
b->data, transposedB, b->dataType, bBlockSize,
c->data, c->dataType, cBlockSize, blockNum,
a->dimSizeRDI[1], a->dimSizeRDI[0],
b->dimSizeRDI[1], b->dimSizeRDI[0],
c->dimSizeRDI[1], c->dimSizeRDI[0], alpha, beta);
BacktoCudaDev(a->devID, devIDBackup);
#endif
}
......
......@@ -22,6 +22,7 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../movement/CopyValues.h"
#include "Sum.h"
#include "Sum.cuh"
#include "SumDim.h"
......@@ -44,8 +45,12 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!");
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
if(beta == 0){
_CopyValues(a, c);
return;
}
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
......
......@@ -110,7 +110,7 @@ void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
int cudaBlockSize[3];
GDevs.GetCudaThread2D(input->devID, strideNum, stride * blockNum,
MAX_INT, cudaGridSize, cudaBlockSize);
MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]);
dim3 threads(cudaBlockSize[1], cudaBlockSize[0]);
......@@ -119,9 +119,9 @@ void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
ProtectCudaDev(a->devID, devIDBackup);
KernelNormalize << <blocks, threads >> >((DTYPE*)input->data, (DTYPE*)output->data,
(DTYPE*)mean->data, (DTYPE*)var->data,
(DTYPE*)a->data, (DTYPE*)b->data, epsilon,
stride, strideNum, blockNum);
(DTYPE*)mean->data, (DTYPE*)var->data,
(DTYPE*)a->data, (DTYPE*)b->data, epsilon,
stride, strideNum, blockNum);
BacktoCudaDev(a->devID, devIDBackup);
}
......
......@@ -109,6 +109,9 @@ void _CudaMergeBlockLists(const XList * sourceList, int * blockSizes, int blockN
CheckNTErrors((maxBlockSize % sizeof(DTYPE) == 0), "Unsupported block size!");
realMaxBlockSize = maxBlockSize / sizeof(DTYPE);
int devIDBackup;
ProtectCudaDev(myMem->devID, devIDBackup);
int cudaGridSizes[3];
int cudaBlockSizes[3];
......@@ -135,6 +138,8 @@ void _CudaMergeBlockLists(const XList * sourceList, int * blockSizes, int blockN
delete[] targetArrays;
delete[] sizes;
delete[] offsets;
BacktoCudaDev(myMem->devID, devIDBackup);
}
#endif // USE_CUDA
......
......@@ -150,11 +150,10 @@ void _LogSoftmax(const XTensor * x, XTensor * y, int leadDim)
}
}
if (x->devID < 0) {
DelTensorBuf(max);
DelTensorBuf(sum);
}
else {
DelTensorBuf(max);
DelTensorBuf(sum);
if (x->devID >= 0) {
delete blockx;
delete blocky;
delete blockMax;
......
......@@ -239,6 +239,9 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
CheckNTErrors((x->devID == y->devID), "Matrices used in log softmax are not on the same GPU.");
CheckNTErrors((y->order >= 1), "Empty tensor!");
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
CheckNTErrors((lossName == CROSSENTROPY ||
......@@ -284,8 +287,14 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
/* make a matrix to keep \beta */
XTensor * beta = new XTensor(y->order - 1, dimSize, y->dataType, y->denseRatio, y->devID, mem);
ytmp->data = mem->AllocBuf(mem->devID, y->unitNum * y->unitSize);
beta->data = mem->AllocBuf(mem->devID, beta->unitNum * beta->unitSize);
if(mem != NULL){
ytmp->data = mem->AllocBuf(mem->devID, y->unitNum * y->unitSize);
beta->data = mem->AllocBuf(mem->devID, beta->unitNum * beta->unitSize);
}
else{
ytmp->data = XMemAlloc(y->devID, y->unitNum * y->unitSize);
beta->data = XMemAlloc(y->devID, beta->unitNum * beta->unitSize);
}
/* \beta = \sum_i (dE/dy_i * y_i) */
_Multiply(dedy, y, ytmp, 0, 0);
......@@ -298,8 +307,18 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
/* dE/ds_j = y_j * ytmp = y_j * (dE/dy_j - \beta) */
_Multiply(y, ytmp, dedx, 0, 0);
mem->ReleaseBuf(mem->devID, y->unitNum * y->unitSize);
mem->ReleaseBuf(mem->devID, beta->unitNum * beta->unitSize);
if(mem != NULL){
mem->ReleaseBuf(mem->devID, y->unitNum * y->unitSize);
mem->ReleaseBuf(mem->devID, beta->unitNum * beta->unitSize);
}
else{
XMemFree(y->devID, ytmp->data);
XMemFree(y->devID, beta->data);
}
ytmp->data = NULL;
beta->data = NULL;
delete[] dimSize;
delete ytmp;
......@@ -311,6 +330,8 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
}
else
ShowNTErrors("TODO!");
BacktoCudaDev(x->devID, devIDBackup);
}
#endif
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论