Commit 11cd04a3 by xiaotong

predict the next state given the current state

parent 898ee241
......@@ -20,6 +20,9 @@
*/
#include "T2TPredictor.h"
#include "../../tensor/core/CHeader.h"
using namespace nts;
namespace transformer
{
......@@ -54,20 +57,20 @@ predict the next state
*/
void T2TPredictor::Predict(T2TStateBundle * next)
{
next->decoderLayers.Clear();
next->encoderLayers.Clear();
AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */
XTensor &inputLast = *(XTensor*)cur->decoderLayers.GetItem(0);
/* word indices of positions up to next state */
XTensor input;
InitTensor2D(&input, inputLast.GetDim(0), inputLast.GetDim(1) + 1,
inputLast.dataType, inputLast.devID, inputLast.mem);
/* concatenate the input tensors */
XTensor &input = *NewTensor();
input = Concatenate(inputLast, cur->prediction, inputLast.GetDim(-1));
/* prediction probabilities */
XTensor output;
XTensor &output = next->prediction;
/* encoder output */
XTensor &outputEnc = *(XTensor*)cur->encoderLayers.GetItem(-1);
......@@ -75,8 +78,12 @@ void T2TPredictor::Predict(T2TStateBundle * next)
/* empty tensors (for masking?) */
XTensor nullMask;
/* make the decoding network */
/* make the decoding network and generate the output probabilities */
output = decoder.Make(cur->prediction, outputEnc, nullMask, nullMask, false);
next->encoderLayers.AddList(&cur->encoderLayers);
next->decoderLayers.Add(&input);
next->decoderLayers.Add(&output);
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论