Commit 4040dde0 by liyinqiao

Clean the local memory codes in Transformer.

parent fdce3ca3
......@@ -51,14 +51,12 @@ initialize the model
>> myIgnored - number of position ignored in attention (from the begining)
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
int myDevID)
{
devID = myDevID;
mem = myMem;
isMasked = myIsMasked;
ignored = myIgnored;
......
......@@ -42,9 +42,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* head number */
int nhead;
......@@ -94,7 +91,7 @@ public:
/* initialize the model */
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
int myDevID = -1);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
......
......@@ -280,7 +280,6 @@ load a batch of sequences
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
>> devID - device id
>> mem - memory pool
>> isTraining - indicates whether we are training the model
*/
int T2TBatchLoader::LoadBatch(FILE * file, bool isLM,
......@@ -290,18 +289,17 @@ int T2TBatchLoader::LoadBatch(FILE * file, bool isLM,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining)
int devID, bool isTraining)
{
if(isLM){
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem, isTraining);
isSorted, wCount, devID, isTraining);
}
else{
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, vsDec, sBatch, wBatch,
isSorted, ws, wCount, devID, mem, isTraining);
isSorted, ws, wCount, devID, isTraining);
}
}
......@@ -322,7 +320,6 @@ load a batch of sequences (for LM)
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
>> devID - device id
>> mem - memory pool
>> isTraining - indicates whether we are training the model
*/
int T2TBatchLoader::LoadBatchLM(FILE * file,
......@@ -332,8 +329,7 @@ int T2TBatchLoader::LoadBatchLM(FILE * file,
int * seqs,
int vSize, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem,
bool isTraining)
int devID, bool isTraining)
{
if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file, isSorted, 1);
......@@ -481,7 +477,6 @@ load a batch of sequences (for MT)
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
>> devID - device id
>> mem - memory pool
>> isTraining - indicates whether we are training the model
*/
int T2TBatchLoader::LoadBatchMT(FILE * file,
......@@ -491,8 +486,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
int * seqs,
int vSizeEnc, int vSizeDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining)
int devID, bool isTraining)
{
if (nextBatch < 0 || nextBatch >= bufBatchSize) {
LoadBuf(file, isSorted, 2);
......
......@@ -131,8 +131,7 @@ public:
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
int devID, bool isTraining);
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
......@@ -141,8 +140,7 @@ public:
XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem,
bool isTraining);
int devID, bool isTraining);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
......@@ -151,8 +149,7 @@ public:
XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
int devID, bool isTraining);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
......
......@@ -57,16 +57,14 @@ initialize the model
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id
>> myMem - the memory pool
*/
void AttDecoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
int myDevID)
{
//AttEncoder::InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
//AttEncoder::InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
devID = myDevID;
mem = myMem;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
......@@ -79,7 +77,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\"");
/* embedding model */
embedder.InitModel(argc, argv, devID, mem, false);
embedder.InitModel(argc, argv, devID, false);
attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
......@@ -90,12 +88,12 @@ void AttDecoder::InitModel(int argc, char ** argv,
/* initialize the stacked layers */
for (int i = 0; i < nlayer; i++) {
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID);
fnnLayerNorms[i].InitModel(argc, argv, myDevID);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID);
}
}
......
......@@ -37,9 +37,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* layer number */
int nlayer;
......@@ -95,7 +92,7 @@ public:
/* initialize the model */
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
int myDevID = -1);
/* make the decoding network */
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining);
......
......@@ -31,7 +31,6 @@ namespace transformer
T2TEmbedder::T2TEmbedder()
{
devID = -1;
mem = NULL;
vSize = -1;
maxLength = -1;
}
......@@ -46,12 +45,10 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem, bool isEnc)
void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, bool isEnc)
{
devID = myDevID;
mem = myMem;
if(isEnc){
LoadParamInt(argc, argv, "vsize", &vSize, -1);
......
......@@ -41,9 +41,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */
int vSize;
......@@ -71,7 +68,7 @@ public:
~T2TEmbedder();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL, bool isEnc = true);
void InitModel(int argc, char ** argv, int myDevID = -1, bool isEnc = true);
/* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length);
......
......@@ -52,15 +52,12 @@ initialize the model
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id
>> myMem - the memory pool
*/
>> myDevID - device id*/
void AttEncoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
int myDevID)
{
devID = myDevID;
mem = myMem;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
......@@ -73,7 +70,7 @@ void AttEncoder::InitModel(int argc, char ** argv,
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
/* embedding model */
embedder.InitModel(argc, argv, devID, mem);
embedder.InitModel(argc, argv, devID);
attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
......@@ -82,10 +79,10 @@ void AttEncoder::InitModel(int argc, char ** argv,
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID);
fnnLayerNorms[i].InitModel(argc, argv, myDevID);
}
}
......
......@@ -65,9 +65,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* layer number */
int nlayer;
......@@ -118,7 +115,7 @@ public:
/* initialize the model */
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
int myDevID = -1);
/* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining);
......
......@@ -47,12 +47,10 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
void T2TFNN::InitModel(int argc, char ** argv, int myDevID)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
......@@ -63,10 +61,10 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
InitTensor2DV2(&w1, inSize, hSize, X_FLOAT, devID);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem);
InitTensor1DV2(&b1, hSize, X_FLOAT, devID);
InitTensor2DV2(&w2, hSize, outSize, X_FLOAT, devID);
InitTensor1D(&b2, outSize, X_FLOAT, devID, mem);
InitTensor1DV2(&b2, outSize, X_FLOAT, devID);
float scale = 1.0F;
float finfout1 = (float)sqrt(6.0F * scale/(inSize + hSize));
......
......@@ -36,9 +36,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* size of input vector */
int inSize;
......@@ -72,7 +69,7 @@ public:
~T2TFNN();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1);
/* make the network */
XTensor Make(XTensor &input, bool isTraining);
......
......@@ -32,7 +32,6 @@ namespace transformer
T2TLN::T2TLN()
{
devID = -1;
mem = NULL;
d = 0;
}
......@@ -46,12 +45,10 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
void T2TLN::InitModel(int argc, char ** argv, int myDevID)
{
devID = myDevID;
mem = myMem;
d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
......
......@@ -37,9 +37,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* the transformation matrix w */
XTensor w;
......@@ -57,7 +54,7 @@ public:
~T2TLN();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1);
/* make the network */
XTensor Make(XTensor &input);
......
......@@ -32,7 +32,6 @@ namespace transformer
T2TModel::T2TModel()
{
devID = -1;
mem = NULL;
isLM = false;
isMT = false;
nhead = 1;
......@@ -48,10 +47,6 @@ T2TModel::~T2TModel()
delete encoder;
delete decoder;
delete outputLayer;
/* we delete "mem" at the end because other members are using it and we must
remove the memory space before all tensors are destroyed. */
delete mem;
}
/*
......@@ -61,29 +56,16 @@ initialize the model
*/
void T2TModel::InitModel(int argc, char ** argv)
{
bool useMem = false;
int memSize = 0;
bool isMemFreeOTF = false;
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "memsize", &memSize, 1024);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamBool(argc, argv, "lm", &isLM, !isMT);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamBool(argc, argv, "freeotf", &isMemFreeOTF, false);
if(useMem){
delete mem;
mem = new XMem(devID, FREE_ON_THE_FLY, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder->InitModel(argc, argv, true, 0, devID, mem);
outputLayer->InitModel(argc, argv, devID, mem);
encoder->InitModel(argc, argv, true, 0, devID);
outputLayer->InitModel(argc, argv, devID);
if(isMT)
decoder->InitModel(argc, argv, true, 0, devID, mem);
decoder->InitModel(argc, argv, true, 0, devID);
TensorList params(10);
GetParams(params);
......@@ -149,7 +131,8 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is
dims[i + 1] = input.GetDim(i);
dims[0] = nhead;
dims[input.order + 1] = len;
XTensor mask(input.order + 2, dims, X_FLOAT, 1.0F, padding.devID, padding.mem);
XTensor mask;
InitTensorV2(&mask, input.order + 2, dims, X_FLOAT, padding.devID);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
......
......@@ -40,9 +40,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* the encoder */
AttEncoder * encoder;
......
......@@ -31,7 +31,6 @@ namespace transformer
T2TOutput::T2TOutput()
{
devID = -1;
mem = NULL;
vSize = -1;
inSize = -1;
hSize = -1;
......@@ -47,12 +46,10 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
......
......@@ -38,9 +38,6 @@ public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */
int vSize;
......@@ -61,7 +58,7 @@ public:
~T2TOutput();
/* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1);
/* make the network */
XTensor Make(XTensor &input);
......
......@@ -179,7 +179,7 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
}
else{
inputDec = GeneratePaths(s);
inputDec.SetDevice(inputEnc->devID, inputEnc->mem);
inputDec.SetDevice(inputEnc->devID);
inputDec = Concatenate(first, inputDec, inputDec.order - 1);
}
......@@ -219,8 +219,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
selectSrc.SetInt(stride - 1, 0);
selectTgt.SetInt(0, 0);
selectSrc.SetDevice(decoding.devID, decoding.mem);
selectTgt.SetDevice(decoding.devID, decoding.mem);
selectSrc.SetDevice(decoding.devID);
selectTgt.SetDevice(decoding.devID);
/* the decoder output of the last position */
decodingStep = CopyIndexed(decoding, decoding.order - 2, selectSrc, selectTgt);
......
......@@ -595,7 +595,7 @@ XTensor T2TSearch::MakeFirstMask(T2TStateBundle * beam)
mask.Set(-1e9, i);
}
mask.SetDevice(prob.devID, prob.mem);
mask.SetDevice(prob.devID);
return mask;
}
......
......@@ -75,7 +75,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID;
XMem * mem = model->mem;
XNet net;
......@@ -106,7 +105,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false))
1, 1, false, ws, wc, devID, false))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch!");
CheckNTErrors(!model->isLM, "Only MT model is supported!");
......
......@@ -75,9 +75,6 @@ void T2TTrainer::Init(int argc, char ** argv)
strcpy(argArray[i], argv[i]);
}
bool useMem = false;
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamFloat(argc, argv, "lrate", &lrate, 1.0F);
LoadParamFloat(argc, argv, "lrbias", &lrbias, 0);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
......@@ -142,7 +139,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
#endif
int devID = model->devID;
XMem * mem = model->mem;
XNet net;
if(isDebugged)
......@@ -184,7 +180,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
while (batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label,
NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, mem, true))
sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, true))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
......@@ -321,7 +317,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID;
XMem * mem = model->mem;
XNet net;
......@@ -351,7 +346,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false))
1, 1, false, ws, wc, devID, false))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
......
......@@ -517,7 +517,7 @@ void XTensor::SetDevice(int myDevId, XMem * myMem)
isInGlobalMem = false;
}
else {
ShowNTErrors("TODO!");
myMem = GMems.GetMem(myDevId);
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论