Commit 0aac9d31 by xiaotong

code of searcher and predictor

parent d68b19b6
...@@ -43,16 +43,16 @@ create an initial state ...@@ -43,16 +43,16 @@ create an initial state
>> top - the top-most layer of the network >> top - the top-most layer of the network
>> state - the state to be initialized >> state - the state to be initialized
*/ */
void T2TPredictor::Init(T2TModel * model, XTensor * top, T2TStateBundle * state) void T2TPredictor::Create(T2TModel * model, XTensor * top, T2TStateBundle * state)
{ {
state->layersEncoding.Clear(); state->layersEnc.Clear();
state->layersDecoding.Clear(); state->layersDec.Clear();
XTensor * encoding = XLink::SearchNode(top, ENCODING_NAME); XTensor * encoding = XLink::SearchNode(top, ENCODING_NAME);
CheckNTErrors(encoding != NULL, "No encoding layers found!"); CheckNTErrors(encoding != NULL, "No encoding layers found!");
state->layersEncoding.Add(encoding); state->layersEnc.Add(encoding);
state->layersDecoding.Add(NULL); state->layersDec.Add(NULL);
} }
/* /*
...@@ -72,39 +72,46 @@ void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state) ...@@ -72,39 +72,46 @@ void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state)
/* /*
predict the next state predict the next state
>> next - next states (assuming that the current state has been read) >> next - next states (assuming that the current state has been read)
>> encoding - encoder output
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder
*/ */
void T2TPredictor::Predict(T2TStateBundle * next) void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc)
{ {
next->layersEncoding.Clear(); next->layersEnc.Clear();
next->layersDecoding.Clear(); next->layersDec.Clear();
AttDecoder &decoder = *m->decoder; AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */ /* word indices of previous positions */
XTensor * inputLast = (XTensor*)s->layersDecoding.GetItem(0); XTensor * inputLast = (XTensor*)s->layersDec.GetItem(0);
/* word indices of positions up to next state */ /* word indices of positions up to next state */
XTensor &input = *NewTensor(); XTensor &inputDec = *NewTensor();
if(inputLast == NULL) if(inputLast == NULL)
input = s->prediction; inputDec = s->prediction;
else else
input = Concatenate(*inputLast, s->prediction, inputLast->GetDim(-1)); inputDec = Concatenate(*inputLast, s->prediction, inputLast->GetDim(-1));
/* prediction probabilities */ /* prediction probabilities */
XTensor &output = next->prediction; XTensor &output = next->score;
/* encoder output */ XTensor paddingDec;
XTensor &outputEnc = *(XTensor*)s->layersEncoding.GetItem(-1); InitTensor3D(&paddingDec, inputDec.GetDim(0), inputDec.GetDim(1), m->outputLayer->vSize, X_INT);
SetDataFixedInt(paddingDec, 1);
/* empty tensors (for masking?) */ XTensor maskDec;
XTensor nullMask; XTensor maskEncDec;
/* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network and generate the output probabilities */ /* make the decoding network and generate the output probabilities */
output = decoder.Make(s->prediction, outputEnc, nullMask, nullMask, false); output = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
next->layersEncoding.AddList(&s->layersEncoding); next->layersEnc.AddList(&s->layersEnc);
next->layersDecoding.Add(&input); next->layersDec.Add(&inputDec);
next->layersDecoding.Add(&output); next->layersDec.Add(&output);
} }
} }
......
...@@ -36,11 +36,11 @@ public: ...@@ -36,11 +36,11 @@ public:
/* we assume that the prediction is an integer */ /* we assume that the prediction is an integer */
int prediction; int prediction;
/* probability of the prediction */ /* score of the prediction */
float prob; float score;
/* probability of the path */ /* score of the path */
float pathProb; float scorePath;
/* pointer to the previous state */ /* pointer to the previous state */
T2TState * last; T2TState * last;
...@@ -53,18 +53,18 @@ public: ...@@ -53,18 +53,18 @@ public:
/* predictions */ /* predictions */
XTensor prediction; XTensor prediction;
/* distribution of every prediction (last state of the path) */ /* score of every prediction (last state of the path) */
XTensor probs; XTensor score;
/* distribution of every path */ /* score of every path */
XTensor pathProbs; XTensor scorePath;
/* layers on the encoder side. We actually use the encoder output instead /* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */ of all hidden layers. */
XList layersEncoding; XList layersEnc;
/* layers on the decoder side */ /* layers on the decoder side */
XList layersDecoding; XList layersDec;
}; };
/* The predictor reads the current state and then predicts the next. /* The predictor reads the current state and then predicts the next.
...@@ -74,6 +74,7 @@ public: ...@@ -74,6 +74,7 @@ public:
indices, hidden states, embeddings and etc.). */ indices, hidden states, embeddings and etc.). */
class T2TPredictor class T2TPredictor
{ {
private:
/* pointer to the transformer model */ /* pointer to the transformer model */
T2TModel * m; T2TModel * m;
...@@ -88,13 +89,13 @@ public: ...@@ -88,13 +89,13 @@ public:
~T2TPredictor(); ~T2TPredictor();
/* create an initial state */ /* create an initial state */
void Init(T2TModel * model, XTensor * top, T2TStateBundle * state); void Create(T2TModel * model, XTensor * top, T2TStateBundle * state);
/* read a state */ /* read a state */
void Read(T2TModel * model, T2TStateBundle * state); void Read(T2TModel * model, T2TStateBundle * state);
/* predict the next state */ /* predict the next state */
void Predict(T2TStateBundle * next); void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc);
}; };
} }
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
*/ */
#include "T2TSearch.h" #include "T2TSearch.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
using namespace nts; using namespace nts;
...@@ -27,6 +29,16 @@ namespace transformer ...@@ -27,6 +29,16 @@ namespace transformer
{ {
/* /*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TSearch::InitModel(int argc, char ** argv)
{
LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
}
/*
search for the most promising states search for the most promising states
>> model - the transformer model >> model - the transformer model
>> input - input of the model >> input - input of the model
...@@ -37,6 +49,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -37,6 +49,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
{ {
XTensor maskEnc; XTensor maskEnc;
XTensor encoding; XTensor encoding;
T2TPredictor predictor;
/* encoder mask */ /* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc); model->MakeMTMaskEnc(*input, *padding, maskEnc);
...@@ -45,30 +58,28 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -45,30 +58,28 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
encoding = model->MakeEncoder(*input, maskEnc, false); encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME); encoding.SetName(ENCODING_NAME);
T2TPredictor predictor; T2TStateBundle * states = new T2TStateBundle[maxLength];
T2TStateBundle state1, state2; T2TStateBundle * first = states;
T2TStateBundle * cur = &state1;
T2TStateBundle * next = &state2;
/* initialize the predictor */ /* create the first state */
predictor.Init(model, &encoding, cur); predictor.Create(model, &encoding, first);
/* generate the sequence from left-to-right */ /* generate the sequence from left to right */
for(int i = 0 ; i < maxLength; i++){ for(int i = 0 ; i < maxLength; i++){
T2TStateBundle * cur = states + i;
T2TStateBundle * next = states + i + 1;
/* read the current state */ /* read the current state */
predictor.Read(model, cur); predictor.Read(model, cur);
/* predict the next state */ /* predict the next state */
predictor.Predict(next); predictor.Predict(next, &encoding, input, padding);
/* pruning */ /* pruning */
Prune(next); Prune(next);
T2TStateBundle * backup = cur;
cur = next;
next = backup;
} }
delete[] states;
} }
/* /*
...@@ -77,11 +88,27 @@ beam pruning ...@@ -77,11 +88,27 @@ beam pruning
*/ */
void T2TSearch::Prune(T2TStateBundle * beam) void T2TSearch::Prune(T2TStateBundle * beam)
{ {
int dims[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK;
XTensor &score = beam->score;
XTensor &index = beam->prediction;
for(int i = 0; i < score.order; i++)
dims[i] = score.GetDim(i);
dims[score.order - 1] = beamSize;
InitTensor(&scoreTopK, score.order, score.dimSize, score.dataType,
1.0F, score.devID, score.mem);
InitTensor(&index, score.order, score.dimSize, X_INT,
1.0F, score.devID, score.mem);
TopK(score, scoreTopK, index, 0, beamSize);
} }
/* /*
save the output sequences in a tensor save the output sequences in a tensor
>> beam - >> beam - the beam that keeps a number of states
*/ */
void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output) void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
{ {
......
...@@ -34,10 +34,16 @@ namespace transformer ...@@ -34,10 +34,16 @@ namespace transformer
The output can be the path with the highest model score. */ The output can be the path with the highest model score. */
class T2TSearch class T2TSearch
{ {
public: private:
/* predictor */
T2TPredictor predictor;
/* max length of the generated sequence */ /* max length of the generated sequence */
int maxLength; int maxLength;
/* beam size */
int beamSize;
public: public:
/* constructor */ /* constructor */
T2TSearch() {}; T2TSearch() {};
...@@ -45,6 +51,9 @@ public: ...@@ -45,6 +51,9 @@ public:
/* de-constructor */ /* de-constructor */
~T2TSearch() {}; ~T2TSearch() {};
/* initialize the model */
void InitModel(int argc, char ** argv);
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output); void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论