Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
87bb27ee
Commit
87bb27ee
authored
Mar 09, 2021
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
updates
parent
b69e10f6
隐藏空白字符变更
内嵌
并排
正在显示
18 个修改的文件
包含
185 行增加
和
173 行删除
+185
-173
source/network/XBackwardFunc.cpp
+1
-0
source/network/XBackwardLoss.cpp
+1
-0
source/network/XBackwardMath.cpp
+4
-77
source/network/XBackwardShape.cpp
+18
-0
source/network/XNet.cpp
+1
-0
source/tensor/XList.cpp
+35
-84
source/tensor/XMem.cpp
+3
-0
source/tensor/XQueue.cpp
+5
-1
source/train/TTrain.cpp
+2
-2
source/train/XLeader.cpp
+34
-0
source/train/XLeader.h
+3
-0
source/train/XTrainer.cpp
+1
-0
source/train/XWorker.cpp
+7
-0
source/train/XWorker.h
+6
-0
source/train/XWorkerBroadcast.cpp
+18
-2
source/train/XWorkerCollect.cpp
+18
-2
source/train/XWorkerJob.cpp
+20
-3
source/train/XWorkerUpdate.cpp
+8
-2
没有找到文件。
source/network/XBackwardFunc.cpp
查看文件 @
87bb27ee
...
@@ -93,6 +93,7 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
...
@@ -93,6 +93,7 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/* indicates whether the node is for an activation function */
/* indicates whether the node is for an activation function */
...
...
source/network/XBackwardLoss.cpp
查看文件 @
87bb27ee
...
@@ -89,6 +89,7 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
...
@@ -89,6 +89,7 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/* indicates whether the node is for a loss computation */
/* indicates whether the node is for a loss computation */
...
...
source/network/XBackwardMath.cpp
查看文件 @
87bb27ee
...
@@ -125,6 +125,9 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
...
@@ -125,6 +125,9 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
else
{
else
{
ShowNTErrors
(
"Unsupported backward computation! TODO!"
);
ShowNTErrors
(
"Unsupported backward computation! TODO!"
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/* indicates whether the node is for a math operation */
/* indicates whether the node is for a math operation */
...
@@ -162,8 +165,6 @@ void XMathGrad::GradAbsolute(XTensor * node, bool isEfficient)
...
@@ -162,8 +165,6 @@ void XMathGrad::GradAbsolute(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -194,8 +195,6 @@ void XMathGrad::GradCos(XTensor * node, bool isEfficient)
...
@@ -194,8 +195,6 @@ void XMathGrad::GradCos(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -225,8 +224,6 @@ void XMathGrad::GradExp(XTensor * node, bool isEfficient)
...
@@ -225,8 +224,6 @@ void XMathGrad::GradExp(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -251,8 +248,6 @@ void XMathGrad::GradLog(XTensor * node, bool isEfficient)
...
@@ -251,8 +248,6 @@ void XMathGrad::GradLog(XTensor * node, bool isEfficient)
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
a
);
_Div
(
node
->
grad
,
a
,
a
->
grad
,
1.0
F
);
_Div
(
node
->
grad
,
a
,
a
->
grad
,
1.0
F
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -276,8 +271,6 @@ void XMathGrad::GradRound(XTensor * node, bool isEfficient)
...
@@ -276,8 +271,6 @@ void XMathGrad::GradRound(XTensor * node, bool isEfficient)
if
(
!
isEfficient
||
a
->
isGrad
)
{
if
(
!
isEfficient
||
a
->
isGrad
)
{
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
a
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -301,8 +294,6 @@ void XMathGrad::GradSign(XTensor * node, bool isEfficient)
...
@@ -301,8 +294,6 @@ void XMathGrad::GradSign(XTensor * node, bool isEfficient)
if
(
!
isEfficient
||
a
->
isGrad
)
{
if
(
!
isEfficient
||
a
->
isGrad
)
{
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
a
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -332,8 +323,6 @@ void XMathGrad::GradSin(XTensor * node, bool isEfficient)
...
@@ -332,8 +323,6 @@ void XMathGrad::GradSin(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -364,8 +353,6 @@ void XMathGrad::GradTan(XTensor * node, bool isEfficient)
...
@@ -364,8 +353,6 @@ void XMathGrad::GradTan(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -398,8 +385,6 @@ void XMathGrad::GradClip(XTensor * node, bool isEfficient)
...
@@ -398,8 +385,6 @@ void XMathGrad::GradClip(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -440,8 +425,6 @@ void XMathGrad::GradDiv(XTensor * node, bool isEfficient)
...
@@ -440,8 +425,6 @@ void XMathGrad::GradDiv(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -539,8 +522,6 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient)
...
@@ -539,8 +522,6 @@ void XMathGrad::GradDivDim(XTensor * node, bool isEfficient)
DelTensorBuf
(
aTMP2
);
DelTensorBuf
(
aTMP2
);
DelTensorBuf
(
aTMP1
);
DelTensorBuf
(
aTMP1
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -602,8 +583,6 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
...
@@ -602,8 +583,6 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
else
{
else
{
ShowNTErrors
(
"TODO!"
);
ShowNTErrors
(
"TODO!"
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -757,8 +736,6 @@ void XMathGrad::GradMatrixMulBatched(XTensor * node, bool isEfficient)
...
@@ -757,8 +736,6 @@ void XMathGrad::GradMatrixMulBatched(XTensor * node, bool isEfficient)
if
(
!
isEfficient
||
b
->
isGrad
)
if
(
!
isEfficient
||
b
->
isGrad
)
_MatrixMulBatched
(
dedc
,
X_TRANS
,
a
,
X_TRANS
,
dedb
,
alpha
,
1.0
F
);
_MatrixMulBatched
(
dedc
,
X_TRANS
,
a
,
X_TRANS
,
dedb
,
alpha
,
1.0
F
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -793,8 +770,6 @@ void XMathGrad::GradMultiply(XTensor * node, bool isEfficient)
...
@@ -793,8 +770,6 @@ void XMathGrad::GradMultiply(XTensor * node, bool isEfficient)
XNoder
::
MakeGrad
(
b
);
XNoder
::
MakeGrad
(
b
);
_Multiply
(
node
->
grad
,
a
,
b
->
grad
,
1.0
F
);
_Multiply
(
node
->
grad
,
a
,
b
->
grad
,
1.0
F
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -879,8 +854,6 @@ void XMathGrad::GradMultiplyDim(XTensor * node, bool isEfficient)
...
@@ -879,8 +854,6 @@ void XMathGrad::GradMultiplyDim(XTensor * node, bool isEfficient)
}
}
DelTensorBuf
(
bGradTMP
);
DelTensorBuf
(
bGradTMP
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -916,8 +889,6 @@ void XMathGrad::GradMultiplyBroadcast(XTensor * node, bool isEfficient)
...
@@ -916,8 +889,6 @@ void XMathGrad::GradMultiplyBroadcast(XTensor * node, bool isEfficient)
if
(
b
->
isVar
||
b
->
income
.
tailNum
>
0
)
if
(
b
->
isVar
||
b
->
income
.
tailNum
>
0
)
ShowNTErrors
(
"TODO"
);
ShowNTErrors
(
"TODO"
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -942,8 +913,6 @@ void XMathGrad::GradNegate(XTensor * node, bool isEfficient)
...
@@ -942,8 +913,6 @@ void XMathGrad::GradNegate(XTensor * node, bool isEfficient)
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
a
);
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
-
1.0
F
);
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
-
1.0
F
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -987,8 +956,6 @@ void XMathGrad::GradPower(XTensor * node, bool isEfficient)
...
@@ -987,8 +956,6 @@ void XMathGrad::GradPower(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
...
@@ -1019,8 +986,6 @@ void XMathGrad::GradReciprocal(XTensor* node, bool isEfficient)
...
@@ -1019,8 +986,6 @@ void XMathGrad::GradReciprocal(XTensor* node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1049,8 +1014,6 @@ void XMathGrad::GradSqrt(XTensor * node, bool isEfficient)
...
@@ -1049,8 +1014,6 @@ void XMathGrad::GradSqrt(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1080,8 +1043,6 @@ void XMathGrad::GradSquare(XTensor * node, bool isEfficient)
...
@@ -1080,8 +1043,6 @@ void XMathGrad::GradSquare(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1109,8 +1070,6 @@ void XMathGrad::GradScaleAndShift(XTensor * node, bool isEfficient)
...
@@ -1109,8 +1070,6 @@ void XMathGrad::GradScaleAndShift(XTensor * node, bool isEfficient)
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
scale
);
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
scale
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1138,8 +1097,6 @@ void XMathGrad::GradScale(XTensor * node, bool isEfficient)
...
@@ -1138,8 +1097,6 @@ void XMathGrad::GradScale(XTensor * node, bool isEfficient)
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
scale
);
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
scale
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1166,9 +1123,7 @@ void XMathGrad::GradDescale(XTensor * node, bool isEfficient)
...
@@ -1166,9 +1123,7 @@ void XMathGrad::GradDescale(XTensor * node, bool isEfficient)
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
a
);
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
1
/
descale
);
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
,
1
/
descale
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1194,8 +1149,6 @@ void XMathGrad::GradShift(XTensor * node, bool isEfficient)
...
@@ -1194,8 +1149,6 @@ void XMathGrad::GradShift(XTensor * node, bool isEfficient)
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
);
_Sum
(
a
->
grad
,
node
->
grad
,
a
->
grad
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1229,8 +1182,6 @@ void XMathGrad::GradSub(XTensor * node, bool isEfficient)
...
@@ -1229,8 +1182,6 @@ void XMathGrad::GradSub(XTensor * node, bool isEfficient)
XNoder
::
MakeGrad
(
b
);
XNoder
::
MakeGrad
(
b
);
_Sum
(
b
->
grad
,
node
->
grad
,
b
->
grad
,
-
beta
);
_Sum
(
b
->
grad
,
node
->
grad
,
b
->
grad
,
-
beta
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1317,8 +1268,6 @@ void XMathGrad::GradSubDim(XTensor * node, bool isEfficient)
...
@@ -1317,8 +1268,6 @@ void XMathGrad::GradSubDim(XTensor * node, bool isEfficient)
DelTensorBuf
(
interGrad
);
DelTensorBuf
(
interGrad
);
}
}
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1352,8 +1301,6 @@ void XMathGrad::GradSum(XTensor * node, bool isEfficient)
...
@@ -1352,8 +1301,6 @@ void XMathGrad::GradSum(XTensor * node, bool isEfficient)
XNoder
::
MakeGrad
(
b
);
XNoder
::
MakeGrad
(
b
);
_Sum
(
b
->
grad
,
node
->
grad
,
b
->
grad
,
beta
);
_Sum
(
b
->
grad
,
node
->
grad
,
b
->
grad
,
beta
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1441,8 +1388,6 @@ void XMathGrad::GradSumDim(XTensor * node, bool isEfficient)
...
@@ -1441,8 +1388,6 @@ void XMathGrad::GradSumDim(XTensor * node, bool isEfficient)
DelTensorBuf
(
interGrad
);
DelTensorBuf
(
interGrad
);
}
}
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1480,8 +1425,6 @@ void XMathGrad::GradSumBroadcast(XTensor * node, bool isEfficient)
...
@@ -1480,8 +1425,6 @@ void XMathGrad::GradSumBroadcast(XTensor * node, bool isEfficient)
ShowNTErrors
(
"TODO"
);
ShowNTErrors
(
"TODO"
);
}
}
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1516,8 +1459,6 @@ void XMathGrad::GradReduceMean(XTensor * node, bool isEfficient)
...
@@ -1516,8 +1459,6 @@ void XMathGrad::GradReduceMean(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1550,8 +1491,6 @@ void XMathGrad::GradReduceSum(XTensor * node, bool isEfficient)
...
@@ -1550,8 +1491,6 @@ void XMathGrad::GradReduceSum(XTensor * node, bool isEfficient)
_Sum
(
a
->
grad
,
tmp
,
a
->
grad
);
_Sum
(
a
->
grad
,
tmp
,
a
->
grad
);
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1582,8 +1521,6 @@ void XMathGrad::GradReduceSumAll(XTensor * node, bool isEfficient)
...
@@ -1582,8 +1521,6 @@ void XMathGrad::GradReduceSumAll(XTensor * node, bool isEfficient)
_Sum
(
a
->
grad
,
tmp
,
a
->
grad
);
_Sum
(
a
->
grad
,
tmp
,
a
->
grad
);
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1639,8 +1576,6 @@ void XMathGrad::GradReduceSumSquared(XTensor * node, bool isEfficient)
...
@@ -1639,8 +1576,6 @@ void XMathGrad::GradReduceSumSquared(XTensor * node, bool isEfficient)
DelTensorBuf
(
e
);
DelTensorBuf
(
e
);
DelTensorBuf
(
d
);
DelTensorBuf
(
d
);
DelTensorBuf
(
c
);
DelTensorBuf
(
c
);
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1696,8 +1631,6 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
...
@@ -1696,8 +1631,6 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
DelTensorBuf
(
e
);
DelTensorBuf
(
e
);
DelTensorBuf
(
d
);
DelTensorBuf
(
d
);
DelTensorBuf
(
c
);
DelTensorBuf
(
c
);
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1815,9 +1748,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
...
@@ -1815,9 +1748,6 @@ void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
dedx
->
Reshape
(
orderBackupX
,
dimsBackupX
);
dedx
->
Reshape
(
orderBackupX
,
dimsBackupX
);
dedc
->
Reshape
(
orderBackupC
,
dimsBackupC
);
dedc
->
Reshape
(
orderBackupC
,
dimsBackupC
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
/*
/*
...
@@ -1933,9 +1863,6 @@ void XMathGrad::GradMLP(XTensor* node, bool isEfficient)
...
@@ -1933,9 +1863,6 @@ void XMathGrad::GradMLP(XTensor* node, bool isEfficient)
dedx
->
Reshape
(
orderBackupX
,
dimsBackupX
);
dedx
->
Reshape
(
orderBackupX
,
dimsBackupX
);
dedc
->
Reshape
(
orderBackupC
,
dimsBackupC
);
dedc
->
Reshape
(
orderBackupC
,
dimsBackupC
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
}
}
}
}
source/network/XBackwardShape.cpp
查看文件 @
87bb27ee
...
@@ -111,6 +111,9 @@ void XShapeGrad::GradConvertDataType(XTensor* node, bool isEfficient)
...
@@ -111,6 +111,9 @@ void XShapeGrad::GradConvertDataType(XTensor* node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -144,6 +147,9 @@ void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficient)
...
@@ -144,6 +147,9 @@ void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficient)
DelTensorBuf
(
tmp
);
DelTensorBuf
(
tmp
);
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -176,6 +182,7 @@ void XShapeGrad::GradGather(XTensor * node, bool isEfficient)
...
@@ -176,6 +182,7 @@ void XShapeGrad::GradGather(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -208,6 +215,7 @@ void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficient)
...
@@ -208,6 +215,7 @@ void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -299,6 +307,7 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficient)
...
@@ -299,6 +307,7 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -382,6 +391,7 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient)
...
@@ -382,6 +391,7 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -410,6 +420,7 @@ void XShapeGrad::GradReshape(XTensor * node, bool isEfficient)
...
@@ -410,6 +420,7 @@ void XShapeGrad::GradReshape(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -455,6 +466,7 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient)
...
@@ -455,6 +466,7 @@ void XShapeGrad::GradSplit(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -539,6 +551,9 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
...
@@ -539,6 +551,9 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
DelTensorBuf
(
nodeGradTMP
);
DelTensorBuf
(
nodeGradTMP
);
}
}
}
}
node
->
visitMark
=
NODE_DOING
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -577,6 +592,7 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
...
@@ -577,6 +592,7 @@ void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
/*
/*
...
@@ -615,6 +631,7 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
...
@@ -615,6 +631,7 @@ void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
}
}
node
->
visitMark
=
NODE_FINISHED
;
node
->
visitMark
=
NODE_FINISHED
;
node
->
isGradFinished
=
true
;
}
}
}
}
\ No newline at end of file
source/network/XNet.cpp
查看文件 @
87bb27ee
...
@@ -101,6 +101,7 @@ void XNet::Backward(TensorList &roots)
...
@@ -101,6 +101,7 @@ void XNet::Backward(TensorList &roots)
for
(
int
i
=
0
;
i
<
nodes
.
count
;
i
++
){
for
(
int
i
=
0
;
i
<
nodes
.
count
;
i
++
){
XTensor
*
node
=
(
XTensor
*
)
nodes
.
Get
(
i
);
XTensor
*
node
=
(
XTensor
*
)
nodes
.
Get
(
i
);
node
->
visitMark
=
NODE_UNFINISHED
;
node
->
visitMark
=
NODE_UNFINISHED
;
node
->
isGradFinished
=
false
;
}
}
/* back-propagation from output to input */
/* back-propagation from output to input */
...
...
source/tensor/XList.cpp
查看文件 @
87bb27ee
...
@@ -36,7 +36,7 @@ TensorListBase<T>::TensorListBase()
...
@@ -36,7 +36,7 @@ TensorListBase<T>::TensorListBase()
{
{
maxNum
=
1
;
maxNum
=
1
;
count
=
0
;
count
=
0
;
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
1
)
;
items
=
new
T
[
1
]
;
}
}
/*
/*
...
@@ -49,7 +49,7 @@ TensorListBase<T>::TensorListBase(int myMaxNum)
...
@@ -49,7 +49,7 @@ TensorListBase<T>::TensorListBase(int myMaxNum)
CheckNTErrors
(
myMaxNum
>
0
,
"check if the input number > 0"
);
CheckNTErrors
(
myMaxNum
>
0
,
"check if the input number > 0"
);
maxNum
=
myMaxNum
;
maxNum
=
myMaxNum
;
count
=
0
;
count
=
0
;
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
myMaxNum
)
;
items
=
new
T
[
myMaxNum
]
;
}
}
/*
/*
...
@@ -62,7 +62,7 @@ TensorListBase<T>::TensorListBase(const T* inputItems, int inputItemCount)
...
@@ -62,7 +62,7 @@ TensorListBase<T>::TensorListBase(const T* inputItems, int inputItemCount)
CheckNTErrors
(
inputItemCount
>
0
,
"check if the input number > 0"
);
CheckNTErrors
(
inputItemCount
>
0
,
"check if the input number > 0"
);
maxNum
=
inputItemCount
;
maxNum
=
inputItemCount
;
count
=
inputItemCount
;
count
=
inputItemCount
;
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
inputItemCount
)
;
items
=
new
T
[
inputItemCount
]
;
memcpy
(
items
,
inputItems
,
inputItemCount
*
sizeof
(
T
));
memcpy
(
items
,
inputItems
,
inputItemCount
*
sizeof
(
T
));
}
}
...
@@ -73,7 +73,7 @@ TensorListBase<T>::TensorListBase(const TensorListBase<T>& l)
...
@@ -73,7 +73,7 @@ TensorListBase<T>::TensorListBase(const TensorListBase<T>& l)
CheckNTErrors
(
l
.
maxNum
>
0
,
"check if the input number > 0"
);
CheckNTErrors
(
l
.
maxNum
>
0
,
"check if the input number > 0"
);
maxNum
=
l
.
maxNum
;
maxNum
=
l
.
maxNum
;
count
=
l
.
count
;
count
=
l
.
count
;
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
maxNum
)
;
items
=
new
T
[
maxNum
]
;
memcpy
(
items
,
l
.
items
,
l
.
count
*
sizeof
(
T
));
memcpy
(
items
,
l
.
items
,
l
.
count
*
sizeof
(
T
));
}
}
...
@@ -94,7 +94,7 @@ TensorListBase<T> TensorListBase<T>::operator=(const TensorListBase<T>& l)
...
@@ -94,7 +94,7 @@ TensorListBase<T> TensorListBase<T>::operator=(const TensorListBase<T>& l)
{
{
maxNum
=
l
.
maxNum
;
maxNum
=
l
.
maxNum
;
count
=
l
.
count
;
count
=
l
.
count
;
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
maxNum
)
;
items
=
new
T
[
maxNum
]
;
memcpy
(
items
,
l
.
items
,
l
.
count
*
sizeof
(
T
));
memcpy
(
items
,
l
.
items
,
l
.
count
*
sizeof
(
T
));
return
*
this
;
return
*
this
;
}
}
...
@@ -105,7 +105,7 @@ TensorListBase<T> TensorListBase<T>::operator=(TensorListBase<T>&& l)
...
@@ -105,7 +105,7 @@ TensorListBase<T> TensorListBase<T>::operator=(TensorListBase<T>&& l)
{
{
maxNum
=
l
.
maxNum
;
maxNum
=
l
.
maxNum
;
count
=
l
.
count
;
count
=
l
.
count
;
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
maxNum
)
;
items
=
new
T
[
maxNum
]
;
memcpy
(
items
,
l
.
items
,
l
.
count
*
sizeof
(
T
));
memcpy
(
items
,
l
.
items
,
l
.
count
*
sizeof
(
T
));
return
*
this
;
return
*
this
;
}
}
...
@@ -115,7 +115,7 @@ template <typename T>
...
@@ -115,7 +115,7 @@ template <typename T>
TensorListBase
<
T
>::~
TensorListBase
()
TensorListBase
<
T
>::~
TensorListBase
()
{
{
if
(
items
!=
NULL
)
if
(
items
!=
NULL
)
free
(
items
)
;
delete
[]
items
;
items
=
NULL
;
items
=
NULL
;
}
}
...
@@ -127,17 +127,10 @@ template <typename T>
...
@@ -127,17 +127,10 @@ template <typename T>
void
TensorListBase
<
T
>::
Reallocate
(
int
itemNum
)
void
TensorListBase
<
T
>::
Reallocate
(
int
itemNum
)
{
{
if
(
maxNum
<
itemNum
)
{
if
(
maxNum
<
itemNum
)
{
T
*
newItems
;
T
*
newItems
=
new
T
[
itemNum
];
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
newItems
=
(
T
*
)
realloc
(
items
,
sizeof
(
T
)
*
itemNum
);
delete
[]
items
;
if
(
newItems
!=
NULL
)
items
=
newItems
;
items
=
newItems
;
else
{
newItems
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
itemNum
);
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
free
(
items
);
items
=
newItems
;
}
maxNum
=
itemNum
;
maxNum
=
itemNum
;
}
}
}
}
...
@@ -150,20 +143,10 @@ template <typename T>
...
@@ -150,20 +143,10 @@ template <typename T>
void
TensorListBase
<
T
>::
Add
(
T
&&
item
)
void
TensorListBase
<
T
>::
Add
(
T
&&
item
)
{
{
if
(
count
==
maxNum
)
{
if
(
count
==
maxNum
)
{
T
*
newItems
=
new
T
[
count
*
2
+
1
];
T
*
newItems
;
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
delete
[]
items
;
newItems
=
(
T
*
)
realloc
(
items
,
sizeof
(
T
)
*
(
count
*
2
+
1
));
items
=
newItems
;
if
(
newItems
!=
NULL
)
items
=
newItems
;
else
{
newItems
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
(
count
*
2
+
1
));
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
free
(
items
);
items
=
newItems
;
}
maxNum
=
count
*
2
+
1
;
maxNum
=
count
*
2
+
1
;
}
}
items
[
count
++
]
=
item
;
items
[
count
++
]
=
item
;
...
@@ -184,18 +167,10 @@ template <typename T>
...
@@ -184,18 +167,10 @@ template <typename T>
void
TensorListBase
<
T
>::
Add
(
const
T
&
item
)
void
TensorListBase
<
T
>::
Add
(
const
T
&
item
)
{
{
if
(
count
==
maxNum
)
{
if
(
count
==
maxNum
)
{
T
*
newItems
;
T
*
newItems
=
new
T
[
count
*
2
+
1
];
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
newItems
=
(
T
*
)
realloc
(
items
,
sizeof
(
T
)
*
(
count
*
2
+
1
));
delete
[]
items
;
if
(
newItems
!=
NULL
)
items
=
newItems
;
items
=
newItems
;
else
{
newItems
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
(
count
*
2
+
1
));
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
free
(
items
);
items
=
newItems
;
}
maxNum
=
count
*
2
+
1
;
maxNum
=
count
*
2
+
1
;
}
}
...
@@ -244,18 +219,10 @@ template <typename T>
...
@@ -244,18 +219,10 @@ template <typename T>
void
TensorListBase
<
T
>::
Add
(
const
T
*
inputItems
,
int
inputItemCount
)
void
TensorListBase
<
T
>::
Add
(
const
T
*
inputItems
,
int
inputItemCount
)
{
{
if
(
count
+
inputItemCount
>=
maxNum
)
{
if
(
count
+
inputItemCount
>=
maxNum
)
{
T
*
newItems
;
T
*
newItems
=
new
T
[
maxNum
+
count
+
inputItemCount
+
1
];
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
newItems
=
(
T
*
)
realloc
(
items
,
sizeof
(
T
)
*
(
count
+
inputItemCount
+
1
));
delete
[]
items
;
if
(
newItems
!=
NULL
)
items
=
newItems
;
items
=
newItems
;
else
{
newItems
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
(
maxNum
+
count
+
inputItemCount
+
1
));
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
free
(
items
);
items
=
newItems
;
}
maxNum
+=
(
count
+
inputItemCount
+
1
);
maxNum
+=
(
count
+
inputItemCount
+
1
);
}
}
memcpy
(
items
+
count
,
inputItems
,
sizeof
(
T
)
*
inputItemCount
);
memcpy
(
items
+
count
,
inputItems
,
sizeof
(
T
)
*
inputItemCount
);
...
@@ -281,18 +248,10 @@ template <typename T>
...
@@ -281,18 +248,10 @@ template <typename T>
void
TensorListBase
<
T
>::
Insert
(
int
pos
,
const
T
&
item
)
void
TensorListBase
<
T
>::
Insert
(
int
pos
,
const
T
&
item
)
{
{
if
(
count
==
maxNum
)
{
if
(
count
==
maxNum
)
{
T
*
newItems
;
T
*
newItems
=
new
T
[
count
*
2
+
1
];
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
newItems
=
(
T
*
)
realloc
(
items
,
sizeof
(
T
)
*
(
count
*
2
+
1
));
delete
[]
items
;
if
(
newItems
!=
NULL
)
items
=
newItems
;
items
=
newItems
;
else
{
newItems
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
(
count
*
2
+
1
));
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
free
(
items
);
items
=
newItems
;
}
maxNum
=
count
*
2
+
1
;
maxNum
=
count
*
2
+
1
;
}
}
...
@@ -306,18 +265,10 @@ template<typename T>
...
@@ -306,18 +265,10 @@ template<typename T>
void
TensorListBase
<
T
>::
Insert
(
int
pos
,
T
&&
item
)
void
TensorListBase
<
T
>::
Insert
(
int
pos
,
T
&&
item
)
{
{
if
(
count
==
maxNum
)
{
if
(
count
==
maxNum
)
{
T
*
newItems
;
T
*
newItems
=
new
T
[
count
*
2
+
1
];
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
newItems
=
(
T
*
)
realloc
(
items
,
sizeof
(
T
)
*
(
count
*
2
+
1
));
delete
[]
items
;
if
(
newItems
!=
NULL
)
items
=
newItems
;
items
=
newItems
;
else
{
newItems
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
(
count
*
2
+
1
));
memcpy
(
newItems
,
items
,
count
*
sizeof
(
T
));
free
(
items
);
items
=
newItems
;
}
maxNum
=
count
*
2
+
1
;
maxNum
=
count
*
2
+
1
;
}
}
...
@@ -459,7 +410,7 @@ void TensorListBase<T>::Clear()
...
@@ -459,7 +410,7 @@ void TensorListBase<T>::Clear()
count
=
0
;
count
=
0
;
maxNum
=
0
;
maxNum
=
0
;
if
(
items
!=
NULL
)
if
(
items
!=
NULL
)
free
(
items
)
;
delete
[]
items
;
items
=
NULL
;
items
=
NULL
;
}
}
...
@@ -514,7 +465,7 @@ void TensorListBase<T>::Reserve(int n)
...
@@ -514,7 +465,7 @@ void TensorListBase<T>::Reserve(int n)
return
;
return
;
}
}
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
n
)
;
items
=
new
T
[
n
]
;
}
}
/*
/*
...
@@ -560,8 +511,8 @@ void TensorListBase<T>::ReadFromFile(FILE* fp, int num)
...
@@ -560,8 +511,8 @@ void TensorListBase<T>::ReadFromFile(FILE* fp, int num)
if
(
!
items
)
if
(
!
items
)
Reserve
(
num
-
maxNum
);
Reserve
(
num
-
maxNum
);
else
{
else
{
free
(
items
)
;
delete
[]
items
;
items
=
(
T
*
)
malloc
(
sizeof
(
T
)
*
num
)
;
items
=
new
T
[
num
]
;
}
}
}
}
fread
(
items
,
sizeof
(
T
),
num
,
fp
);
fread
(
items
,
sizeof
(
T
),
num
,
fp
);
...
...
source/tensor/XMem.cpp
查看文件 @
87bb27ee
...
@@ -1604,6 +1604,9 @@ void XMemManager::GetBufferSize(MTYPE freeMem, MTYPE * myBufSize)
...
@@ -1604,6 +1604,9 @@ void XMemManager::GetBufferSize(MTYPE freeMem, MTYPE * myBufSize)
}
}
}
}
}
}
else
{
ShowNTErrors
(
"No enough memory for buffer allocation!"
);
}
}
}
/* initialize it and set the global memory information */
/* initialize it and set the global memory information */
...
...
source/tensor/XQueue.cpp
查看文件 @
87bb27ee
...
@@ -250,7 +250,11 @@ bool XQueue::GetJobBreak()
...
@@ -250,7 +250,11 @@ bool XQueue::GetJobBreak()
/* get the number of jobs */
/* get the number of jobs */
int
XQueue
::
GetJobNum
()
int
XQueue
::
GetJobNum
()
{
{
return
runningJobCount
;
MUTEX_LOCK
(
jobQueueMutex
);
int
c
=
runningJobCount
;
MUTEX_UNLOCK
(
jobQueueMutex
);
return
c
;
}
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
source/train/TTrain.cpp
查看文件 @
87bb27ee
...
@@ -306,7 +306,7 @@ run the neural network
...
@@ -306,7 +306,7 @@ run the neural network
*/
*/
bool
TTModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
)
bool
TTModel
::
RunSimple
(
XList
*
inputs
,
XList
*
outputs
,
XList
*
golds
,
XList
*
losses
)
{
{
//
fprintf(stderr, "run simple 0\n");
fprintf
(
stderr
,
"run simple 0
\n
"
);
CheckNTErrors
(
inputs
!=
NULL
&&
inputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
inputs
!=
NULL
&&
inputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
outputs
!=
NULL
&&
outputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
outputs
!=
NULL
&&
outputs
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
golds
!=
NULL
&&
golds
->
count
>=
1
,
"Wrong arguments!"
);
CheckNTErrors
(
golds
!=
NULL
&&
golds
->
count
>=
1
,
"Wrong arguments!"
);
...
@@ -340,7 +340,7 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* l
...
@@ -340,7 +340,7 @@ bool TTModel::RunSimple(XList * inputs, XList * outputs, XList * golds, XList* l
delete
[]
dims
;
delete
[]
dims
;
//
fprintf(stderr, "run simple 1\n");
fprintf
(
stderr
,
"run simple 1
\n
"
);
return
true
;
return
true
;
}
}
...
...
source/train/XLeader.cpp
查看文件 @
87bb27ee
...
@@ -141,6 +141,30 @@ void XLeader::SetMode(XLEADER_MODE myMode)
...
@@ -141,6 +141,30 @@ void XLeader::SetMode(XLEADER_MODE myMode)
mode
=
myMode
;
mode
=
myMode
;
}
}
/* set the flag of instant run */
void
XLeader
::
SetInstantRun
(
bool
flag
)
{
for
(
int
i
=
0
;
i
<
jworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
jworkers
.
GetItem
(
i
);
worker
->
SetInstantRun
(
flag
);
}
for
(
int
i
=
0
;
i
<
cworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
cworkers
.
GetItem
(
i
);
worker
->
SetInstantRun
(
flag
);
}
for
(
int
i
=
0
;
i
<
uworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
uworkers
.
GetItem
(
i
);
worker
->
SetInstantRun
(
flag
);
}
for
(
int
i
=
0
;
i
<
bworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
bworkers
.
GetItem
(
i
);
worker
->
SetInstantRun
(
flag
);
}
}
/* start the workers */
/* start the workers */
void
XLeader
::
Start
()
void
XLeader
::
Start
()
{
{
...
@@ -368,6 +392,16 @@ void XLeader::WaitForFinishing(int sleepTime)
...
@@ -368,6 +392,16 @@ void XLeader::WaitForFinishing(int sleepTime)
}
}
}
}
if
(
finished
)
{
for
(
int
i
=
0
;
i
<
bworkers
.
count
;
i
++
)
{
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
bworkers
[
i
];
if
(
worker
->
GetJobNum
()
>
0
)
{
finished
=
false
;
break
;
}
}
}
if
(
finished
)
if
(
finished
)
break
;
break
;
...
...
source/train/XLeader.h
查看文件 @
87bb27ee
...
@@ -123,6 +123,9 @@ public:
...
@@ -123,6 +123,9 @@ public:
/* set the communication mode */
/* set the communication mode */
void
SetMode
(
XLEADER_MODE
myMode
);
void
SetMode
(
XLEADER_MODE
myMode
);
/* set the flag of instant run */
void
SetInstantRun
(
bool
flag
=
true
);
/* add a number of job workers (given their device ids) */
/* add a number of job workers (given their device ids) */
void
AddJobWorker
(
XModel
*
model
,
int
n
,
int
*
ids
);
void
AddJobWorker
(
XModel
*
model
,
int
n
,
int
*
ids
);
...
...
source/train/XTrainer.cpp
查看文件 @
87bb27ee
...
@@ -110,6 +110,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
...
@@ -110,6 +110,7 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
leader
.
AddJobCollectWorker
();
leader
.
AddJobCollectWorker
();
leader
.
AddJobUpdateWorker
(
model
,
optimizer
);
leader
.
AddJobUpdateWorker
(
model
,
optimizer
);
leader
.
AddJobBroadcastWorker
();
leader
.
AddJobBroadcastWorker
();
//leader.SetInstantRun();
leader
.
SetServerModel
(
config
,
model
);
leader
.
SetServerModel
(
config
,
model
);
leader
.
Start
();
leader
.
Start
();
...
...
source/train/XWorker.cpp
查看文件 @
87bb27ee
...
@@ -37,6 +37,7 @@ XWorker::XWorker()
...
@@ -37,6 +37,7 @@ XWorker::XWorker()
devID
=
-
1
;
devID
=
-
1
;
id
=
-
1
;
id
=
-
1
;
state
=
XWORKER_UNSTARTED
;
state
=
XWORKER_UNSTARTED
;
isInstantRun
=
false
;
}
}
/* de-constructor */
/* de-constructor */
...
@@ -69,6 +70,12 @@ int XWorker::GetID()
...
@@ -69,6 +70,12 @@ int XWorker::GetID()
return
id
;
return
id
;
}
}
/* set the flag of instant run */
void
XWorker
::
SetInstantRun
(
bool
flag
)
{
isInstantRun
=
flag
;
}
/*
/*
enqueue a new job
enqueue a new job
>> job - the job function
>> job - the job function
...
...
source/train/XWorker.h
查看文件 @
87bb27ee
...
@@ -60,6 +60,9 @@ protected:
...
@@ -60,6 +60,9 @@ protected:
/* state of the worker */
/* state of the worker */
XWORKER_STATE
state
;
XWORKER_STATE
state
;
/* fire the flag of instant run */
bool
isInstantRun
;
public
:
public
:
/* constructor */
/* constructor */
XWorker
();
XWorker
();
...
@@ -79,6 +82,9 @@ public:
...
@@ -79,6 +82,9 @@ public:
/* get worker id */
/* get worker id */
int
GetID
();
int
GetID
();
/* set the flag of instant run */
void
SetInstantRun
(
bool
flag
=
true
);
/* enqueue a new job */
/* enqueue a new job */
void
AddJob
(
void
*
job
,
XList
*
jobArgs
);
void
AddJob
(
void
*
job
,
XList
*
jobArgs
);
...
...
source/train/XWorkerBroadcast.cpp
查看文件 @
87bb27ee
...
@@ -59,6 +59,8 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
...
@@ -59,6 +59,8 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
{
{
TensorList
&
sp
=
source
->
params
;
TensorList
&
sp
=
source
->
params
;
int
finished
=
0
;
int
finished
=
0
;
int
*
finishedFlag
=
new
int
[
sp
.
count
];
memset
(
finishedFlag
,
0
,
sizeof
(
int
)
*
sp
.
count
);
/* check */
/* check */
for
(
int
i
=
0
;
i
<
targetList
->
count
;
i
++
)
{
for
(
int
i
=
0
;
i
<
targetList
->
count
;
i
++
)
{
...
@@ -69,7 +71,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
...
@@ -69,7 +71,7 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
/* the major body of broadcasting */
/* the major body of broadcasting */
while
(
1
)
{
while
(
1
)
{
for
(
int
i
=
0
;
i
<
sp
.
count
;
i
++
)
{
for
(
int
i
=
0
;
i
<
sp
.
count
;
i
++
)
{
if
(
source
->
flags
[
i
]
==
PARAM_STATE_UPDATED
)
{
if
(
source
->
flags
[
i
]
==
PARAM_STATE_UPDATED
&&
finishedFlag
[
i
]
==
0
)
{
for
(
int
j
=
0
;
j
<
targetList
->
count
;
j
++
)
{
for
(
int
j
=
0
;
j
<
targetList
->
count
;
j
++
)
{
XModel
*
target
=
(
XModel
*
)
targetList
->
GetItem
(
j
);
XModel
*
target
=
(
XModel
*
)
targetList
->
GetItem
(
j
);
TensorList
&
tp
=
target
->
params
;
TensorList
&
tp
=
target
->
params
;
...
@@ -81,12 +83,21 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
...
@@ -81,12 +83,21 @@ void XWorkerBroadcast::BroadcastData(XModel * source, XList * targetList, long s
target
->
flags
[
i
]
=
PARAM_STATE_UPDATED
;
target
->
flags
[
i
]
=
PARAM_STATE_UPDATED
;
finished
++
;
finished
++
;
}
}
finishedFlag
[
i
]
=
1
;
}
}
}
}
if
(
finished
==
sp
.
count
*
targetList
->
count
)
if
(
finished
==
sp
.
count
*
targetList
->
count
)
break
;
break
;
#ifdef _WIN32
Sleep
((
DWORD
)
sleepTime
);
#else
sleep
((
unsigned
)
sleepTime
/
1000
);
#endif
}
}
delete
[]
finishedFlag
;
}
}
/*
/*
...
@@ -95,6 +106,7 @@ wrapper of BroadcastData
...
@@ -95,6 +106,7 @@ wrapper of BroadcastData
*/
*/
void
XWorkerBroadcast
::
Broadcast
(
XList
*
args
)
void
XWorkerBroadcast
::
Broadcast
(
XList
*
args
)
{
{
fprintf
(
stderr
,
"broadcast 0
\n
"
);
XWorkerBroadcast
*
broadcaster
=
(
XWorkerBroadcast
*
)
args
->
GetItem
(
0
);
XWorkerBroadcast
*
broadcaster
=
(
XWorkerBroadcast
*
)
args
->
GetItem
(
0
);
XModel
*
source
=
(
XModel
*
)
args
->
GetItem
(
1
);
XModel
*
source
=
(
XModel
*
)
args
->
GetItem
(
1
);
...
@@ -107,6 +119,7 @@ void XWorkerBroadcast::Broadcast(XList * args)
...
@@ -107,6 +119,7 @@ void XWorkerBroadcast::Broadcast(XList * args)
}
}
broadcaster
->
BroadcastData
(
source
,
&
target
,
SLEEP_TIME_IN_BROADCASTING
);
broadcaster
->
BroadcastData
(
source
,
&
target
,
SLEEP_TIME_IN_BROADCASTING
);
fprintf
(
stderr
,
"broadcast 1
\n
"
);
}
}
/*
/*
...
@@ -139,7 +152,10 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
...
@@ -139,7 +152,10 @@ bool XWorkerBroadcast::AddJobBroadcast(XModel * source, XList * targetList)
args
.
AddInt
(
targetList
->
count
);
args
.
AddInt
(
targetList
->
count
);
args
.
AddList
(
targetList
);
args
.
AddList
(
targetList
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerBroadcast
::
Broadcast
,
&
args
);
if
(
isInstantRun
)
XWorkerBroadcast
::
Broadcast
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerBroadcast
::
Broadcast
,
&
args
);
return
true
;
return
true
;
}
}
...
...
source/train/XWorkerCollect.cpp
查看文件 @
87bb27ee
...
@@ -68,6 +68,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
...
@@ -68,6 +68,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
CheckNTErrors
(
sp
.
count
==
tp
.
count
,
"Incompatiable models!"
);
CheckNTErrors
(
sp
.
count
==
tp
.
count
,
"Incompatiable models!"
);
}
}
//fprintf(stderr, "collect data in 0\n");
/* This is a simple implementation of the wait-and-collect process. But
/* This is a simple implementation of the wait-and-collect process. But
there is a risk that some models are not available, that is, the
there is a risk that some models are not available, that is, the
loop would never stop. A solution might be that we force the loop
loop would never stop. A solution might be that we force the loop
...
@@ -173,6 +175,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
...
@@ -173,6 +175,8 @@ void XWorkerCollect::CollectData(XList * sourceList, XModel * target, long sleep
/* wrapper of CollectData */
/* wrapper of CollectData */
void
XWorkerCollect
::
Collect
(
XList
*
args
)
void
XWorkerCollect
::
Collect
(
XList
*
args
)
{
{
fprintf
(
stderr
,
"collect data 0
\n
"
);
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
args
->
GetItem
(
0
);
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
args
->
GetItem
(
0
);
int
sourceNum
=
args
->
GetItemInt
(
1
);
int
sourceNum
=
args
->
GetItemInt
(
1
);
...
@@ -187,6 +191,8 @@ void XWorkerCollect::Collect(XList * args)
...
@@ -187,6 +191,8 @@ void XWorkerCollect::Collect(XList * args)
XModel
*
target
=
(
XModel
*
)
args
->
GetItem
(
2
+
sourceNum
);
XModel
*
target
=
(
XModel
*
)
args
->
GetItem
(
2
+
sourceNum
);
collecter
->
CollectData
(
&
source
,
target
,
SLEEP_TIME_IN_COLLECTING
);
collecter
->
CollectData
(
&
source
,
target
,
SLEEP_TIME_IN_COLLECTING
);
fprintf
(
stderr
,
"collect data 1
\n
"
);
}
}
/*
/*
...
@@ -253,7 +259,10 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
...
@@ -253,7 +259,10 @@ bool XWorkerCollect::AddJobCollect(XList * sourceList, XModel * target)
args
.
AddList
(
sourceList
);
args
.
AddList
(
sourceList
);
args
.
Add
(
target
);
args
.
Add
(
target
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
Collect
,
&
args
);
if
(
isInstantRun
)
XWorkerCollect
::
Collect
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
Collect
,
&
args
);
return
true
;
return
true
;
}
}
...
@@ -302,6 +311,8 @@ void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, long
...
@@ -302,6 +311,8 @@ void XWorkerCollect::CollectOtherData(XList* sourceList, XNNRecord* target, long
/* wrapper of CollectOtherData */
/* wrapper of CollectOtherData */
void
XWorkerCollect
::
CollectOther
(
XList
*
args
)
void
XWorkerCollect
::
CollectOther
(
XList
*
args
)
{
{
//fprintf(stderr, "collect data other 0\n");
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
args
->
GetItem
(
0
);
XWorkerCollect
*
collecter
=
(
XWorkerCollect
*
)
args
->
GetItem
(
0
);
int
sourceNum
=
args
->
GetItemInt
(
1
);
int
sourceNum
=
args
->
GetItemInt
(
1
);
...
@@ -316,6 +327,8 @@ void XWorkerCollect::CollectOther(XList* args)
...
@@ -316,6 +327,8 @@ void XWorkerCollect::CollectOther(XList* args)
XNNRecord
*
target
=
(
XNNRecord
*
)
args
->
GetItem
(
2
+
sourceNum
);
XNNRecord
*
target
=
(
XNNRecord
*
)
args
->
GetItem
(
2
+
sourceNum
);
collecter
->
CollectOtherData
(
&
source
,
target
,
SLEEP_TIME_IN_COLLECTING_OTHER
);
collecter
->
CollectOtherData
(
&
source
,
target
,
SLEEP_TIME_IN_COLLECTING_OTHER
);
//fprintf(stderr, "collect data other 1\n");
}
}
/*
/*
...
@@ -335,7 +348,10 @@ bool XWorkerCollect::AddJobCollectOther(XList* sourceList, XNNRecord* target)
...
@@ -335,7 +348,10 @@ bool XWorkerCollect::AddJobCollectOther(XList* sourceList, XNNRecord* target)
args
.
AddList
(
sourceList
);
args
.
AddList
(
sourceList
);
args
.
Add
(
target
);
args
.
Add
(
target
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
CollectOther
,
&
args
);
if
(
isInstantRun
)
XWorkerCollect
::
CollectOther
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerCollect
::
CollectOther
,
&
args
);
return
true
;
return
true
;
}
}
...
...
source/train/XWorkerJob.cpp
查看文件 @
87bb27ee
...
@@ -180,12 +180,19 @@ add a new job of model refreshment
...
@@ -180,12 +180,19 @@ add a new job of model refreshment
*/
*/
bool
XWorkerJob
::
AddJobRefresh
(
XModel
*
myModel
)
bool
XWorkerJob
::
AddJobRefresh
(
XModel
*
myModel
)
{
{
//fprintf(stderr, "refresh 0\n");
CheckNTErrors
(
myModel
!=
NULL
,
"no parameter keeper!"
);
CheckNTErrors
(
myModel
!=
NULL
,
"no parameter keeper!"
);
XList
args
(
1
);
XList
args
(
1
);
args
.
Add
(
myModel
);
args
.
Add
(
myModel
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XModel
::
Refresh
,
&
args
);
if
(
isInstantRun
)
XModel
::
Refresh
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XModel
::
Refresh
,
&
args
);
//fprintf(stderr, "refresh 1\n");
return
true
;
return
true
;
}
}
...
@@ -213,7 +220,10 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel,
...
@@ -213,7 +220,10 @@ bool XWorkerJob::AddJobNeuralNet(XModel * myModel,
args
.
Add
(
golds
);
args
.
Add
(
golds
);
args
.
Add
(
losses
);
args
.
Add
(
losses
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XModel
::
Run
,
&
args
);
if
(
isInstantRun
)
XModel
::
Run
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XModel
::
Run
,
&
args
);
SetState
(
XWORKER_STARTED
);
SetState
(
XWORKER_STARTED
);
...
@@ -226,7 +236,10 @@ bool XWorkerJob::AddJobRecord()
...
@@ -226,7 +236,10 @@ bool XWorkerJob::AddJobRecord()
XList
args
;
XList
args
;
args
.
Add
(
this
);
args
.
Add
(
this
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerJob
::
RecordMeStatic
,
&
args
);
if
(
isInstantRun
)
XWorkerJob
::
RecordMeStatic
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerJob
::
RecordMeStatic
,
&
args
);
return
true
;
return
true
;
}
}
...
@@ -234,6 +247,8 @@ bool XWorkerJob::AddJobRecord()
...
@@ -234,6 +247,8 @@ bool XWorkerJob::AddJobRecord()
/* wrapper of RecordMe */
/* wrapper of RecordMe */
void
XWorkerJob
::
RecordMeStatic
(
XList
*
args
)
void
XWorkerJob
::
RecordMeStatic
(
XList
*
args
)
{
{
//fprintf(stderr, "record static 0\n");
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>
0
,
"Illegal arguments!"
);
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>
0
,
"Illegal arguments!"
);
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
args
->
GetItem
(
0
);
XWorkerJob
*
worker
=
(
XWorkerJob
*
)
args
->
GetItem
(
0
);
...
@@ -241,6 +256,8 @@ void XWorkerJob::RecordMeStatic(XList* args)
...
@@ -241,6 +256,8 @@ void XWorkerJob::RecordMeStatic(XList* args)
worker
->
RecordMe
();
worker
->
RecordMe
();
worker
->
SetState
(
XWORKER_FINISHED
);
worker
->
SetState
(
XWORKER_FINISHED
);
//fprintf(stderr, "record static 1\n");
}
}
}
/* end of the nts (NiuTrans.Tensor) namespace */
}
/* end of the nts (NiuTrans.Tensor) namespace */
...
...
source/train/XWorkerUpdate.cpp
查看文件 @
87bb27ee
...
@@ -81,7 +81,6 @@ void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sle
...
@@ -81,7 +81,6 @@ void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, long sle
flags
[
i
]
=
PARAM_STATE_UPDATED
;
flags
[
i
]
=
PARAM_STATE_UPDATED
;
finished
++
;
finished
++
;
}
}
}
}
if
(
finished
==
params
.
count
)
if
(
finished
==
params
.
count
)
...
@@ -103,6 +102,8 @@ wrapper of UpdateModel
...
@@ -103,6 +102,8 @@ wrapper of UpdateModel
*/
*/
void
XWorkerUpdate
::
Update
(
XList
*
args
)
void
XWorkerUpdate
::
Update
(
XList
*
args
)
{
{
fprintf
(
stderr
,
"update 0
\n
"
);
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>=
3
,
"Illegal argument list!"
);
CheckNTErrors
(
args
!=
NULL
&&
args
->
count
>=
3
,
"Illegal argument list!"
);
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
args
->
GetItem
(
0
);
XWorkerUpdate
*
updater
=
(
XWorkerUpdate
*
)
args
->
GetItem
(
0
);
...
@@ -110,6 +111,8 @@ void XWorkerUpdate::Update(XList * args)
...
@@ -110,6 +111,8 @@ void XWorkerUpdate::Update(XList * args)
XOptimizer
*
optimizer
=
(
XOptimizer
*
)
args
->
GetItem
(
2
);
XOptimizer
*
optimizer
=
(
XOptimizer
*
)
args
->
GetItem
(
2
);
updater
->
UpdateModel
(
model
,
optimizer
,
SLEEP_TIME_IN_MODEL_UPDATE
);
updater
->
UpdateModel
(
model
,
optimizer
,
SLEEP_TIME_IN_MODEL_UPDATE
);
fprintf
(
stderr
,
"update 1
\n
"
);
}
}
/*
/*
...
@@ -127,7 +130,10 @@ bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer)
...
@@ -127,7 +130,10 @@ bool XWorkerUpdate::AddJobUpdate(XModel * model, XOptimizer * optimizer)
args
.
Add
(
model
);
args
.
Add
(
model
);
args
.
Add
(
optimizer
);
args
.
Add
(
optimizer
);
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerUpdate
::
Update
,
&
args
);
if
(
isInstantRun
)
XWorkerUpdate
::
Update
(
&
args
);
else
queue
.
EnqueueJob
((
void
*
)(
char
*
)
XWorkerUpdate
::
Update
,
&
args
);
return
true
;
return
true
;
}
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论