Commit 47d01713 by xiaotong

new code of mem

parent e1630c28
...@@ -31,8 +31,6 @@ namespace transformer ...@@ -31,8 +31,6 @@ namespace transformer
/* constructor */ /* constructor */
T2TTrainer::T2TTrainer() T2TTrainer::T2TTrainer()
{ {
devID = -1;
mem = NULL;
seqLen = NULL; seqLen = NULL;
nseqBuf = 0; nseqBuf = 0;
nextSeq = -1; nextSeq = -1;
...@@ -56,7 +54,6 @@ void T2TTrainer::Init(int argc, const char ** argv) ...@@ -56,7 +54,6 @@ void T2TTrainer::Init(int argc, const char ** argv)
bool useMem = false; bool useMem = false;
LoadParamBool(argc, argv, "mem", &useMem, useMem); LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F); LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1); LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1); LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
...@@ -72,10 +69,6 @@ void T2TTrainer::Init(int argc, const char ** argv) ...@@ -72,10 +69,6 @@ void T2TTrainer::Init(int argc, const char ** argv)
seqLen = new int[bufSize]; seqLen = new int[bufSize];
seqOffset = new int[bufSize]; seqOffset = new int[bufSize];
if(useMem){
delete mem;
mem = new XMem(devID, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
}
} }
/* /*
...@@ -94,7 +87,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -94,7 +87,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
float loss = 0; float loss = 0;
float lr = 0; float lr = 0;
model->mem->SetPin(); int devID = model->devID;
XMem * mem = model->mem;
mem->SetPin(); mem->SetPin();
XNet net; XNet net;
...@@ -108,7 +102,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -108,7 +102,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
wordCount = 0; wordCount = 0;
model->mem->BackToPin();
mem->BackToPin(); mem->BackToPin();
/* batch of input sequences */ /* batch of input sequences */
...@@ -117,7 +110,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -117,7 +110,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
/* padding */ /* padding */
XTensor padding; XTensor padding;
while(LoadBatch(file, &batch, &padding, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){ while(LoadBatch(file, &batch, &padding, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
...@@ -157,7 +150,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -157,7 +150,6 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount)); lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
} }
model->mem->BackToPin();
mem->BackToPin(); mem->BackToPin();
} }
...@@ -261,10 +253,13 @@ load a batch of sequences ...@@ -261,10 +253,13 @@ load a batch of sequences
>> wBatch - batch size of words >> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length >> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count >> wCount - word count
>> devID - device id
>> mem - memory pool
*/ */
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding, int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int step, int vs, int sBatch, int wBatch, int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount) bool isSorted, int &wCount,
int devID, XMem * mem)
{ {
if(nextSeq < 0 || nextSeq >= nseqBuf) if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file); LoadBuf(file);
...@@ -295,15 +290,15 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding, ...@@ -295,15 +290,15 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
dims[1] = max; dims[1] = max;
dims[2] = vs; dims[2] = vs;
if(batch->order != 3 || batch->GetDim(0) != dims[0] || //if(batch->order != 3 || batch->GetDim(0) != dims[0] ||
batch->GetDim(1) != dims[1] || batch->GetDim(2) != dims[2]){ // batch->GetDim(1) != dims[1] || batch->GetDim(2) != dims[2]){
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem); InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
} //}
if(padding->order != 2 || padding->GetDim(0) != sc || //if(padding->order != 2 || padding->GetDim(0) != sc ||
padding->GetDim(1) != max){ // padding->GetDim(1) != max){
InitTensor2D(padding, sc, max, X_FLOAT, devID, mem); InitTensor2D(padding, sc, max, X_FLOAT, devID, mem);
} //}
batch->SetZeroAll(); batch->SetZeroAll();
padding->SetZeroAll(); padding->SetZeroAll();
......
...@@ -37,12 +37,6 @@ namespace transformer ...@@ -37,12 +37,6 @@ namespace transformer
class T2TTrainer class T2TTrainer
{ {
public: public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* buffer for loading words */ /* buffer for loading words */
int * buf; int * buf;
...@@ -107,7 +101,8 @@ public: ...@@ -107,7 +101,8 @@ public:
/* load a batch of sequences */ /* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, XTensor * padding, int LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int step, int vs, int sBatch, int wBatch, int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount); bool isSorted, int &wCount,
int devID, XMem * mem);
/* get word probabilities for a batch of sequences */ /* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs); float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
......
...@@ -394,7 +394,8 @@ void * XMem::AllocDynamic(int myDevID, MTYPE mySize) ...@@ -394,7 +394,8 @@ void * XMem::AllocDynamic(int myDevID, MTYPE mySize)
CheckNTErrors(cudaMemset(mem, 0, b->size + 2 * CUDA_PITCH) == cudaSuccess, "Cannot update the memory."); CheckNTErrors(cudaMemset(mem, 0, b->size + 2 * CUDA_PITCH) == cudaSuccess, "Cannot update the memory.");
SetDevice(devIDBackup); SetDevice(devIDBackup);
#else #else
ShowNTErrors("Please specify USE_CUDA for compiling this program."); ShowNTErrors("Please specify USE_CUDA for compiling this program.");> NiuTrans.Network.exe!nts::XMem::AllocDynamic(int myDevID, unsigned __int64 mySize) 387 C++
#endif #endif
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论