Commit 6bc6a058 by huchi

update xtensor function call

parent c177e3c5
......@@ -225,7 +225,7 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
InitTensorV2(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
......@@ -235,10 +235,10 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, 1.0F, paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1, paddingEnc.dataType, 1.0F, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
......@@ -254,13 +254,13 @@ void T2TModel::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
XTensor * padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, 1.0F, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
XTensor * padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, 1.0F, paddingEnc.devID);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
......@@ -296,13 +296,13 @@ void T2TModel::MakeMTMaskEnc(XTensor& inputEnc, XTensor& paddingEnc, XTensor& ma
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
XTensor * padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, 1.0F, paddingEnc.devID);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID);
XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, 1.0F, paddingEnc.devID);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
......@@ -340,7 +340,7 @@ void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
InitTensorV2(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID);
/* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
This matrix can be used to block the attention to current or following words in
......@@ -355,11 +355,12 @@ void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
InitTensorV2(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
XTensor * maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1,
paddingEnc.dataType, 1.0F, paddingEnc.devID);
XTensor * maskEncDecTMPDec = NewTensorBufV2(paddingEnc.order + 1, dims + 1,
paddingEnc.dataType, 1.0F, paddingEnc.devID);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
......
......@@ -294,14 +294,14 @@ void T2TSearch::Generate(T2TStateBundle* beam)
InitTensorV2(&indexCPU, order, dimsTopK, X_INT, 1.0F, -1);
/* TODO: check the mask - mask the first and the padding id */
/*int dimMask[]{ score.GetDim(-1) };
int dimMask[]{ score.GetDim(-1) };
XTensor mask;
InitTensorV2(&mask, 1, dimMask, X_FLOAT, 1.0F, -1);
mask.SetZeroAll();
mask.Set1D(-1e9F, 0);
mask.Set1D(-1e9F, 1);
mask.SetDevice(score.devID);
_SumDim(&score, &mask, 2);*/
_SumDim(&score, &mask, 2);
score.Reshape(order, dimsBeam);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论