Commit b6c077a1 by xiaotong

new code of inputDec tensor

parent 36c80fc7
......@@ -362,9 +362,9 @@ void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
/* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
This matrix can be used to block the attention to current or following words in
a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
......
......@@ -118,6 +118,8 @@ predict the next state
*/
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc)
{
int dims[MAX_TENSOR_DIM_NUM];
next->layersEnc.Clear();
next->layersDec.Clear();
......@@ -128,10 +130,25 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* word indices of positions up to next state */
XTensor &inputDec = *NewTensor();
/* a dummy word that used to as a placeholder when we process the next work */
XTensor dummy;
for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = 1;
InitTensor(&dummy, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID, inputEnc->mem);
dummy.SetZeroAll();
/* add a new word into the input sequence of the decoder side */
if(inputLast == NULL)
inputDec = s->prediction;
else
inputDec = Concatenate(*inputLast, s->prediction, inputLast->GetDim(-1));
inputDec = Identity(dummy);
else{
XTensor inputDecSlide = SelectRange(*inputLast, inputLast->order - 1, 0, inputLast->GetDim(-1) - 2);
inputDec = Concatenate(inputDecSlide, dummy, inputDecSlide.order - 1);
}
/* prediction probabilities */
XTensor &output = next->prob;
......@@ -156,7 +173,6 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < decoding.order - 1; i++)
dims[i] = decoding.GetDim(i);
dims[decoding.order - 2] = 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论