Commit ad40df03 by xiaotong

FREE_ON_THE_FLY mode for transformer

parent c13a8c63
......@@ -52,6 +52,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
{
bool useMem = false;
int memSize = 0;
bool isMemFreeOTF = false;
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
......@@ -59,10 +60,11 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamBool(argc, argv, "freeotf", &isMemFreeOTF, false);
if(useMem){
delete mem;
mem = new XMem(devID, UNI_FREE, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem = new XMem(devID, isMemFreeOTF ? FREE_ON_THE_FLY : UNI_FREE, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
......
......@@ -113,7 +113,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
int devID = model->devID;
XMem * mem = model->mem;
if(mem != NULL)
if(mem != NULL && mem->mode == UNI_FREE)
mem->SetPin();
XNet net;
......@@ -182,7 +182,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount), exp(-prob/wc));
}
if(mem != NULL)
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
}
......@@ -192,7 +192,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
break;
}
if(mem != NULL)
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
double elapsed = GetClockSec() - startT;
......@@ -227,9 +227,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
int devID = model->devID;
XMem * mem = model->mem;
if(mem != NULL)
mem->SetPin();
XNet net;
tf = fopen("tmp.xx.txt", "wb");
......@@ -239,7 +236,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
wordCount = 0;
if(mem != NULL)
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
/* batch of input sequences */
......@@ -304,11 +301,11 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
wordCount += wc;
wordCountTotal += wc;
if(mem != NULL)
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
}
if(mem != NULL)
if(mem != NULL && mem->mode == UNI_FREE)
mem->BackToPin();
fclose(file);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论