Commit edd90176 by xiaotong

t2t code (with bugs)

parent 76c6523c
...@@ -152,7 +152,14 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss) ...@@ -152,7 +152,14 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
if(node->visitMark == NODE_FINISHED) if(node->visitMark == NODE_FINISHED)
continue; continue;
//if(i == 1)
// return;
BackwardNode(node); BackwardNode(node);
if(node->mem != NULL){
CheckNTErrors(node->mem->bufUsed == 0, "Illegal access of buffer!");
}
} }
} }
......
...@@ -51,16 +51,19 @@ initialize the model ...@@ -51,16 +51,19 @@ initialize the model
void T2TModel::InitModel(int argc, const char ** argv) void T2TModel::InitModel(int argc, const char ** argv)
{ {
bool useMem = false; bool useMem = false;
int memSize = 0;
LoadParamInt(argc, argv, "dev", &devID, -1); LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem); LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "memsize", &memSize, 256);
LoadParamBool(argc, argv, "lm", &isLM, true); LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false); LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8); LoadParamInt(argc, argv, "nhead", &nhead, 8);
if(useMem){ if(useMem){
delete mem; delete mem;
mem = new XMem(devID, UNI_FREE, MILLION * 512, 1024, MILLION * 128); mem = new XMem(devID, UNI_FREE, (MTYPE)MILLION * 128, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
} }
encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem); encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "T2TUtility.h" #include "T2TUtility.h"
#include "../../tensor/XUtility.h" #include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h" #include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
namespace transformer namespace transformer
{ {
...@@ -87,11 +88,13 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -87,11 +88,13 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
float loss = 0; float loss = 0;
float lr = 0; float lr = 0;
PrepareModel(model);
int devID = model->devID; int devID = model->devID;
XMem * mem = new XMem(devID, UNI_FREE, MILLION * 256, 1024, MILLION * 64); XMem * mem = model->mem;
if(mem != NULL)
mem->SetPin(); mem->SetPin();
model->mem->SetPin();
XNet net; XNet net;
...@@ -104,7 +107,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -104,7 +107,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
wordCount = 0; wordCount = 0;
model->mem->BackToPin(); if(mem != NULL)
mem->BackToPin(); mem->BackToPin();
/* batch of input sequences */ /* batch of input sequences */
...@@ -153,7 +156,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model) ...@@ -153,7 +156,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount)); lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
} }
model->mem->BackToPin(); if(mem != NULL)
mem->BackToPin(); mem->BackToPin();
} }
...@@ -428,6 +431,38 @@ void T2TTrainer::Update(T2TModel * model, const float lr) ...@@ -428,6 +431,38 @@ void T2TTrainer::Update(T2TModel * model, const float lr)
} }
/* /*
prepare model for training
>> model - the model for training
*/
void T2TTrainer::PrepareModel(T2TModel * model)
{
XList ws(100);
ws.Add(&model->outputLayer.w);
for(int i = 0; i < model->encoder.nlayer; i++){
ws.Add(&model->encoder.fnns[i].w1);
ws.Add(&model->encoder.fnns[i].b1);
ws.Add(&model->encoder.fnns[i].w2);
ws.Add(&model->encoder.fnns[i].b2);
ws.Add(&model->encoder.attentions[i].wk);
ws.Add(&model->encoder.attentions[i].wq);
ws.Add(&model->encoder.attentions[i].wv);
ws.Add(&model->encoder.fnnLayerNorms[i].w);
ws.Add(&model->encoder.fnnLayerNorms[i].b);
ws.Add(&model->encoder.attLayerNorms[i].w);
ws.Add(&model->encoder.attLayerNorms[i].b);
}
ws.Add(&model->encoder.embedder.w);
for(int i = 0; i < ws.count; i++){
XTensor * para = (XTensor*)ws.Get(i);
XNoder::MakeGrad(para);
}
}
/*
do padding on the output do padding on the output
>> output - output tensor of the network >> output - output tensor of the network
>> padding - padding of a batch of sentences >> padding - padding of a batch of sentences
......
...@@ -110,6 +110,9 @@ public: ...@@ -110,6 +110,9 @@ public:
/* update the model by delta rule */ /* update the model by delta rule */
void Update(T2TModel * model, const float lr); void Update(T2TModel * model, const float lr);
/* prepare model for training */
void PrepareModel(T2TModel * model);
/* do padding on the output */ /* do padding on the output */
void PadOutput(XTensor * output, XTensor * padding); void PadOutput(XTensor * output, XTensor * padding);
}; };
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论