Commit 4040dde0 by liyinqiao

Clean the local memory codes in Transformer.

parent fdce3ca3
...@@ -51,14 +51,12 @@ initialize the model ...@@ -51,14 +51,12 @@ initialize the model
>> myIgnored - number of position ignored in attention (from the begining) >> myIgnored - number of position ignored in attention (from the begining)
>> myIsMasked - indicates whether the attention is with a mask >> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool
*/ */
void T2TAttention::InitModel(int argc, char ** argv, void T2TAttention::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem) int myDevID)
{ {
devID = myDevID; devID = myDevID;
mem = myMem;
isMasked = myIsMasked; isMasked = myIsMasked;
ignored = myIgnored; ignored = myIgnored;
......
...@@ -42,9 +42,6 @@ public: ...@@ -42,9 +42,6 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* head number */ /* head number */
int nhead; int nhead;
...@@ -94,7 +91,7 @@ public: ...@@ -94,7 +91,7 @@ public:
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv, void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1);
/* make the network */ /* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining); XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
......
...@@ -280,7 +280,6 @@ load a batch of sequences ...@@ -280,7 +280,6 @@ load a batch of sequences
>> isSorted - indicates whether the sequences are sorted by length >> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count >> wCount - word count
>> devID - device id >> devID - device id
>> mem - memory pool
>> isTraining - indicates whether we are training the model >> isTraining - indicates whether we are training the model
*/ */
int T2TBatchLoader::LoadBatch(FILE * file, bool isLM, int T2TBatchLoader::LoadBatch(FILE * file, bool isLM,
...@@ -290,18 +289,17 @@ int T2TBatchLoader::LoadBatch(FILE * file, bool isLM, ...@@ -290,18 +289,17 @@ int T2TBatchLoader::LoadBatch(FILE * file, bool isLM,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, bool isTraining)
bool isTraining)
{ {
if(isLM){ if(isLM){
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label, return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, sBatch, wBatch, seqs, vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem, isTraining); isSorted, wCount, devID, isTraining);
} }
else{ else{
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label, return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, vsDec, sBatch, wBatch, seqs, vsEnc, vsDec, sBatch, wBatch,
isSorted, ws, wCount, devID, mem, isTraining); isSorted, ws, wCount, devID, isTraining);
} }
} }
...@@ -322,7 +320,6 @@ load a batch of sequences (for LM) ...@@ -322,7 +320,6 @@ load a batch of sequences (for LM)
>> isSorted - indicates whether the sequences are sorted by length >> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count >> wCount - word count
>> devID - device id >> devID - device id
>> mem - memory pool
>> isTraining - indicates whether we are training the model >> isTraining - indicates whether we are training the model
*/ */
int T2TBatchLoader::LoadBatchLM(FILE * file, int T2TBatchLoader::LoadBatchLM(FILE * file,
...@@ -332,8 +329,7 @@ int T2TBatchLoader::LoadBatchLM(FILE * file, ...@@ -332,8 +329,7 @@ int T2TBatchLoader::LoadBatchLM(FILE * file,
int * seqs, int * seqs,
int vSize, int sBatch, int wBatch, int vSize, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem, int devID, bool isTraining)
bool isTraining)
{ {
if(nextSeq < 0 || nextSeq >= nseqBuf) if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file, isSorted, 1); LoadBuf(file, isSorted, 1);
...@@ -481,7 +477,6 @@ load a batch of sequences (for MT) ...@@ -481,7 +477,6 @@ load a batch of sequences (for MT)
>> isSorted - indicates whether the sequences are sorted by length >> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count >> wCount - word count
>> devID - device id >> devID - device id
>> mem - memory pool
>> isTraining - indicates whether we are training the model >> isTraining - indicates whether we are training the model
*/ */
int T2TBatchLoader::LoadBatchMT(FILE * file, int T2TBatchLoader::LoadBatchMT(FILE * file,
...@@ -491,8 +486,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file, ...@@ -491,8 +486,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
int * seqs, int * seqs,
int vSizeEnc, int vSizeDec, int sBatch, int wBatch, int vSizeEnc, int vSizeDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, bool isTraining)
bool isTraining)
{ {
if (nextBatch < 0 || nextBatch >= bufBatchSize) { if (nextBatch < 0 || nextBatch >= bufBatchSize) {
LoadBuf(file, isSorted, 2); LoadBuf(file, isSorted, 2);
......
...@@ -131,8 +131,7 @@ public: ...@@ -131,8 +131,7 @@ public:
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, bool isTraining);
bool isTraining);
/* load a batch of sequences (for language modeling) */ /* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file, int LoadBatchLM(FILE * file,
...@@ -141,8 +140,7 @@ public: ...@@ -141,8 +140,7 @@ public:
XTensor * gold, XTensor * label, XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch, int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem, int devID, bool isTraining);
bool isTraining);
/* load a batch of sequences (for machine translation) */ /* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file, int LoadBatchMT(FILE * file,
...@@ -151,8 +149,7 @@ public: ...@@ -151,8 +149,7 @@ public:
XTensor * gold, XTensor * label, XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch, int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, bool isTraining);
bool isTraining);
/* shuffle the data file */ /* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile); void Shuffle(const char * srcFile, const char * tgtFile);
......
...@@ -57,16 +57,14 @@ initialize the model ...@@ -57,16 +57,14 @@ initialize the model
>> myIsMasked - indicates whether the masked attention is employed >> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start) >> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool
*/ */
void AttDecoder::InitModel(int argc, char ** argv, void AttDecoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem) int myDevID)
{ {
//AttEncoder::InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem); //AttEncoder::InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
devID = myDevID; devID = myDevID;
mem = myMem;
ignored = myIgnored; ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6); LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
...@@ -79,7 +77,7 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -79,7 +77,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\""); CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\"");
/* embedding model */ /* embedding model */
embedder.InitModel(argc, argv, devID, mem, false); embedder.InitModel(argc, argv, devID, false);
attentions = new T2TAttention[nlayer]; attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer]; fnns = new T2TFNN[nlayer];
...@@ -90,12 +88,12 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -90,12 +88,12 @@ void AttDecoder::InitModel(int argc, char ** argv,
/* initialize the stacked layers */ /* initialize the stacked layers */
for (int i = 0; i < nlayer; i++) { for (int i = 0; i < nlayer; i++) {
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem); attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
fnns[i].InitModel(argc, argv, myDevID, myMem); fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attLayerNorms[i].InitModel(argc, argv, myDevID);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem); fnnLayerNorms[i].InitModel(argc, argv, myDevID);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID, myMem); attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attEndeLayerNorms[i].InitModel(argc, argv, myDevID);
} }
} }
......
...@@ -37,9 +37,6 @@ public: ...@@ -37,9 +37,6 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* layer number */ /* layer number */
int nlayer; int nlayer;
...@@ -95,7 +92,7 @@ public: ...@@ -95,7 +92,7 @@ public:
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv, void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1);
/* make the decoding network */ /* make the decoding network */
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining); XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining);
......
...@@ -31,7 +31,6 @@ namespace transformer ...@@ -31,7 +31,6 @@ namespace transformer
T2TEmbedder::T2TEmbedder() T2TEmbedder::T2TEmbedder()
{ {
devID = -1; devID = -1;
mem = NULL;
vSize = -1; vSize = -1;
maxLength = -1; maxLength = -1;
} }
...@@ -46,12 +45,10 @@ initialize the model ...@@ -46,12 +45,10 @@ initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool
*/ */
void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem, bool isEnc) void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, bool isEnc)
{ {
devID = myDevID; devID = myDevID;
mem = myMem;
if(isEnc){ if(isEnc){
LoadParamInt(argc, argv, "vsize", &vSize, -1); LoadParamInt(argc, argv, "vsize", &vSize, -1);
......
...@@ -41,9 +41,6 @@ public: ...@@ -41,9 +41,6 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */ /* vocabulary size */
int vSize; int vSize;
...@@ -71,7 +68,7 @@ public: ...@@ -71,7 +68,7 @@ public:
~T2TEmbedder(); ~T2TEmbedder();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL, bool isEnc = true); void InitModel(int argc, char ** argv, int myDevID = -1, bool isEnc = true);
/* make positional embeddings */ /* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length); void MakePosEmbedding(int eSize, int d, int length);
......
...@@ -52,15 +52,12 @@ initialize the model ...@@ -52,15 +52,12 @@ initialize the model
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed >> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start) >> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id >> myDevID - device id*/
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, char ** argv, void AttEncoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem) int myDevID)
{ {
devID = myDevID; devID = myDevID;
mem = myMem;
ignored = myIgnored; ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6); LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
...@@ -73,7 +70,7 @@ void AttEncoder::InitModel(int argc, char ** argv, ...@@ -73,7 +70,7 @@ void AttEncoder::InitModel(int argc, char ** argv,
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\""); CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
/* embedding model */ /* embedding model */
embedder.InitModel(argc, argv, devID, mem); embedder.InitModel(argc, argv, devID);
attentions = new T2TAttention[nlayer]; attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer]; fnns = new T2TFNN[nlayer];
...@@ -82,10 +79,10 @@ void AttEncoder::InitModel(int argc, char ** argv, ...@@ -82,10 +79,10 @@ void AttEncoder::InitModel(int argc, char ** argv,
/* initialize the stacked layers */ /* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem); attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
fnns[i].InitModel(argc, argv, myDevID, myMem); fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attLayerNorms[i].InitModel(argc, argv, myDevID);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem); fnnLayerNorms[i].InitModel(argc, argv, myDevID);
} }
} }
......
...@@ -65,9 +65,6 @@ public: ...@@ -65,9 +65,6 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* layer number */ /* layer number */
int nlayer; int nlayer;
...@@ -118,7 +115,7 @@ public: ...@@ -118,7 +115,7 @@ public:
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv, void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored, bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL); int myDevID = -1);
/* make the encoding network */ /* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining); XTensor Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining);
......
...@@ -47,12 +47,10 @@ initialize the model ...@@ -47,12 +47,10 @@ initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool
*/ */
void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) void T2TFNN::InitModel(int argc, char ** argv, int myDevID)
{ {
devID = myDevID; devID = myDevID;
mem = myMem;
float minmax = 0; float minmax = 0;
...@@ -63,10 +61,10 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) ...@@ -63,10 +61,10 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0); LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
InitTensor2DV2(&w1, inSize, hSize, X_FLOAT, devID); InitTensor2DV2(&w1, inSize, hSize, X_FLOAT, devID);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem); InitTensor1DV2(&b1, hSize, X_FLOAT, devID);
InitTensor2DV2(&w2, hSize, outSize, X_FLOAT, devID); InitTensor2DV2(&w2, hSize, outSize, X_FLOAT, devID);
InitTensor1D(&b2, outSize, X_FLOAT, devID, mem); InitTensor1DV2(&b2, outSize, X_FLOAT, devID);
float scale = 1.0F; float scale = 1.0F;
float finfout1 = (float)sqrt(6.0F * scale/(inSize + hSize)); float finfout1 = (float)sqrt(6.0F * scale/(inSize + hSize));
......
...@@ -36,9 +36,6 @@ public: ...@@ -36,9 +36,6 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* size of input vector */ /* size of input vector */
int inSize; int inSize;
...@@ -72,7 +69,7 @@ public: ...@@ -72,7 +69,7 @@ public:
~T2TFNN(); ~T2TFNN();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, char ** argv, int myDevID = -1);
/* make the network */ /* make the network */
XTensor Make(XTensor &input, bool isTraining); XTensor Make(XTensor &input, bool isTraining);
......
...@@ -32,7 +32,6 @@ namespace transformer ...@@ -32,7 +32,6 @@ namespace transformer
T2TLN::T2TLN() T2TLN::T2TLN()
{ {
devID = -1; devID = -1;
mem = NULL;
d = 0; d = 0;
} }
...@@ -46,12 +45,10 @@ initialize the model ...@@ -46,12 +45,10 @@ initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool
*/ */
void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) void T2TLN::InitModel(int argc, char ** argv, int myDevID)
{ {
devID = myDevID; devID = myDevID;
mem = myMem;
d = 0; d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
......
...@@ -36,9 +36,6 @@ class T2TLN ...@@ -36,9 +36,6 @@ class T2TLN
public: public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* the transformation matrix w */ /* the transformation matrix w */
XTensor w; XTensor w;
...@@ -57,7 +54,7 @@ public: ...@@ -57,7 +54,7 @@ public:
~T2TLN(); ~T2TLN();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, char ** argv, int myDevID = -1);
/* make the network */ /* make the network */
XTensor Make(XTensor &input); XTensor Make(XTensor &input);
......
...@@ -32,7 +32,6 @@ namespace transformer ...@@ -32,7 +32,6 @@ namespace transformer
T2TModel::T2TModel() T2TModel::T2TModel()
{ {
devID = -1; devID = -1;
mem = NULL;
isLM = false; isLM = false;
isMT = false; isMT = false;
nhead = 1; nhead = 1;
...@@ -48,10 +47,6 @@ T2TModel::~T2TModel() ...@@ -48,10 +47,6 @@ T2TModel::~T2TModel()
delete encoder; delete encoder;
delete decoder; delete decoder;
delete outputLayer; delete outputLayer;
/* we delete "mem" at the end because other members are using it and we must
remove the memory space before all tensors are destroyed. */
delete mem;
} }
/* /*
...@@ -61,29 +56,16 @@ initialize the model ...@@ -61,29 +56,16 @@ initialize the model
*/ */
void T2TModel::InitModel(int argc, char ** argv) void T2TModel::InitModel(int argc, char ** argv)
{ {
bool useMem = false;
int memSize = 0;
bool isMemFreeOTF = false;
LoadParamInt(argc, argv, "dev", &devID, -1); LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "memsize", &memSize, 1024);
LoadParamBool(argc, argv, "mt", &isMT, false); LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamBool(argc, argv, "lm", &isLM, !isMT); LoadParamBool(argc, argv, "lm", &isLM, !isMT);
LoadParamInt(argc, argv, "nhead", &nhead, 8); LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamBool(argc, argv, "freeotf", &isMemFreeOTF, false);
if(useMem){
delete mem;
mem = new XMem(devID, FREE_ON_THE_FLY, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder->InitModel(argc, argv, true, 0, devID, mem); encoder->InitModel(argc, argv, true, 0, devID);
outputLayer->InitModel(argc, argv, devID, mem); outputLayer->InitModel(argc, argv, devID);
if(isMT) if(isMT)
decoder->InitModel(argc, argv, true, 0, devID, mem); decoder->InitModel(argc, argv, true, 0, devID);
TensorList params(10); TensorList params(10);
GetParams(params); GetParams(params);
...@@ -149,7 +131,8 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is ...@@ -149,7 +131,8 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is
dims[i + 1] = input.GetDim(i); dims[i + 1] = input.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[input.order + 1] = len; dims[input.order + 1] = len;
XTensor mask(input.order + 2, dims, X_FLOAT, 1.0F, padding.devID, padding.mem); XTensor mask;
InitTensorV2(&mask, input.order + 2, dims, X_FLOAT, padding.devID);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
......
...@@ -40,9 +40,6 @@ public: ...@@ -40,9 +40,6 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* the encoder */ /* the encoder */
AttEncoder * encoder; AttEncoder * encoder;
......
...@@ -31,7 +31,6 @@ namespace transformer ...@@ -31,7 +31,6 @@ namespace transformer
T2TOutput::T2TOutput() T2TOutput::T2TOutput()
{ {
devID = -1; devID = -1;
mem = NULL;
vSize = -1; vSize = -1;
inSize = -1; inSize = -1;
hSize = -1; hSize = -1;
...@@ -47,12 +46,10 @@ initialize the model ...@@ -47,12 +46,10 @@ initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool
*/ */
void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem) void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
{ {
devID = myDevID; devID = myDevID;
mem = myMem;
float minmax = 0; float minmax = 0;
......
...@@ -38,9 +38,6 @@ public: ...@@ -38,9 +38,6 @@ public:
/* device id */ /* device id */
int devID; int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */ /* vocabulary size */
int vSize; int vSize;
...@@ -61,7 +58,7 @@ public: ...@@ -61,7 +58,7 @@ public:
~T2TOutput(); ~T2TOutput();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, char ** argv, int myDevID = -1);
/* make the network */ /* make the network */
XTensor Make(XTensor &input); XTensor Make(XTensor &input);
......
...@@ -179,7 +179,7 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, ...@@ -179,7 +179,7 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
} }
else{ else{
inputDec = GeneratePaths(s); inputDec = GeneratePaths(s);
inputDec.SetDevice(inputEnc->devID, inputEnc->mem); inputDec.SetDevice(inputEnc->devID);
inputDec = Concatenate(first, inputDec, inputDec.order - 1); inputDec = Concatenate(first, inputDec, inputDec.order - 1);
} }
...@@ -219,8 +219,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, ...@@ -219,8 +219,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
selectSrc.SetInt(stride - 1, 0); selectSrc.SetInt(stride - 1, 0);
selectTgt.SetInt(0, 0); selectTgt.SetInt(0, 0);
selectSrc.SetDevice(decoding.devID, decoding.mem); selectSrc.SetDevice(decoding.devID);
selectTgt.SetDevice(decoding.devID, decoding.mem); selectTgt.SetDevice(decoding.devID);
/* the decoder output of the last position */ /* the decoder output of the last position */
decodingStep = CopyIndexed(decoding, decoding.order - 2, selectSrc, selectTgt); decodingStep = CopyIndexed(decoding, decoding.order - 2, selectSrc, selectTgt);
......
...@@ -595,7 +595,7 @@ XTensor T2TSearch::MakeFirstMask(T2TStateBundle * beam) ...@@ -595,7 +595,7 @@ XTensor T2TSearch::MakeFirstMask(T2TStateBundle * beam)
mask.Set(-1e9, i); mask.Set(-1e9, i);
} }
mask.SetDevice(prob.devID, prob.mem); mask.SetDevice(prob.devID);
return mask; return mask;
} }
......
...@@ -75,7 +75,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -75,7 +75,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
CheckNTErrors(ofile, "Cannot open the output file"); CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID; int devID = model->devID;
XMem * mem = model->mem;
XNet net; XNet net;
...@@ -106,7 +105,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -106,7 +105,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
while(batchLoader.LoadBatch(file, model->isLM, while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label, &batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt, seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false)) 1, 1, false, ws, wc, devID, false))
{ {
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch!"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch!");
CheckNTErrors(!model->isLM, "Only MT model is supported!"); CheckNTErrors(!model->isLM, "Only MT model is supported!");
......
...@@ -75,9 +75,6 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -75,9 +75,6 @@ void T2TTrainer::Init(int argc, char ** argv)
strcpy(argArray[i], argv[i]); strcpy(argArray[i], argv[i]);
} }
bool useMem = false;
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamFloat(argc, argv, "lrate", &lrate, 1.0F); LoadParamFloat(argc, argv, "lrate", &lrate, 1.0F);
LoadParamFloat(argc, argv, "lrbias", &lrbias, 0); LoadParamFloat(argc, argv, "lrbias", &lrbias, 0);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1); LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
...@@ -142,7 +139,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -142,7 +139,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
#endif #endif
int devID = model->devID; int devID = model->devID;
XMem * mem = model->mem;
XNet net; XNet net;
if(isDebugged) if(isDebugged)
...@@ -184,7 +180,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -184,7 +180,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
while (batchLoader.LoadBatch(file, model->isLM, while (batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label,
NULL, vSize, vSizeTgt, NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, mem, true)) sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, true))
{ {
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
...@@ -321,7 +317,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -321,7 +317,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
CheckNTErrors(ofile, "Cannot open the output file"); CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID; int devID = model->devID;
XMem * mem = model->mem;
XNet net; XNet net;
...@@ -351,7 +346,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -351,7 +346,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
while(batchLoader.LoadBatch(file, model->isLM, while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt, seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false)) 1, 1, false, ws, wc, devID, false))
{ {
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
......
...@@ -517,7 +517,7 @@ void XTensor::SetDevice(int myDevId, XMem * myMem) ...@@ -517,7 +517,7 @@ void XTensor::SetDevice(int myDevId, XMem * myMem)
isInGlobalMem = false; isInGlobalMem = false;
} }
else { else {
ShowNTErrors("TODO!"); myMem = GMems.GetMem(myDevId);
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论