Commit 599ed44c by xiaotong

better implementation of mask

parent d061d183
......@@ -204,7 +204,11 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor maskDec;
XTensor maskEncDec;
MakeMTMask(inputEnc, inputDec, paddingEnc, paddingDec, maskEnc, maskDec, maskEncDec);
/* encoder mask */
MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
/* decoder mask */
MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
......@@ -240,13 +244,13 @@ void T2TModel::MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevent the attention to padding dummy words */
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
......@@ -296,7 +300,92 @@ void T2TModel::MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
DelTensorBuf(padding3);
DelTensorBuf(padding2);
}
/*
make the mask of the encoder
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder input
>> maskEnc - mask of the encoder self-attention
*/
void T2TModel::MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc)
{
/* padding on the source side */
int * dimsPadding = new int[paddingEnc.order + 2];
for (int i = 0; i < paddingEnc.order - 1; i++)
dimsPadding[i] = paddingEnc.GetDim(i);
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead;
XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
/* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
_Unsqueeze(padding2, padding3, 0, nhead);
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
InitTensor(&maskEnc, padding3);
maskEnc.SetZeroAll();
/* generate the mask on the source language side (for padding) */
_Sum(&maskEnc, padding3, &maskEnc);
DelTensorBuf(padding3);
DelTensorBuf(padding2);
delete[] dimsPadding;
}
/*
make the mask of the decoder
>> inputEnc - input of the encoder
>> inputDec - input of the decoder
>> paddingEnc - padding of the encoder input
>> paddingDec - padding of the decoder input
>> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention
*/
void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec)
{
int len = inputDec.GetDim(inputDec.order - 1);
int * dims = new int[inputDec.order + 2];
for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem);
/* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
paddingEnc.denseRatio, paddingEnc.devID, paddingEnc.mem);
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
DelTensorBuf(maskEncDecTMPDec);
DelTensorBuf(maskEncDecTMPEnc);
delete[] dims;
}
/*
get parameter matrics
>> list - the list that keeps the parameter matrics
......
......@@ -94,6 +94,14 @@ public:
void MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskEnc, XTensor &maskDec, XTensor &maskEncDec);
/* make the mask of the encoder */
void MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc);
/* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec);
/* get parameter matrics */
void GetParams(XList &list);
......
......@@ -30,15 +30,19 @@ namespace transformer
search for the most promising states
>> model - the transformer model
>> input - input of the model
>> padding - padding of the input
>> output - output that represents the sequences as rows
*/
void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * output)
void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output)
{
XTensor maskNULL;
XTensor maskEnc;
XTensor encoding;
/* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc);
/* make the encoding network */
encoding = model->MakeEncoder(*input, maskNULL, false);
encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME);
T2TPredictor predictor;
......@@ -46,7 +50,6 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * output)
T2TStateBundle * cur = &state1;
T2TStateBundle * next = &state2;
/* initialize the predictor */
predictor.Init(model, &encoding, cur);
......@@ -84,4 +87,4 @@ void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
{
}
}
\ No newline at end of file
}
......@@ -28,7 +28,7 @@
namespace transformer
{
/* The class orgnizes the search process. It calls predictors to generate
/* The class orgnizes the search process. It calls “predictors” to generate
distributions of the predictions and prunes the search space by beam pruning.
It results in a graph where each path respresents a translation hypothsis.
The output can be the path with the highest model score. */
......@@ -46,7 +46,7 @@ public:
~T2TSearch() {};
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * output);
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* beam pruning */
void Prune(T2TStateBundle * beam);
......@@ -57,4 +57,4 @@ public:
}
#endif
\ No newline at end of file
#endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论