Commit b801df51 by liyinqiao

Update the codes of Transformer sample and XList class.

1. Update the codes of machine translation sample. The current version is the same with NiuTrans.NMT.
2. Update the XList class.
3. Bugs fix.
parent 8178ba40
......@@ -128,8 +128,10 @@ int FNNLMMain(int argc, const char ** argv)
Init(model);
/* learn model parameters */
if(strcmp(trainFN, ""))
if(strcmp(trainFN, "")) {
ENABLE_GRAD;
Train(trainFN, shuffled, model);
}
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TAttention.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TAttention::T2TAttention()
{
nhead = -1;
dk = -1;
dv = -1;
d = -1;
isMasked = false;
ignored = 0;
}
/* deconstructor */
T2TAttention::~T2TAttention()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIgnored - number of position ignored in attention (from the begining)
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id
*/
void T2TAttention::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID)
{
devID = myDevID;
isMasked = myIsMasked;
ignored = myIgnored;
float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamInt(argc, argv, "d", &dk, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutatt", &dropoutP, 0);
InitTensor2D(&wk, d, dk, X_FLOAT, devID);
InitTensor2D(&wq, d, dk, X_FLOAT, devID);
InitTensor2D(&wv, d, dv, X_FLOAT, devID);
InitTensor2D(&wa, d, d, X_FLOAT, devID);
InitTensor2D(&wbig, d, 3 * d, X_FLOAT, devID);
float scale = 1.0F;
_SetDataFanInOut(&wk, scale);
_SetDataFanInOut(&wq, scale);
_SetDataFanInOut(&wv, scale);
_SetDataFanInOut(&wa, scale);
_SetDataFanInOut(&wbig, scale);
}
/*
make the network
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
>> q - queries
>> v - values
>> mask - as it is
>> isTraining - indicates whether the model is used for training
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
XTensor k2;
XTensor q2;
XTensor v2;
/* linear transformation before self-attention */
k2 = MMul(k, wk);
q2 = MMul(q, wq);
v2 = MMul(v, wv);
return MakeAttention(k2, q2, v2, mask, isTraining);
}
/*
make the network given a big tensor that keeps keys, queries and values
>> kqv - the big tensor
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeBig(XTensor &kqv, XTensor &mask, bool isTraining)
{
XTensor k2;
XTensor q2;
XTensor v2;
XTensor kqv2;
TensorList split;
kqv2 = MMul(kqv, wbig);
int d1 = kqv2.GetDim(0);
int d2 = kqv2.GetDim(1);
int d3 = kqv2.GetDim(2) / 3;
InitTensor3D(&k2, d1, d2, d3, X_FLOAT, devID);
InitTensor3D(&q2, d1, d2, d3, X_FLOAT, devID);
InitTensor3D(&v2, d1, d2, d3, X_FLOAT, devID);
split.Add(&q2);
split.Add(&k2);
split.Add(&v2);
Split(kqv2, split, 2, 3);
return MakeAttention(k2, q2, v2, mask, isTraining);
}
/*
make the attention network given keys, queries and values (after linear transformation)
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
>> q - queries
>> v - values
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
XTensor kheads;
XTensor qheads;
XTensor vheads;
/* multi head */
kheads = Split(k, k.order - 1, nhead);
qheads = Split(q, q.order - 1, nhead);
vheads = Split(v, v.order - 1, nhead);
XTensor att;
XTensor dot;
XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked)
dot = dot + mask;
dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead));
scalar = Softmax(dot, -1);
if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);
att = BMMul(scalar, vheads);
/* concatenate the heads */
return MMul(Merge(att, att.order - 1), wa);
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,12 +17,15 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-10-09
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#include <math.h>
#include <cmath>
#include "T2TDecoder.h"
#include "T2TUtility.h"
#include "T2TLayerNormal.h"
#include "module/T2TUtility.h"
#include "module/T2TLayerNormal.h"
#include "module/T2TCommonModules.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
......@@ -31,145 +34,162 @@ namespace transformer
/* constructor */
AttDecoder::AttDecoder()
{
attentions = NULL;
selfAtt = NULL;
fnns = NULL;
attLayerNorms = NULL;
selfAttLayerNorms = NULL;
fnnLayerNorms = NULL;
attentionsEnde = NULL;
attEndeLayerNorms = NULL;
enDeAtt = NULL;
enDeAttLayerNorms = NULL;
decoderLayerNorm = NULL;
selfAttCache = NULL;
enDeAttCache = NULL;
}
/* de-constructor */
AttDecoder::~AttDecoder()
{
delete[] attentions;
delete[] selfAttCache;
delete[] enDeAttCache;
delete[] selfAtt;
delete[] fnns;
delete[] attLayerNorms;
delete[] selfAttLayerNorms;
delete[] fnnLayerNorms;
delete[] attentionsEnde;
delete[] attEndeLayerNorms;
delete[] enDeAtt;
delete[] enDeAttLayerNorms;
if (preNorm)
delete decoderLayerNorm;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id
/*
initialize the model
>> config - configurations of the model
*/
void AttDecoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID)
void AttDecoder::InitModel(T2TConfig& config)
{
//AttEncoder::InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
devID = myDevID;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
LoadParamFloat(argc, argv, "dropout", &dropoutP, 0);
devID = config.devID;
nlayer = config.nDecLayer;
hSize = config.modelSize;
eSize = config.embSize;
vSize = config.tgtVocabSize;
dropoutP = config.dropout;
preNorm = config.preNorm;
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\"");
/* embedding model */
embedder.InitModel(argc, argv, devID, false);
embedder.InitModel(config, false);
attentions = new T2TAttention[nlayer];
selfAtt = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
attLayerNorms = new T2TLN[nlayer];
selfAttLayerNorms = new T2TLN[nlayer];
enDeAtt = new T2TAttention[nlayer];
enDeAttLayerNorms = new T2TLN[nlayer];
fnnLayerNorms = new T2TLN[nlayer];
attentionsEnde = new T2TAttention[nlayer];
attEndeLayerNorms = new T2TLN[nlayer];
selfAttCache = new Cache[nlayer];
enDeAttCache = new Cache[nlayer];
if (preNorm)
decoderLayerNorm = new T2TLN;
/* initialize the stacked layers */
for (int i = 0; i < nlayer; i++) {
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID);
fnnLayerNorms[i].InitModel(argc, argv, myDevID);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID);
selfAtt[i].InitModel(config);
fnns[i].InitModel(config);
selfAttLayerNorms[i].InitModel(config);
fnnLayerNorms[i].InitModel(config);
enDeAtt[i].InitModel(config);
enDeAttLayerNorms[i].InitModel(config);
}
if (preNorm)
decoderLayerNorm->InitModel(config);
}
/*
/*
make the decoding network
>> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder
>> mask - mask that indicates which position is valid
>> maskEncDec - mask for the encoder-decoder attention
>> nstep - the current length of the decoder input
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
<< return - the output tensor of the decoder
*/
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
XTensor* maskEncDec, int nstep, bool isTraining)
{
XTensor x;
x = embedder.Make(inputDec);
x = embedder.Make(inputDec, true, isTraining, nstep);
/* dropout */
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
for(int i = 0; i < nlayer; i++){
for (int i = 0; i < nlayer; i++) {
XTensor att;
XTensor ende;
XTensor ln;
XTensor fnn;
XTensor res;
XTensor selfAttnBefore;
XTensor selfAttnAfter;
XTensor endeAttnBefore;
XTensor endeAttnAfter;
XTensor fnnBefore;
/* layer normalization with pre-norm for self-attn */
selfAttnBefore = LayerNorm(x, selfAttLayerNorms[i], preNorm, true, false);
/******************/
/* self attention */
att = attentions[i].MakeBig(x, mask, isTraining);
att = selfAtt[i].Make(selfAttnBefore, selfAttnBefore, selfAttnBefore,
mask, isTraining, &selfAttCache[i], SELF_ATT);
/* dropout */
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
/* layer normalization */
x = attLayerNorms[i].Make(res);
/* layer normalization with post-norm for self-attention */
selfAttnAfter = LayerNorm(res, selfAttLayerNorms[i], preNorm, false, true);
/* layer normalization with pre-norm for encoder-decoder attention */
endeAttnBefore = LayerNorm(selfAttnAfter, enDeAttLayerNorms[i], preNorm, true, false);
/*****************************/
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining);
ende = enDeAtt[i].Make(outputEnc, endeAttnBefore, outputEnc, maskEncDec,
isTraining, &enDeAttCache[i], EN_DE_ATT);
/* dropout */
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP);
/* residual connection */
res = Sum(ende, x);
res = Sum(ende, selfAttnAfter);
/* layer normalization */
x = attEndeLayerNorms[i].Make(res);
/* layer normalization with post-norm for encoder-decoder attention */
endeAttnAfter = LayerNorm(res, enDeAttLayerNorms[i], preNorm, false, true);
/* layer normalization with pre-norm for fnn */
fnnBefore = LayerNorm(endeAttnAfter, fnnLayerNorms[i], preNorm, true, false);
/*******/
/* fnn */
fnn = fnns[i].Make(x, isTraining);
fnn = fnns[i].Make(fnnBefore, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP);
/* residual connection */
res = Sum(fnn, x);
res = Sum(fnn, endeAttnAfter);
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
/* layer normalization with post-norm for fnn */
x = LayerNorm(res, fnnLayerNorms[i], preNorm, false, true);
}
x.SetName(DECODING_NAME);
return x;
}
if (preNorm)
x = decoderLayerNorm->Make(x);
return x;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,18 +17,17 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TDECODER_H__
#define __T2TDECODER_H__
#include "T2TEncoder.h"
#include "module/T2TUtility.h"
namespace transformer
{
#define DECODING_NAME "decoding"
#define DECODING_INPUT_NAME "decoding_input"
class AttDecoder
{
......@@ -52,50 +51,52 @@ public:
/* dropout probability */
DTYPE dropoutP;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
* special design for the attention model. */
int ignored;
/* embedding of word at each position */
T2TEmbedder embedder;
/* FNN model of each layer */
T2TFNN * fnns;
T2TFNN* fnns;
/* attention model of each layer */
T2TAttention * attentions;
/* layer normalization for fnn */
T2TLN * fnnLayerNorms;
T2TAttention* selfAtt;
/* layer normalization for attention */
T2TLN * attLayerNorms;
T2TLN* selfAttLayerNorms;
/* input tensor of the encoder */
XTensor * input;
/* layer normalization for fnn */
T2TLN* fnnLayerNorms;
/* output tensor of the encoder */
XTensor * output;
/* layer normalization for decoder */
T2TLN* decoderLayerNorm;
/* encoder-decoder attention model of each layer */
T2TAttention * attentionsEnde;
T2TAttention* enDeAtt;
/* layer normalization for encoder-decoder attention */
T2TLN * attEndeLayerNorms;
T2TLN* enDeAttLayerNorms;
/* layer cache list */
Cache* selfAttCache;
/* layer cache list */
Cache* enDeAttCache;
/* the location of layer normalization */
bool preNorm;
public:
/* constructor */
AttDecoder();
/* deconstructor */
/* de-constructor */
~AttDecoder();
/* initialize the model */
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1);
void InitModel(T2TConfig& config);
/* make the decoding network */
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining);
XTensor Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
XTensor* maskEncDec, int nstep, bool isTraining);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
*/
#include <math.h>
#include "T2TEmbedding.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TEmbedder::T2TEmbedder()
{
devID = -1;
vSize = -1;
maxLength = -1;
}
/* deconstructor */
T2TEmbedder::~T2TEmbedder()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
*/
void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, bool isEnc)
{
devID = myDevID;
if(isEnc){
LoadParamInt(argc, argv, "vsize", &vSize, -1);
}
else{
LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
}
//LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "maxlen", &maxLength, 512);
LoadParamInt(argc, argv, "d", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID);
DTYPE v = 1.0F/(float)sqrt((float)eSize);
w.SetDataRandn(0, v);
/* create the positional embedding matrix */
MakePosEmbedding(eSize, d, maxLength);
}
/*
make positional embeddings (of size eSize * length)
>> eSize - embedding size
>> d - dimension size of the hidden layers
>> length - length of the sequence
*/
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
{
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID);
float * data = new float[posEmbeddingBase.unitNum];
for(int pos = 0; pos < length; pos++){
float * dp = data + pos * eSize;
int channelSize = eSize / 2;
int offset = 0;
for(int i = 0; i < channelSize; i++){
dp[offset++] = (float)sin(pos/pow(10000.0F, 2.0F*i/(d - 2)));
}
for(int i = 0; i < channelSize; i++){
dp[offset++] = (float)cos(pos/pow(10000.0F, 2.0F*i/(d - 2)));
}
/*
for(int k = 0; k < eSize; k++){
if(k % 2 == 0){
int i = k/2;
dp[k] = (float)sin(pos/pow(10000.0F, 2.0F*i/d));
}
else{
int i = (k - 1)/2;
dp[k] = (float)cos(pos/pow(10000.0F, 2.0F*i/d));
}
}
*/
}
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
delete[] data;
}
/*
make the network
*/
XTensor T2TEmbedder::Make(XTensor &input)
{
//CheckNTErrors(input.GetDim(-1) == vSize, "Wrong vocabulary size!");
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, input.dimSize, input.order * sizeof(int));
dims[input.order] = eSize;
XTensor wordEmbedding;
XTensor posEmbedding;
/* make positional embeddings */
XTensor position;
XTensor embTMP;
InitTensor1D(&position, input.GetDim(-1), X_INT, devID);
position.Range(0, position.unitNum, 1);
embTMP = Gather(posEmbeddingBase, position);
posEmbedding = Unsqueeze(embTMP, 0, dims[0]);
/* make word embeddings */
wordEmbedding = Gather(w, input);
wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize));
/* sum over the two embeddings */
return wordEmbedding + posEmbedding;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,12 +17,15 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#include <math.h>
#include <cmath>
#include "T2TEncoder.h"
#include "T2TLayerNormal.h"
#include "T2TUtility.h"
#include "module/T2TUtility.h"
#include "module/T2TLayerNormal.h"
#include "module/T2TCommonModules.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
......@@ -31,62 +34,65 @@ namespace transformer
/* constructor */
AttEncoder::AttEncoder()
{
attentions = NULL;
selfAtt = NULL;
fnns = NULL;
attLayerNorms = NULL;
fnnLayerNorms = NULL;
encoderLayerNorm = NULL;
}
/* de-constructor */
AttEncoder::~AttEncoder()
{
delete[] attentions;
delete[] selfAtt;
delete[] fnns;
delete[] attLayerNorms;
delete[] fnnLayerNorms;
if (preNorm)
delete encoderLayerNorm;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id*/
void AttEncoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID)
/*
initialize the model
>> config - configurations for the model
*/
void AttEncoder::InitModel(T2TConfig& config)
{
devID = myDevID;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamFloat(argc, argv, "dropout", &dropoutP, 0);
devID = config.devID;
nlayer = config.nEncLayer;
eSize = config.embSize;
hSize = config.modelSize;
vSize = config.srcVocabSize;
preNorm = config.preNorm;
dropoutP = config.dropout;
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
/* embedding model */
embedder.InitModel(argc, argv, devID);
embedder.InitModel(config);
attentions = new T2TAttention[nlayer];
selfAtt = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
attLayerNorms = new T2TLN[nlayer];
fnnLayerNorms = new T2TLN[nlayer];
if (preNorm)
encoderLayerNorm = new T2TLN;
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID);
fnnLayerNorms[i].InitModel(argc, argv, myDevID);
for (int i = 0; i < nlayer; i++) {
selfAtt[i].InitModel(config);
fnns[i].InitModel(config);
attLayerNorms[i].InitModel(config);
fnnLayerNorms[i].InitModel(config);
}
if (preNorm)
encoderLayerNorm->InitModel(config);
}
/*
/*
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
......@@ -94,67 +100,74 @@ make the encoding network
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining)
XTensor AttEncoder::Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining)
{
XTensor x;
x = embedder.Make(input);
x = embedder.Make(input, false, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
for(int i = 0; i < nlayer; i++){
for (int i = 0; i < nlayer; i++) {
XTensor att;
XTensor ln;
XTensor fnn;
XTensor res;
XTensor attnBefore;
XTensor attnAfter;
XTensor fnnBefore;
/* layer normalization with pre-norm for self-attn */
attnBefore = LayerNorm(x, attLayerNorms[i], preNorm, true, false);
/* self attention */
att = attentions[i].MakeBig(x, mask, isTraining);
att = selfAtt[i].Make(attnBefore, attnBefore, attnBefore, mask, isTraining, NULL, 0);
/* dropout */
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
/* layer normalization */
x = attLayerNorms[i].Make(res);
/* layer normalization with post-norm for self-attn */
attnAfter = LayerNorm(res, attLayerNorms[i], preNorm, false, true);
/* layer normalization with pre-norm for fnn */
fnnBefore = LayerNorm(attnAfter, fnnLayerNorms[i], preNorm, true, false);
/* fnn */
fnn = fnns[i].Make(x, isTraining);
fnn = fnns[i].Make(fnnBefore, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP);
/* residual connection */
res = Sum(fnn, x);
res = Sum(fnn, attnAfter);
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
/* layer normalization with post-norm for fnn */
x = LayerNorm(res, fnnLayerNorms[i], preNorm, false, true);
}
x.SetName(ENCODING_NAME);
input.SetName(ENCODING_INPUT_NAME);
if (preNorm)
x = encoderLayerNorm->Make(x);
return x;
}
/*
make the encoding network (wrapper)
make the encoding network (wrapper)
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
XTensor AttEncoder::Make(XTensor& input, XTensor* mask, bool isTraining)
{
XTensor nothing;
return Make(input, mask, nothing, isTraining);
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,47 +17,35 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TENCODER_H__
#define __T2TENCODER_H__
#include "T2TFNN.h"
#include "T2TAttention.h"
#include "T2TEmbedding.h"
#include "T2TLayerNormal.h"
#include "module/T2TFNN.h"
#include "module/T2TUtility.h"
#include "module/T2TAttention.h"
#include "module/T2TEmbedding.h"
#include "module/T2TLayerNormal.h"
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
#define ENCODING_NAME "encoding"
#define ENCODING_INPUT_NAME "encoding_input"
/*
base class of the encoder
/*
base class of the encoder
*/
class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining) = 0;
};
/*
the encoder based on RNN
*/
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining);
virtual XTensor Make(XTensor& input, XTensor* mask, XTensor& mask2, bool isTraining) = 0;
};
/*
the encoder based on self-attention
/*
the encoder based on self-attention
*/
class AttEncoder : T2TEncoder
{
......@@ -88,23 +76,23 @@ public:
T2TEmbedder embedder;
/* FNN model of each layer */
T2TFNN * fnns;
T2TFNN* fnns;
/* attention model of each layer */
T2TAttention * attentions;
T2TAttention* selfAtt;
/* layer normalizations for attention */
T2TLN* attLayerNorms;
/* layer normalization for fnn */
T2TLN * fnnLayerNorms;
T2TLN* fnnLayerNorms;
/* layer normalization for attention */
T2TLN * attLayerNorms;
/* layer normalization for encoder */
T2TLN* encoderLayerNorm;
/* input tensor of the encoder */
XTensor * input;
/* the location of layer normalization */
bool preNorm;
/* output tensor of the encoder */
XTensor * output;
public:
/* constructor */
AttEncoder();
......@@ -113,18 +101,15 @@ public:
~AttEncoder();
/* initialize the model */
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1);
void InitModel(T2TConfig& config);
/* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining);
XTensor Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining);
/* make the encoding network (wrapper) */
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
XTensor Make(XTensor& input, XTensor* mask, bool isTraining);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,16 +17,18 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TMODEL_H__
#define __T2TMODEL_H__
#include "T2TFNN.h"
#include "T2TAttention.h"
#include "T2TEncoder.h"
#include "T2TDecoder.h"
#include "T2TOutput.h"
#include "module/T2TFNN.h"
#include "module/T2TOutput.h"
#include "module/T2TUtility.h"
#include "module/T2TAttention.h"
namespace transformer
{
......@@ -41,13 +43,13 @@ public:
int devID;
/* the encoder */
AttEncoder * encoder;
AttEncoder* encoder;
/* the decoder */
AttDecoder * decoder;
AttDecoder* decoder;
/* output layer */
T2TOutput * outputLayer;
T2TOutput* outputLayer;
/* indicates whether the model is running for language modeling */
bool isLM;
......@@ -55,9 +57,18 @@ public:
/* indicates whether the model is running for machine translation */
bool isMT;
/* indicates whether the model is running with FP16 data type */
bool useFP16;
/* number of heads in the attention model */
int nhead;
/* indicates whether share encoders embeddings with decoders */
int shareAllEmbeddings;
/* indicates whether share decoder embeddings with output weights */
int shareDecInputOutputWeight;
public:
/* constructor */
T2TModel();
......@@ -66,41 +77,42 @@ public:
~T2TModel();
/* initialize the model */
void InitModel(int argc, char ** argv);
void InitModel(T2TConfig& config);
/* make the encoding network */
XTensor MakeEncoder(XTensor &input, XTensor &mask, bool isTraining);
XTensor MakeEncoder(XTensor& input, XTensor* mask, bool isTraining);
/* make the encoding network */
XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor &mask, XTensor &MaskEncDec, bool isTraining);
XTensor MakeDecoder(XTensor& inputEnc, XTensor& inputDec, XTensor* mask,
XTensor& MaskEncDec, bool isTraining);
/* make the network for langauge modeling (with the output softmax layer) */
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the network for language modeling (with the output softmax layer) */
void MakeLM(XTensor& input, XTensor& output, XTensor& padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output,
XTensor &paddingEnc, XTensor &paddingDec, bool isTraining);
void MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output,
XTensor& paddingEnc, XTensor& paddingDec, bool isTraining);
/* make the mask for training MT models */
void MakeMTMask(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskEnc, XTensor &maskDec, XTensor &maskEncDec);
void MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
XTensor& paddingEnc, XTensor& paddingDec,
XTensor& maskEnc, XTensor& maskDec, XTensor& maskEncDec);
/* make the mask of the encoder */
void MakeMTMaskEnc(XTensor &paddingEnc, XTensor &maskEnc);
void MakeMTMaskEnc(XTensor& paddingEnc, XTensor& maskEnc);
/* make the mask of the decoder */
void MakeMTMaskDec(XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec);
void MakeMTMaskDec(XTensor& paddingEnc, XTensor& paddingDec,
XTensor& maskDec, XTensor& maskEncDec);
/* get parameter matrics */
void GetParams(TensorList &list);
/* get parameter matrices */
void GetParams(TensorList& list);
/* dump the parameters */
void Dump(const char * fn);
/* dump the model to a file */
void Dump(const char* fn);
/* read the parameters */
void Read(const char * fn);
void Read(FILE* file);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2019, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/
#include <math.h>
#include "T2TUtility.h"
#include "T2TTester.h"
#include "T2TSearch.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
using namespace nts;
namespace transformer
{
/* constructor */
T2TTester::T2TTester()
{
}
/* de-constructor */
T2TTester::~T2TTester()
{
}
/* initialize the model */
void T2TTester::Init(int argc, char ** argv)
{
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamInt(argc, argv, "vsizetgt", &vSizeTgt, vSize);
batchLoader.Init(argc, argv);
seacher.Init(argc, argv);
}
/*
test the model
>> fn - test data file
>> ofn - output data file
>> model - model that is trained
*/
void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
{
int wc = 0;
int ws = 0;
int wordCount = 0;
int wordCountTotal = 0;
int sentCount = 0;
int batchCount = 0;
float loss = 0;
/* data files */
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "Cannot read the test file");
FILE * ofile = fopen(ofn, "wb");
CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID;
XNet net;
double startT = GetClockSec();
wordCount = 0;
/* batch of input sequences */
XTensor batchEnc;
XTensor batchDec;
/* label */
XTensor label;
/* padding */
XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */
XTensor gold;
/* an array that keeps the sequences */
int * seqs = new int[MILLION];
batchLoader.SetRandomBatch(false);
batchLoader.ClearBuf();
while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, false))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch!");
CheckNTErrors(!model->isLM, "Only MT model is supported!");
XTensor output;
seacher.Search(model, &batchEnc, &paddingEnc, &output);
Dump(ofile, &output);
float prob = 0;
loss += -prob;
wc = batchEnc.GetDim(-1);
wordCount += wc;
wordCountTotal += wc;
sentCount += batchEnc.GetDim(-2);
batchCount += 1;
if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
}
}
fclose(file);
fclose(ofile);
delete[] seqs;
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss/wordCount));
}
/*
dump the result into the file
>> file - data file
>> output - output tensor
*/
void T2TTester::Dump(FILE * file, XTensor * output)
{
int seqLength = output->GetDim(-1);
for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j);
fprintf(file, "%d ", w);
if (w < 0)
break;
}
fprintf(file, "\n");
}
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
namespace transformer
{
FILE * tmpFILE;
int llnum = 0;
FILE * tf = NULL;
void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
strcpy(p, argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
strcpy(p, defaultP);
}
void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*(int*)p = atoi(argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname)){
*(bool*)p = true;
//fprintf(stderr, " %s=%s\n", name, "true");
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*p = (float)atof(argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void ShowParams(int argc, char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
if(argv[i][1] == 0)
continue;
if(argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')){
if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
fprintf(stderr, " %s=yes\n", argv[i]);
}
}
fprintf(stderr, "\n");
}
}
......@@ -17,99 +17,55 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-06
*/
#include <math.h>
#include <time.h>
#include <cmath>
#include <ctime>
#include "Transformer.h"
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "T2TPredictor.h"
#include "T2TTester.h"
#include "train/T2TTrainer.h"
#include "module/T2TUtility.h"
#include "translate/T2TTranslator.h"
#include "../../tensor/XDevice.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/XGlobal.h"
#include "../../tensor/XUtility.h"
namespace transformer
{
int TransformerMain(int argc, const char ** argv)
int TransformerMain(int argc, const char** argv)
{
if(argc == 0)
if (argc == 0)
return 1;
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1];
strcpy(args[i], argv[i]);
}
tmpFILE = fopen("tmp.txt", "wb");
ShowParams(argc, args);
bool isBeamSearch = false;
char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH];
char * outputFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, args, "train", trainFN, "");
LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, "");
LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false);
/* load configurations */
T2TConfig config(argc, argv);
srand((unsigned int)time(NULL));
T2TTrainer trainer;
trainer.Init(argc, args);
T2TModel model;
model.InitModel(argc, args);
/* learn model parameters */
if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
model.Dump(modelFN);
/* load the model if neccessary */
if(strcmp(modelFN, ""))
model.Read(modelFN);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, "")){
/* beam search */
if(isBeamSearch){
T2TTester searcher;
searcher.Init(argc, args);
searcher.Test(testFN, outputFN, &model);
}
/* forced decoding */
else{
T2TTrainer tester;
tester.Init(argc, args);
tester.Validate(testFN, outputFN, &model);
}
/* train the model */
if (strcmp(config.trainFN, "") != 0) {
ENABLE_GRAD;
T2TModel model;
model.InitModel(config);
T2TTrainer trainer;
trainer.Init(config);
trainer.Train(config.trainFN, config.validFN, config.modelFN, &model);
}
delete[] trainFN;
delete[] modelFN;
delete[] testFN;
delete[] outputFN;
for(int i = 0; i < argc; i++)
delete[] args[i];
delete[] args;
fclose(tmpFILE);
/* translate the test file */
if (strcmp(config.testFN, "") != 0 && strcmp(config.outputFN, "") != 0) {
DISABLE_GRAD;
T2TModel model;
model.InitModel(config);
T2TTranslator translator;
translator.Init(config);
translator.Translate(config.testFN, config.srcVocabFN,
config.tgtVocabFN, config.outputFN, &model);
}
return 0;
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,13 +17,13 @@
/*
*
* An impelementation of the transformer system. See more details
* about FNNLM in
* An implementation of the transformer system. See more details
* about FNNLM in
* "Attention Is All You Need" by Vaswani et al.
* https://arxiv.org/pdf/1706.03762.pdf
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* I start writing the code related to NMT - a long time since my last coding
* I start writing the code related to NMT - a long time since my last coding
* work on MT
*/
......@@ -38,7 +38,7 @@ namespace transformer
{
/* entrance of the program */
int TransformerMain(int argc, const char ** argv);
int TransformerMain(int argc, const char** argv);
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,48 +17,93 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04, 2020-06
*/
#ifndef __T2TATTENTION_H__
#define __T2TATTENTION_H__
#include "../../network/XNet.h"
#include "T2TNNUtil.h"
#include "T2TUtility.h"
#include "../../../network/XNet.h"
#include "../../../tensor/core/CHeader.h"
using namespace nts;
namespace transformer
{
/* attention type */
enum { NONE, SELF_ATT, EN_DE_ATT };
/*
multi-head attention
y(Q, K, V) = cat(head_1, head_2, ..., head_n)
where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V)
attention(Q, K, V) = softmax(Q * K^T/d_k^0.5) V
d_k = dimension size of K
*/
/* layer cache for keys and values */
class Cache
{
public:
/* cache for keys, (B, L, H) */
XTensor key;
/* cache for values, (B, L, H) */
XTensor value;
public:
/* indicates cache miss if 'true' */
bool miss;
/* constructor */
Cache();
/* update the states cache */
void Update(XTensor&& k, XTensor&& v);
/* keep alive states */
void KeepAlive(XTensor& aliveIdx);
/* reorder alive states */
void Reorder(XTensor& reorder);
};
/* multi-head attention */
class T2TAttention
{
public:
/* device id */
int devID;
/* head number */
int nhead;
/* transformation matrix for Q */
XTensor wq;
/* bias for Q */
XTensor bq;
/* transformation matrix for K */
XTensor wk;
/* transformation matrix for Q */
XTensor wq;
/* bias for K */
XTensor bk;
/* transformation matrix for V */
XTensor wv;
/* bias for V */
XTensor bv;
XTensor wBig;
XTensor bBig;
/* RPR emb */
XTensor RPEmbK;
/* transformation after dot-product attention */
XTensor wa;
XTensor wbig;
XTensor wo;
/* bias after dot-product attention */
XTensor bo;
/* size of transformed Q and K */
int dk;
......@@ -68,19 +113,15 @@ public:
/* size of input Q, K and V */
int d;
/* indicates whether the attention is masked */
bool isMasked;
/* indicates whether we use the RPR attention */
bool useRPR;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int ignored;
/* indicates whether the model is used for training */
bool isTraining;
/* dropout probability */
DTYPE dropoutP;
/* the maximum relative window size */
int maxRP;
public:
/* constructor */
T2TAttention();
......@@ -89,20 +130,25 @@ public:
~T2TAttention();
/* initialize the model */
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1);
void InitModel(T2TConfig& config);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
/* make the network given a big tensor that keeps keys, queries and values */
XTensor MakeBig(XTensor &kqv, XTensor &mask, bool isTraining);
XTensor Make(XTensor& k, XTensor& q, XTensor& v,
XTensor* mask, bool isTraining,
Cache* cache, int cacheType);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
};
XTensor MakeAttention(XTensor& k, XTensor& q, XTensor& v,
XTensor* mask, bool isTraining);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v,
XTensor* mask, bool isTraining, bool isEnc);
XTensor GetRPEmbedding(const int lenQ, const int lenKV, const int maxRelativeLen, const bool isEnc);
XTensor RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool is_key);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Bei Li (libei_neu@outlook.com) 2020-02-05
* This file includes some common modules of the Transformer model
*/
#include <cmath>
#include "T2TCommonModules.h"
#include "../../../tensor/core/CHeader.h"
#include "../../../tensor/function/FHeader.h"
namespace transformer
{
/*
flexible layer normalization for the Transformer
>> input - input tensor
>> ln - the layernorm network
>> prenorm - whether we use prenorm or not
>> before - whether we use layernorm before attention/fnn
>> after - whether we use layernorm after attention/fnn
*/
XTensor LayerNorm(XTensor& input, T2TLN& ln, bool prenorm, bool before, bool after)
{
if (after ^ prenorm)
return ln.Make(input);
else
return input;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Bei Li (libei_neu@outlook.com) 2020-02-03
*/
#ifndef __COMMONMODULE_H__
#define __COMMONMODULE_H__
#include "T2TLayerNormal.h"
#include "T2TCommonModules.h"
using namespace nts;
namespace transformer
{
/* the layer normalization module to control pre-norm or post-norm*/
XTensor LayerNorm(XTensor& input, T2TLN& ln, bool prenorm, bool before, bool after);
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-07
*/
#include <cmath>
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TEmbedder::T2TEmbedder()
{
devID = -1;
vSize = -1;
maxLength = -1;
}
/* de-constructor */
T2TEmbedder::~T2TEmbedder()
{
}
/*
initialize the model
>> config - configurations of the model
>> isEnc - indicates if it is used for the encoder
*/
void T2TEmbedder::InitModel(T2TConfig& config, bool isEnc)
{
devID = config.devID;
d = config.modelSize;
padIdx = config.padID;
eSize = config.embSize;
maxLength = config.maxPosLen;
vSize = (isEnc) ? config.srcVocabSize : config.tgtVocabSize;
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID);
maxLength = maxLength + 1 + 1;
DTYPE v = 1.0F / (float)sqrt((float)eSize);
w.SetDataRandn(0, v);
/* create the positional embedding matrix */
MakePosEmbedding(maxLength);
}
/*
make positional embeddings (of size eSize * length)
>> length - length of the sequence
*/
void T2TEmbedder::MakePosEmbedding(int length)
{
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID);
float* data = new float[posEmbeddingBase.unitNum];
for (int pos = 0; pos < length; pos++) {
float* dp = data + pos * eSize;
int channelSize = eSize / 2;
int offset = 0;
for (int i = 0; i < channelSize; i++) {
dp[offset++] = (float)sin(pos * exp(-i * log(10000.0F) / (channelSize - 1)));
}
for (int i = 0; i < channelSize; i++) {
dp[offset++] = (float)cos(pos * exp(-i * log(10000.0F) / (channelSize - 1)));
}
}
/* padding zeros */
int padStart = padIdx * eSize;
for (int i = padStart; i < padStart + eSize; i++)
data[i] = 0.F;
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
if (w.dataType != posEmbeddingBase.dataType)
posEmbeddingBase = ConvertDataType(posEmbeddingBase, w.dataType);
delete[] data;
}
/*
make the network
>> input - the word indices
>> nstep - the length of current sequence
>> isDec - indicates whether it is decoder
>> isTraining - indicates whether it is training
<< return - word & position embeddings of the input
*/
XTensor T2TEmbedder::Make(XTensor& input, bool isDec, bool isTraining, int nstep)
{
/* make sure the padding index is 1 */
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
XTensor wordEmbedding, position, posEmbedding;
InitTensor(&position, &input);
int* posData = new int[input.unitNum];
XTensor inputCPU;
InitTensorOnCPU(&inputCPU, &input);
_CopyValues(&input, &inputCPU);
if (!isDec)
{
/* encoder embeddings */
for (int i = 0; i < inputCPU.dimSize[0]; i++) {
int startNoPad = 1 + 1;
int* p = ((int*)inputCPU.data) + i * inputCPU.dimSize[1];
for (int j = 0; j < inputCPU.dimSize[1]; j++) {
if (p[j] == 1) {
posData[i * inputCPU.dimSize[1] + j] = 1;
}
else {
posData[i * inputCPU.dimSize[1] + j] = startNoPad++;
}
}
}
position.SetData(posData, position.unitNum);
}
else
{
/* decoder embeddings */
position.SetDataFixed(nstep + 2);
}
delete[] posData;
/* we make positional embeddings first */
posEmbedding = Gather(posEmbeddingBase, position);
/* then we make word embeddings */
wordEmbedding = Gather(w, input);
wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize));
/* we sum over the two embeddings */
return wordEmbedding + posEmbedding;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,12 +17,14 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-07
*/
#ifndef __T2TEMBEDDING_H__
#define __T2TEMBEDDING_H__
#include "../../network/XNet.h"
#include "T2TUtility.h"
#include "../../../network/XNet.h"
using namespace nts;
......@@ -31,7 +33,7 @@ namespace transformer
#define DEFAULT_EMBEDDING_SIZE 512
/*
/*
embedding (of word at position i):
word embedding + positional embedding
*/
......@@ -40,7 +42,7 @@ class T2TEmbedder
public:
/* device id */
int devID;
/* vocabulary size */
int vSize;
......@@ -53,10 +55,13 @@ public:
/* dimension size of the hidden layers in the t2t model */
int d;
/* padding index */
int padIdx;
/* word embedding matrix */
XTensor w;
/* predefined positional embeddings. It can speeds up
/* predefined positional embeddings. It can speeds up
the embedding processing by re-loading. */
XTensor posEmbeddingBase;
......@@ -68,13 +73,13 @@ public:
~T2TEmbedder();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, bool isEnc = true);
void InitModel(T2TConfig& config, bool isEnc = true);
/* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length);
void MakePosEmbedding(int length);
/* make the network */
XTensor Make(XTensor &input);
XTensor Make(XTensor& input, bool isDec, bool isTraining, int nstep = 0);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,14 +17,16 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#include <math.h>
#include <cmath>
#include "T2TFNN.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/function/FHeader.h"
#include "../../../tensor/core/CHeader.h"
#include "../../../tensor/function/FHeader.h"
namespace transformer
{
......@@ -32,33 +34,30 @@ namespace transformer
/* constructor */
T2TFNN::T2TFNN()
{
inSize = -1;
inSize = -1;
outSize = -1;
hSize = -1;
hSize = -1;
}
/* deconstructor */
/* de-constructor */
T2TFNN::~T2TFNN()
{
}
/*
initialize the model
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> config - configurations of the model
*/
void T2TFNN::InitModel(int argc, char ** argv, int myDevID)
void T2TFNN::InitModel(T2TConfig& config)
{
devID = myDevID;
float minmax = 0;
devID = config.devID;
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, outSize * 4);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
inSize = config.modelSize;
outSize = config.modelSize;
hSize = config.fnnHiddenSize;
dropoutP = config.fnnDropout;
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID);
InitTensor1D(&b1, hSize, X_FLOAT, devID);
......@@ -74,27 +73,24 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID)
b2.SetZeroAll();
}
/*
make the network
/*
make the network
y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor
>> return - the output tensor
*/
XTensor T2TFNN::Make(XTensor &input, bool isTraining)
XTensor T2TFNN::Make(XTensor& input, bool isTraining)
{
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
//t1 = Rectify(MMul(input, w1) + b1);
t1 = Rectify(MulAndShift(input, w1, b1));
if(isTraining && dropoutP > 0)
if (isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP);
/* result = t1 * w2 + b2 */
//return MMul(t1, w2) + b2;
return MulAndShift(t1, w2, b2);
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,12 +17,15 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TFNN_H__
#define __T2TFNN_H__
#include "../../tensor/XTensor.h"
#include "T2TUtility.h"
#include "T2TLayerNormal.h"
#include "../../../tensor/XTensor.h"
using namespace nts;
......@@ -56,7 +59,7 @@ public:
/* bias of transformation 2 */
XTensor b2;
/* dropout probability */
DTYPE dropoutP;
......@@ -65,15 +68,14 @@ public:
/* constructor */
T2TFNN();
/* deconstructor */
/* de-constructor */
~T2TFNN();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1);
void InitModel(T2TConfig& config);
/* make the network */
XTensor Make(XTensor &input, bool isTraining);
XTensor Make(XTensor& input, bool isTraining);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Bei Li (libei_neu@outlook.com) 2020-02-03
*/
#include <cmath>
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "T2TGatedLinearUnit.h"
#include "../../../tensor/core/CHeader.h"
#include "../../../tensor/function/FHeader.h"
namespace transformer
{
/* constructor */
GLU::GLU()
{
inSize = -1;
outSize = -1;
hSize = -1;
}
/* de-constructor */
GLU::~GLU()
{
}
/*
initialize the model
>> config - configurations of the model
*/
void GLU::InitModel(T2TConfig& config)
{
devID = config.devID;
float minmax = 0;
inSize = config.modelSize;
outSize = config.modelSize;
InitTensor2D(&w1, hSize, outSize, X_FLOAT, devID);
InitTensor1D(&b1, outSize, X_FLOAT, devID);
InitTensor2D(&w2, hSize, outSize, X_FLOAT, devID);
InitTensor1D(&b2, outSize, X_FLOAT, devID);
}
/*
make the network
y = W1 * x + b1 * sigmod(W2 * x + b2)
>> input - the input tensor, size = 2 * hSize
>> return - the output tensor, size = hSize
*/
XTensor GLU::Make(XTensor& input)
{
XTensor t1;
XTensor t2;
TensorList input_list;
/* split the input into two vectors with the dim hSize */
Split(input, input_list, -1, 2);
/* t1 = W1 * x + b1 */
t1 = MulAndShift(input_list.GetItem(0), w1, b1);
/* t2 = W2 * x + b2 */
t2 = MulAndShift(input_list.GetItem(1), w2, b2);
return t1 * Sigmoid(t2);
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Bei Li (libei_neu@outlook.com) 2020-02-03
*/
#ifndef __GLU_H__
#define __GLU_H__
#include "T2TLayerNormal.h"
#include "T2TGatedLinearUnit.h"
using namespace nts;
namespace transformer
{
/* a fnn: y = max(0, x * w1 + b1) * w2 + b2 */
class GLU
{
public:
/* device id */
int devID;
/* size of input vector */
int inSize;
/* size of output vector */
int outSize;
/* size of hidden layers */
int hSize;
/* matrix of transformation 1 */
XTensor w1;
/* bias of transformation 1 */
XTensor b1;
/* matrix of transformation 2 */
XTensor w2;
/* bias of transformation 2 */
XTensor b2;
public:
/* constructor */
GLU();
/* de-constructor */
~GLU();
/* initialize the model */
void InitModel(T2TConfig& config);
/* make the network */
XTensor Make(XTensor& input);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Bei Li (libei_neu@outlook.com) 2020-02-03
*/
#include <cmath>
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "T2TLayerNormal.h"
#include "T2TLayerHistory.h"
#include "../../../tensor/core/CHeader.h"
#define SAFE_DELETE(x) do{ if((x) != NULL){delete (x); (x) = NULL;} } while(false)
#define SAFE_DELETE_ARRAY(x) do{ if((x) != NULL) {delete [] (x); (x)=NULL;} } while(false)
namespace transformer
{
/* constructor */
LayerHistory::LayerHistory()
{
d = -1;
count = -1;
weight = NULL;
layerNorms = NULL;
}
/* de-constructor */
LayerHistory::~LayerHistory()
{
history.Clear();
delete[] layerNorms;
}
/*
initialize the model
>> config - configurations of the model
*/
void LayerHistory::InitModel(T2TConfig& config)
{
devID = config.devID;
d = config.modelSize;
nlayer = config.nEncLayer;
InitTensor2D(&weight, nlayer + 1, nlayer + 1, X_FLOAT, devID);
layerNorms = new T2TLN[nlayer];
/* initialize the layer normalization of each layer */
for (int i = 0; i < nlayer; i++) {
layerNorms[i].InitModel(config);
}
}
/*
the Add operation
>> tensor - the previous layer output. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
*/
void LayerHistory::Add(XTensor& tensor)
{
/* the embedding is not normed */
count += 1;
if (history.Size() == 0) {
//sample_ = tensor;
history.Add(&tensor);
return;
}
XTensor ln = layerNorms[count - 2].Make(tensor);
history.Add(&ln);
}
/*
generate the weight sum vector of all previous layer output in the history as the layer input
*/
XTensor LayerHistory::Pop()
{
/* the number of layer output in the history */
size_t size = history.Size();
TensorList historyList;
for (size_t i = 0; i < size; i++)
historyList.Add(history[i]);
/* we need stack the tensor along the first dim*/
XTensor stackTensor = Stack(historyList, 0);
XTensor interWeight;
InitTensor2D(&interWeight, 1, weight.dimSize[1], DEFAULT_DTYPE, devID);
XTensor layerWeight;
InitTensor1D(&layerWeight, size, DEFAULT_DTYPE, devID);
_SelectRange(&weight, &interWeight, 0, size - 1, size);
interWeight.Reshape(interWeight.unitNum);
_SelectRange(&interWeight, &layerWeight, 0, 0, size);
MultiplyDimMe(stackTensor, layerWeight, 0);
XTensor result;
ReduceSum(stackTensor, result, 0);
return result;
}
void LayerHistory::ClearHistory()
{
history.Clear();
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Bei Li (libei_neu@outlook.com) 2020-02-03
*/
#ifndef __LAYERHISTORY_H__
#define __LAYERHISTORY_H__
#include "T2TLayerNormal.h"
#include "T2TLayerHistory.h"
#include "../../../tensor/function/FHeader.h"
using namespace nts;
namespace transformer
{
/*
multi-head attention
y(Q, K, V) = cat(head_1, head_2, ..., head_n)
where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V)
attention(Q, K, V) = softmax(Q * K^T/d_k^0.5) V
d_k = dimension size of K
*/
class LayerHistory
{
public:
/* device id */
int devID;
/* the triangle weight matrix for dlcl */
XTensor weight;
/* hidden size */
int d;
/* layer number */
int nlayer;
/* current layer number */
int count;
/* a history to store the value of intimidate layers */
TensorList history;
/* layer normalization for each intimidate layer */
T2TLN* layerNorms;
public:
/* constructor */
LayerHistory();
/* de-constructor */
~LayerHistory();
/* initialize the model */
void InitModel(T2TConfig& config);
/* add the layer output to the history */
void Add(XTensor& tensor);
/* compute the layer input for the current layer, the weight sum of all previous layer output after normed in the history */
XTensor Pop();
/* clean the history*/
void ClearHistory();
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,13 +17,14 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#include <math.h>
#include "T2TLayerNormal.h"
#include <cmath>
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
#include "T2TLayerNormal.h"
#include "../../../tensor/core/CHeader.h"
namespace transformer
{
......@@ -44,32 +45,28 @@ T2TLN::~T2TLN()
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> config - configurations of the model
*/
void T2TLN::InitModel(int argc, char ** argv, int myDevID)
void T2TLN::InitModel(T2TConfig& config)
{
devID = myDevID;
devID = config.devID;
d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
d = config.modelSize;
InitTensor1D(&w, d, X_FLOAT, devID);
InitTensor1D(&b, d, X_FLOAT, devID);
w.SetDataRand(1.0F, 1.0F);
b.SetZeroAll();
}
/*
make the network
for each layer representation x, we have
y =
>> input - the input tensor
>> return - layer normalization output
*/
XTensor T2TLN::Make(XTensor &input)
XTensor T2TLN::Make(XTensor& input)
{
XTensor &x = input;
XTensor& x = input;
XTensor xn;
XTensor mean;
XTensor variance;
......@@ -77,6 +74,13 @@ XTensor T2TLN::Make(XTensor &input)
XTensor meanFilled;
XTensor standardFilled;
TENSOR_DATA_TYPE dataType = input.dataType;
if (dataType == X_FLOAT16) {
/* reduce functions can only run with FP32 */
x = ConvertDataType(input, X_FLOAT);
}
/* \mu = (sum_i x_i)/m */
mean = ReduceMean(x, x.order - 1);
......@@ -94,8 +98,13 @@ XTensor T2TLN::Make(XTensor &input)
/* x' = (x - \mu)/standard */
xn = (x - meanFilled) / standardFilled;
if (dataType != mean.dataType) {
x = ConvertDataType(x, dataType);
xn = ConvertDataType(xn, dataType);
}
/* result = x' * w + b */
return xn * w + b;
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,19 +17,21 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TLAYERNORMAL_H__
#define __T2TLAYERNORMAL_H__
#include "../../network/XNet.h"
#include "T2TUtility.h"
#include "../../../network/XNet.h"
using namespace nts;
namespace transformer
{
/* layer normalization: y = norm(x) * w + b
/* layer normalization: y = norm(x) * w + b
where norm(x) = (x - mean)/standardDeviation */
class T2TLN
{
......@@ -45,19 +47,19 @@ public:
/* dimension size of the model */
int d;
public:
/* constructor */
T2TLN();
/* de-constructor */
~T2TLN();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1);
void InitModel(T2TConfig& config);
/* make the network */
XTensor Make(XTensor &input);
XTensor Make(XTensor& input);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Chi (huchinlp@foxmail.com) 2020-03-21
*/
#include "T2TNNUtil.h"
namespace transformer
{
/*
a wrapper for the gather function
>> src - the input tensor
>> index - the index tensor
<< res - the output tensor
*/
XTensor AutoGather(XTensor& src, XTensor& index)
{
if (src.order == 2)
return Gather(src, index);
else {
CheckNTErrors(src.order == 3, "the source must be 3d");
int order = src.order;
int dimSize[MAX_TENSOR_DIM_NUM];
for (int i = 0; i < src.order; i++) {
dimSize[i] = src.dimSize[i];
}
src.Reshape(src.dimSize[0], src.dimSize[1] * src.dimSize[2]);
XTensor res = Gather(src, index);
src.Reshape(order, dimSize);
dimSize[0] = index.dimSize[0];
dimSize[1] = res.unitNum / (dimSize[0] * dimSize[2]);
res.Reshape(order, dimSize);
return res;
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,31 +16,24 @@
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Created by: Chi (huchinlp@foxmail.com) 2020-03-21
*/
#ifndef __T2TUTILITY_H__
#define __T2TUTILITY_H__
#ifndef __T2TNNUTIL_H__
#define __T2TNNUTIL_H__
#include <stdio.h>
#include "../../../tensor/XGlobal.h"
#include "../../../tensor/core/CHeader.h"
#include "../../../tensor/function/FHeader.h"
using namespace nts;
namespace transformer
{
extern FILE * tmpFILE;
/* load arguments */
void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP);
/* show arguments */
void ShowParams(int argc, char ** argv);
extern int llnum;
extern FILE * tf;
/* the gather function for tensor with any dimension */
XTensor AutoGather(XTensor& src, XTensor& index);
}
#endif
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,22 +17,24 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#include <math.h>
#include <cmath>
#include "T2TOutput.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
#include "../../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TOutput::T2TOutput()
{
devID = -1;
vSize = -1;
inSize = -1;
hSize = -1;
}
......@@ -42,57 +44,51 @@ T2TOutput::~T2TOutput()
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
initialize the model
>> config - configurations of the model
*/
void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
void T2TOutput::InitModel(T2TConfig& config)
{
devID = myDevID;
float minmax = 0;
devID = config.devID;
hSize = config.modelSize;
vSize = config.tgtVocabSize;
LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, vSize, hSize, X_FLOAT, devID);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID);
float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale/(hSize + vSize));
w.SetDataRand(-finfout, finfout);
DTYPE v = 1.0F/(float)sqrt((float)hSize);
DTYPE v = 1.0F / (float)sqrt((float)hSize);
w.SetDataRandn(0, v);
}
/*
make the network
y = softmax(x * w)
/*
make the network (redefined output tensor)
>> input - input tensor
<< return - output tensor
>> output - output tensor
>> isTraining - whether it is used for training
>> normalized - whether ignore the log-softmax
*/
XTensor T2TOutput::Make(XTensor &input)
void T2TOutput::Make(XTensor& input, XTensor& output, bool isTraining, bool normalized)
{
XTensor &x = input;
XTensor& x = input;
return LogSoftmax(MMul(x, w), -1);
}
output = MMul(x, X_NOTRANS, w, X_TRANS);
/*
make the network (redefined output tensor)
>> input - input tensor
>> output - output tensor
*/
void T2TOutput::Make(XTensor &input, XTensor &output)
{
XTensor &x = input;
/* use softmax for training */
if (isTraining) {
output = Softmax(output, -1);
return;
}
//output = LogSoftmax(MMul(x, w), -1);
output = Softmax(MMul(x, w), -1);
output.SetName(OUTPUT_NAME);
}
/* normalize the output for beam search */
if (normalized) {
auto dataType = output.dataType;
if (dataType == X_FLOAT16)
output = ConvertDataType(output, X_FLOAT);
output = LogSoftmax(output, -1);
if (output.dataType != dataType)
output = ConvertDataType(output, dataType);
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,19 +17,19 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TOUTPUT_H__
#define __T2TOUTPUT_H__
#include "../../tensor/function/FHeader.h"
#include "T2TUtility.h"
#include "../../../tensor/function/FHeader.h"
using namespace nts;
namespace transformer
{
#define OUTPUT_NAME "output"
/* output layer */
class T2TOutput
......@@ -41,9 +41,6 @@ public:
/* vocabulary size */
int vSize;
/* input vector size */
int inSize;
/* vector size of the linear transformation */
int hSize;
......@@ -58,16 +55,12 @@ public:
~T2TOutput();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1);
/* make the network */
XTensor Make(XTensor &input);
void InitModel(T2TConfig& config);
/* make the network (redefined output tensor) */
void Make(XTensor &input, XTensor &output);
void Make(XTensor& input, XTensor& output, bool isTraining, bool normalized);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04, 2020-06
*/
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <fstream>
#include <sstream>
#include "T2TUtility.h"
#include "../../../tensor/XGlobal.h"
using namespace nts;
using namespace std;
namespace transformer
{
/*
load configurations from the command
>> argc - number of arguments
>> argv - the list of arguments
*/
T2TConfig::T2TConfig(int argc, const char** argv)
{
char** args = new char* [MAX_PARAM_NUM];
for (int i = 0; i < argc; i++) {
args[i] = new char[strlen(argv[i]) + 1];
strcpy(args[i], argv[i]);
}
char* configFN = new char[1024];
LoadParamString(argc, args, "config", configFN, "");
int argsNum = argc;
/* load configurations from a file */
if (strcmp(configFN, "") != 0)
argsNum = LoadFromFile(configFN, args);
ShowParams(argsNum, args);
/* options for the model */
LoadParamInt(argsNum, args, "nhead", &nhead, 8);
LoadParamInt(argsNum, args, "enclayer", &nEncLayer, 1);
LoadParamInt(argsNum, args, "declayer", &nDecLayer, 1);
LoadParamInt(argsNum, args, "maxrp", &maxRP, 8);
LoadParamInt(argsNum, args, "embsize", &embSize, 256);
LoadParamInt(argsNum, args, "modelsize", &modelSize, 256);
LoadParamInt(argsNum, args, "maxpos", &maxPosLen, 1024);
LoadParamInt(argsNum, args, "fnnhidden", &fnnHiddenSize, modelSize * 4);
LoadParamInt(argsNum, args, "vsize", &srcVocabSize, 10000);
LoadParamInt(argsNum, args, "vsizetgt", &tgtVocabSize, 10000);
LoadParamInt(argsNum, args, "padid", &padID, 1);
LoadParamInt(argsNum, args, "startid", &startID, 2);
LoadParamInt(argsNum, args, "endid", &endID, 2);
LoadParamBool(argsNum, args, "rpr", &useRPR, false);
LoadParamBool(argsNum, args, "prenorm", &preNorm, false);
LoadParamString(argsNum, args, "model", modelFN, "model.bin");
LoadParamString(argsNum, args, "srcvocab", srcVocabFN, "vocab.src");
LoadParamString(argsNum, args, "tgtvocab", tgtVocabFN, "vocab.tgt");
/* options for training */
LoadParamString(argsNum, args, "train", trainFN, "");
LoadParamString(argsNum, args, "valid", validFN, "");
LoadParamInt(argsNum, args, "dev", &devID, 0);
LoadParamInt(argsNum, args, "wbatch", &wBatchSize, 2048);
LoadParamInt(argsNum, args, "sbatch", &sBatchSize, 1);
isTraining = (strcmp(trainFN, "") == 0) ? false : true;
LoadParamBool(argsNum, args, "mt", &isMT, true);
LoadParamFloat(argsNum, args, "dropout", &dropout, 0.1);
LoadParamFloat(argsNum, args, "fnndrop", &fnnDropout, 0.0);
LoadParamFloat(argsNum, args, "attdrop", &attDropout, 0.0);
LoadParamFloat(argc, args, "lrate", &lrate, 1.0F);
LoadParamFloat(argc, args, "lrbias", &lrbias, 0);
LoadParamInt(argc, args, "nepoch", &nepoch, 20);
LoadParamInt(argc, args, "nstep", &nstep, 100000);
LoadParamInt(argc, args, "nwarmup", &nwarmup, 3000);
LoadParamBool(argc, args, "adam", &useAdam, true);
LoadParamFloat(argc, args, "adambeta1", &adamBeta1, 0.9F);
LoadParamFloat(argc, args, "adambeta2", &adamBeta2, 0.98F);
LoadParamFloat(argc, args, "adamdelta", &adamDelta, 1e-9F);
LoadParamBool(argc, args, "shuffled", &isShuffled, true);
LoadParamFloat(argc, args, "labelsmoothing", &labelSmoothingP, 0.1);
LoadParamInt(argc, args, "nstepcheckpoint", &nStepCheckpoint, -1);
LoadParamBool(argc, args, "epochcheckpoint", &useEpochCheckpoint, false);
LoadParamInt(argc, args, "updatestep", &updateStep, 1);
LoadParamBool(argc, args, "debug", &isDebugged, false);
LoadParamBool(argc, args, "sorted", &isLenSorted, false);
LoadParamInt(argc, args, "bufsize", &bufSize, 50000);
LoadParamBool(argc, args, "doubledend", &isDoubledEnd, false);
LoadParamBool(argc, args, "smallbatch", &isSmallBatch, true);
LoadParamBool(argc, args, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, args, "randbatch", &isRandomBatch, false);
LoadParamInt(argc, args, "bucketsize", &bucketSize, 0);
/* options for translating */
LoadParamString(argsNum, args, "test", testFN, "");
LoadParamString(argsNum, args, "output", outputFN, "");
LoadParamInt(argsNum, args, "beamsize", &beamSize, 1);
LoadParamBool(argsNum, args, "fp16", &useFP16, false);
LoadParamFloat(argsNum, args, "lenalpha", &lenAlpha, 0.6);
LoadParamFloat(argsNum, args, "maxlenalpha", &maxLenAlpha, 2.0);
for (int i = 0; i < argc; i++)
delete[] args[i];
delete[] args;
delete[] configFN;
}
/*
load configurations from a file
>> configFN - path to the configuration file
>> args - the list to store the configurations
format: one option per line, separated by a blank or a tab
*/
int T2TConfig::LoadFromFile(const char* configFN, char** args) {
ifstream f(configFN, ios::in);
CheckNTErrors(f.is_open(), "unable to open the config file");
int argsNum = 0;
/* parse arguments */
string key, value;
while (f >> key >> value) {
key += '-';
strcpy(args[argsNum++], key.c_str());
strcpy(args[argsNum++], value.c_str());
}
/* record the number of arguments */
return argsNum;
}
void LoadParamString(int argc, char** argv, const char* name, char* p, const char* defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
strcpy(p, argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
strcpy(p, defaultP);
}
void LoadParamInt(int argc, char** argv, const char* name, int* p, int defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
*(int*)p = atoi(argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
void LoadParamBool(int argc, char** argv, const char* name, bool* p, bool defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname)) {
*(bool*)p = true;
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
void LoadParamFloat(int argc, char** argv, const char* name, float* p, float defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for (int i = 0; i < argc; i++) {
if (!strcmp(argv[i], vname) && i + 1 < argc) {
*p = (float)atof(argv[i + 1]);
hit = true;
break;
}
}
if (!hit)
*p = defaultP;
}
void ShowParams(int argc, char** argv)
{
fprintf(stderr, "args:\n");
for (int i = 0; i < argc; i++) {
if (argv[i][1] == 0)
continue;
if (argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')) {
if (i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
fprintf(stderr, " %s=yes\n", argv[i]);
}
}
fprintf(stderr, "\n");
}
#define MAX_WORD_NUM 120
/*
split string by delimiter, this will return indices of all sub-strings
>> s - the original string
>> delimiter - as it is
<< indices - indices of all sub-strings
*/
UInt64List SplitToPos(const string& s, const string& delimiter)
{
UInt64List indices;
if (delimiter.length() == 0) {
indices.Add(0);
}
size_t pos = 0;
uint64_t start = 0;
while ((pos = s.find(delimiter, start)) != string::npos) {
if (pos != start) {
indices.Add(start);
}
start = pos + delimiter.length();
}
if (start != s.length()) {
indices.Add(start);
}
return indices;
}
/* split a string to a int64_t list */
IntList SplitInt(const string& s, const string& delimiter)
{
IntList values;
auto indices = SplitToPos(s, delimiter);
for (int i = 0; i < indices.Size(); i++) {
values.Add(strtol(s.data() + indices[i], nullptr, 10));
}
return values;
}
/* split a string to a float list */
FloatList SplitFloat(const string& s, const string& delimiter)
{
FloatList values;
auto indices = SplitToPos(s, delimiter);
for (int i = 0; i < indices.Size(); i++) {
values.Add(strtof(s.data() + indices[i], nullptr));
}
return values;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-06
*/
#ifndef __T2TUTILITY_H__
#define __T2TUTILITY_H__
#include <string>
#include <cstdio>
#include "../../../tensor/XList.h"
using namespace std;
using namespace nts;
namespace transformer
{
#define MAX_PARAM_NUM 100
/* load arguments */
void LoadParamInt(int argc, char** argv, const char* name, int* p, int defaultP);
void LoadParamBool(int argc, char** argv, const char* name, bool* p, bool defaultP);
void LoadParamFloat(int argc, char** argv, const char* name, float* p, float defaultP);
void LoadParamString(int argc, char** argv, const char* name, char* p, const char* defaultP);
/* show arguments */
void ShowParams(int argc, char** argv);
/* split string */
IntList SplitInt(const string& s, const string& delimiter);
FloatList SplitFloat(const string& s, const string& delimiter);
UInt64List SplitToPos(const string& s, const string& delimiter);
/* configurations for t2t */
class T2TConfig {
public:
/* path to the model */
char modelFN[1024];
/* path to the source vocab */
char srcVocabFN[1024];
/* path to the target vocab */
char tgtVocabFN[1024];
/* path to the input file (for inference) */
char testFN[1024];
/* path to the output file (for inference) */
char outputFN[1024];
/* path to the training file */
char trainFN[1024];
/* path to the validation file */
char validFN[1024];
/* device id */
int devID;
/* beam size */
int beamSize;
/* word batch size */
int wBatchSize;
/* sentence batch size */
int sBatchSize;
/* number of heads in attention */
int nhead;
/* number of encoder layers */
int nEncLayer;
/* number of decoder layers */
int nDecLayer;
/* the maximum relative position in RPR attentions */
int maxRP;
/* the dimension of embeddings */
int embSize;
/* the dimension of hidden layer */
int modelSize;
/* the maximum length in positional embedding */
int maxPosLen;
/* the dimension of fnn hidden layer */
int fnnHiddenSize;
/* the vocab size of source sequence */
int srcVocabSize;
/* the vocab size of target sequence */
int tgtVocabSize;
/* the padding id */
int padID;
/* start symbol */
int startID;
/* end symbol */
int endID;
/* indicates whether the model uses pre-norm */
bool preNorm;
/* indicates whether the model is running for machine translation */
bool isMT;
/* indicates whether the model is running with FP16 data type */
bool useFP16;
/* indicates whether we use the RPR attention */
bool useRPR;
/* indicates whether we train the model */
bool isTraining;
/* dropout rate for the model */
float dropout;
/* dropout rate for fnn layers */
float fnnDropout;
/* dropout rate for attention layers */
float attDropout;
/* the alpha parameter controls the length preference */
float lenAlpha;
/* scalar of the input sequence (for max number of search steps) */
float maxLenAlpha;
/* learning rate */
float lrate;
/* the parameter that controls the maximum learning rate in training */
float lrbias;
/* training epoch number */
int nepoch;
/* traing step number */
int nstep;
/* indicates whether we use Adam */
bool useAdam;
/* hyper parameters of Adam */
float adamBeta1;
float adamBeta2;
float adamDelta;
/* step number of warm-up for training */
int nwarmup;
/* indicates whether the data file is shuffled for training */
bool isShuffled;
/* the factor of label smoothing */
float labelSmoothingP;
/* number of steps after which we make a checkpoint */
int nStepCheckpoint;
/* indicates whether we make a checkpoint after each training epoch */
bool useEpochCheckpoint;
/* number of batches on which we do model update */
int updateStep;
/* indicates whether we intend to debug the net */
bool isDebugged;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* buffer size */
int bufSize;
/* indicates whether we double the </s> symbol for the output of LM */
bool isDoubledEnd;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool isSmallBatch;
/* counterpart of "isSmallBatch" */
bool isBigBatch;
/* randomize batches */
bool isRandomBatch;
/* bucket size */
int bucketSize;
public:
/* load configurations from the command */
T2TConfig(int argc, const char** argv);
/* load configurations from a file */
int LoadFromFile(const char* configFN, char** args);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,13 +17,14 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-25
* it is cold today but i'll move to a warm place tomorrow :)
* it is cold today but I'll move to a warm place tomorrow :)
*/
#ifndef __T2TBATCHLOADER_H__
#define __T2TBATCHLOADER_H__
#include "../../network/XNet.h"
#include "../module/T2TUtility.h"
#include "../../../network/XNet.h"
using namespace nts;
......@@ -35,7 +36,7 @@ namespace transformer
/* node to keep batch information */
struct BatchNode
{
/* begining position */
/* beginning position */
int beg;
/* end position */
......@@ -55,13 +56,13 @@ class T2TBatchLoader
{
public:
/* buffer for loading words */
int * buf;
int* buf;
/* another buffer */
int * buf2;
int* buf2;
/* batch buf */
BatchNode * bufBatch;
BatchNode* bufBatch;
/* buffer size */
int bufSize;
......@@ -70,13 +71,13 @@ public:
int bufBatchSize;
/* length of each sequence */
int * seqLen;
int* seqLen;
/* another array */
int * seqLen2;
int* seqLen2;
/* offset of the first word for each sequence */
int * seqOffset;
int* seqOffset;
/* number of sequences in the buffer */
int nseqBuf;
......@@ -87,9 +88,9 @@ public:
/* offset for next batch */
int nextBatch;
/* indicates whether we double the </s> symbol for the output of lms */
/* indicates whether we double the </s> symbol for the output of LM */
bool isDoubledEnd;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
......@@ -112,10 +113,10 @@ public:
~T2TBatchLoader();
/* initialization */
void Init(int argc, char ** argv);
void Init(T2TConfig& config);
/* load data to buffer */
int LoadBuf(FILE * file, bool isSorted, int step);
int LoadBuf(FILE* file, bool isSorted, int step);
/* clear data buffer */
void ClearBuf();
......@@ -124,36 +125,37 @@ public:
void SetRandomBatch(bool flag = true);
/* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, bool isTraining);
int LoadBatch(FILE* file, bool isLM,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec,
XTensor* gold, XTensor* label,
int* seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int& ws, int& wCount,
int devID, bool isTraining);
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, bool isTraining);
int LoadBatchLM(FILE* file,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec,
XTensor* gold, XTensor* label,
int* seqs, int vs, int sBatch, int wBatch,
bool isSorted, int& wCount,
int devID, bool isTraining);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, bool isTraining);
int LoadBatchMT(FILE* file,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec,
XTensor* gold, XTensor* label,
int* seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int& ws, int& wCount,
int devID, bool isTraining);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
void Shuffle(const char* srcFile, const char* tgtFile);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -22,9 +22,9 @@
#ifndef __T2TTRAINER_H__
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "../T2TModel.h"
#include "T2TBatchLoader.h"
#include "../../tensor/function/FHeader.h"
#include "../../../tensor/function/FHeader.h"
using namespace nts;
......@@ -35,15 +35,13 @@ namespace transformer
class T2TTrainer
{
public:
/* paramter number */
int argNum;
/* parameter array */
char ** argArray;
/* configurations */
T2TConfig* cfg;
/* dimension size of each inner layer */
int d;
/* step number of warm-up for training */
int nwarmup;
......@@ -55,7 +53,7 @@ public:
/* learning rate */
float lrate;
/* the parameter that controls the maximum learning rate in training */
float lrbias;
......@@ -81,24 +79,24 @@ public:
float adamBeta1T;
float adamBeta2T;
/* list of the moment of the parameter matrics */
/* list of the moment of the parameter matrices */
TensorList moments;
/* list of the 2nd order moment of the parameter matrics */
/* list of the 2nd order moment of the parameter matrices */
TensorList moments2nd;
/* indicates whether the data file is shuffled for training */
bool isShuffled;
/* the factor of label smoothing */
DTYPE labelSmoothingP;
/* number of steps after which we make a checkpoint */
int nStepCheckpoint;
/* indicates whether we make a checkpoint after each traing epoch */
/* indicates whether we make a checkpoint after each training epoch */
bool useEpochCheckpoint;
/* number of batches on which we do model update */
int updateStep;
......@@ -119,25 +117,24 @@ public:
~T2TTrainer();
/* initialize the trainer */
void Init(int argc, char ** argv);
void Init(T2TConfig& config);
/* train the model */
void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
void Train(const char* fn, const char* validFN, const char* modelFN, T2TModel* model);
/* test the model */
void Validate(const char * fn, const char * ofn, T2TModel * model);
void Validate(const char* fn, const char* ofn, T2TModel* model);
/* make a checkpoint */
void MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id);
void MakeCheckpoint(T2TModel* model, const char* validFN, const char* modelFN, const char* label, int id);
/* update the model by delta rule */
void Update(T2TModel * model, const float lr);
void Update(T2TModel* model, const float lr);
/* prepare model for training */
void PrepareModel(T2TModel * model);
void PrepareModel(T2TModel* model);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-03
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-06
*/
#include <string>
#include <vector>
#include <cstdlib>
#include <fstream>
#include <algorithm>
#include "T2TDataSet.h"
#include "../module/T2TUtility.h"
using namespace transformer;
namespace nts {
/* sort the output by id (in ascending order) */
void DataSet::SortInput() {
sort(inputBuffer.items, inputBuffer.items + inputBuffer.count, [](Example* a, Example* b) {
return a->values.count > b->values.count;
});
}
/* sort the input by length (in descending order) */
void DataSet::SortOutput() {
sort(outputBuffer.items, outputBuffer.items + outputBuffer.count, [](Result* a, Result* b) {
return a->id < b->id;
});
}
/*
load data from the file to the buffer
*/
void DataSet::LoadDataToBuffer()
{
string line;
inputBuffer.Clear();
bufferUsed = 0;
int id = 0;
const string tokenDelimiter = " ";
while (getline(*fp, line)) {
IntList values;
/* load words and transform them to ids */
auto indices = SplitToPos(line, tokenDelimiter);
/* reserve the first 120 words if the input is too long */
size_t maxLen = indices.Size() > MAX_WORD_NUM ? MAX_WORD_NUM : indices.Size();
for (size_t i = 0; i < maxLen; i++) {
auto offset = (i != (indices.Size() - 1)) ?
indices[i + 1] - indices[i] - tokenDelimiter.size()
: line.size() - indices[i];
string word = line.substr(indices[i], offset);
if (srcVocab.word2id.find(word) == srcVocab.word2id.end())
values.Add(3);
else
values.Add(srcVocab.word2id.at(word));
}
/* make sure that the sequence ends with EOS */
if (values.Size() != 0 && values[-1] != EOS)
values.Add(EOS);
Example* example = new Example;
example->id = id;
example->values = values;
if (values.Size() != 0)
inputBuffer.Add(example);
else
emptyLines.Add(id);
id++;
}
fp->close();
SortInput();
XPRINT1(0, stderr, "[INFO] loaded %d sentences\n", id);
}
/*
load a mini-batch to the device
>> batchEnc - a tensor to store the batch of input
>> paddingEnc - a tensor to store the batch of paddings
>> minSentBatch - the minimum number of sentence batch
>> batchSize - the maxium number of words in a batch
>> devID - the device id, -1 for the CPU
<< indices of the sentences
*/
UInt64List DataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
size_t minSentBatch, size_t batchSize, int devID)
{
size_t realBatchSize = minSentBatch;
/* get the maximum sentence length in a mini-batch */
size_t maxLen = inputBuffer[bufferUsed]->values.Size();
/* dynamic batching for sentences */
while ((realBatchSize < (inputBuffer.Size() - bufferUsed))
&& (realBatchSize * maxLen < batchSize)) {
realBatchSize++;
}
/* real batch size */
if ((inputBuffer.Size() - bufferUsed) < realBatchSize) {
realBatchSize = inputBuffer.Size() - bufferUsed;
}
CheckNTErrors(maxLen != 0, "invalid length");
int* batchValues = new int[realBatchSize * maxLen];
float* paddingValues = new float[realBatchSize * maxLen];
for (int i = 0; i < realBatchSize * maxLen; i++) {
batchValues[i] = 1;
paddingValues[i] = 0.0F;
}
size_t cur = 0;
/* left padding */
UInt64List infos;
size_t totalLength = 0;
for (int i = 0; i < realBatchSize; ++i) {
infos.Add(inputBuffer[bufferUsed + i]->id);
totalLength += inputBuffer[bufferUsed + i]->values.Size();
cur = maxLen * (i + 1) - inputBuffer[bufferUsed + i]->values.Size();
for (int j = 0; j < inputBuffer[bufferUsed + i]->values.Size(); j++) {
batchValues[cur] = inputBuffer[bufferUsed + i]->values[j];
paddingValues[cur++] = 1.0F;
}
}
infos.Add(totalLength);
InitTensor2D(batchEnc, realBatchSize, maxLen, X_INT, devID);
InitTensor2D(paddingEnc, realBatchSize, maxLen, X_FLOAT, devID);
bufferUsed += realBatchSize;
batchEnc->SetData(batchValues, batchEnc->unitNum);
paddingEnc->SetData(paddingValues, paddingEnc->unitNum);
delete[] batchValues;
delete[] paddingValues;
return infos;
}
/*
the constructor of DataSet
>> dataFile - path of the data file
>> srcVocabFN - path of the source vocab file
>> tgtVocabFN - path of the target vocab file
*/
void DataSet::Init(const char* dataFile, const char* srcVocabFN, const char* tgtVocabFN)
{
fp = new ifstream(dataFile);
CheckNTErrors(fp->is_open(), "can not open the file");
bufferUsed = 0;
CheckNTErrors(strcmp(srcVocabFN, "") != 0, "missing source vocab file");
CheckNTErrors(strcmp(tgtVocabFN, "") != 0, "missing target vocab file");
srcVocab.Load(srcVocabFN);
/* share source and target vocabs */
if (strcmp(srcVocabFN, tgtVocabFN) == 0) {
XPRINT(0, stderr, "[INFO] share source and target vocabs \n");
tgtVocab.CopyFrom(srcVocab);
}
else {
tgtVocab.Load(tgtVocabFN);
}
LoadDataToBuffer();
}
/* check if the buffer is empty */
bool DataSet::IsEmpty() {
if (bufferUsed < inputBuffer.Size())
return false;
return true;
}
/* dump the translation to a file */
void DataSet::DumpRes(const char* ofn)
{
ofstream ofile(ofn, ios::out);
for (int t = 0; t < outputBuffer.Size(); t++) {
auto res = outputBuffer[t];
for (int i = 0; i < res->res.Size(); i++) {
if (res->res[i] < 4)
break;
ofile << tgtVocab.id2word[res->res[i]] << " ";
}
ofile << "\n";
}
ofile.close();
}
/* de-constructor */
DataSet::~DataSet()
{
/* release the file */
delete fp;
/* release the input buffer */
for (int i = 0; i < inputBuffer.Size(); i++)
delete inputBuffer[i];
/* release the output buffer */
for (int i = 0; i < outputBuffer.Size(); i++)
delete outputBuffer[i];
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-03
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-06
*/
#ifndef __DATASET_H__
#define __DATASET_H__
#include <cstdio>
#include <vector>
#include <fstream>
#include "T2TVocab.h"
#include "../../../tensor/XList.h"
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using namespace std;
namespace nts {
/* the struct of tokenized input */
struct Example {
int id;
IntList values;
};
/* the struct of tokenized output */
struct Result {
int id;
IntList res;
};
/* A `DataSet` is associated with a file which contains variable length data.*/
struct DataSet {
public:
/* the data buffer */
InputBufferType inputBuffer;
/* a list of empty line number */
IntList emptyLines;
/* the result buffer */
OutputBufferType outputBuffer;
/* the pointer to file stream */
ifstream* fp;
/* size of used data in buffer */
size_t bufferUsed;
/* the source vocabulary */
Vocab srcVocab;
/* the target vocabulary */
Vocab tgtVocab;
public:
/* sort the input by length */
void SortInput();
/* reorder the output by ids */
void SortOutput();
/* load data from a file to the buffer */
void LoadDataToBuffer();
/* generate a mini-batch */
UInt64List LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
size_t sBatch, size_t wBatch, int devID);
/* initialization function */
void Init(const char* dataFile, const char* srcVocabFN, const char* tgtVocabFN);
/* check if the buffer is empty */
bool IsEmpty();
/* dump the translations to a file */
void DumpRes(const char* ofn);
/* de-constructor */
~DataSet();
};
}
#endif // __DATASET_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2019, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,7 +15,13 @@
* limitations under the License.
*/
#include "../../tensor/core/CHeader.h"
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-08
* Start of a new week - I just finished several documents.
* Writing document is harder than writing code :)
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#include "T2TLengthPenalty.h"
using namespace nts;
......@@ -23,24 +29,23 @@ using namespace nts;
namespace transformer
{
/*
GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
where n = length of the sequence
>> length - length of the sequence (for each entry)
/*
GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
where n = length of the sequence
>> length - length of the sequence
>> alpha - the parameter controls the length preference
<< return - length penaltyof the sequence (for each entry)
<< return - length penalty of the sequence
*/
XTensor T2TLengthPenalizer::GNMT(const XTensor & length, float alpha)
float T2TLengthPenalizer::GNMT(float length, float alpha)
{
XTensor base;
XTensor lp;
float base;
float lp;
//base = ScaleAndShift(ScaleAndShift(length, 0, 5.0F), 1.0F/(5 + 1));
base = (length + 5)/(1 + 5);
base = (length + 5.0F) / (1.0F + 5.0F);
lp = pow(base, alpha);
lp = Power(base, alpha);
return lp;
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2019, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -19,12 +19,14 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-08
* Start of a new week - I just finished several documents.
* Writing document is harder than writing code :)
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TLENGTHPENALTY_H__
#define __T2TLENGTHPENALTY_H__
#include "../../tensor/XTensor.h"
#include "../module/T2TUtility.h"
#include "../../../tensor/XTensor.h"
using namespace nts;
......@@ -37,10 +39,9 @@ namespace transformer
class T2TLengthPenalizer
{
public:
/* GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
/* GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
where n = length of the sequence */
static
XTensor GNMT(const XTensor & length, float alpha);
static float GNMT(float length, float alpha);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2019, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -18,29 +18,32 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
* This is the first source file I create in 2019 - new start!
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04
*/
#ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__
#include "T2TModel.h"
#include "../T2TModel.h"
#include "T2TLengthPenalty.h"
using namespace std;
namespace transformer
{
#define T2T_PID_EMPTY -1
/* state for search. It keeps the path (back-pointer), prediction distribution,
and etc. It can be regarded as a hypothsis in translation. */
and etc. It can be regarded as a hypotheses in translation. */
class T2TState
{
public:
/* we assume that the prediction is an integer */
int prediction;
/* id of the problem. One can regard it as the sentence id when we
translate a number of sentences in the batched manner. The hypothesis
/* id of the problem. One can regard it as the sentence id when we
translate a number of sentences in the batched manner. The hypotheses
is empty if id = -1 */
int pid;
......@@ -62,11 +65,11 @@ public:
/* model score of every path. A model score = path probability + some other stuff */
float modelScore;
/* nubmer of steps we go over so far */
/* number of steps we go over so far */
int nstep;
/* pointer to the previous state */
T2TState * last;
T2TState* last;
};
/* a bundle of states */
......@@ -75,11 +78,11 @@ class T2TStateBundle
public:
/* predictions */
XTensor prediction;
/* id of the previous state that generates the current one */
XTensor preID;
/* mark that indicates whether each hypothesis is completed */
/* mark that indicates whether each hypotheses is completed */
XTensor endMark;
/* probability of every prediction (last state of the path) */
......@@ -91,18 +94,11 @@ public:
/* model score of every path */
XTensor modelScore;
/* step number of each hypothesis */
XTensor nstep;
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
TensorList layersEnc;
/* layers on the decoder side */
TensorList layersDec;
/* step number of each hypotheses */
float nstep;
/* list of states */
T2TState * states;
T2TState* states;
/* number of states */
int stateNum;
......@@ -121,23 +117,26 @@ public:
void MakeStates(int num);
};
/* The predictor reads the current state and then predicts the next.
/* The predictor reads the current state and then predicts the next.
It is exactly the same procedure of MT inference -
we get the state of previous words and then generate the next word.
Here, a state can be regared as the representation of words (word
Here, a state can be regarded as the representation of words (word
indices, hidden states, embeddings and etc.). */
class T2TPredictor
{
private:
/* pointer to the transformer model */
T2TModel * m;
T2TModel* m;
/* current state */
T2TStateBundle * s;
T2TStateBundle* s;
/* start symbol */
int startSymbol;
/* end symbol */
int endSymbol;
public:
/* constructor */
T2TPredictor();
......@@ -146,19 +145,24 @@ public:
~T2TPredictor();
/* create an initial state */
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state);
void Create(T2TModel* model, XTensor* top, const XTensor* input, int beamSize, T2TStateBundle* state);
/* set the start symbol */
void SetStartSymbol(int symbol);
/* read a state */
void Read(T2TModel * model, T2TStateBundle * state);
void Read(T2TModel* model, T2TStateBundle* state);
/* predict the next state */
void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc);
void Predict(T2TStateBundle* next, XTensor& aliveIndices, XTensor& encoding,
XTensor& inputEnc, XTensor& paddingEnc, int rawBatchSize,
bool isStart, XTensor& reorderState, bool needReorder, int nstep);
/* generate paths up to the states of the current step */
XTensor GeneratePaths(T2TStateBundle * state);
XTensor GeneratePaths(T2TStateBundle* state);
/* get the predictions of the previous step */
XTensor GetLastPrediction(T2TStateBundle* state, int devID);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2019, Natural Language Processing Lab, Northeastern University.
* Copyright (C) 2020, Natural Language Processing Lab, Northeastern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,22 +17,25 @@
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
* $Modified by: HU Chi (huchinlp@gmail.com) 2020-04, 2020-06
*/
#ifndef __T2TSEARCH_H__
#define __T2TSEARCH_H__
#include "T2TModel.h"
#include "../T2TModel.h"
#include "T2TPredictor.h"
using namespace std;
namespace transformer
{
/* The class orgnizes the search process. It calls "predictors" to generate
/* The class organizes the search process. It calls "predictors" to generate
distributions of the predictions and prunes the search space by beam pruning.
This makes a graph where each path respresents a translation hypothsis.
This makes a graph where each path represents a translation hypotheses.
The output can be the path with the highest model score. */
class T2TSearch
class BeamSearch
{
private:
/* the alpha parameter controls the length preference */
......@@ -40,10 +43,10 @@ private:
/* predictor */
T2TPredictor predictor;
/* max length of the generated sequence */
int maxLength;
/* beam size */
int beamSize;
......@@ -51,10 +54,10 @@ private:
int batchSize;
/* we keep the final hypotheses in a heap for each sentence in the batch. */
XHeap<MIN_HEAP, float> * fullHypos;
XHeap<MIN_HEAP, float>* fullHypos;
/* array of the end symbols */
int * endSymbols;
int* endSymbols;
/* number of the end symbols */
int endSymbolNum;
......@@ -62,48 +65,118 @@ private:
/* start symbol */
int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
/* pids for alive states */
IntList aliveStatePids;
/* alive sentences */
IntList aliveSentList;
/* whether we need to reorder the states */
bool needReorder;
public:
/* constructor */
T2TSearch();
BeamSearch();
/* de-constructor */
~T2TSearch();
~BeamSearch();
/* initialize the model */
void Init(int argc, char ** argv);
void Init(T2TConfig& config);
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
void Search(T2TModel* model, XTensor& input, XTensor& padding, IntList* output, XTensor& score);
/* preparation */
void Prepare(int myBatchSize,int myBeamSize);
void Prepare(int myBatchSize, int myBeamSize);
/* compute the model score for each hypothesis */
void Score(T2TStateBundle * prev, T2TStateBundle * beam);
/* compute the model score for each hypotheses */
void Score(T2TStateBundle* prev, T2TStateBundle* beam);
/* generate token indices via beam pruning */
void Generate(T2TStateBundle * beam);
void Generate(T2TStateBundle* prev, T2TStateBundle* beam);
/* expand the search graph */
void Expand(T2TStateBundle * prev, T2TStateBundle * beam);
void Expand(T2TStateBundle* prev, T2TStateBundle* beam, XTensor& reorderState);
/* collect hypotheses with ending symbol */
void Collect(T2TStateBundle * beam);
void Collect(T2TStateBundle* beam);
/* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle * beam);
/* fill the hypotheses heap with incomplete hypotheses */
void FillHeap(T2TStateBundle* beam);
/* save the output sequences in a tensor */
void Dump(XTensor * output);
/* save the output sequences and score */
void Dump(IntList* output, XTensor* score);
/* check if the token is an end symbol */
bool IsEnd(int token);
/* check whether all hypotheses are completed */
bool IsAllCompleted(T2TStateBundle* beam);
/* update the beam by pruning finished states */
void RemoveFinishedStates(T2TStateBundle* beam, XTensor& aliveEncoding,
XTensor& aliveInput, XTensor& alivePadding, XTensor& aliveIdx);
/* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum);
void SetEnd(const int* tokens, const int tokenNum);
/* make a mask to prevent duplicated entries in beam expansion for the first position */
XTensor MakeFirstMask(T2TStateBundle * beam);
XTensor MakeFirstMask(T2TStateBundle* beam);
};
class GreedySearch
{
private:
/* predictor */
T2TPredictor predictor;
/* max length of the generated sequence */
int maxLength;
/* batch size */
int batchSize;
/* array of the end symbols */
int* endSymbols;
/* number of the end symbols */
int endSymbolNum;
/* start symbol */
int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
public:
/* constructor */
GreedySearch();
/* de-constructor */
~GreedySearch();
/* initialize the model */
void Init(T2TConfig& config);
/* search for the most promising states */
void Search(T2TModel* model, XTensor& input, XTensor& padding, IntList* output);
/* preparation */
void Prepare(int myBatchSize);
/* check if the token is an end symbol */
bool IsEnd(int token);
/* set end symbols for search */
void SetEnd(const int* tokens, const int tokenNum);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论