Commit 11cd04a3 by xiaotong

predict the next state given the current state

parent 898ee241
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
*/ */
#include "T2TPredictor.h" #include "T2TPredictor.h"
#include "../../tensor/core/CHeader.h"
using namespace nts;
namespace transformer namespace transformer
{ {
...@@ -54,20 +57,20 @@ predict the next state ...@@ -54,20 +57,20 @@ predict the next state
*/ */
void T2TPredictor::Predict(T2TStateBundle * next) void T2TPredictor::Predict(T2TStateBundle * next)
{ {
next->decoderLayers.Clear();
next->encoderLayers.Clear();
AttDecoder &decoder = *m->decoder; AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */ /* word indices of previous positions */
XTensor &inputLast = *(XTensor*)cur->decoderLayers.GetItem(0); XTensor &inputLast = *(XTensor*)cur->decoderLayers.GetItem(0);
/* word indices of positions up to next state */ /* word indices of positions up to next state */
XTensor input; XTensor &input = *NewTensor();
InitTensor2D(&input, inputLast.GetDim(0), inputLast.GetDim(1) + 1, input = Concatenate(inputLast, cur->prediction, inputLast.GetDim(-1));
inputLast.dataType, inputLast.devID, inputLast.mem);
/* concatenate the input tensors */
/* prediction probabilities */ /* prediction probabilities */
XTensor output; XTensor &output = next->prediction;
/* encoder output */ /* encoder output */
XTensor &outputEnc = *(XTensor*)cur->encoderLayers.GetItem(-1); XTensor &outputEnc = *(XTensor*)cur->encoderLayers.GetItem(-1);
...@@ -75,8 +78,12 @@ void T2TPredictor::Predict(T2TStateBundle * next) ...@@ -75,8 +78,12 @@ void T2TPredictor::Predict(T2TStateBundle * next)
/* empty tensors (for masking?) */ /* empty tensors (for masking?) */
XTensor nullMask; XTensor nullMask;
/* make the decoding network */ /* make the decoding network and generate the output probabilities */
output = decoder.Make(cur->prediction, outputEnc, nullMask, nullMask, false); output = decoder.Make(cur->prediction, outputEnc, nullMask, nullMask, false);
next->encoderLayers.AddList(&cur->encoderLayers);
next->decoderLayers.Add(&input);
next->decoderLayers.Add(&output);
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论