Commit fe90b454 by xiaotong

fix the bug of invalid memory pointer in releasing a memory pool

parent 48661938
......@@ -35,12 +35,19 @@ T2TModel::T2TModel()
isLM = false;
isMT = false;
nhead = 1;
encoder = new AttEncoder();
decoder = new AttDecoder();
outputLayer = new T2TOutput();
}
/* de-constructor */
T2TModel::~T2TModel()
{
delete mem;
delete encoder;
delete decoder;
delete outputLayer;
}
/*
......@@ -68,11 +75,11 @@ void T2TModel::InitModel(int argc, char ** argv)
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder.InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
encoder->InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer->InitModel(argc, argv, devID, mem);
if(isMT)
decoder.InitModel(argc, argv, true, 0, devID, mem);
decoder->InitModel(argc, argv, true, 0, devID, mem);
XList params(10);
GetParams(params);
......@@ -92,7 +99,7 @@ make the encoding network
*/
XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining)
{
return encoder.Make(input, mask, isTraining);
return encoder->Make(input, mask, isTraining);
}
/*
......@@ -106,7 +113,7 @@ make the decoding network
*/
XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, bool isTraining)
{
return decoder.Make(inputDec, outputEnc, mask, isTraining);
return decoder->Make(inputDec, outputEnc, mask, isTraining);
}
/*
......@@ -168,7 +175,7 @@ void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool is
////_Sum(&mask, padding3, &mask);
encoding = MakeEncoder(input, mask, isTraining);
outputLayer.Make(encoding, output);
outputLayer->Make(encoding, output);
delete[] dims;
delete[] dimsPadding;
......@@ -246,7 +253,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
decoding = MakeDecoder(inputDec, encoding, maskDec, isTraining);
outputLayer.Make(decoding, output);
outputLayer->Make(decoding, output);
delete[] dims;
delete[] dimsPadding;
......@@ -262,48 +269,48 @@ get parameter matrics
void T2TModel::GetParams(XList &list)
{
list.Clear();
list.Add(&outputLayer.w);
for(int i = 0; i < encoder.nlayer; i++){
list.Add(&encoder.fnns[i].w1);
list.Add(&encoder.fnns[i].b1);
list.Add(&encoder.fnns[i].w2);
list.Add(&encoder.fnns[i].b2);
list.Add(&encoder.attentions[i].wk);
list.Add(&encoder.attentions[i].wq);
list.Add(&encoder.attentions[i].wv);
list.Add(&encoder.attentions[i].wa);
list.Add(&encoder.fnnLayerNorms[i].w);
list.Add(&encoder.fnnLayerNorms[i].b);
list.Add(&encoder.attLayerNorms[i].w);
list.Add(&encoder.attLayerNorms[i].b);
list.Add(&outputLayer->w);
for(int i = 0; i < encoder->nlayer; i++){
list.Add(&encoder->fnns[i].w1);
list.Add(&encoder->fnns[i].b1);
list.Add(&encoder->fnns[i].w2);
list.Add(&encoder->fnns[i].b2);
list.Add(&encoder->attentions[i].wk);
list.Add(&encoder->attentions[i].wq);
list.Add(&encoder->attentions[i].wv);
list.Add(&encoder->attentions[i].wa);
list.Add(&encoder->fnnLayerNorms[i].w);
list.Add(&encoder->fnnLayerNorms[i].b);
list.Add(&encoder->attLayerNorms[i].w);
list.Add(&encoder->attLayerNorms[i].b);
}
list.Add(&encoder.embedder.w);
list.Add(&encoder->embedder.w);
if(isMT){
for(int i = 0; i < decoder.nlayer; i++){
list.Add(&decoder.fnns[i].w1);
list.Add(&decoder.fnns[i].b1);
list.Add(&decoder.fnns[i].w2);
list.Add(&decoder.fnns[i].b2);
list.Add(&decoder.attentionsEnde[i].wk);
list.Add(&decoder.attentionsEnde[i].wq);
list.Add(&decoder.attentionsEnde[i].wv);
list.Add(&decoder.attentionsEnde[i].wa);
list.Add(&decoder.attEndeLayerNorms[i].w);
list.Add(&decoder.attEndeLayerNorms[i].b);
list.Add(&decoder.attentions[i].wk);
list.Add(&decoder.attentions[i].wq);
list.Add(&decoder.attentions[i].wv);
list.Add(&decoder.attentions[i].wa);
list.Add(&decoder.fnnLayerNorms[i].w);
list.Add(&decoder.fnnLayerNorms[i].b);
list.Add(&decoder.attLayerNorms[i].w);
list.Add(&decoder.attLayerNorms[i].b);
for(int i = 0; i < decoder->nlayer; i++){
list.Add(&decoder->fnns[i].w1);
list.Add(&decoder->fnns[i].b1);
list.Add(&decoder->fnns[i].w2);
list.Add(&decoder->fnns[i].b2);
list.Add(&decoder->attentionsEnde[i].wk);
list.Add(&decoder->attentionsEnde[i].wq);
list.Add(&decoder->attentionsEnde[i].wv);
list.Add(&decoder->attentionsEnde[i].wa);
list.Add(&decoder->attEndeLayerNorms[i].w);
list.Add(&decoder->attEndeLayerNorms[i].b);
list.Add(&decoder->attentions[i].wk);
list.Add(&decoder->attentions[i].wq);
list.Add(&decoder->attentions[i].wv);
list.Add(&decoder->attentions[i].wa);
list.Add(&decoder->fnnLayerNorms[i].w);
list.Add(&decoder->fnnLayerNorms[i].b);
list.Add(&decoder->attLayerNorms[i].w);
list.Add(&decoder->attLayerNorms[i].b);
}
list.Add(&decoder.embedder.w);
list.Add(&decoder->embedder.w);
}
}
......
......@@ -41,13 +41,13 @@ public:
XMem * mem;
/* the encoder */
AttEncoder encoder;
AttEncoder * encoder;
/* the decoder */
AttDecoder decoder;
AttDecoder * decoder;
/* output layer */
T2TOutput outputLayer;
T2TOutput * outputLayer;
/* indicates whether the model is running for language modeling */
bool isLM;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论