Commit fdd64fb0 by xiaotong

new code for the heap in beam search

parent f0937eab
......@@ -55,12 +55,15 @@ void T2TStateBundle::MakeStates(int num)
for(int i = 0; i < num; i++){
states[i].prediction = -1;
states[i].pid = T2T_PID_EMPTY;
states[i].prob = 0;
states[i].probPath = 0;
states[i].modelScore = 0;
states[i].nstep = 0;
states[i].last = NULL;
}
stateNum = num;
}
/* constructor */
......
......@@ -29,6 +29,8 @@
namespace transformer
{
#define T2T_PID_EMPTY -1
/* state for search. It keeps the path (back-pointer), prediction distribution,
and etc. It can be regarded as a hypothsis in translation. */
class T2TState
......@@ -37,6 +39,11 @@ public:
/* we assume that the prediction is an integer */
int prediction;
/* id of the problem. One can regard as the sentence id when we
translated a number of sentences in the batched manner. It is
an empty hypothesis if id = -1 */
int pid;
/* probability of every prediction (last state of the path) */
float prob;
......@@ -85,6 +92,9 @@ public:
/* list of states */
T2TState * states;
/* number of states */
int stateNum;
public:
/* constructor */
T2TStateBundle();
......
......@@ -87,6 +87,22 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
}
/*
prepare for search
>> batchSize - size of the batch
>> beamSize - size of the beam
*/
void T2TSearch::Prepare(int batchSize, int beamSize)
{
if (heaps != NULL)
delete[] heaps;
heaps = new XHeap<MIN_HEAP, float>[batchSize];
for (int i = 0; i < batchSize; i++)
heaps[i].Init(beamSize);
}
/*
compute the model score for each hypothesis
>> prev - the beam of the previous state
>> beam - the beam that keeps a number of states
......@@ -197,6 +213,7 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
XTensor & modelScoreRef = beam->modelScore;
XTensor & probRef = beam->prob;
XTensor & probPathRef = beam->probPath;
XTensor & prediction = beam->prediction;
XTensor id;
XTensor modelScore;
XTensor prob;
......@@ -213,8 +230,14 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
CopyValues(modelScoreRef, modelScore);
CopyValues(prob, probRef);
CopyValues(probPathRef, probPath);
CheckNTErrors(beam->stateNum == id.unitNum, "Errors occur in counting!");
for(int i = 0; i < id.unitNum; i++){
/* we keep information on the states of the graph. All these are maintained
on CPUs to ease the implementation of requent access and modification of
the states. An alternative is to do this on GPUs but it needs much more
coding work and the speed-up is not obvious. */
for(int i = 0; i < beam->stateNum; i++){
T2TState & state = states[i];
/* pointer to the previous state */
......@@ -224,6 +247,23 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
state.modelScore = modelScore.Get(i);
state.prob = prob.Get(i);
state.probPath = probPath.Get(i);
/* prediction */
state.prediction = prediction.GetInt(i);
}
}
/*
collect hypotheses with ending symbol. Given a beam of hypotheses,
we remove the finished hypotheses and keep them in a heap.
>> beam - the beam that keeps a number of states
*/
void T2TSearch::Collect(T2TStateBundle * beam)
{
T2TState * states = beam->states;
for (int i = 0; i < beam->stateNum; i++) {
T2TState & state = states[i];
}
}
......
......@@ -47,6 +47,9 @@ private:
/* beam size */
int beamSize;
/* we keep the final hypotheses in a heap for each sentence in the batch. */
XHeap<MIN_HEAP, float> * heaps;
public:
/* constructor */
T2TSearch() {};
......@@ -60,6 +63,9 @@ public:
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* preparation */
void Prepare(int batchSize,int beamSize);
/* compute the model score for each hypothesis */
void Score(T2TStateBundle * prev, T2TStateBundle * beam);
......@@ -69,6 +75,9 @@ public:
/* expand the search graph */
void Expand(T2TStateBundle * prev, T2TStateBundle * beam);
/* collect hypotheses with ending symbol */
void Collect(T2TStateBundle * beam);
/* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output);
};
......
......@@ -31,15 +31,15 @@ namespace nts{
/* constructor */
template<HeapType hType, typename T>
XHeap<hType, T>::XHeap()
{
}
/* constructor */
template<HeapType hType, typename T>
XHeap<hType, T>::XHeap(int mySize, XMem * myMem)
{
mem = myMem;
size = mySize;
count = 0;
if (mem == NULL)
items = new HeapNode<T>[mySize];
else
mem->Alloc(mem->devID, mySize * sizeof(T));
Init(mySize, myMem);
}
/* deconstructor */
......@@ -50,6 +50,19 @@ XHeap<hType, T>::~XHeap()
}
template<HeapType hType, typename T>
void XHeap<hType, T>::Init(int mySize, XMem * myMem)
{
mem = myMem;
size = mySize;
count = 0;
if (mem == NULL)
items = new HeapNode<T>[mySize];
else
mem->Alloc(mem->devID, mySize * sizeof(T));
}
template<HeapType hType, typename T>
void XHeap<hType, T>::Clear(T initValue)
{
count = 0;
......
......@@ -76,11 +76,17 @@ public:
public:
/* constructor */
XHeap();
/* constructor */
XHeap(int mySize, XMem * myMem = NULL);
/* deconstructor */
~XHeap();
/* initialization */
void Init(int mySize, XMem * myMem = NULL);
/* clear the data */
void Clear(T initValue);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论