Commit 80ab70a2 by xuchen

Merge branch 'xuchen' into xiaotong-working

parents 411cff4c b83a6798
...@@ -375,7 +375,7 @@ void XShapeGrad::GradSplitList(XTensor * node, bool isEfficient) ...@@ -375,7 +375,7 @@ void XShapeGrad::GradSplitList(XTensor * node, bool isEfficient)
XTensor * input = income.tails[0]; XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!"); CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!"); //CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
node->visitMark = NODE_DOING; node->visitMark = NODE_DOING;
} }
......
...@@ -96,7 +96,7 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss) ...@@ -96,7 +96,7 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
backward propagation to obtain gradient wrt. the loss/error function backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network >> root - root node (output) of the network
>> gold - gold standard for the output >> gold - gold standard for the output
>> padding - specify a target value that is ignored and does not contribute to the loss computation >> padding - specify a target value that is ignored and does not contribute to the gradient computation
>> loss - name of loss function >> loss - name of loss function
*/ */
void XNet::Backward(XTensor &root, XTensor &gold, XTensor &padding, LOSS_FUNCTION_NAME loss) void XNet::Backward(XTensor &root, XTensor &gold, XTensor &padding, LOSS_FUNCTION_NAME loss)
...@@ -135,9 +135,9 @@ void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss) ...@@ -135,9 +135,9 @@ void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss)
/* /*
backward propagation to obtain gradient wrt. the loss/error function backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes with a number of root nodes
>> root - a list of root nodes (output) of the network >> roots - a list of root nodes (output) of the network
>> gold - a list of gold standard for the output >> golds - a list of gold standard for the output
>> padding - specify a target value that is ignored >> paddings - specify a target value that is ignored
>> loss - name of loss function >> loss - name of loss function
*/ */
void XNet::Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_NAME loss) void XNet::Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_NAME loss)
......
...@@ -125,9 +125,6 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -125,9 +125,6 @@ void T2TTrainer::Init(int argc, char ** argv)
adamBeta1T = 1.0F; adamBeta1T = 1.0F;
adamBeta2T = 1.0F; adamBeta2T = 1.0F;
validStep = 0;
curEpoch = 0;
} }
int tc = 0; int tc = 0;
...@@ -139,10 +136,8 @@ train the model ...@@ -139,10 +136,8 @@ train the model
>> modelFN - where we keep the model >> modelFN - where we keep the model
>> model - model to train >> model - model to train
*/ */
bool T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model) void T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model)
{ {
curEpoch += 1;
int step = 0; int step = 0;
int wc = 0; int wc = 0;
int wordCount = 0; int wordCount = 0;
...@@ -154,7 +149,8 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -154,7 +149,8 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
int nCheckpoint = 0; int nCheckpoint = 0;
int nSkipped = 0; int nSkipped = 0;
int gradStep = 0; int gradStep = 0;
//int validStep = 0; int validStep = 0;
int epoch = 0;
char * trainFN = new char[(int)strlen(fn) + 10]; char * trainFN = new char[(int)strlen(fn) + 10];
strcpy(trainFN, fn); strcpy(trainFN, fn);
...@@ -172,7 +168,7 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -172,7 +168,7 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
double startT = GetClockSec(); double startT = GetClockSec();
//for(epoch = 1; epoch <= nepoch; epoch++){ for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32 #ifndef WIN32
if(isShuffled) if(isShuffled)
Shuffle(fn, trainFN); Shuffle(fn, trainFN);
...@@ -204,7 +200,6 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -204,7 +200,6 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
{ {
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
...@@ -271,25 +266,27 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -271,25 +266,27 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
break; break;
} }
if (step % 1 == 0) { if (step % 100 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f", XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
lr, elapsed, step, curEpoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc)); lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
if (!doUpdate) if (!doUpdate)
XPRINT(0, stderr, " (no update)"); XPRINT(0, stderr, " (no update)");
XPRINT(0, stderr, "\n"); XPRINT(0, stderr, "\n");
} }
XMem * mem = model->mem; //XMem * mem = model->mem;
MTYPE used = 0; //MTYPE used = 0;
MTYPE total = 0; //MTYPE total = 0;
for(int i = 0; i < mem->blockNum; i++){ //for(int i = 0; i < mem->blockNum; i++){
if(mem->blocks[i].mem != NULL){ // if(mem->blocks[i].mem != NULL){
used += mem->blocks[i].used; // used += mem->blocks[i].used;
total += mem->blocks[i].size; // total += mem->blocks[i].size;
} // }
} //}
fprintf(stderr, "%d %d %d %d mem: %lld %lld\n", paddingEnc.GetDim(0), paddingEnc.GetDim(1), paddingDec.GetDim(0), paddingDec.GetDim(1), used, total);
//fprintf(stderr, "%d %d %d %d mem: %lld %lld\n", paddingEnc.GetDim(0), paddingEnc.GetDim(1),
// paddingDec.GetDim(0), paddingDec.GetDim(1), used, total);
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){ if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
MakeCheckpoint(model, validFN, modelFN, "step", step); MakeCheckpoint(model, validFN, modelFN, "step", step);
...@@ -301,20 +298,20 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -301,20 +298,20 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
fclose(file); fclose(file);
if (isEnd) if (isEnd)
return false; break;
return true;
//if(useEpochCheckpoint) if(useEpochCheckpoint)
// MakeCheckpoint(model, validFN, modelFN, "epoch", epoch); MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
//} }
//double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
//
//epoch = MIN(epoch, nepoch); epoch = MIN(epoch, nepoch);
//
//XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n", XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n",
// lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount)); lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
//XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n", XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n",
// elapsed, step, nSkipped, epoch); elapsed, step, nSkipped, epoch);
delete[] trainFN; delete[] trainFN;
} }
...@@ -368,8 +365,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -368,8 +365,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
seqs, vSize, vSizeTgt, seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem, false)) 1, 1, false, wc, devID, mem, false))
{ {
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
......
...@@ -103,10 +103,6 @@ public: ...@@ -103,10 +103,6 @@ public:
/* indicates whether we use adam */ /* indicates whether we use adam */
bool useAdam; bool useAdam;
int validStep;
int curEpoch;
/* hyper parameters of adam*/ /* hyper parameters of adam*/
float adamBeta1; float adamBeta1;
float adamBeta2; float adamBeta2;
...@@ -157,7 +153,7 @@ public: ...@@ -157,7 +153,7 @@ public:
void Init(int argc, char ** argv); void Init(int argc, char ** argv);
/* train the model */ /* train the model */
bool Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model); void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
/* test the model */ /* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model); void Test(const char * fn, const char * ofn, T2TModel * model);
......
...@@ -58,74 +58,19 @@ int TransformerMain(int argc, const char ** argv) ...@@ -58,74 +58,19 @@ int TransformerMain(int argc, const char ** argv)
LoadParamString(argc, args, "test", testFN, ""); LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, ""); LoadParamString(argc, args, "output", outputFN, "");
/* learn model parameters */
if(strcmp(trainFN, "")) {
double startT = GetClockSec();
T2TTrainer trainer; T2TTrainer trainer;
trainer.Init(argc, args); trainer.Init(argc, args);
char * fn = new char[MAX_LINE_LENGTH];
char * fn1 = new char[MAX_LINE_LENGTH];
char * fn2 = new char[MAX_LINE_LENGTH];
//modelFN = strcmp(modelFN, "") ? modelFN : (char *)"checkpoint.model";
int epoch;
bool isTrain;
for(epoch = 1; epoch <= trainer.nepoch; epoch++) {
sprintf(fn, "%s.%s.%03d", modelFN, "epoch", epoch - 1);
sprintf(fn1, "%s.%s.%03d", modelFN, "epoch", epoch);
sprintf(fn2, "%s.%s.%03d.output", modelFN, "epoch", epoch);
if(epoch == 1) {
T2TModel model;
model.InitModel(argc, args);
isTrain = trainer.Train(trainFN, testFN, modelFN, &model);
//model.Dump(fn1);
}
else {
T2TModel model;
model.InitModel(argc, args);
model.Read(fn);
isTrain = trainer.Train(trainFN, testFN, modelFN, &model);
//model.Dump(fn1);
}
if(trainer.useEpochCheckpoint && strcmp(testFN, "")) {
T2TTrainer tester;
tester.Init(argc, args);
T2TModel model; T2TModel model;
model.InitModel(argc, args); model.InitModel(argc, args);
//model.Read(fn1);
//tester.Test(testFN, fn2, &model);
}
if(!isTrain)
break;
}
double elapsed = GetClockSec() - startT; /* learn model parameters */
epoch = MIN(epoch, trainer.nepoch); if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, modelFN, &model);
XPRINT2(0, stderr, "[INFO] training finished (took %.1fs and epoch=%d)\n", elapsed, epoch);
delete[] fn;
delete[] fn1;
delete[] fn2;
}
/* don't dump the final model */
/* save the final model */ /* save the final model */
//if(strcmp(modelFN, "") && strcmp(trainFN, "")) if(strcmp(modelFN, "") && strcmp(trainFN, ""))
// model.Dump(modelFN); model.Dump(modelFN);
T2TModel model;
model.InitModel(argc, args);
/* load the model if neccessary */ /* load the model if neccessary */
if(strcmp(modelFN, "")) if(strcmp(modelFN, ""))
......
...@@ -39,9 +39,20 @@ using namespace nts; ...@@ -39,9 +39,20 @@ using namespace nts;
void SmallTest(); void SmallTest();
void TransposeTest(); void TransposeTest();
void LittleTest();
void T2TTest();
void T2TTest2();
void PowerTest();
int main( int argc, const char ** argv ) int main( int argc, const char ** argv )
{ {
//PowerTest();
//LittleTest();
//T2TTest();
//T2TTest2();
//return 0;
//_CrtSetBreakAlloc(123); //_CrtSetBreakAlloc(123);
/* a tiny test */ /* a tiny test */
...@@ -63,6 +74,34 @@ int main( int argc, const char ** argv ) ...@@ -63,6 +74,34 @@ int main( int argc, const char ** argv )
return 0; return 0;
} }
void myRead(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "rb");
if(file == NULL)
printf("%s\n", filename);
tensor->Read(file, label);
}
void myDump(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "wb");
if(file == NULL)
printf("%s\n", filename);
tensor->Dump(file, label);
}
void PowerTest()
{
XTensor input;
XTensor output;
InitTensor2D(&input, 256, 10000, X_FLOAT, 0);
InitTensor2D(&output, 256, 10000, X_FLOAT, 0);
myRead(&input, "1.txt", "");
_Power(&input, &output, 2);
output.Dump(stderr, "", 200);
}
void SmallTest() void SmallTest()
{ {
XTensor a; XTensor a;
...@@ -126,3 +165,128 @@ void TransposeTest() ...@@ -126,3 +165,128 @@ void TransposeTest()
delete[] data; delete[] data;
} }
void LittleTest()
{
int a = 5000;
int b = 100000;
int c = a*b;
printf("%d\n", c);
exit(1);
}
void T2TTest()
{
XTensor * input;
XTensor * weight;
XTensor * output;
XTensor * gold;
XTensor * dedy;
XTensor * dedx;
XTensor * dedxTmp;
XTensor * dedw;
XTensor * padding;
DTYPE loss;
int * dimSize = new int[2];
dimSize[0] = 256;
dimSize[1] = 10001;
int * dimSize2 = new int[3];
dimSize2[0] = 2;
dimSize2[1] = 31;
dimSize2[2] = 256;
int * dimSize3 = new int[3];
dimSize3[0] = 2;
dimSize3[1] = 31;
dimSize3[2] = 10001;
int * dimSize4 = new int[2];
dimSize4[0] = 2;
dimSize4[1] = 31;
input = NewTensor(3, dimSize2, X_FLOAT, 1.0F, 0);
weight = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
dedw = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
gold = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
output = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedy = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedx = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedxTmp = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
padding = NewTensor(2, dimSize4, X_FLOAT, 1.0F, 0);
//weight = NewTensor(2, dimSize);
//dedw = NewTensor(2, dimSize);
//input = NewTensor(3, dimSize2);
//gold = NewTensor(3, dimSize3);
//output = NewTensor(3, dimSize3);
//dedy = NewTensor(3, dimSize3);
//dedx = NewTensor(3, dimSize3);
//dedxTmp = NewTensor(3, dimSize3);
//padding = NewTensor(2, dimSize4);
myRead(input, "x.txt", "x");
myRead(weight, "w.txt", "w");
myRead(gold, "gold.txt", "gold");
myRead(padding, "padding.txt", "padding");
XTensor inter;
inter = MMul(*input, *weight);
_Softmax(&inter, output, 2);
//_LogMe(output);
loss = _CrossEntropyFast(output, gold, REDUCE_MEAN, NULL, padding);
printf("loss: %f\n", loss);
_CrossEntropyBackward(dedy, output, gold, NULL);
//_CrossEntropyBackward(dedy, output, gold, NULL, padding);
myDump(dedy, "dedy.txt", "dedy");
_SoftmaxBackward(NULL, output, input, dedy, dedx, NULL, -1, NOLOSS);
_Sub(output, gold, dedxTmp);
myDump(dedx, "dedx.txt", "dedx");
dedx->Dump(stderr, "dedx", 200);
dedxTmp->Dump(stderr, "dedxTmp", 200);
input->Reshape(input->unitNum/input->GetDim(-1), input->GetDim(-1));
dedx->Reshape(dedx->unitNum/dedx->GetDim(-1), dedx->GetDim(-1));
_MatrixMulBatched(input, X_TRANS, dedx, X_NOTRANS, dedw);
myDump(dedw, "dedw.txt", "dedw");
}
void T2TTest2()
{
int dimSize[3];
dimSize[0] = 161;
dimSize[1] = 47;
dimSize[2] = 10001;
XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, 0);
//XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, -1);
//myRead(probs, "probs.txt", " ");
_SetDataFixedFloat(probs, 1.0F);
probs->Reshape(1, probs->unitNum);
DTYPE sum = _ReduceSumAll(probs);
printf("%e\n", sum);
//XTensor tmp;
//tmp = IsNonZero(*probs);
//DTYPE nonZeroNum = ReduceSumAll(tmp);
//printf("%f\n", nonZeroNum);
//
//DTYPE gpu = ReduceSum(*probs, 1).Get2D(0, 0);
//printf("%e\n", gpu);
}
...@@ -1121,7 +1121,7 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2) ...@@ -1121,7 +1121,7 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!"); CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!"); CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!"); CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 2 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d2}; int dims[3] = {d0, d1, d2};
......
...@@ -217,7 +217,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -217,7 +217,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0)); XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0));
int itemSize = smallsItem0->unitNum * smallsItem0->unitSize; int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;
for (int i = 0; i < smalls->count; i++) { for (int i = 0; i < smalls->count; i++) {
XTensor* smallsItem = (XTensor*)smalls->GetItem(i); XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!"); CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
......
...@@ -342,6 +342,24 @@ split a big tensor into small tensors ...@@ -342,6 +342,24 @@ split a big tensor into small tensors
*/ */
void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum) void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum)
{ {
CheckNTErrors(big.GetDim(whereToSplit) % splitNum == 0, "Wrong splitNum!");
int order = big.order;
int * dimSize = new int[order];
for (int i = 0; i < big.order; i++) {
if (i != whereToSplit)
dimSize[i] = big.dimSize[i];
else
dimSize[i] = big.dimSize[whereToSplit] / splitNum;
}
float dr = (!big.isSparse) ? 1.0F : big.denseRatio;
for (int i = 0; i < splitNum; i++) {
XTensor * item = NewTensor(order, dimSize, big.dataType, dr, big.devID, big.mem);
smalls.Add(item);
}
delete[] dimSize;
/* call _Split function */ /* call _Split function */
_Split(&big, &smalls, whereToSplit, splitNum); _Split(&big, &smalls, whereToSplit, splitNum);
......
...@@ -272,29 +272,23 @@ bool TestSplit3() ...@@ -272,29 +272,23 @@ bool TestSplit3()
XTensor * s = NewTensor(sOrder, sDimSize); XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t1 = NewTensor(tOrder1, tDimSize1); XTensor * t1 = NewTensor(tOrder1, tDimSize1);
XTensor * t2 = NewTensor(tOrder2, tDimSize2); XTensor * t2 = NewTensor(tOrder2, tDimSize2);
XTensor * tUser1 = NewTensor(tOrder1, tDimSize1);
XTensor * tUser2 = NewTensor(tOrder2, tDimSize2);
/* initialize variables */ /* initialize variables */
s->SetData(sData, sUnitNum); s->SetData(sData, sUnitNum);
t1->SetZeroAll(); t1->SetZeroAll();
t2->SetZeroAll(); t2->SetZeroAll();
tUser1->SetZeroAll();
tUser2->SetZeroAll();
/* add tensors to list */ /* add tensors to list */
tList->Add(t1); tList->Add(t1);
tList->Add(t2); tList->Add(t2);
tUserList.Add(tUser1);
tUserList.Add(tUser2);
/* call split function */ /* call split function */
_Split(s, tList, 1, 2); _Split(s, tList, 1, 2);
Split(*s, tUserList, 1, 2); Split(*s, tUserList, 1, 2);
/* check results */ /* check results */
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1->CheckData(answer1, tUnitNum1) cpuTest = t1->CheckData(answer1, tUnitNum1) && ((XTensor *)tUserList.Get(0))->CheckData(answer1, tUnitNum1) &&
&& t2->CheckData(answer2, tUnitNum2) && tUser2->CheckData(answer2, tUnitNum2); t2->CheckData(answer2, tUnitNum2) && ((XTensor *)tUserList.Get(1))->CheckData(answer2, tUnitNum2);
#ifdef USE_CUDA #ifdef USE_CUDA
/* GPU test */ /* GPU test */
...@@ -308,42 +302,31 @@ bool TestSplit3() ...@@ -308,42 +302,31 @@ bool TestSplit3()
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0); XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0); XTensor * tGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0);
XTensor * tGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0); XTensor * tGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tUserGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0);
XTensor * tUserGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
/* Initialize variables */ /* Initialize variables */
sGPU->SetData(sData, sUnitNum); sGPU->SetData(sData, sUnitNum);
tGPU1->SetZeroAll(); tGPU1->SetZeroAll();
tGPU2->SetZeroAll(); tGPU2->SetZeroAll();
tUserGPU1->SetZeroAll();
tUserGPU2->SetZeroAll();
/* add tensors to list */ /* add tensors to list */
tList->Add(tGPU1); tList->Add(tGPU1);
tList->Add(tGPU2); tList->Add(tGPU2);
tUserList.Add(tUserGPU1);
tUserList.Add(tUserGPU2);
/* call Split function */ /* call Split function */
_Split(sGPU, tList, 1, 2); _Split(sGPU, tList, 1, 2);
Split(*sGPU, tUserList, 1, 2); Split(*sGPU, tUserList, 1, 2);
/* check results */ /* check results */
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && tUserGPU1->CheckData(answer1, tUnitNum1) gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && ((XTensor *)tUserList.Get(0))->CheckData(answer1, tUnitNum1) &&
&& tGPU2->CheckData(answer2, tUnitNum2) && tUserGPU2->CheckData(answer2, tUnitNum2); tGPU2->CheckData(answer2, tUnitNum2) && ((XTensor *)tUserList.Get(1))->CheckData(answer2, tUnitNum2);
/* destroy variables */ /* destroy variables */
delete s; delete s;
delete t1; delete t1;
delete t2; delete t2;
delete tUser1;
delete tUser2;
delete sGPU; delete sGPU;
delete tGPU1; delete tGPU1;
delete tGPU2; delete tGPU2;
delete tUserGPU1;
delete tUserGPU2;
delete[] sDimSize; delete[] sDimSize;
delete[] tDimSize1; delete[] tDimSize1;
delete[] tDimSize2; delete[] tDimSize2;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论