Commit f14850ad by xiaotong

more code

parent f701be0e
......@@ -38,17 +38,27 @@ T2TPredictor::~T2TPredictor()
}
/*
create an initial state
>> model - the t2t model
>> state - the state to be initialized
*/
void T2TPredictor::Init(T2TModel * model, T2TStateBundle * state)
{
}
/*
read a state
>> model - the t2t model that keeps the network created so far
>> current - a set of states. It keeps
>> state - a set of states. It keeps
1) hypotheses (states)
2) probablities of hypotheses
3) parts of the network for expanding to the next state
*/
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * current)
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state)
{
m = model;
cur = current;
s = state;
}
/*
......@@ -63,25 +73,25 @@ void T2TPredictor::Predict(T2TStateBundle * next)
AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */
XTensor &inputLast = *(XTensor*)cur->decoderLayers.GetItem(0);
XTensor &inputLast = *(XTensor*)s->decoderLayers.GetItem(0);
/* word indices of positions up to next state */
XTensor &input = *NewTensor();
input = Concatenate(inputLast, cur->prediction, inputLast.GetDim(-1));
input = Concatenate(inputLast, s->prediction, inputLast.GetDim(-1));
/* prediction probabilities */
XTensor &output = next->prediction;
/* encoder output */
XTensor &outputEnc = *(XTensor*)cur->encoderLayers.GetItem(-1);
XTensor &outputEnc = *(XTensor*)s->encoderLayers.GetItem(-1);
/* empty tensors (for masking?) */
XTensor nullMask;
/* make the decoding network and generate the output probabilities */
output = decoder.Make(cur->prediction, outputEnc, nullMask, nullMask, false);
output = decoder.Make(s->prediction, outputEnc, nullMask, nullMask, false);
next->encoderLayers.AddList(&cur->encoderLayers);
next->encoderLayers.AddList(&s->encoderLayers);
next->decoderLayers.Add(&input);
next->decoderLayers.Add(&output);
}
......
......@@ -78,7 +78,7 @@ class T2TPredictor
T2TModel * m;
/* current state */
T2TStateBundle * cur;
T2TStateBundle * s;
public:
/* constructor */
......@@ -87,8 +87,11 @@ public:
/* de-constructor */
~T2TPredictor();
/* create an initial state */
void Init(T2TModel * model, T2TStateBundle * state);
/* read a state */
void Read(T2TModel * model, T2TStateBundle * current);
void Read(T2TModel * model, T2TStateBundle * state);
/* predict the next state */
void Predict(T2TStateBundle * next);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论