Commit df09abef by xiaotong

generation via beam pruning

parent 8eae2dbf
......@@ -75,35 +75,52 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* predict the next state */
predictor.Predict(next, &encoding, input, padding);
/* pruning */
Prune(next);
/* beam pruning */
Generate(next);
}
delete[] states;
}
/*
beam pruning
generate tokens for the next state via beam pruning
>> beam - the beam that keeps a number of states
*/
void T2TSearch::Prune(T2TStateBundle * beam)
void T2TSearch::Generate(T2TStateBundle * beam)
{
int dims[MAX_TENSOR_DIM_NUM];
int dimsBeam[MAX_TENSOR_DIM_NUM];
int dimsTopK[MAX_TENSOR_DIM_NUM];
XTensor scoreTopK;
XTensor &score = beam->score;
XTensor &index = beam->prediction;
int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 2] % beamSize == 0, "Wrong dimension size!");
for(int i = 0; i < score.order; i++)
for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i);
dims[score.order - 1] = beamSize;
dimsBeam[i] = score.GetDim(i);
dimsTopK[i] = score.GetDim(i);
}
dimsBeam[order - 2] /= beamSize;
dimsBeam[order - 1] *= beamSize;
dimsTopK[order - 2] = dimsBeam[order - 2];
dimsTopK[order - 1] = beamSize;
InitTensor(&scoreTopK, score.order, score.dimSize, score.dataType,
InitTensor(&scoreTopK, order, dimsTopK, score.dataType,
1.0F, score.devID, score.mem);
InitTensor(&index, score.order, score.dimSize, X_INT,
InitTensor(&index, order, dimsTopK, X_INT,
1.0F, score.devID, score.mem);
score.Reshape(order, dimsBeam);
TopK(score, scoreTopK, index, 0, beamSize);
score.Reshape(order, dims);
}
/*
......
......@@ -28,9 +28,9 @@
namespace transformer
{
/* The class orgnizes the search process. It calls “predictors” to generate
/* The class orgnizes the search process. It calls "predictors" to generate
distributions of the predictions and prunes the search space by beam pruning.
It results in a graph where each path respresents a translation hypothsis.
This makes a graph where each path respresents a translation hypothsis.
The output can be the path with the highest model score. */
class T2TSearch
{
......@@ -57,8 +57,8 @@ public:
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* beam pruning */
void Prune(T2TStateBundle * beam);
/* generate token indices via beam pruning */
void Generate(T2TStateBundle * beam);
/* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论