Commit c177e3c5 by huchi

fix some bugs in beam search

parent e155205c
...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv ) ...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv )
_CrtSetBreakAlloc(2708);*/ _CrtSetBreakAlloc(2708);*/
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
//XTensor singleScore, singleIdx, score;
//InitTensor3DV2(&score, 2, 1, 136160);
////score.SetDataRand(0, 1);
//InitTensor1DV2(&singleIdx, 1, X_INT);
//singleIdx.Set1DInt(1, 0);
//singleIdx.Dump(stderr);
//singleScore = Select(score, singleIdx, 0);
//XTensor s, i;
//InitTensor3DV2(&s, 2, 1, 4);
//InitTensor3DV2(&i, 2, 1, 4, X_INT);
//TopK(score, s, i, -1, 4);
//i.Dump(stderr, "single score:\n");
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
return 0; return 0;
......
...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe ...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe
MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc); MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
/* decoder mask */ /* decoder mask */
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec, 0); MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, &maskEnc, isTraining); encoding = MakeEncoder(inputEnc, &maskEnc, isTraining);
...@@ -225,8 +225,8 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -225,8 +225,8 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
InitTensorV2(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
a given sequence. */ a given sequence. */
...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
/* encoder-decoder mask that prevents the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F); _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -297,22 +295,22 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma ...@@ -297,22 +295,22 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
dimsPadding[i] = paddingEnc.GetDim(i); dimsPadding[i] = paddingEnc.GetDim(i);
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor* padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
_Unsqueeze(padding2, padding3, 0, nhead); _Unsqueeze(padding2, padding3, 0, nhead);
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -332,33 +330,37 @@ make the mask of the decoder ...@@ -332,33 +330,37 @@ make the mask of the decoder
>> maksDec - mask of the decoder self-attention >> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention >> maksEncDec - mask of the decoder enc-dec attention
*/ */
void T2TModel::MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec, void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor& paddingEnc, XTensor& paddingDec, XTensor &paddingEnc, XTensor &paddingDec,
XTensor& maskDec, XTensor& maskEncDec, int incDim) XTensor &maskDec, XTensor &maskEncDec)
{ {
int len = inputDec.GetDim(inputDec.order - 1); int len = inputDec.GetDim(inputDec.order - 1);
int* dims = new int[inputDec.order + 2]; int * dims = new int[inputDec.order + 2];
for (int i = 0; i < inputDec.order; i++) for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
//dims[inputDec.order] += incDim;
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
//InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* An upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
This matrix can be used to block the attention to current or following words in This matrix can be used to block the attention to current or following words in
a given sequence. */ a given sequence. */
//_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
//_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevents the attention to padding dummy words */ //maskDec.Dump(stderr, "mask: ");
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
//maskDec.Dump(stderr, "mask: ");
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
//paddingEnc.Dump(stderr, "paddingenc:"); //paddingEnc.Dump(stderr, "paddingenc:");
......
...@@ -90,9 +90,9 @@ public: ...@@ -90,9 +90,9 @@ public:
void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc); void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc);
/* make the mask of the decoder */ /* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec, void MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec,
XTensor &paddingEnc, XTensor &paddingDec, XTensor& paddingEnc, XTensor& paddingDec,
XTensor &maskDec, XTensor &maskEncDec, int incDim); XTensor& maskDec, XTensor& maskEncDec);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(TensorList &list); void GetParams(TensorList &list);
......
...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
inputDec = GetLastPrediction(s); inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID); inputDec.SetDevice(inputEnc->devID);
} }
inputDec.Dump(stderr, "inputDec");
/* prediction probabilities */ /* prediction probabilities */
XTensor& output = next->prob; XTensor& output = next->prob;
...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
XTensor maskEncDec; XTensor maskEncDec;
/* decoder mask */ /* decoder mask */
//m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network */ /* make the decoding network */
decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false); decoding = m->decoder->Make(inputDec, *encoding, &maskDec, maskEncDec, false);
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!"); CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
......
...@@ -38,7 +38,7 @@ T2TSearch::T2TSearch() ...@@ -38,7 +38,7 @@ T2TSearch::T2TSearch()
endSymbolNum = 0; endSymbolNum = 0;
fullHypos = NULL; fullHypos = NULL;
endSymbols = new int[32]; endSymbols = new int[32];
startSymbol = 2; startSymbol = -1;
} }
/* de-constructor */ /* de-constructor */
...@@ -60,8 +60,10 @@ void T2TSearch::Init(int argc, char** argv) ...@@ -60,8 +60,10 @@ void T2TSearch::Init(int argc, char** argv)
LoadParamInt(argc, argv, "beamsize", &beamSize, 1); LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1); LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F);
LoadParamInt(argc, argv, "endid", endSymbols, 2); LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, 2); LoadParamInt(argc, argv, "startid", &startSymbol, -1);
LoadParamFloat(argc, argv, "maxlenalpha", &scalarMaxLength, 2.0F);
LoadParamBool(argc, argv, "earlystop", &isEarlyStop, false);
if (endSymbols[0] >= 0) if (endSymbols[0] >= 0)
endSymbolNum = 1; endSymbolNum = 1;
...@@ -73,8 +75,10 @@ search for the most promising states ...@@ -73,8 +75,10 @@ search for the most promising states
>> input - input of the model >> input - input of the model
>> padding - padding of the input >> padding - padding of the input
>> output - output that represents the sequences as rows >> output - output that represents the sequences as rows
>> score - score of the sequences
*/ */
void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output) void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding,
XTensor * output, XTensor * score)
{ {
T2TPredictor predictor; T2TPredictor predictor;
XTensor maskEnc; XTensor maskEnc;
...@@ -89,7 +93,7 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -89,7 +93,7 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
Prepare(input->unitNum / input->GetDim(-1), beamSize); Prepare(input->unitNum / input->GetDim(-1), beamSize);
/* encoder mask */ /* encoder mask */
//model->MakeMTMaskEnc(*input, *padding, maskEnc); model->MakeMTMaskEnc(*input, *padding, maskEnc);
/* make the encoding network */ /* make the encoding network */
encoding = model->MakeEncoder(*input, &maskEnc, false); encoding = model->MakeEncoder(*input, &maskEnc, false);
...@@ -101,16 +105,17 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -101,16 +105,17 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
encodingBeam.ReshapeMerged(encodingBeam.order - 4); encodingBeam.ReshapeMerged(encodingBeam.order - 4);
inputBeam.ReshapeMerged(inputBeam.order - 3); inputBeam.ReshapeMerged(inputBeam.order - 3);
paddingBeam.ReshapeMerged(paddingBeam.order - 3); paddingBeam.ReshapeMerged(paddingBeam.order - 3);
/* max output-length = 2 * source-length */ /* max output-length = scalar * source-length */
maxLength = input->GetDim(-1) * 2; int lengthLimit = (int)(input->GetDim(-1) * scalarMaxLength);
CheckNTErrors(maxLength > 0, "no max length specified!"); CheckNTErrors(lengthLimit > 0, "no max length specified!");
maxLength = lengthLimit;
T2TStateBundle* states = new T2TStateBundle[maxLength + 1];
T2TStateBundle* first = states; T2TStateBundle * states = new T2TStateBundle[lengthLimit + 1];
T2TStateBundle* cur; T2TStateBundle * first = states;
T2TStateBundle* next; T2TStateBundle * cur = NULL;
T2TStateBundle * next = NULL;
/* create the first state */ /* create the first state */
predictor.Create(model, &encodingBeam, input, beamSize, first); predictor.Create(model, &encodingBeam, input, beamSize, first);
predictor.SetStartSymbol(startSymbol); predictor.SetStartSymbol(startSymbol);
...@@ -118,15 +123,15 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -118,15 +123,15 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
first->isStart = true; first->isStart = true;
/* generate the sequence from left to right */ /* generate the sequence from left to right */
for (int i = 0; i < maxLength; i++) { for(int l = 0 ; l < lengthLimit; l++){
cur = states + i; cur = states + l;
next = states + i + 1; next = states + l + 1;
/* read the current state */ /* read the current state */
predictor.Read(model, cur); predictor.Read(model, cur);
/* predict the next state */ /* predict the next state */
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, i == 0); predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, l == 0);
/* compute the model score (given the prediction probability) */ /* compute the model score (given the prediction probability) */
Score(cur, next); Score(cur, next);
...@@ -139,12 +144,18 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -139,12 +144,18 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
/* push complete hypotheses into the heap */ /* push complete hypotheses into the heap */
Collect(next); Collect(next);
/* stop searching when all hypotheses are completed */
if(IsAllCompleted(next)){
maxLength = l + 1;
break;
}
} }
/* fill the heap with imcomplete hypotheses if neccesary */ /* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(next); FillHeap(next);
Dump(output); Dump(output, score);
delete[] states; delete[] states;
} }
...@@ -214,8 +225,7 @@ void T2TSearch::Score(T2TStateBundle* prev, T2TStateBundle* beam) ...@@ -214,8 +225,7 @@ void T2TSearch::Score(T2TStateBundle* prev, T2TStateBundle* beam)
_DivDim(&probPath, &lp, &score, 0); _DivDim(&probPath, &lp, &score, 0);
if (prev->isStart) { if (prev->isStart) {
XTensor firstMask; XTensor firstMask = MakeFirstMask(beam);
firstMask = MakeFirstMask(beam);
firstMask.Reshape(firstMask.unitNum); firstMask.Reshape(firstMask.unitNum);
/* mask the hypotheses in the beam except the first one */ /* mask the hypotheses in the beam except the first one */
...@@ -252,22 +262,24 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -252,22 +262,24 @@ void T2TSearch::Generate(T2TStateBundle* beam)
int dimsTopK[MAX_TENSOR_DIM_NUM]; int dimsTopK[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK; XTensor scoreTopK;
XTensor& score = beam->modelScore; XTensor indexCPU;
XTensor& index = beam->prediction; XTensor &score = beam->modelScore;
XTensor& preID = beam->preID; XTensor &index = beam->prediction;
XTensor& probPath = beam->probPath; XTensor &preID = beam->preID;
XTensor& prob = beam->prob; XTensor &probPath = beam->probPath;
XTensor &prob = beam->prob;
int order = score.order; int order = score.order;
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) { for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i); dims[i] = score.GetDim(i);
dimsBeam[i] = score.GetDim(i); dimsBeam[i] = score.GetDim(i);
dimsTopK[i] = score.GetDim(i); dimsTopK[i] = score.GetDim(i);
} }
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
int sizeVocab = score.GetDim(-1); int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1); int stride = score.GetDim(-1);
...@@ -279,23 +291,23 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -279,23 +291,23 @@ void T2TSearch::Generate(T2TStateBundle* beam)
InitTensorV2(&scoreTopK, order, dimsTopK, score.dataType, 1.0F, score.devID); InitTensorV2(&scoreTopK, order, dimsTopK, score.dataType, 1.0F, score.devID);
InitTensorV2(&index, order, dimsTopK, X_INT, 1.0F, score.devID); InitTensorV2(&index, order, dimsTopK, X_INT, 1.0F, score.devID);
InitTensorV2(&preID, order, dimsTopK, X_INT, 1.0F, -1); InitTensorV2(&preID, order, dimsTopK, X_INT, 1.0F, -1);
InitTensorV2(&indexCPU, order, dimsTopK, X_INT, 1.0F, -1);
/* mask the first and the padding id */ /* TODO: check the mask - mask the first and the padding id */
int dimMask[]{ score.GetDim(-1) }; /*int dimMask[]{ score.GetDim(-1) };
XTensor mask; XTensor mask;
InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1); InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1);
mask.SetZeroAll(); mask.SetZeroAll();
mask.Set1D(-1e9F, 0); mask.Set1D(-1e9F, 0);
mask.Set1D(-1e9F, 1); mask.Set1D(-1e9F, 1);
mask.SetDevice(score.devID, score.mem); mask.SetDevice(score.devID);
_SumDim(&score, &mask, 2);*/
_SumDim(&score, &mask, 2);
score.Reshape(order, dimsBeam); score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */ /* keep the most promissing candidates in the beam */
/* TODO: check this line */
TopK(score, scoreTopK, index, -1, beamSize); TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, indexCPU);
CopyValues(index, preID); CopyValues(index, preID);
/* "preID" represents the id (or the offset) of the previous state used to make the current /* "preID" represents the id (or the offset) of the previous state used to make the current
...@@ -317,12 +329,10 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -317,12 +329,10 @@ void T2TSearch::Generate(T2TStateBundle* beam)
CopyValues(scoreTopK, score); CopyValues(scoreTopK, score);
/* CPU data (TODO: remove GPU->CPU data copy!!!) */ /* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexGPU; for (int i = 0; i < indexCPU.unitNum; i += beamSize){
indexGPU = CopyValues(index); for (int j = 0; j < beamSize; j++) {
indexCPU.SetInt(i * stride + indexCPU.GetInt(i + j), i + j);
for (int i = 0; i < indexGPU.unitNum; i += beamSize) { }
for (int j = 0; j < beamSize; j++)
indexGPU.SetInt(i * stride + indexGPU.GetInt(i + j), i + j);
} }
CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!"); CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!");
...@@ -339,13 +349,13 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -339,13 +349,13 @@ void T2TSearch::Generate(T2TStateBundle* beam)
} }
order = probPath.order; order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
prob.Reshape(1, prob.unitNum);
probTopK.Reshape(1, probTopK.unitNum);
_CopyIndexed(&probPath, &probPathTopK, probPathTopK.order - 1, &indexGPU); prob.Reshape(prob.unitNum, 1);
_CopyIndexed(&prob, &probTopK, probTopK.order - 1, &indexGPU); probPath.Reshape(probPath.unitNum, 1);
indexCPU.Reshape(indexCPU.GetDim(0), indexCPU.GetDim(-1));
probTopK = Gather(prob, indexCPU);
probPathTopK = Gather(probPath, indexCPU);
probPath.Reshape(order, dims); probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK); probPathTopK.Reshape(order, dimsTopK);
...@@ -440,7 +450,9 @@ void T2TSearch::Expand(T2TStateBundle* prev, T2TStateBundle* beam) ...@@ -440,7 +450,9 @@ void T2TSearch::Expand(T2TStateBundle* prev, T2TStateBundle* beam)
CheckNTErrors(state.prediction >= 0, "Illegal prediction!"); CheckNTErrors(state.prediction >= 0, "Illegal prediction!");
/* check if it is the end of the sequence */ /* check if it is the end of the sequence */
state.isEnd = IsEnd(state.prediction); state.isEnd = IsEnd(state.prediction);
state.isCompleted = (state.isCompleted || state.isEnd); state.isCompleted = (state.isCompleted || state.isEnd);
/* set the ending mark */ /* set the ending mark */
...@@ -471,8 +483,9 @@ void T2TSearch::Collect(T2TStateBundle* beam) ...@@ -471,8 +483,9 @@ void T2TSearch::Collect(T2TStateBundle* beam)
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted); bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the hypothesis into the heap when it is completed */ /* we push the hypothesis into the heap when it is completed */
if (state.isEnd != 0) if (state.isEnd && isCompleted) {
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore)); fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
}
} }
} }
...@@ -492,10 +505,14 @@ void T2TSearch::FillHeap(T2TStateBundle* beam) ...@@ -492,10 +505,14 @@ void T2TSearch::FillHeap(T2TStateBundle* beam)
T2TState& state = states[i]; T2TState& state = states[i];
CheckNTErrors(state.pid >= 0 && state.pid < batchSize, CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!"); "Invalid sample id!");
/* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the imcomplete hypothesis into the heap */ /* we push the imcomplete hypothesis into the heap */
if (emptyFlags[state.pid] && state.isEnd == 0) if (emptyFlags[state.pid] || state.isEnd || isCompleted)
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore)); fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
} }
...@@ -505,24 +522,29 @@ void T2TSearch::FillHeap(T2TStateBundle* beam) ...@@ -505,24 +522,29 @@ void T2TSearch::FillHeap(T2TStateBundle* beam)
/* /*
save the output sequences in a tensor save the output sequences in a tensor
>> output - output sequences (for return) >> output - output sequences (for return)
>> score - score of thes sequences
*/ */
void T2TSearch::Dump(XTensor* output) void T2TSearch::Dump(XTensor * output, XTensor * score)
{ {
int dims[3] = { batchSize, beamSize, maxLength }; int dims[3] = { batchSize, beamSize, maxLength };
int* words = new int[maxLength]; int* words = new int[maxLength];
InitTensorV2(output, 3, dims, X_INT); InitTensorV2(output, 3, dims, X_INT);
InitTensorV2(score, 2, dims, X_FLOAT);
SetDataFixedInt(*output, -1); SetDataFixedInt(*output, -1);
score->SetZeroAll();
/* heap for an input sentence in the batch */ /* heap for an input sentence in the batch */
for (int h = 0; h < batchSize; h++) { for (int h = 0; h < batchSize; h++) {
XHeap<MIN_HEAP, float>& heap = fullHypos[h]; XHeap<MIN_HEAP, float> &heap = fullHypos[h];
int c = heap.Count();
/* for each output in the beam */ /* for each output in the beam */
for (int i = 0; i < beamSize && heap.Count() > 0; i++) { for(int i = 0; i < beamSize && heap.Count() > 0; i++){
T2TState* state = (T2TState*)heap.Pop().index; HeapNode<float> node = heap.Pop();
T2TState * state = (T2TState *)node.index;
int count = 0; int count = 0;
bool isCompleted = true; bool isCompleted = true;
...@@ -531,15 +553,17 @@ void T2TSearch::Dump(XTensor* output) ...@@ -531,15 +553,17 @@ void T2TSearch::Dump(XTensor* output)
if (!state->isCompleted) if (!state->isCompleted)
isCompleted = false; isCompleted = false;
if (isCompleted) if (isCompleted)
words[count++] = -1; words[count++] = 2;
else else
words[count++] = state->prediction; words[count++] = state->prediction;
state = state->last; state = state->last;
} }
/* dump the sentence to the output tensor */ /* dump the sentence to the output tensor */
for (int w = 0; w < count; w++) for(int w = 0; w < count; w++)
output->Set3DInt(words[count - w - 1], h, beamSize - i - 1, w); output->Set3DInt(words[count - w - 1], h, c - i - 1, w);
score->Set2D(node.value, h, c - i - 1);
} }
} }
...@@ -582,6 +606,23 @@ void T2TSearch::SetEnd(const int* tokens, const int tokenNum) ...@@ -582,6 +606,23 @@ void T2TSearch::SetEnd(const int* tokens, const int tokenNum)
endSymbolNum = tokenNum; endSymbolNum = tokenNum;
} }
/*
check whether all hypotheses are completed
>> beam - the beam that keeps the searching states
*/
bool T2TSearch::IsAllCompleted(T2TStateBundle * beam)
{
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
if(!state.isCompleted)
return false;
}
return true;
}
/* /*
make a mask to prevent duplicated entries in beam expansion for the first position make a mask to prevent duplicated entries in beam expansion for the first position
>> beam - the beam that keeps the searching states >> beam - the beam that keeps the searching states
......
...@@ -62,6 +62,12 @@ private: ...@@ -62,6 +62,12 @@ private:
/* start symbol */ /* start symbol */
int startSymbol; int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public: public:
/* constructor */ /* constructor */
T2TSearch(); T2TSearch();
...@@ -73,7 +79,7 @@ public: ...@@ -73,7 +79,7 @@ public:
void Init(int argc, char** argv); void Init(int argc, char** argv);
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output); void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output, XTensor* score);
/* preparation */ /* preparation */
void Prepare(int myBatchSize, int myBeamSize); void Prepare(int myBatchSize, int myBeamSize);
...@@ -93,12 +99,15 @@ public: ...@@ -93,12 +99,15 @@ public:
/* fill the hypotheis heap with incomplete hypothses */ /* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle* beam); void FillHeap(T2TStateBundle* beam);
/* save the output sequences in a tensor */ /* save the output sequences and score */
void Dump(XTensor* output); void Dump(XTensor* output, XTensor* score);
/* check if the token is an end symbol */ /* check if the token is an end symbol */
bool IsEnd(int token); bool IsEnd(int token);
/*check whether all hypotheses are completed*/
bool IsAllCompleted(T2TStateBundle* beam);
/* set end symbols for search */ /* set end symbols for search */
void SetEnd(const int* tokens, const int tokenNum); void SetEnd(const int* tokens, const int tokenNum);
......
...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID); vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output; XTensor output;
XTensor score;
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(stderr); seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
Result res; Result res;
XTensor sent, srcIdx, tgtIdx; XTensor sent, srcIdx, tgtIdx;
...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
if (batchCount % 1 == 0) { if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, XPRINT3(0, stderr, "[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n", elapsed, sentCount, wordCount);
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
} }
} }
...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output) ...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output)
for (int i = 0; i < output->unitNum; i += seqLength) { for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) { for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j); int w = output->GetInt(i + j);
fprintf(file, "%d ", w); if (w < 0 || w == 1)
if (w < 0)
break; break;
fprintf(file, "%d ", w);
} }
fprintf(file, "\n"); fprintf(file, "\n");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论