Commit 04412ff1 by xiaotong

read and predict

parent f9cfdf9a
......@@ -23,5 +23,61 @@
namespace transformer
{
/* constructor */
T2TPredictor::T2TPredictor()
{
}
/* de-constructor */
T2TPredictor::~T2TPredictor()
{
}
/*
read a state
>> model - the t2t model that keeps the network created so far
>> current - a set of states. It keeps
1) hypotheses (states)
2) probablities of hypotheses
3) parts of the network for expanding to the next state
*/
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * current)
{
m = model;
cur = current;
}
/*
predict the next state
>> next - next states (assuming that the current state has been read)
*/
void T2TPredictor::Predict(T2TStateBundle * next)
{
AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */
XTensor &inputLast = *(XTensor*)cur->decoderLayers.GetItem(0);
/* word indices of positions up to next state */
XTensor input;
InitTensor2D(&input, inputLast.GetDim(0), inputLast.GetDim(1) + 1,
inputLast.dataType, inputLast.devID, inputLast.mem);
/* concatenate the input tensors */
/* prediction probabilities */
XTensor output;
/* encoder output */
XTensor &outputEnc = *(XTensor*)cur->encoderLayers.GetItem(-1);
/* empty tensors (for masking?) */
XTensor nullMask;
/* make the decoding network */
output = decoder.Make(cur->prediction, outputEnc, nullMask, nullMask, false);
}
}
......@@ -29,10 +29,11 @@ namespace transformer
{
/* state for search. It keeps the path (back-pointer), prediction distribution,
and etc. */
and etc. It can be regarded as a hypothsis in translation. */
class T2TState
{
/* we assume that the prediction is an integer number */
public:
/* we assume that the prediction is an integer */
int prediction;
/* probability of the prediction */
......@@ -41,16 +42,17 @@ class T2TState
/* probability of the path */
float pathProb;
/* pointer to the last state */
/* pointer to the previous state */
T2TState * last;
/* pointers to the following states */
XList * followings;
};
/* a bundle of states */
class T2TStateBundle
{
public:
/* predictions */
XTensor prediction;
/* distribution of every prediction (last state of the path) */
XTensor probs;
......@@ -73,7 +75,10 @@ class T2TStateBundle
class T2TPredictor
{
/* pointer to the transformer model */
T2TModel * model;
T2TModel * m;
/* current state */
T2TStateBundle * cur;
public:
/* constructor */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论