Commit 59d8b4bf by xiaotong

bug fix

parent 78bdfb45
......@@ -121,10 +121,10 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked)
dot = dot + mask;
scalar = Softmax(Linear(dot, 1/(float)sqrt((float)dk)), -1);
scalar = Softmax(Linear(dot, 1.0F/(float)sqrt((float)dk)), -1);
if(ignored > 0)
_SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
//if(ignored > 0)
// _SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
att = BMMul(scalar, vheads);
......
......@@ -123,7 +123,8 @@ XTensor T2TEmbedder::Make(XTensor &input)
}
/* we make positional embeddings first */
if(!match){
//if(!match){
if(true){
InitTensor(&posEmbedding, input.order, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
......
......@@ -55,7 +55,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "memsize", &memSize, 256);
LoadParamInt(argc, argv, "memsize", &memSize, 1024);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
......@@ -66,7 +66,7 @@ void T2TModel::InitModel(int argc, const char ** argv)
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
encoder.InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
}
......@@ -104,7 +104,7 @@ void T2TModel::Make(XTensor &input, XTensor &output)
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri(&mask, 1e9F, -1);
_SetDataLowTri(&mask, 1e9F, 0);
_ScaleAndShiftMe(&mask, 1.0F, -1e9F);
encoding = MakeEncoding(input, mask, true);
......
......@@ -85,6 +85,22 @@ public:
/* traing step number */
int nstep;
/* indicates whether we use adam */
bool useAdam;
/* hyper parameters of adam*/
float adamBeta1;
float adamBeta2;
float adamDelta;
float adamBeta1T;
float adamBeta2T;
/* list of the moment of the parameter matrics */
XList moments;
/* list of the 2nd order moment of the parameter matrics */
XList moments2nd;
public:
/* constructor */
T2TTrainer();
......@@ -98,11 +114,19 @@ public:
/* train the model */
void Train(const char * fn, T2TModel * model);
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
/* load data to buffer */
int LoadBuf(FILE * file);
/* clear data buffer */
void ClearBuf();
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, XTensor * padding,
int LoadBatch(FILE * file, bool isLM,
XTensor * batch, XTensor * padding, XTensor * output,
int * seqs,
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
......
......@@ -40,20 +40,23 @@ int TransformerMain(int argc, const char ** argv)
char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH];
char * outputFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, "");
LoadParamString(argc, argv, "model", modelFN, "");
LoadParamString(argc, argv, "test", testFN, "");
LoadParamString(argc, argv, "output", outputFN, "");
T2TTrainer trainer;
trainer.Init(argc, argv);
T2TModel model;
model.InitModel(argc, argv);
/* learn model parameters */
if(strcmp(trainFN, "")){
T2TTrainer trainer;
trainer.Init(argc, argv);
if(strcmp(trainFN, ""))
trainer.Train(trainFN, &model);
}
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
......@@ -63,9 +66,14 @@ int TransformerMain(int argc, const char ** argv)
if(strcmp(modelFN, ""))
model.Read(modelFN);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, ""))
trainer.Test(testFN, outputFN, &model);
delete[] trainFN;
delete[] modelFN;
delete[] testFN;
delete[] outputFN;
fclose(tmpFILE);
......
......@@ -147,6 +147,7 @@ extern bool useCUDA;
#define XPRINT4(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4);FFLUSH(FILEH);}}
#define XPRINT5(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5);FFLUSH(FILEH);}}
#define XPRINT6(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6);FFLUSH(FILEH);}}
#define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}}
#define B2I(V) V==0?false:true
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论