Commit 1e3c89a0 by xiaotong

normalize the loss

parent 7cf73b1e
...@@ -215,6 +215,10 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -215,6 +215,10 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold; XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
if (doUpdate) { if (doUpdate) {
/* recale the output for normalized loss */
RescaleOutput(&output, &g, &padding);
/* back-propagation */ /* back-propagation */
net.Backward(output, g, CROSSENTROPY); net.Backward(output, g, CROSSENTROPY);
...@@ -870,6 +874,30 @@ void T2TTrainer::PadOutput(XTensor * output, XTensor * gold, XTensor * padding) ...@@ -870,6 +874,30 @@ void T2TTrainer::PadOutput(XTensor * output, XTensor * gold, XTensor * padding)
delete[] dimso; delete[] dimso;
DelTensorBuf(padding2); DelTensorBuf(padding2);
} }
/*
recale the output and gold tensors for normalized loss
>> output - output tensor of the network
>> gold - gold standard
>> padding - padding of a batch of sentences
*/
void T2TTrainer::RescaleOutput(XTensor * output, XTensor * gold, XTensor * padding)
{
CheckNTErrors(output->order == 3, "Wrong dimension number!");
CheckNTErrors(gold->order == 3, "Wrong dimension number!");
int num = padding->GetDim(0);
XTensor * factor = NewTensorBuf(1, &num, padding->dataType, 1.0F, padding->devID, padding->mem);
_ReduceSum(padding, factor, padding->order - 1);
_ExpMe(output);
_DivDim(output, factor, output, 0);
_LogMe(output);
_DivDim(gold, factor, gold, 0);
DelTensorBuf(factor);
}
/* /*
perform label smoothing perform label smoothing
......
...@@ -179,6 +179,9 @@ public: ...@@ -179,6 +179,9 @@ public:
/* do padding on the output */ /* do padding on the output */
void PadOutput(XTensor * output, XTensor * gold, XTensor * padding); void PadOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* recale the output and gold tensors for normalized loss */
void RescaleOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* perform label smoothing */ /* perform label smoothing */
void LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE p); void LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE p);
}; };
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论