Commit b83a6798 by xuchen

restore the previous implemention

parent 03a9836e
...@@ -375,7 +375,7 @@ void XShapeGrad::GradSplitList(XTensor * node, bool isEfficient) ...@@ -375,7 +375,7 @@ void XShapeGrad::GradSplitList(XTensor * node, bool isEfficient)
XTensor * input = income.tails[0]; XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!"); CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!"); //CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
node->visitMark = NODE_DOING; node->visitMark = NODE_DOING;
} }
......
...@@ -96,7 +96,7 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss) ...@@ -96,7 +96,7 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
backward propagation to obtain gradient wrt. the loss/error function backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network >> root - root node (output) of the network
>> gold - gold standard for the output >> gold - gold standard for the output
>> padding - specify a target value that is ignored and does not contribute to the loss computation >> padding - specify a target value that is ignored and does not contribute to the gradient computation
>> loss - name of loss function >> loss - name of loss function
*/ */
void XNet::Backward(XTensor &root, XTensor &gold, XTensor &padding, LOSS_FUNCTION_NAME loss) void XNet::Backward(XTensor &root, XTensor &gold, XTensor &padding, LOSS_FUNCTION_NAME loss)
...@@ -135,9 +135,9 @@ void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss) ...@@ -135,9 +135,9 @@ void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss)
/* /*
backward propagation to obtain gradient wrt. the loss/error function backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes with a number of root nodes
>> root - a list of root nodes (output) of the network >> roots - a list of root nodes (output) of the network
>> gold - a list of gold standard for the output >> golds - a list of gold standard for the output
>> padding - specify a target value that is ignored >> paddings - specify a target value that is ignored
>> loss - name of loss function >> loss - name of loss function
*/ */
void XNet::Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_NAME loss) void XNet::Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_NAME loss)
......
...@@ -193,13 +193,21 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -193,13 +193,21 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor maskDec; XTensor maskDec;
/* generate mask to see "previous" words on the decoder side */ /* generate mask to see "previous" words on the decoder side */
int len = inputDec.GetDim(inputDec.order - 2); //int len = inputDec.GetDim(inputDec.order - 2);
int * dims = new int[inputDec.order + 1]; //int * dims = new int[inputDec.order + 1];
//for(int i = 0; i < inputDec.order; i++)
// dims[i + 1] = inputDec.GetDim(i);
//dims[0] = nhead;
//dims[inputDec.order] = len;
//InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem);
int len = inputDec.GetDim(inputDec.order - 1);
int * dims = new int[inputDec.order + 2];
for(int i = 0; i < inputDec.order; i++) for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order] = len; dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
......
...@@ -114,7 +114,8 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -114,7 +114,8 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false); LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
LoadParamInt(argc, argv, "updatestep", &updateStep, 1); LoadParamInt(argc, argv, "updatestep", &updateStep, 1);
LoadParamBool(argc, argv, "doubledend", &isDoubledEnd, false); LoadParamBool(argc, argv, "doubledend", &isDoubledEnd, false);
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, false); LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, true);
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize]; buf2 = new int[bufSize];
...@@ -124,9 +125,6 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -124,9 +125,6 @@ void T2TTrainer::Init(int argc, char ** argv)
adamBeta1T = 1.0F; adamBeta1T = 1.0F;
adamBeta2T = 1.0F; adamBeta2T = 1.0F;
validStep = 0;
curEpoch = 0;
} }
int tc = 0; int tc = 0;
...@@ -138,10 +136,8 @@ train the model ...@@ -138,10 +136,8 @@ train the model
>> modelFN - where we keep the model >> modelFN - where we keep the model
>> model - model to train >> model - model to train
*/ */
bool T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model) void T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model)
{ {
curEpoch += 1;
int step = 0; int step = 0;
int wc = 0; int wc = 0;
int wordCount = 0; int wordCount = 0;
...@@ -153,7 +149,8 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -153,7 +149,8 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
int nCheckpoint = 0; int nCheckpoint = 0;
int nSkipped = 0; int nSkipped = 0;
int gradStep = 0; int gradStep = 0;
//int validStep = 0; int validStep = 0;
int epoch = 0;
char * trainFN = new char[(int)strlen(fn) + 10]; char * trainFN = new char[(int)strlen(fn) + 10];
strcpy(trainFN, fn); strcpy(trainFN, fn);
...@@ -171,7 +168,7 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -171,7 +168,7 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
double startT = GetClockSec(); double startT = GetClockSec();
//for(epoch = 1; epoch <= nepoch; epoch++){ for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32 #ifndef WIN32
if(isShuffled) if(isShuffled)
Shuffle(fn, trainFN); Shuffle(fn, trainFN);
...@@ -203,7 +200,6 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -203,7 +200,6 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
{ {
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
...@@ -273,12 +269,25 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -273,12 +269,25 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (step % 100 == 0) { if (step % 100 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f", XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
lr, elapsed, step, curEpoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc)); lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
if (!doUpdate) if (!doUpdate)
XPRINT(0, stderr, " (no update)"); XPRINT(0, stderr, " (no update)");
XPRINT(0, stderr, "\n"); XPRINT(0, stderr, "\n");
} }
//XMem * mem = model->mem;
//MTYPE used = 0;
//MTYPE total = 0;
//for(int i = 0; i < mem->blockNum; i++){
// if(mem->blocks[i].mem != NULL){
// used += mem->blocks[i].used;
// total += mem->blocks[i].size;
// }
//}
//fprintf(stderr, "%d %d %d %d mem: %lld %lld\n", paddingEnc.GetDim(0), paddingEnc.GetDim(1),
// paddingDec.GetDim(0), paddingDec.GetDim(1), used, total);
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){ if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
MakeCheckpoint(model, validFN, modelFN, "step", step); MakeCheckpoint(model, validFN, modelFN, "step", step);
nStepCheck = 0; nStepCheck = 0;
...@@ -289,20 +298,20 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -289,20 +298,20 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
fclose(file); fclose(file);
if (isEnd) if (isEnd)
return false; break;
return true;
//if(useEpochCheckpoint)
// MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
//}
//double elapsed = GetClockSec() - startT; if(useEpochCheckpoint)
// MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
//epoch = MIN(epoch, nepoch); }
//
//XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n", double elapsed = GetClockSec() - startT;
// lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
//XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n", epoch = MIN(epoch, nepoch);
// elapsed, step, nSkipped, epoch);
XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n",
elapsed, step, nSkipped, epoch);
delete[] trainFN; delete[] trainFN;
} }
...@@ -356,8 +365,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -356,8 +365,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
seqs, vSize, vSizeTgt, seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem, false)) 1, 1, false, wc, devID, mem, false))
{ {
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */ /* output probabilities */
...@@ -454,6 +461,7 @@ char line[MAX_SEQUENCE_LENGTH]; ...@@ -454,6 +461,7 @@ char line[MAX_SEQUENCE_LENGTH];
struct SampleNode struct SampleNode
{ {
int id; int id;
int offset;
int * p; int * p;
int size; int size;
int value; int value;
...@@ -545,6 +553,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -545,6 +553,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
for (int i = 0; i < seqCount; i += step) { for (int i = 0; i < seqCount; i += step) {
SampleNode &node = nodes[count]; SampleNode &node = nodes[count];
node.id = count; node.id = count;
node.offset = i;
node.p = buf + offset; node.p = buf + offset;
node.size = 0; node.size = 0;
for(int j = 0; j < step; j++) for(int j = 0; j < step; j++)
...@@ -562,8 +571,8 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -562,8 +571,8 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
SampleNode &node = nodes[count]; SampleNode &node = nodes[count];
memcpy(buf2 + offset, node.p, sizeof(int) * node.size); memcpy(buf2 + offset, node.p, sizeof(int) * node.size);
for(int j = 0; j < step; j++){ for(int j = 0; j < step; j++){
seqLen2[i + j] = seqLen[node.id + j]; seqLen2[i + j] = seqLen[node.offset + j];
seqOffset[i + j] = offset + (j > 0 ? seqLen[node.id + j - 1] : 0); seqOffset[i + j] = offset + (j > 0 ? seqLen[node.offset + j - 1] : 0);
} }
count += 1; count += 1;
offset += node.size; offset += node.size;
...@@ -679,7 +688,7 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -679,7 +688,7 @@ int T2TTrainer::LoadBatchLM(FILE * file,
if(max < wn) if(max < wn)
max = wn; max = wn;
int tc = isSmallBatch ? max * sc : wc; int tc = isBigBatch ? wc : max * sc;
if(sc >= sBatch && tc >= wBatch) if(sc >= sBatch && tc >= wBatch)
break; break;
} }
...@@ -823,8 +832,9 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -823,8 +832,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
if(maxDec < wnDec) if(maxDec < wnDec)
maxDec = wnDec; maxDec = wnDec;
int tc = isSmallBatch ? maxEnc * sc / 2 : wcEnc; int tcEnc = isBigBatch ? wcEnc : maxEnc * sc / 2;
if(sc >= sBatch * 2 && tc >= wBatch) int tcDec = isBigBatch ? wcDec : maxDec * sc / 2;
if(sc >= sBatch * 2 && (tcEnc >= wBatch || tcDec >= wBatch))
break; break;
} }
...@@ -838,9 +848,9 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -838,9 +848,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int dimsEnc[3] = {sCount, maxEnc, vsEnc}; int dimsEnc[3] = {sCount, maxEnc, vsEnc};
int dimsDec[3] = {sCount, maxDec, vsDec}; int dimsDec[3] = {sCount, maxDec, vsDec};
InitTensor(batchEnc, 3, dimsEnc, X_FLOAT, 1.0F, devID, mem); InitTensor(batchEnc, 2, dimsEnc, X_INT, 1.0F, -1);
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem); InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem); InitTensor(batchDec, 2, dimsDec, X_INT, 1.0F, -1);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem); InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem); InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
...@@ -857,7 +867,8 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -857,7 +867,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s]; int len = seqLen[s];
int sent = (s - seq)/2; int sent = (s - seq)/2;
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]); batchEnc->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, sent, w); paddingEnc->Set2D(1.0F, sent, w);
wCount++; wCount++;
} }
...@@ -869,8 +880,9 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -869,8 +880,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
CheckNTErrors(len <= maxDec, "Something is wrong!"); CheckNTErrors(len <= maxDec, "Something is wrong!");
int sent = (s - seq - 1)/2; int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
batchDec->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingDec->Set2D(1.0F, sent, w); paddingDec->Set2D(1.0F, sent, w);
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
if(w > 0) if(w > 0)
gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]); gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
if (w == len - 1) { if (w == len - 1) {
......
...@@ -103,10 +103,6 @@ public: ...@@ -103,10 +103,6 @@ public:
/* indicates whether we use adam */ /* indicates whether we use adam */
bool useAdam; bool useAdam;
int validStep;
int curEpoch;
/* hyper parameters of adam*/ /* hyper parameters of adam*/
float adamBeta1; float adamBeta1;
float adamBeta2; float adamBeta2;
...@@ -143,6 +139,9 @@ public: ...@@ -143,6 +139,9 @@ public:
length and sc is the sentence number */ length and sc is the sentence number */
bool isSmallBatch; bool isSmallBatch;
/* counterpart of "isSmallBatch" */
bool isBigBatch;
public: public:
/* constructor */ /* constructor */
T2TTrainer(); T2TTrainer();
...@@ -154,7 +153,7 @@ public: ...@@ -154,7 +153,7 @@ public:
void Init(int argc, char ** argv); void Init(int argc, char ** argv);
/* train the model */ /* train the model */
bool Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model); void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
/* test the model */ /* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model); void Test(const char * fn, const char * ofn, T2TModel * model);
......
...@@ -58,74 +58,19 @@ int TransformerMain(int argc, const char ** argv) ...@@ -58,74 +58,19 @@ int TransformerMain(int argc, const char ** argv)
LoadParamString(argc, args, "test", testFN, ""); LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, ""); LoadParamString(argc, args, "output", outputFN, "");
/* learn model parameters */
if(strcmp(trainFN, "")) {
double startT = GetClockSec();
T2TTrainer trainer; T2TTrainer trainer;
trainer.Init(argc, args); trainer.Init(argc, args);
char * fn = new char[MAX_LINE_LENGTH];
char * fn1 = new char[MAX_LINE_LENGTH];
char * fn2 = new char[MAX_LINE_LENGTH];
modelFN = strcmp(modelFN, "") ? modelFN : (char *)"checkpoint.model";
int epoch;
bool isTrain;
for(epoch = 1; epoch <= trainer.nepoch; epoch++) {
sprintf(fn, "%s.%s.%03d", modelFN, "epoch", epoch - 1);
sprintf(fn1, "%s.%s.%03d", modelFN, "epoch", epoch);
sprintf(fn2, "%s.%s.%03d.output", modelFN, "epoch", epoch);
if(epoch == 1) {
T2TModel model;
model.InitModel(argc, args);
isTrain = trainer.Train(trainFN, testFN, modelFN, &model);
model.Dump(fn1);
}
else {
T2TModel model;
model.InitModel(argc, args);
model.Read(fn);
isTrain = trainer.Train(trainFN, testFN, modelFN, &model);
model.Dump(fn1);
}
if(trainer.useEpochCheckpoint && strcmp(testFN, "")) {
T2TTrainer tester;
tester.Init(argc, args);
T2TModel model; T2TModel model;
model.InitModel(argc, args); model.InitModel(argc, args);
model.Read(fn1);
tester.Test(testFN, fn2, &model);
}
if(!isTrain)
break;
}
double elapsed = GetClockSec() - startT; /* learn model parameters */
epoch = MIN(epoch, trainer.nepoch); if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, modelFN, &model);
XPRINT2(0, stderr, "[INFO] training finished (took %.1fs and epoch=%d)\n", elapsed, epoch);
delete[] fn;
delete[] fn1;
delete[] fn2;
}
/* don't dump the final model */
/* save the final model */ /* save the final model */
//if(strcmp(modelFN, "") && strcmp(trainFN, "")) if(strcmp(modelFN, "") && strcmp(trainFN, ""))
// model.Dump(modelFN); model.Dump(modelFN);
T2TModel model;
model.InitModel(argc, args);
/* load the model if neccessary */ /* load the model if neccessary */
if(strcmp(modelFN, "")) if(strcmp(modelFN, ""))
......
...@@ -39,9 +39,20 @@ using namespace nts; ...@@ -39,9 +39,20 @@ using namespace nts;
void SmallTest(); void SmallTest();
void TransposeTest(); void TransposeTest();
void LittleTest();
void T2TTest();
void T2TTest2();
void PowerTest();
int main( int argc, const char ** argv ) int main( int argc, const char ** argv )
{ {
//PowerTest();
//LittleTest();
//T2TTest();
//T2TTest2();
//return 0;
//_CrtSetBreakAlloc(123); //_CrtSetBreakAlloc(123);
/* a tiny test */ /* a tiny test */
...@@ -63,6 +74,34 @@ int main( int argc, const char ** argv ) ...@@ -63,6 +74,34 @@ int main( int argc, const char ** argv )
return 0; return 0;
} }
void myRead(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "rb");
if(file == NULL)
printf("%s\n", filename);
tensor->Read(file, label);
}
void myDump(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "wb");
if(file == NULL)
printf("%s\n", filename);
tensor->Dump(file, label);
}
void PowerTest()
{
XTensor input;
XTensor output;
InitTensor2D(&input, 256, 10000, X_FLOAT, 0);
InitTensor2D(&output, 256, 10000, X_FLOAT, 0);
myRead(&input, "1.txt", "");
_Power(&input, &output, 2);
output.Dump(stderr, "", 200);
}
void SmallTest() void SmallTest()
{ {
XTensor a; XTensor a;
...@@ -126,3 +165,128 @@ void TransposeTest() ...@@ -126,3 +165,128 @@ void TransposeTest()
delete[] data; delete[] data;
} }
void LittleTest()
{
int a = 5000;
int b = 100000;
int c = a*b;
printf("%d\n", c);
exit(1);
}
void T2TTest()
{
XTensor * input;
XTensor * weight;
XTensor * output;
XTensor * gold;
XTensor * dedy;
XTensor * dedx;
XTensor * dedxTmp;
XTensor * dedw;
XTensor * padding;
DTYPE loss;
int * dimSize = new int[2];
dimSize[0] = 256;
dimSize[1] = 10001;
int * dimSize2 = new int[3];
dimSize2[0] = 2;
dimSize2[1] = 31;
dimSize2[2] = 256;
int * dimSize3 = new int[3];
dimSize3[0] = 2;
dimSize3[1] = 31;
dimSize3[2] = 10001;
int * dimSize4 = new int[2];
dimSize4[0] = 2;
dimSize4[1] = 31;
input = NewTensor(3, dimSize2, X_FLOAT, 1.0F, 0);
weight = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
dedw = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
gold = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
output = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedy = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedx = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedxTmp = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
padding = NewTensor(2, dimSize4, X_FLOAT, 1.0F, 0);
//weight = NewTensor(2, dimSize);
//dedw = NewTensor(2, dimSize);
//input = NewTensor(3, dimSize2);
//gold = NewTensor(3, dimSize3);
//output = NewTensor(3, dimSize3);
//dedy = NewTensor(3, dimSize3);
//dedx = NewTensor(3, dimSize3);
//dedxTmp = NewTensor(3, dimSize3);
//padding = NewTensor(2, dimSize4);
myRead(input, "x.txt", "x");
myRead(weight, "w.txt", "w");
myRead(gold, "gold.txt", "gold");
myRead(padding, "padding.txt", "padding");
XTensor inter;
inter = MMul(*input, *weight);
_Softmax(&inter, output, 2);
//_LogMe(output);
loss = _CrossEntropyFast(output, gold, REDUCE_MEAN, NULL, padding);
printf("loss: %f\n", loss);
_CrossEntropyBackward(dedy, output, gold, NULL);
//_CrossEntropyBackward(dedy, output, gold, NULL, padding);
myDump(dedy, "dedy.txt", "dedy");
_SoftmaxBackward(NULL, output, input, dedy, dedx, NULL, -1, NOLOSS);
_Sub(output, gold, dedxTmp);
myDump(dedx, "dedx.txt", "dedx");
dedx->Dump(stderr, "dedx", 200);
dedxTmp->Dump(stderr, "dedxTmp", 200);
input->Reshape(input->unitNum/input->GetDim(-1), input->GetDim(-1));
dedx->Reshape(dedx->unitNum/dedx->GetDim(-1), dedx->GetDim(-1));
_MatrixMulBatched(input, X_TRANS, dedx, X_NOTRANS, dedw);
myDump(dedw, "dedw.txt", "dedw");
}
void T2TTest2()
{
int dimSize[3];
dimSize[0] = 161;
dimSize[1] = 47;
dimSize[2] = 10001;
XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, 0);
//XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, -1);
//myRead(probs, "probs.txt", " ");
_SetDataFixedFloat(probs, 1.0F);
probs->Reshape(1, probs->unitNum);
DTYPE sum = _ReduceSumAll(probs);
printf("%e\n", sum);
//XTensor tmp;
//tmp = IsNonZero(*probs);
//DTYPE nonZeroNum = ReduceSumAll(tmp);
//printf("%f\n", nonZeroNum);
//
//DTYPE gpu = ReduceSum(*probs, 1).Get2D(0, 0);
//printf("%e\n", gpu);
}
...@@ -47,6 +47,7 @@ XMem::XMem() ...@@ -47,6 +47,7 @@ XMem::XMem()
name = new char[64]; name = new char[64];
strcpy(name, "xmem"); strcpy(name, "xmem");
signature = 0; signature = 0;
mergeFreeOTF = true;
} }
/* /*
...@@ -69,6 +70,7 @@ XMem::XMem(int myDevID, MEMPOOL_MODE myMode, MTYPE myBlockSize, int myBlockNum, ...@@ -69,6 +70,7 @@ XMem::XMem(int myDevID, MEMPOOL_MODE myMode, MTYPE myBlockSize, int myBlockNum,
name = new char[64]; name = new char[64];
strcpy(name, "xmem"); strcpy(name, "xmem");
signature = 0; signature = 0;
mergeFreeOTF = true;
Initialize(myDevID, myMode, myBlockSize, myBlockNum, myBufSize); Initialize(myDevID, myMode, myBlockSize, myBlockNum, myBufSize);
} }
...@@ -153,8 +155,13 @@ void XMem::Initialize(int myDevID, MEMPOOL_MODE myMode, MTYPE myBlockSize, int m ...@@ -153,8 +155,13 @@ void XMem::Initialize(int myDevID, MEMPOOL_MODE myMode, MTYPE myBlockSize, int m
bufSize = myBufSize; bufSize = myBufSize;
#ifdef SMALL_DEVICE
if (myMode == FREE_ON_THE_FLY)
SetIndex(50000);
#else
if (myMode == FREE_ON_THE_FLY) if (myMode == FREE_ON_THE_FLY)
SetIndex(MILLION); SetIndex(MILLION);
#endif
signature++; signature++;
} }
...@@ -613,7 +620,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex) ...@@ -613,7 +620,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
while(node != NULL){ while(node != NULL){
if(node->size == 0){ if(node->size == 0){
MPieceNode * next = node->next; MPieceNode * next = node->next;
RemoveFreeIndexNode(node, entry); RemoveIndexNode(node, entry);
node = next; node = next;
} }
else{ else{
...@@ -655,11 +662,11 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex) ...@@ -655,11 +662,11 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
next.pre = &cur; next.pre = &cur;
next.next = cur.next; next.next = cur.next;
cur.next = &next; cur.next = &next;
if(cur.next != NULL)
cur.next->pre = &next;
cur.size = needed; cur.size = needed;
if(next.next != NULL)
next.next->pre = &next;
next.state = 1; next.state = 1;
next.size = remaining; next.size = remaining;
next.blockID = cur.blockID; next.blockID = cur.blockID;
...@@ -670,7 +677,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex) ...@@ -670,7 +677,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
hit->pReal = beg; hit->pReal = beg;
blocks[hit->head.blockID].used += head->size; blocks[hit->head.blockID].used += head->size;
RemoveFreeIndexNode(hit); RemoveIndexNode(hit);
AddAllocIndexNode(hit); AddAllocIndexNode(hit);
result = beg; result = beg;
...@@ -684,6 +691,65 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex) ...@@ -684,6 +691,65 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
} }
/* if there is still no available memory piece, we have to obtain a new block of memory. */ /* if there is still no available memory piece, we have to obtain a new block of memory. */
else{ else{
/*MTYPE used = 0;
MTYPE total = 0;
MTYPE free = 0;
for(int i = 0; i < blockNum; i++){
if(blocks[i].mem != NULL){
used += blocks[i].used;
total += blocks[i].size;
}
}
MPieceNode * bufNodes = new MPieceNode[MILLION];
int bufNodeCount = 0;
for(int i = 0; i <= indexEntryNum; i++){
entry = memIndex + i;
node = entry->next;
while(node != NULL){
bufNodes[bufNodeCount++] = *node;
if(node->size == 0){
MPieceNode * next = node->next;
node = next;
}
else{
if(node->head.state == 1 && node->size >= mySize){
fprintf(stderr, "hit!!!!!!!!!!!\n");
}
//fprintf(stderr, "%d %lld %lld %lld\n", node->head.blockID, free, node->size, mySize);
free += node->size;
node = node->next;
}
}
}
MTYPE headSize = 0;
MTYPE headSizeUsed = 0;
for(int i = 0, j = 0; i < blockNum; i++){
XMemBlock * block = blocks + i;
if(block->mem != NULL){
MHeader * head = block->head;
while(head != NULL){
if(head->state == 1){
headSize += head->size;
//fprintf(stderr, "%d head %lld\n", j++, head->size);
}
else{
headSizeUsed += head->size;
}
head = head->next;
}
}
}
delete[] bufNodes;
fprintf(stderr, "%lld %lld\n", headSize, headSizeUsed);
fprintf(stderr, "mem: %lld %lld %lld %lld\n", used, total, free, mySize);*/
int bi; int bi;
for(bi = 0; bi < blockNum; bi++){ for(bi = 0; bi < blockNum; bi++){
XMemBlock * block = blocks + bi; XMemBlock * block = blocks + bi;
...@@ -856,15 +922,16 @@ int XMem::FindIndexEntry(MTYPE mySize) ...@@ -856,15 +922,16 @@ int XMem::FindIndexEntry(MTYPE mySize)
} }
/* /*
remove an index node for available memory pieces remove an index node
>> node - node to remove >> node - node to remove
>> - the entry of the list that keeps the node >> - the entry of the list that keeps the node
*/ */
void XMem::RemoveFreeIndexNode(MPieceNode * node, MPieceNode * entry) void XMem::RemoveIndexNode(MPieceNode * node, MPieceNode * entry)
{ {
MPieceNode * pre = node->pre; MPieceNode * pre = node->pre;
MPieceNode * next = node->next; MPieceNode * next = node->next;
CheckNTErrors(pre != NULL, "cannot free the entry node!"); CheckNTErrors(pre != NULL, "cannot free the entry node!");
pre->next = next; pre->next = next;
...@@ -885,6 +952,19 @@ void XMem::AddFreeIndexNode(MPieceNode * node, MPieceNode * entry) ...@@ -885,6 +952,19 @@ void XMem::AddFreeIndexNode(MPieceNode * node, MPieceNode * entry)
MPieceNode * entryForMe = entry != NULL ? entry : MPieceNode * entryForMe = entry != NULL ? entry :
memIndex + FindIndexEntry(node->size); memIndex + FindIndexEntry(node->size);
/*MPieceNode * backup = entryForMe->next;
while(backup != NULL && backup->head.size < node->head.size){
backup = backup->next;
entryForMe = entryForMe->next;
}
entryForMe->next = node;
node->pre = entryForMe;
node->next = backup;
if(backup != NULL)
backup->pre = node;*/
MPieceNode * backup = entryForMe->next; MPieceNode * backup = entryForMe->next;
entryForMe->next = node; entryForMe->next = node;
node->pre = entryForMe; node->pre = entryForMe;
...@@ -903,7 +983,7 @@ remove an index node for memory pieces in use ...@@ -903,7 +983,7 @@ remove an index node for memory pieces in use
*/ */
void XMem::RemoveAllocIndexNode(MPieceNode * node, MPieceNode * entry) void XMem::RemoveAllocIndexNode(MPieceNode * node, MPieceNode * entry)
{ {
RemoveFreeIndexNode(node, entry); RemoveIndexNode(node, entry);
} }
/* /*
...@@ -959,7 +1039,7 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size) ...@@ -959,7 +1039,7 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size)
if(node->size == 0){ if(node->size == 0){
MPieceNode * next = node->next; MPieceNode * next = node->next;
RemoveFreeIndexNode(node, entry); RemoveIndexNode(node, entry);
node = next; node = next;
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
} }
...@@ -979,10 +1059,53 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size) ...@@ -979,10 +1059,53 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size)
RemoveAllocIndexNode(hit); RemoveAllocIndexNode(hit);
MTYPE usedSize = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH);
blocks[hit->head.blockID].used -= usedSize;
if(mergeFreeOTF){
MHeader * head = &hit->head;
MHeader * pre = head->pre;
MHeader * next = head->next;
bool mergeLeft = false;
bool mergeRight = false;
CheckNTErrors(head != pre, "wrong list of memory headers");
CheckNTErrors(head != next, "wrong list of memory headers");
if(pre != NULL && pre->state == 1 && pre->blockID == head->blockID){
mergeLeft = true;
head->pre = pre->pre;
if(head->pre != NULL)
head->pre->next = head;
hit->p = pre->indexNode->p;
hit->head.size += pre->size;
RemoveAllocIndexNode(pre->indexNode);
if(pre == blocks[head->blockID].head)
blocks[head->blockID].head = head;
}
if(next != NULL && next->state == 1 && next->blockID == head->blockID){
mergeRight = true;
head->next = next->next;
if(head->next != NULL)
head->next->pre = head;
hit->head.size += next->size;
RemoveAllocIndexNode(next->indexNode);
}
if(!mergeLeft && !mergeRight){
hit->size = usedSize;
}
else{
hit->size = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH); hit->size = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH);
AddFreeIndexNode(hit); }
}
else{
hit->size = usedSize;
}
blocks[hit->head.blockID].used -= hit->head.size; AddFreeIndexNode(hit);
} }
/* rebuild index to merge small fragments of memory and free the block with no use */ /* rebuild index to merge small fragments of memory and free the block with no use */
...@@ -1046,7 +1169,7 @@ void XMem::RebuildIndex() ...@@ -1046,7 +1169,7 @@ void XMem::RebuildIndex()
if(head->state == 1){ if(head->state == 1){
newNode->size = (char*)p + head->size - newNode->size = (char*)p + head->size -
( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal); (head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
} }
else else
newNode->size = node->size; newNode->size = node->size;
...@@ -1067,7 +1190,6 @@ void XMem::RebuildIndex() ...@@ -1067,7 +1190,6 @@ void XMem::RebuildIndex()
if(newLast != NULL) if(newLast != NULL)
newLast->next = newHeader; newLast->next = newHeader;
newHeader->pre = newLast;
if(head->state == 1){ if(head->state == 1){
newNode->pReal = NULL; newNode->pReal = NULL;
......
...@@ -231,6 +231,9 @@ public: ...@@ -231,6 +231,9 @@ public:
/* index offset */ /* index offset */
int indexOffset; int indexOffset;
/* indicates whether we merge free memory pieces on the fly */
bool mergeFreeOTF;
public: public:
/* constructor */ /* constructor */
...@@ -326,7 +329,7 @@ public: ...@@ -326,7 +329,7 @@ public:
int FindIndexEntry(MTYPE mySize); int FindIndexEntry(MTYPE mySize);
/* remove an index node for available memory pieces */ /* remove an index node for available memory pieces */
void RemoveFreeIndexNode(MPieceNode * node, MPieceNode * entry = NULL); void RemoveIndexNode(MPieceNode * node, MPieceNode * entry = NULL);
/* add an index node for available memory pieces */ /* add an index node for available memory pieces */
void AddFreeIndexNode(MPieceNode * node, MPieceNode * entry = NULL); void AddFreeIndexNode(MPieceNode * node, MPieceNode * entry = NULL);
......
...@@ -260,6 +260,7 @@ void XTensor::DestroyData() ...@@ -260,6 +260,7 @@ void XTensor::DestroyData()
FreeData(this, mem); FreeData(this, mem);
else if(data != NULL) else if(data != NULL)
mem->Release(data, GetDataSizeInChar(), signature); mem->Release(data, GetDataSizeInChar(), signature);
data = NULL; data = NULL;
if(dataHost != NULL) if(dataHost != NULL)
...@@ -1117,7 +1118,7 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2) ...@@ -1117,7 +1118,7 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!"); CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!"); CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!"); CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!"); CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 2 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type."); CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d2}; int dims[3] = {d0, d1, d2};
...@@ -1364,8 +1365,9 @@ bool XTensor::Resize(const int myOrder, const int * myDimSize, ...@@ -1364,8 +1365,9 @@ bool XTensor::Resize(const int myOrder, const int * myDimSize,
d = new int[size]; d = new int[size];
memset(d, 0, size); memset(d, 0, size);
} }
else else{
d = (int*)mem->Alloc(mem->devID, size); d = (int*)mem->Alloc(mem->devID, size);
}
if(d == NULL) if(d == NULL)
return false; return false;
......
...@@ -217,7 +217,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -217,7 +217,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0)); XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0));
int itemSize = smallsItem0->unitNum * smallsItem0->unitSize; int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;
for (int i = 0; i < smalls->count; i++) { for (int i = 0; i < smalls->count; i++) {
XTensor* smallsItem = (XTensor*)smalls->GetItem(i); XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!"); CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
......
...@@ -342,6 +342,24 @@ split a big tensor into small tensors ...@@ -342,6 +342,24 @@ split a big tensor into small tensors
*/ */
void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum) void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum)
{ {
CheckNTErrors(big.GetDim(whereToSplit) % splitNum == 0, "Wrong splitNum!");
int order = big.order;
int * dimSize = new int[order];
for (int i = 0; i < big.order; i++) {
if (i != whereToSplit)
dimSize[i] = big.dimSize[i];
else
dimSize[i] = big.dimSize[whereToSplit] / splitNum;
}
float dr = (!big.isSparse) ? 1.0F : big.denseRatio;
for (int i = 0; i < splitNum; i++) {
XTensor * item = NewTensor(order, dimSize, big.dataType, dr, big.devID, big.mem);
smalls.Add(item);
}
delete[] dimSize;
/* call _Split function */ /* call _Split function */
_Split(&big, &smalls, whereToSplit, splitNum); _Split(&big, &smalls, whereToSplit, splitNum);
......
...@@ -207,7 +207,7 @@ void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x, ...@@ -207,7 +207,7 @@ void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
DTYPE * gp = gold != NULL ? (DTYPE*)gold->data : NULL; DTYPE * gp = gold != NULL ? (DTYPE*)gold->data : NULL;
DTYPE * op = (DTYPE*)y->data; DTYPE * op = (DTYPE*)y->data;
DTYPE * sp = (DTYPE*)dedx->data; DTYPE * sp = (DTYPE*)dedx->data;
DTYPE * yp = (DTYPE*)dedy->data; DTYPE * yp = NULL;
if(lossName == CROSSENTROPY){ if(lossName == CROSSENTROPY){
if(gold->isSparse){ if(gold->isSparse){
......
...@@ -272,29 +272,23 @@ bool TestSplit3() ...@@ -272,29 +272,23 @@ bool TestSplit3()
XTensor * s = NewTensor(sOrder, sDimSize); XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t1 = NewTensor(tOrder1, tDimSize1); XTensor * t1 = NewTensor(tOrder1, tDimSize1);
XTensor * t2 = NewTensor(tOrder2, tDimSize2); XTensor * t2 = NewTensor(tOrder2, tDimSize2);
XTensor * tUser1 = NewTensor(tOrder1, tDimSize1);
XTensor * tUser2 = NewTensor(tOrder2, tDimSize2);
/* initialize variables */ /* initialize variables */
s->SetData(sData, sUnitNum); s->SetData(sData, sUnitNum);
t1->SetZeroAll(); t1->SetZeroAll();
t2->SetZeroAll(); t2->SetZeroAll();
tUser1->SetZeroAll();
tUser2->SetZeroAll();
/* add tensors to list */ /* add tensors to list */
tList->Add(t1); tList->Add(t1);
tList->Add(t2); tList->Add(t2);
tUserList.Add(tUser1);
tUserList.Add(tUser2);
/* call split function */ /* call split function */
_Split(s, tList, 1, 2); _Split(s, tList, 1, 2);
Split(*s, tUserList, 1, 2); Split(*s, tUserList, 1, 2);
/* check results */ /* check results */
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1->CheckData(answer1, tUnitNum1) cpuTest = t1->CheckData(answer1, tUnitNum1) && ((XTensor *)tUserList.Get(0))->CheckData(answer1, tUnitNum1) &&
&& t2->CheckData(answer2, tUnitNum2) && tUser2->CheckData(answer2, tUnitNum2); t2->CheckData(answer2, tUnitNum2) && ((XTensor *)tUserList.Get(1))->CheckData(answer2, tUnitNum2);
#ifdef USE_CUDA #ifdef USE_CUDA
/* GPU test */ /* GPU test */
...@@ -308,42 +302,31 @@ bool TestSplit3() ...@@ -308,42 +302,31 @@ bool TestSplit3()
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0); XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0); XTensor * tGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0);
XTensor * tGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0); XTensor * tGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tUserGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0);
XTensor * tUserGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
/* Initialize variables */ /* Initialize variables */
sGPU->SetData(sData, sUnitNum); sGPU->SetData(sData, sUnitNum);
tGPU1->SetZeroAll(); tGPU1->SetZeroAll();
tGPU2->SetZeroAll(); tGPU2->SetZeroAll();
tUserGPU1->SetZeroAll();
tUserGPU2->SetZeroAll();
/* add tensors to list */ /* add tensors to list */
tList->Add(tGPU1); tList->Add(tGPU1);
tList->Add(tGPU2); tList->Add(tGPU2);
tUserList.Add(tUserGPU1);
tUserList.Add(tUserGPU2);
/* call Split function */ /* call Split function */
_Split(sGPU, tList, 1, 2); _Split(sGPU, tList, 1, 2);
Split(*sGPU, tUserList, 1, 2); Split(*sGPU, tUserList, 1, 2);
/* check results */ /* check results */
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && tUserGPU1->CheckData(answer1, tUnitNum1) gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && ((XTensor *)tUserList.Get(0))->CheckData(answer1, tUnitNum1) &&
&& tGPU2->CheckData(answer2, tUnitNum2) && tUserGPU2->CheckData(answer2, tUnitNum2); tGPU2->CheckData(answer2, tUnitNum2) && ((XTensor *)tUserList.Get(1))->CheckData(answer2, tUnitNum2);
/* destroy variables */ /* destroy variables */
delete s; delete s;
delete t1; delete t1;
delete t2; delete t2;
delete tUser1;
delete tUser2;
delete sGPU; delete sGPU;
delete tGPU1; delete tGPU1;
delete tGPU2; delete tGPU2;
delete tUserGPU1;
delete tUserGPU2;
delete[] sDimSize; delete[] sDimSize;
delete[] tDimSize1; delete[] tDimSize1;
delete[] tDimSize2; delete[] tDimSize2;
......
...@@ -68,6 +68,7 @@ bool TestXMemCase1() ...@@ -68,6 +68,7 @@ bool TestXMemCase1()
int j = rand() % caseNum; int j = rand() % caseNum;
//fprintf(stderr, "%d %d %d\n", testxmemid, j, ok); //fprintf(stderr, "%d %d %d\n", testxmemid, j, ok);
//fprintf(stderr, "iter %d %d %d\n", iter, i, j);
if (p[j] == NULL) { if (p[j] == NULL) {
p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int)); p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int));
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论