Commit da1d7ca8 by xiaotong

tester

parent 1faabe78
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-25
* it is cold today but i'll move to a warm place tomorrow :)
*/
#ifndef __T2TBATCHLOADER_H__
#define __T2TBATCHLOADER_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
#define MAX_SEQUENCE_LENGTH 1024 * 4
/* node to keep batch information */
struct BatchNode
{
/* begining position */
int beg;
/* end position */
int end;
/* maximum word number on the encoder side */
int maxEnc;
/* maximum word number on the decoder side */
int maxDec;
/* a key for sorting */
int key;
};
class T2TBatchLoader
{
public:
/* buffer for loading words */
int * buf;
/* another buffer */
int * buf2;
/* batch buf */
BatchNode * bufBatch;
/* buffer size */
int bufSize;
/* size of batch buffer */
int bufBatchSize;
/* length of each sequence */
int * seqLen;
/* another array */
int * seqLen2;
/* offset of the first word for each sequence */
int * seqOffset;
/* number of sequences in the buffer */
int nseqBuf;
/* offset for next sequence in the buffer */
int nextSeq;
/* offset for next batch */
int nextBatch;
/* indicates whether we double the </s> symbol for the output of lms */
bool isDoubledEnd;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool isSmallBatch;
/* counterpart of "isSmallBatch" */
bool isBigBatch;
/* randomize batches */
bool isRandomBatch;
/* bucket size */
int bucketSize;
public:
/* constructor */
T2TBatchLoader();
/* de-constructor */
~T2TBatchLoader();
/* initialization */
void Init(int argc, char ** argv);
/* load data to buffer */
int LoadBuf(FILE * file, bool isSorted, int step);
/* clear data buffer */
void ClearBuf();
/* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
};
}
#endif
\ No newline at end of file
......@@ -319,17 +319,40 @@ void T2TSearch::Collect(T2TStateBundle * beam)
/*
save the output sequences in a tensor
>> input - input sequences
>> output - output sequences (for return)
*/
void T2TSearch::Dump(XTensor * input, XTensor * output)
void T2TSearch::Dump(XTensor * output)
{
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i);
dims[input->order - 1] = maxLength;
int dims[3] = {batchSize, beamSize, maxLength};
int * words = new int[maxLength];
InitTensor(output, 3, dims, X_INT);
SetDataFixedInt(*output, -1);
/* heap for an input sentence in the batch */
for(int h = 0; h < batchSize; h++){
XHeap<MIN_HEAP, float> &heap = fullHypos[h];
/* for each output in the beam */
for(int i = 0; i < beamSize; i++){
T2TState * state = (T2TState *)heap.Pop().index;
int count = 0;
/* we track the state from the end to the beginning */
while(state != NULL){
words[count++] = state->prediction;
state = state->last;
}
/* dump the sentence to the output tensor */
for(int w = 0; w < count; w++)
output->Set3DInt(words[count - w - 1], h, i, w);
}
}
InitTensor(output, input->order, dims, X_INT);
delete[] words;
}
/*
......
......@@ -88,7 +88,7 @@ public:
void Collect(T2TStateBundle * beam);
/* save the output sequences in a tensor */
void Dump(XTensor * input, XTensor * output);
void Dump(XTensor * output);
/* check if the token is an end symbol */
bool IsEnd(int token);
......
......@@ -25,4 +25,30 @@ using namespace nts;
namespace transformer
{
/* constructor */
T2TTester::T2TTester()
{
}
/* de-constructor */
T2TTester::~T2TTester()
{
}
/* initialize the model */
void T2TTester::InitModel(int argc, char ** argv)
{
}
/*
test the model
>> fn - test data file
>> ofn - output data file
>> model - model that is trained
*/
void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
{
}
}
\ No newline at end of file
......@@ -32,6 +32,17 @@ namespace transformer
class T2TTester
{
public:
/* constructor */
T2TTester();
/* de-constructor */
~T2TTester();
/* initialize the model */
void InitModel(int argc, char ** argv);
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
};
}
......
......@@ -23,35 +23,14 @@
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "T2TBatchLoader.h"
#include "../../tensor/function/FHeader.h"
#define MAX_SEQUENCE_LENGTH 1024 * 4
using namespace nts;
namespace transformer
{
/* node to keep batch information */
struct BatchNode
{
/* begining position */
int beg;
/* end position */
int end;
/* maximum word number on the encoder side */
int maxEnc;
/* maximum word number on the decoder side */
int maxDec;
/* a key for sorting */
int key;
};
/* trainer of the T2T model */
class T2TTrainer
{
......@@ -61,42 +40,6 @@ public:
/* parameter array */
char ** argArray;
/* buffer for loading words */
int * buf;
/* another buffer */
int * buf2;
/* batch buf */
BatchNode * bufBatch;
/* buffer size */
int bufSize;
/* size of batch buffer */
int bufBatchSize;
/* length of each sequence */
int * seqLen;
/* another array */
int * seqLen2;
/* offset of the first word for each sequence */
int * seqOffset;
/* number of sequences in the buffer */
int nseqBuf;
/* offset for next sequence in the buffer */
int nextSeq;
/* offset for next batch */
int nextBatch;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* dimension size of each inner layer */
int d;
......@@ -158,26 +101,15 @@ public:
/* number of batches on which we do model update */
int updateStep;
/* indicates whether we double the </s> symbol for the output of lms */
bool isDoubledEnd;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool isSmallBatch;
/* counterpart of "isSmallBatch" */
bool isBigBatch;
/* randomize batches */
bool isRandomBatch;
/* indicates whether we intend to debug the net */
bool isDebugged;
/* bucket size */
int bucketSize;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* for batching */
T2TBatchLoader batchLoader;
public:
/* constructor */
......@@ -197,46 +129,6 @@ public:
/* make a checkpoint */
void MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id);
/* load data to buffer */
int LoadBuf(FILE * file, bool isSorted, int step);
/* clear data buffer */
void ClearBuf();
/* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论