Commit 3e7f7645 by xiaotong

generate masks for t2t mt models

parent 21892dbf
......@@ -60,7 +60,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentionsEnde[i].InitModel(argc, argv, false, myIgnored, myDevID, myMem);
attentionsEnde[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
}
......@@ -69,11 +69,12 @@ void AttDecoder::InitModel(int argc, char ** argv,
make the decoding network
>> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder
>> mask - the mask that indicate each position is valid
>> mask - mask that indicates which position is valid
>> mask - mask for the encoder-decoder attention
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining)
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{
XTensor x;
......@@ -89,7 +90,6 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b
XTensor ln;
XTensor fnn;
XTensor res;
XTensor nothing;
/******************/
/* self attention */
......@@ -107,7 +107,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b
/*****************************/
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, nothing, isTraining);
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......
......@@ -48,7 +48,7 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the decoding network */
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining);
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining);
};
}
......
......@@ -93,10 +93,11 @@ void AttEncoder::InitModel(int argc, char ** argv,
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> maskEncDec - no use
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{
XTensor x;
......@@ -144,4 +145,18 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
return x;
}
/*
make the encoding network (wrapper)
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
{
XTensor nothing;
return Make(input, mask, nothing, isTraining);
}
}
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input, XTensor &mask, bool isTraining) = 0;
XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining);
};
......@@ -118,6 +118,9 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining);
/* make the encoding network (wrapper) */
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
};
......
......@@ -75,7 +75,7 @@ void T2TModel::InitModel(int argc, char ** argv)
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder->InitModel(argc, argv, isLM, 0, devID, mem);
encoder->InitModel(argc, argv, true, 0, devID, mem);
outputLayer->InitModel(argc, argv, devID, mem);
if(isMT)
......@@ -99,7 +99,9 @@ make the encoding network
*/
XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining)
{
return encoder->Make(input, mask, isTraining);
XTensor nothing;
return encoder->Make(input, mask, nothing, isTraining);
}
/*
......@@ -107,13 +109,14 @@ make the decoding network
>> inputDec - input tensor of the decoder
>> outputEnc - output tensor of the encoder
>> output - output tensor (distribution)
>> mask - the mask for positions that are/not involved in computation
>> mask - mask for positions that are/not involved in computation
>> maskEncDec - mask for the encoder-decoder attention
>> isTraining - indicates whether we are training the model
<< return - encoding result
*/
XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining)
XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{
return decoder->Make(inputDec, outputEnc, mask, isTraining);
return decoder->Make(inputDec, outputEnc, mask, maskEncDec, isTraining);
}
/*
......@@ -190,14 +193,16 @@ make the network for machine translation (with the output softmax layer)
>> inputDec - input tensor of the decoder
>> output - output tensor (distribution)
>> paddingEnc - padding of the sequences (on the encoder side)
>> paddingDec - padding of the sequences (on the decoder side)
>> isTraining - indicates whether the model is for training
*/
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining)
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, XTensor &paddingDec, bool isTraining)
{
XTensor encoding;
XTensor decoding;
XTensor maskEnc;
XTensor maskDec;
XTensor maskEncDec;
/* generate mask to see "previous" words on the decoder side */
//int len = inputDec.GetDim(inputDec.order - 2);
......@@ -222,6 +227,23 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevent the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1));
_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec);
_ScaleAndShiftMe(maskEncDecTMPDec, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPDec, &maskEncDec, 0, dims[0]);
DelTensorBuf(maskEncDecTMPDec);
DelTensorBuf(maskEncDecTMPEnc);
/* padding on the source side */
int * dimsPadding = new int[paddingEnc.order + 2];
for (int i = 0; i < paddingEnc.order - 1; i++)
......@@ -252,7 +274,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_Sum(&maskEnc, padding3, &maskEnc);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
outputLayer->Make(decoding, output);
delete[] dims;
......
......@@ -72,13 +72,13 @@ public:
XTensor MakeEncoder(XTensor &input, XTensor &mask, bool isTraining);
/* make the encoding network */
XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor &mask, bool isTraining);
XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor &mask, XTensor &MaskEncDec, bool isTraining);
/* make the network for langauge modeling (with the output softmax layer) */
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, bool isTraining);
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, XTensor &paddingDec, bool isTraining);
/* get parameter matrics */
void GetParams(XList &list);
......
......@@ -208,7 +208,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if(model->isLM)
model->MakeLM(batchEnc, output, paddingEnc, true);
else if(model->isMT)
model->MakeMT(batchEnc, batchDec, output, paddingEnc, true);
model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, true);
else{
ShowNTErrors("Illegal model type!");
}
......@@ -358,7 +358,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
if(model->isLM)
model->MakeLM(batchEnc, output, paddingEnc, false);
else if(model->isMT)
model->MakeMT(batchEnc, batchDec, output, paddingEnc, false);
model->MakeMT(batchEnc, batchDec, output, paddingEnc, paddingDec, false);
else{
ShowNTErrors("Illegal model type!");
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论