Commit ad40df03 by xiaotong

FREE_ON_THE_FLY mode for transformer

parent c13a8c63
...@@ -52,6 +52,7 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -52,6 +52,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
{ {
bool useMem = false; bool useMem = false;
int memSize = 0; int memSize = 0;
bool isMemFreeOTF = false;
LoadParamInt(argc, argv, "dev", &devID, -1); LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem); LoadParamBool(argc, argv, "mem", &useMem, useMem);
...@@ -59,10 +60,11 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -59,10 +60,11 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamBool(argc, argv, "lm", &isLM, true); LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false); LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8); LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamBool(argc, argv, "freeotf", &isMemFreeOTF, false);
if(useMem){ if(useMem){
delete mem; delete mem;
mem = new XMem(devID, UNI_FREE, (MTYPE)MILLION * 256, 1024, MILLION * 128); mem = new XMem(devID, isMemFreeOTF ? FREE_ON_THE_FLY : UNI_FREE, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION); mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
} }
......
...@@ -113,7 +113,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -113,7 +113,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
int devID = model->devID; int devID = model->devID;
XMem * mem = model->mem; XMem * mem = model->mem;
if(mem != NULL) if(mem != NULL && mem->mode == UNI_FREE)
mem->SetPin(); mem->SetPin();
XNet net; XNet net;
...@@ -182,7 +182,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -182,7 +182,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount), exp(-prob/wc)); lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount), exp(-prob/wc));
} }
if(mem != NULL) if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin(); mem->BackToPin();
} }
...@@ -192,7 +192,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -192,7 +192,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
break; break;
} }
if(mem != NULL) if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin(); mem->BackToPin();
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
...@@ -227,9 +227,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -227,9 +227,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
int devID = model->devID; int devID = model->devID;
XMem * mem = model->mem; XMem * mem = model->mem;
if(mem != NULL)
mem->SetPin();
XNet net; XNet net;
tf = fopen("tmp.xx.txt", "wb"); tf = fopen("tmp.xx.txt", "wb");
...@@ -239,7 +236,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -239,7 +236,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
wordCount = 0; wordCount = 0;
if(mem != NULL) if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin(); mem->BackToPin();
/* batch of input sequences */ /* batch of input sequences */
...@@ -304,11 +301,11 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -304,11 +301,11 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
wordCount += wc; wordCount += wc;
wordCountTotal += wc; wordCountTotal += wc;
if(mem != NULL) if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin(); mem->BackToPin();
} }
if(mem != NULL) if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin(); mem->BackToPin();
fclose(file); fclose(file);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论