Commit cf43c58c by xiaotong

new code

parent 56326405
......@@ -109,8 +109,6 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
XTensor fnn;
XTensor res;
llnum = -1;
/* we skip the residual connection for the first layer if
the encoder is used in language modeling. */
if(skipInputRes && i == 0){
......
......@@ -58,7 +58,7 @@ void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE * 4);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
......
......@@ -325,11 +325,19 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
char line[MAX_SEQUENCE_LENGTH];
struct SampleNode
{
int id;
int size;
};
/*
load data to buffer
>> file - where to load data
>> isSorted - indicates whether the samples are sorted by length
>> step - the number of sequences we go over when move to the next sample
*/
int T2TTrainer::LoadBuf(FILE * file)
int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
{
int lineCount = 0;
int seqCount = 0;
......@@ -395,6 +403,17 @@ int T2TTrainer::LoadBuf(FILE * file)
nseqBuf = seqCount;
nextSeq = 0;
if (isSorted) {
SampleNode * nodes = new SampleNode[seqCount];
int count = 0;
for (int i = 0; i < seqCount; i += step) {
nodes[count].id = count;
nodes[count].size = seqLen[i];
count++;
}
delete[] nodes;
}
return lineCount;
}
......@@ -422,7 +441,7 @@ load a batch of sequences
>> devID - device id
>> mem - memory pool
*/
int T2TTrainer::LoadBatch(FILE * file, bool isLM,
int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output,
int * seqs,
int step, int vs, int sBatch, int wBatch,
......@@ -430,7 +449,7 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
int devID, XMem * mem)
{
if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file);
LoadBuf(file, isSorted);
int seq = MAX(nextSeq, 0);
int wc = 0;
......
......@@ -118,7 +118,7 @@ public:
void Test(const char * fn, const char * ofn, T2TModel * model);
/* load data to buffer */
int LoadBuf(FILE * file);
int LoadBuf(FILE * file, bool isSorted, int step);
/* clear data buffer */
void ClearBuf();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论