Commit 5cd1be65 by xiaotong

new code of t2t trainer

parent 67bbdfd2
...@@ -580,7 +580,7 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad) ...@@ -580,7 +580,7 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad)
get prediction probabilites of the gold words get prediction probabilites of the gold words
>> output - output probabilities >> output - output probabilities
>> gold - gold standard >> gold - gold standard
>> >> wordPobs - probability of each word
<< return - probability of the batch << return - probability of the batch
*/ */
float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs) float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
......
...@@ -19,8 +19,11 @@ ...@@ -19,8 +19,11 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/ */
#include <math.h>
#include "T2TTrainer.h" #include "T2TTrainer.h"
#include "T2TUtility.h" #include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer namespace transformer
{ {
...@@ -52,6 +55,7 @@ void T2TTrainer::Init(int argc, const char ** argv) ...@@ -52,6 +55,7 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1); LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
LoadParamInt(argc, argv, "nepoch", &nepoch, 1); LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
LoadParamInt(argc, argv, "nstep", &nstep, 1); LoadParamInt(argc, argv, "nstep", &nstep, 1);
LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
int maxUnitInBuf; int maxUnitInBuf;
LoadParamInt(argc, argv, "bufsize", &maxUnitInBuf, 20000); LoadParamInt(argc, argv, "bufsize", &maxUnitInBuf, 20000);
...@@ -67,6 +71,64 @@ train the model ...@@ -67,6 +71,64 @@ train the model
*/ */
void T2TTrainer::Train(const char * fn, T2TModel * model) void T2TTrainer::Train(const char * fn, T2TModel * model)
{ {
int epoch = 0;
int step = 0;
int wc = 0;
int wordCount = 0;
int wordCountTotal = 0;
bool isEnd = false;
float loss = 0;
double startT = GetClockSec();
for(epoch = 0; epoch < nepoch; epoch++){
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "cannot open training file!");
wordCount = 0;
/* batch of input sequences */
XTensor batch;
/* output probabilities */
XTensor output;
while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
/* make the network */
model->Make(&batch, &output);
/* TODO: update the model!!!! */
/* get probabilities */
float prob = GetProb(&output, &batch, NULL);
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
if(++step >= nstep){
isEnd = true;
break;
}
if (step % 100 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
}
}
fclose(file);
}
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
} }
char line[MAX_SEQUENCE_LENGTH]; char line[MAX_SEQUENCE_LENGTH];
...@@ -126,10 +188,7 @@ int T2TTrainer::LoadBuf(FILE * file) ...@@ -126,10 +188,7 @@ int T2TTrainer::LoadBuf(FILE * file)
wordCount += wNum; wordCount += wNum;
lineCount++; lineCount++;
if(wordCount >= wBatchSize) if(wordCount >= wBatchSize || lineCount >= sBatchSize)
break;
if(lineCount >= sBatchSize)
break; break;
} }
...@@ -148,8 +207,9 @@ load a batch of sequences ...@@ -148,8 +207,9 @@ load a batch of sequences
>> sBatch - batch size of sequences >> sBatch - batch size of sequences
>> wBatch - batch size of words >> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length >> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
*/ */
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted) int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount)
{ {
if(nextSeq >= nseqBuf) if(nextSeq >= nseqBuf)
LoadBuf(file); LoadBuf(file);
...@@ -182,14 +242,55 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sB ...@@ -182,14 +242,55 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sB
batch->SetZeroAll(); batch->SetZeroAll();
/* this might be slow on GPUs :( */
for(int s = seq; s < seq + sc; s++){ for(int s = seq; s < seq + sc; s++){
for(int w = 0; w < seqLen[s]; w++){ for(int w = 0; w < seqLen[s]; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]); batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
wCount++;
} }
} }
} }
return sc; return sc;
} }
/*
get word probabilities for a batch of sequences
>> output - word distribution for each position
>> gold - gold standard
>> wordProbs - word probability for gold prediction
*/
float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
{
XTensor probs;
InitTensor(&probs, output);
/* probs[i,j] = output[i,j] * gold[i,j] */
_Multiply(output, gold, &probs);
/* probability of each word */
XTensor wprobs;
InitTensor1D(&wprobs, output->unitNum/output->GetDim(-1), X_FLOAT, output->devID, output->mem);
int dims[2] = {output->unitNum/output->GetDim(-1), output->GetDim(-1)};
probs.Reshape(2, dims);
_ReduceSum(&probs, &wprobs, 1);
if(wordProbs != NULL)
_CopyValues(&wprobs, wordProbs);
/* reshape the tensor to fit it into the reduce procedure
TODO: XTensor supports scalars */
dims[0] = 1;
dims[1] = probs.unitNum;
probs.Reshape(2, dims);
/* probability for the batch */
XTensor result;
InitTensor1D(&result, 1, X_FLOAT, output->devID, output->mem);
_ReduceSum(&probs, &result, 1);
return result.Get1D(0);
}
} }
\ No newline at end of file
...@@ -57,6 +57,9 @@ public: ...@@ -57,6 +57,9 @@ public:
/* offset for next sequence in the buffer */ /* offset for next sequence in the buffer */
int nextSeq; int nextSeq;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* vocabulary size of the source side */ /* vocabulary size of the source side */
int vSize; int vSize;
...@@ -93,10 +96,13 @@ public: ...@@ -93,10 +96,13 @@ public:
int LoadBuf(FILE * file); int LoadBuf(FILE * file);
/* load a batch of sequences */ /* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted); int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
}; };
} }
#endif #endif
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论