Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
dea97945
Commit
dea97945
authored
5 years ago
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fixes
parent
fde02e21
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
31 行增加
和
17 行删除
+31
-17
source/sample/transformer/T2TLengthPenalty.cpp
+0
-1
source/sample/transformer/T2TPredictor.cpp
+2
-0
source/sample/transformer/T2TSearch.cpp
+15
-5
source/sample/transformer/T2TSearch.h
+3
-2
source/sample/transformer/T2TTester.cpp
+3
-1
source/tensor/core/getandset/SetData.cpp
+2
-2
source/tensor/core/getandset/SetData.cu
+6
-6
没有找到文件。
source/sample/transformer/T2TLengthPenalty.cpp
查看文件 @
dea97945
...
...
@@ -36,7 +36,6 @@ XTensor T2TLengthPenalizer::GNMT(const XTensor & length, float alpha)
XTensor
lp
;
base
=
(
length
+
5
)
/
(
1
+
5
);
lp
=
Power
(
base
,
alpha
);
return
lp
;
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TPredictor.cpp
查看文件 @
dea97945
...
...
@@ -227,6 +227,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
/* generate the output probabilities */
m
->
outputLayer
->
Make
(
decodingStep
,
output
);
_LogMe
(
&
output
);
next
->
layersEnc
.
AddList
(
&
s
->
layersEnc
);
next
->
layersDec
.
Add
(
&
inputDec
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TSearch.cpp
查看文件 @
dea97945
...
...
@@ -73,8 +73,10 @@ search for the most promising states
>> input - input of the model
>> padding - padding of the input
>> output - output that represents the sequences as rows
>> score - score of the sequences
*/
void
T2TSearch
::
Search
(
T2TModel
*
model
,
XTensor
*
input
,
XTensor
*
padding
,
XTensor
*
output
)
void
T2TSearch
::
Search
(
T2TModel
*
model
,
XTensor
*
input
,
XTensor
*
padding
,
XTensor
*
output
,
XTensor
*
score
)
{
T2TPredictor
predictor
;
XTensor
maskEnc
;
...
...
@@ -155,7 +157,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* fill the heap with imcomplete hypotheses if neccesary */
FillHeap
(
next
);
Dump
(
output
);
Dump
(
output
,
score
);
delete
[]
states
;
}
...
...
@@ -237,6 +239,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
InitTensor
(
&
mask
,
prev
->
endMark
.
order
,
prev
->
endMark
.
dimSize
,
X_FLOAT
,
1.0
F
,
prev
->
endMark
.
devID
,
prev
->
endMark
.
mem
);
mask
.
SetZeroAll
();
_SetDataFixedCond
(
&
mask
,
&
prev
->
endMark
,
-
1e9
F
);
mask
.
Reshape
(
mask
.
unitNum
);
...
...
@@ -514,23 +517,28 @@ void T2TSearch::FillHeap(T2TStateBundle * beam)
/*
save the output sequences in a tensor
>> output - output sequences (for return)
>> score - score of thes sequences
*/
void
T2TSearch
::
Dump
(
XTensor
*
output
)
void
T2TSearch
::
Dump
(
XTensor
*
output
,
XTensor
*
score
)
{
int
dims
[
3
]
=
{
batchSize
,
beamSize
,
maxLength
};
int
*
words
=
new
int
[
maxLength
];
InitTensor
(
output
,
3
,
dims
,
X_INT
);
InitTensor
(
score
,
2
,
dims
,
X_FLOAT
);
SetDataFixedInt
(
*
output
,
-
1
);
score
->
SetZeroAll
();
/* heap for an input sentence in the batch */
for
(
int
h
=
0
;
h
<
batchSize
;
h
++
){
XHeap
<
MIN_HEAP
,
float
>
&
heap
=
fullHypos
[
h
];
int
c
=
heap
.
Count
();
/* for each output in the beam */
for
(
int
i
=
0
;
i
<
beamSize
&&
heap
.
Count
()
>
0
;
i
++
){
T2TState
*
state
=
(
T2TState
*
)
heap
.
Pop
().
index
;
HeapNode
<
float
>
node
=
heap
.
Pop
();
T2TState
*
state
=
(
T2TState
*
)
node
.
index
;
int
count
=
0
;
bool
isCompleted
=
true
;
...
...
@@ -548,7 +556,9 @@ void T2TSearch::Dump(XTensor * output)
/* dump the sentence to the output tensor */
for
(
int
w
=
0
;
w
<
count
;
w
++
)
output
->
Set3DInt
(
words
[
count
-
w
-
1
],
h
,
beamSize
-
i
-
1
,
w
);
output
->
Set3DInt
(
words
[
count
-
w
-
1
],
h
,
c
-
i
-
1
,
w
);
score
->
Set2D
(
node
.
value
,
h
,
c
-
i
-
1
);
}
}
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TSearch.h
查看文件 @
dea97945
...
...
@@ -73,7 +73,8 @@ public:
void
Init
(
int
argc
,
char
**
argv
);
/* search for the most promising states */
void
Search
(
T2TModel
*
model
,
XTensor
*
input
,
XTensor
*
padding
,
XTensor
*
output
);
void
Search
(
T2TModel
*
model
,
XTensor
*
input
,
XTensor
*
padding
,
XTensor
*
output
,
XTensor
*
score
);
/* preparation */
void
Prepare
(
int
myBatchSize
,
int
myBeamSize
);
...
...
@@ -94,7 +95,7 @@ public:
void
FillHeap
(
T2TStateBundle
*
beam
);
/* save the output sequences in a tensor */
void
Dump
(
XTensor
*
output
);
void
Dump
(
XTensor
*
output
,
XTensor
*
score
);
/* check if the token is an end symbol */
bool
IsEnd
(
int
token
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTester.cpp
查看文件 @
dea97945
...
...
@@ -112,10 +112,12 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
CheckNTErrors
(
!
model
->
isLM
,
"Only MT model is supported!"
);
XTensor
output
;
XTensor
score
;
seacher
.
Search
(
model
,
&
batchEnc
,
&
paddingEnc
,
&
output
);
seacher
.
Search
(
model
,
&
batchEnc
,
&
paddingEnc
,
&
output
,
&
score
);
Dump
(
ofile
,
&
output
);
//score.Dump(ofile, "score:");
float
prob
=
0
;
...
...
This diff is collapsed.
Click to expand it.
source/tensor/core/getandset/SetData.cpp
查看文件 @
dea97945
...
...
@@ -238,12 +238,12 @@ void _SetDataFixedCond(XTensor * tensor, XTensor * condition, DTYPE p)
int
num
=
tensor
->
unitNum
;
CheckNTErrors
(
num
==
condition
->
unitNum
,
"Wrong size of the condition tensor!"
);
CheckNTErrors
(
condition
->
unitSize
==
sizeof
(
floa
t
),
"TODO!"
);
CheckNTErrors
(
condition
->
unitSize
==
sizeof
(
in
t
),
"TODO!"
);
if
(
tensor
->
dataType
==
DEFAULT_DTYPE
){
if
(
tensor
->
devID
<
0
){
DTYPE
*
data
=
(
DTYPE
*
)
tensor
->
data
;
DTYPE
*
cond
=
(
DTYPE
*
)
condition
->
data
;
int
*
cond
=
(
int
*
)
condition
->
data
;
for
(
int
i
=
0
;
i
<
num
;
i
++
){
if
(
cond
[
i
]
!=
0
)
data
[
i
]
=
p
;
...
...
This diff is collapsed.
Click to expand it.
source/tensor/core/getandset/SetData.cu
查看文件 @
dea97945
...
...
@@ -159,7 +159,7 @@ if the condition entry is non-zero
>> p - the initial value
*/
__global__
void KernelSetDataFixedCondFloat(float * d,
floa
t * c, int size, float p)
void KernelSetDataFixedCondFloat(float * d,
in
t * c, int size, float p)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
...
@@ -178,7 +178,7 @@ if the condition entry is non-zero
void _CudaSetDataFixedCondFloat(XTensor * tensor, XTensor * condition, float p)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_FLOAT!");
CheckNTErrors(condition->unitSize == sizeof(
floa
t), "TODO!");
CheckNTErrors(condition->unitSize == sizeof(
in
t), "TODO!");
int gridSize[3];
int blockSize[3];
...
...
@@ -191,7 +191,7 @@ void _CudaSetDataFixedCondFloat(XTensor * tensor, XTensor * condition, float p)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataFixedCondFloat <<<blocks, threads >>>((float*)tensor->data, (
floa
t*)condition->data,
KernelSetDataFixedCondFloat <<<blocks, threads >>>((float*)tensor->data, (
in
t*)condition->data,
tensor->unitNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
...
...
@@ -206,7 +206,7 @@ if the condition entry is non-zero
>> p - the initial value
*/
__global__
void KernelSetDataFixedCondInt(int * d,
floa
t * c, int size, int p)
void KernelSetDataFixedCondInt(int * d,
in
t * c, int size, int p)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
...
...
@@ -225,7 +225,7 @@ if the condition entry is non-zero
void _CudaSetDataFixedCondInt(XTensor * tensor, XTensor * condition, int p)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_FLOAT!");
CheckNTErrors(condition->unitSize == sizeof(
floa
t), "TODO!");
CheckNTErrors(condition->unitSize == sizeof(
in
t), "TODO!");
int gridSize[3];
int blockSize[3];
...
...
@@ -238,7 +238,7 @@ void _CudaSetDataFixedCondInt(XTensor * tensor, XTensor * condition, int p)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataFixedCondInt <<<blocks, threads >>>((int*)tensor->data, (
floa
t*)condition->data,
KernelSetDataFixedCondInt <<<blocks, threads >>>((int*)tensor->data, (
in
t*)condition->data,
tensor->unitNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论