Commit b7026578 by xiaotong

decoding code

parent da1d7ca8
......@@ -49,7 +49,7 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TSearch::InitModel(int argc, char ** argv)
void T2TSearch::Init(int argc, char ** argv)
{
LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
......@@ -100,6 +100,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
}
delete[] states;
Dump(output);
}
/*
......
......@@ -67,7 +67,7 @@ public:
~T2TSearch();
/* initialize the model */
void InitModel(int argc, char ** argv);
void Init(int argc, char ** argv);
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
......
......@@ -19,7 +19,12 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/
#include "T2TUtility.h"
#include "T2TTester.h"
#include "T2TSearch.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
using namespace nts;
......@@ -39,6 +44,11 @@ T2TTester::~T2TTester()
/* initialize the model */
void T2TTester::InitModel(int argc, char ** argv)
{
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamInt(argc, argv, "vsizetgt", &vSizeTgt, vSize);
batchLoader.Init(argc, argv);
seacher.Init(argc, argv);
}
/*
......@@ -49,6 +59,78 @@ test the model
*/
void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
{
int wc = 0;
int ws = 0;
int wordCount = 0;
int wordCountTotal = 0;
int sentCount = 0;
float loss = 0;
/* data files */
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "Cannot read the test file");
FILE * ofile = fopen(ofn, "wb");
CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID;
XMem * mem = model->mem;
XNet net;
double startT = GetClockSec();
wordCount = 0;
/* batch of input sequences */
XTensor batchEnc;
XTensor batchDec;
/* label */
XTensor label;
/* padding */
XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */
XTensor gold;
/* an array that keeps the sequences */
int * seqs = new int[MILLION];
batchLoader.ClearBuf();
while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch!");
CheckNTErrors(!model->isLM, "Only MT model is supported!");
XTensor output;
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(ofile, "output:");
float prob = 0;
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
sentCount += 1;
}
fclose(file);
fclose(ofile);
delete[] seqs;
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss / wordCount));
}
}
\ No newline at end of file
......@@ -24,6 +24,7 @@
#define __T2TTESTER_H__
#include "T2TSearch.h"
#include "T2TBatchLoader.h"
namespace transformer
{
......@@ -32,6 +33,19 @@ namespace transformer
class T2TTester
{
public:
/* vocabulary size of the source side */
int vSize;
/* vocabulary size of the target side */
int vSizeTgt;
/* for batching */
T2TBatchLoader batchLoader;
/* decoder for inference */
T2TSearch seacher;
public:
/* constructor */
T2TTester();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论