Commit cad766b4 by xiaotong

bug fixes

parent ff8b2cf9
...@@ -527,10 +527,16 @@ void T2TSearch::Dump(XTensor * output) ...@@ -527,10 +527,16 @@ void T2TSearch::Dump(XTensor * output)
T2TState * state = (T2TState *)heap.Pop().index; T2TState * state = (T2TState *)heap.Pop().index;
int count = 0; int count = 0;
bool isCompleted = false;
/* we track the state from the end to the beginning */ /* we track the state from the end to the beginning */
while(state != NULL){ while(state != NULL){
words[count++] = state->isCompleted ? -1 : state->prediction; if (isCompleted)
words[count++] = -1;
else
words[count++] = state->prediction;
if (state->isCompleted)
isCompleted = true;
state = state->last; state = state->last;
} }
......
...@@ -115,8 +115,8 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -115,8 +115,8 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
seacher.Search(model, &batchEnc, &paddingEnc, &output); seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(ofile, "output:"); Dump(ofile, &output);
float prob = 0; float prob = 0;
loss += -prob; loss += -prob;
...@@ -150,14 +150,19 @@ dump the result into the file ...@@ -150,14 +150,19 @@ dump the result into the file
>> file - data file >> file - data file
>> output - output tensor >> output - output tensor
*/ */
void T2TTester::Dump(FILE * file, const XTensor * output) void T2TTester::Dump(FILE * file, XTensor * output)
{ {
int seqLength = output->GetDim(-1); int seqLength = output->GetDim(-1);
for(int i = 0; i < output->unitNum; i += seqLength){ for(int i = 0; i < output->unitNum; i += seqLength){
for(int j = 0; j < seqLength; j++){ for(int j = 0; j < seqLength; j++){
int w = output->GetInt(i + j);
if (w < 0)
break;
fprintf(file, "%d ", w);
} }
fprintf(file, "\n");
} }
} }
......
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
void Test(const char * fn, const char * ofn, T2TModel * model); void Test(const char * fn, const char * ofn, T2TModel * model);
/* dump the result into the file */ /* dump the result into the file */
void Dump(FILE * file, const XTensor * output); void Dump(FILE * file, XTensor * output);
}; };
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论