Commit 434335e0 by xiaotong

generate top-k probability for each position

parent a4e6ab64
...@@ -260,6 +260,7 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -260,6 +260,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &index = beam->prediction; XTensor &index = beam->prediction;
XTensor &preID = beam->preID; XTensor &preID = beam->preID;
XTensor &probPath = beam->probPath; XTensor &probPath = beam->probPath;
XTensor &prob = beam->prob;
int order = score.order; int order = score.order;
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger."); CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
...@@ -323,9 +324,13 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -323,9 +324,13 @@ void T2TSearch::Generate(T2TStateBundle * beam)
for(int i = 0; i < indexCPU.unitNum; i++) for(int i = 0; i < indexCPU.unitNum; i++)
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i); indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
/* sequence probability of top-k candidates */ CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!");
/* sequence probability and prediction probability of top-k candidates */
XTensor probPathTopK; XTensor probPathTopK;
InitTensor(&probPathTopK, &scoreTopK); InitTensor(&probPathTopK, &scoreTopK);
XTensor probTopK;
InitTensor(&probTopK, &scoreTopK);
for(int i = 0; i < probPath.order; i++){ for(int i = 0; i < probPath.order; i++){
dims[i] = probPath.GetDim(i); dims[i] = probPath.GetDim(i);
...@@ -335,19 +340,27 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -335,19 +340,27 @@ void T2TSearch::Generate(T2TStateBundle * beam)
order = probPath.order; order = probPath.order;
probPath.Reshape(1, probPath.unitNum); probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum); probPathTopK.Reshape(1, probPathTopK.unitNum);
prob.Reshape(1, prob.unitNum);
probTopK.Reshape(1, probTopK.unitNum);
indexCPU.Dump(stderr, "indexCPU:"); indexCPU.Dump(stderr, "indexCPU:");
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum); _Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
_Gather(&prob, &probTopK, probTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
probPath.Reshape(order, dims); probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK); probPathTopK.Reshape(order, dimsTopK);
prob.Reshape(order, dims);
probTopK.Reshape(order, dimsTopK);
indexCPU.Dump(stderr, "indexcpu:"); indexCPU.Dump(stderr, "indexcpu:");
scoreTopK.Dump(stderr, "scoretopk:"); scoreTopK.Dump(stderr, "scoretopk:");
probPathTopK.Dump(stderr, "probpathtopk:"); probPathTopK.Dump(stderr, "probpathtopk:");
probTopK.Dump(stderr, "probtopk:");
probPath = probPathTopK; probPath = probPathTopK;
prob = probTopK;
} }
/* /*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论