Commit c6f50a22 by xiaotong

load batch of sequence on both langauge sides

parent 430f0dfc
......@@ -67,17 +67,17 @@ void AttDecoder::InitModel(int argc, char ** argv,
/*
make the decoding network
>> input - the input tensor of the decoder
>> encoderOutput - the output tensor of the encoder
>> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder
>> mask - the mask that indicate each position is valid
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttDecoder::Make(XTensor &input, XTensor &encoderOutput, XTensor &mask, bool isTraining)
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining)
{
XTensor x;
x = embedder.Make(input);
x = embedder.Make(inputDec);
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -106,7 +106,7 @@ XTensor AttDecoder::Make(XTensor &input, XTensor &encoderOutput, XTensor &mask,
/*****************************/
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(encoderOutput, x, encoderOutput, mask, isTraining);
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......
......@@ -48,7 +48,7 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the decoding network */
XTensor Make(XTensor &input, XTensor &encoderOutput, XTensor &mask, bool isTraining);
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining);
};
}
......
......@@ -90,13 +90,27 @@ make the encoding network
>> isTraining - indicates whether we are training the model
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool isTraining)
XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining)
{
return encoder.Make(input, mask, isTraining);
}
/*
make the entire network for language modeling (with the output softmax layer)
make the decoding network
>> inputDec - input tensor of the decoder
>> outputEnc - output tensor of the encoder
>> output - output tensor (distribution)
>> mask - the mask for positions that are/not involved in computation
>> isTraining - indicates whether we are training the model
<< return - encoding result
*/
XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining)
{
return decoder.Make(inputDec, outputEnc, mask, isTraining);
}
/*
make the network for language modeling (with the output softmax layer)
>> input - input tensor
>> output - output tensor (distribution)
>> padding - padding of the sequences
......@@ -145,7 +159,7 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is
//_Sum(&mask, padding3, &mask);
encoding = MakeEncoding(input, mask, isTraining);
encoding = MakeEncoder(input, mask, isTraining);
outputLayer.Make(encoding, output);
delete[] dims;
......@@ -156,6 +170,43 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is
}
/*
make the network for machine translation (with the output softmax layer)
>> inputEnc - input tensor of the encoder
>> inputDec - input tensor of the decoder
>> output - output tensor (distribution)
>> padding - padding of the sequences
>> isTraining - indicates whether the model is for training
*/
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &padding, bool isTraining)
{
XTensor encoding;
XTensor decoding;
XTensor maskEnc;
XTensor maskDec;
/* generate mask to see "previous" words on the decoder side */
int len = inputDec.GetDim(inputDec.order - 2);
int * dims = new int[inputDec.order + 1];
for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order] = len;
InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining);
outputLayer.Make(decoding, output);
delete[] dims;
}
/*
get parameter matrics
>> list - the list that keeps the parameter matrics
*/
......
......@@ -69,11 +69,17 @@ public:
void InitModel(int argc, char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool isTraining);
XTensor MakeEncoder(XTensor &input, XTensor &mask, bool isTraining);
/* make the entire network for langauge modeling (with the output softmax layer) */
/* make the encoding network */
XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor &mask, bool isTraining);
/* make the network for langauge modeling (with the output softmax layer) */
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &padding, bool isTraining);
/* get parameter matrics */
void GetParams(XList &list);
......
......@@ -79,6 +79,9 @@ public:
/* vocabulary size of the source side */
int vSize;
/* vocabulary size of the target side */
int vSizeTgt;
/* learning rate */
float lrate;
......@@ -160,10 +163,24 @@ public:
int LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output,
int * seqs,
int step, int vs, int sBatch, int wBatch,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
XTensor * batch, XTensor * padding, XTensor * output,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论