Commit cad766b4 by xiaotong

bug fixes

parent ff8b2cf9
......@@ -527,10 +527,16 @@ void T2TSearch::Dump(XTensor * output)
T2TState * state = (T2TState *)heap.Pop().index;
int count = 0;
bool isCompleted = false;
/* we track the state from the end to the beginning */
while(state != NULL){
words[count++] = state->isCompleted ? -1 : state->prediction;
if (isCompleted)
words[count++] = -1;
else
words[count++] = state->prediction;
if (state->isCompleted)
isCompleted = true;
state = state->last;
}
......
......@@ -115,8 +115,8 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(ofile, "output:");
Dump(ofile, &output);
float prob = 0;
loss += -prob;
......@@ -150,14 +150,19 @@ dump the result into the file
>> file - data file
>> output - output tensor
*/
void T2TTester::Dump(FILE * file, const XTensor * output)
void T2TTester::Dump(FILE * file, XTensor * output)
{
int seqLength = output->GetDim(-1);
for(int i = 0; i < output->unitNum; i += seqLength){
for(int j = 0; j < seqLength; j++){
int w = output->GetInt(i + j);
if (w < 0)
break;
fprintf(file, "%d ", w);
}
fprintf(file, "\n");
}
}
......
......@@ -59,7 +59,7 @@ public:
void Test(const char * fn, const char * ofn, T2TModel * model);
/* dump the result into the file */
void Dump(FILE * file, const XTensor * output);
void Dump(FILE * file, XTensor * output);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论