Commit a0a38702 by xiaotong

add label smoothing

parent c8cb9219
...@@ -108,6 +108,7 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -108,6 +108,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F); LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F); LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
LoadParamBool(argc, argv, "shuffled", &isShuffled, false); LoadParamBool(argc, argv, "shuffled", &isShuffled, false);
LoadParamFloat(argc, argv, "labelsmoothing", &labelSmoothingP, 0);
LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1); LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1);
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false); LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
...@@ -180,6 +181,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -180,6 +181,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* gold standard */ /* gold standard */
XTensor gold; XTensor gold;
/* label smoothed gold standard (if needed) */
XTensor goldSmoothed;
while(LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){ while(LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch"); CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
...@@ -190,12 +194,17 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -190,12 +194,17 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* make the network */ /* make the network */
model->Make(batch, output, padding, true); model->Make(batch, output, padding, true);
/* back-propagation for obtaining gradients */
if(labelSmoothingP > 0)
LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
/* make paddings for the output */ /* make paddings for the output */
if(output.GetDim(0) > 1) if(output.GetDim(0) > 1)
PadOutput(&output, &padding); PadOutput(&output, &gold, &padding);
/* back-propagation for obtaining gradients */ XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
net.Backward(output, gold, CROSSENTROPY);
net.Backward(output, g, CROSSENTROPY);
/* learning rate */ /* learning rate */
lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)step + 1, -0.5F - lrbias), ((float)step + 1) * pow((float)nwarmup, -1.5F - lrbias)); lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)step + 1, -0.5F - lrbias), ((float)step + 1) * pow((float)nwarmup, -1.5F - lrbias));
...@@ -789,9 +798,11 @@ void T2TTrainer::PrepareModel(T2TModel * model) ...@@ -789,9 +798,11 @@ void T2TTrainer::PrepareModel(T2TModel * model)
/* /*
do padding on the output do padding on the output
>> output - output tensor of the network >> output - output tensor of the network
>> gold - gold standard
>> padding - padding of a batch of sentences >> padding - padding of a batch of sentences
>> lsP - smoothing factor
*/ */
void T2TTrainer::PadOutput(XTensor * output, XTensor * padding) void T2TTrainer::PadOutput(XTensor * output, XTensor * gold, XTensor * padding)
{ {
if(output == NULL || padding == NULL) if(output == NULL || padding == NULL)
return; return;
...@@ -807,13 +818,45 @@ void T2TTrainer::PadOutput(XTensor * output, XTensor * padding) ...@@ -807,13 +818,45 @@ void T2TTrainer::PadOutput(XTensor * output, XTensor * padding)
_CopyValues(padding, padding2); _CopyValues(padding, padding2);
_ScaleAndShiftMe(padding2, 1e9F, -1e9F); _ScaleAndShiftMe(padding2, 1e9F, -1e9F);
_SumDim(output, padding2, output, 0); _SumDim(output, padding2, output, 0);
output->Reshape(on, dimso); output->Reshape(on, dimso);
if(gold != NULL){
gold->Reshape(gold->unitNum/dimso[output->order - 1], dimso[output->order - 1]);
_CopyValues(padding, padding2);
_MultiplyDim(gold, padding2, gold, 0);
gold->Reshape(on, dimso);
}
delete[] dimso; delete[] dimso;
DelTensorBuf(padding2); DelTensorBuf(padding2);
} }
/*
perform label smoothing
>> gold - gold standard
>> smoothed - result of label smoothing
>> lsP - smoothing factor
*/
void T2TTrainer::LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE lsP)
{
DTYPE p = lsP;
CheckNTErrors(p >= 0 && p <= 1.0F, "Smoothing factor must be in range [0,1]");
int n = gold->GetDim(-1);
DTYPE q = 1.0F - p;
DTYPE gift = p / (n - 1);
InitTensor(smoothed, gold);
_CopyValues(gold, smoothed);
if(p == 0)
return;
_ScaleAndShiftMe(smoothed, gift/q, -gift/q);
_Sum(smoothed, gold, smoothed);
_ScaleAndShiftMe(smoothed, q);
}
} }
...@@ -115,6 +115,9 @@ public: ...@@ -115,6 +115,9 @@ public:
/* indicates whether the data file is shuffled for training */ /* indicates whether the data file is shuffled for training */
bool isShuffled; bool isShuffled;
/* the factor of label smoothing */
DTYPE labelSmoothingP;
/* number of steps after which we make a checkpoint */ /* number of steps after which we make a checkpoint */
int nStepCheckpoint; int nStepCheckpoint;
...@@ -168,7 +171,10 @@ public: ...@@ -168,7 +171,10 @@ public:
void PrepareModel(T2TModel * model); void PrepareModel(T2TModel * model);
/* do padding on the output */ /* do padding on the output */
void PadOutput(XTensor * output, XTensor * padding); void PadOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* perform label smoothing */
void LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE lsP);
}; };
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/ */
#include <math.h>
#include "Transformer.h" #include "Transformer.h"
#include "T2TModel.h" #include "T2TModel.h"
#include "T2TUtility.h" #include "T2TUtility.h"
...@@ -32,6 +33,8 @@ int TransformerMain(int argc, const char ** argv) ...@@ -32,6 +33,8 @@ int TransformerMain(int argc, const char ** argv)
{ {
if(argc == 0) if(argc == 0)
return 1; return 1;
fprintf(stderr, "%e\n", log(1e-45F));
char ** args = new char*[argc]; char ** args = new char*[argc];
for(int i = 0; i < argc; i++){ for(int i = 0; i < argc; i++){
...@@ -93,4 +96,4 @@ int TransformerMain(int argc, const char ** argv) ...@@ -93,4 +96,4 @@ int TransformerMain(int argc, const char ** argv)
return 0; return 0;
} }
} }
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论