Commit 599ed44c by xiaotong

better implementation of mask

parent d061d183
...@@ -204,7 +204,11 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -204,7 +204,11 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor maskDec; XTensor maskDec;
XTensor maskEncDec; XTensor maskEncDec;
MakeMTMask(inputEnc, inputDec, paddingEnc, paddingDec, maskEnc, maskDec, maskEncDec); /* encoder mask */
MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
/* decoder mask */
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining); encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
...@@ -240,13 +244,13 @@ void T2TModel::MakeMTMask(XTensor &inputEnc, XTensor &inputDec, ...@@ -240,13 +244,13 @@ void T2TModel::MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
a given sequence. */ a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevent the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
...@@ -296,7 +300,92 @@ void T2TModel::MakeMTMask(XTensor &inputEnc, XTensor &inputDec, ...@@ -296,7 +300,92 @@ void T2TModel::MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
DelTensorBuf(padding3); DelTensorBuf(padding3);
DelTensorBuf(padding2); DelTensorBuf(padding2);
} }
/*
make the mask of the encoder
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder input
>> maskEnc - mask of the encoder self-attention
*/
void T2TModel::MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc)
{
/* padding on the source side */
int * dimsPadding = new int[paddingEnc.order + 2];
for (int i = 0; i < paddingEnc.order - 1; i++)
dimsPadding[i] = paddingEnc.GetDim(i);
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
_Unsqueeze(padding2, padding3, 0, nhead);
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */
_Sum(&maskEnc, padding3, &maskEnc);
DelTensorBuf(padding3);
DelTensorBuf(padding2);
delete[] dimsPadding;
}
/*
make the mask of the decoder
>> inputEnc - input of the encoder
>> inputDec - input of the decoder
>> paddingEnc - padding of the encoder input
>> paddingDec - padding of the decoder input
>> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention
*/
void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec)
{
int len = inputDec.GetDim(inputDec.order - 1);
int * dims = new int[inputDec.order + 2];
for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
DelTensorBuf(maskEncDecTMPDec);
DelTensorBuf(maskEncDecTMPEnc);
delete[] dims;
}
/* /*
get parameter matrics get parameter matrics
>> list - the list that keeps the parameter matrics >> list - the list that keeps the parameter matrics
......
...@@ -94,6 +94,14 @@ public: ...@@ -94,6 +94,14 @@ public:
void MakeMTMask(XTensor &inputEnc, XTensor &inputDec, void MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec, XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskEnc, XTensor &maskDec, XTensor &maskEncDec); XTensor &maskEnc, XTensor &maskDec, XTensor &maskEncDec);
/* make the mask of the encoder */
void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc);
/* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec);
/* get parameter matrics */ /* get parameter matrics */
void GetParams(XList &list); void GetParams(XList &list);
......
...@@ -30,15 +30,19 @@ namespace transformer ...@@ -30,15 +30,19 @@ namespace transformer
search for the most promising states search for the most promising states
>> model - the transformer model >> model - the transformer model
>> input - input of the model >> input - input of the model
>> padding - padding of the input
>> output - output that represents the sequences as rows >> output - output that represents the sequences as rows
*/ */
void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * output) void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output)
{ {
XTensor maskNULL; XTensor maskEnc;
XTensor encoding; XTensor encoding;
/* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc);
/* make the encoding network */ /* make the encoding network */
encoding = model->MakeEncoder(*input, maskNULL, false); encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME); encoding.SetName(ENCODING_NAME);
T2TPredictor predictor; T2TPredictor predictor;
...@@ -46,7 +50,6 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * output) ...@@ -46,7 +50,6 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * output)
T2TStateBundle * cur = &state1; T2TStateBundle * cur = &state1;
T2TStateBundle * next = &state2; T2TStateBundle * next = &state2;
/* initialize the predictor */ /* initialize the predictor */
predictor.Init(model, &encoding, cur); predictor.Init(model, &encoding, cur);
...@@ -84,4 +87,4 @@ void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output) ...@@ -84,4 +87,4 @@ void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
{ {
} }
} }
\ No newline at end of file
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
namespace transformer namespace transformer
{ {
/* The class orgnizes the search process. It calls predictors to generate /* The class orgnizes the search process. It calls “predictors” to generate
distributions of the predictions and prunes the search space by beam pruning. distributions of the predictions and prunes the search space by beam pruning.
It results in a graph where each path respresents a translation hypothsis. It results in a graph where each path respresents a translation hypothsis.
The output can be the path with the highest model score. */ The output can be the path with the highest model score. */
...@@ -46,7 +46,7 @@ public: ...@@ -46,7 +46,7 @@ public:
~T2TSearch() {}; ~T2TSearch() {};
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * output); void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* beam pruning */ /* beam pruning */
void Prune(T2TStateBundle * beam); void Prune(T2TStateBundle * beam);
...@@ -57,4 +57,4 @@ public: ...@@ -57,4 +57,4 @@ public:
} }
#endif #endif
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论