Commit f14850ad by xiaotong

more code

parent f701be0e
...@@ -38,17 +38,27 @@ T2TPredictor::~T2TPredictor() ...@@ -38,17 +38,27 @@ T2TPredictor::~T2TPredictor()
} }
/* /*
create an initial state
>> model - the t2t model
>> state - the state to be initialized
*/
void T2TPredictor::Init(T2TModel * model, T2TStateBundle * state)
{
}
/*
read a state read a state
>> model - the t2t model that keeps the network created so far >> model - the t2t model that keeps the network created so far
>> current - a set of states. It keeps >> state - a set of states. It keeps
1) hypotheses (states) 1) hypotheses (states)
2) probablities of hypotheses 2) probablities of hypotheses
3) parts of the network for expanding to the next state 3) parts of the network for expanding to the next state
*/ */
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * current) void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state)
{ {
m = model; m = model;
cur = current; s = state;
} }
/* /*
...@@ -63,25 +73,25 @@ void T2TPredictor::Predict(T2TStateBundle * next) ...@@ -63,25 +73,25 @@ void T2TPredictor::Predict(T2TStateBundle * next)
AttDecoder &decoder = *m->decoder; AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */ /* word indices of previous positions */
XTensor &inputLast = *(XTensor*)cur->decoderLayers.GetItem(0); XTensor &inputLast = *(XTensor*)s->decoderLayers.GetItem(0);
/* word indices of positions up to next state */ /* word indices of positions up to next state */
XTensor &input = *NewTensor(); XTensor &input = *NewTensor();
input = Concatenate(inputLast, cur->prediction, inputLast.GetDim(-1)); input = Concatenate(inputLast, s->prediction, inputLast.GetDim(-1));
/* prediction probabilities */ /* prediction probabilities */
XTensor &output = next->prediction; XTensor &output = next->prediction;
/* encoder output */ /* encoder output */
XTensor &outputEnc = *(XTensor*)cur->encoderLayers.GetItem(-1); XTensor &outputEnc = *(XTensor*)s->encoderLayers.GetItem(-1);
/* empty tensors (for masking?) */ /* empty tensors (for masking?) */
XTensor nullMask; XTensor nullMask;
/* make the decoding network and generate the output probabilities */ /* make the decoding network and generate the output probabilities */
output = decoder.Make(cur->prediction, outputEnc, nullMask, nullMask, false); output = decoder.Make(s->prediction, outputEnc, nullMask, nullMask, false);
next->encoderLayers.AddList(&cur->encoderLayers); next->encoderLayers.AddList(&s->encoderLayers);
next->decoderLayers.Add(&input); next->decoderLayers.Add(&input);
next->decoderLayers.Add(&output); next->decoderLayers.Add(&output);
} }
......
...@@ -78,7 +78,7 @@ class T2TPredictor ...@@ -78,7 +78,7 @@ class T2TPredictor
T2TModel * m; T2TModel * m;
/* current state */ /* current state */
T2TStateBundle * cur; T2TStateBundle * s;
public: public:
/* constructor */ /* constructor */
...@@ -87,8 +87,11 @@ public: ...@@ -87,8 +87,11 @@ public:
/* de-constructor */ /* de-constructor */
~T2TPredictor(); ~T2TPredictor();
/* create an initial state */
void Init(T2TModel * model, T2TStateBundle * state);
/* read a state */ /* read a state */
void Read(T2TModel * model, T2TStateBundle * current); void Read(T2TModel * model, T2TStateBundle * state);
/* predict the next state */ /* predict the next state */
void Predict(T2TStateBundle * next); void Predict(T2TStateBundle * next);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论