Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
Emmay
NiuTrans.Tensor
Commits
8d1ae93b
Commit
8d1ae93b
authored
Aug 05, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
renaming and bug fixes
parent
90dc67f2
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
44 行增加
和
30 行删除
+44
-30
source/network/XBackwardMath.cpp
+15
-8
source/sample/transformer/T2TEmbedding.cpp
+1
-1
source/sample/transformer/T2TEncoder.cpp
+2
-2
source/sample/transformer/T2TFNN.cpp
+2
-2
source/sample/transformer/T2TLayerNormal.cpp
+1
-1
source/sample/transformer/T2TTrainer.cpp
+3
-10
source/tensor/core/math/Power.cpp
+6
-2
source/tensor/core/math/Power.cu
+14
-4
没有找到文件。
source/network/XBackwardMath.cpp
查看文件 @
8d1ae93b
...
...
@@ -459,7 +459,7 @@ gradient for power
for
c = pow(a,p)
we have
dE/da = (dE/dc) * p
*
a^(p-1)
dE/da = (dE/dc) * p
*
a^(p-1)
>> node - the node (c) for backward computation
*/
void
XMathGrad
::
GradPower
(
XTensor
*
node
)
...
...
@@ -942,10 +942,10 @@ void XMathGrad::GradReduceSum(XTensor * node)
/*
gradient for reduceSumSquared
for
c =
reduceSumSquared(a, dim, b)
c =
\sum_i (a_i - b)^2
we have
dE/da = Unsqueeze(dE/dc) * 2a
dE/db =
Unsqueeze(dE/dc) * (-2b)
dE/db =
dE/dc * -2 * n * b
>> node - the node (c) for backward computation
*/
void
XMathGrad
::
GradReduceSumSquared
(
XTensor
*
node
)
...
...
@@ -964,10 +964,13 @@ void XMathGrad::GradReduceSumSquared(XTensor * node)
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
b
);
/* dE/da = Unsqueeze(dE/dc) * 2a */
_ScaleAndShift
(
a
,
c
,
2.0
F
);
_ScaleAndShift
(
b
,
d
,
-
2.0
F
);
_Unsqueeze
(
node
->
grad
,
e
,
dim
,
n
);
_Multiply
(
e
,
c
,
a
->
grad
,
1.0
F
);
/* dE/db = dE/dc * -2 * n * b */
_ScaleAndShift
(
b
,
d
,
-
2.0
F
*
n
);
_Multiply
(
node
->
grad
,
d
,
b
->
grad
,
1.0
F
);
DelTensorBuf
(
c
);
...
...
@@ -980,10 +983,11 @@ void XMathGrad::GradReduceSumSquared(XTensor * node)
/*
gradient for reduceVariance
for
c = reduceVariance(a, dim, b)
c = (sum_i (a_i - b)^2) * 1/n
where b is the mean, and n is the size of a
we have
dE/da = Unsqueeze(dE/dc) * 2a/
dimSizeA[dim]
dE/db =
Unsqueeze(dE/dc) * (-2a/dimSizeA[dim])
dE/da = Unsqueeze(dE/dc) * 2a/
n
dE/db =
dE/dc * -2 * b
>> node - the node (c) for backward computation
*/
void
XMathGrad
::
GradReduceVariance
(
XTensor
*
node
)
...
...
@@ -1002,10 +1006,13 @@ void XMathGrad::GradReduceVariance(XTensor * node)
XNoder
::
MakeGrad
(
a
);
XNoder
::
MakeGrad
(
b
);
/* dE/da = Unsqueeze(dE/dc) * 2a/n */
_ScaleAndShift
(
a
,
c
,
2.0
F
/
n
);
_ScaleAndShift
(
b
,
d
,
-
2.0
F
/
n
);
_Unsqueeze
(
node
->
grad
,
e
,
dim
,
n
);
_Multiply
(
e
,
c
,
a
->
grad
,
1.0
F
);
/* dE/db = dE/dc * -2 * b */
_ScaleAndShift
(
b
,
d
,
-
2.0
F
);
_Multiply
(
node
->
grad
,
d
,
b
->
grad
,
1.0
F
);
DelTensorBuf
(
c
);
...
...
source/sample/transformer/T2TEmbedding.cpp
查看文件 @
8d1ae93b
...
...
@@ -62,7 +62,7 @@ void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * my
InitTensor2D
(
&
w
,
vSize
,
eSize
,
X_FLOAT
,
devID
,
mem
);
w
.
SetDataRandn
(
0
,
1
/
(
float
)
sqrt
((
float
)
eSize
));
w
.
SetDataRandn
(
0
,
1
.0
F
/
(
float
)
sqrt
((
float
)
eSize
));
/* create the positional embedding matrix */
MakePosEmbedding
(
eSize
,
d
,
maxLength
);
...
...
source/sample/transformer/T2TEncoder.cpp
查看文件 @
8d1ae93b
...
...
@@ -53,13 +53,13 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
devID
=
myDevID
;
mem
=
myMem
;
LoadParamInt
(
argc
,
argv
,
"n
stack
"
,
&
nlayer
,
6
);
LoadParamInt
(
argc
,
argv
,
"n
layer
"
,
&
nlayer
,
6
);
LoadParamInt
(
argc
,
argv
,
"hsize"
,
&
hSize
,
512
);
LoadParamInt
(
argc
,
argv
,
"esize"
,
&
eSize
,
512
);
LoadParamInt
(
argc
,
argv
,
"vsize"
,
&
vSize
,
-
1
);
CheckNTErrors
(
nlayer
>
1
,
"We have one encoding layer at least!"
);
CheckNTErrors
(
nlayer
>
=
1
,
"We have one encoding layer at least!"
);
CheckNTErrors
(
vSize
>
1
,
"set vocabulary size by
\"
-vsize
\"
"
);
/* embedding model */
...
...
source/sample/transformer/T2TFNN.cpp
查看文件 @
8d1ae93b
...
...
@@ -88,10 +88,10 @@ XTensor T2TFNN::Make(XTensor &input)
XTensor
t1
;
/* t1 = max(0, x * w1 + b1) */
t1
=
Rectify
(
MMul
(
input
,
X_NOTRANS
,
w1
,
X_NOTRANS
)
+
b1
);
t1
=
Rectify
(
MMul
(
input
,
w1
)
+
b1
);
/* result = t1 * w2 + b2 */
return
MMul
(
t1
,
X_NOTRANS
,
w2
,
X_NOTRANS
)
+
b2
;
return
MMul
(
t1
,
w2
)
+
b2
;
}
...
...
source/sample/transformer/T2TLayerNormal.cpp
查看文件 @
8d1ae93b
...
...
@@ -76,7 +76,7 @@ XTensor T2TLN::Make(XTensor &input)
standard
=
Power
(
variance
,
0.5
F
);
/* unsqueeze mean and standard deviation to fit them into
the same s
iz
e of x */
the same s
hap
e of x */
meanFilled
=
Unsqueeze
(
mean
,
x
.
order
-
1
,
x
.
GetDim
(
-
1
));
standardFilled
=
Unsqueeze
(
standard
,
x
.
order
-
1
,
x
.
GetDim
(
-
1
));
...
...
source/sample/transformer/T2TTrainer.cpp
查看文件 @
8d1ae93b
...
...
@@ -342,6 +342,9 @@ void T2TTrainer::Update(T2TModel * model, const float lr)
ws
.
Add
(
&
model
->
encoder
.
fnns
[
i
].
b1
);
ws
.
Add
(
&
model
->
encoder
.
fnns
[
i
].
w2
);
ws
.
Add
(
&
model
->
encoder
.
fnns
[
i
].
b2
);
ws
.
Add
(
&
model
->
encoder
.
attentions
[
i
].
wk
);
ws
.
Add
(
&
model
->
encoder
.
attentions
[
i
].
wq
);
ws
.
Add
(
&
model
->
encoder
.
attentions
[
i
].
wv
);
}
ws
.
Add
(
&
model
->
encoder
.
embedder
.
w
);
...
...
@@ -352,16 +355,6 @@ void T2TTrainer::Update(T2TModel * model, const float lr)
CheckNTErrors
(
para
!=
NULL
,
"NULL parameter tensor!"
);
CheckNTErrors
(
paraGrad
!=
NULL
,
"NULL gradient tensor!"
);
/*DTYPE * d = (DTYPE*)paraGrad->data;
for(int i = 0; i < paraGrad->unitNum; i++){
if(IsINF(d[i])){
fprintf(stderr, "isinf %d\n", i);
}
if(IsNAN(d[i])){
fprintf(stderr, "isnan %d\n", i);
}
}*/
/* the delta rule */
_Sum
(
para
,
paraGrad
,
para
,
-
lr
);
...
...
source/tensor/core/math/Power.cpp
查看文件 @
8d1ae93b
...
...
@@ -60,8 +60,12 @@ void _Power(const XTensor * a, XTensor * b, DTYPE p)
bData
[
i
]
=
aData
[
i
]
*
aData
[
i
];
}
else
{
for
(
int
i
=
0
;
i
<
a
->
unitNum
;
i
++
)
bData
[
i
]
=
(
DTYPE
)
pow
(
aData
[
i
],
p
);
for
(
int
i
=
0
;
i
<
a
->
unitNum
;
i
++
)
{
if
(
p
<
0
&&
aData
[
i
]
==
0
)
bData
[
i
]
=
1e20
F
;
else
bData
[
i
]
=
(
DTYPE
)
pow
(
aData
[
i
],
p
);
}
}
}
...
...
source/tensor/core/math/Power.cu
查看文件 @
8d1ae93b
...
...
@@ -77,8 +77,13 @@ void KernelPower(DTYPE * a, DTYPE * b, DTYPE p, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = pow(a[i], p);
if (i < size) {
DTYPE v = a[i];
if (p < 0 && v == 0)
b[i] = 1e20;
else
b[i] = pow(a[i], p);
}
}
/*
...
...
@@ -94,8 +99,13 @@ void KernelPower(__half * a, __half * b, __half p, int size)
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
#else
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = __float2half(pow(__half2float(a[i]), __half2float(p)));
if (i < size) {
float v = __half2float(a[i]);
if (__half2float(p) < 0 && v == 0)
b[i] = __float2half(1e20);
else
b[i] = __float2half(pow(__half2float(a[i]), __half2float(p)));
}
#endif
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论