Commit 49869ae5 by xuchen

recover the Main.cpp file

parent 126c581f
......@@ -35,8 +35,6 @@
void BackwardTest();
void TransposeTest();
void SumDimTest();
void SplitBackwardTest();
void MemTest();
using namespace nts;
using namespace fnnlm;
......@@ -44,10 +42,6 @@ using namespace transformer;
int main( int argc, const char ** argv )
{
//MemTest();
//return 0;
//SplitBackwardTest();
//return 0;
//_CrtSetBreakAlloc(896);
//BackwardTest();
//return 0;
......@@ -215,67 +209,3 @@ void SumDimTest()
delete[] data;
}
\ No newline at end of file
void SplitBackwardTest()
{
int * dimSize = new int[2];
dimSize[0] = 2;
dimSize[1] = 4;
XTensor t1;
InitTensor2D(&t1, 2, 4, X_FLOAT, 0, NULL);
XTensor t2;
InitTensor2D(&t2, 2, 4, X_FLOAT, 0, NULL);
XTensor tensor;
//_SetDataFixedFloat(&t1, 1.0F);
//_SetDataFixedFloat(&t2, 2.0F);
t1.SetDataRand();
t2.SetDataRand();
tensor = t1 + t2;
XList smalls;
XTensor first;
XTensor second;
InitTensor2D(&first, 2, 2, X_FLOAT, 0, NULL);
InitTensor2D(&second, 2, 2, X_FLOAT, 0, NULL);
smalls.Add(&first);
smalls.Add(&second);
Split(tensor, smalls, 1, 2);
XTensor mul;
mul = Sum(first, second);
XNet net;
net.Backward(mul);
net.Dump(stderr);
printf("Done!");
}
void MemTest()
{
XMem * mem;
mem = new XMem(0, FREE_ON_THE_FLY, (MTYPE)MILLION, 1024, MILLION);
XTensor tensor;
InitTensor2D(&tensor, 2, 4, X_FLOAT, 0, mem);
tensor.SetZeroAll();
tensor.Dump(stderr);
delete mem;
if (tensor.mem != NULL) {
printf("It isn't null!\n");
printf("%d\n", (int)tensor.mem->signature);
}
else {
printf("It's null\n");
}
tensor.Dump(stderr);
}
\ No newline at end of file
......@@ -873,9 +873,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int dimsEnc[3] = {sCount, maxEnc, vsEnc};
int dimsDec[3] = {sCount, maxDec, vsDec};
InitTensor(batchEnc, 2, dimsEnc, X_INT, 1.0F, -1);
InitTensor2D(batchEnc, sCount, maxEnc, X_INT, devID, mem);
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(batchDec, 2, dimsDec, X_INT, 1.0F, -1);
InitTensor2D(batchDec, sCount, maxDec, X_INT, devID, mem);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
......@@ -887,6 +887,10 @@ int T2TTrainer::LoadBatchMT(FILE * file,
wCount = 0;
MTYPE * batchEncOffsets = new MTYPE[batchEnc->unitNum];
int * batchEncValues = new int[batchEnc->unitNum];
MTYPE * batchDecOffsets = new MTYPE[batchDec->unitNum];
int * batchDecValues = new int[batchDec->unitNum];
MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
......@@ -896,13 +900,18 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s];
int sent = (s - seq)/2;
for(int w = 0; w < len; w++){
batchEnc->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchEnc->Set2DInt(buf[seqOffset[s] + w], sent, w);
//paddingEnc->Set2D(1.0F, sent, w);
int num = buf[seqOffset[s] + w];
batchEncOffsets[wCount] = batchEnc->GetOffset2D(sent, w);
batchEncValues[wCount] = num;
paddingEncOffsets[wCount] = paddingEnc->GetOffset2D(sent, w);
wCount++;
}
}
batchEnc->SetDataBatched(batchEncOffsets, batchEncValues, wCount);
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
int wCountDec = 0;
......@@ -914,9 +923,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
CheckNTErrors(len <= maxDec, "Something is wrong!");
int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){
batchDec->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchDec->Set2DInt(buf[seqOffset[s] + w], sent, w);
//paddingDec->Set2D(1.0F, sent, w);
paddingDecOffsets[wCountDec++] = paddingDec->GetOffset2D(sent, w);
int num = buf[seqOffset[s] + w];
batchDecOffsets[wCountDec] = batchDec->GetOffset2D(sent, w);
batchDecValues[wCountDec] = num;
paddingDecOffsets[wCountDec] = paddingDec->GetOffset2D(sent, w);
if (w > 0) {
//gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]);
......@@ -932,7 +946,7 @@ int T2TTrainer::LoadBatchMT(FILE * file,
}
}
wCount++;
wCountDec++;
if(seqs != NULL)
seqs[seqSize++] = buf[seqOffset[s] + w];
}
......@@ -943,9 +957,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
}
}
batchDec->SetDataBatched(batchDecOffsets, batchDecValues, wCountDec);
paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCountDec);
gold->SetDataBatched(goldOffsets, 1.0F, wGold);
delete[] batchEncOffsets;
delete[] batchEncValues;
delete[] batchDecOffsets;
delete[] batchDecValues;
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
delete[] goldOffsets;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论