Commit d38bed87 by xiaotong

debugging

parent b155f4c1
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "T2TModel.h" #include "T2TModel.h"
#include "T2TUtility.h" #include "T2TUtility.h"
#include "../../tensor/core/CHeader.h" #include "../../tensor/core/CHeader.h"
#include "../../tensor/XUtility.h"
namespace transformer namespace transformer
{ {
...@@ -366,7 +367,12 @@ void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec, ...@@ -366,7 +367,12 @@ void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
This matrix can be used to block the attention to current or following words in This matrix can be used to block the attention to current or following words in
a given sequence. */ a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0); _SetDataLowTri(&maskDec, 1e9F, 0);
maskDec.Dump(stderr, "mask: ");
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F); _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
maskDec.Dump(stderr, "mask: ");
/* encoder-decoder mask that prevents the attention to padding dummy words */ /* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1); dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
...@@ -445,6 +451,8 @@ dump the parameters ...@@ -445,6 +451,8 @@ dump the parameters
*/ */
void T2TModel::Dump(const char * fn) void T2TModel::Dump(const char * fn)
{ {
double startT = GetClockSec();
FILE * file = fopen(fn, "wb"); FILE * file = fopen(fn, "wb");
CheckNTErrors(file, "Cannot open the model file"); CheckNTErrors(file, "Cannot open the model file");
...@@ -459,12 +467,16 @@ void T2TModel::Dump(const char * fn) ...@@ -459,12 +467,16 @@ void T2TModel::Dump(const char * fn)
fclose(file); fclose(file);
XPRINT(0, stderr, "[INFO] model saved\n"); double elapsed = GetClockSec() - startT;
XPRINT1(0, stderr, "[INFO] model saved (took %.1fs)\n", elapsed);
} }
/* read the parameters */ /* read the parameters */
void T2TModel::Read(const char * fn) void T2TModel::Read(const char * fn)
{ {
double startT = GetClockSec();
FILE * file = fopen(fn, "rb"); FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "Cannot open the model file"); CheckNTErrors(file, "Cannot open the model file");
...@@ -479,7 +491,9 @@ void T2TModel::Read(const char * fn) ...@@ -479,7 +491,9 @@ void T2TModel::Read(const char * fn)
fclose(file); fclose(file);
XPRINT(0, stderr, "[INFO] model loaded\n"); double elapsed = GetClockSec() - startT;
XPRINT1(0, stderr, "[INFO] model loaded (took %.1fs)\n", elapsed);
} }
} }
...@@ -193,6 +193,11 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * ...@@ -193,6 +193,11 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* decoder mask */ /* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
inputDec.Dump(stderr, "inputdec: ");
//encoding->Dump(stderr, "encoding: ");
maskDec.Dump(stderr, "maskdec: ");
maskEncDec.Dump(stderr, "mask-enc-dec: ");
/* make the decoding network */ /* make the decoding network */
decoding = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false); decoding = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论