Commit 7375289d by xiaotong

fix the bug of generating path-probs

parent ace64052
......@@ -72,6 +72,7 @@ void T2TStateBundle::MakeStates(int num)
/* constructor */
T2TPredictor::T2TPredictor()
{
startSymbol = -1;
}
/* de-constructor */
......@@ -113,6 +114,15 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
state->stateNum = 0;
}
/*
set start symbol
>> symbol - the symbol (in integer)
*/
void T2TPredictor::SetStartSymbol(int symbol)
{
startSymbol = symbol;
}
/*
read a state
......@@ -150,28 +160,24 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* word indices of positions up to next state */
XTensor &inputDec = *NewTensor();
/* a dummy word that used to as a placeholder when we process the next work */
XTensor dummy;
/* the first token */
XTensor first;
for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = 1;
InitTensor(&dummy, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID, inputEnc->mem);
dummy.SetZeroAll();
InitTensor(&first, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID, inputEnc->mem);
_SetDataFixedInt(&first, startSymbol);
/* add a new word into the input sequence of the decoder side */
if(inputLast == NULL)
inputDec = Identity(dummy);
if(inputLast == NULL){
inputDec = Identity(first);
}
else{
inputDec = GeneratePaths(s);
for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = inputDec.GetDim(-1);
inputDec.Resize(inputEnc->order, dims, X_INT);
inputDec.SetDevice(inputEnc->devID, inputEnc->mem);
inputDec = Concatenate(inputDec, dummy, inputDec.order - 1);
inputDec = Concatenate(inputDec, first, inputDec.order - 1);
}
/* prediction probabilities */
......@@ -193,10 +199,11 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
//inputEnc->Dump(stderr, "inputenc:");
//paddingEnc->Dump(stderr, "paddingenc:");
inputDec.Dump(stderr, "inputdec: ");
//encoding->Dump(stderr, "encoding: ");
maskDec.Dump(stderr, "maskdec: ");
maskEncDec.Dump(stderr, "mask-enc-dec: ");
//maskDec.Dump(stderr, "maskdec: ");
//maskEncDec.Dump(stderr, "mask-enc-dec: ");
/* make the decoding network */
decoding = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
......
......@@ -131,6 +131,9 @@ private:
/* current state */
T2TStateBundle * s;
/* start symbol */
int startSymbol;
public:
/* constructor */
......@@ -141,6 +144,9 @@ public:
/* create an initial state */
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state);
/* set the start symbol */
void SetStartSymbol(int symbol);
/* read a state */
void Read(T2TModel * model, T2TStateBundle * state);
......
......@@ -38,6 +38,7 @@ T2TSearch::T2TSearch()
endSymbolNum = 0;
fullHypos = NULL;
endSymbols = new int[32];
startSymbol = -1;
}
/* de-constructor */
......@@ -60,6 +61,7 @@ void T2TSearch::Init(int argc, char ** argv)
LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1);
if(endSymbols[0] >= 0)
endSymbolNum = 1;
......@@ -79,12 +81,16 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
T2TPredictor predictor;
CheckNTErrors(endSymbolNum > 0, "The search class is not initialized!");
CheckNTErrors(startSymbol >= 0, "The search class is not initialized!");
Prepare(input->unitNum/input->GetDim(-1), beamSize);
/* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc);
//input->Dump(stderr, "input:");
//maskEnc.Dump(stderr, "maskenc:");
/* make the encoding network */
encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME);
......@@ -98,6 +104,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* create the first state */
predictor.Create(model, &encoding, input, beamSize, first);
predictor.SetStartSymbol(startSymbol);
first->isStart = true;
......@@ -161,6 +168,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
{
XTensor &score = beam->modelScore;
XTensor &prob = beam->prob;
XTensor &probPath = beam->probPath;
XTensor &probPathPrev = prev->probPath;
XTensor &lenPrev = prev->nstep;
XTensor &len = beam->nstep;
......@@ -174,13 +182,15 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
dims[i] = prob.GetDim(i);
InitTensor(&score, &prob);
InitTensor(&probPath, &prob);
prob.Reshape(prob.unitNum/outputSize, outputSize);
score.Reshape(score.unitNum/outputSize, outputSize);
probPath.Reshape(score.unitNum/outputSize, outputSize);
probPathPrev.Reshape(probPathPrev.unitNum);
/* the log-scale probability of the entire sequence */
_SumDim(&prob, &probPathPrev, &score, 0);
_SumDim(&prob, &probPathPrev, &probPath, 0);
InitTensor(&len, &lenPrev);
InitTensor(&lp, &lenPrev);
......@@ -191,9 +201,11 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
lp = T2TLengthPenalizer::GNMT(len, alpha);
lp.Reshape(lp.unitNum);
lp.Dump(stderr, "lp:");
/* score = log-prob/lp */
_DivDim(&score, &lp, &score, 0);
_DivDim(&probPath, &lp, &score, 0);
InitTensor(&mask,
prev->endMark.order, prev->endMark.dimSize, X_FLOAT, 1.0F,
......@@ -208,6 +220,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
prob.Reshape(order, dims);
score.Reshape(order, dims);
probPath.Reshape(order, dims);
probPathPrev.Reshape(order - 1, dims);
lp.Reshape(order - 1, dims);
mask.Reshape(order -1 , dims);
......@@ -263,6 +276,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
CopyValues(index, preID);
int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
/* "preID" represents the id (or the offset) of previous state used to make the current
hypothesis. Note that we reshape the "score" tensor into a matrix where each
......@@ -286,10 +300,33 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor indexCPU;
InitTensor(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
CopyValues(index, indexCPU);
for(int i = 0; i < indexCPU.unitNum; i++)
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
/* sequence probability of top-k candidates */
InitTensor(&probPath, &scoreTopK);
_Gather(&beam->prob, &probPath, probPath.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
XTensor probPathTopK;
InitTensor(&probPathTopK, &scoreTopK);
for(int i = 0; i < probPath.order; i++){
dims[i] = probPath.GetDim(i);
dimsTopK[i] = probPathTopK.GetDim(i);
}
order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK);
indexCPU.Dump(stderr, "indexcpu:");
scoreTopK.Dump(stderr, "scoretopk:");
probPathTopK.Dump(stderr, "probpathtopk:");
probPath = probPathTopK;
}
/*
......@@ -350,10 +387,12 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
if (prev->isStart) {
state.last = NULL;
state.pid = i;
state.nstep = 0;
}
else{
state.last = last;
state.pid = state.last->pid;
state.nstep = last->nstep + 1;
CheckNTErrors(offset < prev->stateNum, "Wrong state index!");
}
......
......@@ -58,6 +58,9 @@ private:
/* number of the end symbols */
int endSymbolNum;
/* start symbol */
int startSymbol;
public:
/* constructor */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论