Commit 73f32216 by xiaotong

initialize the member tensors of the first state in search

parent 7294cb66
......@@ -134,7 +134,7 @@ public:
~T2TPredictor();
/* create an initial state */
void Create(T2TModel * model, XTensor * top, T2TStateBundle * state);
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state);
/* read a state */
void Read(T2TModel * model, T2TStateBundle * state);
......
......@@ -83,7 +83,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
T2TStateBundle * first = states;
/* create the first state */
predictor.Create(model, &encoding, first);
predictor.Create(model, &encoding, input, beamSize, first);
/* generate the sequence from left to right */
for(int i = 0 ; i < maxLength; i++){
......@@ -136,7 +136,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
{
XTensor &score = beam->modelScore;
XTensor &prob = beam->prob;
XTensor &probPath = beam->probPath;
XTensor &probPathPrev = prev->probPath;
XTensor &lenPrev = prev->nstep;
XTensor &len = beam->nstep;
XTensor lp;
......@@ -145,7 +145,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
InitTensor(&score, &prob);
/* the log-scale probability of the entire sequence */
_Sum(&prob, &probPath, &score);
_Sum(&prob, &probPathPrev, &score);
InitTensor(&len, &lenPrev);
InitTensor(&lp, &lenPrev);
......@@ -181,6 +181,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &score = beam->modelScore;
XTensor &index = beam->prediction;
XTensor &preID = beam->preID;
XTensor &probPath = beam->probPath;
int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger.");
......@@ -230,6 +231,10 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* we keep the top-k scores */
InitTensor(&score, &scoreTopK);
CopyValues(scoreTopK, score);
/* sequence probability of top-k candidates */
InitTensor(&probPath, &scoreTopK);
_Gather(&beam->prob, &probPath, probPath.order - 1, (int*)index.data, index.unitNum);
}
/*
......
......@@ -34,8 +34,8 @@ gather indexed sub-tensors
>> s - the source tensor
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
e.g., for a tensor of size (3, 2, 4) and dim = 0,
we have 3 sub-tensors of size (2, 4)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
*/
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论