Commit 0aac9d31 by xiaotong

code of searcher and predictor

parent d68b19b6
......@@ -43,16 +43,16 @@ create an initial state
>> top - the top-most layer of the network
>> state - the state to be initialized
*/
void T2TPredictor::Init(T2TModel * model, XTensor * top, T2TStateBundle * state)
void T2TPredictor::Create(T2TModel * model, XTensor * top, T2TStateBundle * state)
{
state->layersEncoding.Clear();
state->layersDecoding.Clear();
state->layersEnc.Clear();
state->layersDec.Clear();
XTensor * encoding = XLink::SearchNode(top, ENCODING_NAME);
CheckNTErrors(encoding != NULL, "No encoding layers found!");
state->layersEncoding.Add(encoding);
state->layersDecoding.Add(NULL);
state->layersEnc.Add(encoding);
state->layersDec.Add(NULL);
}
/*
......@@ -72,39 +72,46 @@ void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state)
/*
predict the next state
>> next - next states (assuming that the current state has been read)
>> encoding - encoder output
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder
*/
void T2TPredictor::Predict(T2TStateBundle * next)
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc)
{
next->layersEncoding.Clear();
next->layersDecoding.Clear();
next->layersEnc.Clear();
next->layersDec.Clear();
AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */
XTensor * inputLast = (XTensor*)s->layersDecoding.GetItem(0);
XTensor * inputLast = (XTensor*)s->layersDec.GetItem(0);
/* word indices of positions up to next state */
XTensor &input = *NewTensor();
XTensor &inputDec = *NewTensor();
if(inputLast == NULL)
input = s->prediction;
inputDec = s->prediction;
else
input = Concatenate(*inputLast, s->prediction, inputLast->GetDim(-1));
inputDec = Concatenate(*inputLast, s->prediction, inputLast->GetDim(-1));
/* prediction probabilities */
XTensor &output = next->prediction;
XTensor &output = next->score;
/* encoder output */
XTensor &outputEnc = *(XTensor*)s->layersEncoding.GetItem(-1);
XTensor paddingDec;
InitTensor3D(&paddingDec, inputDec.GetDim(0), inputDec.GetDim(1), m->outputLayer->vSize, X_INT);
SetDataFixedInt(paddingDec, 1);
/* empty tensors (for masking?) */
XTensor nullMask;
XTensor maskDec;
XTensor maskEncDec;
/* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network and generate the output probabilities */
output = decoder.Make(s->prediction, outputEnc, nullMask, nullMask, false);
output = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
next->layersEncoding.AddList(&s->layersEncoding);
next->layersDecoding.Add(&input);
next->layersDecoding.Add(&output);
next->layersEnc.AddList(&s->layersEnc);
next->layersDec.Add(&inputDec);
next->layersDec.Add(&output);
}
}
......
......@@ -36,11 +36,11 @@ public:
/* we assume that the prediction is an integer */
int prediction;
/* probability of the prediction */
float prob;
/* score of the prediction */
float score;
/* probability of the path */
float pathProb;
/* score of the path */
float scorePath;
/* pointer to the previous state */
T2TState * last;
......@@ -53,18 +53,18 @@ public:
/* predictions */
XTensor prediction;
/* distribution of every prediction (last state of the path) */
XTensor probs;
/* score of every prediction (last state of the path) */
XTensor score;
/* distribution of every path */
XTensor pathProbs;
/* score of every path */
XTensor scorePath;
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
XList layersEncoding;
XList layersEnc;
/* layers on the decoder side */
XList layersDecoding;
XList layersDec;
};
/* The predictor reads the current state and then predicts the next.
......@@ -74,6 +74,7 @@ public:
indices, hidden states, embeddings and etc.). */
class T2TPredictor
{
private:
/* pointer to the transformer model */
T2TModel * m;
......@@ -88,13 +89,13 @@ public:
~T2TPredictor();
/* create an initial state */
void Init(T2TModel * model, XTensor * top, T2TStateBundle * state);
void Create(T2TModel * model, XTensor * top, T2TStateBundle * state);
/* read a state */
void Read(T2TModel * model, T2TStateBundle * state);
/* predict the next state */
void Predict(T2TStateBundle * next);
void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc);
};
}
......
......@@ -20,6 +20,8 @@
*/
#include "T2TSearch.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
using namespace nts;
......@@ -27,6 +29,16 @@ namespace transformer
{
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TSearch::InitModel(int argc, char ** argv)
{
LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
}
/*
search for the most promising states
>> model - the transformer model
>> input - input of the model
......@@ -37,6 +49,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
{
XTensor maskEnc;
XTensor encoding;
T2TPredictor predictor;
/* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc);
......@@ -45,30 +58,28 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME);
T2TPredictor predictor;
T2TStateBundle state1, state2;
T2TStateBundle * cur = &state1;
T2TStateBundle * next = &state2;
T2TStateBundle * states = new T2TStateBundle[maxLength];
T2TStateBundle * first = states;
/* initialize the predictor */
predictor.Init(model, &encoding, cur);
/* create the first state */
predictor.Create(model, &encoding, first);
/* generate the sequence from left-to-right */
/* generate the sequence from left to right */
for(int i = 0 ; i < maxLength; i++){
T2TStateBundle * cur = states + i;
T2TStateBundle * next = states + i + 1;
/* read the current state */
predictor.Read(model, cur);
/* predict the next state */
predictor.Predict(next);
predictor.Predict(next, &encoding, input, padding);
/* pruning */
Prune(next);
T2TStateBundle * backup = cur;
cur = next;
next = backup;
}
delete[] states;
}
/*
......@@ -77,11 +88,27 @@ beam pruning
*/
void T2TSearch::Prune(T2TStateBundle * beam)
{
int dims[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK;
XTensor &score = beam->score;
XTensor &index = beam->prediction;
for(int i = 0; i < score.order; i++)
dims[i] = score.GetDim(i);
dims[score.order - 1] = beamSize;
InitTensor(&scoreTopK, score.order, score.dimSize, score.dataType,
1.0F, score.devID, score.mem);
InitTensor(&index, score.order, score.dimSize, X_INT,
1.0F, score.devID, score.mem);
TopK(score, scoreTopK, index, 0, beamSize);
}
/*
save the output sequences in a tensor
>> beam -
>> beam - the beam that keeps a number of states
*/
void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
{
......
......@@ -34,10 +34,16 @@ namespace transformer
The output can be the path with the highest model score. */
class T2TSearch
{
public:
private:
/* predictor */
T2TPredictor predictor;
/* max length of the generated sequence */
int maxLength;
/* beam size */
int beamSize;
public:
/* constructor */
T2TSearch() {};
......@@ -45,6 +51,9 @@ public:
/* de-constructor */
~T2TSearch() {};
/* initialize the model */
void InitModel(int argc, char ** argv);
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论