Commit 12d11fab by xiaotong

distribute samples into buckets wrt length

parent c260abfb
...@@ -122,6 +122,7 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -122,6 +122,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false); LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, argv, "debug", &isDebugged, false); LoadParamBool(argc, argv, "debug", &isDebugged, false);
LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false); LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false);
LoadParamInt(argc, argv, "bucketsize", &bucketSize, 0);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize]; buf2 = new int[bufSize];
...@@ -459,6 +460,7 @@ struct SampleNode ...@@ -459,6 +460,7 @@ struct SampleNode
int * p; int * p;
int size; int size;
int value; int value;
int key;
}; };
int CompareSampleNode(const void * a, const void * b) int CompareSampleNode(const void * a, const void * b)
...@@ -466,6 +468,11 @@ int CompareSampleNode(const void * a, const void * b) ...@@ -466,6 +468,11 @@ int CompareSampleNode(const void * a, const void * b)
return ((SampleNode*)b)->value - ((SampleNode*)a)->value; return ((SampleNode*)b)->value - ((SampleNode*)a)->value;
} }
int CompareSampleNodeV2(const void * a, const void * b)
{
return ((SampleNode*)b)->key - ((SampleNode*)a)->key;
}
/* /*
load data to buffer load data to buffer
>> file - where to load data >> file - where to load data
...@@ -553,12 +560,34 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -553,12 +560,34 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
for(int j = 0; j < step; j++) for(int j = 0; j < step; j++)
node.size += seqLen[i + j]; node.size += seqLen[i + j];
node.value = seqLen[i]; node.value = seqLen[i];
node.key = rand();
count++; count++;
offset += node.size; offset += node.size;
} }
qsort(nodes, count, sizeof(SampleNode), CompareSampleNode); qsort(nodes, count, sizeof(SampleNode), CompareSampleNode);
if (bucketSize > 0) {
int bucketCount = 0;
int low = 0;
int high = low + bucketSize;
int n = count - 1;
int m = n;
int num = 0;
while (num < count) {
for (m = n; m >= 0; m--) {
if (nodes[m].value > high)
break;
}
qsort(nodes + m + 1, n - m, sizeof(SampleNode), CompareSampleNodeV2);
num += (n - m);
n = m;
low += bucketSize;
high = low + bucketSize;
}
}
count = 0; count = 0;
offset = 0; offset = 0;
for(int i = 0; i < seqCount; i += step){ for(int i = 0; i < seqCount; i += step){
......
...@@ -176,6 +176,9 @@ public: ...@@ -176,6 +176,9 @@ public:
/* indicates whether we intend to debug the net */ /* indicates whether we intend to debug the net */
bool isDebugged; bool isDebugged;
/* bucket size */
int bucketSize;
public: public:
/* constructor */ /* constructor */
T2TTrainer(); T2TTrainer();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论