Commit c44d1a79 by xuchen

gather implementation in mt

parent a400b619
...@@ -193,13 +193,21 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -193,13 +193,21 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor maskDec; XTensor maskDec;
/* generate mask to see "previous" words on the decoder side */ /* generate mask to see "previous" words on the decoder side */
int len = inputDec.GetDim(inputDec.order - 2); //int len = inputDec.GetDim(inputDec.order - 2);
int * dims = new int[inputDec.order + 1]; //int * dims = new int[inputDec.order + 1];
//for(int i = 0; i < inputDec.order; i++)
// dims[i + 1] = inputDec.GetDim(i);
//dims[0] = nhead;
//dims[inputDec.order] = len;
//InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem);
int len = inputDec.GetDim(inputDec.order - 1);
int * dims = new int[inputDec.order + 2];
for(int i = 0; i < inputDec.order; i++) for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order] = len; dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
......
...@@ -199,24 +199,6 @@ public: ...@@ -199,24 +199,6 @@ public:
int devID, XMem * mem, int devID, XMem * mem,
bool isTraining); bool isTraining);
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* shuffle the data file */ /* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile); void Shuffle(const char * srcFile, const char * tgtFile);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论