Commit 35e084b0 by xiaotong

add padding on the encoder side for t2t MT

parent 6f90577d
......@@ -174,10 +174,10 @@ make the network for machine translation (with the output softmax layer)
>> inputEnc - input tensor of the encoder
>> inputDec - input tensor of the decoder
>> output - output tensor (distribution)
>> padding - padding of the sequences
>> paddingEnc - padding of the sequences (on the encoder side)
>> isTraining - indicates whether the model is for training
*/
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &padding, bool isTraining)
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining)
{
XTensor encoding;
XTensor decoding;
......@@ -199,11 +199,44 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* padding on the source side */
int * dimsPadding = new int[paddingEnc.order + 2];
for (int i = 0; i < paddingEnc.order - 1; i++)
dimsPadding[i] = paddingEnc.GetDim(i);
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
_Unsqueeze(padding2, padding3, 0, nhead);
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */
_Sum(&maskEnc, padding3, &maskEnc);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining);
outputLayer.Make(decoding, output);
delete[] dims;
delete[] dimsPadding;
DelTensorBuf(padding2);
DelTensorBuf(padding3);
}
/*
......
......@@ -78,7 +78,7 @@ public:
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &padding, bool isTraining);
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining);
/* get parameter matrics */
void GetParams(XList &list);
......
......@@ -183,7 +183,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor batch;
/* padding */
XTensor padding;
XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */
XTensor gold;
......@@ -191,7 +192,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */
XTensor goldSmoothed;
while (LoadBatch(file, model->isLM, &batch, &padding, &gold, NULL, vSize, vSizeTgt,
while (LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec,
NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, wc, devID, mem))
{
......@@ -202,9 +204,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* make the network */
if(model->isLM)
model->MakeLM(batch, output, padding, true);
model->MakeLM(batch, output, paddingEnc, true);
else if(model->isMT)
model->MakeMT(batch, gold, output, padding, true);
model->MakeMT(batch, gold, output, paddingEnc, true);
else{
ShowNTErrors("Illegal model type!");
}
......@@ -215,7 +217,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* make paddings for the output */
if (output.GetDim(0) > 1)
PadOutput(&output, &gold, &padding);
PadOutput(&output, &gold, &paddingDec);
//output.Dump(tmpFILE, "output: ");
//fflush(tmpFILE);
......@@ -230,7 +232,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (doUpdate) {
/* recale the output for normalized loss */
RescaleOutput(&output, &g, &padding);
RescaleOutput(&output, &g, &paddingDec);
/* back-propagation */
net.Backward(output, g, CROSSENTROPY);
......@@ -331,7 +333,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor batch;
/* padding */
XTensor padding;
XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */
XTensor gold;
......@@ -341,7 +344,8 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf();
while(LoadBatch(file, model->isLM, &batch, &padding, &gold, seqs, vSize, vSizeTgt,
while(LoadBatch(file, model->isLM, &batch, &paddingEnc, &gold, &paddingDec,
seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem))
{
......@@ -352,9 +356,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
/* make the network */
if(model->isLM)
model->MakeLM(batch, output, padding, false);
model->MakeLM(batch, output, paddingEnc, false);
else if(model->isMT)
model->MakeMT(batch, gold, output, padding, false);
model->MakeMT(batch, gold, output, paddingEnc, false);
else{
ShowNTErrors("Illegal model type!");
}
......@@ -560,6 +564,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
buf = buf2;
buf2 = tmp;
tmp = seqLen;
seqLen = seqLen2;
seqLen2 = tmp;
......@@ -580,9 +585,10 @@ void T2TTrainer::ClearBuf()
load a batch of sequences
>> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms
>> batch - the batch of the input sequences
>> padding - padding of the input sequences
>> output - the batch of the output sequences
>> batchEnc - the batch of the input sequences
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary
......@@ -594,19 +600,20 @@ load a batch of sequences
>> mem - memory pool
*/
int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem)
{
if(isLM){
return LoadBatchLM(file, batch, padding, output, seqs,
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs,
vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem);
}
else{
return LoadBatchMT(file, batch, padding, output, seqs,
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, seqs,
vsEnc, vsDec, sBatch, wBatch,
isSorted, wCount, devID, mem);
}
......@@ -616,9 +623,10 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
load a batch of sequences (for LM)
>> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms
>> batch - the batch of the input sequences
>> padding - padding of the input sequences
>> output - the batch of the output sequences
>> batchEnc - the batch of the input sequences
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> seqs - keep the sequences in an array
>> vs - vocabulary size
>> sBatch - batch size of sequences
......@@ -629,7 +637,8 @@ load a batch of sequences (for LM)
>> mem - memory pool
*/
int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs,
int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
......@@ -669,20 +678,24 @@ int T2TTrainer::LoadBatchLM(FILE * file,
dims[1] = max;
dims[2] = vs;
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(padding, sc, max, X_FLOAT, devID, mem);
InitTensor(output, 3, dims, X_FLOAT, 1.0F, devID, mem);
XNoder::MakeGrad(batch);
XNoder::MakeGrad(padding);
XNoder::MakeGrad(output);
batch->SetZeroAll();
padding->SetZeroAll();
output->SetZeroAll();
batch->grad->SetZeroAll();
padding->grad->SetZeroAll();
output->grad->SetZeroAll();
InitTensor(batchEnc, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem);
XNoder::MakeGrad(batchEnc);
XNoder::MakeGrad(paddingEnc);
XNoder::MakeGrad(batchDec);
XNoder::MakeGrad(paddingDec);
batchEnc->SetZeroAll();
paddingEnc->SetZeroAll();
batchDec->SetZeroAll();
paddingDec->SetZeroAll();
batchEnc->grad->SetZeroAll();
paddingEnc->grad->SetZeroAll();
batchDec->grad->SetZeroAll();
paddingDec->grad->SetZeroAll();
int seqSize = 0;
......@@ -693,15 +706,18 @@ int T2TTrainer::LoadBatchLM(FILE * file,
int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1;
CheckNTErrors(len <= max, "Something is wrong!");
for(int w = 0; w < len; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
padding->Set2D(1.0F, s - seq, w);
if(w > 0)
output->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
batchEnc->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, s - seq, w);
if (w > 0) {
batchDec->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
paddingDec->Set2D(1.0F, s - seq, w - 1);
}
if(w == len - 1){
if(isDoubledEnd)
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
else
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
batchDec->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
paddingDec->Set2D(1.0F, s - seq, w);
}
wCount++;
/*fprintf(tf, "%d", buf[seqOffset[s] + w]);
......@@ -727,9 +743,10 @@ int T2TTrainer::LoadBatchLM(FILE * file,
/*
load a batch of sequences (for MT)
>> file - the handle to the data file
>> batch - the batch of the input sequences
>> padding - padding of the input sequences
>> output - the batch of the output sequences
>> batchEnc - the batch of the input sequences
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary
......@@ -741,7 +758,8 @@ load a batch of sequences (for MT)
>> mem - memory pool
*/
int T2TTrainer::LoadBatchMT(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
......@@ -794,13 +812,15 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int dimsEnc[3] = {sCount, maxEnc, vsEnc};
int dimsDec[3] = {sCount, maxDec, vsDec};
InitTensor(batch, 3, dimsEnc, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(padding, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(output, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor(batchEnc, 3, dimsEnc, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
batch->SetZeroAll();
padding->SetZeroAll();
output->SetZeroAll();
batchEnc->SetZeroAll();
paddingEnc->SetZeroAll();
batchDec->SetZeroAll();
paddingDec->SetZeroAll();
wCount = 0;
......@@ -809,7 +829,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s];
int sent = (s - seq)/2;
for(int w = 0; w < len; w++){
batch->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, sent, w);
wCount++;
}
}
......@@ -819,11 +840,11 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s];
int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){
padding->Set2D(1.0F, sent, w);
paddingDec->Set2D(1.0F, sent, w);
if(w > 0)
output->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
batchDec->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
if(w == len - 1)
output->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
wCount++;
if(seqs != NULL)
......
......@@ -166,7 +166,8 @@ public:
/* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
......@@ -174,14 +175,16 @@ public:
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论