Commit 3e7f7645 by xiaotong

generate masks for t2t mt models

parent 21892dbf
...@@ -60,7 +60,7 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -60,7 +60,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
/* initialize the stacked layers */ /* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
attentionsEnde[i].InitModel(argc, argv, false, myIgnored, myDevID, myMem); attentionsEnde[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
} }
} }
...@@ -69,11 +69,12 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -69,11 +69,12 @@ void AttDecoder::InitModel(int argc, char ** argv,
make the decoding network make the decoding network
>> inputDec - the input tensor of the decoder >> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder >> outputEnc - the output tensor of the encoder
>> mask - the mask that indicate each position is valid >> mask - mask that indicates which position is valid
>> mask - mask for the encoder-decoder attention
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder << return - the output tensor of the encoder
*/ */
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining) XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{ {
XTensor x; XTensor x;
...@@ -89,7 +90,6 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b ...@@ -89,7 +90,6 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b
XTensor ln; XTensor ln;
XTensor fnn; XTensor fnn;
XTensor res; XTensor res;
XTensor nothing;
/******************/ /******************/
/* self attention */ /* self attention */
...@@ -107,7 +107,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b ...@@ -107,7 +107,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b
/*****************************/ /*****************************/
/* encoder-decoder attention */ /* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, nothing, isTraining); ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
......
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1, XMem * myMem = NULL);
/* make the decoding network */ /* make the decoding network */
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining); XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining);
}; };
} }
......
...@@ -93,10 +93,11 @@ void AttEncoder::InitModel(int argc, char ** argv, ...@@ -93,10 +93,11 @@ void AttEncoder::InitModel(int argc, char ** argv,
make the encoding network make the encoding network
>> input - the input tensor of the encoder >> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid >> mask - the mask that indicate each position is valid
>> maskEncDec - no use
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder << return - the output tensor of the encoder
*/ */
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining) XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{ {
XTensor x; XTensor x;
...@@ -144,4 +145,18 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining) ...@@ -144,4 +145,18 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
return x; return x;
} }
/*
make the encoding network (wrapper)
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
{
XTensor nothing;
return Make(input, mask, nothing, isTraining);
}
} }
...@@ -40,7 +40,7 @@ class T2TEncoder ...@@ -40,7 +40,7 @@ class T2TEncoder
{ {
public: public:
virtual virtual
XTensor Make(XTensor &input, XTensor &mask, bool isTraining) = 0; XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining) = 0;
}; };
/* /*
...@@ -49,7 +49,7 @@ the encoder based on RNN ...@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder class RNNEncoder : T2TEncoder
{ {
public: public:
XTensor Make(XTensor &input, XTensor &mask, bool isTraining); XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining);
}; };
...@@ -118,6 +118,9 @@ public: ...@@ -118,6 +118,9 @@ public:
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */ /* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining);
/* make the encoding network (wrapper) */
XTensor Make(XTensor &input, XTensor &mask, bool isTraining); XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
}; };
......
...@@ -75,7 +75,7 @@ void T2TModel::InitModel(int argc, char ** argv) ...@@ -75,7 +75,7 @@ void T2TModel::InitModel(int argc, char ** argv)
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION); mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
} }
encoder->InitModel(argc, argv, isLM, 0, devID, mem); encoder->InitModel(argc, argv, true, 0, devID, mem);
outputLayer->InitModel(argc, argv, devID, mem); outputLayer->InitModel(argc, argv, devID, mem);
if(isMT) if(isMT)
...@@ -99,7 +99,9 @@ make the encoding network ...@@ -99,7 +99,9 @@ make the encoding network
*/ */
XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining) XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining)
{ {
return encoder->Make(input, mask, isTraining); XTensor nothing;
return encoder->Make(input, mask, nothing, isTraining);
} }
/* /*
...@@ -107,13 +109,14 @@ make the decoding network ...@@ -107,13 +109,14 @@ make the decoding network
>> inputDec - input tensor of the decoder >> inputDec - input tensor of the decoder
>> outputEnc - output tensor of the encoder >> outputEnc - output tensor of the encoder
>> output - output tensor (distribution) >> output - output tensor (distribution)
>> mask - the mask for positions that are/not involved in computation >> mask - mask for positions that are/not involved in computation
>> maskEncDec - mask for the encoder-decoder attention
>> isTraining - indicates whether we are training the model >> isTraining - indicates whether we are training the model
<< return - encoding result << return - encoding result
*/ */
XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining) XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{ {
return decoder->Make(inputDec, outputEnc, mask, isTraining); return decoder->Make(inputDec, outputEnc, mask, maskEncDec, isTraining);
} }
/* /*
...@@ -190,14 +193,16 @@ make the network for machine translation (with the output softmax layer) ...@@ -190,14 +193,16 @@ make the network for machine translation (with the output softmax layer)
>> inputDec - input tensor of the decoder >> inputDec - input tensor of the decoder
>> output - output tensor (distribution) >> output - output tensor (distribution)
>> paddingEnc - padding of the sequences (on the encoder side) >> paddingEnc - padding of the sequences (on the encoder side)
>> paddingDec - padding of the sequences (on the decoder side)
>> isTraining - indicates whether the model is for training >> isTraining - indicates whether the model is for training
*/ */
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining) void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, XTensor &paddingDec, bool isTraining)
{ {
XTensor encoding; XTensor encoding;
XTensor decoding; XTensor decoding;
XTensor maskEnc; XTensor maskEnc;
XTensor maskDec; XTensor maskDec;
XTensor maskEncDec;
/* generate mask to see "previous" words on the decoder side */ /* generate mask to see "previous" words on the decoder side */
//int len = inputDec.GetDim(inputDec.order - 2); //int len = inputDec.GetDim(inputDec.order - 2);
...@@ -222,6 +227,23 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -222,6 +227,23 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevent the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1));
_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec);
_ScaleAndShiftMe(maskEncDecTMPDec, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPDec, &maskEncDec, 0, dims[0]);
DelTensorBuf(maskEncDecTMPDec);
DelTensorBuf(maskEncDecTMPEnc);
/* padding on the source side */ /* padding on the source side */
int * dimsPadding = new int[paddingEnc.order + 2]; int * dimsPadding = new int[paddingEnc.order + 2];
for (int i = 0; i < paddingEnc.order - 1; i++) for (int i = 0; i < paddingEnc.order - 1; i++)
...@@ -252,7 +274,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -252,7 +274,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_Sum(&maskEnc, padding3, &maskEnc); _Sum(&maskEnc, padding3, &maskEnc);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining); encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining); decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
outputLayer->Make(decoding, output); outputLayer->Make(decoding, output);
delete[] dims; delete[] dims;
......
...@@ -72,13 +72,13 @@ public: ...@@ -72,13 +72,13 @@ public:
XTensor MakeEncoder(XTensor &input, XTensor &mask, bool isTraining); XTensor MakeEncoder(XTensor &input, XTensor &mask, bool isTraining);
/* make the encoding network */ /* make the encoding network */
XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor &mask, bool isTraining); XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor &mask, XTensor &MaskEncDec, bool isTraining);
/* make the network for langauge modeling (with the output softmax layer) */ /* make the network for langauge modeling (with the output softmax layer) */
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining); void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */ /* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining); void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, XTensor &paddingDec, bool isTraining);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(XList &list); void GetParams(XList &list);
......
...@@ -208,7 +208,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -208,7 +208,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if(model->isLM) if(model->isLM)
model->MakeLM(batchEnc, output, paddingEnc, true); model->MakeLM(batchEnc, output, paddingEnc, true);
else if(model->isMT) else if(model->isMT)
model->MakeMT(batchEnc, batchDec, output, paddingEnc, true); model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true);
else{ else{
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
...@@ -358,7 +358,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -358,7 +358,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
if(model->isLM) if(model->isLM)
model->MakeLM(batchEnc, output, paddingEnc, false); model->MakeLM(batchEnc, output, paddingEnc, false);
else if(model->isMT) else if(model->isMT)
model->MakeMT(batchEnc, batchDec, output, paddingEnc, false); model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, false);
else{ else{
ShowNTErrors("Illegal model type!"); ShowNTErrors("Illegal model type!");
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论