Commit 30c3a629 by xiaotong

better code of t2t mask for lm and add DataSetDim

parent a7fe4564
......@@ -36,6 +36,7 @@ T2TAttention::T2TAttention()
dv = -1;
d = -1;
isMasked = false;
ignored = 0;
}
/* deconstructor */
......@@ -47,14 +48,19 @@ T2TAttention::~T2TAttention()
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIgnored - number of position ignored in attention (from the begining)
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID, XMem * myMem)
void T2TAttention::InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
isMasked = myIsMasked;
ignored = myIgnored;
float minmax = 0;
......@@ -116,6 +122,10 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
if(isMasked)
dot = dot + mask;
scalar = Softmax(Linear(dot, 1/(float)sqrt((float)dk)), -1);
if(ignored > 0)
_SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
att = BMMul(scalar, vheads);
/* concatenate the heads */
......
......@@ -69,6 +69,10 @@ public:
/* indicates whether the attention is masked */
bool isMasked;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int ignored;
public:
/* constructor */
T2TAttention();
......@@ -77,7 +81,9 @@ public:
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask);
......
......@@ -136,7 +136,7 @@ XTensor T2TEmbedder::Make(XTensor &input)
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)d));
/* we sum over the two embeddings */
return wordEmbedding +posEmbedding;
return wordEmbedding + posEmbedding;
}
}
......@@ -47,14 +47,17 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID, XMem * myMem)
void AttEncoder::InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
isMasked = myIsMasked;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
......@@ -74,7 +77,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, bool myIsMasked, int my
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, isMasked, myDevID, myMem);
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
......@@ -85,9 +88,10 @@ void AttEncoder::InitModel(int argc, const char ** argv, bool myIsMasked, int my
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask)
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes)
{
XTensor x;
......@@ -99,16 +103,27 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask)
XTensor fnn;
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x, mask);
if(skipInputRes && i == 0){
/* self attention */
att = attentions[i].Make(x, x, x, mask);
/* residual connection */
res = Sum(att, x);
/* TODO: dropout */
/* TODO: dropout */
/* layer normalization */
x = attLayerNorms[i].Make(att);
}
else{
/* self attention */
att = attentions[i].Make(x, x, x, mask);
/* layer normalization */
x = attLayerNorms[i].Make(res);
/* residual connection */
res = Sum(att, x);
/* TODO: dropout */
/* layer normalization */
x = attLayerNorms[i].Make(res);
}
/* fnn */
fnn = fnns[i].Make(x);
......
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input, XTensor &mask) = 0;
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input, XTensor &mask);
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes);
};
......@@ -76,9 +76,10 @@ public:
/* vocabulary size */
int vSize;
/* indicates whether masked attention is employed */
int isMasked;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int ignored;
/* embedding of word at each position */
T2TEmbedder embedder;
......@@ -109,10 +110,12 @@ public:
~AttEncoder();
/* initialize the model */
void InitModel(int argc, const char ** argv, bool myIsMasked, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask);
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes);
};
......
......@@ -90,7 +90,6 @@ XTensor T2TLN::Make(XTensor &input)
/* standard = sqrt(variance) */
standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into
the same shape of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
......
......@@ -34,6 +34,7 @@ T2TModel::T2TModel()
mem = NULL;
isLM = false;
isMT = false;
nhead = 1;
}
/* de-constructor */
......@@ -55,13 +56,14 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
if(useMem){
delete mem;
mem = new XMem(devID);
}
encoder.InitModel(argc, argv, isLM, devID, mem);
encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
}
......@@ -69,11 +71,12 @@ void T2TModel::InitModel(int argc, const char ** argv)
make the encoding network
>> input - input tensor
>> mask - the mask for positions that are/not involved in computation
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask)
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes)
{
return encoder.Make(input, mask);
return encoder.Make(input, mask, skipInputRes);
}
/*
......@@ -88,18 +91,21 @@ void T2TModel::Make(XTensor &input, XTensor &output)
if(isLM){
/* generate mask to see "previous" words only */
int len = input.GetDim(input.order - 2);
int dims[MAX_TENSOR_DIM_NUM];
int * dims = new int[input.order + 1];
for(int i = 0; i < input.order; i++)
dims[i] = input.GetDim(i);
dims[input.order - 1] = len;
XTensor mask(input.order, dims, X_FLOAT, 1.0F, input.devID, input.mem);
dims[i + 1] = input.GetDim(i);
dims[0] = nhead;
dims[input.order] = len;
XTensor mask(input.order + 1, dims, X_FLOAT, 1.0F, input.devID, input.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9 */
_SetDataLowTri(&mask, 1e-9, -1);
_ScaleAndShiftMe(&mask, 1.0F, -1e-9);
_SetDataLowTri(&mask, 1e9F, -1);
_ScaleAndShiftMe(&mask, 1.0F, -1e9F);
encoding = MakeEncoding(input, mask);
encoding = MakeEncoding(input, mask, true);
outputLayer.Make(encoding, output);
delete[] dims;
}
else{
ShowNTErrors("TODO!");
......
......@@ -55,6 +55,9 @@ public:
/* indicates whether the model is running for machine translation */
bool isMT;
/* number of heads in the attention model */
int nhead;
public:
/* constructor */
T2TModel();
......@@ -66,7 +69,7 @@ public:
void InitModel(int argc, const char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask);
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output);
......
......@@ -214,6 +214,56 @@ void _SetDataFixedDouble(XTensor * tensor, double p)
}
/*
set data items along with a given dimension (and keep the remaining items unchanged)
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
if(tensor->devID < 0){
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = n - 1; i > dim; i--){
stride *= tensor->GetDim(i);
}
blockSize = stride * tensor->GetDim(dim);
blockNum = tensor->unitNum / blockSize;
int l = len * stride;
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)tensor->data + blockSize * i + beg * stride;
for(int j = 0; j < l; j++)
d[j] = p;
}
}
else{
#ifdef USE_CUDA
_CudaSetDataDim(tensor, beg, len, dim, p);
#endif
}
}
/*
generate data as lower triangular matrics for last two dimensions
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
......@@ -247,10 +297,10 @@ void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift)
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)tensor->data + i * blockSize;
for(int row = 0; row < l; row++){
for(int col = 0; col < row + shift; col++){
d[row * l + col] = row;
for(int col = 0; col <= row + shift; col++){
d[row * l + col] = p;
}
for(int col = MAX(0, row + shift); col < l; col++){
for(int col = MAX(0, row + shift + 1); col < l; col++){
d[row * l + col] = 0;
}
}
......
......@@ -185,6 +185,82 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - kernel version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length of the segment to be set
>> blockSize - size of a data block
>> blockNum - number of data blocks
*/
__global__
void KernelSetDataDim(DTYPE * d, int beg, int len, int blockSize, int blockNum, DTYPE p)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockSize || j > blockNum)
return;
if(i < beg || i >= beg + len)
return;
d[blockSize * j + i] = p;
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - cuda version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = n - 1; i > dim; i--){
stride *= tensor->GetDim(i);
}
blockSize = stride * tensor->GetDim(dim);
blockNum = tensor->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(tensor->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride, len * stride, blockSize, blockNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
set lower triangular matrics for each block
>> d - pointer to the data array
>> l - row number (or column number) of each block, i.e,
......@@ -204,7 +280,7 @@ e.g., for a 3* 3 tensor,
2 2 0
*/
__global__
void _KernalSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p, int shift)
void _KernelSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p, int shift)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -219,7 +295,7 @@ void _KernalSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p
int col = i % l;
DTYPE * d2 = d + blockSize * j + row * l + col;
if(col < row + shift)
if(col <= row + shift)
*d2 = p;
else
*d2 = 0;
......@@ -266,7 +342,7 @@ void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
_KernalSetDataLowTri<<<blocks, threads >>>((DTYPE*)tensor->data, l, blockSize, blockNum, p, shift);
_KernelSetDataLowTri<<<blocks, threads >>>((DTYPE*)tensor->data, l, blockSize, blockNum, p, shift);
BacktoCudaDev(tensor->devID, devIDBackup);
}
......
......@@ -37,6 +37,9 @@ void _CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* generate data as lower triangular matrics for last two dimensions (cuda version) */
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
......
......@@ -45,6 +45,9 @@ void _SetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _SetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* generate data as lower triangular matrics for last two dimensions */
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论