Commit 000a990c by xiaotong

add end symbols

parent 03fed61d
......@@ -56,6 +56,7 @@ void T2TStateBundle::MakeStates(int num)
for(int i = 0; i < num; i++){
states[i].prediction = -1;
states[i].pid = T2T_PID_EMPTY;
states[i].isEnd = 0;
states[i].prob = 0;
states[i].probPath = 0;
states[i].modelScore = 0;
......
......@@ -44,6 +44,9 @@ public:
an empty hypothesis if id = -1 */
int pid;
/* indicates whether the state is an end */
int isEnd;
/* probability of every prediction (last state of the path) */
float prob;
......
......@@ -32,6 +32,7 @@ namespace transformer
T2TSearch::T2TSearch()
{
fullHypos = NULL;
endSymbols = NULL;
}
/* de-constructor */
......@@ -39,6 +40,8 @@ T2TSearch::~T2TSearch()
{
if(fullHypos != NULL)
delete[] fullHypos;
if(endSymbols != NULL)
delete[] endSymbols;
}
/*
......@@ -104,8 +107,11 @@ prepare for search
>> batchSize - size of the batch
>> beamSize - size of the beam
*/
void T2TSearch::Prepare(int batchSize, int beamSize)
void T2TSearch::Prepare(int myBatchSize, int myBeamSize)
{
batchSize = myBatchSize;
beamSize = myBeamSize;
if (fullHypos != NULL)
delete[] fullHypos;
......@@ -279,6 +285,14 @@ void T2TSearch::Collect(T2TStateBundle * beam)
T2TState & state = states[i];
state.pid = state.last->pid;
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
if(IsEnd(state.prediction)){
state.isEnd = 1;
}
else
state.isEnd = 0;
}
}
......@@ -290,4 +304,40 @@ void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
{
}
/*
check if the token is an end symbol
>> token - token to be checked
*/
bool T2TSearch::IsEnd(int token)
{
CheckNTErrors(endSymbolNum > 0, "No end symbol?");
for(int i = 0; i < endSymbolNum; i++){
if(endSymbols[i] == token)
return true;
}
return false;
}
/*
set end symbols for search
>> tokens - end symbols
>> tokenNum - number of the end symbols
*/
void T2TSearch::SetEnd(const int * tokens, const int tokenNum)
{
if(endSymbols != NULL)
delete[] endSymbols;
if(tokenNum <= 0)
return;
/* we may have multiple end symbols */
tokens = new int[tokenNum];
for(int i = 0; i < tokenNum; i++)
endSymbols[i] = tokens[i];
endSymbolNum = tokenNum;
}
}
......@@ -47,9 +47,18 @@ private:
/* beam size */
int beamSize;
/* batch size */
int batchSize;
/* we keep the final hypotheses in a heap for each sentence in the batch. */
XHeap<MIN_HEAP, float> * fullHypos;
/* array of the end symbols */
int * endSymbols;
/* number of the end symbols */
int endSymbolNum;
public:
/* constructor */
T2TSearch();
......@@ -64,7 +73,7 @@ public:
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* preparation */
void Prepare(int batchSize,int beamSize);
void Prepare(int myBatchSize,int myBeamSize);
/* compute the model score for each hypothesis */
void Score(T2TStateBundle * prev, T2TStateBundle * beam);
......@@ -80,6 +89,12 @@ public:
/* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output);
/* check if the token is an end symbol */
bool IsEnd(int token);
/* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论