Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
0e2452b9
Commit
0e2452b9
authored
Apr 10, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bug of the relative multihead attention
parent
61cf1afa
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
35 行增加
和
8 行删除
+35
-8
fairseq/models/speech_to_text/s2t_transformer.py
+4
-2
fairseq/models/transformer.py
+27
-0
fairseq/modules/layer_history.py
+1
-1
fairseq/modules/relative_multihead_attention.py
+3
-5
没有找到文件。
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
0e2452b9
...
...
@@ -147,7 +147,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default
=
"selfattn"
,
choices
=
[
"selfattn"
,
"rel_selfattn"
"rel_selfattn"
,
"relative"
,
],
help
=
"transformer encoder self-attention layer type"
)
...
...
@@ -183,7 +184,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default
=
"selfattn"
,
choices
=
[
"selfattn"
,
"rel_selfattn"
"rel_selfattn"
,
"relative"
,
],
help
=
"transformer decoder self-attention layer type"
)
...
...
fairseq/models/transformer.py
查看文件 @
0e2452b9
...
...
@@ -194,6 +194,29 @@ class TransformerModel(FairseqEncoderDecoderModel):
help
=
'block size of quantization noise at training time'
)
parser
.
add_argument
(
'--quant-noise-scalar'
,
type
=
float
,
metavar
=
'D'
,
default
=
0
,
help
=
'scalar quantization noise and scalar quantization at training time'
)
parser
.
add_argument
(
"--encoder-attention-type"
,
type
=
str
,
default
=
"selfattn"
,
choices
=
[
"selfattn"
,
"rel_selfattn"
,
"relative"
,
],
help
=
"transformer encoder self-attention layer type"
)
parser
.
add_argument
(
"--decoder-attention-type"
,
type
=
str
,
default
=
"selfattn"
,
choices
=
[
"selfattn"
,
"rel_selfattn"
,
"relative"
,
],
help
=
"transformer decoder self-attention layer type"
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
...
...
@@ -1134,6 +1157,10 @@ def base_architecture(args):
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
args
.
quant_noise_scalar
=
getattr
(
args
,
"quant_noise_scalar"
,
0
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
...
...
fairseq/modules/layer_history.py
查看文件 @
0e2452b9
import
torch
import
torch.nn
as
nn
from
fairseq.mod
els.transformer
import
LayerNorm
from
fairseq.mod
ules.layer_norm
import
LayerNorm
import
queue
import
numpy
as
np
...
...
fairseq/modules/relative_multihead_attention.py
查看文件 @
0e2452b9
...
...
@@ -3,13 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.modules
.multihead_attention
import
MultiheadAttention
from
fairseq.modules
import
MultiheadAttention
from
torch
import
Tensor
,
nn
from
torch.nn
import
Parameter
...
...
@@ -56,11 +55,10 @@ class RelativeMultiheadAttention(MultiheadAttention):
self
.
k_only
=
k_only
self
.
relative_position_keys
=
Parameter
(
torch
.
Tensor
(
2
*
self
.
max_relative_length
+
1
,
self
.
head_dim
))
if
not
self
.
k_only
:
self
.
relative_position_values
=
Parameter
(
torch
.
Tensor
(
2
*
self
.
max_relative_length
+
1
,
self
.
head_dim
))
nn
.
init
.
xavier_uniform_
(
self
.
relative_position_keys
)
if
not
self
.
k_only
:
self
.
relative_position_values
=
Parameter
(
torch
.
Tensor
(
2
*
self
.
max_relative_length
+
1
,
self
.
head_dim
))
nn
.
init
.
xavier_uniform_
(
self
.
relative_position_values
)
def
forward
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论