Commit 45a5a936 by xiaotong

add mask

parent c767cce8
......@@ -35,6 +35,7 @@ T2TAttention::T2TAttention()
dk = -1;
dv = -1;
d = -1;
isMasked = false;
}
/* deconstructor */
......@@ -46,10 +47,11 @@ T2TAttention::~T2TAttention()
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TAttention::InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......@@ -82,9 +84,10 @@ make the network
and H = vector size of each position
>> q - queries
>> v - values
>> maske - as it is
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
{
XTensor k2;
XTensor q2;
......@@ -105,10 +108,14 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
vheads = Split(v2, v2.order - 1, nhead);
XTensor att;
XTensor dot;
XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/(float)sqrt((float)dk)), -1);
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked)
dot = dot + mask;
scalar = Softmax(Linear(dot, 1/(float)sqrt((float)dk)), -1);
att = BMMul(scalar, vheads);
/* concatenate the heads */
......
......@@ -66,6 +66,9 @@ public:
/* size of input Q, K and V */
int d;
/* indicates whether the attention is masked */
bool isMasked;
public:
/* constructor */
T2TAttention();
......@@ -74,10 +77,10 @@ public:
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v);
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask);
};
}
......
......@@ -72,7 +72,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myDevID, myMem);
attentions[i].InitModel(argc, argv, false, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
......@@ -82,9 +82,10 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
/*
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input)
XTensor AttEncoder::Make(XTensor &input, XTensor &mask)
{
XTensor x;
......@@ -97,7 +98,7 @@ XTensor AttEncoder::Make(XTensor &input)
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x);
att = attentions[i].Make(x, x, x, mask);
/* residual connection */
res = Sum(att, x);
......
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input) = 0;
XTensor Make(XTensor &input, XTensor &mask) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input);
XTensor Make(XTensor &input, XTensor &mask);
};
......@@ -109,7 +109,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input);
XTensor Make(XTensor &input, XTensor &mask);
};
......
......@@ -96,7 +96,7 @@ XTensor T2TLN::Make(XTensor &input)
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */
xn = (x - meanFilled)/standardFilled ;
xn = (x - meanFilled)/standardFilled;
/* result = x' * w + b */
return MMul(xn, w) + b;
......
......@@ -68,11 +68,12 @@ void T2TModel::InitModel(int argc, const char ** argv)
/*
make the encoding network
>> input - input tensor
>> mask - the mask for positions that are/not involved in computation
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input)
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask)
{
return encoder.Make(input);
return encoder.Make(input, mask);
}
/*
......@@ -85,7 +86,15 @@ void T2TModel::Make(XTensor &input, XTensor &output)
XTensor encoding;
if(isLM){
encoding = MakeEncoding(input);
/* generate mask to see "previous" words only */
int len = input.GetDim(input.order - 2);
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < input.order; i++)
dims[i] = input.GetDim(i);
dims[input.order - 1] = len;
XTensor mask(input.order, dims, X_FLOAT, 1.0F, input.devID, input.mem);
encoding = MakeEncoding(input, mask);
outputLayer.Make(encoding, output);
}
else{
......
......@@ -66,7 +66,7 @@ public:
void InitModel(int argc, const char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input);
XTensor MakeEncoding(XTensor &input, XTensor &mask);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output);
......
......@@ -100,7 +100,9 @@ void ShowParams(int argc, const char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
if(argv[i][0] == '-'){
if(argv[i][1] == 0)
continue;
if(argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')){
if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论