Commit 434335e0 by xiaotong

generate top-k probability for each position

parent a4e6ab64
......@@ -62,7 +62,7 @@ void T2TSearch::Init(int argc, char ** argv)
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1);
if(endSymbols[0] >= 0)
endSymbolNum = 1;
}
......@@ -260,6 +260,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &index = beam->prediction;
XTensor &preID = beam->preID;
XTensor &probPath = beam->probPath;
XTensor &prob = beam->prob;
int order = score.order;
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
......@@ -323,9 +324,13 @@ void T2TSearch::Generate(T2TStateBundle * beam)
for(int i = 0; i < indexCPU.unitNum; i++)
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
/* sequence probability of top-k candidates */
CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!");
/* sequence probability and prediction probability of top-k candidates */
XTensor probPathTopK;
InitTensor(&probPathTopK, &scoreTopK);
XTensor probTopK;
InitTensor(&probTopK, &scoreTopK);
for(int i = 0; i < probPath.order; i++){
dims[i] = probPath.GetDim(i);
......@@ -335,19 +340,27 @@ void T2TSearch::Generate(T2TStateBundle * beam)
order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
prob.Reshape(1, prob.unitNum);
probTopK.Reshape(1, probTopK.unitNum);
indexCPU.Dump(stderr, "indexCPU:");
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
_Gather(&prob, &probTopK, probTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK);
prob.Reshape(order, dims);
probTopK.Reshape(order, dimsTopK);
indexCPU.Dump(stderr, "indexcpu:");
scoreTopK.Dump(stderr, "scoretopk:");
probPathTopK.Dump(stderr, "probpathtopk:");
probTopK.Dump(stderr, "probtopk:");
probPath = probPathTopK;
prob = probTopK;
}
/*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论