Commit a1835883 by xiaotong

fixing bug of wrong tensor orders

parent 4dbd1f23
......@@ -96,12 +96,12 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
encoding.SetName(ENCODING_NAME);
XTensor encodingBeam = Unsqueeze(encoding, encoding.order - 2, beamSize);
XTensor inputBeam = Unsqueeze(*input, input->order - 2, beamSize);
XTensor paddingBeam = Unsqueeze(*padding, padding->order - 2, beamSize);
XTensor inputBeam = Unsqueeze(*input, input->order - 1, beamSize);
XTensor paddingBeam = Unsqueeze(*padding, padding->order - 1, beamSize);
encodingBeam.ReshapeMerged(encodingBeam.order - 4);
inputBeam.ReshapeMerged(inputBeam.order - 4);
paddingBeam.ReshapeMerged(paddingBeam.order - 4);
inputBeam.ReshapeMerged(inputBeam.order - 3);
paddingBeam.ReshapeMerged(paddingBeam.order - 3);
/* max output-length = 2 * source-length */
maxLength = input->GetDim(-1) * 2;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论