Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
W
WMT19-1.0.14
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
libei
WMT19-1.0.14
Commits
c31fd16c
Commit
c31fd16c
authored
Apr 18, 2019
by
libei
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add conv-self attention
parent
868f56a6
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
2 行增加
和
4 行删除
+2
-4
tensor2tensor/models/common_attention.py
+2
-4
没有找到文件。
tensor2tensor/models/common_attention.py
查看文件 @
c31fd16c
...
...
@@ -382,7 +382,6 @@ def multihead_attention(query_antecedent,
dropout_rate
,
attention_type
=
"dot_product"
,
max_relative_length
=
16
,
window_size
=-
1
,
conv_mask
=
None
,
summaries
=
False
,
image_shapes
=
None
,
...
...
@@ -467,7 +466,7 @@ def multihead_attention(query_antecedent,
q
,
k
,
v
,
bias
,
max_relative_length
,
dropout_rate
,
summaries
,
image_shapes
,
dropout_broadcast_dims
=
dropout_broadcast_dims
)
elif
attention_type
==
"convolutional_self_attention"
:
x
=
convolutional_self_attention
(
q
,
k
,
v
,
bias
,
dropout_rate
,
summaries
,
image_shapes
,
dropout_broadcast_dims
=
dropout_broadcast_dims
,
window_size
=
window_size
,
conv_mask
=
conv_mask
)
q
,
k
,
v
,
bias
,
dropout_rate
,
summaries
,
image_shapes
,
dropout_broadcast_dims
=
dropout_broadcast_dims
,
conv_mask
=
conv_mask
)
else
:
raise
ValueError
...
...
@@ -852,7 +851,6 @@ def convolutional_self_attention(q,
summaries
=
False
,
image_shapes
=
None
,
dropout_broadcast_dims
=
None
,
window_size
=-
1
,
conv_mask
=
None
,
name
=
None
):
"""dot-product attention.
...
...
@@ -877,7 +875,7 @@ def convolutional_self_attention(q,
logits
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
if
bias
is
not
None
:
logits
+=
bias
if
window_size
!=
-
1
:
if
conv_mask
is
not
None
:
logits
+=
conv_mask
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
# broadcast dropout can save memory
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论