Commit 45a5a936 by xiaotong

add mask

parent c767cce8
...@@ -35,6 +35,7 @@ T2TAttention::T2TAttention() ...@@ -35,6 +35,7 @@ T2TAttention::T2TAttention()
dk = -1; dk = -1;
dv = -1; dv = -1;
d = -1; d = -1;
isMasked = false;
} }
/* deconstructor */ /* deconstructor */
...@@ -46,10 +47,11 @@ T2TAttention::~T2TAttention() ...@@ -46,10 +47,11 @@ T2TAttention::~T2TAttention()
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void T2TAttention::InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
...@@ -82,9 +84,10 @@ make the network ...@@ -82,9 +84,10 @@ make the network
and H = vector size of each position and H = vector size of each position
>> q - queries >> q - queries
>> v - values >> v - values
>> maske - as it is
<< return - multi-attention result << return - multi-attention result
*/ */
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v) XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
{ {
XTensor k2; XTensor k2;
XTensor q2; XTensor q2;
...@@ -105,10 +108,14 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v) ...@@ -105,10 +108,14 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
vheads = Split(v2, v2.order - 1, nhead); vheads = Split(v2, v2.order - 1, nhead);
XTensor att; XTensor att;
XTensor dot;
XTensor scalar; XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */ /* scalar = softmax(Q * K^T / sqrt(dk)) * V */
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/(float)sqrt((float)dk)), -1); dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked)
dot = dot + mask;
scalar = Softmax(Linear(dot, 1/(float)sqrt((float)dk)), -1);
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
......
...@@ -66,6 +66,9 @@ public: ...@@ -66,6 +66,9 @@ public:
/* size of input Q, K and V */ /* size of input Q, K and V */
int d; int d;
/* indicates whether the attention is masked */
bool isMasked;
public: public:
/* constructor */ /* constructor */
T2TAttention(); T2TAttention();
...@@ -74,10 +77,10 @@ public: ...@@ -74,10 +77,10 @@ public:
~T2TAttention(); ~T2TAttention();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID = -1, XMem * myMem = NULL);
/* make the network */ /* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v); XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask);
}; };
} }
......
...@@ -72,7 +72,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM ...@@ -72,7 +72,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
/* initialize the stacked layers */ /* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myDevID, myMem); attentions[i].InitModel(argc, argv, false, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem); fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem); fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
...@@ -82,9 +82,10 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM ...@@ -82,9 +82,10 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
/* /*
make the encoding network make the encoding network
>> input - the input tensor of the encoder >> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
<< return - the output tensor of the encoder << return - the output tensor of the encoder
*/ */
XTensor AttEncoder::Make(XTensor &input) XTensor AttEncoder::Make(XTensor &input, XTensor &mask)
{ {
XTensor x; XTensor x;
...@@ -97,7 +98,7 @@ XTensor AttEncoder::Make(XTensor &input) ...@@ -97,7 +98,7 @@ XTensor AttEncoder::Make(XTensor &input)
XTensor res; XTensor res;
/* self attention */ /* self attention */
att = attentions[i].Make(x, x, x); att = attentions[i].Make(x, x, x, mask);
/* residual connection */ /* residual connection */
res = Sum(att, x); res = Sum(att, x);
......
...@@ -40,7 +40,7 @@ class T2TEncoder ...@@ -40,7 +40,7 @@ class T2TEncoder
{ {
public: public:
virtual virtual
XTensor Make(XTensor &input) = 0; XTensor Make(XTensor &input, XTensor &mask) = 0;
}; };
/* /*
...@@ -49,7 +49,7 @@ the encoder based on RNN ...@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder class RNNEncoder : T2TEncoder
{ {
public: public:
XTensor Make(XTensor &input); XTensor Make(XTensor &input, XTensor &mask);
}; };
...@@ -109,7 +109,7 @@ public: ...@@ -109,7 +109,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */ /* make the encoding network */
XTensor Make(XTensor &input); XTensor Make(XTensor &input, XTensor &mask);
}; };
......
...@@ -96,7 +96,7 @@ XTensor T2TLN::Make(XTensor &input) ...@@ -96,7 +96,7 @@ XTensor T2TLN::Make(XTensor &input)
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1)); standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */ /* x' = (x - \mu)/standard */
xn = (x - meanFilled)/standardFilled ; xn = (x - meanFilled)/standardFilled;
/* result = x' * w + b */ /* result = x' * w + b */
return MMul(xn, w) + b; return MMul(xn, w) + b;
......
...@@ -68,11 +68,12 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -68,11 +68,12 @@ void T2TModel::InitModel(int argc, const char ** argv)
/* /*
make the encoding network make the encoding network
>> input - input tensor >> input - input tensor
>> mask - the mask for positions that are/not involved in computation
<< return - encoding result << return - encoding result
*/ */
XTensor T2TModel::MakeEncoding(XTensor &input) XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask)
{ {
return encoder.Make(input); return encoder.Make(input, mask);
} }
/* /*
...@@ -85,7 +86,15 @@ void T2TModel::Make(XTensor &input, XTensor &output) ...@@ -85,7 +86,15 @@ void T2TModel::Make(XTensor &input, XTensor &output)
XTensor encoding; XTensor encoding;
if(isLM){ if(isLM){
encoding = MakeEncoding(input); /* generate mask to see "previous" words only */
int len = input.GetDim(input.order - 2);
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < input.order; i++)
dims[i] = input.GetDim(i);
dims[input.order - 1] = len;
XTensor mask(input.order, dims, X_FLOAT, 1.0F, input.devID, input.mem);
encoding = MakeEncoding(input, mask);
outputLayer.Make(encoding, output); outputLayer.Make(encoding, output);
} }
else{ else{
......
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
void InitModel(int argc, const char ** argv); void InitModel(int argc, const char ** argv);
/* make the encoding network */ /* make the encoding network */
XTensor MakeEncoding(XTensor &input); XTensor MakeEncoding(XTensor &input, XTensor &mask);
/* make the entire network (with the output softmax layer) */ /* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output); void Make(XTensor &input, XTensor &output);
......
...@@ -100,7 +100,9 @@ void ShowParams(int argc, const char ** argv) ...@@ -100,7 +100,9 @@ void ShowParams(int argc, const char ** argv)
{ {
fprintf(stderr, "args:\n"); fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){ for(int i = 0; i < argc; i++){
if(argv[i][0] == '-'){ if(argv[i][1] == 0)
continue;
if(argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')){
if(i + 1 < argc && argv[i + 1][0] != '-') if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]); fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else else
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论