Commit 5cd1be65 by xiaotong

new code of t2t trainer

parent 67bbdfd2
......@@ -580,7 +580,7 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad)
get prediction probabilites of the gold words
>> output - output probabilities
>> gold - gold standard
>>
>> wordPobs - probability of each word
<< return - probability of the batch
*/
float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
......
......@@ -19,8 +19,11 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#include <math.h>
#include "T2TTrainer.h"
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
......@@ -52,6 +55,7 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
LoadParamInt(argc, argv, "nstep", &nstep, 1);
LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
int maxUnitInBuf;
LoadParamInt(argc, argv, "bufsize", &maxUnitInBuf, 20000);
......@@ -67,6 +71,64 @@ train the model
*/
void T2TTrainer::Train(const char * fn, T2TModel * model)
{
int epoch = 0;
int step = 0;
int wc = 0;
int wordCount = 0;
int wordCountTotal = 0;
bool isEnd = false;
float loss = 0;
double startT = GetClockSec();
for(epoch = 0; epoch < nepoch; epoch++){
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "cannot open training file!");
wordCount = 0;
/* batch of input sequences */
XTensor batch;
/* output probabilities */
XTensor output;
while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
/* make the network */
model->Make(&batch, &output);
/* TODO: update the model!!!! */
/* get probabilities */
float prob = GetProb(&output, &batch, NULL);
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
if(++step >= nstep){
isEnd = true;
break;
}
if (step % 100 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
}
}
fclose(file);
}
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
}
char line[MAX_SEQUENCE_LENGTH];
......@@ -126,10 +188,7 @@ int T2TTrainer::LoadBuf(FILE * file)
wordCount += wNum;
lineCount++;
if(wordCount >= wBatchSize)
break;
if(lineCount >= sBatchSize)
if(wordCount >= wBatchSize || lineCount >= sBatchSize)
break;
}
......@@ -148,8 +207,9 @@ load a batch of sequences
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
*/
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted)
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount)
{
if(nextSeq >= nseqBuf)
LoadBuf(file);
......@@ -182,14 +242,55 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sB
batch->SetZeroAll();
/* this might be slow on GPUs :( */
for(int s = seq; s < seq + sc; s++){
for(int w = 0; w < seqLen[s]; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
wCount++;
}
}
}
return sc;
}
/*
get word probabilities for a batch of sequences
>> output - word distribution for each position
>> gold - gold standard
>> wordProbs - word probability for gold prediction
*/
float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
{
XTensor probs;
InitTensor(&probs, output);
/* probs[i,j] = output[i,j] * gold[i,j] */
_Multiply(output, gold, &probs);
/* probability of each word */
XTensor wprobs;
InitTensor1D(&wprobs, output->unitNum/output->GetDim(-1), X_FLOAT, output->devID, output->mem);
int dims[2] = {output->unitNum/output->GetDim(-1), output->GetDim(-1)};
probs.Reshape(2, dims);
_ReduceSum(&probs, &wprobs, 1);
if(wordProbs != NULL)
_CopyValues(&wprobs, wordProbs);
/* reshape the tensor to fit it into the reduce procedure
TODO: XTensor supports scalars */
dims[0] = 1;
dims[1] = probs.unitNum;
probs.Reshape(2, dims);
/* probability for the batch */
XTensor result;
InitTensor1D(&result, 1, X_FLOAT, output->devID, output->mem);
_ReduceSum(&probs, &result, 1);
return result.Get1D(0);
}
}
\ No newline at end of file
}
......@@ -57,6 +57,9 @@ public:
/* offset for next sequence in the buffer */
int nextSeq;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* vocabulary size of the source side */
int vSize;
......@@ -93,10 +96,13 @@ public:
int LoadBuf(FILE * file);
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted);
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
};
}
#endif
\ No newline at end of file
#endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论