Commit 000a990c by xiaotong

add end symbols

parent 03fed61d
...@@ -56,6 +56,7 @@ void T2TStateBundle::MakeStates(int num) ...@@ -56,6 +56,7 @@ void T2TStateBundle::MakeStates(int num)
for(int i = 0; i < num; i++){ for(int i = 0; i < num; i++){
states[i].prediction = -1; states[i].prediction = -1;
states[i].pid = T2T_PID_EMPTY; states[i].pid = T2T_PID_EMPTY;
states[i].isEnd = 0;
states[i].prob = 0; states[i].prob = 0;
states[i].probPath = 0; states[i].probPath = 0;
states[i].modelScore = 0; states[i].modelScore = 0;
......
...@@ -44,6 +44,9 @@ public: ...@@ -44,6 +44,9 @@ public:
an empty hypothesis if id = -1 */ an empty hypothesis if id = -1 */
int pid; int pid;
/* indicates whether the state is an end */
int isEnd;
/* probability of every prediction (last state of the path) */ /* probability of every prediction (last state of the path) */
float prob; float prob;
......
...@@ -32,6 +32,7 @@ namespace transformer ...@@ -32,6 +32,7 @@ namespace transformer
T2TSearch::T2TSearch() T2TSearch::T2TSearch()
{ {
fullHypos = NULL; fullHypos = NULL;
endSymbols = NULL;
} }
/* de-constructor */ /* de-constructor */
...@@ -39,6 +40,8 @@ T2TSearch::~T2TSearch() ...@@ -39,6 +40,8 @@ T2TSearch::~T2TSearch()
{ {
if(fullHypos != NULL) if(fullHypos != NULL)
delete[] fullHypos; delete[] fullHypos;
if(endSymbols != NULL)
delete[] endSymbols;
} }
/* /*
...@@ -104,8 +107,11 @@ prepare for search ...@@ -104,8 +107,11 @@ prepare for search
>> batchSize - size of the batch >> batchSize - size of the batch
>> beamSize - size of the beam >> beamSize - size of the beam
*/ */
void T2TSearch::Prepare(int batchSize, int beamSize) void T2TSearch::Prepare(int myBatchSize, int myBeamSize)
{ {
batchSize = myBatchSize;
beamSize = myBeamSize;
if (fullHypos != NULL) if (fullHypos != NULL)
delete[] fullHypos; delete[] fullHypos;
...@@ -279,6 +285,14 @@ void T2TSearch::Collect(T2TStateBundle * beam) ...@@ -279,6 +285,14 @@ void T2TSearch::Collect(T2TStateBundle * beam)
T2TState & state = states[i]; T2TState & state = states[i];
state.pid = state.last->pid; state.pid = state.last->pid;
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
if(IsEnd(state.prediction)){
state.isEnd = 1;
}
else
state.isEnd = 0;
} }
} }
...@@ -290,4 +304,40 @@ void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output) ...@@ -290,4 +304,40 @@ void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
{ {
} }
/*
check if the token is an end symbol
>> token - token to be checked
*/
bool T2TSearch::IsEnd(int token)
{
CheckNTErrors(endSymbolNum > 0, "No end symbol?");
for(int i = 0; i < endSymbolNum; i++){
if(endSymbols[i] == token)
return true;
}
return false;
}
/*
set end symbols for search
>> tokens - end symbols
>> tokenNum - number of the end symbols
*/
void T2TSearch::SetEnd(const int * tokens, const int tokenNum)
{
if(endSymbols != NULL)
delete[] endSymbols;
if(tokenNum <= 0)
return;
/* we may have multiple end symbols */
tokens = new int[tokenNum];
for(int i = 0; i < tokenNum; i++)
endSymbols[i] = tokens[i];
endSymbolNum = tokenNum;
}
} }
...@@ -47,9 +47,18 @@ private: ...@@ -47,9 +47,18 @@ private:
/* beam size */ /* beam size */
int beamSize; int beamSize;
/* batch size */
int batchSize;
/* we keep the final hypotheses in a heap for each sentence in the batch. */ /* we keep the final hypotheses in a heap for each sentence in the batch. */
XHeap<MIN_HEAP, float> * fullHypos; XHeap<MIN_HEAP, float> * fullHypos;
/* array of the end symbols */
int * endSymbols;
/* number of the end symbols */
int endSymbolNum;
public: public:
/* constructor */ /* constructor */
T2TSearch(); T2TSearch();
...@@ -64,7 +73,7 @@ public: ...@@ -64,7 +73,7 @@ public:
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output); void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
/* preparation */ /* preparation */
void Prepare(int batchSize,int beamSize); void Prepare(int myBatchSize,int myBeamSize);
/* compute the model score for each hypothesis */ /* compute the model score for each hypothesis */
void Score(T2TStateBundle * prev, T2TStateBundle * beam); void Score(T2TStateBundle * prev, T2TStateBundle * beam);
...@@ -80,6 +89,12 @@ public: ...@@ -80,6 +89,12 @@ public:
/* save the output sequences in a tensor */ /* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output); void DumpOutput(T2TStateBundle * beam, XTensor * output);
/* check if the token is an end symbol */
bool IsEnd(int token);
/* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论