Commit 9449403d by xiaotong

save the encoding output layer for inference

parent b3c2cf56
...@@ -211,6 +211,11 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -211,6 +211,11 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining); decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
outputLayer->Make(decoding, output); outputLayer->Make(decoding, output);
encoding.SetName(ENCODING_NAME);
decoding.SetName(DECODING_NAME);
output.SetName(OUTPUT_NAME);
inputEnc.SetName(ENCODING_INPUT_NAME);
} }
/* /*
......
...@@ -31,6 +31,12 @@ ...@@ -31,6 +31,12 @@
namespace transformer namespace transformer
{ {
#define ENCODING_NAME "encoding"
#define DECODING_NAME "decoding"
#define ENCODING_INPUT_NAME "encoding_input"
#define DECODING_INPUT_NAME "decoding_input"
#define OUTPUT_NAME "output"
/* a transformer model that keeps parameters of the encoder, /* a transformer model that keeps parameters of the encoder,
the decoder and the output layer (softmax). Also, it creates the decoder and the output layer (softmax). Also, it creates
the network used in transformer. */ the network used in transformer. */
......
...@@ -40,11 +40,19 @@ T2TPredictor::~T2TPredictor() ...@@ -40,11 +40,19 @@ T2TPredictor::~T2TPredictor()
/* /*
create an initial state create an initial state
>> model - the t2t model >> model - the t2t model
>> top - the top-most layer of the network
>> state - the state to be initialized >> state - the state to be initialized
*/ */
void T2TPredictor::Init(T2TModel * model, T2TStateBundle * state) void T2TPredictor::Init(T2TModel * model, XTensor * top, T2TStateBundle * state)
{ {
state->layersEncoding.Clear();
state->layersDecoding.Clear();
XTensor * encoding = XLink::SearchNode(top, ENCODING_NAME);
CheckNTErrors(encoding != NULL, "No encoding layers found!");
state->layersEncoding.Add(encoding);
state->layersDecoding.Add(NULL);
} }
/* /*
...@@ -53,7 +61,7 @@ read a state ...@@ -53,7 +61,7 @@ read a state
>> state - a set of states. It keeps >> state - a set of states. It keeps
1) hypotheses (states) 1) hypotheses (states)
2) probablities of hypotheses 2) probablities of hypotheses
3) parts of the network for expanding to the next state 3) parts of the network for expanding toward the next state
*/ */
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state) void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state)
{ {
...@@ -67,23 +75,26 @@ predict the next state ...@@ -67,23 +75,26 @@ predict the next state
*/ */
void T2TPredictor::Predict(T2TStateBundle * next) void T2TPredictor::Predict(T2TStateBundle * next)
{ {
next->decoderLayers.Clear(); next->layersEncoding.Clear();
next->encoderLayers.Clear(); next->layersDecoding.Clear();
AttDecoder &decoder = *m->decoder; AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */ /* word indices of previous positions */
XTensor &inputLast = *(XTensor*)s->decoderLayers.GetItem(0); XTensor * inputLast = (XTensor*)s->layersDecoding.GetItem(0);
/* word indices of positions up to next state */ /* word indices of positions up to next state */
XTensor &input = *NewTensor(); XTensor &input = *NewTensor();
input = Concatenate(inputLast, s->prediction, inputLast.GetDim(-1)); if(inputLast == NULL)
input = s->prediction;
else
input = Concatenate(*inputLast, s->prediction, inputLast->GetDim(-1));
/* prediction probabilities */ /* prediction probabilities */
XTensor &output = next->prediction; XTensor &output = next->prediction;
/* encoder output */ /* encoder output */
XTensor &outputEnc = *(XTensor*)s->encoderLayers.GetItem(-1); XTensor &outputEnc = *(XTensor*)s->layersEncoding.GetItem(-1);
/* empty tensors (for masking?) */ /* empty tensors (for masking?) */
XTensor nullMask; XTensor nullMask;
...@@ -91,9 +102,9 @@ void T2TPredictor::Predict(T2TStateBundle * next) ...@@ -91,9 +102,9 @@ void T2TPredictor::Predict(T2TStateBundle * next)
/* make the decoding network and generate the output probabilities */ /* make the decoding network and generate the output probabilities */
output = decoder.Make(s->prediction, outputEnc, nullMask, nullMask, false); output = decoder.Make(s->prediction, outputEnc, nullMask, nullMask, false);
next->encoderLayers.AddList(&s->encoderLayers); next->layersEncoding.AddList(&s->layersEncoding);
next->decoderLayers.Add(&input); next->layersDecoding.Add(&input);
next->decoderLayers.Add(&output); next->layersDecoding.Add(&output);
} }
} }
......
...@@ -61,10 +61,10 @@ public: ...@@ -61,10 +61,10 @@ public:
/* layers on the encoder side. We actually use the encoder output instead /* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */ of all hidden layers. */
XList encoderLayers; XList layersEncoding;
/* layers on the decoder side */ /* layers on the decoder side */
XList decoderLayers; XList layersDecoding;
}; };
/* The predictor reads the current state and then predicts the next. /* The predictor reads the current state and then predicts the next.
...@@ -88,7 +88,7 @@ public: ...@@ -88,7 +88,7 @@ public:
~T2TPredictor(); ~T2TPredictor();
/* create an initial state */ /* create an initial state */
void Init(T2TModel * model, T2TStateBundle * state); void Init(T2TModel * model, XTensor * top, T2TStateBundle * state);
/* read a state */ /* read a state */
void Read(T2TModel * model, T2TStateBundle * state); void Read(T2TModel * model, T2TStateBundle * state);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论