Commit fdce3ca3 by liyinqiao

Clean the local memory codes in FNNLM.

parent 2921f79d
......@@ -68,8 +68,8 @@ void Read(const char * fn, FNNModel &model);
void Test(const char * test, const char * result, FNNModel &model);
int LoadNGrams(FILE * file, int n, NGram * ngrams, int sentNum, int wordNum);
void InitZeroOneTensor2D(XTensor &tensor, int rowNum, int colNum, int * rows, int * cols,
int itemNum, int devID, XMem * mem);
void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSize, int devID, XMem * mem);
int itemNum, int devID);
void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSize, int devID);
void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net);
void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss,
FNNModel &model, FNNModel &grad, FNNNet &net);
......@@ -229,11 +229,6 @@ void LoadArgs(int argc, const char ** argv, FNNModel &model)
fprintf(stderr, " -dev=%d\n", model.devID);
}
}
for(int i = 0; i < argc; i++){
if (!strcmp(argv[i], "-mempool"))
model.mem = new XMem(model.devID);
}
}
/* check model settings */
......@@ -262,11 +257,6 @@ void Copy(FNNModel &tgt, FNNModel &src)
tgt.vSize = src.vSize;
tgt.devID = src.devID;
tgt.useMemPool = src.useMemPool;
if(src.mem != NULL){
tgt.mem = new XMem(src.mem->devID, src.mem->mode,
src.mem->maxBlockSize, src.mem->blockNum,
src.mem->bufSize);
}
}
/*
......@@ -459,10 +449,10 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
/* make the input tensor for position i */
for(int i = 0; i < model.n - 1; i++)
MakeWordBatch(inputs[i], ngrams, ngramNum, i, model.vSize, model.devID, model.mem);
MakeWordBatch(inputs[i], ngrams, ngramNum, i, model.vSize, model.devID);
/* make the gold tensor */
MakeWordBatch(gold, ngrams, ngramNum, model.n - 1, model.vSize, model.devID, model.mem);
MakeWordBatch(gold, ngrams, ngramNum, model.n - 1, model.vSize, model.devID);
if(!autoDiff){
/* prepare an empty network for building the fnn */
......@@ -474,8 +464,6 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
/* forward computation */
Forward(inputs, output, model, net);
/* backward computation to obtain gradients */
Backward(inputs, output, gold, CROSSENTROPY, model, grad, net);
......@@ -727,10 +715,9 @@ The indexed cell is set to 1, and 0 otherwise.
>> cols - column index
>> itemNum - number of non-zero items
>> devID - device id
>> mem - memory pool
*/
void InitZeroOneTensor2D(XTensor &tensor, int rowNum, int colNum, int * rows, int * cols,
int itemNum, int devID, XMem * mem)
int itemNum, int devID)
{
InitTensor2DV2(&tensor, rowNum, colNum, X_FLOAT, devID);
......@@ -749,9 +736,8 @@ make a tensor that encodes a batch of words
>> n - indicate which word is encode for each ngram
>> vSize - vocabulary size
>> devID - device id
>> mem - memory pool
*/
void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSize, int devID, XMem * mem)
void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSize, int devID)
{
int * rows = new int[ngramNum];
int * cols = new int[ngramNum];
......@@ -761,7 +747,7 @@ void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSiz
cols[i] = ngrams[i].words[n];
}
InitZeroOneTensor2D(batch, ngramNum, vSize, rows, cols, ngramNum, devID, mem);
InitZeroOneTensor2D(batch, ngramNum, vSize, rows, cols, ngramNum, devID);
delete[] rows;
delete[] cols;
......@@ -1170,10 +1156,10 @@ void Test(const char * test, const char * result, FNNModel &model)
/* make the input tensor for position i */
for (int i = 0; i < model.n - 1; i++)
MakeWordBatch(inputs[i], ngrams, ngramNum, i, model.vSize, model.devID, model.mem);
MakeWordBatch(inputs[i], ngrams, ngramNum, i, model.vSize, model.devID);
/* make the gold tensor */
MakeWordBatch(gold, ngrams, ngramNum, model.n - 1, model.vSize, model.devID, model.mem);
MakeWordBatch(gold, ngrams, ngramNum, model.n - 1, model.vSize, model.devID);
if (!autoDiff) {
/* prepare an empty network for building the fnn */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论