Commit 6ea64b51 by xiaotong

add length-based sorting

parent ad3fc86f
...@@ -70,6 +70,14 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -70,6 +70,14 @@ void T2TModel::InitModel(int argc, const char ** argv)
encoder.InitModel(argc, argv, isLM, 0, devID, mem); encoder.InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem); outputLayer.InitModel(argc, argv, devID, mem);
XList params(10);
GetParams(params);
for(int i = 0; i < params.count; i++){
XTensor * param = (XTensor*)params.Get(i);
param->SetVarFlag();
}
} }
/* /*
......
...@@ -41,7 +41,9 @@ T2TTrainer::T2TTrainer() ...@@ -41,7 +41,9 @@ T2TTrainer::T2TTrainer()
T2TTrainer::~T2TTrainer() T2TTrainer::~T2TTrainer()
{ {
delete[] buf; delete[] buf;
delete[] buf2;
delete[] seqLen; delete[] seqLen;
delete[] seqLen2;
delete[] seqOffset; delete[] seqOffset;
for(int i = 0; i < moments.count; i++){ for(int i = 0; i < moments.count; i++){
...@@ -82,7 +84,9 @@ void T2TTrainer::Init(int argc, const char ** argv) ...@@ -82,7 +84,9 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F); LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize];
seqLen = new int[bufSize]; seqLen = new int[bufSize];
seqLen2 = new int[bufSize];
seqOffset = new int[bufSize]; seqOffset = new int[bufSize];
adamBeta1T = 1.0F; adamBeta1T = 1.0F;
...@@ -328,9 +332,16 @@ char line[MAX_SEQUENCE_LENGTH]; ...@@ -328,9 +332,16 @@ char line[MAX_SEQUENCE_LENGTH];
struct SampleNode struct SampleNode
{ {
int id; int id;
int * p;
int size; int size;
int value;
}; };
int CompareSampleNode(const void * a, const void * b)
{
return ((SampleNode*)b)->value - ((SampleNode*)a)->value;
}
/* /*
load data to buffer load data to buffer
>> file - where to load data >> file - where to load data
...@@ -403,14 +414,46 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -403,14 +414,46 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
nseqBuf = seqCount; nseqBuf = seqCount;
nextSeq = 0; nextSeq = 0;
/* sort the sequences by length */
if (isSorted) { if (isSorted) {
SampleNode * nodes = new SampleNode[seqCount]; SampleNode * nodes = new SampleNode[seqCount];
int count = 0; int count = 0;
int offset = 0;
for (int i = 0; i < seqCount; i += step) { for (int i = 0; i < seqCount; i += step) {
nodes[count].id = count; SampleNode &node = nodes[count];
nodes[count].size = seqLen[i]; node.id = count;
node.p = buf + offset;
node.size = 0;
for(int j = 0; j < step; j++)
node.size += seqLen[i + j];
node.value = seqLen[i];
count++; count++;
offset += node.size;
}
qsort(nodes, seqCount, sizeof(SampleNode), CompareSampleNode);
count = 0;
offset = 0;
for(int i = 0; i < seqCount; i++){
SampleNode &node = nodes[count];
//fprintf(stderr, "%d %d %d\n", node.size, node.id, node.value);
memcpy(buf2 + offset, node.p, sizeof(int) * node.size);
for(int j = 0; j < step; j++){
seqLen2[count + j] = seqLen[node.id + j];
seqOffset[count + j] = offset + (j > 0 ? seqLen[node.id + j - 1] : 0);
}
count += step;
offset += node.size;
} }
int * tmp = buf;
buf = buf2;
buf2 = tmp;
tmp = seqLen;
seqLen = seqLen2;
seqLen2 = tmp;
delete[] nodes; delete[] nodes;
} }
......
...@@ -40,12 +40,18 @@ public: ...@@ -40,12 +40,18 @@ public:
/* buffer for loading words */ /* buffer for loading words */
int * buf; int * buf;
/* another buffer */
int * buf2;
/* buffer size */ /* buffer size */
int bufSize; int bufSize;
/* length of each sequence */ /* length of each sequence */
int * seqLen; int * seqLen;
/* another array */
int * seqLen2;
/* offset of the first word for each sequence */ /* offset of the first word for each sequence */
int * seqOffset; int * seqOffset;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论