Commit fdd64fb0 by xiaotong

new code for the heap in beam search

parent f0937eab
...@@ -55,12 +55,15 @@ void T2TStateBundle::MakeStates(int num) ...@@ -55,12 +55,15 @@ void T2TStateBundle::MakeStates(int num)
for(int i = 0; i < num; i++){ for(int i = 0; i < num; i++){
states[i].prediction = -1; states[i].prediction = -1;
states[i].pid = T2T_PID_EMPTY;
states[i].prob = 0; states[i].prob = 0;
states[i].probPath = 0; states[i].probPath = 0;
states[i].modelScore = 0; states[i].modelScore = 0;
states[i].nstep = 0; states[i].nstep = 0;
states[i].last = NULL; states[i].last = NULL;
} }
stateNum = num;
} }
/* constructor */ /* constructor */
......
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
namespace transformer namespace transformer
{ {
#define T2T_PID_EMPTY -1
/* state for search. It keeps the path (back-pointer), prediction distribution, /* state for search. It keeps the path (back-pointer), prediction distribution,
and etc. It can be regarded as a hypothsis in translation. */ and etc. It can be regarded as a hypothsis in translation. */
class T2TState class T2TState
...@@ -37,6 +39,11 @@ public: ...@@ -37,6 +39,11 @@ public:
/* we assume that the prediction is an integer */ /* we assume that the prediction is an integer */
int prediction; int prediction;
/* id of the problem. One can regard as the sentence id when we
translated a number of sentences in the batched manner. It is
an empty hypothesis if id = -1 */
int pid;
/* probability of every prediction (last state of the path) */ /* probability of every prediction (last state of the path) */
float prob; float prob;
...@@ -85,6 +92,9 @@ public: ...@@ -85,6 +92,9 @@ public:
/* list of states */ /* list of states */
T2TState * states; T2TState * states;
/* number of states */
int stateNum;
public: public:
/* constructor */ /* constructor */
T2TStateBundle(); T2TStateBundle();
......
...@@ -87,6 +87,22 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -87,6 +87,22 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
} }
/* /*
prepare for search
>> batchSize - size of the batch
>> beamSize - size of the beam
*/
void T2TSearch::Prepare(int batchSize, int beamSize)
{
if (heaps != NULL)
delete[] heaps;
heaps = new XHeap<MIN_HEAP, float>[batchSize];
for (int i = 0; i < batchSize; i++)
heaps[i].Init(beamSize);
}
/*
compute the model score for each hypothesis compute the model score for each hypothesis
>> prev - the beam of the previous state >> prev - the beam of the previous state
>> beam - the beam that keeps a number of states >> beam - the beam that keeps a number of states
...@@ -197,6 +213,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -197,6 +213,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
XTensor & modelScoreRef = beam->modelScore; XTensor & modelScoreRef = beam->modelScore;
XTensor & probRef = beam->prob; XTensor & probRef = beam->prob;
XTensor & probPathRef = beam->probPath; XTensor & probPathRef = beam->probPath;
XTensor & prediction = beam->prediction;
XTensor id; XTensor id;
XTensor modelScore; XTensor modelScore;
XTensor prob; XTensor prob;
...@@ -213,8 +230,14 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -213,8 +230,14 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
CopyValues(modelScoreRef, modelScore); CopyValues(modelScoreRef, modelScore);
CopyValues(prob, probRef); CopyValues(prob, probRef);
CopyValues(probPathRef, probPath); CopyValues(probPathRef, probPath);
CheckNTErrors(beam->stateNum == id.unitNum, "Errors occur in counting!");
for(int i = 0; i < id.unitNum; i++){ /* we keep information on the states of the graph. All these are maintained
on CPUs to ease the implementation of requent access and modification of
the states. An alternative is to do this on GPUs but it needs much more
coding work and the speed-up is not obvious. */
for(int i = 0; i < beam->stateNum; i++){
T2TState & state = states[i]; T2TState & state = states[i];
/* pointer to the previous state */ /* pointer to the previous state */
...@@ -224,6 +247,23 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -224,6 +247,23 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
state.modelScore = modelScore.Get(i); state.modelScore = modelScore.Get(i);
state.prob = prob.Get(i); state.prob = prob.Get(i);
state.probPath = probPath.Get(i); state.probPath = probPath.Get(i);
/* prediction */
state.prediction = prediction.GetInt(i);
}
}
/*
collect hypotheses with ending symbol. Given a beam of hypotheses,
we remove the finished hypotheses and keep them in a heap.
>> beam - the beam that keeps a number of states
*/
void T2TSearch::Collect(T2TStateBundle * beam)
{
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
} }
} }
......
...@@ -47,6 +47,9 @@ private: ...@@ -47,6 +47,9 @@ private:
/* beam size */ /* beam size */
int beamSize; int beamSize;
/* we keep the final hypotheses in a heap for each sentence in the batch. */
XHeap<MIN_HEAP, float> * heaps;
public: public:
/* constructor */ /* constructor */
T2TSearch() {}; T2TSearch() {};
...@@ -60,6 +63,9 @@ public: ...@@ -60,6 +63,9 @@ public:
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output); void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* preparation */
void Prepare(int batchSize,int beamSize);
/* compute the model score for each hypothesis */ /* compute the model score for each hypothesis */
void Score(T2TStateBundle * prev, T2TStateBundle * beam); void Score(T2TStateBundle * prev, T2TStateBundle * beam);
...@@ -69,6 +75,9 @@ public: ...@@ -69,6 +75,9 @@ public:
/* expand the search graph */ /* expand the search graph */
void Expand(T2TStateBundle * prev, T2TStateBundle * beam); void Expand(T2TStateBundle * prev, T2TStateBundle * beam);
/* collect hypotheses with ending symbol */
void Collect(T2TStateBundle * beam);
/* save the output sequences in a tensor */ /* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output); void DumpOutput(T2TStateBundle * beam, XTensor * output);
}; };
......
...@@ -31,15 +31,15 @@ namespace nts{ ...@@ -31,15 +31,15 @@ namespace nts{
/* constructor */ /* constructor */
template<HeapType hType, typename T> template<HeapType hType, typename T>
XHeap<hType, T>::XHeap()
{
}
/* constructor */
template<HeapType hType, typename T>
XHeap<hType, T>::XHeap(int mySize, XMem * myMem) XHeap<hType, T>::XHeap(int mySize, XMem * myMem)
{ {
mem = myMem; Init(mySize, myMem);
size = mySize;
count = 0;
if (mem == NULL)
items = new HeapNode<T>[mySize];
else
mem->Alloc(mem->devID, mySize * sizeof(T));
} }
/* deconstructor */ /* deconstructor */
...@@ -50,6 +50,19 @@ XHeap<hType, T>::~XHeap() ...@@ -50,6 +50,19 @@ XHeap<hType, T>::~XHeap()
} }
template<HeapType hType, typename T> template<HeapType hType, typename T>
void XHeap<hType, T>::Init(int mySize, XMem * myMem)
{
mem = myMem;
size = mySize;
count = 0;
if (mem == NULL)
items = new HeapNode<T>[mySize];
else
mem->Alloc(mem->devID, mySize * sizeof(T));
}
template<HeapType hType, typename T>
void XHeap<hType, T>::Clear(T initValue) void XHeap<hType, T>::Clear(T initValue)
{ {
count = 0; count = 0;
......
...@@ -76,11 +76,17 @@ public: ...@@ -76,11 +76,17 @@ public:
public: public:
/* constructor */ /* constructor */
XHeap();
/* constructor */
XHeap(int mySize, XMem * myMem = NULL); XHeap(int mySize, XMem * myMem = NULL);
/* deconstructor */ /* deconstructor */
~XHeap(); ~XHeap();
/* initialization */
void Init(int mySize, XMem * myMem = NULL);
/* clear the data */ /* clear the data */
void Clear(T initValue); void Clear(T initValue);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论