Commit a0a38702 by xiaotong

add label smoothing

parent c8cb9219
......@@ -108,6 +108,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
LoadParamBool(argc, argv, "shuffled", &isShuffled, false);
LoadParamFloat(argc, argv, "labelsmoothing", &labelSmoothingP, 0);
LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1);
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
......@@ -180,6 +181,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* gold standard */
XTensor gold;
/* label smoothed gold standard (if needed) */
XTensor goldSmoothed;
while(LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
......@@ -190,12 +194,17 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* make the network */
model->Make(batch, output, padding, true);
/* back-propagation for obtaining gradients */
if(labelSmoothingP > 0)
LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
/* make paddings for the output */
if(output.GetDim(0) > 1)
PadOutput(&output, &padding);
PadOutput(&output, &gold, &padding);
/* back-propagation for obtaining gradients */
net.Backward(output, gold, CROSSENTROPY);
XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
net.Backward(output, g, CROSSENTROPY);
/* learning rate */
lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)step + 1, -0.5F - lrbias), ((float)step + 1) * pow((float)nwarmup, -1.5F - lrbias));
......@@ -789,9 +798,11 @@ void T2TTrainer::PrepareModel(T2TModel * model)
/*
do padding on the output
>> output - output tensor of the network
>> gold - gold standard
>> padding - padding of a batch of sentences
>> lsP - smoothing factor
*/
void T2TTrainer::PadOutput(XTensor * output, XTensor * padding)
void T2TTrainer::PadOutput(XTensor * output, XTensor * gold, XTensor * padding)
{
if(output == NULL || padding == NULL)
return;
......@@ -807,13 +818,45 @@ void T2TTrainer::PadOutput(XTensor * output, XTensor * padding)
_CopyValues(padding, padding2);
_ScaleAndShiftMe(padding2, 1e9F, -1e9F);
_SumDim(output, padding2, output, 0);
output->Reshape(on, dimso);
if(gold != NULL){
gold->Reshape(gold->unitNum/dimso[output->order - 1], dimso[output->order - 1]);
_CopyValues(padding, padding2);
_MultiplyDim(gold, padding2, gold, 0);
gold->Reshape(on, dimso);
}
delete[] dimso;
DelTensorBuf(padding2);
}
/*
perform label smoothing
>> gold - gold standard
>> smoothed - result of label smoothing
>> lsP - smoothing factor
*/
void T2TTrainer::LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE lsP)
{
DTYPE p = lsP;
CheckNTErrors(p >= 0 && p <= 1.0F, "Smoothing factor must be in range [0,1]");
int n = gold->GetDim(-1);
DTYPE q = 1.0F - p;
DTYPE gift = p / (n - 1);
InitTensor(smoothed, gold);
_CopyValues(gold, smoothed);
if(p == 0)
return;
_ScaleAndShiftMe(smoothed, gift/q, -gift/q);
_Sum(smoothed, gold, smoothed);
_ScaleAndShiftMe(smoothed, q);
}
}
......@@ -116,6 +116,9 @@ public:
/* indicates whether the data file is shuffled for training */
bool isShuffled;
/* the factor of label smoothing */
DTYPE labelSmoothingP;
/* number of steps after which we make a checkpoint */
int nStepCheckpoint;
......@@ -168,7 +171,10 @@ public:
void PrepareModel(T2TModel * model);
/* do padding on the output */
void PadOutput(XTensor * output, XTensor * padding);
void PadOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* perform label smoothing */
void LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE lsP);
};
......
......@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "Transformer.h"
#include "T2TModel.h"
#include "T2TUtility.h"
......@@ -33,6 +34,8 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0)
return 1;
fprintf(stderr, "%e\n", log(1e-45F));
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论