Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
6f90577d
Commit
6f90577d
authored
6 years ago
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bug in sorting sequences
parent
a037d802
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
8 行增加
和
7 行删除
+8
-7
source/sample/transformer/T2TModel.cpp
+2
-2
source/sample/transformer/T2TTrainer.cpp
+6
-5
没有找到文件。
source/sample/transformer/T2TModel.cpp
查看文件 @
6f90577d
...
...
@@ -194,8 +194,8 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
InitTensor
(
&
maskDec
,
inputDec
.
order
+
1
,
dims
,
X_FLOAT
,
1.0
F
,
inputDec
.
devID
,
inputDec
.
mem
);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
this matrix can be used to prevent the attention to current or following words in
a given sequence. */
_SetDataLowTri
(
&
maskDec
,
1e9
F
,
0
);
_ScaleAndShiftMe
(
&
maskDec
,
1.0
F
,
-
1e9
F
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTrainer.cpp
查看文件 @
6f90577d
...
...
@@ -525,6 +525,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
/* sort the sequences by length */
if
(
isSorted
)
{
CheckNTErrors
(
seqCount
%
step
==
0
,
"Wrong number of sequences!"
);
SampleNode
*
nodes
=
new
SampleNode
[
seqCount
];
int
count
=
0
;
int
offset
=
0
;
...
...
@@ -540,18 +541,18 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
offset
+=
node
.
size
;
}
qsort
(
nodes
,
seqC
ount
,
sizeof
(
SampleNode
),
CompareSampleNode
);
qsort
(
nodes
,
c
ount
,
sizeof
(
SampleNode
),
CompareSampleNode
);
count
=
0
;
offset
=
0
;
for
(
int
i
=
0
;
i
<
seqCount
;
i
++
){
for
(
int
i
=
0
;
i
<
seqCount
;
i
+=
step
){
SampleNode
&
node
=
nodes
[
count
];
memcpy
(
buf2
+
offset
,
node
.
p
,
sizeof
(
int
)
*
node
.
size
);
for
(
int
j
=
0
;
j
<
step
;
j
++
){
seqLen2
[
count
+
j
]
=
seqLen
[
node
.
id
+
j
];
seqOffset
[
count
+
j
]
=
offset
+
(
j
>
0
?
seqLen
[
node
.
id
+
j
-
1
]
:
0
);
seqLen2
[
i
+
j
]
=
seqLen
[
node
.
id
+
j
];
seqOffset
[
i
+
j
]
=
offset
+
(
j
>
0
?
seqLen
[
node
.
id
+
j
-
1
]
:
0
);
}
count
+=
step
;
count
+=
1
;
offset
+=
node
.
size
;
}
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论