Commit 6ea64b51 by xiaotong

add length-based sorting

parent ad3fc86f
......@@ -70,6 +70,14 @@ void T2TModel::InitModel(int argc, const char ** argv)
encoder.InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
XList params(10);
GetParams(params);
for(int i = 0; i < params.count; i++){
XTensor * param = (XTensor*)params.Get(i);
param->SetVarFlag();
}
}
/*
......
......@@ -41,7 +41,9 @@ T2TTrainer::T2TTrainer()
T2TTrainer::~T2TTrainer()
{
delete[] buf;
delete[] buf2;
delete[] seqLen;
delete[] seqLen2;
delete[] seqOffset;
for(int i = 0; i < moments.count; i++){
......@@ -81,8 +83,10 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
buf = new int[bufSize];
seqLen = new int[bufSize];
buf = new int[bufSize];
buf2 = new int[bufSize];
seqLen = new int[bufSize];
seqLen2 = new int[bufSize];
seqOffset = new int[bufSize];
adamBeta1T = 1.0F;
......@@ -328,9 +332,16 @@ char line[MAX_SEQUENCE_LENGTH];
struct SampleNode
{
int id;
int * p;
int size;
int value;
};
int CompareSampleNode(const void * a, const void * b)
{
return ((SampleNode*)b)->value - ((SampleNode*)a)->value;
}
/*
load data to buffer
>> file - where to load data
......@@ -403,14 +414,46 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
nseqBuf = seqCount;
nextSeq = 0;
/* sort the sequences by length */
if (isSorted) {
SampleNode * nodes = new SampleNode[seqCount];
int count = 0;
int offset = 0;
for (int i = 0; i < seqCount; i += step) {
nodes[count].id = count;
nodes[count].size = seqLen[i];
SampleNode &node = nodes[count];
node.id = count;
node.p = buf + offset;
node.size = 0;
for(int j = 0; j < step; j++)
node.size += seqLen[i + j];
node.value = seqLen[i];
count++;
offset += node.size;
}
qsort(nodes, seqCount, sizeof(SampleNode), CompareSampleNode);
count = 0;
offset = 0;
for(int i = 0; i < seqCount; i++){
SampleNode &node = nodes[count];
//fprintf(stderr, "%d %d %d\n", node.size, node.id, node.value);
memcpy(buf2 + offset, node.p, sizeof(int) * node.size);
for(int j = 0; j < step; j++){
seqLen2[count + j] = seqLen[node.id + j];
seqOffset[count + j] = offset + (j > 0 ? seqLen[node.id + j - 1] : 0);
}
count += step;
offset += node.size;
}
int * tmp = buf;
buf = buf2;
buf2 = tmp;
tmp = seqLen;
seqLen = seqLen2;
seqLen2 = tmp;
delete[] nodes;
}
......
......@@ -40,12 +40,18 @@ public:
/* buffer for loading words */
int * buf;
/* another buffer */
int * buf2;
/* buffer size */
int bufSize;
/* length of each sequence */
int * seqLen;
/* another array */
int * seqLen2;
/* offset of the first word for each sequence */
int * seqOffset;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论