Commit 6f90577d by xiaotong

fix the bug in sorting sequences

parent a037d802
...@@ -194,8 +194,8 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -194,8 +194,8 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem); InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
a given sequence. */ a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
......
...@@ -525,6 +525,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -525,6 +525,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
/* sort the sequences by length */ /* sort the sequences by length */
if (isSorted) { if (isSorted) {
CheckNTErrors(seqCount % step == 0, "Wrong number of sequences!");
SampleNode * nodes = new SampleNode[seqCount]; SampleNode * nodes = new SampleNode[seqCount];
int count = 0; int count = 0;
int offset = 0; int offset = 0;
...@@ -540,18 +541,18 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -540,18 +541,18 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
offset += node.size; offset += node.size;
} }
qsort(nodes, seqCount, sizeof(SampleNode), CompareSampleNode); qsort(nodes, count, sizeof(SampleNode), CompareSampleNode);
count = 0; count = 0;
offset = 0; offset = 0;
for(int i = 0; i < seqCount; i++){ for(int i = 0; i < seqCount; i += step){
SampleNode &node = nodes[count]; SampleNode &node = nodes[count];
memcpy(buf2 + offset, node.p, sizeof(int) * node.size); memcpy(buf2 + offset, node.p, sizeof(int) * node.size);
for(int j = 0; j < step; j++){ for(int j = 0; j < step; j++){
seqLen2[count + j] = seqLen[node.id + j]; seqLen2[i + j] = seqLen[node.id + j];
seqOffset[count + j] = offset + (j > 0 ? seqLen[node.id + j - 1] : 0); seqOffset[i + j] = offset + (j > 0 ? seqLen[node.id + j - 1] : 0);
} }
count += step; count += 1;
offset += node.size; offset += node.size;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论