Commit b7026578 by xiaotong

decoding code

parent da1d7ca8
...@@ -49,7 +49,7 @@ initialize the model ...@@ -49,7 +49,7 @@ initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
*/ */
void T2TSearch::InitModel(int argc, char ** argv) void T2TSearch::Init(int argc, char ** argv)
{ {
LoadParamInt(argc, argv, "beamsize", &beamSize, 1); LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
...@@ -100,6 +100,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -100,6 +100,8 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
} }
delete[] states; delete[] states;
Dump(output);
} }
/* /*
......
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
~T2TSearch(); ~T2TSearch();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, char ** argv); void Init(int argc, char ** argv);
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output); void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
......
...@@ -19,7 +19,12 @@ ...@@ -19,7 +19,12 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/ */
#include "T2TUtility.h"
#include "T2TTester.h" #include "T2TTester.h"
#include "T2TSearch.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
using namespace nts; using namespace nts;
...@@ -39,6 +44,11 @@ T2TTester::~T2TTester() ...@@ -39,6 +44,11 @@ T2TTester::~T2TTester()
/* initialize the model */ /* initialize the model */
void T2TTester::InitModel(int argc, char ** argv) void T2TTester::InitModel(int argc, char ** argv)
{ {
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamInt(argc, argv, "vsizetgt", &vSizeTgt, vSize);
batchLoader.Init(argc, argv);
seacher.Init(argc, argv);
} }
/* /*
...@@ -49,6 +59,78 @@ test the model ...@@ -49,6 +59,78 @@ test the model
*/ */
void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
{ {
int wc = 0;
int ws = 0;
int wordCount = 0;
int wordCountTotal = 0;
int sentCount = 0;
float loss = 0;
/* data files */
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "Cannot read the test file");
FILE * ofile = fopen(ofn, "wb");
CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID;
XMem * mem = model->mem;
XNet net;
double startT = GetClockSec();
wordCount = 0;
/* batch of input sequences */
XTensor batchEnc;
XTensor batchDec;
/* label */
XTensor label;
/* padding */
XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */
XTensor gold;
/* an array that keeps the sequences */
int * seqs = new int[MILLION];
batchLoader.ClearBuf();
while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch!");
CheckNTErrors(!model->isLM, "Only MT model is supported!");
XTensor output;
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(ofile, "output:");
float prob = 0;
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
sentCount += 1;
}
fclose(file);
fclose(ofile);
delete[] seqs;
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss / wordCount));
} }
} }
\ No newline at end of file
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#define __T2TTESTER_H__ #define __T2TTESTER_H__
#include "T2TSearch.h" #include "T2TSearch.h"
#include "T2TBatchLoader.h"
namespace transformer namespace transformer
{ {
...@@ -32,6 +33,19 @@ namespace transformer ...@@ -32,6 +33,19 @@ namespace transformer
class T2TTester class T2TTester
{ {
public: public:
/* vocabulary size of the source side */
int vSize;
/* vocabulary size of the target side */
int vSizeTgt;
/* for batching */
T2TBatchLoader batchLoader;
/* decoder for inference */
T2TSearch seacher;
public:
/* constructor */ /* constructor */
T2TTester(); T2TTester();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论