Commit 66855df5 by xiaotong

refine some bad code

parent 2046cd23
...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv ) ...@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(123); //_CrtSetBreakAlloc(123);
/* a tiny test */ /* a tiny test */
if(true) if(false)
SmallTest(); SmallTest();
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
...@@ -53,8 +53,8 @@ int main( int argc, const char ** argv ) ...@@ -53,8 +53,8 @@ int main( int argc, const char ** argv )
if(argc > 1 && !strcmp(argv[1], "-test")) if(argc > 1 && !strcmp(argv[1], "-test"))
Test(); Test();
//else if(argc > 1 && !strcmp(argv[1], "-fnnlm")) else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
// FNNLMMain(argc - 1, argv + 1); FNNLMMain(argc - 1, argv + 1);
else{ else{
fprintf(stderr, "Thanks for using NiuTrans.Tensor! This is a library that eases the\n"); fprintf(stderr, "Thanks for using NiuTrans.Tensor! This is a library that eases the\n");
fprintf(stderr, "use of tensors. All you need is to ... \n\n"); fprintf(stderr, "use of tensors. All you need is to ... \n\n");
......
...@@ -210,7 +210,7 @@ Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x ...@@ -210,7 +210,7 @@ Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha, DTYPE beta, XPRunner * parallelRunner) DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{ {
CheckNTErrors((&a && &b), "Empty input tensors!"); CheckNTErrors((&a != &NULLTensor && &b != &NULLTensor), "Empty input tensors!");
CheckNTErrors((a.dataType == b.dataType), "Input tensors should have the same data type!"); CheckNTErrors((a.dataType == b.dataType), "Input tensors should have the same data type!");
CheckNTErrors((a.order >= 2 && b.order >= 2), "Input tensors must have a order >= 2!"); CheckNTErrors((a.order >= 2 && b.order >= 2), "Input tensors must have a order >= 2!");
...@@ -246,7 +246,7 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor ...@@ -246,7 +246,7 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return c; return c;
} }
......
...@@ -175,9 +175,9 @@ where trans() returns the transposed matrix if the flag is fired. ...@@ -175,9 +175,9 @@ where trans() returns the transposed matrix if the flag is fired.
XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha, DTYPE beta, XPRunner * parallelRunner) DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{ {
CheckNTErrors((&a && &b), "Empty input tensors!"); CheckNTErrors(&a != &NULLTensor && &b != &NULLTensor, "Empty input tensors!");
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!"); CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors((a.order >= 2 && b.order >= 2), "Input tensors must have a order >= 2!"); CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
CheckNTErrors(a.order == b.order, "Input tensor and output tensor must have same order!"); CheckNTErrors(a.order == b.order, "Input tensor and output tensor must have same order!");
int an = transposedA == X_TRANS ? a.dimSizeRDI[0] : a.dimSizeRDI[1]; int an = transposedA == X_TRANS ? a.dimSizeRDI[0] : a.dimSizeRDI[1];
...@@ -210,7 +210,7 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const ...@@ -210,7 +210,7 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return c; return c;
} }
......
...@@ -123,7 +123,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high) ...@@ -123,7 +123,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
XLink::AddParamToHead(&c, high); XLink::AddParamToHead(&c, high);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return c; return c;
} }
......
...@@ -131,7 +131,7 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in ...@@ -131,7 +131,7 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
_CopyIndexed(&s, &t, dim, srcIndex, indexSize, tgtIndex, copyNum); _CopyIndexed(&s, &t, dim, srcIndex, indexSize, tgtIndex, copyNum);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED); XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
......
...@@ -121,7 +121,7 @@ XTensor ReduceMax(const XTensor &input, int dim) ...@@ -121,7 +121,7 @@ XTensor ReduceMax(const XTensor &input, int dim)
_ReduceMax(&input, &output, dim); _ReduceMax(&input, &output, dim);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
......
...@@ -82,9 +82,9 @@ XTensor ReduceMean(const XTensor &input, int dim) ...@@ -82,9 +82,9 @@ XTensor ReduceMean(const XTensor &input, int dim)
XLink::AddParamToHead(&output, dim); XLink::AddParamToHead(&output, dim);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return output; return output;
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -239,7 +239,7 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow ...@@ -239,7 +239,7 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow
XLink::AddParamToHead(&output, power); XLink::AddParamToHead(&output, power);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return output; return output;
} }
......
...@@ -78,9 +78,9 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift) ...@@ -78,9 +78,9 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
XLink::AddParamToHead(&output, dim); XLink::AddParamToHead(&output, dim);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return output; return output;
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -76,7 +76,7 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean) ...@@ -76,7 +76,7 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
_ReduceVariance(&input, &output, dim, &mean); _ReduceVariance(&input, &output, dim, &mean);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return output; return output;
} }
......
...@@ -188,7 +188,7 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim) ...@@ -188,7 +188,7 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
_Merge(&s, &t, whereToMerge, leadingDim); _Merge(&s, &t, whereToMerge, leadingDim);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return t; return t;
} }
...@@ -335,7 +335,7 @@ XTensor Merge(const XList &smalls, int whereToMerge) ...@@ -335,7 +335,7 @@ XTensor Merge(const XList &smalls, int whereToMerge)
_Merge(&smalls, &big, whereToMerge); _Merge(&smalls, &big, whereToMerge);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return big; return big;
} }
......
...@@ -162,7 +162,7 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum) ...@@ -162,7 +162,7 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
_Split(&s, &t, whereToSplit, splitNum); _Split(&s, &t, whereToSplit, splitNum);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return t; return t;
} }
...@@ -308,7 +308,7 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum) ...@@ -308,7 +308,7 @@ XList SplitList(const XTensor &big, int whereToSplit, int splitNum)
_Split(&big, &smalls, whereToSplit, splitNum); _Split(&big, &smalls, whereToSplit, splitNum);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return smalls; return smalls;
} }
......
...@@ -130,7 +130,7 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize) ...@@ -130,7 +130,7 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
_Unsqueeze(&a, &b, dim, dSize); _Unsqueeze(&a, &b, dim, dSize);
/* destroy variables */ /* destroy variables */
delete dimSize; delete[] dimSize;
return b; return b;
} }
......
...@@ -178,7 +178,11 @@ DTYPE LossCompute(XTensor * gold, XTensor * output, LOSS_FUNCTION_NAME LFName, ...@@ -178,7 +178,11 @@ DTYPE LossCompute(XTensor * gold, XTensor * output, LOSS_FUNCTION_NAME LFName,
} }
} }
else { else {
#ifdef USE_CUDA
error = CudaLossCompute(gold, output, LFName, isLogOutput, leadDim, gBeg, gLen, oBeg); error = CudaLossCompute(gold, output, LFName, isLogOutput, leadDim, gBeg, gLen, oBeg);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
} }
return error; return error;
...@@ -476,7 +480,11 @@ void LossBackward(XTensor * dedy, XTensor * t, XTensor * y, ...@@ -476,7 +480,11 @@ void LossBackward(XTensor * dedy, XTensor * t, XTensor * y,
} }
} }
else { else {
#ifdef USE_CUDA
CudaLossBackward(dedy, t, y, LFName, leadDim, tBeg, tLen, yBeg); CudaLossBackward(dedy, t, y, LFName, leadDim, tBeg, tLen, yBeg);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
} }
} }
......
...@@ -490,9 +490,9 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs) ...@@ -490,9 +490,9 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
/* probability of each word */ /* probability of each word */
XTensor wprobs; XTensor wprobs;
InitTensor1D(&wprobs, output.GetDim(0), output.dataType, output.devID, output.mem); InitTensor1D(&wprobs, output.GetDim(0), output.dataType, output.devID, output.mem);
ReduceSum(&probs, &wprobs, 1); _ReduceSum(&probs, &wprobs, 1);
if(wordProbs != NULL) if(wordProbs != NULL)
CopyValues(&wprobs, wordProbs); _CopyValues(&wprobs, wordProbs);
/* reshape the tensor to fit it into the reduce procedure /* reshape the tensor to fit it into the reduce procedure
TODO: XTensor supports scalars */ TODO: XTensor supports scalars */
...@@ -504,7 +504,7 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs) ...@@ -504,7 +504,7 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
/* probability for the batch */ /* probability for the batch */
XTensor result; XTensor result;
InitTensor1D(&result, 1, X_FLOAT, output.devID, output.mem); InitTensor1D(&result, 1, X_FLOAT, output.devID, output.mem);
ReduceSum(&probs, &result, 1); _ReduceSum(&probs, &result, 1);
return result.Get1D(0); return result.Get1D(0);
} }
...@@ -673,7 +673,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -673,7 +673,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* generate word embedding of position i: /* generate word embedding of position i:
embedding = input * w */ embedding = input * w */
MatrixMul(&input, X_NOTRANS, &w, X_NOTRANS, &embedding); _MatrixMul(&input, X_NOTRANS, &w, X_NOTRANS, &embedding);
eList.Add(&net.embeddings[i]); eList.Add(&net.embeddings[i]);
} }
...@@ -681,7 +681,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -681,7 +681,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* concatenate word embeddings /* concatenate word embeddings
embeddingcat = cat(embedding_0...embedding_{n-1}) */ embeddingcat = cat(embedding_0...embedding_{n-1}) */
InitModelTensor2D(net.embeddingCat, batchSize, (n - 1) * model.eSize, model); InitModelTensor2D(net.embeddingCat, batchSize, (n - 1) * model.eSize, model);
Concatenate(&eList, &net.embeddingCat, 1); _Concatenate(&eList, &net.embeddingCat, 1);
/* go over each hidden layer */ /* go over each hidden layer */
for(int i = 0; i < depth; i++){ for(int i = 0; i < depth; i++){
...@@ -696,12 +696,12 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -696,12 +696,12 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* generate hidden states of layer i: /* generate hidden states of layer i:
s = h_pre * w */ s = h_pre * w */
MatrixMul(&h_pre, X_NOTRANS, &w, X_NOTRANS, &s); _MatrixMul(&h_pre, X_NOTRANS, &w, X_NOTRANS, &s);
/* make a 2d tensor for the bias term */ /* make a 2d tensor for the bias term */
XTensor b2D; XTensor b2D;
InitTensor(&b2D, &s); InitTensor(&b2D, &s);
Unsqueeze(&b, &b2D, 0, batchSize); _Unsqueeze(&b, &b2D, 0, batchSize);
/* introduce bias term: /* introduce bias term:
s = s + b s = s + b
...@@ -711,7 +711,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -711,7 +711,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* pass the state through the hard tanh function: /* pass the state through the hard tanh function:
h = tanh(s) */ h = tanh(s) */
HardTanH(&s, &h); _HardTanH(&s, &h);
} }
/* generate the output Pr(w_{n-1}|w_0...w_{n-2}): /* generate the output Pr(w_{n-1}|w_0...w_{n-2}):
...@@ -729,16 +729,16 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -729,16 +729,16 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
InitModelTensor2D(y, batchSize, model.vSize, model); InitModelTensor2D(y, batchSize, model.vSize, model);
/* s = h_last * w */ /* s = h_last * w */
MatrixMul(&h_last, X_NOTRANS, &w, X_NOTRANS, &s); _MatrixMul(&h_last, X_NOTRANS, &w, X_NOTRANS, &s);
XTensor b2D; XTensor b2D;
InitTensor(&b2D, &s); InitTensor(&b2D, &s);
Unsqueeze(&b, &b2D, 0, batchSize); _Unsqueeze(&b, &b2D, 0, batchSize);
_Sum(&s, &b2D, &s); _Sum(&s, &b2D, &s);
/* y = softmax(s) */ /* y = softmax(s) */
LogSoftmax(&s, &y, 1); _LogSoftmax(&s, &y, 1);
} }
...@@ -782,18 +782,18 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -782,18 +782,18 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
x is the top most hidden layer) x is the top most hidden layer)
so we know so we know
dE/dw = x^T * dE/ds */ dE/dw = x^T * dE/ds */
MatrixMul(&x, X_TRANS, &deds, X_NOTRANS, &dedw); _MatrixMul(&x, X_TRANS, &deds, X_NOTRANS, &dedw);
/* gradient of the bias: dE/db = dE/ds * 1 = dE/ds /* gradient of the bias: dE/db = dE/ds * 1 = dE/ds
specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */ specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */
ReduceSum(&deds, &dedb, 0); _ReduceSum(&deds, &dedb, 0);
/* then, we compute /* then, we compute
dE/dx_{j} = \sum_j' (dE/ds_{j'} * ds_{j'}/dx_j) dE/dx_{j} = \sum_j' (dE/ds_{j'} * ds_{j'}/dx_j)
= \sum_j' (dE/ds_{j'} * w_{j, j'}) = \sum_j' (dE/ds_{j'} * w_{j, j'})
i.e., i.e.,
dE/dx = dE/ds * w^T */ dE/dx = dE/ds * w^T */
MatrixMul(&deds, X_NOTRANS, &w, X_TRANS, &dedx); _MatrixMul(&deds, X_NOTRANS, &w, X_TRANS, &dedx);
XTensor &gradPassed = dedx; XTensor &gradPassed = dedx;
XTensor dedsHidden; XTensor dedsHidden;
...@@ -821,17 +821,17 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -821,17 +821,17 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
HardTanHBackward(NULL, &h, &s, &dedh, &deds, NOLOSS); HardTanHBackward(NULL, &h, &s, &dedh, &deds, NOLOSS);
/* gradient of the weight: dE/dw = x^T * dE/ds */ /* gradient of the weight: dE/dw = x^T * dE/ds */
MatrixMul(&x, X_TRANS, &deds, X_NOTRANS, &dedw); _MatrixMul(&x, X_TRANS, &deds, X_NOTRANS, &dedw);
/* gradient of the bias: dE/db = dE/ds * 1 = dE/ds /* gradient of the bias: dE/db = dE/ds * 1 = dE/ds
specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */ specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */
ReduceSum(&deds, &dedb, 0); _ReduceSum(&deds, &dedb, 0);
/* gradient of the input: dE/dx = dE/ds * w^T */ /* gradient of the input: dE/dx = dE/ds * w^T */
MatrixMul(&deds, X_NOTRANS, &w, X_TRANS, &dedx); _MatrixMul(&deds, X_NOTRANS, &w, X_TRANS, &dedx);
if (i > 0) if (i > 0)
CopyValues(&dedx, &gradPassed); _CopyValues(&dedx, &gradPassed);
} }
XList eList(n - 1); XList eList(n - 1);
...@@ -846,7 +846,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -846,7 +846,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
XTensor &dedyCat = depth > 0 ? dedxBottom : dedx; XTensor &dedyCat = depth > 0 ? dedxBottom : dedx;
/* split the concatenation of gradients of the embeddings */ /* split the concatenation of gradients of the embeddings */
Split(&dedyCat, &eList, 1, n - 1); _Split(&dedyCat, &eList, 1, n - 1);
/* go over for each word */ /* go over for each word */
for (int i = 0; i < n - 1; i++) { for (int i = 0; i < n - 1; i++) {
...@@ -857,7 +857,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -857,7 +857,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
/* gradient of the embedding weight: dE/dw += x^T * dE/dy /* gradient of the embedding weight: dE/dw += x^T * dE/dy
NOTE that we accumulate dE/dw here because the matrix w NOTE that we accumulate dE/dw here because the matrix w
is shared by several layers (or words) */ is shared by several layers (or words) */
MatrixMul(&x, X_TRANS, dedy, X_NOTRANS, &dedw, 1.0F, 1.0F); _MatrixMul(&x, X_TRANS, dedy, X_NOTRANS, &dedw, 1.0F, 1.0F);
delete dedy; delete dedy;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论