Commit c177e3c5 by huchi

fix some bugs in beam search

parent e155205c
...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv ) ...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv )
_CrtSetBreakAlloc(2708);*/ _CrtSetBreakAlloc(2708);*/
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
//XTensor singleScore, singleIdx, score;
//InitTensor3DV2(&score, 2, 1, 136160);
////score.SetDataRand(0, 1);
//InitTensor1DV2(&singleIdx, 1, X_INT);
//singleIdx.Set1DInt(1, 0);
//singleIdx.Dump(stderr);
//singleScore = Select(score, singleIdx, 0);
//XTensor s, i;
//InitTensor3DV2(&s, 2, 1, 4);
//InitTensor3DV2(&i, 2, 1, 4, X_INT);
//TopK(score, s, i, -1, 4);
//i.Dump(stderr, "single score:\n");
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
return 0; return 0;
......
...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe ...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe
MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc); MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
/* decoder mask */ /* decoder mask */
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec, 0); MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, &maskEnc, isTraining); encoding = MakeEncoder(inputEnc, &maskEnc, isTraining);
...@@ -225,8 +225,8 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -225,8 +225,8 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
InitTensorV2(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
a given sequence. */ a given sequence. */
...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
/* encoder-decoder mask that prevents the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F); _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -297,22 +295,22 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma ...@@ -297,22 +295,22 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
dimsPadding[i] = paddingEnc.GetDim(i); dimsPadding[i] = paddingEnc.GetDim(i);
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor* padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
_Unsqueeze(padding2, padding3, 0, nhead); _Unsqueeze(padding2, padding3, 0, nhead);
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -332,33 +330,37 @@ make the mask of the decoder ...@@ -332,33 +330,37 @@ make the mask of the decoder
>> maksDec - mask of the decoder self-attention >> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention >> maksEncDec - mask of the decoder enc-dec attention
*/ */
void T2TModel::MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec, void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor& paddingEnc, XTensor& paddingDec, XTensor &paddingEnc, XTensor &paddingDec,
XTensor& maskDec, XTensor& maskEncDec, int incDim) XTensor &maskDec, XTensor &maskEncDec)
{ {
int len = inputDec.GetDim(inputDec.order - 1); int len = inputDec.GetDim(inputDec.order - 1);
int* dims = new int[inputDec.order + 2]; int * dims = new int[inputDec.order + 2];
for (int i = 0; i < inputDec.order; i++) for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
//dims[inputDec.order] += incDim;
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
//InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* An upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
This matrix can be used to block the attention to current or following words in This matrix can be used to block the attention to current or following words in
a given sequence. */ a given sequence. */
//_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
//_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevents the attention to padding dummy words */ //maskDec.Dump(stderr, "mask: ");
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
//maskDec.Dump(stderr, "mask: ");
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
//paddingEnc.Dump(stderr, "paddingenc:"); //paddingEnc.Dump(stderr, "paddingenc:");
......
...@@ -90,9 +90,9 @@ public: ...@@ -90,9 +90,9 @@ public:
void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc); void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc);
/* make the mask of the decoder */ /* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec, void MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec,
XTensor &paddingEnc, XTensor &paddingDec, XTensor& paddingEnc, XTensor& paddingDec,
XTensor &maskDec, XTensor &maskEncDec, int incDim); XTensor& maskDec, XTensor& maskEncDec);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(TensorList &list); void GetParams(TensorList &list);
......
...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
inputDec = GetLastPrediction(s); inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID); inputDec.SetDevice(inputEnc->devID);
} }
inputDec.Dump(stderr, "inputDec");
/* prediction probabilities */ /* prediction probabilities */
XTensor& output = next->prob; XTensor& output = next->prob;
...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
XTensor maskEncDec; XTensor maskEncDec;
/* decoder mask */ /* decoder mask */
//m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network */ /* make the decoding network */
decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false); decoding = m->decoder->Make(inputDec, *encoding, &maskDec, maskEncDec, false);
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!"); CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
......
...@@ -62,6 +62,12 @@ private: ...@@ -62,6 +62,12 @@ private:
/* start symbol */ /* start symbol */
int startSymbol; int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public: public:
/* constructor */ /* constructor */
T2TSearch(); T2TSearch();
...@@ -73,7 +79,7 @@ public: ...@@ -73,7 +79,7 @@ public:
void Init(int argc, char** argv); void Init(int argc, char** argv);
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output); void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output, XTensor* score);
/* preparation */ /* preparation */
void Prepare(int myBatchSize, int myBeamSize); void Prepare(int myBatchSize, int myBeamSize);
...@@ -93,12 +99,15 @@ public: ...@@ -93,12 +99,15 @@ public:
/* fill the hypotheis heap with incomplete hypothses */ /* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle* beam); void FillHeap(T2TStateBundle* beam);
/* save the output sequences in a tensor */ /* save the output sequences and score */
void Dump(XTensor* output); void Dump(XTensor* output, XTensor* score);
/* check if the token is an end symbol */ /* check if the token is an end symbol */
bool IsEnd(int token); bool IsEnd(int token);
/*check whether all hypotheses are completed*/
bool IsAllCompleted(T2TStateBundle* beam);
/* set end symbols for search */ /* set end symbols for search */
void SetEnd(const int* tokens, const int tokenNum); void SetEnd(const int* tokens, const int tokenNum);
......
...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID); vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output; XTensor output;
XTensor score;
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(stderr); seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
Result res; Result res;
XTensor sent, srcIdx, tgtIdx; XTensor sent, srcIdx, tgtIdx;
...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
if (batchCount % 1 == 0) { if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, XPRINT3(0, stderr, "[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n", elapsed, sentCount, wordCount);
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
} }
} }
...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output) ...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output)
for (int i = 0; i < output->unitNum; i += seqLength) { for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) { for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j); int w = output->GetInt(i + j);
fprintf(file, "%d ", w); if (w < 0 || w == 1)
if (w < 0)
break; break;
fprintf(file, "%d ", w);
} }
fprintf(file, "\n"); fprintf(file, "\n");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论