Commit c177e3c5 by huchi

fix some bugs in beam search

parent e155205c
...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv ) ...@@ -43,6 +43,21 @@ int main( int argc, const char ** argv )
_CrtSetBreakAlloc(2708);*/ _CrtSetBreakAlloc(2708);*/
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
//XTensor singleScore, singleIdx, score;
//InitTensor3DV2(&score, 2, 1, 136160);
////score.SetDataRand(0, 1);
//InitTensor1DV2(&singleIdx, 1, X_INT);
//singleIdx.Set1DInt(1, 0);
//singleIdx.Dump(stderr);
//singleScore = Select(score, singleIdx, 0);
//XTensor s, i;
//InitTensor3DV2(&s, 2, 1, 4);
//InitTensor3DV2(&i, 2, 1, 4, X_INT);
//TopK(score, s, i, -1, 4);
//i.Dump(stderr, "single score:\n");
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
return 0; return 0;
......
...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe ...@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe
MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc); MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
/* decoder mask */ /* decoder mask */
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec, 0); MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, &maskEnc, isTraining); encoding = MakeEncoder(inputEnc, &maskEnc, isTraining);
...@@ -225,7 +225,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -225,7 +225,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
InitTensorV2(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
/* encoder-decoder mask that prevents the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F); _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -298,13 +296,13 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma ...@@ -298,13 +296,13 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor* padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
...@@ -312,7 +310,7 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma ...@@ -312,7 +310,7 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
_ScaleAndShiftMe(padding3, 1e9F, -1e9F); _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3); InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll(); maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
...@@ -332,32 +330,36 @@ make the mask of the decoder ...@@ -332,32 +330,36 @@ make the mask of the decoder
>> maksDec - mask of the decoder self-attention >> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention >> maksEncDec - mask of the decoder enc-dec attention
*/ */
void T2TModel::MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec, void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor& paddingEnc, XTensor& paddingDec, XTensor &paddingEnc, XTensor &paddingDec,
XTensor& maskDec, XTensor& maskEncDec, int incDim) XTensor &maskDec, XTensor &maskEncDec)
{ {
int len = inputDec.GetDim(inputDec.order - 1); int len = inputDec.GetDim(inputDec.order - 1);
int* dims = new int[inputDec.order + 2]; int * dims = new int[inputDec.order + 2];
for (int i = 0; i < inputDec.order; i++) for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
//dims[inputDec.order] += incDim;
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
//InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* An upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
This matrix can be used to block the attention to current or following words in This matrix can be used to block the attention to current or following words in
a given sequence. */ a given sequence. */
//_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
//maskDec.Dump(stderr, "mask: ");
//_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
//maskDec.Dump(stderr, "mask: ");
/* encoder-decoder mask that prevents the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID); paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
......
...@@ -90,9 +90,9 @@ public: ...@@ -90,9 +90,9 @@ public:
void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc); void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc);
/* make the mask of the decoder */ /* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec, void MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec,
XTensor &paddingEnc, XTensor &paddingDec, XTensor& paddingEnc, XTensor& paddingDec,
XTensor &maskDec, XTensor &maskEncDec, int incDim); XTensor& maskDec, XTensor& maskEncDec);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(TensorList &list); void GetParams(TensorList &list);
......
...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
inputDec = GetLastPrediction(s); inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID); inputDec.SetDevice(inputEnc->devID);
} }
inputDec.Dump(stderr, "inputDec");
/* prediction probabilities */ /* prediction probabilities */
XTensor& output = next->prob; XTensor& output = next->prob;
...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
XTensor maskEncDec; XTensor maskEncDec;
/* decoder mask */ /* decoder mask */
//m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network */ /* make the decoding network */
decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false); decoding = m->decoder->Make(inputDec, *encoding, &maskDec, maskEncDec, false);
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!"); CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
......
...@@ -38,7 +38,7 @@ T2TSearch::T2TSearch() ...@@ -38,7 +38,7 @@ T2TSearch::T2TSearch()
endSymbolNum = 0; endSymbolNum = 0;
fullHypos = NULL; fullHypos = NULL;
endSymbols = new int[32]; endSymbols = new int[32];
startSymbol = 2; startSymbol = -1;
} }
/* de-constructor */ /* de-constructor */
...@@ -60,8 +60,10 @@ void T2TSearch::Init(int argc, char** argv) ...@@ -60,8 +60,10 @@ void T2TSearch::Init(int argc, char** argv)
LoadParamInt(argc, argv, "beamsize", &beamSize, 1); LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1); LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F);
LoadParamInt(argc, argv, "endid", endSymbols, 2); LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, 2); LoadParamInt(argc, argv, "startid", &startSymbol, -1);
LoadParamFloat(argc, argv, "maxlenalpha", &scalarMaxLength, 2.0F);
LoadParamBool(argc, argv, "earlystop", &isEarlyStop, false);
if (endSymbols[0] >= 0) if (endSymbols[0] >= 0)
endSymbolNum = 1; endSymbolNum = 1;
...@@ -73,8 +75,10 @@ search for the most promising states ...@@ -73,8 +75,10 @@ search for the most promising states
>> input - input of the model >> input - input of the model
>> padding - padding of the input >> padding - padding of the input
>> output - output that represents the sequences as rows >> output - output that represents the sequences as rows
>> score - score of the sequences
*/ */
void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output) void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding,
XTensor * output, XTensor * score)
{ {
T2TPredictor predictor; T2TPredictor predictor;
XTensor maskEnc; XTensor maskEnc;
...@@ -89,7 +93,7 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -89,7 +93,7 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
Prepare(input->unitNum / input->GetDim(-1), beamSize); Prepare(input->unitNum / input->GetDim(-1), beamSize);
/* encoder mask */ /* encoder mask */
//model->MakeMTMaskEnc(*input, *padding, maskEnc); model->MakeMTMaskEnc(*input, *padding, maskEnc);
/* make the encoding network */ /* make the encoding network */
encoding = model->MakeEncoder(*input, &maskEnc, false); encoding = model->MakeEncoder(*input, &maskEnc, false);
...@@ -102,14 +106,15 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -102,14 +106,15 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
inputBeam.ReshapeMerged(inputBeam.order - 3); inputBeam.ReshapeMerged(inputBeam.order - 3);
paddingBeam.ReshapeMerged(paddingBeam.order - 3); paddingBeam.ReshapeMerged(paddingBeam.order - 3);
/* max output-length = 2 * source-length */ /* max output-length = scalar * source-length */
maxLength = input->GetDim(-1) * 2; int lengthLimit = (int)(input->GetDim(-1) * scalarMaxLength);
CheckNTErrors(maxLength > 0, "no max length specified!"); CheckNTErrors(lengthLimit > 0, "no max length specified!");
maxLength = lengthLimit;
T2TStateBundle* states = new T2TStateBundle[maxLength + 1]; T2TStateBundle * states = new T2TStateBundle[lengthLimit + 1];
T2TStateBundle* first = states; T2TStateBundle * first = states;
T2TStateBundle* cur; T2TStateBundle * cur = NULL;
T2TStateBundle* next; T2TStateBundle * next = NULL;
/* create the first state */ /* create the first state */
predictor.Create(model, &encodingBeam, input, beamSize, first); predictor.Create(model, &encodingBeam, input, beamSize, first);
...@@ -118,15 +123,15 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -118,15 +123,15 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
first->isStart = true; first->isStart = true;
/* generate the sequence from left to right */ /* generate the sequence from left to right */
for (int i = 0; i < maxLength; i++) { for(int l = 0 ; l < lengthLimit; l++){
cur = states + i; cur = states + l;
next = states + i + 1; next = states + l + 1;
/* read the current state */ /* read the current state */
predictor.Read(model, cur); predictor.Read(model, cur);
/* predict the next state */ /* predict the next state */
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, i == 0); predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, l == 0);
/* compute the model score (given the prediction probability) */ /* compute the model score (given the prediction probability) */
Score(cur, next); Score(cur, next);
...@@ -139,12 +144,18 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso ...@@ -139,12 +144,18 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
/* push complete hypotheses into the heap */ /* push complete hypotheses into the heap */
Collect(next); Collect(next);
/* stop searching when all hypotheses are completed */
if(IsAllCompleted(next)){
maxLength = l + 1;
break;
}
} }
/* fill the heap with imcomplete hypotheses if neccesary */ /* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(next); FillHeap(next);
Dump(output); Dump(output, score);
delete[] states; delete[] states;
} }
...@@ -214,8 +225,7 @@ void T2TSearch::Score(T2TStateBundle* prev, T2TStateBundle* beam) ...@@ -214,8 +225,7 @@ void T2TSearch::Score(T2TStateBundle* prev, T2TStateBundle* beam)
_DivDim(&probPath, &lp, &score, 0); _DivDim(&probPath, &lp, &score, 0);
if (prev->isStart) { if (prev->isStart) {
XTensor firstMask; XTensor firstMask = MakeFirstMask(beam);
firstMask = MakeFirstMask(beam);
firstMask.Reshape(firstMask.unitNum); firstMask.Reshape(firstMask.unitNum);
/* mask the hypotheses in the beam except the first one */ /* mask the hypotheses in the beam except the first one */
...@@ -252,15 +262,14 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -252,15 +262,14 @@ void T2TSearch::Generate(T2TStateBundle* beam)
int dimsTopK[MAX_TENSOR_DIM_NUM]; int dimsTopK[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK; XTensor scoreTopK;
XTensor& score = beam->modelScore; XTensor indexCPU;
XTensor& index = beam->prediction; XTensor &score = beam->modelScore;
XTensor& preID = beam->preID; XTensor &index = beam->prediction;
XTensor& probPath = beam->probPath; XTensor &preID = beam->preID;
XTensor& prob = beam->prob; XTensor &probPath = beam->probPath;
int order = score.order; XTensor &prob = beam->prob;
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger."); int order = score.order;
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) { for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i); dims[i] = score.GetDim(i);
...@@ -268,6 +277,9 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -268,6 +277,9 @@ void T2TSearch::Generate(T2TStateBundle* beam)
dimsTopK[i] = score.GetDim(i); dimsTopK[i] = score.GetDim(i);
} }
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
int sizeVocab = score.GetDim(-1); int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1); int stride = score.GetDim(-1);
...@@ -279,23 +291,23 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -279,23 +291,23 @@ void T2TSearch::Generate(T2TStateBundle* beam)
InitTensorV2(&scoreTopK, order, dimsTopK, score.dataType, 1.0F, score.devID); InitTensorV2(&scoreTopK, order, dimsTopK, score.dataType, 1.0F, score.devID);
InitTensorV2(&index, order, dimsTopK, X_INT, 1.0F, score.devID); InitTensorV2(&index, order, dimsTopK, X_INT, 1.0F, score.devID);
InitTensorV2(&preID, order, dimsTopK, X_INT, 1.0F, -1); InitTensorV2(&preID, order, dimsTopK, X_INT, 1.0F, -1);
InitTensorV2(&indexCPU, order, dimsTopK, X_INT, 1.0F, -1);
/* mask the first and the padding id */ /* TODO: check the mask - mask the first and the padding id */
int dimMask[]{ score.GetDim(-1) }; /*int dimMask[]{ score.GetDim(-1) };
XTensor mask; XTensor mask;
InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1); InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1);
mask.SetZeroAll(); mask.SetZeroAll();
mask.Set1D(-1e9F, 0); mask.Set1D(-1e9F, 0);
mask.Set1D(-1e9F, 1); mask.Set1D(-1e9F, 1);
mask.SetDevice(score.devID, score.mem); mask.SetDevice(score.devID);
_SumDim(&score, &mask, 2);*/
_SumDim(&score, &mask, 2);
score.Reshape(order, dimsBeam); score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */ /* keep the most promissing candidates in the beam */
/* TODO: check this line */
TopK(score, scoreTopK, index, -1, beamSize); TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, indexCPU);
CopyValues(index, preID); CopyValues(index, preID);
/* "preID" represents the id (or the offset) of the previous state used to make the current /* "preID" represents the id (or the offset) of the previous state used to make the current
...@@ -317,12 +329,10 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -317,12 +329,10 @@ void T2TSearch::Generate(T2TStateBundle* beam)
CopyValues(scoreTopK, score); CopyValues(scoreTopK, score);
/* CPU data (TODO: remove GPU->CPU data copy!!!) */ /* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexGPU; for (int i = 0; i < indexCPU.unitNum; i += beamSize){
indexGPU = CopyValues(index); for (int j = 0; j < beamSize; j++) {
indexCPU.SetInt(i * stride + indexCPU.GetInt(i + j), i + j);
for (int i = 0; i < indexGPU.unitNum; i += beamSize) { }
for (int j = 0; j < beamSize; j++)
indexGPU.SetInt(i * stride + indexGPU.GetInt(i + j), i + j);
} }
CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!"); CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!");
...@@ -339,13 +349,13 @@ void T2TSearch::Generate(T2TStateBundle* beam) ...@@ -339,13 +349,13 @@ void T2TSearch::Generate(T2TStateBundle* beam)
} }
order = probPath.order; order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
prob.Reshape(1, prob.unitNum);
probTopK.Reshape(1, probTopK.unitNum);
_CopyIndexed(&probPath, &probPathTopK, probPathTopK.order - 1, &indexGPU); prob.Reshape(prob.unitNum, 1);
_CopyIndexed(&prob, &probTopK, probTopK.order - 1, &indexGPU); probPath.Reshape(probPath.unitNum, 1);
indexCPU.Reshape(indexCPU.GetDim(0), indexCPU.GetDim(-1));
probTopK = Gather(prob, indexCPU);
probPathTopK = Gather(probPath, indexCPU);
probPath.Reshape(order, dims); probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK); probPathTopK.Reshape(order, dimsTopK);
...@@ -440,7 +450,9 @@ void T2TSearch::Expand(T2TStateBundle* prev, T2TStateBundle* beam) ...@@ -440,7 +450,9 @@ void T2TSearch::Expand(T2TStateBundle* prev, T2TStateBundle* beam)
CheckNTErrors(state.prediction >= 0, "Illegal prediction!"); CheckNTErrors(state.prediction >= 0, "Illegal prediction!");
/* check if it is the end of the sequence */ /* check if it is the end of the sequence */
state.isEnd = IsEnd(state.prediction); state.isEnd = IsEnd(state.prediction);
state.isCompleted = (state.isCompleted || state.isEnd); state.isCompleted = (state.isCompleted || state.isEnd);
/* set the ending mark */ /* set the ending mark */
...@@ -471,9 +483,10 @@ void T2TSearch::Collect(T2TStateBundle* beam) ...@@ -471,9 +483,10 @@ void T2TSearch::Collect(T2TStateBundle* beam)
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted); bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the hypothesis into the heap when it is completed */ /* we push the hypothesis into the heap when it is completed */
if (state.isEnd != 0) if (state.isEnd && isCompleted) {
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore)); fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
} }
}
} }
/* /*
...@@ -494,8 +507,12 @@ void T2TSearch::FillHeap(T2TStateBundle* beam) ...@@ -494,8 +507,12 @@ void T2TSearch::FillHeap(T2TStateBundle* beam)
CheckNTErrors(state.pid >= 0 && state.pid < batchSize, CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!"); "Invalid sample id!");
/* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the imcomplete hypothesis into the heap */ /* we push the imcomplete hypothesis into the heap */
if (emptyFlags[state.pid] && state.isEnd == 0) if (emptyFlags[state.pid] || state.isEnd || isCompleted)
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore)); fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
} }
...@@ -505,23 +522,28 @@ void T2TSearch::FillHeap(T2TStateBundle* beam) ...@@ -505,23 +522,28 @@ void T2TSearch::FillHeap(T2TStateBundle* beam)
/* /*
save the output sequences in a tensor save the output sequences in a tensor
>> output - output sequences (for return) >> output - output sequences (for return)
>> score - score of thes sequences
*/ */
void T2TSearch::Dump(XTensor* output) void T2TSearch::Dump(XTensor * output, XTensor * score)
{ {
int dims[3] = { batchSize, beamSize, maxLength }; int dims[3] = { batchSize, beamSize, maxLength };
int* words = new int[maxLength]; int* words = new int[maxLength];
InitTensorV2(output, 3, dims, X_INT); InitTensorV2(output, 3, dims, X_INT);
InitTensorV2(score, 2, dims, X_FLOAT);
SetDataFixedInt(*output, -1); SetDataFixedInt(*output, -1);
score->SetZeroAll();
/* heap for an input sentence in the batch */ /* heap for an input sentence in the batch */
for (int h = 0; h < batchSize; h++) { for (int h = 0; h < batchSize; h++) {
XHeap<MIN_HEAP, float>& heap = fullHypos[h]; XHeap<MIN_HEAP, float> &heap = fullHypos[h];
int c = heap.Count();
/* for each output in the beam */ /* for each output in the beam */
for (int i = 0; i < beamSize && heap.Count() > 0; i++) { for(int i = 0; i < beamSize && heap.Count() > 0; i++){
T2TState* state = (T2TState*)heap.Pop().index; HeapNode<float> node = heap.Pop();
T2TState * state = (T2TState *)node.index;
int count = 0; int count = 0;
bool isCompleted = true; bool isCompleted = true;
...@@ -531,15 +553,17 @@ void T2TSearch::Dump(XTensor* output) ...@@ -531,15 +553,17 @@ void T2TSearch::Dump(XTensor* output)
if (!state->isCompleted) if (!state->isCompleted)
isCompleted = false; isCompleted = false;
if (isCompleted) if (isCompleted)
words[count++] = -1; words[count++] = 2;
else else
words[count++] = state->prediction; words[count++] = state->prediction;
state = state->last; state = state->last;
} }
/* dump the sentence to the output tensor */ /* dump the sentence to the output tensor */
for (int w = 0; w < count; w++) for(int w = 0; w < count; w++)
output->Set3DInt(words[count - w - 1], h, beamSize - i - 1, w); output->Set3DInt(words[count - w - 1], h, c - i - 1, w);
score->Set2D(node.value, h, c - i - 1);
} }
} }
...@@ -583,6 +607,23 @@ void T2TSearch::SetEnd(const int* tokens, const int tokenNum) ...@@ -583,6 +607,23 @@ void T2TSearch::SetEnd(const int* tokens, const int tokenNum)
} }
/* /*
check whether all hypotheses are completed
>> beam - the beam that keeps the searching states
*/
bool T2TSearch::IsAllCompleted(T2TStateBundle * beam)
{
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
if(!state.isCompleted)
return false;
}
return true;
}
/*
make a mask to prevent duplicated entries in beam expansion for the first position make a mask to prevent duplicated entries in beam expansion for the first position
>> beam - the beam that keeps the searching states >> beam - the beam that keeps the searching states
*/ */
......
...@@ -62,6 +62,12 @@ private: ...@@ -62,6 +62,12 @@ private:
/* start symbol */ /* start symbol */
int startSymbol; int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public: public:
/* constructor */ /* constructor */
T2TSearch(); T2TSearch();
...@@ -73,7 +79,7 @@ public: ...@@ -73,7 +79,7 @@ public:
void Init(int argc, char** argv); void Init(int argc, char** argv);
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output); void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output, XTensor* score);
/* preparation */ /* preparation */
void Prepare(int myBatchSize, int myBeamSize); void Prepare(int myBatchSize, int myBeamSize);
...@@ -93,12 +99,15 @@ public: ...@@ -93,12 +99,15 @@ public:
/* fill the hypotheis heap with incomplete hypothses */ /* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle* beam); void FillHeap(T2TStateBundle* beam);
/* save the output sequences in a tensor */ /* save the output sequences and score */
void Dump(XTensor* output); void Dump(XTensor* output, XTensor* score);
/* check if the token is an end symbol */ /* check if the token is an end symbol */
bool IsEnd(int token); bool IsEnd(int token);
/*check whether all hypotheses are completed*/
bool IsAllCompleted(T2TStateBundle* beam);
/* set end symbols for search */ /* set end symbols for search */
void SetEnd(const int* tokens, const int tokenNum); void SetEnd(const int* tokens, const int tokenNum);
......
...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID); vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output; XTensor output;
XTensor score;
seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(stderr);
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
Result res; Result res;
XTensor sent, srcIdx, tgtIdx; XTensor sent, srcIdx, tgtIdx;
...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model) ...@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
if (batchCount % 1 == 0) { if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, XPRINT3(0, stderr, "[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n", elapsed, sentCount, wordCount);
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
} }
} }
...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output) ...@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output)
for (int i = 0; i < output->unitNum; i += seqLength) { for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) { for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j); int w = output->GetInt(i + j);
fprintf(file, "%d ", w); if (w < 0 || w == 1)
if (w < 0)
break; break;
fprintf(file, "%d ", w);
} }
fprintf(file, "\n"); fprintf(file, "\n");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论