Commit 4bbd6a27 by hello

Merge with NiuTrans.NMT

parent 4a3a47f1
...@@ -5,3 +5,6 @@ vc140.pdb ...@@ -5,3 +5,6 @@ vc140.pdb
NiuTrans.Tensor.vcxproj.user NiuTrans.Tensor.vcxproj.user
NiuTrans.Tensor.aps NiuTrans.Tensor.aps
data/ data/
build/
xxx/
bin/
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -40,6 +40,16 @@ AttDecoder::AttDecoder() ...@@ -40,6 +40,16 @@ AttDecoder::AttDecoder()
decoderLayerNorm = NULL; decoderLayerNorm = NULL;
selfAttCache = NULL; selfAttCache = NULL;
enDeAttCache = NULL; enDeAttCache = NULL;
history = NULL;
preNorm = true;
useHistory = false;
finalNorm = false;
devID = -1;
eSize = -1;
hSize = -1;
nlayer = -1;
vSize = -1;
dropoutP = 0.0F;
} }
/* de-constructor */ /* de-constructor */
...@@ -53,8 +63,10 @@ AttDecoder::~AttDecoder() ...@@ -53,8 +63,10 @@ AttDecoder::~AttDecoder()
delete[] fnnLayerNorms; delete[] fnnLayerNorms;
delete[] enDeAtt; delete[] enDeAtt;
delete[] enDeAttLayerNorms; delete[] enDeAttLayerNorms;
if (preNorm) if (finalNorm)
delete decoderLayerNorm; delete decoderLayerNorm;
if (useHistory)
delete history;
} }
/* /*
...@@ -70,13 +82,12 @@ void AttDecoder::InitModel(Config& config) ...@@ -70,13 +82,12 @@ void AttDecoder::InitModel(Config& config)
vSize = config.tgtVocabSize; vSize = config.tgtVocabSize;
dropoutP = config.dropout; dropoutP = config.dropout;
preNorm = config.preNorm; preNorm = config.preNorm;
finalNorm = config.finalNorm;
useHistory = config.useHistory;
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!"); CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\""); CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\"");
/* embedding model */
embedder.InitModel(config, false);
selfAtt = new Attention[nlayer]; selfAtt = new Attention[nlayer];
fnns = new FNN[nlayer]; fnns = new FNN[nlayer];
selfAttLayerNorms = new LN[nlayer]; selfAttLayerNorms = new LN[nlayer];
...@@ -86,10 +97,15 @@ void AttDecoder::InitModel(Config& config) ...@@ -86,10 +97,15 @@ void AttDecoder::InitModel(Config& config)
selfAttCache = new Cache[nlayer]; selfAttCache = new Cache[nlayer];
enDeAttCache = new Cache[nlayer]; enDeAttCache = new Cache[nlayer];
if (preNorm)
if (finalNorm)
decoderLayerNorm = new LN; decoderLayerNorm = new LN;
if (useHistory)
history = new LayerHistory;
/* initialize the stacked layers */ /* initialize the stacked layers */
embedder.InitModel(config, false);
for (int i = 0; i < nlayer; i++) { for (int i = 0; i < nlayer; i++) {
selfAtt[i].InitModel(config); selfAtt[i].InitModel(config);
fnns[i].InitModel(config); fnns[i].InitModel(config);
...@@ -100,8 +116,10 @@ void AttDecoder::InitModel(Config& config) ...@@ -100,8 +116,10 @@ void AttDecoder::InitModel(Config& config)
selfAttCache[i].enable = true; selfAttCache[i].enable = true;
enDeAttCache[i].enable = true; enDeAttCache[i].enable = true;
} }
if (preNorm) if (finalNorm)
decoderLayerNorm->InitModel(config); decoderLayerNorm->InitModel(config);
if (useHistory)
history->InitModel(config);
} }
/* /*
...@@ -117,15 +135,26 @@ make the decoding network ...@@ -117,15 +135,26 @@ make the decoding network
XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask, XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
XTensor* maskEncDec, int nstep, bool isTraining) XTensor* maskEncDec, int nstep, bool isTraining)
{ {
/* clear the history */
if (useHistory)
history->ClearHistory();
XTensor x; XTensor x;
x = embedder.Make(inputDec, true, isTraining, nstep); x = embedder.Make(inputDec, true, isTraining, nstep);
/* dropout */ /* dropout */
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP); x = Dropout(x, dropoutP, /*inplace=*/true);
if (useHistory)
history->Add(x);
for (int i = 0; i < nlayer; i++) { for (int i = 0; i < nlayer; i++) {
if (useHistory)
x = history->Pop();
XTensor att; XTensor att;
XTensor ende; XTensor ende;
XTensor fnn; XTensor fnn;
...@@ -146,10 +175,10 @@ XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask, ...@@ -146,10 +175,10 @@ XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
/* dropout */ /* dropout */
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
att = Dropout(att, dropoutP); att = Dropout(att, dropoutP, /*inplace=*/true);
/* residual connection */ /* residual connection */
res = Sum(att, x); res = Sum(att, x, /*inplace=*/true);
/* layer normalization with post-norm for self-attention */ /* layer normalization with post-norm for self-attention */
selfAttnAfter = LayerNorm(res, selfAttLayerNorms[i], preNorm, false, true); selfAttnAfter = LayerNorm(res, selfAttLayerNorms[i], preNorm, false, true);
...@@ -163,10 +192,10 @@ XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask, ...@@ -163,10 +192,10 @@ XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
/* dropout */ /* dropout */
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP); ende = Dropout(ende, dropoutP, /*inplace=*/true);
/* residual connection */ /* residual connection */
res = Sum(ende, selfAttnAfter); res = Sum(ende, selfAttnAfter, /*inplace=*/true);
/* layer normalization with post-norm for encoder-decoder attention */ /* layer normalization with post-norm for encoder-decoder attention */
endeAttnAfter = LayerNorm(res, enDeAttLayerNorms[i], preNorm, false, true); endeAttnAfter = LayerNorm(res, enDeAttLayerNorms[i], preNorm, false, true);
...@@ -179,94 +208,27 @@ XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask, ...@@ -179,94 +208,27 @@ XTensor AttDecoder::Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
/* dropout */ /* dropout */
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP); fnn = Dropout(fnn, dropoutP, /*inplace=*/true);
/* residual connection */ /* residual connection */
res = Sum(fnn, endeAttnAfter); res = Sum(fnn, endeAttnAfter, /*inplace=*/true);
/* layer normalization with post-norm for fnn */ /* layer normalization with post-norm for fnn */
x = LayerNorm(res, fnnLayerNorms[i], preNorm, false, true); x = LayerNorm(res, fnnLayerNorms[i], preNorm, false, true);
}
if (preNorm)
return decoderLayerNorm->Make(x);
return x; if (useHistory)
} history->Add(x);
}
/*
make the decoding network
>> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder
>> mask - mask that indicates which position is valid
>> maskEncDec - mask for the encoder-decoder attention
>> nstep - the current length of the decoder input
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the decoder
*/
XTensor AttDecoder::MakeFast(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
XTensor* maskEncDec, int nstep, bool isTraining)
{
XTensor x;
x = embedder.Make(inputDec, true, isTraining, nstep);
/* dropout */
if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
for (int i = 0; i < nlayer; i++) {
XTensor res;
res = x;
/* layer normalization with pre-norm for self-attn */
x = selfAttLayerNorms[i].Make(x);
/******************/
/* self attention */
x = selfAtt[i].Make(x, x, x, mask, isTraining, &selfAttCache[i], SELF_ATT);
/* dropout */
if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
/* residual connection */
x = Sum(res, x);
res = x;
/* layer normalization with pre-norm for encoder-decoder attention */
x = enDeAttLayerNorms[i].Make(x);
/* encoder-decoder attention */
x = enDeAtt[i].Make(outputEnc, x, outputEnc, maskEncDec,
isTraining, &enDeAttCache[i], EN_DE_ATT);
/* dropout */
if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
/* residual connection */
x = Sum(res, x);
res = x;
/* layer normalization with pre-norm for fnn */
x = fnnLayerNorms[i].Make(x);
/* fnn */
x = fnns[i].Make(x, isTraining);
/* dropout */ if (useHistory)
if (isTraining && dropoutP > 0) x = history->Pop();
x = Dropout(x, dropoutP);
/* residual connection */ /* clear the history while not training */
x = Sum(res, x); if (useHistory && !isTraining)
} history->ClearHistory();
x = decoderLayerNorm->Make(x); if (finalNorm)
return decoderLayerNorm->Make(x);
return x; return x;
} }
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -74,6 +74,9 @@ public: ...@@ -74,6 +74,9 @@ public:
/* layer normalization for encoder-decoder attention */ /* layer normalization for encoder-decoder attention */
LN* enDeAttLayerNorms; LN* enDeAttLayerNorms;
/* dynamic layer history */
LayerHistory* history;
/* layer cache list */ /* layer cache list */
Cache* selfAttCache; Cache* selfAttCache;
...@@ -83,6 +86,12 @@ public: ...@@ -83,6 +86,12 @@ public:
/* the location of layer normalization */ /* the location of layer normalization */
bool preNorm; bool preNorm;
/* add LN to the decoder output or not */
bool finalNorm;
/* reserve history for layers or not */
bool useHistory;
public: public:
/* constructor */ /* constructor */
AttDecoder(); AttDecoder();
...@@ -96,10 +105,6 @@ public: ...@@ -96,10 +105,6 @@ public:
/* make the decoding network */ /* make the decoding network */
XTensor Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask, XTensor Make(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
XTensor* maskEncDec, int nstep, bool isTraining); XTensor* maskEncDec, int nstep, bool isTraining);
/* make the decoding network (pre norm) */
XTensor MakeFast(XTensor& inputDec, XTensor& outputEnc, XTensor* mask,
XTensor* maskEncDec, int nstep, bool isTraining);
}; };
} }
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -31,11 +31,22 @@ namespace nmt ...@@ -31,11 +31,22 @@ namespace nmt
/* constructor */ /* constructor */
AttEncoder::AttEncoder() AttEncoder::AttEncoder()
{ {
devID = -1;
selfAtt = NULL; selfAtt = NULL;
fnns = NULL; fnns = NULL;
attLayerNorms = NULL; attLayerNorms = NULL;
fnnLayerNorms = NULL; fnnLayerNorms = NULL;
encoderLayerNorm = NULL; encoderLayerNorm = NULL;
useHistory = false;
history = NULL;
dropoutP = 0.0;
eSize = -1;
finalNorm = false;
hSize = -1;
ignored = -1;
nlayer = -1;
preNorm = false;
vSize = -1;
} }
/* de-constructor */ /* de-constructor */
...@@ -45,8 +56,10 @@ AttEncoder::~AttEncoder() ...@@ -45,8 +56,10 @@ AttEncoder::~AttEncoder()
delete[] fnns; delete[] fnns;
delete[] attLayerNorms; delete[] attLayerNorms;
delete[] fnnLayerNorms; delete[] fnnLayerNorms;
if (preNorm) if (finalNorm)
delete encoderLayerNorm; delete encoderLayerNorm;
if (useHistory)
delete history;
} }
/* /*
...@@ -62,31 +75,36 @@ void AttEncoder::InitModel(Config& config) ...@@ -62,31 +75,36 @@ void AttEncoder::InitModel(Config& config)
hSize = config.modelSize; hSize = config.modelSize;
vSize = config.srcVocabSize; vSize = config.srcVocabSize;
preNorm = config.preNorm; preNorm = config.preNorm;
finalNorm = config.finalNorm;
useHistory = config.useHistory;
dropoutP = config.dropout; dropoutP = config.dropout;
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!"); CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "Set vocabulary size by \"-vsize\""); CheckNTErrors(vSize > 1, "Set vocabulary size by \"-vsize\"");
/* embedding model */
embedder.InitModel(config);
selfAtt = new Attention[nlayer]; selfAtt = new Attention[nlayer];
fnns = new FNN[nlayer]; fnns = new FNN[nlayer];
attLayerNorms = new LN[nlayer]; attLayerNorms = new LN[nlayer];
fnnLayerNorms = new LN[nlayer]; fnnLayerNorms = new LN[nlayer];
if (preNorm) if (finalNorm)
encoderLayerNorm = new LN; encoderLayerNorm = new LN;
if (useHistory)
history = new LayerHistory;
/* initialize the stacked layers */ /* initialize the stacked layers */
embedder.InitModel(config);
for (int i = 0; i < nlayer; i++) { for (int i = 0; i < nlayer; i++) {
selfAtt[i].InitModel(config); selfAtt[i].InitModel(config);
fnns[i].InitModel(config); fnns[i].InitModel(config);
attLayerNorms[i].InitModel(config); attLayerNorms[i].InitModel(config);
fnnLayerNorms[i].InitModel(config); fnnLayerNorms[i].InitModel(config);
} }
if (preNorm) if (finalNorm)
encoderLayerNorm->InitModel(config); encoderLayerNorm->InitModel(config);
if (useHistory)
history->InitModel(config);
} }
/* /*
...@@ -99,15 +117,25 @@ make the encoding network ...@@ -99,15 +117,25 @@ make the encoding network
*/ */
XTensor AttEncoder::Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining) XTensor AttEncoder::Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining)
{ {
XTensor x; /* clear the history */
if (useHistory)
history->ClearHistory();
XTensor x;
x = embedder.Make(input, false, isTraining); x = embedder.Make(input, false, isTraining);
/* dropout */ /* dropout */
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP); x = Dropout(x, dropoutP, /*inplace=*/true);
if (useHistory)
history->Add(x);
for (int i = 0; i < nlayer; i++) { for (int i = 0; i < nlayer; i++) {
if (useHistory)
x = history->Pop();
XTensor att; XTensor att;
XTensor fnn; XTensor fnn;
XTensor res; XTensor res;
...@@ -123,10 +151,10 @@ XTensor AttEncoder::Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, boo ...@@ -123,10 +151,10 @@ XTensor AttEncoder::Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, boo
/* dropout */ /* dropout */
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
att = Dropout(att, dropoutP); att = Dropout(att, dropoutP, /*inplace=*/true);
/* residual connection */ /* residual connection */
res = Sum(att, x); res = Sum(att, x, /*inplace=*/true);
/* layer normalization with post-norm for self-attn */ /* layer normalization with post-norm for self-attn */
attnAfter = LayerNorm(res, attLayerNorms[i], preNorm, false, true); attnAfter = LayerNorm(res, attLayerNorms[i], preNorm, false, true);
...@@ -139,72 +167,27 @@ XTensor AttEncoder::Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, boo ...@@ -139,72 +167,27 @@ XTensor AttEncoder::Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, boo
/* dropout */ /* dropout */
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP); fnn = Dropout(fnn, dropoutP, /*inplace=*/true);
/* residual connection */ /* residual connection */
res = Sum(fnn, attnAfter); res = Sum(fnn, attnAfter, /*inplace=*/true);
/* layer normalization with post-norm for fnn */ /* layer normalization with post-norm for fnn */
x = LayerNorm(res, fnnLayerNorms[i], preNorm, false, true); x = LayerNorm(res, fnnLayerNorms[i], preNorm, false, true);
}
if (preNorm)
return encoderLayerNorm->Make(x);
return x;
}
/*
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> maskEncDec - no use
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::MakeFast(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining)
{
XTensor x;
x = embedder.Make(input, false, isTraining);
/* dropout */
if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
for (int i = 0; i < nlayer; i++) {
XTensor res;
res = x;
/* layer normalization with pre-norm for self-attn */
x = attLayerNorms[i].Make(x);
/* self attention */
x = selfAtt[i].Make(x, x, x, mask, isTraining, NULL, SELF_ATT);
/* dropout */
if (isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
/* residual connection */
x = Sum(res, x);
res = x; if (useHistory)
history->Add(x);
}
/* layer normalization with pre-norm for fnn */ if (useHistory)
x = fnnLayerNorms[i].Make(x); x = history->Pop();
/* fnn */ /* clear the history while not training */
x = fnns[i].Make(x, isTraining); if (useHistory && !isTraining)
history->ClearHistory();
/* dropout */ if (finalNorm)
if (isTraining && dropoutP > 0) return encoderLayerNorm->Make(x);
x = Dropout(x, dropoutP);
/* residual connection */
x = Sum(res, x);
}
x = encoderLayerNorm->Make(x);
return x; return x;
} }
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "submodel/Attention.h" #include "submodel/Attention.h"
#include "submodel/Embedding.h" #include "submodel/Embedding.h"
#include "submodel/LayerNorm.h" #include "submodel/LayerNorm.h"
#include "submodel/LayerHistory.h"
#include "../../network/XNet.h" #include "../../network/XNet.h"
using namespace nts; using namespace nts;
...@@ -89,9 +90,18 @@ public: ...@@ -89,9 +90,18 @@ public:
/* layer normalization for encoder */ /* layer normalization for encoder */
LN* encoderLayerNorm; LN* encoderLayerNorm;
/* dynamic layer history */
LayerHistory* history;
/* the location of layer normalization */ /* the location of layer normalization */
bool preNorm; bool preNorm;
/* add LN to the encoder output or not */
bool finalNorm;
/* reserve history for layers or not */
bool useHistory;
public: public:
/* constructor */ /* constructor */
AttEncoder(); AttEncoder();
...@@ -105,9 +115,6 @@ public: ...@@ -105,9 +115,6 @@ public:
/* make the encoding network */ /* make the encoding network */
XTensor Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining); XTensor Make(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining);
/* make the encoding network */
XTensor MakeFast(XTensor& input, XTensor* mask, XTensor& maskEncDec, bool isTraining);
/* make the encoding network (wrapper) */ /* make the encoding network (wrapper) */
XTensor Make(XTensor& input, XTensor* mask, bool isTraining); XTensor Make(XTensor& input, XTensor* mask, bool isTraining);
}; };
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -72,20 +72,23 @@ void Model::InitModel(Config& config) ...@@ -72,20 +72,23 @@ void Model::InitModel(Config& config)
&config.tgtVocabSize, &config.nhead, &config.tgtVocabSize, &config.nhead,
&config.maxRP, &config.shareAllEmbeddings, &config.maxRP, &config.shareAllEmbeddings,
&config.shareDecInputOutputWeight, &config.shareDecInputOutputWeight,
&config.maxPosLen &config.maxPosition
}; };
FILE* modelFile = NULL; FILE* modelFile = NULL;
/* read model configurations */ /* read model configurations */
if (!config.isTraining) { if (!config.isTraining || strcmp(config.pretrainedModel, "") != 0) {
if (strcmp(config.pretrainedModel, "") != 0)
modelFile = fopen(config.pretrainedModel, "rb");
else
modelFile = fopen(config.modelFN, "rb"); modelFile = fopen(config.modelFN, "rb");
CheckNTErrors(modelFile, "Failed to open the model file"); CheckNTErrors(modelFile, "Failed to open the model file");
for (auto& meta : metaInfo) { for (auto& meta : metaInfo) {
fread(meta, sizeof(int), 1, modelFile); fread(meta, sizeof(int), 1, modelFile);
} }
} }
else { if (config.isTraining) {
/* read the source and target vocab size */ /* read the source and target vocab size */
FILE* trainF = fopen(config.trainFN, "rb"); FILE* trainF = fopen(config.trainFN, "rb");
CheckNTErrors(trainF, "Failed to open the training file"); CheckNTErrors(trainF, "Failed to open the training file");
...@@ -110,9 +113,10 @@ void Model::InitModel(Config& config) ...@@ -110,9 +113,10 @@ void Model::InitModel(Config& config)
decoder->InitModel(config); decoder->InitModel(config);
/* load parameters */ /* load parameters */
if (!config.isTraining) if (!config.isTraining || strcmp(config.pretrainedModel, "") != 0)
Read(modelFile); Read(modelFile);
else {
if (config.isTraining) {
TensorList params; TensorList params;
GetParams(params); GetParams(params);
for (int i = 0; i < params.Size(); i++) for (int i = 0; i < params.Size(); i++)
...@@ -220,6 +224,8 @@ void Model::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, ...@@ -220,6 +224,8 @@ void Model::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output,
XTensor maskDec; XTensor maskDec;
XTensor maskEncDec; XTensor maskEncDec;
bool debug(false);
/* encoder mask */ /* encoder mask */
MakeMTMaskEnc(paddingEnc, maskEnc); MakeMTMaskEnc(paddingEnc, maskEnc);
...@@ -228,9 +234,25 @@ void Model::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output, ...@@ -228,9 +234,25 @@ void Model::MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output,
encoding = MakeEncoder(inputEnc, &maskEnc, isTraining); encoding = MakeEncoder(inputEnc, &maskEnc, isTraining);
if (debug) {
LOG("after encoding:");
encoding.mem->ShowMemUsage(stderr);
}
decoding = MakeDecoder(inputDec, encoding, &maskDec, maskEncDec, isTraining); decoding = MakeDecoder(inputDec, encoding, &maskDec, maskEncDec, isTraining);
if (debug) {
LOG("after decoding:");
encoding.mem->ShowMemUsage(stderr);
}
outputLayer->Make(decoding, output, true, true); outputLayer->Make(decoding, output, true, true);
if (debug) {
LOG("after outputing:");
encoding.mem->ShowMemUsage(stderr);
exit(0);
}
} }
/* /*
...@@ -265,9 +287,9 @@ void Model::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -265,9 +287,9 @@ void Model::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID); InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
XTensor* maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, XTensor* maskEncDecTMPEnc = NewTensorBufV2(paddingEnc.order + 1, dims + 1,
paddingEnc.dataType, paddingEnc.devID); paddingEnc.dataType, 1.0F, paddingEnc.devID, paddingEnc.mem);
XTensor* maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID); XTensor* maskEncDecTMPDec = NewTensorBufV2(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F); _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
...@@ -283,14 +305,14 @@ void Model::MakeMTMask(XTensor& inputEnc, XTensor& inputDec, ...@@ -283,14 +305,14 @@ void Model::MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1); dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
XTensor* padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, XTensor* padding2 = NewTensorBufV2(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType, 1.0F,
paddingEnc.devID); paddingEnc.devID, paddingEnc.mem);
for (int i = 0; i < padding2->order; i++) for (int i = 0; i < padding2->order; i++)
dimsPadding[i + 1] = padding2->GetDim(i); dimsPadding[i + 1] = padding2->GetDim(i);
dimsPadding[0] = nhead; dimsPadding[0] = nhead;
XTensor* padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, paddingEnc.devID); XTensor* padding3 = NewTensorBufV2(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType, 1.0F, paddingEnc.devID, paddingEnc.mem);
/* mask of the padding */ /* mask of the padding */
_Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
...@@ -322,6 +344,7 @@ void Model::MakeMTMaskEnc(XTensor& paddingEnc, XTensor& maskEnc) ...@@ -322,6 +344,7 @@ void Model::MakeMTMaskEnc(XTensor& paddingEnc, XTensor& maskEnc)
/* mask of the padding */ /* mask of the padding */
Unsqueeze(paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1)); Unsqueeze(paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
Unsqueeze(padding2, maskEnc, 0, nhead); Unsqueeze(padding2, maskEnc, 0, nhead);
ScaleAndShiftMe(maskEnc, 1e9F, -1e9F); ScaleAndShiftMe(maskEnc, 1e9F, -1e9F);
} }
...@@ -355,6 +378,7 @@ void Model::MakeMTMaskDec(XTensor& paddingEnc, XTensor& paddingDec, ...@@ -355,6 +378,7 @@ void Model::MakeMTMaskDec(XTensor& paddingEnc, XTensor& paddingDec,
Unsqueeze(paddingEnc, maskEncDecTMP, paddingEnc.order - 1, paddingDec.GetDim(-1)); Unsqueeze(paddingEnc, maskEncDecTMP, paddingEnc.order - 1, paddingDec.GetDim(-1));
ScaleAndShiftMe(maskEncDecTMP, 1e9F, -1e9F); ScaleAndShiftMe(maskEncDecTMP, 1e9F, -1e9F);
Unsqueeze(maskEncDecTMP, maskEncDec, 0, dims[0]); Unsqueeze(maskEncDecTMP, maskEncDec, 0, dims[0]);
delete[] dims; delete[] dims;
...@@ -369,6 +393,14 @@ void Model::GetParams(TensorList& list) ...@@ -369,6 +393,14 @@ void Model::GetParams(TensorList& list)
list.Clear(); list.Clear();
/* encoder parameters */ /* encoder parameters */
if (encoder->useHistory) {
for (int i = 0; i < encoder->nlayer + 1; i++)
list.Add(&encoder->history->weights[i]);
for (int i = 0; i < encoder->nlayer; i++) {
list.Add(&encoder->history->layerNorms[i].weight);
list.Add(&encoder->history->layerNorms[i].bias);
}
}
for (int i = 0; i < encoder->nlayer; i++) { for (int i = 0; i < encoder->nlayer; i++) {
list.Add(&encoder->selfAtt[i].weightQ); list.Add(&encoder->selfAtt[i].weightQ);
list.Add(&encoder->selfAtt[i].weightK); list.Add(&encoder->selfAtt[i].weightK);
...@@ -384,18 +416,27 @@ void Model::GetParams(TensorList& list) ...@@ -384,18 +416,27 @@ void Model::GetParams(TensorList& list)
list.Add(&encoder->fnns[i].b1); list.Add(&encoder->fnns[i].b1);
list.Add(&encoder->fnns[i].w2); list.Add(&encoder->fnns[i].w2);
list.Add(&encoder->fnns[i].b2); list.Add(&encoder->fnns[i].b2);
list.Add(&encoder->attLayerNorms[i].w); list.Add(&encoder->attLayerNorms[i].weight);
list.Add(&encoder->attLayerNorms[i].b); list.Add(&encoder->attLayerNorms[i].bias);
list.Add(&encoder->fnnLayerNorms[i].w); list.Add(&encoder->fnnLayerNorms[i].weight);
list.Add(&encoder->fnnLayerNorms[i].b); list.Add(&encoder->fnnLayerNorms[i].bias);
} }
if (encoder->preNorm) { if (encoder->finalNorm) {
list.Add(&encoder->encoderLayerNorm->w); list.Add(&encoder->encoderLayerNorm->weight);
list.Add(&encoder->encoderLayerNorm->b); list.Add(&encoder->encoderLayerNorm->bias);
} }
if (isMT) { if (isMT) {
/* decoder parameters */ /* decoder parameters */
if (decoder->useHistory) {
for (int i = 0; i < decoder->nlayer + 1; i++)
list.Add(&decoder->history->weights[i]);
for (int i = 0; i < decoder->nlayer; i++) {
list.Add(&decoder->history->layerNorms[i].weight);
list.Add(&decoder->history->layerNorms[i].bias);
}
}
for (int i = 0; i < decoder->nlayer; i++) { for (int i = 0; i < decoder->nlayer; i++) {
list.Add(&decoder->selfAtt[i].weightQ); list.Add(&decoder->selfAtt[i].weightQ);
list.Add(&decoder->selfAtt[i].weightK); list.Add(&decoder->selfAtt[i].weightK);
...@@ -407,8 +448,8 @@ void Model::GetParams(TensorList& list) ...@@ -407,8 +448,8 @@ void Model::GetParams(TensorList& list)
list.Add(&decoder->selfAtt[i].RPEmbK); list.Add(&decoder->selfAtt[i].RPEmbK);
list.Add(&decoder->selfAtt[i].weightO); list.Add(&decoder->selfAtt[i].weightO);
list.Add(&decoder->selfAtt[i].biasO); list.Add(&decoder->selfAtt[i].biasO);
list.Add(&decoder->selfAttLayerNorms[i].w); list.Add(&decoder->selfAttLayerNorms[i].weight);
list.Add(&decoder->selfAttLayerNorms[i].b); list.Add(&decoder->selfAttLayerNorms[i].bias);
list.Add(&decoder->enDeAtt[i].weightQ); list.Add(&decoder->enDeAtt[i].weightQ);
list.Add(&decoder->enDeAtt[i].weightK); list.Add(&decoder->enDeAtt[i].weightK);
list.Add(&decoder->enDeAtt[i].weightV); list.Add(&decoder->enDeAtt[i].weightV);
...@@ -417,18 +458,18 @@ void Model::GetParams(TensorList& list) ...@@ -417,18 +458,18 @@ void Model::GetParams(TensorList& list)
list.Add(&decoder->enDeAtt[i].biasV); list.Add(&decoder->enDeAtt[i].biasV);
list.Add(&decoder->enDeAtt[i].weightO); list.Add(&decoder->enDeAtt[i].weightO);
list.Add(&decoder->enDeAtt[i].biasO); list.Add(&decoder->enDeAtt[i].biasO);
list.Add(&decoder->enDeAttLayerNorms[i].w); list.Add(&decoder->enDeAttLayerNorms[i].weight);
list.Add(&decoder->enDeAttLayerNorms[i].b); list.Add(&decoder->enDeAttLayerNorms[i].bias);
list.Add(&decoder->fnns[i].w1); list.Add(&decoder->fnns[i].w1);
list.Add(&decoder->fnns[i].b1); list.Add(&decoder->fnns[i].b1);
list.Add(&decoder->fnns[i].w2); list.Add(&decoder->fnns[i].w2);
list.Add(&decoder->fnns[i].b2); list.Add(&decoder->fnns[i].b2);
list.Add(&decoder->fnnLayerNorms[i].w); list.Add(&decoder->fnnLayerNorms[i].weight);
list.Add(&decoder->fnnLayerNorms[i].b); list.Add(&decoder->fnnLayerNorms[i].bias);
} }
if (decoder->preNorm) { if (decoder->finalNorm) {
list.Add(&decoder->decoderLayerNorm->w); list.Add(&decoder->decoderLayerNorm->weight);
list.Add(&decoder->decoderLayerNorm->b); list.Add(&decoder->decoderLayerNorm->bias);
} }
} }
...@@ -490,7 +531,7 @@ void Model::Read(FILE* file) ...@@ -490,7 +531,7 @@ void Model::Read(FILE* file)
TensorList params; TensorList params;
GetParams(params); GetParams(params);
LOG("params count: %lu", (unsigned long)params.Size()); LOG("params count: %zd", params.Size());
int size = 0; int size = 0;
for (int i = 0; i < params.Size(); i++) { for (int i = 0; i < params.Size(); i++) {
size += params[i]->unitNum; size += params[i]->unitNum;
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -36,22 +36,31 @@ int NMTMain(int argc, const char** argv) ...@@ -36,22 +36,31 @@ int NMTMain(int argc, const char** argv)
/* load configurations */ /* load configurations */
Config config(argc, argv); Config config(argc, argv);
srand(1); srand(config.seed);
/* training */ /* training */
if (strcmp(config.trainFN, "") != 0) { if (strcmp(config.trainFN, "") != 0) {
Model model; Model model;
model.InitModel(config); model.InitModel(config);
TensorList params;
model.GetParams(params);
int count = 0;
for (int i = 0; i < params.count; i++){
count += params[i]->unitNum;
}
LOG("number of parameters: %d", count);
Trainer trainer; Trainer trainer;
trainer.Init(config); trainer.Init(config);
trainer.Train(config.trainFN, config.validFN, config.modelFN, &model); trainer.Train(config.trainFN, config.validFN, config.modelFN, &model);
} }
/* translating */ /* translating */
if (strcmp(config.testFN, "") != 0 && strcmp(config.outputFN, "") != 0) { else if (strcmp(config.testFN, "") != 0 && strcmp(config.outputFN, "") != 0) {
/* disable grad flow */ /* disable gradient flow */
DISABLE_GRAD; DISABLE_GRAD;
Model model; Model model;
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -63,18 +63,23 @@ Config::Config(int argc, const char** argv) ...@@ -63,18 +63,23 @@ Config::Config(int argc, const char** argv)
LoadParamInt(argsNum, args, "nhead", &nhead, 4); LoadParamInt(argsNum, args, "nhead", &nhead, 4);
LoadParamInt(argsNum, args, "enclayer", &nEncLayer, 6); LoadParamInt(argsNum, args, "enclayer", &nEncLayer, 6);
LoadParamInt(argsNum, args, "declayer", &nDecLayer, 6); LoadParamInt(argsNum, args, "declayer", &nDecLayer, 6);
LoadParamInt(argsNum, args, "maxrp", &maxRP, 8); LoadParamInt(argsNum, args, "maxrp", &maxRP, -1);
LoadParamInt(argsNum, args, "embsize", &embSize, 512); LoadParamInt(argsNum, args, "embsize", &embSize, 512);
LoadParamInt(argsNum, args, "modelsize", &modelSize, 512); LoadParamInt(argsNum, args, "modelsize", &modelSize, 512);
LoadParamInt(argsNum, args, "maxpos", &maxPosLen, 1024); LoadParamInt(argsNum, args, "maxpos", &maxPosition, 1024);
LoadParamInt(argsNum, args, "maxsrclen", &maxSrcLen, 1024);
LoadParamInt(argsNum, args, "maxtgtlen", &maxTgtLen, 1024);
LoadParamInt(argsNum, args, "fnnhidden", &fnnHiddenSize, modelSize * 2); LoadParamInt(argsNum, args, "fnnhidden", &fnnHiddenSize, modelSize * 2);
LoadParamInt(argsNum, args, "vsize", &srcVocabSize, 10152); LoadParamInt(argsNum, args, "vsize", &srcVocabSize, 10152);
LoadParamInt(argsNum, args, "vsizetgt", &tgtVocabSize, 10152); LoadParamInt(argsNum, args, "vsizetgt", &tgtVocabSize, 10152);
LoadParamInt(argsNum, args, "padid", &padID, 1); LoadParamInt(argsNum, args, "padid", &padID, 1);
LoadParamInt(argsNum, args, "startid", &startID, 2); LoadParamInt(argsNum, args, "startid", &startID, 2);
LoadParamInt(argsNum, args, "endid", &endID, 2); LoadParamInt(argsNum, args, "endid", &endID, 2);
LoadParamBool(argsNum, args, "rpr", &useRPR, false); LoadParamInt(argsNum, args, "unkid", &unkID, 3);
LoadParamBool(argsNum, args, "rpr", &useRPR, maxRP > 0);
LoadParamBool(argsNum, args, "prenorm", &preNorm, true); LoadParamBool(argsNum, args, "prenorm", &preNorm, true);
LoadParamBool(argsNum, args, "finalnorm", &finalNorm, true);
LoadParamBool(argsNum, args, "dlcl", &useHistory, false);
// TODO: refactor the parameters type to support weight sharing during training // TODO: refactor the parameters type to support weight sharing during training
LoadParamInt(argsNum, args, "shareemb", &shareAllEmbeddings, 0); LoadParamInt(argsNum, args, "shareemb", &shareAllEmbeddings, 0);
...@@ -86,9 +91,12 @@ Config::Config(int argc, const char** argv) ...@@ -86,9 +91,12 @@ Config::Config(int argc, const char** argv)
/* options for training */ /* options for training */
LoadParamString(argsNum, args, "train", trainFN, ""); LoadParamString(argsNum, args, "train", trainFN, "");
LoadParamString(argsNum, args, "valid", validFN, ""); LoadParamString(argsNum, args, "valid", validFN, "");
LoadParamString(argsNum, args, "pretrain", pretrainedModel, "");
LoadParamInt(argsNum, args, "dev", &devID, 0); LoadParamInt(argsNum, args, "dev", &devID, 0);
LoadParamInt(argsNum, args, "seed", &seed, 1);
LoadParamInt(argsNum, args, "log", &logInterval, 100);
LoadParamInt(argsNum, args, "wbatch", &wBatchSize, 4096); LoadParamInt(argsNum, args, "wbatch", &wBatchSize, 4096);
LoadParamInt(argsNum, args, "sbatch", &sBatchSize, 8); LoadParamInt(argsNum, args, "sbatch", &sBatchSize, 16);
isTraining = (strcmp(trainFN, "") == 0) ? false : true; isTraining = (strcmp(trainFN, "") == 0) ? false : true;
LoadParamBool(argsNum, args, "mt", &isMT, true); LoadParamBool(argsNum, args, "mt", &isMT, true);
LoadParamFloat(argsNum, args, "dropout", &dropout, 0.3F); LoadParamFloat(argsNum, args, "dropout", &dropout, 0.3F);
...@@ -117,7 +125,7 @@ Config::Config(int argc, const char** argv) ...@@ -117,7 +125,7 @@ Config::Config(int argc, const char** argv)
LoadParamBool(argc, args, "smallbatch", &isSmallBatch, true); LoadParamBool(argc, args, "smallbatch", &isSmallBatch, true);
LoadParamBool(argc, args, "bigbatch", &isBigBatch, false); LoadParamBool(argc, args, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, args, "randbatch", &isRandomBatch, false); LoadParamBool(argc, args, "randbatch", &isRandomBatch, false);
LoadParamInt(argc, args, "bucketsize", &bucketSize, wBatchSize * 10); LoadParamInt(argc, args, "bucketsize", &bucketSize, wBatchSize * 1);
/* options for translating */ /* options for translating */
LoadParamString(argsNum, args, "test", testFN, ""); LoadParamString(argsNum, args, "test", testFN, "");
...@@ -241,8 +249,6 @@ void ShowParams(int argc, char** argv) ...@@ -241,8 +249,6 @@ void ShowParams(int argc, char** argv)
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
#define MAX_WORD_NUM 120
/* /*
split string by delimiter, this will return indices of all sub-strings split string by delimiter, this will return indices of all sub-strings
>> s - the original string >> s - the original string
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -73,9 +73,18 @@ public: ...@@ -73,9 +73,18 @@ public:
/* path to the validation file */ /* path to the validation file */
char validFN[1024]; char validFN[1024];
/* path to the pre-trained model */
char pretrainedModel[1024];
/* device id */ /* device id */
int devID; int devID;
/* random seed */
int seed;
/* interval step for logging */
int logInterval;
/* beam size */ /* beam size */
int beamSize; int beamSize;
...@@ -104,7 +113,13 @@ public: ...@@ -104,7 +113,13 @@ public:
int modelSize; int modelSize;
/* the maximum length in positional embedding */ /* the maximum length in positional embedding */
int maxPosLen; int maxPosition;
/* the maximum length for the source sequence */
int maxSrcLen;
/* the maximum length for the target sequence */
int maxTgtLen;
/* the dimension of fnn hidden layer */ /* the dimension of fnn hidden layer */
int fnnHiddenSize; int fnnHiddenSize;
...@@ -118,6 +133,9 @@ public: ...@@ -118,6 +133,9 @@ public:
/* the padding id */ /* the padding id */
int padID; int padID;
/* the unk id */
int unkID;
/* start symbol */ /* start symbol */
int startID; int startID;
...@@ -127,6 +145,12 @@ public: ...@@ -127,6 +145,12 @@ public:
/* indicates whether the model uses pre-norm */ /* indicates whether the model uses pre-norm */
bool preNorm; bool preNorm;
/* add LN to the encoder/decoder output or not */
bool finalNorm;
/* reserve history for encoder/decoder layers or not */
bool useHistory;
/* indicates whether the model is running for machine translation */ /* indicates whether the model is running for machine translation */
bool isMT; bool isMT;
...@@ -139,10 +163,10 @@ public: ...@@ -139,10 +163,10 @@ public:
/* indicates whether the model is running with FP16 data type */ /* indicates whether the model is running with FP16 data type */
bool useFP16; bool useFP16;
/* indicates whether we use the RPR attention */ /* use the RPR attention or not */
bool useRPR; bool useRPR;
/* indicates whether we train the model */ /* train the model or not */
bool isTraining; bool isTraining;
/* dropout rate for the model */ /* dropout rate for the model */
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -29,10 +29,14 @@ namespace nmt ...@@ -29,10 +29,14 @@ namespace nmt
/* constructor */ /* constructor */
Attention::Attention() Attention::Attention()
{ {
devID = -1;
nhead = -1; nhead = -1;
dk = -1; dk = -1;
dv = -1; dv = -1;
d = -1; d = -1;
dropoutP = 0.0;
maxRP = -1;
useRPR = false;
} }
/* de-constructor */ /* de-constructor */
...@@ -82,17 +86,17 @@ void Attention::InitModel(Config& config) ...@@ -82,17 +86,17 @@ void Attention::InitModel(Config& config)
biasQ.SetZeroAll(); biasQ.SetZeroAll();
biasO.SetZeroAll(); biasO.SetZeroAll();
biasK.SetDataRand(-(DTYPE)sqrt(6.0F / d), (DTYPE)sqrt(6.0F / d)); biasK.SetDataRandn(-(DTYPE)sqrt(6.0F / d), (DTYPE)sqrt(6.0F / d));
biasV.SetDataRand(-(DTYPE)sqrt(6.0F / d), (DTYPE)sqrt(6.0F / d)); biasV.SetDataRandn(-(DTYPE)sqrt(6.0F / d), (DTYPE)sqrt(6.0F / d));
} }
/* /*
make the network make the network
>> k - keys, B * L * H for encoders, B * 1 * H for decoders >> k - keys, B * L * H
where B = batch size, L = sequence length, where B = batch size, L = sequence length,
and H = vector size of each position and H = vector size of each position
>> q - queries, B * L * H >> q - queries, B * L * H for encoders, B * 1 * H for decoders during inference
>> v - values, B * L * H for encoders, B * 1 * H for decoders >> v - values, B * L * H
>> mask - as it is >> mask - as it is
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
>> cache - decoder cache >> cache - decoder cache
...@@ -188,9 +192,9 @@ XTensor Attention::MakeAttention(XTensor& k, XTensor& q, XTensor& v, ...@@ -188,9 +192,9 @@ XTensor Attention::MakeAttention(XTensor& k, XTensor& q, XTensor& v,
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS); dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if (mask) if (mask)
dot = dot + *mask; dot = Sum(dot, *mask, /*inplace=*/true);
dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead)); dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead), 0.0F, true);
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
...@@ -244,7 +248,7 @@ XTensor Attention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, ...@@ -244,7 +248,7 @@ XTensor Attention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v,
XTensor embMatrix, relativeKey; XTensor embMatrix, relativeKey;
/* generate the relative emb index (L_q, L_kv) */ /* generate the relative emb index (L_q, L_kv) */
embMatrix = GetRPEmbedding(lenQ, lenKV, maxRP, isEnc || isTraining); embMatrix = GetRPEmbedding(lenQ, lenKV, maxRP, isEnc || isTraining, isTraining);
/* generate the relative key from the RPEmbK (L_q, L_kv, H/K) */ /* generate the relative key from the RPEmbK (L_q, L_kv, H/K) */
relativeKey = Gather(RPEmbK, embMatrix); relativeKey = Gather(RPEmbK, embMatrix);
...@@ -255,13 +259,13 @@ XTensor Attention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, ...@@ -255,13 +259,13 @@ XTensor Attention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v,
relativeKey = ConvertDataType(relativeKey, X_FLOAT); relativeKey = ConvertDataType(relativeKey, X_FLOAT);
} }
float scaling = (float)sqrt(d / nhead); float scaling = float(sqrt(d / nhead));
qheads = ScaleAndShift(qheads, 1.0F / scaling); qheads = ScaleAndShift(qheads, 1.0F / scaling);
dot = RPDotProduct(qheads, kheads, relativeKey, true); dot = RPDotProduct(qheads, kheads, relativeKey, true);
if (mask) if (mask)
dot = dot + *mask; dot = Sum(dot, *mask, /*inplace=*/true);
/* softmax */ /* softmax */
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
...@@ -287,12 +291,14 @@ generate relative position embeddings ...@@ -287,12 +291,14 @@ generate relative position embeddings
>> lenQ - the length of query >> lenQ - the length of query
>> lenKV - the length of key and value >> lenKV - the length of key and value
>> maxRelativeLen - the maximum length of relative position >> maxRelativeLen - the maximum length of relative position
>> isEnc - indicates whether it is in the encoder
*/ */
XTensor Attention::GetRPEmbedding(const int lenQ, const int lenKV, XTensor Attention::GetRPEmbedding(int lenQ, int lenKV,
const int maxRelativeLen, const bool isEnc) int maxRelativeLen, bool isEnc, bool isTraining)
{ {
XTensor range; XTensor range;
XTensor embMatrix; XTensor embMatrix;
InitTensor1D(&range, lenKV, X_INT, devID); InitTensor1D(&range, lenKV, X_INT, devID);
int* index = new int[lenKV]; int* index = new int[lenKV];
...@@ -313,11 +319,19 @@ XTensor Attention::GetRPEmbedding(const int lenQ, const int lenKV, ...@@ -313,11 +319,19 @@ XTensor Attention::GetRPEmbedding(const int lenQ, const int lenKV,
embMatrix = Unsqueeze(range, 0, lenQ); embMatrix = Unsqueeze(range, 0, lenQ);
} }
//ClipMe(embMatrix, -float(maxRelativeLen), float(maxRelativeLen)); ClipMe(embMatrix, -float(maxRelativeLen), float(maxRelativeLen));
embMatrix = Clip(embMatrix, -float(maxRelativeLen), float(maxRelativeLen)); ScaleAndShiftMe(embMatrix, 1.0F, float(maxRelativeLen));
embMatrix = ScaleAndShift(embMatrix, 1.0F, float(maxRelativeLen));
delete[] index; delete[] index;
/* disable gradient flow */
if (isTraining) {
XTensor copyEmbMatrix;
InitTensor(&copyEmbMatrix, &embMatrix);
_CopyValues(&embMatrix, &copyEmbMatrix);
return copyEmbMatrix;
}
return embMatrix; return embMatrix;
} }
...@@ -351,6 +365,7 @@ XTensor Attention::RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool i ...@@ -351,6 +365,7 @@ XTensor Attention::RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool i
XTensor context; XTensor context;
context = BMMul(x, y); context = BMMul(x, y);
int newDims[]{ headNum, batchSize, context.GetDim(1), context.GetDim(2) }; int newDims[]{ headNum, batchSize, context.GetDim(1), context.GetDim(2) };
context = Reshape(context, 4, newDims); context = Reshape(context, 4, newDims);
...@@ -358,7 +373,7 @@ XTensor Attention::RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool i ...@@ -358,7 +373,7 @@ XTensor Attention::RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool i
xTrans = Transpose(x, 0, 1); xTrans = Transpose(x, 0, 1);
XTensor relative; XTensor relative;
relative = MatrixMulBatched(xTrans, X_NOTRANS, z, transposeFlag); relative = BMMul(xTrans, X_NOTRANS, z, transposeFlag);
XTensor relativeTrans; XTensor relativeTrans;
relativeTrans = Transpose(relative, 0, 1); relativeTrans = Transpose(relative, 0, 1);
...@@ -367,7 +382,7 @@ XTensor Attention::RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool i ...@@ -367,7 +382,7 @@ XTensor Attention::RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool i
relativeTrans = Reshape(relativeTrans, 4, splitDims); relativeTrans = Reshape(relativeTrans, 4, splitDims);
return context + relativeTrans; return Sum(context, relativeTrans);
} }
/* constructor */ /* constructor */
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -93,10 +93,6 @@ public: ...@@ -93,10 +93,6 @@ public:
/* bias for V */ /* bias for V */
XTensor biasV; XTensor biasV;
XTensor wBig;
XTensor bBig;
/* RPR emb */ /* RPR emb */
XTensor RPEmbK; XTensor RPEmbK;
...@@ -148,7 +144,7 @@ public: ...@@ -148,7 +144,7 @@ public:
XTensor* mask, bool isTraining, bool isEnc); XTensor* mask, bool isTraining, bool isEnc);
/* generate relative position embeddings */ /* generate relative position embeddings */
XTensor GetRPEmbedding(const int lenQ, const int lenKV, const int maxRelativeLen, const bool isEnc); XTensor GetRPEmbedding(int lenQ, int lenKV, int maxRelativeLen, bool isEnc, bool isTraining);
/* relative position-aware dot-product attention inner calculation */ /* relative position-aware dot-product attention inner calculation */
XTensor RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool is_key); XTensor RPDotProduct(XTensor& x, XTensor& y, XTensor& z, const bool is_key);
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -29,8 +29,10 @@ namespace nmt ...@@ -29,8 +29,10 @@ namespace nmt
/* constructor */ /* constructor */
Embedder::Embedder() Embedder::Embedder()
{ {
d = -1;
devID = -1; devID = -1;
vSize = -1; vSize = -1;
eSize = -1;
maxLength = -1; maxLength = -1;
} }
...@@ -50,7 +52,7 @@ void Embedder::InitModel(Config& config, bool isEnc) ...@@ -50,7 +52,7 @@ void Embedder::InitModel(Config& config, bool isEnc)
d = config.modelSize; d = config.modelSize;
padIdx = config.padID; padIdx = config.padID;
eSize = config.embSize; eSize = config.embSize;
maxLength = config.maxPosLen; maxLength = config.maxPosition;
vSize = (isEnc) ? config.srcVocabSize : config.tgtVocabSize; vSize = (isEnc) ? config.srcVocabSize : config.tgtVocabSize;
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID); InitTensor2D(&w, vSize, eSize, X_FLOAT, devID);
...@@ -59,6 +61,10 @@ void Embedder::InitModel(Config& config, bool isEnc) ...@@ -59,6 +61,10 @@ void Embedder::InitModel(Config& config, bool isEnc)
DTYPE v = 1.0F / (float)sqrt((float)eSize); DTYPE v = 1.0F / (float)sqrt((float)eSize);
w.SetDataRandn(0, v); w.SetDataRandn(0, v);
for (int i = 0; i < eSize; i++) {
w.Set2D(0.0F, padIdx, i);
}
/* create the positional embedding matrix */ /* create the positional embedding matrix */
MakePosEmbedding(maxLength); MakePosEmbedding(maxLength);
} }
...@@ -138,13 +144,13 @@ XTensor Embedder::Make(XTensor& input, bool isDec, bool isTraining, int nstep) ...@@ -138,13 +144,13 @@ XTensor Embedder::Make(XTensor& input, bool isDec, bool isTraining, int nstep)
posEmbedding = Unsqueeze(embTMP, 0, input.GetDim(0)); posEmbedding = Unsqueeze(embTMP, 0, input.GetDim(0));
/* then we make word embeddings */ /* then we make word embeddings */
//w.enableGrad = false;
wordEmbedding = Gather(w, input); wordEmbedding = Gather(w, input);
wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize)); wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize), 0.0F, true);
/* we sum over the two embeddings */ /* we sum over the two embeddings */
SumMe(wordEmbedding, posEmbedding); SumMe(wordEmbedding, posEmbedding);
return wordEmbedding; return wordEmbedding;
} }
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -31,9 +31,11 @@ namespace nmt ...@@ -31,9 +31,11 @@ namespace nmt
/* constructor */ /* constructor */
FNN::FNN() FNN::FNN()
{ {
dropoutP = 0.0F;
inSize = -1; inSize = -1;
outSize = -1; outSize = -1;
hSize = -1; hSize = -1;
devID = -1;
} }
/* de-constructor */ /* de-constructor */
...@@ -66,8 +68,8 @@ void FNN::InitModel(Config& config) ...@@ -66,8 +68,8 @@ void FNN::InitModel(Config& config)
_SetDataFanInOut(&w1, scale); _SetDataFanInOut(&w1, scale);
_SetDataFanInOut(&w2, scale); _SetDataFanInOut(&w2, scale);
w1.SetDataRand(-(DTYPE)sqrt(6.0F / inSize), (DTYPE)sqrt(6.0F / inSize)); //w1.SetDataRand(-(DTYPE)sqrt(6.0F / inSize), (DTYPE)sqrt(6.0F / inSize));
w2.SetDataRand(-(DTYPE)sqrt(6.0F / hSize), (DTYPE)sqrt(6.0F / hSize)); //w2.SetDataRand(-(DTYPE)sqrt(6.0F / hSize), (DTYPE)sqrt(6.0F / hSize));
b1.SetZeroAll(); b1.SetZeroAll();
b2.SetZeroAll(); b2.SetZeroAll();
...@@ -87,7 +89,7 @@ XTensor FNN::Make(XTensor& input, bool isTraining) ...@@ -87,7 +89,7 @@ XTensor FNN::Make(XTensor& input, bool isTraining)
t1 = Rectify(MulAndShift(input, w1, b1)); t1 = Rectify(MulAndShift(input, w1, b1));
if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP); t1 = Dropout(t1, dropoutP, /*inplace=*/true);
/* result = t1 * w2 + b2 */ /* result = t1 * w2 + b2 */
return MulAndShift(t1, w2, b2); return MulAndShift(t1, w2, b2);
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -30,6 +30,7 @@ namespace nmt ...@@ -30,6 +30,7 @@ namespace nmt
/* constructor */ /* constructor */
GLU::GLU() GLU::GLU()
{ {
devID = -1;
inSize = -1; inSize = -1;
outSize = -1; outSize = -1;
hSize = -1; hSize = -1;
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
/* /*
* $Created by: Bei Li (libei_neu@outlook.com) 2020-02-03 * $Created by: Bei Li (libei_neu@outlook.com) 2020-02-03
* $Modified by: Chi Hu (huchinlp@gmail.com) 2020-12-10
*/ */
#include "Embedding.h" #include "Embedding.h"
...@@ -23,6 +24,7 @@ ...@@ -23,6 +24,7 @@
#include "LayerHistory.h" #include "LayerHistory.h"
#include "../Utility.h" #include "../Utility.h"
#include "../../../tensor/core/CHeader.h" #include "../../../tensor/core/CHeader.h"
#include "../../../tensor/XName.h"
#define SAFE_DELETE(x) do{ if((x) != NULL){delete (x); (x) = NULL;} } while(false) #define SAFE_DELETE(x) do{ if((x) != NULL){delete (x); (x) = NULL;} } while(false)
#define SAFE_DELETE_ARRAY(x) do{ if((x) != NULL) {delete [] (x); (x)=NULL;} } while(false) #define SAFE_DELETE_ARRAY(x) do{ if((x) != NULL) {delete [] (x); (x)=NULL;} } while(false)
...@@ -34,16 +36,20 @@ namespace nmt ...@@ -34,16 +36,20 @@ namespace nmt
LayerHistory::LayerHistory() LayerHistory::LayerHistory()
{ {
d = -1; d = -1;
devID = -1;
count = -1; count = -1;
weight = NULL; nlayer = -1;
weights = NULL;
history = NULL;
layerNorms = NULL; layerNorms = NULL;
} }
/* de-constructor */ /* de-constructor */
LayerHistory::~LayerHistory() LayerHistory::~LayerHistory()
{ {
history.Clear(); delete history;
delete[] layerNorms; delete[] layerNorms;
delete[] weights;
} }
/* /*
...@@ -56,7 +62,20 @@ void LayerHistory::InitModel(Config& config) ...@@ -56,7 +62,20 @@ void LayerHistory::InitModel(Config& config)
d = config.modelSize; d = config.modelSize;
nlayer = config.nEncLayer; nlayer = config.nEncLayer;
InitTensor2D(&weight, nlayer + 1, nlayer + 1, X_FLOAT, devID); /* the triangle weight matrices for dlcl
layer 0: [1, 0, ..., 0]
layer 1: [0.5, 0.5, ..., 0]
layer 2: [0.33, 0.33, 0.33, ..., 0] */
weights = new XTensor[nlayer + 1];
for (int i = 0; i < nlayer + 1; i++) {
InitTensor1D(&(weights[i]), i + 1, X_FLOAT, devID);
float* data = new float[i + 1];
for (int j = 0; j < i + 1; j++) {
data[j] = 1.0F / float(i + 1);
}
weights[i].SetData(data, i + 1);
delete[] data;
}
layerNorms = new LN[nlayer]; layerNorms = new LN[nlayer];
...@@ -68,59 +87,88 @@ void LayerHistory::InitModel(Config& config) ...@@ -68,59 +87,88 @@ void LayerHistory::InitModel(Config& config)
/* /*
the Add operation the Add operation
>> tensor - the previous layer output. It might be of size B * L * H >> layer - the previous layer output. It might be of size B * L * H
where B = batch size, L = sequence length, where B = batch size, L = sequence length,
and H = vector size of each position and H = vector size of each position
*/ */
void LayerHistory::Add(XTensor& tensor) void LayerHistory::Add(XTensor& layer)
{ {
/* the embedding is not normed */ /* the embedding is not normed */
count += 1; count += 1;
if (history.Size() == 0) { if (history->count == 0) {
history->Add(layer);
//sample_ = tensor;
history.Add(&tensor);
return; return;
} }
XTensor ln = layerNorms[count - 2].Make(tensor); layer = layerNorms[count - 2].Make(layer);
history.Add(&ln); history->Add(layer);
} }
/* /*
generate the weight sum vector of all previous layer output in the history as the layer input calculate the weighted sum of previous layers
the result for the i-th layer is:
result = sum(layers[0...i] * weight[i][0...i])
shape of the result: B * L * H
*/ */
XTensor LayerHistory::Pop() XTensor LayerHistory::Pop()
{ {
/* the number of layer output in the history */ TensorList list;
int size = (int)history.Size(); for (int i = 0; i < history->count; i++) {
list.Add(&(history->list[i]));
TensorList historyList; }
for (int i = 0; i < size; i++) XTensor stack;
historyList.Add(history[i]); stack = Merge(list, 0);
//Stack(list, 0);
/* we need stack the tensor along the first dim*/ int dimSize[MAX_TENSOR_DIM_NUM];
XTensor stackTensor = Stack(historyList, 0); for (int i = 0; i < stack.order + 1; i++)
dimSize[i + 1] = stack.dimSize[i];
dimSize[0] = int(list.Size());
dimSize[1] /= dimSize[0];
stack = Reshape(stack, stack.order + 1, dimSize);
XTensor interWeight; XTensor res;
res = MultiplyDim(stack, weights[list.Size() - 1], 0);
InitTensor2D(&interWeight, 1, weight.dimSize[1], DEFAULT_DTYPE, devID); return ReduceSum(res, 0);
XTensor layerWeight; }
InitTensor1D(&layerWeight, size, DEFAULT_DTYPE, devID);
_SelectRange(&weight, &interWeight, 0, size - 1, size); /* clear the history */
interWeight.Reshape(interWeight.unitNum); void LayerHistory::ClearHistory(bool reset)
_SelectRange(&interWeight, &layerWeight, 0, 0, size); {
MultiplyDimMe(stackTensor, layerWeight, 0); if(history != NULL)
delete history;
if(reset)
history = new History;
else
history = NULL;
count = 0;
}
XTensor result; /* initialize the history */
ReduceSum(stackTensor, result, 0); History::History()
{
count = 0;
}
return result; /* delete the history */
History::~History()
{
for (int i = 0; i < MAX_LAYER_NUM; i++) {
list[i].DestroyData();
XLink::ClearOutgoing(&list[i]);
XLink::ClearIncoming(&list[i]);
if (list[i].grad != NULL)
delete list[i].grad;
}
} }
void LayerHistory::ClearHistory() /* append a layer to the history */
void History::Add(XTensor& layer)
{ {
history.Clear(); list[count] = std::move(layer);
XLink::ClearOutgoing(&layer);
XLink::ClearIncoming(&layer);
count++;
} }
} }
\ No newline at end of file
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -27,16 +27,42 @@ ...@@ -27,16 +27,42 @@
#include "../../../tensor/function/FHeader.h" #include "../../../tensor/function/FHeader.h"
using namespace nts; using namespace nts;
using namespace std;
namespace nmt namespace nmt
{ {
#define MAX_LAYER_NUM 50
/*
the class of history list
*/
class History {
public:
/* number of elements in the list */
int count;
/* the history list */
XTensor list[MAX_LAYER_NUM];
public:
/* contructor */
History();
/* de-contructor */
~History();
/* append a layer to the list */
void Add(XTensor& layer);
};
/* /*
multi-head attention the class of layer history
y(Q, K, V) = cat(head_1, head_2, ..., head_n) it generates the weighted sum of previous layers
where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V) the result for the i-th layer is:
attention(Q, K, V) = softmax(Q * K^T/d_k^0.5) V res = sum(layers[0...i] * weight[i][0...i])
d_k = dimension size of K
*/ */
class LayerHistory class LayerHistory
{ {
...@@ -44,8 +70,8 @@ public: ...@@ -44,8 +70,8 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* the triangle weight matrix for dlcl */ /* the triangle weight matrices for dlcl */
XTensor weight; XTensor* weights;
/* hidden size */ /* hidden size */
int d; int d;
...@@ -57,7 +83,7 @@ public: ...@@ -57,7 +83,7 @@ public:
int count; int count;
/* a history to store the value of intimidate layers */ /* a history to store the value of intimidate layers */
TensorList history; History* history;
/* layer normalization for each intimidate layer */ /* layer normalization for each intimidate layer */
LN* layerNorms; LN* layerNorms;
...@@ -79,7 +105,7 @@ public: ...@@ -79,7 +105,7 @@ public:
XTensor Pop(); XTensor Pop();
/* clean the history*/ /* clean the history*/
void ClearHistory(); void ClearHistory(bool reset=true);
}; };
} }
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -51,12 +51,12 @@ void LN::InitModel(Config& config) ...@@ -51,12 +51,12 @@ void LN::InitModel(Config& config)
d = config.modelSize; d = config.modelSize;
InitTensor1D(&w, d, X_FLOAT, devID); InitTensor1D(&weight, d, X_FLOAT, devID);
InitTensor1D(&b, d, X_FLOAT, devID); InitTensor1D(&bias, d, X_FLOAT, devID);
w.SetDataRand(1.0F, 1.0F); weight.SetDataRand(1.0F, 1.0F);
b.SetZeroAll(); bias.SetZeroAll();
w.SetDataFixed(1); weight.SetDataFixed(1);
} }
/* /*
...@@ -104,7 +104,11 @@ XTensor LN::Make(XTensor& input) ...@@ -104,7 +104,11 @@ XTensor LN::Make(XTensor& input)
} }
/* result = x' * w + b */ /* result = x' * w + b */
return xn * w + b; xn = xn * weight;
xn = Sum(xn, bias, true);
return xn;
} }
} }
\ No newline at end of file
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -39,10 +39,10 @@ public: ...@@ -39,10 +39,10 @@ public:
int devID; int devID;
/* the transformation matrix w */ /* the transformation matrix w */
XTensor w; XTensor weight;
/* the bias term b */ /* the bias term b */
XTensor b; XTensor bias;
/* dimension size of the model */ /* dimension size of the model */
int d; int d;
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -33,6 +33,7 @@ Output::Output() ...@@ -33,6 +33,7 @@ Output::Output()
devID = -1; devID = -1;
vSize = -1; vSize = -1;
hSize = -1; hSize = -1;
padIdx = -1;
} }
/* de-constructor */ /* de-constructor */
...@@ -49,11 +50,15 @@ void Output::InitModel(Config& config) ...@@ -49,11 +50,15 @@ void Output::InitModel(Config& config)
devID = config.devID; devID = config.devID;
hSize = config.modelSize; hSize = config.modelSize;
vSize = config.tgtVocabSize; vSize = config.tgtVocabSize;
padIdx = config.padID;
InitTensor2D(&w, vSize, hSize, X_FLOAT, devID); InitTensor2D(&w, vSize, hSize, X_FLOAT, devID);
DTYPE v = 1.0F / (float)sqrt((float)hSize); DTYPE v = 1.0F / (float)sqrt((float)hSize);
w.SetDataRandn(0, v); w.SetDataRandn(0, v);
for (int i = 0; i < hSize; i++) {
w.Set2D(0.0F, padIdx, i);
}
} }
/* /*
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -43,6 +43,9 @@ public: ...@@ -43,6 +43,9 @@ public:
/* vector size of the linear transformation */ /* vector size of the linear transformation */
int hSize; int hSize;
/* the padding index */
int padIdx;
/* transformation matrix */ /* transformation matrix */
XTensor w; XTensor w;
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -32,19 +32,19 @@ using namespace nmt; ...@@ -32,19 +32,19 @@ using namespace nmt;
namespace nts { namespace nts {
/* sort the output by id (in ascending order) */ /* sort the input by length (in descending order) */
void DataSet::SortInput() { void DataSet::SortInput() {
sort(inputBuffer.items, inputBuffer.items + inputBuffer.count, sort(inputBuffer.begin(), inputBuffer.end(),
[](Example* a, Example* b) { [](const Example& a, const Example& b) {
return a->values.count > b->values.count; return a.values.size() > b.values.size();
}); });
} }
/* sort the input by length (in descending order) */ /* sort the output by id (in ascending order) */
void DataSet::SortOutput() { void DataSet::SortOutput() {
sort(outputBuffer.items, outputBuffer.items + outputBuffer.count, sort(outputBuffer.begin(), outputBuffer.end(),
[](Result* a, Result* b) { [](const Example& a, const Example& b) {
return a->id < b->id; return a.id < b.id;
}); });
} }
...@@ -54,43 +54,43 @@ load data from the file to the buffer ...@@ -54,43 +54,43 @@ load data from the file to the buffer
void DataSet::LoadDataToBuffer() void DataSet::LoadDataToBuffer()
{ {
string line; string line;
inputBuffer.Clear(); inputBuffer.clear();
bufferUsed = 0; bufferUsed = 0;
int id = 0; int id = 0;
const string tokenDelimiter = " "; const string tokenDelimiter = " ";
while (getline(*fp, line)) { while (getline(*fp, line)) {
IntList values; vector<int> values;
/* load words and transform them to ids */ /* load words and transform them to ids */
auto indices = SplitToPos(line, tokenDelimiter); UInt64List indices = SplitToPos(line, tokenDelimiter);
/* reserve the first 120 words if the input is too long */ /* reserve the first maxSrcLen words if the input is too long */
size_t maxLen = indices.Size() > MAX_WORD_NUM ? MAX_WORD_NUM : indices.Size(); int maxLen = int(indices.Size()) > maxSrcLen ? maxSrcLen : int(indices.Size());
for (size_t i = 0; i < maxLen; i++) { for (int i = 0; i < maxLen; i++) {
size_t offset = (i != (indices.Size() - 1)) ? auto offset = (i != (int(indices.Size()) - 1)) ?
(size_t)indices[(int)i + 1] - (size_t)indices[(int)i] - tokenDelimiter.size() indices[i + 1] - indices[i] - tokenDelimiter.size()
: line.size() - (size_t)indices[(int)i]; : line.size() - indices[i];
string word = line.substr((size_t)indices[(int)i], offset); string word = line.substr(indices[i], offset);
if (srcVocab.word2id.find(word) == srcVocab.word2id.end()) if (srcVocab.word2id.find(word) == srcVocab.word2id.end())
values.Add(UNK); values.emplace_back(unkID);
else else
values.Add(srcVocab.word2id.at(word)); values.emplace_back(srcVocab.word2id.at(word));
} }
/* make sure that the sequence ends with EOS */ /* make sure that the sequence ends with EOS */
if (values.Size() != 0 && values[-1] != EOS) if (values.size() != 0 && values.back() != endID)
values.Add(EOS); values.emplace_back(endID);
Example* example = new Example; Example example;
example->id = id; example.id = id;
example->values = values; example.values = values;
if (values.Size() != 0) if (values.size() != 0)
inputBuffer.Add(example); inputBuffer.emplace_back(example);
else else
emptyLines.Add(id); emptyLines.emplace_back(id);
id++; id++;
} }
fp->close(); fp->close();
...@@ -109,23 +109,16 @@ load a mini-batch to the device (for translating) ...@@ -109,23 +109,16 @@ load a mini-batch to the device (for translating)
>> devID - the device id, -1 for the CPU >> devID - the device id, -1 for the CPU
<< indices of the sentences << indices of the sentences
*/ */
UInt64List DataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, UInt64List DataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, int minSentBatch, int batchSize, int devID)
int minSentBatch, int batchSize, int devID)
{ {
int realBatchSize = minSentBatch; int realBatchSize = minSentBatch;
/* get the maximum sentence length in a mini-batch */ /* get the maximum sentence length in a mini-batch */
int maxLen = (int)inputBuffer[(int)bufferUsed]->values.Size(); int maxLen = int(inputBuffer[bufferUsed].values.size());
/* dynamic batching for sentences */
//while ((realBatchSize < (inputBuffer.Size() - bufferUsed))
// && (realBatchSize * maxLen < batchSize)) {
// realBatchSize++;
//}
/* real batch size */ /* real batch size */
if ((inputBuffer.Size() - bufferUsed) < realBatchSize) { if ((int(inputBuffer.size()) - bufferUsed) < realBatchSize) {
realBatchSize = (int)(inputBuffer.Size() - bufferUsed); realBatchSize = int(inputBuffer.size()) - bufferUsed;
} }
CheckNTErrors(maxLen != 0, "invalid length"); CheckNTErrors(maxLen != 0, "invalid length");
...@@ -134,25 +127,25 @@ UInt64List DataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, ...@@ -134,25 +127,25 @@ UInt64List DataSet::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
float* paddingValues = new float[realBatchSize * maxLen]; float* paddingValues = new float[realBatchSize * maxLen];
for (int i = 0; i < realBatchSize * maxLen; i++) { for (int i = 0; i < realBatchSize * maxLen; i++) {
batchValues[i] = PAD; batchValues[i] = padID;
paddingValues[i] = 1.0F; paddingValues[i] = 1.0F;
} }
size_t curSrc = 0; int curSrc = 0;
/* right padding */ /* right padding */
UInt64List infos; UInt64List infos;
size_t totalLength = 0; int totalLength = 0;
for (size_t i = 0; i < (size_t)realBatchSize; ++i) { for (int i = 0; i < realBatchSize; ++i) {
infos.Add(inputBuffer[(int)(bufferUsed + i)]->id); infos.Add(inputBuffer[bufferUsed + i].id);
totalLength += inputBuffer[(int)(bufferUsed + i)]->values.Size(); totalLength += int(inputBuffer[bufferUsed + i].values.size());
curSrc = maxLen * i; curSrc = maxLen * i;
for (size_t j = 0; j < inputBuffer[(int)(bufferUsed + i)]->values.Size(); j++) for (int j = 0; j < int(inputBuffer[bufferUsed + i].values.size()); j++)
batchValues[(int)(curSrc++)] = (int)inputBuffer[(int)(bufferUsed + i)]->values[(int)j]; batchValues[curSrc++] = inputBuffer[bufferUsed + i].values[j];
while (curSrc < maxLen * (i + 1)) while (curSrc < maxLen * (i + 1))
paddingValues[(int)(curSrc++)] = 0; paddingValues[curSrc++] = 0;
} }
infos.Add(totalLength); infos.Add(totalLength);
...@@ -201,7 +194,7 @@ void DataSet::Init(const char* dataFile, const char* srcVocabFN, const char* tgt ...@@ -201,7 +194,7 @@ void DataSet::Init(const char* dataFile, const char* srcVocabFN, const char* tgt
/* check if the buffer is empty */ /* check if the buffer is empty */
bool DataSet::IsEmpty() { bool DataSet::IsEmpty() {
if (bufferUsed < inputBuffer.Size()) if (bufferUsed < inputBuffer.size())
return false; return false;
return true; return true;
} }
...@@ -211,12 +204,11 @@ void DataSet::DumpRes(const char* ofn) ...@@ -211,12 +204,11 @@ void DataSet::DumpRes(const char* ofn)
{ {
ofstream ofile(ofn, ios::out); ofstream ofile(ofn, ios::out);
for (int t = 0; t < outputBuffer.Size(); t++) { for (const auto& tgtSent : outputBuffer) {
auto res = outputBuffer[t]; for (const auto& w : tgtSent.values) {
for (int i = 0; i < res->res.Size(); i++) { if (w < 4)
if (res->res[i] < 4)
break; break;
ofile << tgtVocab.id2word[res->res[i]] << " "; ofile << tgtVocab.id2word[w] << " ";
} }
ofile << "\n"; ofile << "\n";
} }
...@@ -229,14 +221,6 @@ DataSet::~DataSet() ...@@ -229,14 +221,6 @@ DataSet::~DataSet()
{ {
/* release the file */ /* release the file */
delete fp; delete fp;
/* release the input buffer */
for (int i = 0; i < inputBuffer.Size(); i++)
delete inputBuffer[i];
/* release the output buffer */
for (int i = 0; i < outputBuffer.Size(); i++)
delete outputBuffer[i];
} }
} }
\ No newline at end of file
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -31,40 +31,41 @@ ...@@ -31,40 +31,41 @@
#include "../../../tensor/XTensor.h" #include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h" #include "../../../tensor/XGlobal.h"
#define MAX_WORD_NUM 120
using namespace std; using namespace std;
namespace nts { namespace nts {
/* the struct of tokenized input */ /* the struct of tokenized input */
struct Example { struct Example {
int id;
IntList values;
};
/* the struct of tokenized output */
struct Result {
int id; int id;
IntList res;
vector<int> values;
public:
Example() {
id = 0;
}
}; };
/* A `DataSet` is associated with a file which contains variable length data.*/ /* A `DataSet` is associated with a file which contains variable length data.*/
struct DataSet { struct DataSet {
public: public:
/* the data buffer */ /* the data buffer */
InputBufferType inputBuffer; vector<Example> inputBuffer;
/* a list of empty line number */ /* a list of empty line number */
IntList emptyLines; vector<int> emptyLines;
/* the result buffer */ /* the result buffer */
OutputBufferType outputBuffer; vector<Example> outputBuffer;
/* the pointer to file stream */ /* the pointer to file stream */
ifstream* fp; ifstream* fp;
/* size of used data in buffer */ /* size of used data in buffer */
size_t bufferUsed; int bufferUsed;
/* the source vocabulary */ /* the source vocabulary */
Vocab srcVocab; Vocab srcVocab;
...@@ -72,6 +73,21 @@ public: ...@@ -72,6 +73,21 @@ public:
/* the target vocabulary */ /* the target vocabulary */
Vocab tgtVocab; Vocab tgtVocab;
/* the maximum length of an input sequence */
int maxSrcLen;
/* the padding id */
int padID;
/* the unk id */
int unkID;
/* start symbol */
int startID;
/* end symbol */
int endID;
public: public:
/* sort the input by length */ /* sort the input by length */
...@@ -84,8 +100,7 @@ public: ...@@ -84,8 +100,7 @@ public:
void LoadDataToBuffer(); void LoadDataToBuffer();
/* generate a mini-batch */ /* generate a mini-batch */
UInt64List LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, UInt64List LoadBatch(XTensor* batchEnc, XTensor* paddingEnc, int minSentBatch, int batchSize, int devID);
int sBatch, int wBatch, int devID);
/* initialization function */ /* initialization function */
void Init(const char* dataFile, const char* srcVocabFN, const char* tgtVocabFN); void Init(const char* dataFile, const char* srcVocabFN, const char* tgtVocabFN);
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -42,7 +42,7 @@ float LengthPenalizer::GNMT(float length, float alpha) ...@@ -42,7 +42,7 @@ float LengthPenalizer::GNMT(float length, float alpha)
base = (length + 5.0F) / (1.0F + 5.0F); base = (length + 5.0F) / (1.0F + 5.0F);
lp = pow(base, alpha); lp = float(pow(base, alpha));
return lp; return lp;
} }
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -727,6 +727,7 @@ GreedySearch::GreedySearch() ...@@ -727,6 +727,7 @@ GreedySearch::GreedySearch()
endSymbolNum = 0; endSymbolNum = 0;
endSymbols = new int[32]; endSymbols = new int[32];
startSymbol = -1; startSymbol = -1;
scalarMaxLength = -1;
} }
/* de-constructor */ /* de-constructor */
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -54,6 +54,12 @@ void Translator::Init(Config& config) ...@@ -54,6 +54,12 @@ void Translator::Init(Config& config)
sentBatch = config.sBatchSize; sentBatch = config.sBatchSize;
wordBatch = config.wBatchSize; wordBatch = config.wBatchSize;
batchLoader.maxSrcLen = config.maxSrcLen;
batchLoader.unkID = config.unkID;
batchLoader.padID = config.padID;
batchLoader.startID = config.startID;
batchLoader.endID = config.endID;
if (beamSize > 1) { if (beamSize > 1) {
LOG("translating with beam search (%d)", beamSize); LOG("translating with beam search (%d)", beamSize);
seacher = new BeamSearch(); seacher = new BeamSearch();
...@@ -123,12 +129,12 @@ void Translator::Translate(const char* ifn, const char* sfn, ...@@ -123,12 +129,12 @@ void Translator::Translate(const char* ifn, const char* sfn,
XTensor score; XTensor score;
((BeamSearch*)seacher)->Search(model, batchEnc, paddingEnc, output, score); ((BeamSearch*)seacher)->Search(model, batchEnc, paddingEnc, output, score);
} }
for (int i = 0; i < indices.Size() - 1; ++i) { for (int i = 0; i < indices.Size() - 1; ++i) {
Result* res = new Result; Example res;
res->id = int(indices[i]); res.id = int(indices[i]);
res->res = output[i]; for (int j = 0; j < output[i].Size(); j++)
batchLoader.outputBuffer.Add(res); res.values.emplace_back(output[i][j]);
batchLoader.outputBuffer.emplace_back(std::move(res));
} }
delete[] output; delete[] output;
...@@ -142,17 +148,17 @@ void Translator::Translate(const char* ifn, const char* sfn, ...@@ -142,17 +148,17 @@ void Translator::Translate(const char* ifn, const char* sfn,
double elapsed = GetClockSec() - batchStart; double elapsed = GetClockSec() - batchStart;
batchStart = GetClockSec(); batchStart = GetClockSec();
LOG("elapsed=%.1fs, sentence=%f, sword=%.1fw/s", LOG("elapsed=%.1fs, sentence=%f, sword=%.1fw/s",
elapsed, float(sentCount) / float(batchLoader.inputBuffer.Size()), elapsed, float(sentCount) / float(batchLoader.inputBuffer.size()),
double(wc) / elapsed); double(wc) / elapsed);
wc = 0; wc = 0;
} }
} }
/* append empty lines to the result */ /* append empty lines to the result */
for (int i = 0; i < batchLoader.emptyLines.Size(); i++) { for (const auto& empty: batchLoader.emptyLines) {
Result* emptyRes = new Result; Example emptyRes;
emptyRes->id = batchLoader.emptyLines[i]; emptyRes.id = empty;
batchLoader.outputBuffer.Add(emptyRes); batchLoader.outputBuffer.emplace_back(emptyRes);
} }
double startDump = GetClockSec(); double startDump = GetClockSec();
...@@ -166,7 +172,7 @@ void Translator::Translate(const char* ifn, const char* sfn, ...@@ -166,7 +172,7 @@ void Translator::Translate(const char* ifn, const char* sfn,
double elapsed = GetClockSec() - startDump; double elapsed = GetClockSec() - startDump;
LOG("translation completed (word=%d, sent=%zu)", LOG("translation completed (word=%d, sent=%zu)",
wordCountTotal, batchLoader.inputBuffer.Size() + batchLoader.emptyLines.Size()); wordCountTotal, batchLoader.outputBuffer.size() + batchLoader.emptyLines.size());
} }
/* /*
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
/* NiuTrans.NMT - an open-source neural machine translation system. /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2020 NiuTrans Research. All rights reserved. * Copyright (C) 2020 NiuTrans Research. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论