Commit fe90b454 by xiaotong

fix the bug of invalid memory pointer in releasing a memory pool

parent 48661938
...@@ -35,12 +35,19 @@ T2TModel::T2TModel() ...@@ -35,12 +35,19 @@ T2TModel::T2TModel()
isLM = false; isLM = false;
isMT = false; isMT = false;
nhead = 1; nhead = 1;
encoder = new AttEncoder();
decoder = new AttDecoder();
outputLayer = new T2TOutput();
} }
/* de-constructor */ /* de-constructor */
T2TModel::~T2TModel() T2TModel::~T2TModel()
{ {
delete mem; delete mem;
delete encoder;
delete decoder;
delete outputLayer;
} }
/* /*
...@@ -68,11 +75,11 @@ void T2TModel::InitModel(int argc, char ** argv) ...@@ -68,11 +75,11 @@ void T2TModel::InitModel(int argc, char ** argv)
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION); mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
} }
encoder.InitModel(argc, argv, isLM, 0, devID, mem); encoder->InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem); outputLayer->InitModel(argc, argv, devID, mem);
if(isMT) if(isMT)
decoder.InitModel(argc, argv, true, 0, devID, mem); decoder->InitModel(argc, argv, true, 0, devID, mem);
XList params(10); XList params(10);
GetParams(params); GetParams(params);
...@@ -92,7 +99,7 @@ make the encoding network ...@@ -92,7 +99,7 @@ make the encoding network
*/ */
XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining) XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining)
{ {
return encoder.Make(input, mask, isTraining); return encoder->Make(input, mask, isTraining);
} }
/* /*
...@@ -106,7 +113,7 @@ make the decoding network ...@@ -106,7 +113,7 @@ make the decoding network
*/ */
XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining) XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining)
{ {
return decoder.Make(inputDec, outputEnc, mask, isTraining); return decoder->Make(inputDec, outputEnc, mask, isTraining);
} }
/* /*
...@@ -168,7 +175,7 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is ...@@ -168,7 +175,7 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is
////_Sum(&mask, padding3, &mask); ////_Sum(&mask, padding3, &mask);
encoding = MakeEncoder(input, mask, isTraining); encoding = MakeEncoder(input, mask, isTraining);
outputLayer.Make(encoding, output); outputLayer->Make(encoding, output);
delete[] dims; delete[] dims;
delete[] dimsPadding; delete[] dimsPadding;
...@@ -246,7 +253,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -246,7 +253,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
encoding = MakeEncoder(inputEnc, maskEnc, isTraining); encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining); decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining);
outputLayer.Make(decoding, output); outputLayer->Make(decoding, output);
delete[] dims; delete[] dims;
delete[] dimsPadding; delete[] dimsPadding;
...@@ -262,48 +269,48 @@ get parameter matrics ...@@ -262,48 +269,48 @@ get parameter matrics
void T2TModel::GetParams(XList &list) void T2TModel::GetParams(XList &list)
{ {
list.Clear(); list.Clear();
list.Add(&outputLayer.w); list.Add(&outputLayer->w);
for(int i = 0; i < encoder.nlayer; i++){ for(int i = 0; i < encoder->nlayer; i++){
list.Add(&encoder.fnns[i].w1); list.Add(&encoder->fnns[i].w1);
list.Add(&encoder.fnns[i].b1); list.Add(&encoder->fnns[i].b1);
list.Add(&encoder.fnns[i].w2); list.Add(&encoder->fnns[i].w2);
list.Add(&encoder.fnns[i].b2); list.Add(&encoder->fnns[i].b2);
list.Add(&encoder.attentions[i].wk); list.Add(&encoder->attentions[i].wk);
list.Add(&encoder.attentions[i].wq); list.Add(&encoder->attentions[i].wq);
list.Add(&encoder.attentions[i].wv); list.Add(&encoder->attentions[i].wv);
list.Add(&encoder.attentions[i].wa); list.Add(&encoder->attentions[i].wa);
list.Add(&encoder.fnnLayerNorms[i].w); list.Add(&encoder->fnnLayerNorms[i].w);
list.Add(&encoder.fnnLayerNorms[i].b); list.Add(&encoder->fnnLayerNorms[i].b);
list.Add(&encoder.attLayerNorms[i].w); list.Add(&encoder->attLayerNorms[i].w);
list.Add(&encoder.attLayerNorms[i].b); list.Add(&encoder->attLayerNorms[i].b);
} }
list.Add(&encoder.embedder.w); list.Add(&encoder->embedder.w);
if(isMT){ if(isMT){
for(int i = 0; i < decoder.nlayer; i++){ for(int i = 0; i < decoder->nlayer; i++){
list.Add(&decoder.fnns[i].w1); list.Add(&decoder->fnns[i].w1);
list.Add(&decoder.fnns[i].b1); list.Add(&decoder->fnns[i].b1);
list.Add(&decoder.fnns[i].w2); list.Add(&decoder->fnns[i].w2);
list.Add(&decoder.fnns[i].b2); list.Add(&decoder->fnns[i].b2);
list.Add(&decoder.attentionsEnde[i].wk); list.Add(&decoder->attentionsEnde[i].wk);
list.Add(&decoder.attentionsEnde[i].wq); list.Add(&decoder->attentionsEnde[i].wq);
list.Add(&decoder.attentionsEnde[i].wv); list.Add(&decoder->attentionsEnde[i].wv);
list.Add(&decoder.attentionsEnde[i].wa); list.Add(&decoder->attentionsEnde[i].wa);
list.Add(&decoder.attEndeLayerNorms[i].w); list.Add(&decoder->attEndeLayerNorms[i].w);
list.Add(&decoder.attEndeLayerNorms[i].b); list.Add(&decoder->attEndeLayerNorms[i].b);
list.Add(&decoder.attentions[i].wk); list.Add(&decoder->attentions[i].wk);
list.Add(&decoder.attentions[i].wq); list.Add(&decoder->attentions[i].wq);
list.Add(&decoder.attentions[i].wv); list.Add(&decoder->attentions[i].wv);
list.Add(&decoder.attentions[i].wa); list.Add(&decoder->attentions[i].wa);
list.Add(&decoder.fnnLayerNorms[i].w); list.Add(&decoder->fnnLayerNorms[i].w);
list.Add(&decoder.fnnLayerNorms[i].b); list.Add(&decoder->fnnLayerNorms[i].b);
list.Add(&decoder.attLayerNorms[i].w); list.Add(&decoder->attLayerNorms[i].w);
list.Add(&decoder.attLayerNorms[i].b); list.Add(&decoder->attLayerNorms[i].b);
} }
list.Add(&decoder.embedder.w); list.Add(&decoder->embedder.w);
} }
} }
......
...@@ -41,13 +41,13 @@ public: ...@@ -41,13 +41,13 @@ public:
XMem * mem; XMem * mem;
/* the encoder */ /* the encoder */
AttEncoder encoder; AttEncoder * encoder;
/* the decoder */ /* the decoder */
AttDecoder decoder; AttDecoder * decoder;
/* output layer */ /* output layer */
T2TOutput outputLayer; T2TOutput * outputLayer;
/* indicates whether the model is running for language modeling */ /* indicates whether the model is running for language modeling */
bool isLM; bool isLM;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论