Commit 12d11fab by xiaotong

distribute samples into buckets wrt length

parent c260abfb
......@@ -122,6 +122,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, argv, "debug", &isDebugged, false);
LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false);
LoadParamInt(argc, argv, "bucketsize", &bucketSize, 0);
buf = new int[bufSize];
buf2 = new int[bufSize];
......@@ -459,6 +460,7 @@ struct SampleNode
int * p;
int size;
int value;
int key;
};
int CompareSampleNode(const void * a, const void * b)
......@@ -466,6 +468,11 @@ int CompareSampleNode(const void * a, const void * b)
return ((SampleNode*)b)->value - ((SampleNode*)a)->value;
}
int CompareSampleNodeV2(const void * a, const void * b)
{
return ((SampleNode*)b)->key - ((SampleNode*)a)->key;
}
/*
load data to buffer
>> file - where to load data
......@@ -553,12 +560,34 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
for(int j = 0; j < step; j++)
node.size += seqLen[i + j];
node.value = seqLen[i];
node.key = rand();
count++;
offset += node.size;
}
qsort(nodes, count, sizeof(SampleNode), CompareSampleNode);
if (bucketSize > 0) {
int bucketCount = 0;
int low = 0;
int high = low + bucketSize;
int n = count - 1;
int m = n;
int num = 0;
while (num < count) {
for (m = n; m >= 0; m--) {
if (nodes[m].value > high)
break;
}
qsort(nodes + m + 1, n - m, sizeof(SampleNode), CompareSampleNodeV2);
num += (n - m);
n = m;
low += bucketSize;
high = low + bucketSize;
}
}
count = 0;
offset = 0;
for(int i = 0; i < seqCount; i += step){
......
......@@ -176,6 +176,9 @@ public:
/* indicates whether we intend to debug the net */
bool isDebugged;
/* bucket size */
int bucketSize;
public:
/* constructor */
T2TTrainer();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论