Commit 87bb27ee by xiaotong

updates

parent b69e10f6
...@@ -93,6 +93,7 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -93,6 +93,7 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* indicates whether the node is for an activation function */ /* indicates whether the node is for an activation function */
......
...@@ -89,6 +89,7 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -89,6 +89,7 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* indicates whether the node is for a loss computation */ /* indicates whether the node is for a loss computation */
......
...@@ -125,6 +125,9 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -125,6 +125,9 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
else{ else{
ShowNTErrors("Unsupported backward computation! TODO!"); ShowNTErrors("Unsupported backward computation! TODO!");
} }
node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* indicates whether the node is for a math operation */ /* indicates whether the node is for a math operation */
...@@ -162,8 +165,6 @@ void XMathGrad::GradAbsolute(XTensor * node, bool isEfficient) ...@@ -162,8 +165,6 @@ void XMathGrad::GradAbsolute(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -194,8 +195,6 @@ void XMathGrad::GradCos(XTensor * node, bool isEfficient) ...@@ -194,8 +195,6 @@ void XMathGrad::GradCos(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -225,8 +224,6 @@ void XMathGrad::GradExp(XTensor * node, bool isEfficient) ...@@ -225,8 +224,6 @@ void XMathGrad::GradExp(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -251,8 +248,6 @@ void XMathGrad::GradLog(XTensor * node, bool isEfficient) ...@@ -251,8 +248,6 @@ void XMathGrad::GradLog(XTensor * node, bool isEfficient)
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
_Div(node->grad, a, a->grad, 1.0F); _Div(node->grad, a, a->grad, 1.0F);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -276,8 +271,6 @@ void XMathGrad::GradRound(XTensor * node, bool isEfficient) ...@@ -276,8 +271,6 @@ void XMathGrad::GradRound(XTensor * node, bool isEfficient)
if (!isEfficient || a->isGrad) { if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -301,8 +294,6 @@ void XMathGrad::GradSign(XTensor * node, bool isEfficient) ...@@ -301,8 +294,6 @@ void XMathGrad::GradSign(XTensor * node, bool isEfficient)
if (!isEfficient || a->isGrad) { if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -332,8 +323,6 @@ void XMathGrad::GradSin(XTensor * node, bool isEfficient) ...@@ -332,8 +323,6 @@ void XMathGrad::GradSin(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -364,8 +353,6 @@ void XMathGrad::GradTan(XTensor * node, bool isEfficient) ...@@ -364,8 +353,6 @@ void XMathGrad::GradTan(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -398,8 +385,6 @@ void XMathGrad::GradClip(XTensor * node, bool isEfficient) ...@@ -398,8 +385,6 @@ void XMathGrad::GradClip(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -440,8 +425,6 @@ void XMathGrad::GradDiv(XTensor * node, bool isEfficient) ...@@ -440,8 +425,6 @@ void XMathGrad::GradDiv(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -539,8 +522,6 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient) ...@@ -539,8 +522,6 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient)
DelTensorBuf(aTMP2); DelTensorBuf(aTMP2);
DelTensorBuf(aTMP1); DelTensorBuf(aTMP1);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -602,8 +583,6 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient) ...@@ -602,8 +583,6 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
else{ else{
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -757,8 +736,6 @@ void XMathGrad::GradMatrixMulBatched(XTensor * node, bool isEfficient) ...@@ -757,8 +736,6 @@ void XMathGrad::GradMatrixMulBatched(XTensor * node, bool isEfficient)
if (!isEfficient || b->isGrad) if (!isEfficient || b->isGrad)
_MatrixMulBatched(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F); _MatrixMulBatched(dedc, X_TRANS, a, X_TRANS, dedb, alpha, 1.0F);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -793,8 +770,6 @@ void XMathGrad::GradMultiply(XTensor * node, bool isEfficient) ...@@ -793,8 +770,6 @@ void XMathGrad::GradMultiply(XTensor * node, bool isEfficient)
XNoder::MakeGrad(b); XNoder::MakeGrad(b);
_Multiply(node->grad, a, b->grad, 1.0F); _Multiply(node->grad, a, b->grad, 1.0F);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -879,8 +854,6 @@ void XMathGrad::GradMultiplyDim(XTensor * node, bool isEfficient) ...@@ -879,8 +854,6 @@ void XMathGrad::GradMultiplyDim(XTensor * node, bool isEfficient)
} }
DelTensorBuf(bGradTMP); DelTensorBuf(bGradTMP);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -916,8 +889,6 @@ void XMathGrad::GradMultiplyBroadcast(XTensor * node, bool isEfficient) ...@@ -916,8 +889,6 @@ void XMathGrad::GradMultiplyBroadcast(XTensor * node, bool isEfficient)
if (b->isVar || b->income.tailNum > 0) if (b->isVar || b->income.tailNum > 0)
ShowNTErrors("TODO"); ShowNTErrors("TODO");
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -942,8 +913,6 @@ void XMathGrad::GradNegate(XTensor * node, bool isEfficient) ...@@ -942,8 +913,6 @@ void XMathGrad::GradNegate(XTensor * node, bool isEfficient)
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad, -1.0F); _Sum(a->grad, node->grad, a->grad, -1.0F);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -987,8 +956,6 @@ void XMathGrad::GradPower(XTensor * node, bool isEfficient) ...@@ -987,8 +956,6 @@ void XMathGrad::GradPower(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
...@@ -1019,8 +986,6 @@ void XMathGrad::GradReciprocal(XTensor* node, bool isEfficient) ...@@ -1019,8 +986,6 @@ void XMathGrad::GradReciprocal(XTensor* node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1049,8 +1014,6 @@ void XMathGrad::GradSqrt(XTensor * node, bool isEfficient) ...@@ -1049,8 +1014,6 @@ void XMathGrad::GradSqrt(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1080,8 +1043,6 @@ void XMathGrad::GradSquare(XTensor * node, bool isEfficient) ...@@ -1080,8 +1043,6 @@ void XMathGrad::GradSquare(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1109,8 +1070,6 @@ void XMathGrad::GradScaleAndShift(XTensor * node, bool isEfficient) ...@@ -1109,8 +1070,6 @@ void XMathGrad::GradScaleAndShift(XTensor * node, bool isEfficient)
_Sum(a->grad, node->grad, a->grad, scale); _Sum(a->grad, node->grad, a->grad, scale);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1138,8 +1097,6 @@ void XMathGrad::GradScale(XTensor * node, bool isEfficient) ...@@ -1138,8 +1097,6 @@ void XMathGrad::GradScale(XTensor * node, bool isEfficient)
_Sum(a->grad, node->grad, a->grad, scale); _Sum(a->grad, node->grad, a->grad, scale);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1166,9 +1123,7 @@ void XMathGrad::GradDescale(XTensor * node, bool isEfficient) ...@@ -1166,9 +1123,7 @@ void XMathGrad::GradDescale(XTensor * node, bool isEfficient)
XNoder::MakeGrad(a); XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad, 1 / descale); _Sum(a->grad, node->grad, a->grad, 1 / descale);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1194,8 +1149,6 @@ void XMathGrad::GradShift(XTensor * node, bool isEfficient) ...@@ -1194,8 +1149,6 @@ void XMathGrad::GradShift(XTensor * node, bool isEfficient)
_Sum(a->grad, node->grad, a->grad); _Sum(a->grad, node->grad, a->grad);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1229,8 +1182,6 @@ void XMathGrad::GradSub(XTensor * node, bool isEfficient) ...@@ -1229,8 +1182,6 @@ void XMathGrad::GradSub(XTensor * node, bool isEfficient)
XNoder::MakeGrad(b); XNoder::MakeGrad(b);
_Sum(b->grad, node->grad, b->grad, -beta); _Sum(b->grad, node->grad, b->grad, -beta);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1317,8 +1268,6 @@ void XMathGrad::GradSubDim(XTensor * node, bool isEfficient) ...@@ -1317,8 +1268,6 @@ void XMathGrad::GradSubDim(XTensor * node, bool isEfficient)
DelTensorBuf(interGrad); DelTensorBuf(interGrad);
} }
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1352,8 +1301,6 @@ void XMathGrad::GradSum(XTensor * node, bool isEfficient) ...@@ -1352,8 +1301,6 @@ void XMathGrad::GradSum(XTensor * node, bool isEfficient)
XNoder::MakeGrad(b); XNoder::MakeGrad(b);
_Sum(b->grad, node->grad, b->grad, beta); _Sum(b->grad, node->grad, b->grad, beta);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1441,8 +1388,6 @@ void XMathGrad::GradSumDim(XTensor * node, bool isEfficient) ...@@ -1441,8 +1388,6 @@ void XMathGrad::GradSumDim(XTensor * node, bool isEfficient)
DelTensorBuf(interGrad); DelTensorBuf(interGrad);
} }
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1480,8 +1425,6 @@ void XMathGrad::GradSumBroadcast(XTensor * node, bool isEfficient) ...@@ -1480,8 +1425,6 @@ void XMathGrad::GradSumBroadcast(XTensor * node, bool isEfficient)
ShowNTErrors("TODO"); ShowNTErrors("TODO");
} }
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1516,8 +1459,6 @@ void XMathGrad::GradReduceMean(XTensor * node, bool isEfficient) ...@@ -1516,8 +1459,6 @@ void XMathGrad::GradReduceMean(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1550,8 +1491,6 @@ void XMathGrad::GradReduceSum(XTensor * node, bool isEfficient) ...@@ -1550,8 +1491,6 @@ void XMathGrad::GradReduceSum(XTensor * node, bool isEfficient)
_Sum(a->grad, tmp, a->grad); _Sum(a->grad, tmp, a->grad);
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1582,8 +1521,6 @@ void XMathGrad::GradReduceSumAll(XTensor * node, bool isEfficient) ...@@ -1582,8 +1521,6 @@ void XMathGrad::GradReduceSumAll(XTensor * node, bool isEfficient)
_Sum(a->grad, tmp, a->grad); _Sum(a->grad, tmp, a->grad);
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1639,8 +1576,6 @@ void XMathGrad::GradReduceSumSquared(XTensor * node, bool isEfficient) ...@@ -1639,8 +1576,6 @@ void XMathGrad::GradReduceSumSquared(XTensor * node, bool isEfficient)
DelTensorBuf(e); DelTensorBuf(e);
DelTensorBuf(d); DelTensorBuf(d);
DelTensorBuf(c); DelTensorBuf(c);
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1696,8 +1631,6 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient) ...@@ -1696,8 +1631,6 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
DelTensorBuf(e); DelTensorBuf(e);
DelTensorBuf(d); DelTensorBuf(d);
DelTensorBuf(c); DelTensorBuf(c);
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1815,9 +1748,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient) ...@@ -1815,9 +1748,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
dedx->Reshape(orderBackupX, dimsBackupX); dedx->Reshape(orderBackupX, dimsBackupX);
dedc->Reshape(orderBackupC, dimsBackupC); dedc->Reshape(orderBackupC, dimsBackupC);
} }
node->visitMark = NODE_FINISHED;
} }
/* /*
...@@ -1933,9 +1863,6 @@ void XMathGrad::GradMLP(XTensor* node, bool isEfficient) ...@@ -1933,9 +1863,6 @@ void XMathGrad::GradMLP(XTensor* node, bool isEfficient)
dedx->Reshape(orderBackupX, dimsBackupX); dedx->Reshape(orderBackupX, dimsBackupX);
dedc->Reshape(orderBackupC, dimsBackupC); dedc->Reshape(orderBackupC, dimsBackupC);
} }
node->visitMark = NODE_FINISHED;
} }
} }
...@@ -111,6 +111,9 @@ void XShapeGrad::GradConvertDataType(XTensor* node, bool isEfficient) ...@@ -111,6 +111,9 @@ void XShapeGrad::GradConvertDataType(XTensor* node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -144,6 +147,9 @@ void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficient) ...@@ -144,6 +147,9 @@ void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficient)
DelTensorBuf(tmp); DelTensorBuf(tmp);
} }
node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -176,6 +182,7 @@ void XShapeGrad::GradGather(XTensor * node, bool isEfficient) ...@@ -176,6 +182,7 @@ void XShapeGrad::GradGather(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -208,6 +215,7 @@ void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficient) ...@@ -208,6 +215,7 @@ void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -299,6 +307,7 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficient) ...@@ -299,6 +307,7 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -382,6 +391,7 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient) ...@@ -382,6 +391,7 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -410,6 +420,7 @@ void XShapeGrad::GradReshape(XTensor * node, bool isEfficient) ...@@ -410,6 +420,7 @@ void XShapeGrad::GradReshape(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -455,6 +466,7 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient) ...@@ -455,6 +466,7 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -539,6 +551,9 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient) ...@@ -539,6 +551,9 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
DelTensorBuf(nodeGradTMP); DelTensorBuf(nodeGradTMP);
} }
} }
node->visitMark = NODE_DOING;
node->isGradFinished = true;
} }
/* /*
...@@ -577,6 +592,7 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient) ...@@ -577,6 +592,7 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
/* /*
...@@ -615,6 +631,7 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient) ...@@ -615,6 +631,7 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
} }
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
node->isGradFinished = true;
} }
} }
\ No newline at end of file
...@@ -101,6 +101,7 @@ void XNet::Backward(TensorList &roots) ...@@ -101,6 +101,7 @@ void XNet::Backward(TensorList &roots)
for(int i = 0; i < nodes.count; i++){ for(int i = 0; i < nodes.count; i++){
XTensor * node = (XTensor*)nodes.Get(i); XTensor * node = (XTensor*)nodes.Get(i);
node->visitMark = NODE_UNFINISHED; node->visitMark = NODE_UNFINISHED;
node->isGradFinished = false;
} }
/* back-propagation from output to input */ /* back-propagation from output to input */
......
...@@ -36,7 +36,7 @@ TensorListBase<T>::TensorListBase() ...@@ -36,7 +36,7 @@ TensorListBase<T>::TensorListBase()
{ {
maxNum = 1; maxNum = 1;
count = 0; count = 0;
items = (T*)malloc(sizeof(T) * 1); items = new T[1];
} }
/* /*
...@@ -49,7 +49,7 @@ TensorListBase<T>::TensorListBase(int myMaxNum) ...@@ -49,7 +49,7 @@ TensorListBase<T>::TensorListBase(int myMaxNum)
CheckNTErrors(myMaxNum > 0, "check if the input number > 0"); CheckNTErrors(myMaxNum > 0, "check if the input number > 0");
maxNum = myMaxNum; maxNum = myMaxNum;
count = 0; count = 0;
items = (T*)malloc(sizeof(T) * myMaxNum); items = new T[myMaxNum];
} }
/* /*
...@@ -62,7 +62,7 @@ TensorListBase<T>::TensorListBase(const T* inputItems, int inputItemCount) ...@@ -62,7 +62,7 @@ TensorListBase<T>::TensorListBase(const T* inputItems, int inputItemCount)
CheckNTErrors(inputItemCount > 0, "check if the input number > 0"); CheckNTErrors(inputItemCount > 0, "check if the input number > 0");
maxNum = inputItemCount; maxNum = inputItemCount;
count = inputItemCount; count = inputItemCount;
items = (T*)malloc(sizeof(T) * inputItemCount); items = new T[inputItemCount];
memcpy(items, inputItems, inputItemCount * sizeof(T)); memcpy(items, inputItems, inputItemCount * sizeof(T));
} }
...@@ -73,7 +73,7 @@ TensorListBase<T>::TensorListBase(const TensorListBase<T>& l) ...@@ -73,7 +73,7 @@ TensorListBase<T>::TensorListBase(const TensorListBase<T>& l)
CheckNTErrors(l.maxNum > 0, "check if the input number > 0"); CheckNTErrors(l.maxNum > 0, "check if the input number > 0");
maxNum = l.maxNum; maxNum = l.maxNum;
count = l.count; count = l.count;
items = (T*)malloc(sizeof(T) * maxNum); items = new T[maxNum];
memcpy(items, l.items, l.count * sizeof(T)); memcpy(items, l.items, l.count * sizeof(T));
} }
...@@ -94,7 +94,7 @@ TensorListBase<T> TensorListBase<T>::operator=(const TensorListBase<T>& l) ...@@ -94,7 +94,7 @@ TensorListBase<T> TensorListBase<T>::operator=(const TensorListBase<T>& l)
{ {
maxNum = l.maxNum; maxNum = l.maxNum;
count = l.count; count = l.count;
items = (T*)malloc(sizeof(T) * maxNum); items = new T[maxNum];
memcpy(items, l.items, l.count * sizeof(T)); memcpy(items, l.items, l.count * sizeof(T));
return *this; return *this;
} }
...@@ -105,7 +105,7 @@ TensorListBase<T> TensorListBase<T>::operator=(TensorListBase<T>&& l) ...@@ -105,7 +105,7 @@ TensorListBase<T> TensorListBase<T>::operator=(TensorListBase<T>&& l)
{ {
maxNum = l.maxNum; maxNum = l.maxNum;
count = l.count; count = l.count;
items = (T*)malloc(sizeof(T) * maxNum); items = new T[maxNum];
memcpy(items, l.items, l.count * sizeof(T)); memcpy(items, l.items, l.count * sizeof(T));
return *this; return *this;
} }
...@@ -115,7 +115,7 @@ template <typename T> ...@@ -115,7 +115,7 @@ template <typename T>
TensorListBase<T>::~TensorListBase() TensorListBase<T>::~TensorListBase()
{ {
if(items != NULL) if(items != NULL)
free(items); delete[] items;
items = NULL; items = NULL;
} }
...@@ -127,17 +127,10 @@ template <typename T> ...@@ -127,17 +127,10 @@ template <typename T>
void TensorListBase<T>::Reallocate(int itemNum) void TensorListBase<T>::Reallocate(int itemNum)
{ {
if (maxNum < itemNum) { if (maxNum < itemNum) {
T* newItems; T * newItems = new T[itemNum];
memcpy(newItems, items, count * sizeof(T));
newItems = (T*)realloc(items, sizeof(T) * itemNum); delete[] items;
if (newItems != NULL) items = newItems;
items = newItems;
else {
newItems = (T*)malloc(sizeof(T) * itemNum);
memcpy(newItems, items, count * sizeof(T));
free(items);
items = newItems;
}
maxNum = itemNum; maxNum = itemNum;
} }
} }
...@@ -150,20 +143,10 @@ template <typename T> ...@@ -150,20 +143,10 @@ template <typename T>
void TensorListBase<T>::Add(T&& item) void TensorListBase<T>::Add(T&& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T * newItems = new T[count * 2 + 1];
T* newItems; memcpy(newItems, items, count * sizeof(T));
delete[] items;
newItems = (T*)realloc(items, sizeof(T) * (count * 2 + 1)); items = newItems;
if (newItems != NULL)
items = newItems;
else {
newItems = (T*)malloc(sizeof(T) * (count * 2 + 1));
memcpy(newItems, items, count * sizeof(T));
free(items);
items = newItems;
}
maxNum = count * 2 + 1; maxNum = count * 2 + 1;
} }
items[count++] = item; items[count++] = item;
...@@ -184,18 +167,10 @@ template <typename T> ...@@ -184,18 +167,10 @@ template <typename T>
void TensorListBase<T>::Add(const T& item) void TensorListBase<T>::Add(const T& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T* newItems; T * newItems = new T[count * 2 + 1];
memcpy(newItems, items, count * sizeof(T));
newItems = (T*)realloc(items, sizeof(T) * (count * 2 + 1)); delete[] items;
if (newItems != NULL) items = newItems;
items = newItems;
else {
newItems = (T*)malloc(sizeof(T) * (count * 2 + 1));
memcpy(newItems, items, count * sizeof(T));
free(items);
items = newItems;
}
maxNum = count * 2 + 1; maxNum = count * 2 + 1;
} }
...@@ -244,18 +219,10 @@ template <typename T> ...@@ -244,18 +219,10 @@ template <typename T>
void TensorListBase<T>::Add(const T* inputItems, int inputItemCount) void TensorListBase<T>::Add(const T* inputItems, int inputItemCount)
{ {
if (count + inputItemCount >= maxNum) { if (count + inputItemCount >= maxNum) {
T* newItems; T* newItems = new T[maxNum + count + inputItemCount + 1];
memcpy(newItems, items, count * sizeof(T));
newItems = (T*)realloc(items, sizeof(T) * (count + inputItemCount + 1)); delete[] items;
if (newItems != NULL) items = newItems;
items = newItems;
else {
newItems = (T*)malloc(sizeof(T) * (maxNum + count + inputItemCount + 1));
memcpy(newItems, items, count * sizeof(T));
free(items);
items = newItems;
}
maxNum += (count + inputItemCount + 1); maxNum += (count + inputItemCount + 1);
} }
memcpy(items + count, inputItems, sizeof(T) * inputItemCount); memcpy(items + count, inputItems, sizeof(T) * inputItemCount);
...@@ -281,18 +248,10 @@ template <typename T> ...@@ -281,18 +248,10 @@ template <typename T>
void TensorListBase<T>::Insert(int pos, const T& item) void TensorListBase<T>::Insert(int pos, const T& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T* newItems; T * newItems = new T[count * 2 + 1];
memcpy(newItems, items, count * sizeof(T));
newItems = (T*)realloc(items, sizeof(T) * (count * 2 + 1)); delete[] items;
if (newItems != NULL) items = newItems;
items = newItems;
else {
newItems = (T*)malloc(sizeof(T) * (count * 2 + 1));
memcpy(newItems, items, count * sizeof(T));
free(items);
items = newItems;
}
maxNum = count * 2 + 1; maxNum = count * 2 + 1;
} }
...@@ -306,18 +265,10 @@ template<typename T> ...@@ -306,18 +265,10 @@ template<typename T>
void TensorListBase<T>::Insert(int pos, T&& item) void TensorListBase<T>::Insert(int pos, T&& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T* newItems; T * newItems = new T[count * 2 + 1];
memcpy(newItems, items, count * sizeof(T));
newItems = (T*)realloc(items, sizeof(T) * (count * 2 + 1)); delete[] items;
if (newItems != NULL) items = newItems;
items = newItems;
else {
newItems = (T*)malloc(sizeof(T) * (count * 2 + 1));
memcpy(newItems, items, count * sizeof(T));
free(items);
items = newItems;
}
maxNum = count * 2 + 1; maxNum = count * 2 + 1;
} }
...@@ -459,7 +410,7 @@ void TensorListBase<T>::Clear() ...@@ -459,7 +410,7 @@ void TensorListBase<T>::Clear()
count = 0; count = 0;
maxNum = 0; maxNum = 0;
if(items != NULL) if(items != NULL)
free(items); delete[] items;
items = NULL; items = NULL;
} }
...@@ -514,7 +465,7 @@ void TensorListBase<T>::Reserve(int n) ...@@ -514,7 +465,7 @@ void TensorListBase<T>::Reserve(int n)
return; return;
} }
items = (T*)malloc(sizeof(T) * n); items = new T[n];
} }
/* /*
...@@ -560,8 +511,8 @@ void TensorListBase<T>::ReadFromFile(FILE* fp, int num) ...@@ -560,8 +511,8 @@ void TensorListBase<T>::ReadFromFile(FILE* fp, int num)
if(!items) if(!items)
Reserve(num - maxNum); Reserve(num - maxNum);
else { else {
free(items); delete[] items;
items = (T*)malloc(sizeof(T) * num); items = new T[num];
} }
} }
fread(items, sizeof(T), num, fp); fread(items, sizeof(T), num, fp);
......
...@@ -1604,6 +1604,9 @@ void XMemManager::GetBufferSize(MTYPE freeMem, MTYPE * myBufSize) ...@@ -1604,6 +1604,9 @@ void XMemManager::GetBufferSize(MTYPE freeMem, MTYPE * myBufSize)
} }
} }
} }
else {
ShowNTErrors("No enough memory for buffer allocation!");
}
} }
/* initialize it and set the global memory information */ /* initialize it and set the global memory information */
......
...@@ -250,7 +250,11 @@ bool XQueue::GetJobBreak() ...@@ -250,7 +250,11 @@ bool XQueue::GetJobBreak()
/* get the number of jobs */ /* get the number of jobs */
int XQueue::GetJobNum() int XQueue::GetJobNum()
{ {
return runningJobCount; MUTEX_LOCK(jobQueueMutex);
int c = runningJobCount;
MUTEX_UNLOCK(jobQueueMutex);
return c;
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -306,7 +306,7 @@ run the neural network ...@@ -306,7 +306,7 @@ run the neural network
*/ */
bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* losses) bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* losses)
{ {
//fprintf(stderr, "run simple 0\n"); fprintf(stderr, "run simple 0\n");
CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!"); CheckNTErrors(inputs != NULL && inputs->count >= 1, "Wrong arguments!");
CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!"); CheckNTErrors(outputs != NULL && outputs->count >= 1, "Wrong arguments!");
CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!"); CheckNTErrors(golds != NULL && golds->count >= 1, "Wrong arguments!");
...@@ -340,7 +340,7 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* l ...@@ -340,7 +340,7 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* l
delete[] dims; delete[] dims;
//fprintf(stderr, "run simple 1\n"); fprintf(stderr, "run simple 1\n");
return true; return true;
} }
......
...@@ -141,6 +141,30 @@ void XLeader::SetMode(XLEADER_MODE myMode) ...@@ -141,6 +141,30 @@ void XLeader::SetMode(XLEADER_MODE myMode)
mode = myMode; mode = myMode;
} }
/* set the flag of instant run */
void XLeader::SetInstantRun(bool flag)
{
for (int i = 0; i < jworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)jworkers.GetItem(i);
worker->SetInstantRun(flag);
}
for (int i = 0; i < cworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)cworkers.GetItem(i);
worker->SetInstantRun(flag);
}
for (int i = 0; i < uworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)uworkers.GetItem(i);
worker->SetInstantRun(flag);
}
for (int i = 0; i < bworkers.count; i++) {
XWorkerJob * worker = (XWorkerJob*)bworkers.GetItem(i);
worker->SetInstantRun(flag);
}
}
/* start the workers */ /* start the workers */
void XLeader::Start() void XLeader::Start()
{ {
...@@ -368,6 +392,16 @@ void XLeader::WaitForFinishing(int sleepTime) ...@@ -368,6 +392,16 @@ void XLeader::WaitForFinishing(int sleepTime)
} }
} }
if (finished) {
for (int i = 0; i < bworkers.count; i++) {
XWorkerJob* worker = (XWorkerJob*)bworkers[i];
if (worker->GetJobNum() > 0) {
finished = false;
break;
}
}
}
if (finished) if (finished)
break; break;
......
...@@ -123,6 +123,9 @@ public: ...@@ -123,6 +123,9 @@ public:
/* set the communication mode */ /* set the communication mode */
void SetMode(XLEADER_MODE myMode); void SetMode(XLEADER_MODE myMode);
/* set the flag of instant run */
void SetInstantRun(bool flag = true);
/* add a number of job workers (given their device ids) */ /* add a number of job workers (given their device ids) */
void AddJobWorker(XModel * model, int n, int * ids); void AddJobWorker(XModel * model, int n, int * ids);
......
...@@ -110,6 +110,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor, ...@@ -110,6 +110,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
leader.AddJobCollectWorker(); leader.AddJobCollectWorker();
leader.AddJobUpdateWorker(model, optimizer); leader.AddJobUpdateWorker(model, optimizer);
leader.AddJobBroadcastWorker(); leader.AddJobBroadcastWorker();
//leader.SetInstantRun();
leader.SetServerModel(config, model); leader.SetServerModel(config, model);
leader.Start(); leader.Start();
......
...@@ -37,6 +37,7 @@ XWorker::XWorker() ...@@ -37,6 +37,7 @@ XWorker::XWorker()
devID = -1; devID = -1;
id = -1; id = -1;
state = XWORKER_UNSTARTED; state = XWORKER_UNSTARTED;
isInstantRun = false;
} }
/* de-constructor */ /* de-constructor */
...@@ -69,6 +70,12 @@ int XWorker::GetID() ...@@ -69,6 +70,12 @@ int XWorker::GetID()
return id; return id;
} }
/* set the flag of instant run */
void XWorker::SetInstantRun(bool flag)
{
isInstantRun = flag;
}
/* /*
enqueue a new job enqueue a new job
>> job - the job function >> job - the job function
......
...@@ -60,6 +60,9 @@ protected: ...@@ -60,6 +60,9 @@ protected:
/* state of the worker */ /* state of the worker */
XWORKER_STATE state; XWORKER_STATE state;
/* fire the flag of instant run */
bool isInstantRun;
public: public:
/* constructor */ /* constructor */
XWorker(); XWorker();
...@@ -79,6 +82,9 @@ public: ...@@ -79,6 +82,9 @@ public:
/* get worker id */ /* get worker id */
int GetID(); int GetID();
/* set the flag of instant run */
void SetInstantRun(bool flag = true);
/* enqueue a new job */ /* enqueue a new job */
void AddJob(void * job, XList * jobArgs); void AddJob(void * job, XList * jobArgs);
......
...@@ -59,6 +59,8 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s ...@@ -59,6 +59,8 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
{ {
TensorList & sp = source->params; TensorList & sp = source->params;
int finished = 0; int finished = 0;
int * finishedFlag = new int[sp.count];
memset(finishedFlag, 0, sizeof(int) * sp.count);
/* check */ /* check */
for (int i = 0; i < targetList->count; i++) { for (int i = 0; i < targetList->count; i++) {
...@@ -69,7 +71,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s ...@@ -69,7 +71,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
/* the major body of broadcasting */ /* the major body of broadcasting */
while (1) { while (1) {
for (int i = 0; i < sp.count; i++) { for (int i = 0; i < sp.count; i++) {
if (source->flags[i] == PARAM_STATE_UPDATED) { if (source->flags[i] == PARAM_STATE_UPDATED && finishedFlag[i] == 0) {
for (int j = 0; j < targetList->count; j++) { for (int j = 0; j < targetList->count; j++) {
XModel * target = (XModel*)targetList->GetItem(j); XModel * target = (XModel*)targetList->GetItem(j);
TensorList & tp = target->params; TensorList & tp = target->params;
...@@ -81,12 +83,21 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s ...@@ -81,12 +83,21 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
target->flags[i] = PARAM_STATE_UPDATED; target->flags[i] = PARAM_STATE_UPDATED;
finished++; finished++;
} }
finishedFlag[i] = 1;
} }
} }
if (finished == sp.count * targetList->count) if (finished == sp.count * targetList->count)
break; break;
#ifdef _WIN32
Sleep((DWORD)sleepTime);
#else
sleep((unsigned)sleepTime / 1000);
#endif
} }
delete[] finishedFlag;
} }
/* /*
...@@ -95,6 +106,7 @@ wrapper of BroadcastData ...@@ -95,6 +106,7 @@ wrapper of BroadcastData
*/ */
void XWorkerBroadcast::Broadcast(XList * args) void XWorkerBroadcast::Broadcast(XList * args)
{ {
fprintf(stderr, "broadcast 0\n");
XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(0); XWorkerBroadcast * broadcaster = (XWorkerBroadcast*)args->GetItem(0);
XModel * source = (XModel*)args->GetItem(1); XModel * source = (XModel*)args->GetItem(1);
...@@ -107,6 +119,7 @@ void XWorkerBroadcast::Broadcast(XList * args) ...@@ -107,6 +119,7 @@ void XWorkerBroadcast::Broadcast(XList * args)
} }
broadcaster->BroadcastData(source, &target, SLEEP_TIME_IN_BROADCASTING); broadcaster->BroadcastData(source, &target, SLEEP_TIME_IN_BROADCASTING);
fprintf(stderr, "broadcast 1\n");
} }
/* /*
...@@ -139,7 +152,10 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList) ...@@ -139,7 +152,10 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
args.AddInt(targetList->count); args.AddInt(targetList->count);
args.AddList(targetList); args.AddList(targetList);
queue.EnqueueJob((void*)(char*)XWorkerBroadcast::Broadcast, &args); if (isInstantRun)
XWorkerBroadcast::Broadcast(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerBroadcast::Broadcast, &args);
return true; return true;
} }
......
...@@ -68,6 +68,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep ...@@ -68,6 +68,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
CheckNTErrors(sp.count == tp.count, "Incompatiable models!"); CheckNTErrors(sp.count == tp.count, "Incompatiable models!");
} }
//fprintf(stderr, "collect data in 0\n");
/* This is a simple implementation of the wait-and-collect process. But /* This is a simple implementation of the wait-and-collect process. But
there is a risk that some models are not available, that is, the there is a risk that some models are not available, that is, the
loop would never stop. A solution might be that we force the loop loop would never stop. A solution might be that we force the loop
...@@ -173,6 +175,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep ...@@ -173,6 +175,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
/* wrapper of CollectData */ /* wrapper of CollectData */
void XWorkerCollect::Collect(XList * args) void XWorkerCollect::Collect(XList * args)
{ {
fprintf(stderr, "collect data 0\n");
XWorkerCollect * collecter = (XWorkerCollect*)args->GetItem(0); XWorkerCollect * collecter = (XWorkerCollect*)args->GetItem(0);
int sourceNum = args->GetItemInt(1); int sourceNum = args->GetItemInt(1);
...@@ -187,6 +191,8 @@ void XWorkerCollect::Collect(XList * args) ...@@ -187,6 +191,8 @@ void XWorkerCollect::Collect(XList * args)
XModel * target = (XModel*)args->GetItem(2 + sourceNum); XModel * target = (XModel*)args->GetItem(2 + sourceNum);
collecter->CollectData(&source, target, SLEEP_TIME_IN_COLLECTING); collecter->CollectData(&source, target, SLEEP_TIME_IN_COLLECTING);
fprintf(stderr, "collect data 1\n");
} }
/* /*
...@@ -253,7 +259,10 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target) ...@@ -253,7 +259,10 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
args.AddList(sourceList); args.AddList(sourceList);
args.Add(target); args.Add(target);
queue.EnqueueJob((void*)(char*)XWorkerCollect::Collect, &args); if (isInstantRun)
XWorkerCollect::Collect(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::Collect, &args);
return true; return true;
} }
...@@ -302,6 +311,8 @@ void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, long ...@@ -302,6 +311,8 @@ void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, long
/* wrapper of CollectOtherData */ /* wrapper of CollectOtherData */
void XWorkerCollect::CollectOther(XList* args) void XWorkerCollect::CollectOther(XList* args)
{ {
//fprintf(stderr, "collect data other 0\n");
XWorkerCollect* collecter = (XWorkerCollect*)args->GetItem(0); XWorkerCollect* collecter = (XWorkerCollect*)args->GetItem(0);
int sourceNum = args->GetItemInt(1); int sourceNum = args->GetItemInt(1);
...@@ -316,6 +327,8 @@ void XWorkerCollect::CollectOther(XList* args) ...@@ -316,6 +327,8 @@ void XWorkerCollect::CollectOther(XList* args)
XNNRecord* target = (XNNRecord*)args->GetItem(2 + sourceNum); XNNRecord* target = (XNNRecord*)args->GetItem(2 + sourceNum);
collecter->CollectOtherData(&source, target, SLEEP_TIME_IN_COLLECTING_OTHER); collecter->CollectOtherData(&source, target, SLEEP_TIME_IN_COLLECTING_OTHER);
//fprintf(stderr, "collect data other 1\n");
} }
/* /*
...@@ -335,7 +348,10 @@ bool XWorkerCollect::AddJobCollectOther(XList* sourceList, XNNRecord* target) ...@@ -335,7 +348,10 @@ bool XWorkerCollect::AddJobCollectOther(XList* sourceList, XNNRecord* target)
args.AddList(sourceList); args.AddList(sourceList);
args.Add(target); args.Add(target);
queue.EnqueueJob((void*)(char*)XWorkerCollect::CollectOther, &args); if (isInstantRun)
XWorkerCollect::CollectOther(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerCollect::CollectOther, &args);
return true; return true;
} }
......
...@@ -180,12 +180,19 @@ add a new job of model refreshment ...@@ -180,12 +180,19 @@ add a new job of model refreshment
*/ */
bool XWorkerJob::AddJobRefresh(XModel * myModel) bool XWorkerJob::AddJobRefresh(XModel * myModel)
{ {
//fprintf(stderr, "refresh 0\n");
CheckNTErrors(myModel != NULL, "no parameter keeper!"); CheckNTErrors(myModel != NULL, "no parameter keeper!");
XList args(1); XList args(1);
args.Add(myModel); args.Add(myModel);
queue.EnqueueJob((void*)(char*)XModel::Refresh, &args); if(isInstantRun)
XModel::Refresh(&args);
else
queue.EnqueueJob((void*)(char*)XModel::Refresh, &args);
//fprintf(stderr, "refresh 1\n");
return true; return true;
} }
...@@ -213,7 +220,10 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel, ...@@ -213,7 +220,10 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel,
args.Add(golds); args.Add(golds);
args.Add(losses); args.Add(losses);
queue.EnqueueJob((void*)(char*)XModel::Run, &args); if(isInstantRun)
XModel::Run(&args);
else
queue.EnqueueJob((void*)(char*)XModel::Run, &args);
SetState(XWORKER_STARTED); SetState(XWORKER_STARTED);
...@@ -226,7 +236,10 @@ bool XWorkerJob::AddJobRecord() ...@@ -226,7 +236,10 @@ bool XWorkerJob::AddJobRecord()
XList args; XList args;
args.Add(this); args.Add(this);
queue.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args); if (isInstantRun)
XWorkerJob::RecordMeStatic(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerJob::RecordMeStatic, &args);
return true; return true;
} }
...@@ -234,6 +247,8 @@ bool XWorkerJob::AddJobRecord() ...@@ -234,6 +247,8 @@ bool XWorkerJob::AddJobRecord()
/* wrapper of RecordMe */ /* wrapper of RecordMe */
void XWorkerJob::RecordMeStatic(XList* args) void XWorkerJob::RecordMeStatic(XList* args)
{ {
//fprintf(stderr, "record static 0\n");
CheckNTErrors(args != NULL && args->count > 0, "Illegal arguments!"); CheckNTErrors(args != NULL && args->count > 0, "Illegal arguments!");
XWorkerJob * worker = (XWorkerJob*)args->GetItem(0); XWorkerJob * worker = (XWorkerJob*)args->GetItem(0);
...@@ -241,6 +256,8 @@ void XWorkerJob::RecordMeStatic(XList* args) ...@@ -241,6 +256,8 @@ void XWorkerJob::RecordMeStatic(XList* args)
worker->RecordMe(); worker->RecordMe();
worker->SetState(XWORKER_FINISHED); worker->SetState(XWORKER_FINISHED);
//fprintf(stderr, "record static 1\n");
} }
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
......
...@@ -81,7 +81,6 @@ void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sle ...@@ -81,7 +81,6 @@ void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sle
flags[i] = PARAM_STATE_UPDATED; flags[i] = PARAM_STATE_UPDATED;
finished++; finished++;
} }
} }
if (finished == params.count) if (finished == params.count)
...@@ -103,6 +102,8 @@ wrapper of UpdateModel ...@@ -103,6 +102,8 @@ wrapper of UpdateModel
*/ */
void XWorkerUpdate::Update(XList * args) void XWorkerUpdate::Update(XList * args)
{ {
fprintf(stderr, "update 0\n");
CheckNTErrors(args != NULL && args->count >= 3, "Illegal argument list!"); CheckNTErrors(args != NULL && args->count >= 3, "Illegal argument list!");
XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(0); XWorkerUpdate * updater = (XWorkerUpdate*)args->GetItem(0);
...@@ -110,6 +111,8 @@ void XWorkerUpdate::Update(XList * args) ...@@ -110,6 +111,8 @@ void XWorkerUpdate::Update(XList * args)
XOptimizer * optimizer = (XOptimizer*)args->GetItem(2); XOptimizer * optimizer = (XOptimizer*)args->GetItem(2);
updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE); updater->UpdateModel(model, optimizer, SLEEP_TIME_IN_MODEL_UPDATE);
fprintf(stderr, "update 1\n");
} }
/* /*
...@@ -127,7 +130,10 @@ bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer) ...@@ -127,7 +130,10 @@ bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer)
args.Add(model); args.Add(model);
args.Add(optimizer); args.Add(optimizer);
queue.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args); if(isInstantRun)
XWorkerUpdate::Update(&args);
else
queue.EnqueueJob((void*)(char*)XWorkerUpdate::Update, &args);
return true; return true;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论