Commit 7375289d by xiaotong

fix the bug of generating path-probs

parent ace64052
...@@ -72,6 +72,7 @@ void T2TStateBundle::MakeStates(int num) ...@@ -72,6 +72,7 @@ void T2TStateBundle::MakeStates(int num)
/* constructor */ /* constructor */
T2TPredictor::T2TPredictor() T2TPredictor::T2TPredictor()
{ {
startSymbol = -1;
} }
/* de-constructor */ /* de-constructor */
...@@ -115,6 +116,15 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input ...@@ -115,6 +116,15 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
} }
/* /*
set start symbol
>> symbol - the symbol (in integer)
*/
void T2TPredictor::SetStartSymbol(int symbol)
{
startSymbol = symbol;
}
/*
read a state read a state
>> model - the t2t model that keeps the network created so far >> model - the t2t model that keeps the network created so far
>> state - a set of states. It keeps >> state - a set of states. It keeps
...@@ -150,28 +160,24 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * ...@@ -150,28 +160,24 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* word indices of positions up to next state */ /* word indices of positions up to next state */
XTensor &inputDec = *NewTensor(); XTensor &inputDec = *NewTensor();
/* a dummy word that used to as a placeholder when we process the next work */ /* the first token */
XTensor dummy; XTensor first;
for(int i = 0; i < inputEnc->order - 1; i++) for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i); dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = 1; dims[inputEnc->order - 1] = 1;
InitTensor(&dummy, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID, inputEnc->mem); InitTensor(&first, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID, inputEnc->mem);
dummy.SetZeroAll(); _SetDataFixedInt(&first, startSymbol);
/* add a new word into the input sequence of the decoder side */ /* add a new word into the input sequence of the decoder side */
if(inputLast == NULL) if(inputLast == NULL){
inputDec = Identity(dummy); inputDec = Identity(first);
}
else{ else{
inputDec = GeneratePaths(s); inputDec = GeneratePaths(s);
for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = inputDec.GetDim(-1);
inputDec.Resize(inputEnc->order, dims, X_INT);
inputDec.SetDevice(inputEnc->devID, inputEnc->mem); inputDec.SetDevice(inputEnc->devID, inputEnc->mem);
inputDec = Concatenate(inputDec, first, inputDec.order - 1);
inputDec = Concatenate(inputDec, dummy, inputDec.order - 1);
} }
/* prediction probabilities */ /* prediction probabilities */
...@@ -193,10 +199,11 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * ...@@ -193,10 +199,11 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* decoder mask */ /* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
//inputEnc->Dump(stderr, "inputenc:");
//paddingEnc->Dump(stderr, "paddingenc:");
inputDec.Dump(stderr, "inputdec: "); inputDec.Dump(stderr, "inputdec: ");
//encoding->Dump(stderr, "encoding: "); //maskDec.Dump(stderr, "maskdec: ");
maskDec.Dump(stderr, "maskdec: "); //maskEncDec.Dump(stderr, "mask-enc-dec: ");
maskEncDec.Dump(stderr, "mask-enc-dec: ");
/* make the decoding network */ /* make the decoding network */
decoding = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false); decoding = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
......
...@@ -132,6 +132,9 @@ private: ...@@ -132,6 +132,9 @@ private:
/* current state */ /* current state */
T2TStateBundle * s; T2TStateBundle * s;
/* start symbol */
int startSymbol;
public: public:
/* constructor */ /* constructor */
T2TPredictor(); T2TPredictor();
...@@ -142,6 +145,9 @@ public: ...@@ -142,6 +145,9 @@ public:
/* create an initial state */ /* create an initial state */
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state); void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state);
/* set the start symbol */
void SetStartSymbol(int symbol);
/* read a state */ /* read a state */
void Read(T2TModel * model, T2TStateBundle * state); void Read(T2TModel * model, T2TStateBundle * state);
......
...@@ -38,6 +38,7 @@ T2TSearch::T2TSearch() ...@@ -38,6 +38,7 @@ T2TSearch::T2TSearch()
endSymbolNum = 0; endSymbolNum = 0;
fullHypos = NULL; fullHypos = NULL;
endSymbols = new int[32]; endSymbols = new int[32];
startSymbol = -1;
} }
/* de-constructor */ /* de-constructor */
...@@ -60,6 +61,7 @@ void T2TSearch::Init(int argc, char ** argv) ...@@ -60,6 +61,7 @@ void T2TSearch::Init(int argc, char ** argv)
LoadParamInt(argc, argv, "batchsize", &batchSize, 1); LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
LoadParamInt(argc, argv, "endid", endSymbols, -1); LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1);
if(endSymbols[0] >= 0) if(endSymbols[0] >= 0)
endSymbolNum = 1; endSymbolNum = 1;
...@@ -79,12 +81,16 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -79,12 +81,16 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
T2TPredictor predictor; T2TPredictor predictor;
CheckNTErrors(endSymbolNum > 0, "The search class is not initialized!"); CheckNTErrors(endSymbolNum > 0, "The search class is not initialized!");
CheckNTErrors(startSymbol >= 0, "The search class is not initialized!");
Prepare(input->unitNum/input->GetDim(-1), beamSize); Prepare(input->unitNum/input->GetDim(-1), beamSize);
/* encoder mask */ /* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc); model->MakeMTMaskEnc(*input, *padding, maskEnc);
//input->Dump(stderr, "input:");
//maskEnc.Dump(stderr, "maskenc:");
/* make the encoding network */ /* make the encoding network */
encoding = model->MakeEncoder(*input, maskEnc, false); encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME); encoding.SetName(ENCODING_NAME);
...@@ -98,6 +104,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -98,6 +104,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* create the first state */ /* create the first state */
predictor.Create(model, &encoding, input, beamSize, first); predictor.Create(model, &encoding, input, beamSize, first);
predictor.SetStartSymbol(startSymbol);
first->isStart = true; first->isStart = true;
...@@ -161,6 +168,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -161,6 +168,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
{ {
XTensor &score = beam->modelScore; XTensor &score = beam->modelScore;
XTensor &prob = beam->prob; XTensor &prob = beam->prob;
XTensor &probPath = beam->probPath;
XTensor &probPathPrev = prev->probPath; XTensor &probPathPrev = prev->probPath;
XTensor &lenPrev = prev->nstep; XTensor &lenPrev = prev->nstep;
XTensor &len = beam->nstep; XTensor &len = beam->nstep;
...@@ -174,13 +182,15 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -174,13 +182,15 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
dims[i] = prob.GetDim(i); dims[i] = prob.GetDim(i);
InitTensor(&score, &prob); InitTensor(&score, &prob);
InitTensor(&probPath, &prob);
prob.Reshape(prob.unitNum/outputSize, outputSize); prob.Reshape(prob.unitNum/outputSize, outputSize);
score.Reshape(score.unitNum/outputSize, outputSize); score.Reshape(score.unitNum/outputSize, outputSize);
probPath.Reshape(score.unitNum/outputSize, outputSize);
probPathPrev.Reshape(probPathPrev.unitNum); probPathPrev.Reshape(probPathPrev.unitNum);
/* the log-scale probability of the entire sequence */ /* the log-scale probability of the entire sequence */
_SumDim(&prob, &probPathPrev, &score, 0); _SumDim(&prob, &probPathPrev, &probPath, 0);
InitTensor(&len, &lenPrev); InitTensor(&len, &lenPrev);
InitTensor(&lp, &lenPrev); InitTensor(&lp, &lenPrev);
...@@ -192,8 +202,10 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -192,8 +202,10 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
lp.Reshape(lp.unitNum); lp.Reshape(lp.unitNum);
lp.Dump(stderr, "lp:");
/* score = log-prob/lp */ /* score = log-prob/lp */
_DivDim(&score, &lp, &score, 0); _DivDim(&probPath, &lp, &score, 0);
InitTensor(&mask, InitTensor(&mask,
prev->endMark.order, prev->endMark.dimSize, X_FLOAT, 1.0F, prev->endMark.order, prev->endMark.dimSize, X_FLOAT, 1.0F,
...@@ -208,6 +220,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -208,6 +220,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
prob.Reshape(order, dims); prob.Reshape(order, dims);
score.Reshape(order, dims); score.Reshape(order, dims);
probPath.Reshape(order, dims);
probPathPrev.Reshape(order - 1, dims); probPathPrev.Reshape(order - 1, dims);
lp.Reshape(order - 1, dims); lp.Reshape(order - 1, dims);
mask.Reshape(order -1 , dims); mask.Reshape(order -1 , dims);
...@@ -263,6 +276,7 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -263,6 +276,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
CopyValues(index, preID); CopyValues(index, preID);
int sizeVocab = score.GetDim(-1); int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
/* "preID" represents the id (or the offset) of previous state used to make the current /* "preID" represents the id (or the offset) of previous state used to make the current
hypothesis. Note that we reshape the "score" tensor into a matrix where each hypothesis. Note that we reshape the "score" tensor into a matrix where each
...@@ -287,9 +301,32 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -287,9 +301,32 @@ void T2TSearch::Generate(T2TStateBundle * beam)
InitTensor(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1); InitTensor(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
CopyValues(index, indexCPU); CopyValues(index, indexCPU);
for(int i = 0; i < indexCPU.unitNum; i++)
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
/* sequence probability of top-k candidates */ /* sequence probability of top-k candidates */
InitTensor(&probPath, &scoreTopK); XTensor probPathTopK;
_Gather(&beam->prob, &probPath, probPath.order - 1, (int*)indexCPU.data, indexCPU.unitNum); InitTensor(&probPathTopK, &scoreTopK);
for(int i = 0; i < probPath.order; i++){
dims[i] = probPath.GetDim(i);
dimsTopK[i] = probPathTopK.GetDim(i);
}
order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK);
indexCPU.Dump(stderr, "indexcpu:");
scoreTopK.Dump(stderr, "scoretopk:");
probPathTopK.Dump(stderr, "probpathtopk:");
probPath = probPathTopK;
} }
/* /*
...@@ -350,10 +387,12 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -350,10 +387,12 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
if (prev->isStart) { if (prev->isStart) {
state.last = NULL; state.last = NULL;
state.pid = i; state.pid = i;
state.nstep = 0;
} }
else{ else{
state.last = last; state.last = last;
state.pid = state.last->pid; state.pid = state.last->pid;
state.nstep = last->nstep + 1;
CheckNTErrors(offset < prev->stateNum, "Wrong state index!"); CheckNTErrors(offset < prev->stateNum, "Wrong state index!");
} }
......
...@@ -59,6 +59,9 @@ private: ...@@ -59,6 +59,9 @@ private:
/* number of the end symbols */ /* number of the end symbols */
int endSymbolNum; int endSymbolNum;
/* start symbol */
int startSymbol;
public: public:
/* constructor */ /* constructor */
T2TSearch(); T2TSearch();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论