Commit c177e3c5 by huchi

fix some bugs in beam search

parent e155205c
...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv ) ...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv )
_CrtSetBreakAlloc(2708);*/ _CrtSetBreakAlloc(2708);*/
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
//XTensor singleScore, singleIdx, score;
//InitTensor3DV2(&score, 2, 1, 136160);
////score.SetDataRand(0, 1);
//InitTensor1DV2(&singleIdx, 1, X_INT);
//singleIdx.Set1DInt(1, 0);
//singleIdx.Dump(stderr);
//singleScore = Select(score, singleIdx, 0);
//XTensor s, i;
//InitTensor3DV2(&s, 2, 1, 4);
//InitTensor3DV2(&i, 2, 1, 4, X_INT);
//TopK(score, s, i, -1, 4);
//i.Dump(stderr, "single score:\n");
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
return 0; return 0;
......
...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe ...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe
MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc); MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
/* decoder mask */ /* decoder mask */
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec, 0); MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, &maskEnc, isTraining); encoding = MakeEncoder(inputEnc, &maskEnc, isTraining);
...@@ -225,7 +225,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -225,7 +225,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
InitTensorV2(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
/* encoder-decoder mask that prevents the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F); _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -298,13 +296,13 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma ...@@ -298,13 +296,13 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor* padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
...@@ -312,7 +310,7 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma ...@@ -312,7 +310,7 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -332,32 +330,36 @@ make the mask of the decoder ...@@ -332,32 +330,36 @@ make the mask of the decoder
>> maksDec - mask of the decoder self-attention >> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention >> maksEncDec - mask of the decoder enc-dec attention
*/ */
void T2TModel::MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec, void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor& paddingEnc, XTensor& paddingDec, XTensor &paddingEnc, XTensor &paddingDec,
XTensor& maskDec, XTensor& maskEncDec, int incDim) XTensor &maskDec, XTensor &maskEncDec)
{ {
int len = inputDec.GetDim(inputDec.order - 1); int len = inputDec.GetDim(inputDec.order - 1);
int* dims = new int[inputDec.order + 2]; int * dims = new int[inputDec.order + 2];
for (int i = 0; i < inputDec.order; i++) for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
//dims[inputDec.order] += incDim;
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
//InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* An upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
This matrix can be used to block the attention to current or following words in This matrix can be used to block the attention to current or following words in
a given sequence. */ a given sequence. */
//_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
//maskDec.Dump(stderr, "mask: ");
//_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
//maskDec.Dump(stderr, "mask: ");
/* encoder-decoder mask that prevents the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
......
...@@ -90,9 +90,9 @@ public: ...@@ -90,9 +90,9 @@ public:
void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc); void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc);
/* make the mask of the decoder */ /* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec, void MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec,
XTensor &paddingEnc, XTensor &paddingDec, XTensor& paddingEnc, XTensor& paddingDec,
XTensor &maskDec, XTensor &maskEncDec, int incDim); XTensor& maskDec, XTensor& maskEncDec);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(TensorList &list); void GetParams(TensorList &list);
......
...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
inputDec = GetLastPrediction(s); inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID); inputDec.SetDevice(inputEnc->devID);
} }
inputDec.Dump(stderr, "inputDec");
/* prediction probabilities */ /* prediction probabilities */
XTensor& output = next->prob; XTensor& output = next->prob;
...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
XTensor maskEncDec; XTensor maskEncDec;
/* decoder mask */ /* decoder mask */
//m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network */ /* make the decoding network */
decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false); decoding = m->decoder->Make(inputDec, *encoding, &maskDec, maskEncDec, false);
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!"); CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
......
...@@ -62,6 +62,12 @@ private: ...@@ -62,6 +62,12 @@ private:
/* start symbol */ /* start symbol */
int startSymbol; int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public: public:
/* constructor */ /* constructor */
T2TSearch(); T2TSearch();
...@@ -73,7 +79,7 @@ public: ...@@ -73,7 +79,7 @@ public:
void Init(int argc, char** argv); void Init(int argc, char** argv);
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output); void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output, XTensor* score);
/* preparation */ /* preparation */
void Prepare(int myBatchSize, int myBeamSize); void Prepare(int myBatchSize, int myBeamSize);
...@@ -93,12 +99,15 @@ public: ...@@ -93,12 +99,15 @@ public:
/* fill the hypotheis heap with incomplete hypothses */ /* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle* beam); void FillHeap(T2TStateBundle* beam);
/* save the output sequences in a tensor */ /* save the output sequences and score */
void Dump(XTensor* output); void Dump(XTensor* output, XTensor* score);
/* check if the token is an end symbol */ /* check if the token is an end symbol */
bool IsEnd(int token); bool IsEnd(int token);
/*check whether all hypotheses are completed*/
bool IsAllCompleted(T2TStateBundle* beam);
/* set end symbols for search */ /* set end symbols for search */
void SetEnd(const int* tokens, const int tokenNum); void SetEnd(const int* tokens, const int tokenNum);
......
...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID); vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output; XTensor output;
XTensor score;
seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(stderr);
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
Result res; Result res;
XTensor sent, srcIdx, tgtIdx; XTensor sent, srcIdx, tgtIdx;
...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
if (batchCount % 1 == 0) { if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, XPRINT3(0, stderr, "[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n", elapsed, sentCount, wordCount);
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
} }
} }
...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output) ...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output)
for (int i = 0; i < output->unitNum; i += seqLength) { for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) { for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j); int w = output->GetInt(i + j);
fprintf(file, "%d ", w); if (w < 0 || w == 1)
if (w < 0)
break; break;
fprintf(file, "%d ", w);
} }
fprintf(file, "\n"); fprintf(file, "\n");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论