Commit a33c3231 by xiaotong

new transformer code

parent 3cd237ff
......@@ -44,7 +44,7 @@ T2TAttention::~T2TAttention()
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
......
......@@ -44,7 +44,7 @@ T2TEmbedder::~T2TEmbedder()
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
......
......@@ -44,7 +44,7 @@ AttEncoder::~AttEncoder()
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
......
......@@ -43,7 +43,7 @@ T2TFNN::~T2TFNN()
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
......
......@@ -39,7 +39,7 @@ T2TLN::~T2TLN()
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
......
......@@ -41,7 +41,11 @@ T2TModel::~T2TModel()
delete mem;
}
/* initialize the model */
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TModel::InitModel(int argc, const char ** argv)
{
bool useMem = false;
......
......@@ -43,7 +43,7 @@ T2TOutput::~T2TOutput()
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#include "T2TTrainer.h"
#include "T2TUtility.h"
namespace transformer
{
/* constructor */
T2TTrainer::T2TTrainer()
{
seqLen = NULL;
nseqBuf = 0;
nextSeq = -1;
}
/* de-constructor */
T2TTrainer::~T2TTrainer()
{
delete[] buf;
delete[] seqLen;
}
/*
initialization
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TTrainer::Init(int argc, const char ** argv)
{
LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
LoadParamInt(argc, argv, "nstep", &nstep, 1);
int maxUnitInBuf;
LoadParamInt(argc, argv, "bufsize", &maxUnitInBuf, 20000);
buf = new int[maxUnitInBuf];
seqLen = new int[maxUnitInBuf];
seqOffset = new int[maxUnitInBuf];
}
/*
train the model
>> fn - training data file
>> model - model to train
*/
void T2TTrainer::Train(const char * fn, T2TModel * model)
{
}
char line[MAX_SEQUENCE_LENGTH];
/*
load data to buffer
>> file - where to load data
*/
int T2TTrainer::LoadBuf(FILE * file)
{
int lineCount = 0;
int seqCount = 0;
int wordCount = 0;
while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
int len = (int)strlen(line);
if(line[len - 1] == '\r')
line[len - 1] = 0;
len = (int)strlen(line);
if(len == 0)
continue;
/* how many characters are in a word */
int wSize = 0;
/* how many words are in the sentence */
int wNum = 0;
int wNumLocal = 0;
for(int i = 0; i < len; i++){
/* load word (id) seperated by space or tab */
if((line[i] == ' ' || line[i] == '\t' || i == len - 1) && wSize > 0){
line[i] = 0;
if(wSize == 3 && line[i - 1] == '|' && line[i - 2] == '|' && line[i - 3] == '|'){
seqLen[seqCount] = wNumLocal;
seqOffset[seqCount] = wordCount + wNum - wNumLocal;
seqCount++;
wNumLocal = 0;
}
else{
buf[wNum++] = atoi(line + i - wSize);
wNumLocal++;
}
wSize = 0;
}
else
wSize++;
}
seqLen[seqCount] = wNumLocal;
seqOffset[seqCount] = wordCount + wNum - wNumLocal;
seqCount++;
wordCount += wNum;
lineCount++;
if(wordCount >= wBatchSize)
break;
if(lineCount >= sBatchSize)
break;
}
nseqBuf = seqCount;
nextSeq = 0;
return lineCount;
}
/*
load a batch of sequences
>> file - the handle to the data file
>> batch - the batch
>> step - the step we go over when move to the next sequence
>> vs - vocabulary size
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
*/
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted)
{
if(nextSeq >= nseqBuf)
LoadBuf(file);
int seq = nextSeq;
int wc = 0;
int sc = 0;
int max = 0;
while(seq < nseqBuf){
wc += seqLen[seq];
sc += 1;
if(max < wc)
max = wc;
if(sc >= sBatch && wc >= wBatch)
break;
}
if(sc > 0){
int dims[MAX_TENSOR_DIM_NUM];
dims[0] = sc;
dims[1] = max;
dims[2] = vs;
if(batch->order != 3 || batch->GetDim(0) != dims[0] ||
batch->GetDim(1) != dims[1] || batch->GetDim(2) != dims[2]){
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
}
batch->SetZeroAll();
for(int s = seq; s < seq + sc; s++){
for(int w = 0; w < seqLen[s]; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
}
}
}
return sc;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#ifndef __T2TTRAINER_H__
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "../../tensor/function/FHeader.h"
#define MAX_SEQUENCE_LENGTH 1024 * 64
using namespace nts;
namespace transformer
{
/* trainer of the T2T model */
class T2TTrainer
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* buffer for loading words */
int * buf;
/* length of each sequence */
int * seqLen;
/* offset of the first word for each sequence */
int * seqOffset;
/* number of sequences in the buffer */
int nseqBuf;
/* offset for next sequence in the buffer */
int nextSeq;
/* vocabulary size of the source side */
int vSize;
/* learning rate */
float lrate;
/* sentence batch size */
int sBatchSize;
/* word batch size */
int wBatchSize;
/* training epoch number */
int nepoch;
/* traing step number */
int nstep;
public:
/* constructor */
T2TTrainer();
/* de-constructor */
~T2TTrainer();
/* initialize the trainer */
void Init(int argc, const char ** argv);
/* train the model */
void Train(const char * fn, T2TModel * model);
/* load data to buffer */
int LoadBuf(FILE * file);
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted);
};
}
#endif
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论