Commit 47d01713 by xiaotong

new code of mem

parent e1630c28
......@@ -31,8 +31,6 @@ namespace transformer
/* constructor */
T2TTrainer::T2TTrainer()
{
devID = -1;
mem = NULL;
seqLen = NULL;
nseqBuf = 0;
nextSeq = -1;
......@@ -56,7 +54,6 @@ void T2TTrainer::Init(int argc, const char ** argv)
bool useMem = false;
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
......@@ -72,10 +69,6 @@ void T2TTrainer::Init(int argc, const char ** argv)
seqLen = new int[bufSize];
seqOffset = new int[bufSize];
if(useMem){
delete mem;
mem = new XMem(devID, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
}
}
/*
......@@ -94,7 +87,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
float loss = 0;
float lr = 0;
model->mem->SetPin();
int devID = model->devID;
XMem * mem = model->mem;
mem->SetPin();
XNet net;
......@@ -108,7 +102,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
wordCount = 0;
model->mem->BackToPin();
mem->BackToPin();
/* batch of input sequences */
......@@ -117,7 +110,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
/* padding */
XTensor padding;
while(LoadBatch(file, &batch, &padding, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
while(LoadBatch(file, &batch, &padding, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
/* output probabilities */
XTensor output;
......@@ -157,7 +150,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
}
model->mem->BackToPin();
mem->BackToPin();
}
......@@ -261,10 +253,13 @@ load a batch of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
>> devID - device id
>> mem - memory pool
*/
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount)
bool isSorted, int &wCount,
int devID, XMem * mem)
{
if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file);
......@@ -295,15 +290,15 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
dims[1] = max;
dims[2] = vs;
if(batch->order != 3 || batch->GetDim(0) != dims[0] ||
batch->GetDim(1) != dims[1] || batch->GetDim(2) != dims[2]){
//if(batch->order != 3 || batch->GetDim(0) != dims[0] ||
// batch->GetDim(1) != dims[1] || batch->GetDim(2) != dims[2]){
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
}
//}
if(padding->order != 2 || padding->GetDim(0) != sc ||
padding->GetDim(1) != max){
//if(padding->order != 2 || padding->GetDim(0) != sc ||
// padding->GetDim(1) != max){
InitTensor2D(padding, sc, max, X_FLOAT, devID, mem);
}
//}
batch->SetZeroAll();
padding->SetZeroAll();
......
......@@ -37,12 +37,6 @@ namespace transformer
class T2TTrainer
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* buffer for loading words */
int * buf;
......@@ -107,7 +101,8 @@ public:
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount);
bool isSorted, int &wCount,
int devID, XMem * mem);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
......
......@@ -394,7 +394,8 @@ void * XMem::AllocDynamic(int myDevID, MTYPE mySize)
CheckNTErrors(cudaMemset(mem, 0, b->size + 2 * CUDA_PITCH) == cudaSuccess, "Cannot update the memory.");
SetDevice(devIDBackup);
#else
ShowNTErrors("Please specify USE_CUDA for compiling this program.");
ShowNTErrors("Please specify USE_CUDA for compiling this program.");> NiuTrans.Network.exe!nts::XMem::AllocDynamic(int myDevID, unsigned __int64 mySize) 387 C++
#endif
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论