Commit a7fe4564 by xiaotong

mask of the encoder

parent 80cfa480
...@@ -46,13 +46,15 @@ AttEncoder::~AttEncoder() ...@@ -46,13 +46,15 @@ AttEncoder::~AttEncoder()
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void AttEncoder::InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
isMasked = myIsMasked;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6); LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
...@@ -72,7 +74,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM ...@@ -72,7 +74,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
/* initialize the stacked layers */ /* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, false, myDevID, myMem); attentions[i].InitModel(argc, argv, isMasked, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem); fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem); fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
......
...@@ -76,6 +76,9 @@ public: ...@@ -76,6 +76,9 @@ public:
/* vocabulary size */ /* vocabulary size */
int vSize; int vSize;
/* indicates whether masked attention is employed */
int isMasked;
/* embedding of word at each position */ /* embedding of word at each position */
T2TEmbedder embedder; T2TEmbedder embedder;
...@@ -106,7 +109,7 @@ public: ...@@ -106,7 +109,7 @@ public:
~AttEncoder(); ~AttEncoder();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */ /* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask); XTensor Make(XTensor &input, XTensor &mask);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/ */
#include <math.h>
#include "T2TLayerNormal.h" #include "T2TLayerNormal.h"
#include "T2TUtility.h" #include "T2TUtility.h"
#include "T2TEmbedding.h" #include "T2TEmbedding.h"
......
...@@ -61,7 +61,7 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -61,7 +61,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
mem = new XMem(devID); mem = new XMem(devID);
} }
encoder.InitModel(argc, argv, devID, mem); encoder.InitModel(argc, argv, isLM, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem); outputLayer.InitModel(argc, argv, devID, mem);
} }
...@@ -93,6 +93,10 @@ void T2TModel::Make(XTensor &input, XTensor &output) ...@@ -93,6 +93,10 @@ void T2TModel::Make(XTensor &input, XTensor &output)
dims[i] = input.GetDim(i); dims[i] = input.GetDim(i);
dims[input.order - 1] = len; dims[input.order - 1] = len;
XTensor mask(input.order, dims, X_FLOAT, 1.0F, input.devID, input.mem); XTensor mask(input.order, dims, X_FLOAT, 1.0F, input.devID, input.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9 */
_SetDataLowTri(&mask, 1e-9, -1);
_ScaleAndShiftMe(&mask, 1.0F, -1e-9);
encoding = MakeEncoding(input, mask); encoding = MakeEncoding(input, mask);
outputLayer.Make(encoding, output); outputLayer.Make(encoding, output);
......
...@@ -250,7 +250,7 @@ void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift) ...@@ -250,7 +250,7 @@ void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift)
for(int col = 0; col < row + shift; col++){ for(int col = 0; col < row + shift; col++){
d[row * l + col] = row; d[row * l + col] = row;
} }
for(int col = row + shift; col < l; col++){ for(int col = MAX(0, row + shift); col < l; col++){
d[row * l + col] = 0; d[row * l + col] = 0;
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论