Commit edd90176 by xiaotong

t2t code (with bugs)

parent 76c6523c
......@@ -145,14 +145,21 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
lossGrad.Compute(gold, root, root->grad, loss);
}
}
/* back-propagation from output to input */
for(int i = nodes.count - 1; i >= 0; i--){
XTensor * node = (XTensor*)nodes.Get(i);;
if(node->visitMark == NODE_FINISHED)
continue;
//if(i == 1)
// return;
BackwardNode(node);
if(node->mem != NULL){
CheckNTErrors(node->mem->bufUsed == 0, "Illegal access of buffer!");
}
}
}
......
......@@ -51,16 +51,19 @@ initialize the model
void T2TModel::InitModel(int argc, const char ** argv)
{
bool useMem = false;
int memSize = 0;
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamInt(argc, argv, "memsize", &memSize, 256);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
if(useMem){
delete mem;
mem = new XMem(devID, UNI_FREE, MILLION * 512, 1024, MILLION * 128);
mem = new XMem(devID, UNI_FREE, (MTYPE)MILLION * 128, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
......
......@@ -24,6 +24,7 @@
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
namespace transformer
{
......@@ -87,11 +88,13 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
float loss = 0;
float lr = 0;
PrepareModel(model);
int devID = model->devID;
XMem * mem = new XMem(devID, UNI_FREE, MILLION * 256, 1024, MILLION * 64);
XMem * mem = model->mem;
mem->SetPin();
model->mem->SetPin();
if(mem != NULL)
mem->SetPin();
XNet net;
......@@ -104,8 +107,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
wordCount = 0;
model->mem->BackToPin();
mem->BackToPin();
if(mem != NULL)
mem->BackToPin();
/* batch of input sequences */
XTensor batch;
......@@ -153,8 +156,8 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
}
model->mem->BackToPin();
mem->BackToPin();
if(mem != NULL)
mem->BackToPin();
}
fclose(file);
......@@ -428,6 +431,38 @@ void T2TTrainer::Update(T2TModel * model, const float lr)
}
/*
prepare model for training
>> model - the model for training
*/
void T2TTrainer::PrepareModel(T2TModel * model)
{
XList ws(100);
ws.Add(&model->outputLayer.w);
for(int i = 0; i < model->encoder.nlayer; i++){
ws.Add(&model->encoder.fnns[i].w1);
ws.Add(&model->encoder.fnns[i].b1);
ws.Add(&model->encoder.fnns[i].w2);
ws.Add(&model->encoder.fnns[i].b2);
ws.Add(&model->encoder.attentions[i].wk);
ws.Add(&model->encoder.attentions[i].wq);
ws.Add(&model->encoder.attentions[i].wv);
ws.Add(&model->encoder.fnnLayerNorms[i].w);
ws.Add(&model->encoder.fnnLayerNorms[i].b);
ws.Add(&model->encoder.attLayerNorms[i].w);
ws.Add(&model->encoder.attLayerNorms[i].b);
}
ws.Add(&model->encoder.embedder.w);
for(int i = 0; i < ws.count; i++){
XTensor * para = (XTensor*)ws.Get(i);
XNoder::MakeGrad(para);
}
}
/*
do padding on the output
>> output - output tensor of the network
>> padding - padding of a batch of sentences
......
......@@ -110,6 +110,9 @@ public:
/* update the model by delta rule */
void Update(T2TModel * model, const float lr);
/* prepare model for training */
void PrepareModel(T2TModel * model);
/* do padding on the output */
void PadOutput(XTensor * output, XTensor * padding);
};
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论