Commit d5abb971 by xiaotong

reshape tensors to fit into SumDim and etc.

parent 75385ebe
......@@ -190,18 +190,13 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
for(int i = 0; i < decoding.order - 1; i++)
dims[i] = decoding.GetDim(i);
dims[decoding.order - 2] = 1;
int stride = decoding.GetDim(decoding.order - 2);
InitTensor(&selectSrc, decoding.order - 1, dims, X_INT);
InitTensor(&selectTgt, decoding.order - 1, dims, X_INT);
InitTensor1D(&selectSrc, 1, X_INT);
InitTensor1D(&selectTgt, 1, X_INT);
int stride = decoding.GetDim(decoding.order - 2);
for(int i = 0; i < selectSrc.unitNum; i++){
selectSrc.SetInt(i * stride + stride - 1, i);
selectTgt.SetInt(i, i);
}
selectSrc.SetInt(stride - 1, 0);
selectTgt.SetInt(0, 0);
selectSrc.SetDevice(decoding.devID, decoding.mem);
selectTgt.SetDevice(decoding.devID, decoding.mem);
......
......@@ -142,10 +142,20 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
XTensor lp;
XTensor mask;
int order = prob.order;
int outputSize = prob.GetDim(-1);
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < order; i++)
dims[i] = prob.GetDim(i);
InitTensor(&score, &prob);
prob.Reshape(prob.unitNum/outputSize, outputSize);
score.Reshape(score.unitNum/outputSize, outputSize);
probPathPrev.Reshape(probPathPrev.unitNum);
/* the log-scale probability of the entire sequence */
_Sum(&prob, &probPathPrev, &score);
_SumDim(&prob, &probPathPrev, &score, 0);
InitTensor(&len, &lenPrev);
InitTensor(&lp, &lenPrev);
......@@ -155,8 +165,15 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
/* the GNMT-like length penalty */
lp = T2TLengthPenalizer::GNMT(len, alpha);
lp.Reshape(lp.unitNum);
/* score = log-prob/lp */
_Div(&score, &lp, &score);
_DivDim(&score, &lp, &score, 0);
prob.Reshape(order, dims);
score.Reshape(order, dims);
probPathPrev.Reshape(order - 1, dims);
lp.Reshape(order - 1, dims);
InitTensor(&mask, &prev->endMark);
CopyValues(prev->endMark, mask);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论