Commit 543320bc by xiaotong

remove </s> for transformer lm

parent 98ac897b
...@@ -112,6 +112,7 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -112,6 +112,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1); LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1);
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false); LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
LoadParamInt(argc, argv, "updatestep", &updateStep, 1); LoadParamInt(argc, argv, "updatestep", &updateStep, 1);
LoadParamBool(argc, argv, "doubledend", &isDoubledEnd, false);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize]; buf2 = new int[bufSize];
...@@ -590,7 +591,9 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM, ...@@ -590,7 +591,9 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
int sc = 0; int sc = 0;
int max = 0; int max = 0;
while(seq + sc < nseqBuf){ while(seq + sc < nseqBuf){
wn = seqLen[seq + sc]; int len = isDoubledEnd ? seqLen[seq + sc] : seqLen[seq + sc] - 1;
CheckNTErrors(len > 0, "Empty sequence!");
wn = len;
wc += wn; wc += wn;
sc += 1; sc += 1;
...@@ -645,13 +648,19 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM, ...@@ -645,13 +648,19 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
/* this might be slow on GPUs :( */ /* this might be slow on GPUs :( */
for(int s = seq; s < seq + sc; s++){ for(int s = seq; s < seq + sc; s++){
for(int w = 0; w < seqLen[s]; w++){ int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1;
CheckNTErrors(len <= max, "Something is wrong!");
for(int w = 0; w < len; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
padding->Set2D(1.0F, s - seq, w); padding->Set2D(1.0F, s - seq, w);
if(w > 0) if(w > 0)
output->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]); output->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
if(w == seqLen[s] - 1) if(w == len - 1){
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); if(isDoubledEnd)
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
else
output->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
}
wCount++; wCount++;
/*fprintf(tf, "%d", buf[seqOffset[s] + w]); /*fprintf(tf, "%d", buf[seqOffset[s] + w]);
if(w < seqLen[s] - 1) if(w < seqLen[s] - 1)
...@@ -663,7 +672,7 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM, ...@@ -663,7 +672,7 @@ int T2TTrainer::LoadBatch(FILE * file, bool isLM,
} }
if(seqs != NULL){ if(seqs != NULL){
for(int w = seqLen[s]; w < max; w++) for(int w = len; w < max; w++)
seqs[seqSize++] = -1; seqs[seqSize++] = -1;
} }
} }
......
...@@ -127,6 +127,9 @@ public: ...@@ -127,6 +127,9 @@ public:
/* number of batches on which we do model update */ /* number of batches on which we do model update */
int updateStep; int updateStep;
/* indicates whether we double the </s> symble for the output of lms */
bool isDoubledEnd;
public: public:
/* constructor */ /* constructor */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论