Commit 9449403d by xiaotong

save the encoding output layer for inference

parent b3c2cf56
......@@ -211,6 +211,11 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
outputLayer->Make(decoding, output);
encoding.SetName(ENCODING_NAME);
decoding.SetName(DECODING_NAME);
output.SetName(OUTPUT_NAME);
inputEnc.SetName(ENCODING_INPUT_NAME);
}
/*
......
......@@ -31,6 +31,12 @@
namespace transformer
{
#define ENCODING_NAME "encoding"
#define DECODING_NAME "decoding"
#define ENCODING_INPUT_NAME "encoding_input"
#define DECODING_INPUT_NAME "decoding_input"
#define OUTPUT_NAME "output"
/* a transformer model that keeps parameters of the encoder,
the decoder and the output layer (softmax). Also, it creates
the network used in transformer. */
......
......@@ -40,11 +40,19 @@ T2TPredictor::~T2TPredictor()
/*
create an initial state
>> model - the t2t model
>> top - the top-most layer of the network
>> state - the state to be initialized
*/
void T2TPredictor::Init(T2TModel * model, T2TStateBundle * state)
void T2TPredictor::Init(T2TModel * model, XTensor * top, T2TStateBundle * state)
{
state->layersEncoding.Clear();
state->layersDecoding.Clear();
XTensor * encoding = XLink::SearchNode(top, ENCODING_NAME);
CheckNTErrors(encoding != NULL, "No encoding layers found!");
state->layersEncoding.Add(encoding);
state->layersDecoding.Add(NULL);
}
/*
......@@ -53,7 +61,7 @@ read a state
>> state - a set of states. It keeps
1) hypotheses (states)
2) probablities of hypotheses
3) parts of the network for expanding to the next state
3) parts of the network for expanding toward the next state
*/
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state)
{
......@@ -67,23 +75,26 @@ predict the next state
*/
void T2TPredictor::Predict(T2TStateBundle * next)
{
next->decoderLayers.Clear();
next->encoderLayers.Clear();
next->layersEncoding.Clear();
next->layersDecoding.Clear();
AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */
XTensor &inputLast = *(XTensor*)s->decoderLayers.GetItem(0);
XTensor * inputLast = (XTensor*)s->layersDecoding.GetItem(0);
/* word indices of positions up to next state */
XTensor &input = *NewTensor();
input = Concatenate(inputLast, s->prediction, inputLast.GetDim(-1));
if(inputLast == NULL)
input = s->prediction;
else
input = Concatenate(*inputLast, s->prediction, inputLast->GetDim(-1));
/* prediction probabilities */
XTensor &output = next->prediction;
/* encoder output */
XTensor &outputEnc = *(XTensor*)s->encoderLayers.GetItem(-1);
XTensor &outputEnc = *(XTensor*)s->layersEncoding.GetItem(-1);
/* empty tensors (for masking?) */
XTensor nullMask;
......@@ -91,9 +102,9 @@ void T2TPredictor::Predict(T2TStateBundle * next)
/* make the decoding network and generate the output probabilities */
output = decoder.Make(s->prediction, outputEnc, nullMask, nullMask, false);
next->encoderLayers.AddList(&s->encoderLayers);
next->decoderLayers.Add(&input);
next->decoderLayers.Add(&output);
next->layersEncoding.AddList(&s->layersEncoding);
next->layersDecoding.Add(&input);
next->layersDecoding.Add(&output);
}
}
......
......@@ -61,10 +61,10 @@ public:
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
XList encoderLayers;
XList layersEncoding;
/* layers on the decoder side */
XList decoderLayers;
XList layersDecoding;
};
/* The predictor reads the current state and then predicts the next.
......@@ -88,7 +88,7 @@ public:
~T2TPredictor();
/* create an initial state */
void Init(T2TModel * model, T2TStateBundle * state);
void Init(T2TModel * model, XTensor * top, T2TStateBundle * state);
/* read a state */
void Read(T2TModel * model, T2TStateBundle * state);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论