Commit ecc4041d by xiaotong

Merge branch 'xiaotong-working' of 47.105.50.196:NiuTrans/NiuTrans.Tensor into xiaotong-working

parents 4a609624 03064396
...@@ -61,7 +61,7 @@ void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) ...@@ -61,7 +61,7 @@ void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem); InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);
DTYPE v = 1.0F/(float)sqrt((float)eSize); DTYPE v = 1.0F/(float)sqrt((float)eSize);
w.SetDataRand(-v, v); w.SetDataRandn(0, v);
/* create the positional embedding matrix */ /* create the positional embedding matrix */
MakePosEmbedding(eSize, d, maxLength); MakePosEmbedding(eSize, d, maxLength);
......
...@@ -59,10 +59,7 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) ...@@ -59,10 +59,7 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
InitTensor1D(&w, d, X_FLOAT, devID, mem); InitTensor1D(&w, d, X_FLOAT, devID, mem);
InitTensor1D(&b, d, X_FLOAT, devID, mem); InitTensor1D(&b, d, X_FLOAT, devID, mem);
float scale = 1.0F; w.SetDataRand(1.0F, 1.0F);
float finfout = (float)sqrt(6.0F * scale / d);
w.SetDataRand(-finfout, finfout);
b.SetZeroAll(); b.SetZeroAll();
} }
......
...@@ -66,6 +66,9 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) ...@@ -66,6 +66,9 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
float scale = 1.0F; float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale/(hSize + vSize)); float finfout = (float)sqrt(6.0F * scale/(hSize + vSize));
w.SetDataRand(-finfout, finfout); w.SetDataRand(-finfout, finfout);
DTYPE v = 1.0F/(float)sqrt((float)hSize);
w.SetDataRandn(0, v);
} }
/* /*
......
...@@ -217,6 +217,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -217,6 +217,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (output.GetDim(0) > 1) if (output.GetDim(0) > 1)
PadOutput(&output, &gold, &padding); PadOutput(&output, &gold, &padding);
//output.Dump(tmpFILE, "output: ");
//fflush(tmpFILE);
/* get probabilities */ /* get probabilities */
float prob = GetProb(&output, &gold, NULL); float prob = GetProb(&output, &gold, NULL);
DTYPE lossLocal = -prob / wc; DTYPE lossLocal = -prob / wc;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论