Commit 21892dbf by xiaotong

softmax output -> logsoftmax output

parent 8f2d16f5
...@@ -93,8 +93,8 @@ void T2TOutput::Make(XTensor &input, XTensor &output) ...@@ -93,8 +93,8 @@ void T2TOutput::Make(XTensor &input, XTensor &output)
{ {
XTensor &x = input; XTensor &x = input;
//output = LogSoftmax(MMul(x, w), -1); output = LogSoftmax(MMul(x, w), -1);
output = Softmax(MMul(x, w), -1); //output = Softmax(MMul(x, w), -1);
} }
} }
...@@ -218,8 +218,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -218,8 +218,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
LabelSmooth(&gold, &goldSmoothed, labelSmoothingP); LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
/* make paddings for the output */ /* make paddings for the output */
//if (output.GetDim(0) > 1) if (output.GetDim(0) > 1)
// PadOutput(&output, &gold, &paddingDec); PadOutput(&output, &gold, &paddingDec);
/* get probabilities */ /* get probabilities */
float prob = GetProb(&output, &gold, NULL); float prob = GetProb(&output, &gold, NULL);
...@@ -232,7 +232,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -232,7 +232,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (doUpdate) { if (doUpdate) {
/* recale the output for normalized loss */ /* recale the output for normalized loss */
//RescaleOutput(&output, &g, &paddingDec); RescaleOutput(&output, &g, &paddingDec);
/* back-propagation */ /* back-propagation */
net.Backward(output, g, paddingDec, CROSSENTROPY); net.Backward(output, g, paddingDec, CROSSENTROPY);
...@@ -977,12 +977,12 @@ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs) ...@@ -977,12 +977,12 @@ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
XTensor probs; XTensor probs;
InitTensor(&probs, output); InitTensor(&probs, output);
XTensor logOutput; /*XTensor logOutput;
InitTensor(&logOutput, output); InitTensor(&logOutput, output);
_Log(output, &logOutput); _Log(output, &logOutput);*/
/* probs[i,j] = output[i,j] * gold[i,j] */ /* probs[i,j] = output[i,j] * gold[i,j] */
_Multiply(&logOutput, gold, &probs); _Multiply(output, gold, &probs);
/* probability of each word */ /* probability of each word */
XTensor wprobs; XTensor wprobs;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论