Commit d954656b by huchi

fix a bug for gpu version

parent 6bc6a058
......@@ -354,6 +354,7 @@ void T2TSearch::Generate(T2TStateBundle* beam)
probPath.Reshape(probPath.unitNum, 1);
indexCPU.Reshape(indexCPU.GetDim(0), indexCPU.GetDim(-1));
indexCPU.SetDevice(prob.devID);
probTopK = Gather(prob, indexCPU);
probPathTopK = Gather(probPath, indexCPU);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论