Commit c177e3c5 by huchi

fix some bugs in beam search

parent e155205c
......@@ -43,6 +43,21 @@ int main( int argc, const char ** argv )
_CrtSetBreakAlloc(2708);*/
TransformerMain(argc - 1, argv + 1);
//XTensor singleScore, singleIdx, score;
//InitTensor3DV2(&score, 2, 1, 136160);
////score.SetDataRand(0, 1);
//InitTensor1DV2(&singleIdx, 1, X_INT);
//singleIdx.Set1DInt(1, 0);
//singleIdx.Dump(stderr);
//singleScore = Select(score, singleIdx, 0);
//XTensor s, i;
//InitTensor3DV2(&s, 2, 1, 4);
//InitTensor3DV2(&i, 2, 1, 4, X_INT);
//TopK(score, s, i, -1, 4);
//i.Dump(stderr, "single score:\n");
//_CrtDumpMemoryLeaks();
return 0;
......
......@@ -196,7 +196,7 @@ void T2TModel::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, XTe
MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
/* decoder mask */
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec, 0);
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, &maskEnc, isTraining);
......@@ -225,8 +225,8 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order + 1] = len;
InitTensorV2(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID);
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
......@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
......@@ -254,15 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType,
paddingEnc.devID);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType,
paddingEnc.devID);
XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
......@@ -270,7 +268,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3);
InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */
......@@ -297,22 +295,22 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
dimsPadding[i] = paddingEnc.GetDim(i);
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
XTensor* padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
_Unsqueeze(padding2, padding3, 0, nhead);
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensorV2(&maskEnc, padding3);
InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */
......@@ -332,33 +330,37 @@ make the mask of the decoder
>> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention
*/
void T2TModel::MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec,
XTensor& paddingEnc, XTensor& paddingDec,
XTensor& maskDec, XTensor& maskEncDec, int incDim)
void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec)
{
int len = inputDec.GetDim(inputDec.order - 1);
int* dims = new int[inputDec.order + 2];
for (int i = 0; i < inputDec.order; i++)
int * dims = new int[inputDec.order + 2];
for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i);
//dims[inputDec.order] += incDim;
dims[0] = nhead;
dims[inputDec.order + 1] = len;
//InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec);
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
/* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
This matrix can be used to block the attention to current or following words in
a given sequence. */
//_SetDataLowTri(&maskDec, 1e9F, 0);
//_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
_SetDataLowTri(&maskDec, 1e9F, 0);
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
//maskDec.Dump(stderr, "mask: ");
XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor* maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
//maskDec.Dump(stderr, "mask: ");
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
//paddingEnc.Dump(stderr, "paddingenc:");
......
......@@ -90,9 +90,9 @@ public:
void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc);
/* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec, int incDim);
void MakeMTMaskDec(XTensor& inputEnc, XTensor& inputDec,
XTensor& paddingEnc, XTensor& paddingDec,
XTensor& maskDec, XTensor& maskEncDec);
/* get parameter matrics */
void GetParams(TensorList &list);
......
......@@ -166,7 +166,6 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID);
}
inputDec.Dump(stderr, "inputDec");
/* prediction probabilities */
XTensor& output = next->prob;
......@@ -184,10 +183,10 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
XTensor maskEncDec;
/* decoder mask */
//m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0);
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network */
decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false);
decoding = m->decoder->Make(inputDec, *encoding, &maskDec, maskEncDec, false);
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
......
......@@ -38,7 +38,7 @@ T2TSearch::T2TSearch()
endSymbolNum = 0;
fullHypos = NULL;
endSymbols = new int[32];
startSymbol = 2;
startSymbol = -1;
}
/* de-constructor */
......@@ -60,8 +60,10 @@ void T2TSearch::Init(int argc, char** argv)
LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F);
LoadParamInt(argc, argv, "endid", endSymbols, 2);
LoadParamInt(argc, argv, "startid", &startSymbol, 2);
LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1);
LoadParamFloat(argc, argv, "maxlenalpha", &scalarMaxLength, 2.0F);
LoadParamBool(argc, argv, "earlystop", &isEarlyStop, false);
if (endSymbols[0] >= 0)
endSymbolNum = 1;
......@@ -73,8 +75,10 @@ search for the most promising states
>> input - input of the model
>> padding - padding of the input
>> output - output that represents the sequences as rows
>> score - score of the sequences
*/
void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output)
void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding,
XTensor * output, XTensor * score)
{
T2TPredictor predictor;
XTensor maskEnc;
......@@ -89,7 +93,7 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
Prepare(input->unitNum / input->GetDim(-1), beamSize);
/* encoder mask */
//model->MakeMTMaskEnc(*input, *padding, maskEnc);
model->MakeMTMaskEnc(*input, *padding, maskEnc);
/* make the encoding network */
encoding = model->MakeEncoder(*input, &maskEnc, false);
......@@ -101,16 +105,17 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
encodingBeam.ReshapeMerged(encodingBeam.order - 4);
inputBeam.ReshapeMerged(inputBeam.order - 3);
paddingBeam.ReshapeMerged(paddingBeam.order - 3);
/* max output-length = 2 * source-length */
maxLength = input->GetDim(-1) * 2;
CheckNTErrors(maxLength > 0, "no max length specified!");
T2TStateBundle* states = new T2TStateBundle[maxLength + 1];
T2TStateBundle* first = states;
T2TStateBundle* cur;
T2TStateBundle* next;
/* max output-length = scalar * source-length */
int lengthLimit = (int)(input->GetDim(-1) * scalarMaxLength);
CheckNTErrors(lengthLimit > 0, "no max length specified!");
maxLength = lengthLimit;
T2TStateBundle * states = new T2TStateBundle[lengthLimit + 1];
T2TStateBundle * first = states;
T2TStateBundle * cur = NULL;
T2TStateBundle * next = NULL;
/* create the first state */
predictor.Create(model, &encodingBeam, input, beamSize, first);
predictor.SetStartSymbol(startSymbol);
......@@ -118,15 +123,15 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
first->isStart = true;
/* generate the sequence from left to right */
for (int i = 0; i < maxLength; i++) {
cur = states + i;
next = states + i + 1;
for(int l = 0 ; l < lengthLimit; l++){
cur = states + l;
next = states + l + 1;
/* read the current state */
predictor.Read(model, cur);
/* predict the next state */
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, i == 0);
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, l == 0);
/* compute the model score (given the prediction probability) */
Score(cur, next);
......@@ -139,12 +144,18 @@ void T2TSearch::Search(T2TModel* model, XTensor* input, XTensor* padding, XTenso
/* push complete hypotheses into the heap */
Collect(next);
/* stop searching when all hypotheses are completed */
if(IsAllCompleted(next)){
maxLength = l + 1;
break;
}
}
/* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(next);
Dump(output);
Dump(output, score);
delete[] states;
}
......@@ -214,8 +225,7 @@ void T2TSearch::Score(T2TStateBundle* prev, T2TStateBundle* beam)
_DivDim(&probPath, &lp, &score, 0);
if (prev->isStart) {
XTensor firstMask;
firstMask = MakeFirstMask(beam);
XTensor firstMask = MakeFirstMask(beam);
firstMask.Reshape(firstMask.unitNum);
/* mask the hypotheses in the beam except the first one */
......@@ -252,22 +262,24 @@ void T2TSearch::Generate(T2TStateBundle* beam)
int dimsTopK[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK;
XTensor& score = beam->modelScore;
XTensor& index = beam->prediction;
XTensor& preID = beam->preID;
XTensor& probPath = beam->probPath;
XTensor& prob = beam->prob;
XTensor indexCPU;
XTensor &score = beam->modelScore;
XTensor &index = beam->prediction;
XTensor &preID = beam->preID;
XTensor &probPath = beam->probPath;
XTensor &prob = beam->prob;
int order = score.order;
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i);
dimsBeam[i] = score.GetDim(i);
dimsTopK[i] = score.GetDim(i);
}
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
......@@ -279,23 +291,23 @@ void T2TSearch::Generate(T2TStateBundle* beam)
InitTensorV2(&scoreTopK, order, dimsTopK, score.dataType, 1.0F, score.devID);
InitTensorV2(&index, order, dimsTopK, X_INT, 1.0F, score.devID);
InitTensorV2(&preID, order, dimsTopK, X_INT, 1.0F, -1);
InitTensorV2(&indexCPU, order, dimsTopK, X_INT, 1.0F, -1);
/* mask the first and the padding id */
int dimMask[]{ score.GetDim(-1) };
/* TODO: check the mask - mask the first and the padding id */
/*int dimMask[]{ score.GetDim(-1) };
XTensor mask;
InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1);
mask.SetZeroAll();
mask.Set1D(-1e9F, 0);
mask.Set1D(-1e9F, 1);
mask.SetDevice(score.devID, score.mem);
mask.SetDevice(score.devID);
_SumDim(&score, &mask, 2);*/
_SumDim(&score, &mask, 2);
score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */
/* TODO: check this line */
TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, indexCPU);
CopyValues(index, preID);
/* "preID" represents the id (or the offset) of the previous state used to make the current
......@@ -317,12 +329,10 @@ void T2TSearch::Generate(T2TStateBundle* beam)
CopyValues(scoreTopK, score);
/* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexGPU;
indexGPU = CopyValues(index);
for (int i = 0; i < indexGPU.unitNum; i += beamSize) {
for (int j = 0; j < beamSize; j++)
indexGPU.SetInt(i * stride + indexGPU.GetInt(i + j), i + j);
for (int i = 0; i < indexCPU.unitNum; i += beamSize){
for (int j = 0; j < beamSize; j++) {
indexCPU.SetInt(i * stride + indexCPU.GetInt(i + j), i + j);
}
}
CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!");
......@@ -339,13 +349,13 @@ void T2TSearch::Generate(T2TStateBundle* beam)
}
order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
prob.Reshape(1, prob.unitNum);
probTopK.Reshape(1, probTopK.unitNum);
_CopyIndexed(&probPath, &probPathTopK, probPathTopK.order - 1, &indexGPU);
_CopyIndexed(&prob, &probTopK, probTopK.order - 1, &indexGPU);
prob.Reshape(prob.unitNum, 1);
probPath.Reshape(probPath.unitNum, 1);
indexCPU.Reshape(indexCPU.GetDim(0), indexCPU.GetDim(-1));
probTopK = Gather(prob, indexCPU);
probPathTopK = Gather(probPath, indexCPU);
probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK);
......@@ -440,7 +450,9 @@ void T2TSearch::Expand(T2TStateBundle* prev, T2TStateBundle* beam)
CheckNTErrors(state.prediction >= 0, "Illegal prediction!");
/* check if it is the end of the sequence */
state.isEnd = IsEnd(state.prediction);
state.isCompleted = (state.isCompleted || state.isEnd);
/* set the ending mark */
......@@ -471,8 +483,9 @@ void T2TSearch::Collect(T2TStateBundle* beam)
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the hypothesis into the heap when it is completed */
if (state.isEnd != 0)
if (state.isEnd && isCompleted) {
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
}
}
}
......@@ -492,10 +505,14 @@ void T2TSearch::FillHeap(T2TStateBundle* beam)
T2TState& state = states[i];
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
"Invalid sample id!");
/* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the imcomplete hypothesis into the heap */
if (emptyFlags[state.pid] && state.isEnd == 0)
if (emptyFlags[state.pid] || state.isEnd || isCompleted)
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
}
......@@ -505,24 +522,29 @@ void T2TSearch::FillHeap(T2TStateBundle* beam)
/*
save the output sequences in a tensor
>> output - output sequences (for return)
>> score - score of thes sequences
*/
void T2TSearch::Dump(XTensor* output)
void T2TSearch::Dump(XTensor * output, XTensor * score)
{
int dims[3] = { batchSize, beamSize, maxLength };
int* words = new int[maxLength];
InitTensorV2(output, 3, dims, X_INT);
InitTensorV2(score, 2, dims, X_FLOAT);
SetDataFixedInt(*output, -1);
score->SetZeroAll();
/* heap for an input sentence in the batch */
for (int h = 0; h < batchSize; h++) {
XHeap<MIN_HEAP, float>& heap = fullHypos[h];
XHeap<MIN_HEAP, float> &heap = fullHypos[h];
int c = heap.Count();
/* for each output in the beam */
for (int i = 0; i < beamSize && heap.Count() > 0; i++) {
T2TState* state = (T2TState*)heap.Pop().index;
for(int i = 0; i < beamSize && heap.Count() > 0; i++){
HeapNode<float> node = heap.Pop();
T2TState * state = (T2TState *)node.index;
int count = 0;
bool isCompleted = true;
......@@ -531,15 +553,17 @@ void T2TSearch::Dump(XTensor* output)
if (!state->isCompleted)
isCompleted = false;
if (isCompleted)
words[count++] = -1;
words[count++] = 2;
else
words[count++] = state->prediction;
state = state->last;
}
/* dump the sentence to the output tensor */
for (int w = 0; w < count; w++)
output->Set3DInt(words[count - w - 1], h, beamSize - i - 1, w);
for(int w = 0; w < count; w++)
output->Set3DInt(words[count - w - 1], h, c - i - 1, w);
score->Set2D(node.value, h, c - i - 1);
}
}
......@@ -582,6 +606,23 @@ void T2TSearch::SetEnd(const int* tokens, const int tokenNum)
endSymbolNum = tokenNum;
}
/*
check whether all hypotheses are completed
>> beam - the beam that keeps the searching states
*/
bool T2TSearch::IsAllCompleted(T2TStateBundle * beam)
{
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
if(!state.isCompleted)
return false;
}
return true;
}
/*
make a mask to prevent duplicated entries in beam expansion for the first position
>> beam - the beam that keeps the searching states
......
......@@ -62,6 +62,12 @@ private:
/* start symbol */
int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public:
/* constructor */
T2TSearch();
......@@ -73,7 +79,7 @@ public:
void Init(int argc, char** argv);
/* search for the most promising states */
void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output);
void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output, XTensor* score);
/* preparation */
void Prepare(int myBatchSize, int myBeamSize);
......@@ -93,12 +99,15 @@ public:
/* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle* beam);
/* save the output sequences in a tensor */
void Dump(XTensor* output);
/* save the output sequences and score */
void Dump(XTensor* output, XTensor* score);
/* check if the token is an end symbol */
bool IsEnd(int token);
/*check whether all hypotheses are completed*/
bool IsAllCompleted(T2TStateBundle* beam);
/* set end symbols for search */
void SetEnd(const int* tokens, const int tokenNum);
......
......@@ -101,9 +101,10 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output;
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(stderr);
XTensor score;
seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
for (int i = 0; i < indices.size(); ++i) {
Result res;
XTensor sent, srcIdx, tgtIdx;
......@@ -127,9 +128,7 @@ void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
XPRINT3(0, stderr, "[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n", elapsed, sentCount, wordCount);
}
}
......@@ -160,9 +159,10 @@ void T2TTester::Dump(FILE* file, XTensor* output)
for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j);
fprintf(file, "%d ", w);
if (w < 0)
if (w < 0 || w == 1)
break;
fprintf(file, "%d ", w);
}
fprintf(file, "\n");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论