Commit 35e084b0 by xiaotong

add padding on the encoder side for t2t MT

parent 6f90577d
...@@ -174,10 +174,10 @@ make the network for machine translation (with the output softmax layer) ...@@ -174,10 +174,10 @@ make the network for machine translation (with the output softmax layer)
>> inputEnc - input tensor of the encoder >> inputEnc - input tensor of the encoder
>> inputDec - input tensor of the decoder >> inputDec - input tensor of the decoder
>> output - output tensor (distribution) >> output - output tensor (distribution)
>> padding - padding of the sequences >> paddingEnc - padding of the sequences (on the encoder side)
>> isTraining - indicates whether the model is for training >> isTraining - indicates whether the model is for training
*/ */
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &padding, bool isTraining) void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining)
{ {
XTensor encoding; XTensor encoding;
XTensor decoding; XTensor decoding;
...@@ -199,11 +199,44 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -199,11 +199,44 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* padding on the source side */
int * dimsPadding = new int[paddingEnc.order + 2];
for (int i = 0; i < paddingEnc.order - 1; i++)
dimsPadding[i] = paddingEnc.GetDim(i);
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
_Unsqueeze(padding2, padding3, 0, nhead);
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */
_Sum(&maskEnc, padding3, &maskEnc);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining); encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining); decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining);
outputLayer.Make(decoding, output); outputLayer.Make(decoding, output);
delete[] dims; delete[] dims;
delete[] dimsPadding;
DelTensorBuf(padding2);
DelTensorBuf(padding3);
} }
/* /*
......
...@@ -78,7 +78,7 @@ public: ...@@ -78,7 +78,7 @@ public:
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining); void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */ /* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &padding, bool isTraining); void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(XList &list); void GetParams(XList &list);
......
...@@ -183,7 +183,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -183,7 +183,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor batch; XTensor batch;
/* padding */ /* padding */
XTensor padding; XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */ /* gold standard */
XTensor gold; XTensor gold;
...@@ -191,7 +192,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -191,7 +192,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */ /* label smoothed gold standard (if needed) */
XTensor goldSmoothed; XTensor goldSmoothed;
while (LoadBatch(file, model->isLM, &batch, &padding, &gold, NULL, vSize, vSizeTgt, while (LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec,
NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)) sBatchSize, wBatchSize, isLenSorted, wc, devID, mem))
{ {
...@@ -202,9 +204,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -202,9 +204,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* make the network */ /* make the network */
if(model->isLM) if(model->isLM)
model->MakeLM(batch, output, padding, true); model->MakeLM(batch, output, paddingEnc, true);
else if(model->isMT) else if(model->isMT)
model->MakeMT(batch, gold, output, padding, true); model->MakeMT(batch, gold, output, paddingEnc, true);
else{ else{
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
...@@ -215,7 +217,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -215,7 +217,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* make paddings for the output */ /* make paddings for the output */
if (output.GetDim(0) > 1) if (output.GetDim(0) > 1)
PadOutput(&output, &gold, &padding); PadOutput(&output, &gold, &paddingDec);
//output.Dump(tmpFILE, "output: "); //output.Dump(tmpFILE, "output: ");
//fflush(tmpFILE); //fflush(tmpFILE);
...@@ -230,7 +232,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -230,7 +232,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (doUpdate) { if (doUpdate) {
/* recale the output for normalized loss */ /* recale the output for normalized loss */
RescaleOutput(&output, &g, &padding); RescaleOutput(&output, &g, &paddingDec);
/* back-propagation */ /* back-propagation */
net.Backward(output, g, CROSSENTROPY); net.Backward(output, g, CROSSENTROPY);
...@@ -331,7 +333,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -331,7 +333,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor batch; XTensor batch;
/* padding */ /* padding */
XTensor padding; XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */ /* gold standard */
XTensor gold; XTensor gold;
...@@ -341,7 +344,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -341,7 +344,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf(); ClearBuf();
while(LoadBatch(file, model->isLM, &batch, &padding, &gold, seqs, vSize, vSizeTgt, while(LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec,
seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem)) 1, 1, false, wc, devID, mem))
{ {
...@@ -352,9 +356,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -352,9 +356,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
/* make the network */ /* make the network */
if(model->isLM) if(model->isLM)
model->MakeLM(batch, output, padding, false); model->MakeLM(batch, output, paddingEnc, false);
else if(model->isMT) else if(model->isMT)
model->MakeMT(batch, gold, output, padding, false); model->MakeMT(batch, gold, output, paddingEnc, false);
else{ else{
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
...@@ -560,6 +564,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -560,6 +564,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
buf = buf2; buf = buf2;
buf2 = tmp; buf2 = tmp;
tmp = seqLen; tmp = seqLen;
seqLen = seqLen2; seqLen = seqLen2;
seqLen2 = tmp; seqLen2 = tmp;
...@@ -580,9 +585,10 @@ void T2TTrainer::ClearBuf() ...@@ -580,9 +585,10 @@ void T2TTrainer::ClearBuf()
load a batch of sequences load a batch of sequences
>> file - the handle to the data file >> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms >> isLM - indicates whether the data is used for training lms
>> batch - the batch of the input sequences >> batchEnc - the batch of the input sequences
>> padding - padding of the input sequences >> paddingEnc - padding of the input sequences
>> output - the batch of the output sequences >> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> seqs - keep the sequences in an array >> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary >> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary >> vsDec - size of the decoder vocabulary
...@@ -594,19 +600,20 @@ load a batch of sequences ...@@ -594,19 +600,20 @@ load a batch of sequences
>> mem - memory pool >> mem - memory pool
*/ */
int T2TTrainer::LoadBatch(FILE * file, bool isLM, int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem) int devID, XMem * mem)
{ {
if(isLM){ if(isLM){
return LoadBatchLM(file, batch, padding, output, seqs, return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs,
vsEnc, sBatch, wBatch, vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem); isSorted, wCount, devID, mem);
} }
else{ else{
return LoadBatchMT(file, batch, padding, output, seqs, return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs,
vsEnc, vsDec, sBatch, wBatch, vsEnc, vsDec, sBatch, wBatch,
isSorted, wCount, devID, mem); isSorted, wCount, devID, mem);
} }
...@@ -616,9 +623,10 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM, ...@@ -616,9 +623,10 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
load a batch of sequences (for LM) load a batch of sequences (for LM)
>> file - the handle to the data file >> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms >> isLM - indicates whether the data is used for training lms
>> batch - the batch of the input sequences >> batchEnc - the batch of the input sequences
>> padding - padding of the input sequences >> paddingEnc - padding of the input sequences
>> output - the batch of the output sequences >> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> seqs - keep the sequences in an array >> seqs - keep the sequences in an array
>> vs - vocabulary size >> vs - vocabulary size
>> sBatch - batch size of sequences >> sBatch - batch size of sequences
...@@ -629,7 +637,8 @@ load a batch of sequences (for LM) ...@@ -629,7 +637,8 @@ load a batch of sequences (for LM)
>> mem - memory pool >> mem - memory pool
*/ */
int T2TTrainer::LoadBatchLM(FILE * file, int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int * seqs,
int vs, int sBatch, int wBatch, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
...@@ -669,20 +678,24 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -669,20 +678,24 @@ int T2TTrainer::LoadBatchLM(FILE * file,
dims[1] = max; dims[1] = max;
dims[2] = vs; dims[2] = vs;
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem); InitTensor(batchEnc, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(padding, sc, max, X_FLOAT, devID, mem); InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem);
InitTensor(output, 3, dims, X_FLOAT, 1.0F, devID, mem); InitTensor(batchDec, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem);
XNoder::MakeGrad(batch); XNoder::MakeGrad(batchEnc);
XNoder::MakeGrad(padding); XNoder::MakeGrad(paddingEnc);
XNoder::MakeGrad(output); XNoder::MakeGrad(batchDec);
XNoder::MakeGrad(paddingDec);
batch->SetZeroAll(); batchEnc->SetZeroAll();
padding->SetZeroAll(); paddingEnc->SetZeroAll();
output->SetZeroAll(); batchDec->SetZeroAll();
batch->grad->SetZeroAll(); paddingDec->SetZeroAll();
padding->grad->SetZeroAll(); batchEnc->grad->SetZeroAll();
output->grad->SetZeroAll(); paddingEnc->grad->SetZeroAll();
batchDec->grad->SetZeroAll();
paddingDec->grad->SetZeroAll();
int seqSize = 0; int seqSize = 0;
...@@ -693,15 +706,18 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -693,15 +706,18 @@ int T2TTrainer::LoadBatchLM(FILE * file,
int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1; int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1;
CheckNTErrors(len <= max, "Something is wrong!"); CheckNTErrors(len <= max, "Something is wrong!");
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); batchEnc->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
padding->Set2D(1.0F, s - seq, w); paddingEnc->Set2D(1.0F, s - seq, w);
if(w > 0) if (w > 0) {
output->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]); batchDec->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
paddingDec->Set2D(1.0F, s - seq, w - 1);
}
if(w == len - 1){ if(w == len - 1){
if(isDoubledEnd) if(isDoubledEnd)
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
else else
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]); batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
paddingDec->Set2D(1.0F, s - seq, w);
} }
wCount++; wCount++;
/*fprintf(tf, "%d", buf[seqOffset[s] + w]); /*fprintf(tf, "%d", buf[seqOffset[s] + w]);
...@@ -727,9 +743,10 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -727,9 +743,10 @@ int T2TTrainer::LoadBatchLM(FILE * file,
/* /*
load a batch of sequences (for MT) load a batch of sequences (for MT)
>> file - the handle to the data file >> file - the handle to the data file
>> batch - the batch of the input sequences >> batchEnc - the batch of the input sequences
>> padding - padding of the input sequences >> paddingEnc - padding of the input sequences
>> output - the batch of the output sequences >> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> seqs - keep the sequences in an array >> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary >> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary >> vsDec - size of the decoder vocabulary
...@@ -741,7 +758,8 @@ load a batch of sequences (for MT) ...@@ -741,7 +758,8 @@ load a batch of sequences (for MT)
>> mem - memory pool >> mem - memory pool
*/ */
int T2TTrainer::LoadBatchMT(FILE * file, int T2TTrainer::LoadBatchMT(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
...@@ -794,13 +812,15 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -794,13 +812,15 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int dimsEnc[3] = {sCount, maxEnc, vsEnc}; int dimsEnc[3] = {sCount, maxEnc, vsEnc};
int dimsDec[3] = {sCount, maxDec, vsDec}; int dimsDec[3] = {sCount, maxDec, vsDec};
InitTensor(batch, 3, dimsEnc, X_FLOAT, 1.0F, devID, mem); InitTensor(batchEnc, 3, dimsEnc, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(padding, sCount, maxDec, X_FLOAT, devID, mem); InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(output, 3, dimsDec, X_FLOAT, 1.0F, devID, mem); InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
batch->SetZeroAll(); batchEnc->SetZeroAll();
padding->SetZeroAll(); paddingEnc->SetZeroAll();
output->SetZeroAll(); batchDec->SetZeroAll();
paddingDec->SetZeroAll();
wCount = 0; wCount = 0;
...@@ -809,7 +829,8 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -809,7 +829,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s]; int len = seqLen[s];
int sent = (s - seq)/2; int sent = (s - seq)/2;
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
batch->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]); batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, sent, w);
wCount++; wCount++;
} }
} }
...@@ -819,11 +840,11 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -819,11 +840,11 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s]; int len = seqLen[s];
int sent = (s - seq - 1)/2; int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
padding->Set2D(1.0F, sent, w); paddingDec->Set2D(1.0F, sent, w);
if(w > 0) if(w > 0)
output->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]); batchDec->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
if(w == len - 1) if(w == len - 1)
output->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]); batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
wCount++; wCount++;
if(seqs != NULL) if(seqs != NULL)
......
...@@ -166,7 +166,8 @@ public: ...@@ -166,7 +166,8 @@ public:
/* load a batch of sequences */ /* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM, int LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
...@@ -174,14 +175,16 @@ public: ...@@ -174,14 +175,16 @@ public:
/* load a batch of sequences (for language modeling) */ /* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file, int LoadBatchLM(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int vs, int sBatch, int wBatch, int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem); int devID, XMem * mem);
/* load a batch of sequences (for machine translation) */ /* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file, int LoadBatchMT(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch, int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem); int devID, XMem * mem);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论