Commit a037d802 by xiaotong

bug fixes for t2t mt

parent ecc4041d
......@@ -60,7 +60,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentionsEnde[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
attentionsEnde[i].InitModel(argc, argv, false, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
}
......@@ -89,6 +89,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b
XTensor ln;
XTensor fnn;
XTensor res;
XTensor nothing;
/******************/
/* self attention */
......@@ -106,7 +107,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b
/*****************************/
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, mask, isTraining);
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, nothing, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -137,4 +138,4 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, b
}
}
\ No newline at end of file
}
......@@ -57,8 +57,8 @@ void T2TModel::InitModel(int argc, char ** argv)
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "memsize", &memSize, 1024);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamBool(argc, argv, "lm", &isLM, !isMT);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamBool(argc, argv, "freeotf", &isMemFreeOTF, false);
......@@ -229,6 +229,8 @@ void T2TModel::GetParams(XList &list)
list.Add(&encoder.attLayerNorms[i].w);
list.Add(&encoder.attLayerNorms[i].b);
}
list.Add(&encoder.embedder.w);
if(isMT){
for(int i = 0; i < decoder.nlayer; i++){
......@@ -251,9 +253,9 @@ void T2TModel::GetParams(XList &list)
list.Add(&decoder.attLayerNorms[i].w);
list.Add(&decoder.attLayerNorms[i].b);
}
list.Add(&decoder.embedder.w);
}
list.Add(&encoder.embedder.w);
}
/*
......
......@@ -191,7 +191,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */
XTensor goldSmoothed;
while (LoadBatch(file, true, &batch, &padding, &gold, NULL, vSize, vSizeTgt,
while (LoadBatch(file, model->isLM, &batch, &padding, &gold, NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, wc, devID, mem))
{
......@@ -341,7 +341,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf();
while(LoadBatch(file, true, &batch, &padding, &gold, seqs, vSize, vSizeTgt,
while(LoadBatch(file, model->isLM, &batch, &padding, &gold, seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem))
{
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论