Commit b83a6798 by xuchen

restore the previous implemention

parent 03a9836e
......@@ -375,7 +375,7 @@ void XShapeGrad::GradSplitList(XTensor * node, bool isEfficient)
XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
//CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
node->visitMark = NODE_DOING;
}
......
......@@ -96,7 +96,7 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network
>> gold - gold standard for the output
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> padding - specify a target value that is ignored and does not contribute to the gradient computation
>> loss - name of loss function
*/
void XNet::Backward(XTensor &root, XTensor &gold, XTensor &padding, LOSS_FUNCTION_NAME loss)
......@@ -135,9 +135,9 @@ void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss)
/*
backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes
>> root - a list of root nodes (output) of the network
>> gold - a list of gold standard for the output
>> padding - specify a target value that is ignored
>> roots - a list of root nodes (output) of the network
>> golds - a list of gold standard for the output
>> paddings - specify a target value that is ignored
>> loss - name of loss function
*/
void XNet::Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_NAME loss)
......
......@@ -193,13 +193,21 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor maskDec;
/* generate mask to see "previous" words on the decoder side */
int len = inputDec.GetDim(inputDec.order - 2);
int * dims = new int[inputDec.order + 1];
//int len = inputDec.GetDim(inputDec.order - 2);
//int * dims = new int[inputDec.order + 1];
//for(int i = 0; i < inputDec.order; i++)
// dims[i + 1] = inputDec.GetDim(i);
//dims[0] = nhead;
//dims[inputDec.order] = len;
//InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem);
int len = inputDec.GetDim(inputDec.order - 1);
int * dims = new int[inputDec.order + 2];
for(int i = 0; i < inputDec.order; i++)
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order] = len;
InitTensor(&maskDec, inputDec.order + 1, dims, X_FLOAT, 1.0F, inputDec.devID, inputDec.mem);
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
......
......@@ -114,7 +114,8 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "epochcheckpoint", &useEpochCheckpoint, false);
LoadParamInt(argc, argv, "updatestep", &updateStep, 1);
LoadParamBool(argc, argv, "doubledend", &isDoubledEnd, false);
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, false);
LoadParamBool(argc, argv, "smallbatch", &isSmallBatch, true);
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
buf = new int[bufSize];
buf2 = new int[bufSize];
......@@ -124,9 +125,6 @@ void T2TTrainer::Init(int argc, char ** argv)
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
validStep = 0;
curEpoch = 0;
}
int tc = 0;
......@@ -138,10 +136,8 @@ train the model
>> modelFN - where we keep the model
>> model - model to train
*/
bool T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model)
void T2TTrainer::Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model)
{
curEpoch += 1;
int step = 0;
int wc = 0;
int wordCount = 0;
......@@ -153,7 +149,8 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
int nCheckpoint = 0;
int nSkipped = 0;
int gradStep = 0;
//int validStep = 0;
int validStep = 0;
int epoch = 0;
char * trainFN = new char[(int)strlen(fn) + 10];
strcpy(trainFN, fn);
......@@ -171,10 +168,10 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
double startT = GetClockSec();
//for(epoch = 1; epoch <= nepoch; epoch++){
for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32
if(isShuffled)
Shuffle(fn, trainFN);
if(isShuffled)
Shuffle(fn, trainFN);
#endif
FILE * file = fopen(trainFN, "rb");
......@@ -203,7 +200,6 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
......@@ -273,12 +269,25 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (step % 100 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
lr, elapsed, step, curEpoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
if (!doUpdate)
XPRINT(0, stderr, " (no update)");
XPRINT(0, stderr, "\n");
}
//XMem * mem = model->mem;
//MTYPE used = 0;
//MTYPE total = 0;
//for(int i = 0; i < mem->blockNum; i++){
// if(mem->blocks[i].mem != NULL){
// used += mem->blocks[i].used;
// total += mem->blocks[i].size;
// }
//}
//fprintf(stderr, "%d %d %d %d mem: %lld %lld\n", paddingEnc.GetDim(0), paddingEnc.GetDim(1),
// paddingDec.GetDim(0), paddingDec.GetDim(1), used, total);
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
MakeCheckpoint(model, validFN, modelFN, "step", step);
nStepCheck = 0;
......@@ -287,22 +296,22 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
}
fclose(file);
if (isEnd)
return false;
return true;
//if(useEpochCheckpoint)
// MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
//}
//double elapsed = GetClockSec() - startT;
//
//epoch = MIN(epoch, nepoch);
//
//XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n",
// lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
//XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n",
// elapsed, step, nSkipped, epoch);
break;
if(useEpochCheckpoint)
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
}
double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch);
XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n",
elapsed, step, nSkipped, epoch);
delete[] trainFN;
}
......@@ -356,8 +365,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem, false))
{
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
/* output probabilities */
......@@ -454,6 +461,7 @@ char line[MAX_SEQUENCE_LENGTH];
struct SampleNode
{
int id;
int offset;
int * p;
int size;
int value;
......@@ -545,6 +553,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
for (int i = 0; i < seqCount; i += step) {
SampleNode &node = nodes[count];
node.id = count;
node.offset = i;
node.p = buf + offset;
node.size = 0;
for(int j = 0; j < step; j++)
......@@ -562,8 +571,8 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
SampleNode &node = nodes[count];
memcpy(buf2 + offset, node.p, sizeof(int) * node.size);
for(int j = 0; j < step; j++){
seqLen2[i + j] = seqLen[node.id + j];
seqOffset[i + j] = offset + (j > 0 ? seqLen[node.id + j - 1] : 0);
seqLen2[i + j] = seqLen[node.offset + j];
seqOffset[i + j] = offset + (j > 0 ? seqLen[node.offset + j - 1] : 0);
}
count += 1;
offset += node.size;
......@@ -679,7 +688,7 @@ int T2TTrainer::LoadBatchLM(FILE * file,
if(max < wn)
max = wn;
int tc = isSmallBatch ? max * sc : wc;
int tc = isBigBatch ? wc : max * sc;
if(sc >= sBatch && tc >= wBatch)
break;
}
......@@ -823,8 +832,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
if(maxDec < wnDec)
maxDec = wnDec;
int tc = isSmallBatch ? maxEnc * sc / 2 : wcEnc;
if(sc >= sBatch * 2 && tc >= wBatch)
int tcEnc = isBigBatch ? wcEnc : maxEnc * sc / 2;
int tcDec = isBigBatch ? wcDec : maxDec * sc / 2;
if(sc >= sBatch * 2 && (tcEnc >= wBatch || tcDec >= wBatch))
break;
}
......@@ -838,9 +848,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int dimsEnc[3] = {sCount, maxEnc, vsEnc};
int dimsDec[3] = {sCount, maxDec, vsDec};
InitTensor(batchEnc, 3, dimsEnc, X_FLOAT, 1.0F, devID, mem);
InitTensor(batchEnc, 2, dimsEnc, X_INT, 1.0F, -1);
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor(batchDec, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor(batchDec, 2, dimsDec, X_INT, 1.0F, -1);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
......@@ -857,7 +867,8 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int len = seqLen[s];
int sent = (s - seq)/2;
for(int w = 0; w < len; w++){
batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
batchEnc->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, sent, w);
wCount++;
}
......@@ -869,8 +880,9 @@ int T2TTrainer::LoadBatchMT(FILE * file,
CheckNTErrors(len <= maxDec, "Something is wrong!");
int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){
batchDec->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingDec->Set2D(1.0F, sent, w);
batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
if(w > 0)
gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
if (w == len - 1) {
......
......@@ -103,10 +103,6 @@ public:
/* indicates whether we use adam */
bool useAdam;
int validStep;
int curEpoch;
/* hyper parameters of adam*/
float adamBeta1;
float adamBeta2;
......@@ -143,6 +139,9 @@ public:
length and sc is the sentence number */
bool isSmallBatch;
/* counterpart of "isSmallBatch" */
bool isBigBatch;
public:
/* constructor */
T2TTrainer();
......@@ -154,7 +153,7 @@ public:
void Init(int argc, char ** argv);
/* train the model */
bool Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
......
......@@ -58,75 +58,20 @@ int TransformerMain(int argc, const char ** argv)
LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, "");
T2TTrainer trainer;
trainer.Init(argc, args);
T2TModel model;
model.InitModel(argc, args);
/* learn model parameters */
if(strcmp(trainFN, "")) {
double startT = GetClockSec();
T2TTrainer trainer;
trainer.Init(argc, args);
char * fn = new char[MAX_LINE_LENGTH];
char * fn1 = new char[MAX_LINE_LENGTH];
char * fn2 = new char[MAX_LINE_LENGTH];
modelFN = strcmp(modelFN, "") ? modelFN : (char *)"checkpoint.model";
int epoch;
bool isTrain;
for(epoch = 1; epoch <= trainer.nepoch; epoch++) {
sprintf(fn, "%s.%s.%03d", modelFN, "epoch", epoch - 1);
sprintf(fn1, "%s.%s.%03d", modelFN, "epoch", epoch);
sprintf(fn2, "%s.%s.%03d.output", modelFN, "epoch", epoch);
if(epoch == 1) {
T2TModel model;
model.InitModel(argc, args);
isTrain = trainer.Train(trainFN, testFN, modelFN, &model);
model.Dump(fn1);
}
else {
T2TModel model;
model.InitModel(argc, args);
model.Read(fn);
isTrain = trainer.Train(trainFN, testFN, modelFN, &model);
model.Dump(fn1);
}
if(trainer.useEpochCheckpoint && strcmp(testFN, "")) {
T2TTrainer tester;
tester.Init(argc, args);
T2TModel model;
model.InitModel(argc, args);
model.Read(fn1);
tester.Test(testFN, fn2, &model);
}
if(!isTrain)
break;
}
double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, trainer.nepoch);
if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, modelFN, &model);
XPRINT2(0, stderr, "[INFO] training finished (took %.1fs and epoch=%d)\n", elapsed, epoch);
delete[] fn;
delete[] fn1;
delete[] fn2;
}
/* don't dump the final model */
/* save the final model */
//if(strcmp(modelFN, "") && strcmp(trainFN, ""))
// model.Dump(modelFN);
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
model.Dump(modelFN);
T2TModel model;
model.InitModel(argc, args);
/* load the model if neccessary */
if(strcmp(modelFN, ""))
model.Read(modelFN);
......
......@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* This is the entrance of the low-level tensor library : NiuTrans.Tensor
......@@ -39,9 +39,20 @@ using namespace nts;
void SmallTest();
void TransposeTest();
void LittleTest();
void T2TTest();
void T2TTest2();
void PowerTest();
int main( int argc, const char ** argv )
{
//PowerTest();
//LittleTest();
//T2TTest();
//T2TTest2();
//return 0;
//_CrtSetBreakAlloc(123);
/* a tiny test */
......@@ -63,6 +74,34 @@ int main( int argc, const char ** argv )
return 0;
}
void myRead(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "rb");
if(file == NULL)
printf("%s\n", filename);
tensor->Read(file, label);
}
void myDump(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "wb");
if(file == NULL)
printf("%s\n", filename);
tensor->Dump(file, label);
}
void PowerTest()
{
XTensor input;
XTensor output;
InitTensor2D(&input, 256, 10000, X_FLOAT, 0);
InitTensor2D(&output, 256, 10000, X_FLOAT, 0);
myRead(&input, "1.txt", "");
_Power(&input, &output, 2);
output.Dump(stderr, "", 200);
}
void SmallTest()
{
XTensor a;
......@@ -126,3 +165,128 @@ void TransposeTest()
delete[] data;
}
void LittleTest()
{
int a = 5000;
int b = 100000;
int c = a*b;
printf("%d\n", c);
exit(1);
}
void T2TTest()
{
XTensor * input;
XTensor * weight;
XTensor * output;
XTensor * gold;
XTensor * dedy;
XTensor * dedx;
XTensor * dedxTmp;
XTensor * dedw;
XTensor * padding;
DTYPE loss;
int * dimSize = new int[2];
dimSize[0] = 256;
dimSize[1] = 10001;
int * dimSize2 = new int[3];
dimSize2[0] = 2;
dimSize2[1] = 31;
dimSize2[2] = 256;
int * dimSize3 = new int[3];
dimSize3[0] = 2;
dimSize3[1] = 31;
dimSize3[2] = 10001;
int * dimSize4 = new int[2];
dimSize4[0] = 2;
dimSize4[1] = 31;
input = NewTensor(3, dimSize2, X_FLOAT, 1.0F, 0);
weight = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
dedw = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
gold = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
output = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedy = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedx = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedxTmp = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
padding = NewTensor(2, dimSize4, X_FLOAT, 1.0F, 0);
//weight = NewTensor(2, dimSize);
//dedw = NewTensor(2, dimSize);
//input = NewTensor(3, dimSize2);
//gold = NewTensor(3, dimSize3);
//output = NewTensor(3, dimSize3);
//dedy = NewTensor(3, dimSize3);
//dedx = NewTensor(3, dimSize3);
//dedxTmp = NewTensor(3, dimSize3);
//padding = NewTensor(2, dimSize4);
myRead(input, "x.txt", "x");
myRead(weight, "w.txt", "w");
myRead(gold, "gold.txt", "gold");
myRead(padding, "padding.txt", "padding");
XTensor inter;
inter = MMul(*input, *weight);
_Softmax(&inter, output, 2);
//_LogMe(output);
loss = _CrossEntropyFast(output, gold, REDUCE_MEAN, NULL, padding);
printf("loss: %f\n", loss);
_CrossEntropyBackward(dedy, output, gold, NULL);
//_CrossEntropyBackward(dedy, output, gold, NULL, padding);
myDump(dedy, "dedy.txt", "dedy");
_SoftmaxBackward(NULL, output, input, dedy, dedx, NULL, -1, NOLOSS);
_Sub(output, gold, dedxTmp);
myDump(dedx, "dedx.txt", "dedx");
dedx->Dump(stderr, "dedx", 200);
dedxTmp->Dump(stderr, "dedxTmp", 200);
input->Reshape(input->unitNum/input->GetDim(-1), input->GetDim(-1));
dedx->Reshape(dedx->unitNum/dedx->GetDim(-1), dedx->GetDim(-1));
_MatrixMulBatched(input, X_TRANS, dedx, X_NOTRANS, dedw);
myDump(dedw, "dedw.txt", "dedw");
}
void T2TTest2()
{
int dimSize[3];
dimSize[0] = 161;
dimSize[1] = 47;
dimSize[2] = 10001;
XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, 0);
//XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, -1);
//myRead(probs, "probs.txt", " ");
_SetDataFixedFloat(probs, 1.0F);
probs->Reshape(1, probs->unitNum);
DTYPE sum = _ReduceSumAll(probs);
printf("%e\n", sum);
//XTensor tmp;
//tmp = IsNonZero(*probs);
//DTYPE nonZeroNum = ReduceSumAll(tmp);
//printf("%f\n", nonZeroNum);
//
//DTYPE gpu = ReduceSum(*probs, 1).Get2D(0, 0);
//printf("%e\n", gpu);
}
......@@ -47,6 +47,7 @@ XMem::XMem()
name = new char[64];
strcpy(name, "xmem");
signature = 0;
mergeFreeOTF = true;
}
/*
......@@ -69,6 +70,7 @@ XMem::XMem(int myDevID, MEMPOOL_MODE myMode, MTYPE myBlockSize, int myBlockNum,
name = new char[64];
strcpy(name, "xmem");
signature = 0;
mergeFreeOTF = true;
Initialize(myDevID, myMode, myBlockSize, myBlockNum, myBufSize);
}
......@@ -153,8 +155,13 @@ void XMem::Initialize(int myDevID, MEMPOOL_MODE myMode, MTYPE myBlockSize, int m
bufSize = myBufSize;
#ifdef SMALL_DEVICE
if (myMode == FREE_ON_THE_FLY)
SetIndex(50000);
#else
if (myMode == FREE_ON_THE_FLY)
SetIndex(MILLION);
#endif
signature++;
}
......@@ -613,7 +620,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
while(node != NULL){
if(node->size == 0){
MPieceNode * next = node->next;
RemoveFreeIndexNode(node, entry);
RemoveIndexNode(node, entry);
node = next;
}
else{
......@@ -655,11 +662,11 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
next.pre = &cur;
next.next = cur.next;
cur.next = &next;
if(cur.next != NULL)
cur.next->pre = &next;
cur.size = needed;
if(next.next != NULL)
next.next->pre = &next;
next.state = 1;
next.size = remaining;
next.blockID = cur.blockID;
......@@ -670,7 +677,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
hit->pReal = beg;
blocks[hit->head.blockID].used += head->size;
RemoveFreeIndexNode(hit);
RemoveIndexNode(hit);
AddAllocIndexNode(hit);
result = beg;
......@@ -684,6 +691,65 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
}
/* if there is still no available memory piece, we have to obtain a new block of memory. */
else{
/*MTYPE used = 0;
MTYPE total = 0;
MTYPE free = 0;
for(int i = 0; i < blockNum; i++){
if(blocks[i].mem != NULL){
used += blocks[i].used;
total += blocks[i].size;
}
}
MPieceNode * bufNodes = new MPieceNode[MILLION];
int bufNodeCount = 0;
for(int i = 0; i <= indexEntryNum; i++){
entry = memIndex + i;
node = entry->next;
while(node != NULL){
bufNodes[bufNodeCount++] = *node;
if(node->size == 0){
MPieceNode * next = node->next;
node = next;
}
else{
if(node->head.state == 1 && node->size >= mySize){
fprintf(stderr, "hit!!!!!!!!!!!\n");
}
//fprintf(stderr, "%d %lld %lld %lld\n", node->head.blockID, free, node->size, mySize);
free += node->size;
node = node->next;
}
}
}
MTYPE headSize = 0;
MTYPE headSizeUsed = 0;
for(int i = 0, j = 0; i < blockNum; i++){
XMemBlock * block = blocks + i;
if(block->mem != NULL){
MHeader * head = block->head;
while(head != NULL){
if(head->state == 1){
headSize += head->size;
//fprintf(stderr, "%d head %lld\n", j++, head->size);
}
else{
headSizeUsed += head->size;
}
head = head->next;
}
}
}
delete[] bufNodes;
fprintf(stderr, "%lld %lld\n", headSize, headSizeUsed);
fprintf(stderr, "mem: %lld %lld %lld %lld\n", used, total, free, mySize);*/
int bi;
for(bi = 0; bi < blockNum; bi++){
XMemBlock * block = blocks + bi;
......@@ -856,15 +922,16 @@ int XMem::FindIndexEntry(MTYPE mySize)
}
/*
remove an index node for available memory pieces
remove an index node
>> node - node to remove
>> - the entry of the list that keeps the node
*/
void XMem::RemoveFreeIndexNode(MPieceNode * node, MPieceNode * entry)
void XMem::RemoveIndexNode(MPieceNode * node, MPieceNode * entry)
{
MPieceNode * pre = node->pre;
MPieceNode * next = node->next;
CheckNTErrors(pre != NULL, "cannot free the entry node!");
pre->next = next;
......@@ -884,6 +951,19 @@ void XMem::AddFreeIndexNode(MPieceNode * node, MPieceNode * entry)
{
MPieceNode * entryForMe = entry != NULL ? entry :
memIndex + FindIndexEntry(node->size);
/*MPieceNode * backup = entryForMe->next;
while(backup != NULL && backup->head.size < node->head.size){
backup = backup->next;
entryForMe = entryForMe->next;
}
entryForMe->next = node;
node->pre = entryForMe;
node->next = backup;
if(backup != NULL)
backup->pre = node;*/
MPieceNode * backup = entryForMe->next;
entryForMe->next = node;
......@@ -903,7 +983,7 @@ remove an index node for memory pieces in use
*/
void XMem::RemoveAllocIndexNode(MPieceNode * node, MPieceNode * entry)
{
RemoveFreeIndexNode(node, entry);
RemoveIndexNode(node, entry);
}
/*
......@@ -959,7 +1039,7 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size)
if(node->size == 0){
MPieceNode * next = node->next;
RemoveFreeIndexNode(node, entry);
RemoveIndexNode(node, entry);
node = next;
ShowNTErrors("Something is wrong!");
}
......@@ -979,10 +1059,53 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size)
RemoveAllocIndexNode(hit);
hit->size = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH);
AddFreeIndexNode(hit);
MTYPE usedSize = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH);
blocks[hit->head.blockID].used -= usedSize;
if(mergeFreeOTF){
MHeader * head = &hit->head;
MHeader * pre = head->pre;
MHeader * next = head->next;
bool mergeLeft = false;
bool mergeRight = false;
CheckNTErrors(head != pre, "wrong list of memory headers");
CheckNTErrors(head != next, "wrong list of memory headers");
if(pre != NULL && pre->state == 1 && pre->blockID == head->blockID){
mergeLeft = true;
head->pre = pre->pre;
if(head->pre != NULL)
head->pre->next = head;
hit->p = pre->indexNode->p;
hit->head.size += pre->size;
RemoveAllocIndexNode(pre->indexNode);
if(pre == blocks[head->blockID].head)
blocks[head->blockID].head = head;
}
if(next != NULL && next->state == 1 && next->blockID == head->blockID){
mergeRight = true;
head->next = next->next;
if(head->next != NULL)
head->next->pre = head;
hit->head.size += next->size;
RemoveAllocIndexNode(next->indexNode);
}
if(!mergeLeft && !mergeRight){
hit->size = usedSize;
}
else{
hit->size = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH);
}
}
else{
hit->size = usedSize;
}
blocks[hit->head.blockID].used -= hit->head.size;
AddFreeIndexNode(hit);
}
/* rebuild index to merge small fragments of memory and free the block with no use */
......@@ -1046,7 +1169,7 @@ void XMem::RebuildIndex()
if(head->state == 1){
newNode->size = (char*)p + head->size -
( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
(head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
}
else
newNode->size = node->size;
......@@ -1067,7 +1190,6 @@ void XMem::RebuildIndex()
if(newLast != NULL)
newLast->next = newHeader;
newHeader->pre = newLast;
if(head->state == 1){
newNode->pReal = NULL;
......
......@@ -231,6 +231,9 @@ public:
/* index offset */
int indexOffset;
/* indicates whether we merge free memory pieces on the fly */
bool mergeFreeOTF;
public:
/* constructor */
......@@ -326,7 +329,7 @@ public:
int FindIndexEntry(MTYPE mySize);
/* remove an index node for available memory pieces */
void RemoveFreeIndexNode(MPieceNode * node, MPieceNode * entry = NULL);
void RemoveIndexNode(MPieceNode * node, MPieceNode * entry = NULL);
/* add an index node for available memory pieces */
void AddFreeIndexNode(MPieceNode * node, MPieceNode * entry = NULL);
......
......@@ -260,6 +260,7 @@ void XTensor::DestroyData()
FreeData(this, mem);
else if(data != NULL)
mem->Release(data, GetDataSizeInChar(), signature);
data = NULL;
if(dataHost != NULL)
......@@ -1117,7 +1118,7 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 2 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d2};
......@@ -1364,8 +1365,9 @@ bool XTensor::Resize(const int myOrder, const int * myDimSize,
d = new int[size];
memset(d, 0, size);
}
else
else{
d = (int*)mem->Alloc(mem->devID, size);
}
if(d == NULL)
return false;
......
......@@ -217,7 +217,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0));
int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;
for (int i = 0; i < smalls->count; i++) {
XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
......
......@@ -342,6 +342,24 @@ split a big tensor into small tensors
*/
void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum)
{
CheckNTErrors(big.GetDim(whereToSplit) % splitNum == 0, "Wrong splitNum!");
int order = big.order;
int * dimSize = new int[order];
for (int i = 0; i < big.order; i++) {
if (i != whereToSplit)
dimSize[i] = big.dimSize[i];
else
dimSize[i] = big.dimSize[whereToSplit] / splitNum;
}
float dr = (!big.isSparse) ? 1.0F : big.denseRatio;
for (int i = 0; i < splitNum; i++) {
XTensor * item = NewTensor(order, dimSize, big.dataType, dr, big.devID, big.mem);
smalls.Add(item);
}
delete[] dimSize;
/* call _Split function */
_Split(&big, &smalls, whereToSplit, splitNum);
......
......@@ -207,7 +207,7 @@ void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
DTYPE * gp = gold != NULL ? (DTYPE*)gold->data : NULL;
DTYPE * op = (DTYPE*)y->data;
DTYPE * sp = (DTYPE*)dedx->data;
DTYPE * yp = (DTYPE*)dedy->data;
DTYPE * yp = NULL;
if(lossName == CROSSENTROPY){
if(gold->isSparse){
......
......@@ -272,29 +272,23 @@ bool TestSplit3()
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t1 = NewTensor(tOrder1, tDimSize1);
XTensor * t2 = NewTensor(tOrder2, tDimSize2);
XTensor * tUser1 = NewTensor(tOrder1, tDimSize1);
XTensor * tUser2 = NewTensor(tOrder2, tDimSize2);
/* initialize variables */
s->SetData(sData, sUnitNum);
t1->SetZeroAll();
t2->SetZeroAll();
tUser1->SetZeroAll();
tUser2->SetZeroAll();
/* add tensors to list */
tList->Add(t1);
tList->Add(t2);
tUserList.Add(tUser1);
tUserList.Add(tUser2);
/* call split function */
_Split(s, tList, 1, 2);
Split(*s, tUserList, 1, 2);
/* check results */
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1->CheckData(answer1, tUnitNum1)
&& t2->CheckData(answer2, tUnitNum2) && tUser2->CheckData(answer2, tUnitNum2);
cpuTest = t1->CheckData(answer1, tUnitNum1) && ((XTensor *)tUserList.Get(0))->CheckData(answer1, tUnitNum1) &&
t2->CheckData(answer2, tUnitNum2) && ((XTensor *)tUserList.Get(1))->CheckData(answer2, tUnitNum2);
#ifdef USE_CUDA
/* GPU test */
......@@ -308,42 +302,31 @@ bool TestSplit3()
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0);
XTensor * tGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tUserGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0);
XTensor * tUserGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
/* Initialize variables */
sGPU->SetData(sData, sUnitNum);
tGPU1->SetZeroAll();
tGPU2->SetZeroAll();
tUserGPU1->SetZeroAll();
tUserGPU2->SetZeroAll();
/* add tensors to list */
tList->Add(tGPU1);
tList->Add(tGPU2);
tUserList.Add(tUserGPU1);
tUserList.Add(tUserGPU2);
/* call Split function */
_Split(sGPU, tList, 1, 2);
Split(*sGPU, tUserList, 1, 2);
/* check results */
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && tUserGPU1->CheckData(answer1, tUnitNum1)
&& tGPU2->CheckData(answer2, tUnitNum2) && tUserGPU2->CheckData(answer2, tUnitNum2);
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && ((XTensor *)tUserList.Get(0))->CheckData(answer1, tUnitNum1) &&
tGPU2->CheckData(answer2, tUnitNum2) && ((XTensor *)tUserList.Get(1))->CheckData(answer2, tUnitNum2);
/* destroy variables */
delete s;
delete t1;
delete t2;
delete tUser1;
delete tUser2;
delete sGPU;
delete tGPU1;
delete tGPU2;
delete tUserGPU1;
delete tUserGPU2;
delete[] sDimSize;
delete[] tDimSize1;
delete[] tDimSize2;
......
......@@ -68,6 +68,7 @@ bool TestXMemCase1()
int j = rand() % caseNum;
//fprintf(stderr, "%d %d %d\n", testxmemid, j, ok);
//fprintf(stderr, "iter %d %d %d\n", iter, i, j);
if (p[j] == NULL) {
p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int));
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论