Commit f701be0e by xiaotong

improve the code of masking

parent 11cd04a3
...@@ -203,16 +203,30 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -203,16 +203,30 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor maskEnc; XTensor maskEnc;
XTensor maskDec; XTensor maskDec;
XTensor maskEncDec; XTensor maskEncDec;
/* generate mask to see "previous" words on the decoder side */
//int len = inputDec.GetDim(inputDec.order - 2);
//int * dims = new int[inputDec.order + 1];
//for(int i = 0; i < inputDec.order; i++)
// dims[i + 1] = inputDec.GetDim(i);
//dims[0] = nhead;
//dims[inputDec.order] = len;
//InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem);
MakeMTMask(inputEnc, inputDec, paddingEnc, paddingDec, maskEnc, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
outputLayer->Make(decoding, output);
}
/*
make the mask for training MT models
>> inputEnc - input of the encoder
>> inputDec - input of the decoder
>> paddingEnc - padding of the encoder input
>> paddingDec - padding of the decoder input
>> maskEnc - mask of the encoder self-attention
>> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention
*/
void T2TModel::MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskEnc, XTensor &maskDec, XTensor &maskEncDec)
{
int len = inputDec.GetDim(inputDec.order - 1); int len = inputDec.GetDim(inputDec.order - 1);
int * dims = new int[inputDec.order + 2]; int * dims = new int[inputDec.order + 2];
for(int i = 0; i < inputDec.order; i++) for(int i = 0; i < inputDec.order; i++)
...@@ -236,8 +250,6 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -236,8 +250,6 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem); XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
//_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1));
//_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec);
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F); _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]); _Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
...@@ -273,12 +285,6 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -273,12 +285,6 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
/* generate the mask on the source language side (for padding) */ /* generate the mask on the source language side (for padding) */
_Sum(&maskEnc, padding3, &maskEnc); _Sum(&maskEnc, padding3, &maskEnc);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
outputLayer->Make(decoding, output);
delete[] dims; delete[] dims;
delete[] dimsPadding; delete[] dimsPadding;
......
...@@ -81,7 +81,13 @@ public: ...@@ -81,7 +81,13 @@ public:
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining); void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */ /* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, XTensor &paddingDec, bool isTraining); void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output,
XTensor &paddingEnc, XTensor &paddingDec, bool isTraining);
/* make the mask for training MT models */
void MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskEnc, XTensor &maskDec, XTensor &maskEncDec);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(XList &list); void GetParams(XList &list);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论