Commit 5f345e87 by xiaotong

bug fixes

parent e6c92495
...@@ -148,7 +148,6 @@ get a batch of samples ...@@ -148,7 +148,6 @@ get a batch of samples
*/ */
bool TTDataLoader::GetBatchSimple(XList * inputs, XList * golds) bool TTDataLoader::GetBatchSimple(XList * inputs, XList * golds)
{ {
fprintf(stderr, "get batch 0\n");
CheckNTErrors(file != NULL, "No input file specificed!"); CheckNTErrors(file != NULL, "No input file specificed!");
CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong argument!"); CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong argument!");
CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong argument!"); CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong argument!");
...@@ -184,16 +183,14 @@ bool TTDataLoader::GetBatchSimple(XList * inputs, XList * golds) ...@@ -184,16 +183,14 @@ bool TTDataLoader::GetBatchSimple(XList * inputs, XList * golds)
InitTensor2D(input, count, 3, X_INT); InitTensor2D(input, count, 3, X_INT);
InitTensor2D(gold, count, 1, X_INT); InitTensor2D(gold, count, 1, X_INT);
input->SetData(input, count * 3); input->SetData(inputBatch, count * 3);
gold->SetData(gold, count); gold->SetData(goldBatch, count);
} }
delete[] line; delete[] line;
delete[] inputBatch; delete[] inputBatch;
delete[] goldBatch; delete[] goldBatch;
fprintf(stderr, "get batch 1\n");
if (count > 0) if (count > 0)
return true; return true;
else else
...@@ -225,15 +222,17 @@ void TTModel::Init(XConfig &myConfig, int devID) ...@@ -225,15 +222,17 @@ void TTModel::Init(XConfig &myConfig, int devID)
{ {
SetConfig(myConfig); SetConfig(myConfig);
int vSize = MAX_INT_IN_TTRAIN + 1; vSize = MAX_INT_IN_TTRAIN + 1;
int eSize = config.GetInt("esize", TT_EMBEDDING_SIZE); eSize = config.GetInt("esize", TT_EMBEDDING_SIZE);
int hSize = config.GetInt("hsize", TT_HIDDEN_SIZE); hSize = config.GetInt("hsize", TT_HIDDEN_SIZE);
InitTensor2D(&embeddingW, vSize, eSize, X_FLOAT, devID); InitTensor2D(&embeddingW, vSize, eSize, X_FLOAT, devID);
InitTensor2D(&hiddenW, 3 * eSize, hSize, X_FLOAT, devID); InitTensor2D(&hiddenW, 3 * eSize, hSize, X_FLOAT, devID);
InitTensor2D(&outputW, hSize, vSize, X_FLOAT, devID);
embeddingW.SetDataRand(-0.1F, 0.1F); embeddingW.SetDataRand(-0.1F, 0.1F);
hiddenW.SetDataRand(-0.1F, 0.1F); hiddenW.SetDataRand(-0.1F, 0.1F);
outputW.SetDataRand(-0.1F, 0.1F);
} }
/* create the model */ /* create the model */
...@@ -243,21 +242,17 @@ void TTModel::Forward(int devID, XTensor * input, XTensor * output) ...@@ -243,21 +242,17 @@ void TTModel::Forward(int devID, XTensor * input, XTensor * output)
XTensor embeddingCat; XTensor embeddingCat;
XTensor hidden; XTensor hidden;
fprintf(stderr, "forward 0\n");
/* [e_0, e_1, e_2] = w_e * input(one-hot) */ /* [e_0, e_1, e_2] = w_e * input(one-hot) */
embedding = Gather(embeddingW, *input); embedding = Gather(embeddingW, *input);
/* e = merge(e_0, e_1, e_2) */ /* e = merge(e_0, e_1, e_2) */
embeddingCat = Merge(embedding, 0, 1); embeddingCat = Merge(embedding, embedding.order - 1, embedding.order - 2);
/* h = e * w_h */
hidden = MMul(embeddingCat, hiddenW);
/* output = Softmax(h) */ /* h = hardtanh(e * w_h) */
*output = Softmax(hidden, 0); hidden = HardTanH(MMul(embeddingCat, hiddenW));
fprintf(stderr, "forward 1\n"); /* output = Softmax(h * w_o) */
*output = Softmax(MMul(hidden, outputW), -1);
} }
/* clear the model */ /* clear the model */
...@@ -292,15 +287,26 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds) ...@@ -292,15 +287,26 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds)
XTensor * output = (XTensor*)outputs->GetItem(0); XTensor * output = (XTensor*)outputs->GetItem(0);
XTensor * gold = (XTensor*)golds->GetItem(0); XTensor * gold = (XTensor*)golds->GetItem(0);
XTensor loss; XTensor loss;
XTensor goldOneHot;
XNet net; XNet net;
Forward(devID, input, output); Forward(devID, input, output);
loss = CrossEntropy(output, gold); goldOneHot = IndexToOnehot(*gold, vSize, 0.0F);
int* dims = new int[goldOneHot.order];
for (int i = 0; i < goldOneHot.order - 2; i++)
dims[i] = goldOneHot.GetDim(i);
dims[goldOneHot.order - 2] = goldOneHot.GetDim(goldOneHot.order - 1);
goldOneHot.Reshape(goldOneHot.order - 1, dims);
loss = CrossEntropy(output, goldOneHot);
net.Backward(loss); net.Backward(loss);
delete[] dims;
return true; return true;
} }
......
...@@ -111,6 +111,18 @@ protected: ...@@ -111,6 +111,18 @@ protected:
/* parameter matrix of the hidden layer */ /* parameter matrix of the hidden layer */
XTensor hiddenW; XTensor hiddenW;
/* parameter matrix of the output layer */
XTensor outputW;
/* vocabulary size */
int vSize;
/* embedding size */
int eSize;
/* hidden layer size */
int hSize;
public: public:
/* constructor */ /* constructor */
TTModel(); TTModel();
......
...@@ -206,6 +206,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -206,6 +206,7 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
XModel * model, XOptimizer * optimizer) XModel * model, XOptimizer * optimizer)
{ {
bool isDataOK = true; bool isDataOK = true;
int activeJobCount = 0;
/* Feed the input to each worker and geneate the output. /* Feed the input to each worker and geneate the output.
For each worker, we define a job queue and enqueue jobs For each worker, we define a job queue and enqueue jobs
...@@ -218,18 +219,21 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -218,18 +219,21 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* get a batch of samples */ /* get a batch of samples */
bool fetched = dataDistributor->GetBatchSimple(worker->GetInput(), worker->GetGold()); bool fetched = dataDistributor->GetBatchSimple(worker->GetInput(), worker->GetGold());
if (!fetched)
isDataOK = false;
else {
/* job in queue 1: refresh the model */ /* job in queue 1: refresh the model */
worker->AddJobRefresh(jmodel); worker->AddJobRefresh(jmodel);
/* job in queue 1: run the model */ /* job in queue 1: run the model */
worker->AddJobNeuralNet(jmodel, worker->GetInput(), worker->GetOutput(), worker->GetGold()); worker->AddJobNeuralNet(jmodel, worker->GetInput(), worker->GetOutput(), worker->GetGold());
/* clear it */ activeJobCount++;
worker->Clear();
if (!fetched)
isDataOK = false;
} }
}
if (activeJobCount == 0)
return false;
XList members(jworkers.count); XList members(jworkers.count);
for (int i = 0; i < jworkers.count; i++) { for (int i = 0; i < jworkers.count; i++) {
...@@ -266,6 +270,11 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -266,6 +270,11 @@ bool XLeader::Run(XConfig * config, DataDistributeBase * dataDistributor,
WaitForFinishing(); WaitForFinishing();
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers[i];
worker->Clear();
}
return isDataOK; return isDataOK;
} }
......
...@@ -124,6 +124,9 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -124,6 +124,9 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
/* one step of udpate */ /* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer); ok = leader.Run(config, dataDistributor, model, optimizer);
if ((step + 1) % 100 == 0)
fprintf(stderr, "epoch:%d step:%d\n", epoch + 1, step + 1);
if (step++ >= nstep) if (step++ >= nstep)
break; break;
} }
...@@ -135,6 +138,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -135,6 +138,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
} }
delete[] ids; delete[] ids;
fprintf(stderr, "epoch:%d step:%d\n", epoch, step);
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -105,7 +105,6 @@ add a new job of model refreshment ...@@ -105,7 +105,6 @@ add a new job of model refreshment
*/ */
bool XWorkerJob::AddJobRefresh(XModel * myModel) bool XWorkerJob::AddJobRefresh(XModel * myModel)
{ {
fprintf(stderr, "refresh 0\n");
CheckNTErrors(myModel != NULL, "no parameter keeper!"); CheckNTErrors(myModel != NULL, "no parameter keeper!");
XList args(1); XList args(1);
...@@ -113,8 +112,6 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel) ...@@ -113,8 +112,6 @@ bool XWorkerJob::AddJobRefresh(XModel * myModel)
queue.EnqueueJob((void*)(char*)XModel::Refresh, &args); queue.EnqueueJob((void*)(char*)XModel::Refresh, &args);
fprintf(stderr, "refresh 1\n");
return true; return true;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论