Commit df09abef by xiaotong

generation via beam pruning

parent 8eae2dbf
...@@ -75,35 +75,52 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -75,35 +75,52 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* predict the next state */ /* predict the next state */
predictor.Predict(next, &encoding, input, padding); predictor.Predict(next, &encoding, input, padding);
/* pruning */ /* beam pruning */
Prune(next); Generate(next);
} }
delete[] states; delete[] states;
} }
/* /*
beam pruning generate tokens for the next state via beam pruning
>> beam - the beam that keeps a number of states >> beam - the beam that keeps a number of states
*/ */
void T2TSearch::Prune(T2TStateBundle * beam) void T2TSearch::Generate(T2TStateBundle * beam)
{ {
int dims[MAX_TENSOR_DIM_NUM]; int dims[MAX_TENSOR_DIM_NUM];
int dimsBeam[MAX_TENSOR_DIM_NUM];
int dimsTopK[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK; XTensor scoreTopK;
XTensor &score = beam->score; XTensor &score = beam->score;
XTensor &index = beam->prediction; XTensor &index = beam->prediction;
int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 2] % beamSize == 0, "Wrong dimension size!");
for(int i = 0; i < score.order; i++) for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i); dims[i] = score.GetDim(i);
dims[score.order - 1] = beamSize; dimsBeam[i] = score.GetDim(i);
dimsTopK[i] = score.GetDim(i);
}
dimsBeam[order - 2] /= beamSize;
dimsBeam[order - 1] *= beamSize;
dimsTopK[order - 2] = dimsBeam[order - 2];
dimsTopK[order - 1] = beamSize;
InitTensor(&scoreTopK, score.order, score.dimSize, score.dataType, InitTensor(&scoreTopK, order, dimsTopK, score.dataType,
1.0F, score.devID, score.mem); 1.0F, score.devID, score.mem);
InitTensor(&index, score.order, score.dimSize, X_INT, InitTensor(&index, order, dimsTopK, X_INT,
1.0F, score.devID, score.mem); 1.0F, score.devID, score.mem);
score.Reshape(order, dimsBeam);
TopK(score, scoreTopK, index, 0, beamSize); TopK(score, scoreTopK, index, 0, beamSize);
score.Reshape(order, dims);
} }
/* /*
......
...@@ -28,9 +28,9 @@ ...@@ -28,9 +28,9 @@
namespace transformer namespace transformer
{ {
/* The class orgnizes the search process. It calls “predictors” to generate /* The class orgnizes the search process. It calls "predictors" to generate
distributions of the predictions and prunes the search space by beam pruning. distributions of the predictions and prunes the search space by beam pruning.
It results in a graph where each path respresents a translation hypothsis. This makes a graph where each path respresents a translation hypothsis.
The output can be the path with the highest model score. */ The output can be the path with the highest model score. */
class T2TSearch class T2TSearch
{ {
...@@ -57,8 +57,8 @@ public: ...@@ -57,8 +57,8 @@ public:
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output); void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* beam pruning */ /* generate token indices via beam pruning */
void Prune(T2TStateBundle * beam); void Generate(T2TStateBundle * beam);
/* save the output sequences in a tensor */ /* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output); void DumpOutput(T2TStateBundle * beam, XTensor * output);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论