Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
b23817e0
Commit
b23817e0
authored
Sep 03, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement the pyramid transformer
parent
7802e6f7
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
1591 行增加
和
6 行删除
+1591
-6
fairseq/models/speech_to_text/__init__.py
+1
-0
fairseq/models/speech_to_text/pys2t_transformer.py
+544
-0
fairseq/models/speech_to_text/s2t_sate.py
+15
-4
fairseq/models/speech_to_text/s2t_transformer.py
+2
-1
fairseq/modules/__init__.py
+4
-0
fairseq/modules/local_multihead_attention.py
+0
-1
fairseq/modules/pyramid_layer.py
+530
-0
fairseq/modules/reduced_multihead_attention.py
+495
-0
没有找到文件。
fairseq/models/speech_to_text/__init__.py
查看文件 @
b23817e0
...
...
@@ -7,4 +7,5 @@ from .berard import * # noqa
from
.convtransformer
import
*
# noqa
from
.s2t_transformer
import
*
# noqa
from
.s2t_conformer
import
*
# noqa
from
.pys2t_transformer
import
*
# noqa
from
.s2t_sate
import
*
# noqa
fairseq/models/speech_to_text/pys2t_transformer.py
0 → 100644
查看文件 @
b23817e0
#!/usr/bin/env python3
import
logging
import
math
import
torch
from
functools
import
reduce
import
torch.nn
as
nn
from
fairseq
import
checkpoint_utils
from
fairseq.data.data_utils
import
lengths_to_padding_mask
from
fairseq.models
import
(
FairseqEncoder
,
register_model
,
register_model_architecture
,
)
from
fairseq.models.speech_to_text
import
S2TTransformerModel
from
fairseq.modules
import
(
FairseqDropout
,
LayerNorm
,
PositionalEmbedding
,
PyramidTransformerEncoderLayer
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
lengths_to_padding_mask_with_maxlen
(
lens
,
max_length
):
bsz
=
lens
.
size
(
0
)
mask
=
torch
.
arange
(
max_length
)
.
to
(
lens
.
device
)
.
view
(
1
,
max_length
)
mask
=
mask
.
expand
(
bsz
,
-
1
)
>=
lens
.
view
(
bsz
,
1
)
.
expand
(
-
1
,
max_length
)
return
mask
class
ReducedEmbed
(
nn
.
Module
):
# Reduced embedding for Pyramid Transformer
def
__init__
(
self
,
reduced_way
:
str
,
embed_norm
:
bool
,
in_channels
:
int
,
out_channels
:
int
,
kernel_sizes
:
int
,
stride
:
int
,
padding
:
int
,
):
super
()
.
__init__
()
self
.
stride
=
stride
self
.
reduced_way
=
reduced_way
if
self
.
reduced_way
==
"conv"
:
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_sizes
,
stride
=
stride
,
padding
=
padding
)
elif
self
.
reduced_way
==
"glu"
:
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
*
2
,
kernel_sizes
,
stride
=
stride
,
padding
=
padding
)
self
.
glu
=
nn
.
GLU
(
dim
=
1
)
else
:
logger
.
error
(
"Unsupported reduced way!"
)
self
.
embed_norm
=
embed_norm
if
self
.
embed_norm
:
# self.norm = LayerNorm(out_channels)
self
.
norm
=
LayerNorm
(
in_channels
)
def
forward
(
self
,
x
,
lengths
):
seq_len
,
bsz
,
dim
=
x
.
size
()
assert
seq_len
%
self
.
stride
==
0
,
"The sequence length
%
d must be a multiple of
%
d."
%
(
seq_len
,
self
.
stride
)
padding_mask
=
lengths_to_padding_mask_with_maxlen
(
lengths
,
seq_len
)
# bsz, seq_len
mask_pad
=
padding_mask
.
unsqueeze
(
2
)
# mask batch padding
if
mask_pad
is
not
None
:
x
=
x
.
transpose
(
0
,
1
)
x
.
masked_fill_
(
mask_pad
,
0.0
)
x
=
x
.
transpose
(
0
,
1
)
if
self
.
embed_norm
:
x
=
self
.
norm
(
x
)
x
=
x
.
permute
(
1
,
2
,
0
)
# B * D * T
x
=
self
.
conv
(
x
)
if
self
.
reduced_way
==
"glu"
:
x
=
self
.
glu
(
x
)
x
=
x
.
permute
(
2
,
0
,
1
)
# T * B * D
lengths
=
lengths
/
self
.
stride
padding_mask
=
lengths_to_padding_mask_with_maxlen
(
lengths
,
x
.
size
(
0
))
mask_pad
=
padding_mask
.
unsqueeze
(
2
)
# mask batch padding
if
mask_pad
is
not
None
:
x
=
x
.
transpose
(
0
,
1
)
x
.
masked_fill_
(
mask_pad
,
0.0
)
x
=
x
.
transpose
(
0
,
1
)
return
x
,
lengths
,
padding_mask
@register_model
(
"pys2t_transformer"
)
class
PYS2TTransformerModel
(
S2TTransformerModel
):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
S2TTransformerModel
.
add_args
(
parser
)
parser
.
add_argument
(
"--pyramid-stages"
,
type
=
int
,
help
=
"the number of the stage"
,
)
parser
.
add_argument
(
"--pyramid-layers"
,
type
=
str
,
help
=
"the number of the encoder layers"
,
)
parser
.
add_argument
(
"--pyramid-sr-ratios"
,
type
=
str
,
help
=
"the ratio of the subsampling"
,
)
parser
.
add_argument
(
"--pyramid-attn-sample-ratio"
,
type
=
str
,
help
=
"the ratio of the subsampling in the self attention module"
,
)
parser
.
add_argument
(
"--pyramid-reduced-embed"
,
type
=
str
,
choices
=
[
"glu"
,
"conv"
],
help
=
"the reduced way of the embedding"
,
)
parser
.
add_argument
(
"--pyramid-embed-norm"
,
action
=
"store_true"
,
help
=
"use layer norm in reduced embedding"
,
)
parser
.
add_argument
(
"--pyramid-position-embed"
,
type
=
str
,
help
=
"use the position embedding or not"
,
)
parser
.
add_argument
(
"--pyramid-embed-dims"
,
type
=
str
,
help
=
"the embedding dimension"
,
)
parser
.
add_argument
(
"--pyramid-kernel-sizes"
,
type
=
str
,
help
=
"the kernel size of the reduced embedding"
,
)
parser
.
add_argument
(
"--pyramid-ffn-ratios"
,
type
=
str
,
help
=
"the ratio of the ffn"
,
)
parser
.
add_argument
(
"--pyramid-heads"
,
type
=
str
,
help
=
"the number of the attention heads"
,
)
parser
.
add_argument
(
"--ctc-layer"
,
type
=
int
,
help
=
"the position of the ctc loss"
,
)
pass
@classmethod
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
PyS2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
return
encoder
class
PyS2TTransformerEncoder
(
FairseqEncoder
):
"""Speech-to-text Pyramid Transformer encoder"""
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
None
)
self
.
padding_idx
=
1
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
pyramid_stages
=
getattr
(
args
,
"pyramid_stages"
,
4
)
self
.
pyramid_layers
=
[
int
(
n
)
for
n
in
args
.
pyramid_layers
.
split
(
"_"
)]
self
.
pyramid_sr_ratios
=
[
int
(
n
)
for
n
in
args
.
pyramid_sr_ratios
.
split
(
"_"
)]
self
.
pyramid_attn_sample_ratios
=
[
int
(
n
)
for
n
in
args
.
pyramid_attn_sample_ratios
.
split
(
"_"
)]
self
.
pyramid_embed_dims
=
[
int
(
n
)
for
n
in
args
.
pyramid_embed_dims
.
split
(
"_"
)]
self
.
pyramid_position_embed
=
[
int
(
n
)
for
n
in
args
.
pyramid_position_embed
.
split
(
"_"
)]
self
.
pyramid_kernel_sizes
=
[
int
(
n
)
for
n
in
args
.
pyramid_kernel_sizes
.
split
(
"_"
)]
self
.
pyramid_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pyramid_ffn_ratios
.
split
(
"_"
)]
self
.
pyramid_heads
=
[
int
(
n
)
for
n
in
args
.
pyramid_heads
.
split
(
"_"
)]
self
.
pyramid_reduced_embed
=
args
.
pyramid_reduced_embed
self
.
pyramid_embed_norm
=
args
.
pyramid_embed_norm
for
i
in
range
(
self
.
pyramid_stages
):
num_layers
=
self
.
pyramid_layers
[
i
]
sr_ratio
=
self
.
pyramid_sr_ratios
[
i
]
attn_sample_ratio
=
self
.
pyramid_attn_sample_ratios
[
i
]
embed_dim
=
self
.
pyramid_embed_dims
[
i
]
kernel_size
=
self
.
pyramid_kernel_sizes
[
i
]
ffn_ratio
=
self
.
pyramid_ffn_ratios
[
i
]
num_head
=
self
.
pyramid_heads
[
i
]
use_pos_embed
=
self
.
pyramid_position_embed
[
i
]
if
i
==
0
:
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
args
.
no_scale_embedding
:
self
.
embed_scale
=
1.0
reduced_embed
=
ReducedEmbed
(
self
.
pyramid_reduced_embed
,
self
.
pyramid_embed_norm
if
i
!=
0
else
False
,
args
.
input_feat_per_channel
*
args
.
input_channels
if
i
==
0
else
self
.
pyramid_embed_dims
[
i
-
1
],
embed_dim
,
kernel_sizes
=
kernel_size
,
stride
=
sr_ratio
,
padding
=
kernel_size
//
2
,
)
if
use_pos_embed
:
pos_embed
=
PositionalEmbedding
(
args
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
,
pos_emb_type
=
self
.
attn_type
)
else
:
pos_embed
=
None
dropout
=
FairseqDropout
(
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
block
=
nn
.
ModuleList
([
PyramidTransformerEncoderLayer
(
args
,
embed_dim
,
embed_dim
*
ffn_ratio
,
num_head
,
attn_sample_ratio
)
for
_
in
range
(
num_layers
)])
setattr
(
self
,
f
"reduced_embed{i + 1}"
,
reduced_embed
)
setattr
(
self
,
f
"pos_embed{i + 1}"
,
pos_embed
)
setattr
(
self
,
f
"dropout{i + 1}"
,
dropout
)
setattr
(
self
,
f
"block{i + 1}"
,
block
)
if
i
==
self
.
pyramid_stages
-
1
:
if
args
.
encoder_normalize_before
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
else
:
self
.
layer_norm
=
None
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
))
if
self
.
use_ctc
:
self
.
ctc_layer
=
(
args
.
encoder_layers
+
args
.
ctc_layer
)
%
args
.
encoder_layers
self
.
inter_ctc
=
True
if
self
.
ctc_layer
!=
args
.
encoder_layers
-
1
else
False
if
task
.
source_dictionary
==
task
.
target_dictionary
and
getattr
(
args
,
"share_all_embeddings"
,
False
):
self
.
ctc_projection
=
nn
.
Linear
(
embed_tokens
.
weight
.
shape
[
1
],
embed_tokens
.
weight
.
shape
[
0
],
bias
=
False
,
)
self
.
ctc_projection
.
weight
=
embed_tokens
.
weight
else
:
embed_dim
=
self
.
pyramid_embed_dims
[
-
1
]
if
self
.
inter_ctc
:
ctc_layer
=
self
.
ctc_layer
for
i
in
range
(
self
.
pyramid_stages
):
ctc_layer
-=
self
.
pyramid_layers
[
i
]
if
ctc_layer
<=
0
:
embed_dim
=
self
.
pyramid_embed_dims
[
i
]
self
.
ctc_layer_norm
=
LayerNorm
(
embed_dim
)
self
.
ctc_projection
=
nn
.
Linear
(
embed_dim
,
len
(
task
.
source_dictionary
),
bias
=
False
)
nn
.
init
.
normal_
(
self
.
ctc_projection
.
weight
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
self
.
ctc_dropout_module
=
FairseqDropout
(
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
batch
=
src_tokens
.
size
(
0
)
x
=
src_tokens
.
transpose
(
0
,
1
)
input_lengths
=
src_lengths
# padding to the multiply of 2
max_len
=
x
.
size
(
0
)
length
=
reduce
(
lambda
a
,
b
:
a
*
b
,
self
.
pyramid_sr_ratios
)
padding_to_len
=
(
length
-
max_len
%
length
)
if
padding_to_len
>
0
:
padding_for_pyramid
=
x
.
new_zeros
((
padding_to_len
,
batch
,
x
.
size
(
2
)))
x
=
torch
.
cat
([
x
,
padding_for_pyramid
],
dim
=
0
)
layer_idx
=
0
ctc_logit
=
None
for
i
in
range
(
self
.
pyramid_stages
):
reduced_embed
=
getattr
(
self
,
f
"reduced_embed{i + 1}"
)
pos_embed
=
getattr
(
self
,
f
"pos_embed{i + 1}"
)
dropout
=
getattr
(
self
,
f
"dropout{i + 1}"
)
block
=
getattr
(
self
,
f
"block{i + 1}"
)
if
i
==
0
:
x
=
self
.
embed_scale
*
x
# reduced embed
x
,
input_lengths
,
encoder_padding_mask
=
reduced_embed
(
x
,
input_lengths
)
# max_lens = int(x.size(0))
# encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, max_lens)
# add the position encoding and dropout
if
pos_embed
:
positions
=
pos_embed
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
if
self
.
attn_type
!=
"rel_selfattn"
:
x
+=
positions
if
i
==
0
:
x
=
dropout
(
x
)
positions
=
dropout
(
positions
)
else
:
positions
=
None
for
layer
in
block
:
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
layer_idx
+=
1
if
self
.
use_ctc
and
self
.
inter_ctc
and
self
.
ctc_layer
==
layer_idx
:
ctc_logit
=
self
.
ctc_layer_norm
(
x
)
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[],
# B x T x C
"encoder_states"
:
[],
# List[T x B x C]
"ctc_logit"
:
[
ctc_logit
if
ctc_logit
is
not
None
else
x
],
"src_tokens"
:
[],
"src_lengths"
:
[],
}
def
compute_ctc_logit
(
self
,
encoder_out
):
assert
self
.
use_ctc
,
"CTC is not available!"
if
isinstance
(
encoder_out
,
dict
)
and
"encoder_out"
in
encoder_out
:
encoder_state
=
encoder_out
[
"encoder_out"
][
0
]
else
:
encoder_state
=
encoder_out
ctc_logit
=
self
.
ctc_projection
(
self
.
ctc_dropout_module
(
encoder_state
))
return
ctc_logit
def
compute_ctc_prob
(
self
,
encoder_out
,
temperature
=
1.0
):
assert
self
.
use_ctc
,
"CTC is not available!"
ctc_logit
=
self
.
compute_ctc_logit
(
encoder_out
)
/
temperature
return
self
.
softmax
(
ctc_logit
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
new_encoder_out
=
(
[]
if
len
(
encoder_out
[
"encoder_out"
])
==
0
else
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"encoder_out"
]]
)
new_encoder_padding_mask
=
(
[]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
else
[
x
.
index_select
(
0
,
new_order
)
for
x
in
encoder_out
[
"encoder_padding_mask"
]]
)
new_encoder_embedding
=
(
[]
if
len
(
encoder_out
[
"encoder_embedding"
])
==
0
else
[
x
.
index_select
(
0
,
new_order
)
for
x
in
encoder_out
[
"encoder_embedding"
]]
)
encoder_states
=
encoder_out
[
"encoder_states"
]
if
len
(
encoder_states
)
>
0
:
for
idx
,
state
in
enumerate
(
encoder_states
):
encoder_states
[
idx
]
=
state
.
index_select
(
1
,
new_order
)
return
{
"encoder_out"
:
new_encoder_out
,
# T x B x C
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
[],
# B x T
"src_lengths"
:
[],
# B x 1
}
@register_model_architecture
(
model_name
=
"pys2t_transformer"
,
arch_name
=
"pys2t_transformer"
)
def
base_architecture
(
args
):
# Convolutional subsampler
args
.
conv_kernel_sizes
=
getattr
(
args
,
"conv_kernel_sizes"
,
""
)
args
.
conv_channels
=
getattr
(
args
,
"conv_channels"
,
1024
)
# Pyramid
args
.
pyramid_stages
=
getattr
(
args
,
"pyramid_stages"
,
None
)
args
.
pyramid_layers
=
getattr
(
args
,
"pyramid_layers"
,
None
)
args
.
pyramid_sr_ratios
=
getattr
(
args
,
"pyramid_sr_ratios"
,
None
)
args
.
pyramid_attn_sample_ratios
=
getattr
(
args
,
"pyramid_attn_sample_ratios"
,
None
)
args
.
pyramid_embed_dims
=
getattr
(
args
,
"pyramid_embed_dims"
,
None
)
args
.
pyramid_kernel_sizes
=
getattr
(
args
,
"pyramid_kernel_sizes"
,
None
)
args
.
pyramid_ffn_ratios
=
getattr
(
args
,
"pyramid_ffn_ratios"
,
None
)
args
.
pyramid_heads
=
getattr
(
args
,
"pyramid_heads"
,
None
)
args
.
pyramid_position_embed
=
getattr
(
args
,
"pyramid_position_embed"
,
None
)
args
.
pyramid_reduced_embed
=
getattr
(
args
,
"pyramid_reduced_embed"
,
"conv"
)
args
.
pyramid_embed_norm
=
getattr
(
args
,
"pyramid_embed_norm"
,
False
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
-
1
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
True
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
True
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
args
.
dropout
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
args
.
dropout
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0.0
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_s"
)
def
pys2t_transformer_s
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
8
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
pyramid_stages
=
getattr
(
args
,
"pyramid_stages"
,
4
)
args
.
pyramid_layers
=
getattr
(
args
,
"pyramid_layers"
,
"3_3_3_3"
)
args
.
pyramid_embed_dims
=
getattr
(
args
,
"pyramid_embed_dims"
,
"64_128_256_512"
)
args
.
pyramid_kernel_sizes
=
getattr
(
args
,
"pyramid_kernel_sizes"
,
"2_2_2_2"
)
args
.
pyramid_ffn_ratios
=
getattr
(
args
,
"pyramid_ffn_ratios"
,
"4_4_4_4"
)
args
.
pyramid_attn_sample_ratios
=
getattr
(
args
,
"pyramid_attn_sample_ratios"
,
"8_4_2_1"
)
args
.
pyramid_sr_ratios
=
getattr
(
args
,
"pyramid_sr_ratios"
,
"2_2_2_2"
)
args
.
pyramid_heads
=
getattr
(
args
,
"pyramid_heads"
,
"1_2_4_8"
)
args
.
pyramid_position_embed
=
getattr
(
args
,
"pyramid_position_embed"
,
"1_1_1_1"
)
args
.
pyramid_reduced_embed
=
getattr
(
args
,
"pyramid_reduced_embed"
,
"conv"
)
args
.
pyramid_embed_norm
=
getattr
(
args
,
"pyramid_embed_norm"
,
False
)
base_architecture
(
args
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_s_relative"
)
def
pys2t_transformer_s_relative
(
args
):
args
.
max_encoder_relative_length
=
100
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
pys2t_transformer_s
(
args
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_xs"
)
def
pys2t_transformer_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
3
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.3
)
pys2t_transformer_s
(
args
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_sp"
)
def
pys2t_transformer_sp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
pys2t_transformer_s
(
args
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_m"
)
def
pys2t_transformer_m
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
512
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.15
)
base_architecture
(
args
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_mp"
)
def
pys2t_transformer_mp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
pys2t_transformer_m
(
args
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_l"
)
def
pys2t_transformer_l
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
1024
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
base_architecture
(
args
)
@register_model_architecture
(
"pys2t_transformer"
,
"pys2t_transformer_lp"
)
def
pys2t_transformer_lp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
pys2t_transformer_l
(
args
)
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
b23817e0
...
...
@@ -17,7 +17,9 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel
,
S2TTransformerEncoder
,
S2TConformerEncoder
,
S2TConformerModel
S2TConformerModel
,
PYS2TTransformerModel
,
PyS2TTransformerEncoder
,
)
from
fairseq.models.speech_to_text.s2t_transformer
import
Conv1dSubsampler
from
fairseq.modules
import
(
...
...
@@ -46,6 +48,7 @@ class S2TSATEModel(S2TTransformerModel):
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
S2TConformerModel
.
add_args
(
parser
)
PYS2TTransformerModel
.
add_args
(
parser
)
parser
.
add_argument
(
"--text-encoder-layers"
,
...
...
@@ -195,13 +198,16 @@ class Adapter(nn.Module):
linear_out
=
self
.
linear_adapter
(
representation
)
soft_out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
out
=
linear_out
+
soft_out
elif
self
.
adapter_type
==
"gated_league"
:
linear_out
=
self
.
linear_adapter
(
representation
)
soft_out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
coef
=
(
self
.
gate_linear
(
torch
.
cat
([
linear_out
,
soft_out
],
dim
=-
1
)))
.
sigmoid
()
out
=
coef
*
linear_out
+
(
1
-
coef
)
*
soft_out
elif
self
.
adapter_type
==
"none"
:
out
=
representation
else
:
out
=
None
logging
.
error
(
"Unsupported adapter type: {}."
.
format
(
self
.
adapter_type
))
...
...
@@ -262,6 +268,8 @@ class S2TSATEEncoder(FairseqEncoder):
self
.
acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
elif
acoustic_encoder_type
==
"conformer"
:
self
.
acoustic_encoder
=
S2TConformerEncoder
(
args
,
task
,
embed_tokens
)
elif
acoustic_encoder_type
==
"pyramid"
:
self
.
acoustic_encoder
=
PyS2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
else
:
logging
.
error
(
"Unsupported model arch {}!"
.
format
(
acoustic_encoder_type
))
...
...
@@ -277,9 +285,9 @@ class S2TSATEEncoder(FairseqEncoder):
# )
acoustic_encoder_attention_type
=
args
.
encoder_attention_type
if
acoustic_encoder_attention_type
!=
"selfattn"
:
args
.
encoder_attention_type
=
"selfattn"
logger
.
info
(
"Force self attention for text encoder."
)
#
if acoustic_encoder_attention_type != "selfattn":
#
args.encoder_attention_type = "selfattn"
#
logger.info("Force self attention for text encoder.")
# text encoder
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
...
...
@@ -378,6 +386,9 @@ def base_architecture(args):
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
False
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
# Pyramid
args
.
pyramid_layers
=
getattr
(
args
,
"pyramid_layers"
,
None
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
b23817e0
...
...
@@ -147,10 +147,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type
=
str
,
default
=
"selfattn"
,
choices
=
[
"local"
,
"selfattn"
,
"reduced"
,
"rel_selfattn"
,
"relative"
,
"local"
,
],
help
=
"transformer encoder self-attention layer type"
)
...
...
fairseq/modules/__init__.py
查看文件 @
b23817e0
...
...
@@ -29,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from
.local_multihead_attention
import
LocalMultiheadAttention
from
.multihead_attention
import
MultiheadAttention
from
.positional_embedding
import
PositionalEmbedding
from
.reduced_multihead_attention
import
ReducedMultiheadAttention
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
from
.relative_multihead_attention
import
RelativeMultiheadAttention
from
.same_pad
import
SamePad
...
...
@@ -41,6 +42,7 @@ from .unfold import unfold1d
from
.transformer_layer
import
TransformerDecoderLayer
,
TransformerEncoderLayer
from
.vggblock
import
VGGBlock
from
.conformer_layer
import
ConformerEncoderLayer
from
.pyramid_layer
import
PyramidTransformerEncoderLayer
__all__
=
[
"AdaptiveInput"
,
...
...
@@ -74,6 +76,8 @@ __all__ = [
"LocalMultiheadAttention"
,
"MultiheadAttention"
,
"PositionalEmbedding"
,
"PyramidTransformerEncoderLayer"
,
"ReducedMultiheadAttention"
,
"RelPositionMultiheadAttention"
,
"RelativeMultiheadAttention"
,
"SamePad"
,
...
...
fairseq/modules/local_multihead_attention.py
查看文件 @
b23817e0
...
...
@@ -325,7 +325,6 @@ class LocalMultiheadAttention(nn.Module):
multihead_mask_weight
=
None
gauss_bias
=
None
if
self
.
multihead_gauss_mask_sigma
is
not
None
:
data_type
=
attn_weights
.
dtype
x1
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
-
1
,
1
)
.
to
(
attn_weights
.
device
)
x2
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
1
,
-
1
)
.
to
(
attn_weights
.
device
)
dis_square
=
-
(
x1
-
x2
)
**
2
/
2.0
...
...
fairseq/modules/pyramid_layer.py
0 → 100644
查看文件 @
b23817e0
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq.modules
import
(
LayerNorm
,
MultiheadAttention
,
ReducedMultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
,
LocalMultiheadAttention
,
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
class
PyramidTransformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def
__init__
(
self
,
args
,
embed_dim
,
ffn_embed_dim
,
num_head
,
att_sample_ratio
=
1
):
super
()
.
__init__
()
self
.
args
=
args
self
.
embed_dim
=
embed_dim
self
.
encoder_ffn_embed_dim
=
ffn_embed_dim
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
num_head
,
args
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation
=
getattr
(
args
,
'activation_fn'
,
'relu'
)
or
"relu"
)
activation_dropout_p
=
getattr
(
args
,
"activation_dropout"
,
0
)
or
0
if
activation_dropout_p
==
0
:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p
=
getattr
(
args
,
"relu_dropout"
,
0
)
or
0
self
.
activation_dropout_module
=
FairseqDropout
(
float
(
activation_dropout_p
),
module_name
=
self
.
__class__
.
__name__
)
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
self
.
build_fc1
(
self
.
embed_dim
,
self
.
encoder_ffn_embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
fc2
=
self
.
build_fc2
(
self
.
encoder_ffn_embed_dim
,
self
.
embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
p
=
q_noise
,
block_size
=
qn_block_size
)
def
build_fc2
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
p
=
q_noise
,
block_size
=
qn_block_size
)
def
build_self_attention
(
self
,
args
,
embed_dim
,
num_head
,
sample_ratio
=
1
):
encoder_attention_heads
=
num_head
if
self
.
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
elif
self
.
attn_type
==
"local"
:
hard_mask_window
=
getattr
(
args
,
"hard_mask_window"
,
0
)
gauss_mask_sigma
=
getattr
(
args
,
"gauss_mask_sigma"
,
0
)
init_mask_weight
=
getattr
(
args
,
"init_mask_weight"
,
0
)
return
LocalMultiheadAttention
(
embed_dim
,
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
hard_mask_window
=
hard_mask_window
,
gauss_mask_sigma
=
gauss_mask_sigma
,
init_mask_weight
=
init_mask_weight
)
elif
self
.
attn_type
==
"reduced"
:
return
ReducedMultiheadAttention
(
embed_dim
,
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
sample_ratio
=
sample_ratio
,
)
else
:
print
(
"The encoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
embed_dim
,
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
residual_connection
(
self
,
x
,
residual
):
return
residual
+
x
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"final_layer_norm"
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layer_norms.{}.{}"
.
format
(
name
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.{}.{}"
.
format
(
name
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_emb
:
Optional
[
Tensor
]
=
None
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
pos_emb (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
.
to
(
torch
.
bool
),
-
1e8
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
==
"rel_selfattn"
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
return
x
class
TransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
args
,
no_encoder_attn
=
False
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
super
()
.
__init__
()
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
dropout_module
=
FairseqDropout
(
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
quant_noise
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
self
.
cross_self_attention
=
getattr
(
args
,
"cross_self_attention"
,
False
)
self
.
attn_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
args
,
add_bias_kv
=
add_bias_kv
,
add_zero_attn
=
add_zero_attn
,
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation
=
str
(
args
.
activation_fn
)
if
getattr
(
args
,
"activation_fn"
,
None
)
is
not
None
else
"relu"
)
activation_dropout_p
=
getattr
(
args
,
"activation_dropout"
,
0
)
or
0
if
activation_dropout_p
==
0
:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p
=
getattr
(
args
,
"relu_dropout"
,
0
)
or
0
self
.
activation_dropout_module
=
FairseqDropout
(
float
(
activation_dropout_p
),
module_name
=
self
.
__class__
.
__name__
)
self
.
normalize_before
=
args
.
decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export
=
getattr
(
args
,
"char_inputs"
,
False
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
if
no_encoder_attn
:
self
.
encoder_attn
=
None
self
.
encoder_attn_layer_norm
=
None
else
:
self
.
encoder_attn
=
self
.
build_encoder_attention
(
self
.
embed_dim
,
args
)
self
.
encoder_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
self
.
fc1
=
self
.
build_fc1
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
fc2
=
self
.
build_fc2
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
self
.
need_attn
=
True
self
.
onnx_trace
=
False
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
q_noise
,
qn_block_size
)
def
build_fc2
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
q_noise
,
qn_block_size
)
def
build_self_attention
(
self
,
embed_dim
,
args
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
if
self
.
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
max
(
getattr
(
args
,
"max_decoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum decoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
print
(
"The decoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
add_bias_kv
=
add_bias_kv
,
add_zero_attn
=
add_zero_attn
,
self_attention
=
not
getattr
(
args
,
"cross_self_attention"
,
False
),
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
build_encoder_attention
(
self
,
embed_dim
,
args
):
return
MultiheadAttention
(
embed_dim
,
args
.
decoder_attention_heads
,
kdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
vdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
dropout
=
args
.
attention_dropout
,
encoder_decoder_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
def
residual_connection
(
self
,
x
,
residual
):
return
residual
+
x
def
forward
(
self
,
x
,
encoder_out
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
prev_self_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
prev_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
self_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
self_attn_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_attn
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
pos_emb
:
Optional
[
Tensor
]
=
None
,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if
need_head_weights
:
need_attn
=
True
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
prev_self_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_self_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_self_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_self_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
self_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
_self_attn_input_buffer
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
if
self
.
cross_self_attention
and
not
(
incremental_state
is
not
None
and
_self_attn_input_buffer
is
not
None
and
"prev_key"
in
_self_attn_input_buffer
):
if
self_attn_mask
is
not
None
:
assert
encoder_out
is
not
None
self_attn_mask
=
torch
.
cat
(
(
x
.
new_zeros
(
x
.
size
(
0
),
encoder_out
.
size
(
0
)),
self_attn_mask
),
dim
=
1
)
if
self_attn_padding_mask
is
not
None
:
if
encoder_padding_mask
is
None
:
assert
encoder_out
is
not
None
encoder_padding_mask
=
self_attn_padding_mask
.
new_zeros
(
encoder_out
.
size
(
1
),
encoder_out
.
size
(
0
)
)
self_attn_padding_mask
=
torch
.
cat
(
(
encoder_padding_mask
,
self_attn_padding_mask
),
dim
=
1
)
assert
encoder_out
is
not
None
y
=
torch
.
cat
((
encoder_out
,
x
),
dim
=
0
)
else
:
y
=
x
if
self
.
attn_type
==
"rel_selfattn"
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
encoder_attn
is
not
None
and
encoder_out
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
if
prev_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
encoder_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
need_head_weights
=
need_head_weights
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
if
self
.
onnx_trace
and
incremental_state
is
not
None
:
saved_state
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
assert
saved_state
is
not
None
if
self_attn_padding_mask
is
not
None
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
],
saved_state
[
"prev_key_padding_mask"
],
]
else
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
]]
return
x
,
attn
,
self_attn_state
return
x
,
attn
,
None
def
make_generation_fast_
(
self
,
need_attn
:
bool
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
fairseq/modules/reduced_multihead_attention.py
0 → 100644
查看文件 @
b23817e0
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.incremental_decoding_utils
import
with_incremental_state
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
,
nn
from
torch.nn
import
Parameter
@with_incremental_state
class
ReducedMultiheadAttention
(
nn
.
Module
):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
kdim
=
None
,
vdim
=
None
,
dropout
=
0.0
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
,
self_attention
=
False
,
encoder_decoder_attention
=
False
,
q_noise
=
0.0
,
qn_block_size
=
8
,
sample_ratio
=
1
,
):
super
()
.
__init__
()
self
.
embed_dim
=
embed_dim
self
.
kdim
=
kdim
if
kdim
is
not
None
else
embed_dim
self
.
vdim
=
vdim
if
vdim
is
not
None
else
embed_dim
self
.
qkv_same_dim
=
self
.
kdim
==
embed_dim
and
self
.
vdim
==
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout_module
=
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
head_dim
=
embed_dim
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
),
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
self_attention
=
self_attention
self
.
encoder_decoder_attention
=
encoder_decoder_attention
assert
not
self
.
self_attention
or
self
.
qkv_same_dim
,
(
"Self-attention requires query, key and "
"value to be of the same size"
)
self
.
k_proj
=
quant_noise
(
nn
.
Linear
(
self
.
kdim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
self
.
v_proj
=
quant_noise
(
nn
.
Linear
(
self
.
vdim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
self
.
q_proj
=
quant_noise
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
self
.
out_proj
=
quant_noise
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
if
add_bias_kv
:
self
.
bias_k
=
Parameter
(
torch
.
Tensor
(
1
,
1
,
embed_dim
))
self
.
bias_v
=
Parameter
(
torch
.
Tensor
(
1
,
1
,
embed_dim
))
else
:
self
.
bias_k
=
self
.
bias_v
=
None
self
.
add_zero_attn
=
add_zero_attn
self
.
sample_ratio
=
sample_ratio
if
self
.
sample_ratio
>
1
:
self
.
sr
=
nn
.
Conv2d
(
embed_dim
,
embed_dim
,
kernel_size
=
sample_ratio
,
stride
=
sample_ratio
)
self
.
norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
reset_parameters
()
self
.
onnx_trace
=
False
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
def
reset_parameters
(
self
):
if
self
.
qkv_same_dim
:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn
.
init
.
xavier_uniform_
(
self
.
k_proj
.
weight
,
gain
=
1
/
math
.
sqrt
(
2
))
nn
.
init
.
xavier_uniform_
(
self
.
v_proj
.
weight
,
gain
=
1
/
math
.
sqrt
(
2
))
nn
.
init
.
xavier_uniform_
(
self
.
q_proj
.
weight
,
gain
=
1
/
math
.
sqrt
(
2
))
else
:
nn
.
init
.
xavier_uniform_
(
self
.
k_proj
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
v_proj
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
q_proj
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
if
self
.
out_proj
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.0
)
if
self
.
bias_k
is
not
None
:
nn
.
init
.
xavier_normal_
(
self
.
bias_k
)
if
self
.
bias_v
is
not
None
:
nn
.
init
.
xavier_normal_
(
self
.
bias_v
)
def
forward
(
self
,
query
,
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
need_weights
:
bool
=
True
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
before_softmax
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if
need_head_weights
:
need_weights
=
True
is_tpu
=
query
.
device
.
type
==
"xla"
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
if
(
self
.
sample_ratio
==
1
and
not
self
.
onnx_trace
and
not
is_tpu
# don't use PyTorch version on TPUs
and
incremental_state
is
None
and
not
static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and
not
torch
.
jit
.
is_scripting
()
):
assert
key
is
not
None
and
value
is
not
None
return
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
self
.
embed_dim
,
self
.
num_heads
,
torch
.
empty
([
0
]),
torch
.
cat
((
self
.
q_proj
.
bias
,
self
.
k_proj
.
bias
,
self
.
v_proj
.
bias
)),
self
.
bias_k
,
self
.
bias_v
,
self
.
add_zero_attn
,
self
.
dropout_module
.
p
,
self
.
out_proj
.
weight
,
self
.
out_proj
.
bias
,
self
.
training
or
self
.
dropout_module
.
apply_during_inference
,
key_padding_mask
,
need_weights
,
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
self
.
q_proj
.
weight
,
k_proj_weight
=
self
.
k_proj
.
weight
,
v_proj_weight
=
self
.
v_proj
.
weight
,
)
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute
# key and value if they are static
if
static_kv
:
assert
self
.
encoder_decoder_attention
and
not
self
.
self_attention
key
=
value
=
None
else
:
saved_state
=
None
q
=
self
.
q_proj
(
query
)
if
self
.
self_attention
:
if
self
.
sample_ratio
>
1
:
query_
=
query
.
permute
(
1
,
2
,
0
)
# bsz, dim, seq_len
query_
=
self
.
sr
(
query_
)
.
permute
(
2
,
0
,
1
)
# seq_len, bsz, dim
query
=
self
.
norm
(
query_
)
k
=
self
.
k_proj
(
query
)
v
=
self
.
v_proj
(
query
)
elif
self
.
encoder_decoder_attention
:
# encoder-decoder attention
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
key
)
else
:
assert
key
is
not
None
and
value
is
not
None
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
value
)
q
*=
self
.
scaling
if
self
.
bias_k
is
not
None
:
assert
self
.
bias_v
is
not
None
k
=
torch
.
cat
([
k
,
self
.
bias_k
.
repeat
(
1
,
bsz
,
1
)])
v
=
torch
.
cat
([
v
,
self
.
bias_v
.
repeat
(
1
,
bsz
,
1
)])
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
key_padding_mask
.
new_zeros
(
key_padding_mask
.
size
(
0
),
1
),
],
dim
=
1
,
)
q
=
(
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
k
is
not
None
:
k
=
(
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
v
is
not
None
:
v
=
(
v
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
saved_state
is
not
None
:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if
"prev_key"
in
saved_state
:
_prev_key
=
saved_state
[
"prev_key"
]
assert
_prev_key
is
not
None
prev_key
=
_prev_key
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
k
=
prev_key
else
:
assert
k
is
not
None
k
=
torch
.
cat
([
prev_key
,
k
],
dim
=
1
)
if
"prev_value"
in
saved_state
:
_prev_value
=
saved_state
[
"prev_value"
]
assert
_prev_value
is
not
None
prev_value
=
_prev_value
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
v
=
prev_value
else
:
assert
v
is
not
None
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
prev_key_padding_mask
:
Optional
[
Tensor
]
=
None
if
"prev_key_padding_mask"
in
saved_state
:
prev_key_padding_mask
=
saved_state
[
"prev_key_padding_mask"
]
assert
k
is
not
None
and
v
is
not
None
key_padding_mask
=
ReducedMultiheadAttention
.
_append_prev_key_padding_mask
(
key_padding_mask
=
key_padding_mask
,
prev_key_padding_mask
=
prev_key_padding_mask
,
batch_size
=
bsz
,
src_len
=
k
.
size
(
1
),
static_kv
=
static_kv
,
)
saved_state
[
"prev_key"
]
=
k
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_value"
]
=
v
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_key_padding_mask"
]
=
key_padding_mask
# In this branch incremental_state is never None
assert
incremental_state
is
not
None
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
assert
k
is
not
None
src_len
=
k
.
size
(
1
)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if
key_padding_mask
is
not
None
and
key_padding_mask
.
dim
()
==
0
:
key_padding_mask
=
None
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
if
self
.
add_zero_attn
:
assert
v
is
not
None
src_len
+=
1
k
=
torch
.
cat
([
k
,
k
.
new_zeros
((
k
.
size
(
0
),
1
)
+
k
.
size
()[
2
:])],
dim
=
1
)
v
=
torch
.
cat
([
v
,
v
.
new_zeros
((
v
.
size
(
0
),
1
)
+
v
.
size
()[
2
:])],
dim
=
1
)
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
torch
.
zeros
(
key_padding_mask
.
size
(
0
),
1
)
.
type_as
(
key_padding_mask
),
],
dim
=
1
,
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
if
self
.
onnx_trace
:
attn_mask
=
attn_mask
.
repeat
(
attn_weights
.
size
(
0
),
1
,
1
)
attn_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
if
not
is_tpu
:
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
to
(
torch
.
bool
),
float
(
"-inf"
),
)
else
:
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
before_softmax
:
return
attn_weights
,
v
attn_weights_float
=
utils
.
softmax
(
attn_weights
,
dim
=-
1
,
onnx_trace
=
self
.
onnx_trace
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_probs
=
self
.
dropout_module
(
attn_weights
)
assert
v
is
not
None
attn
=
torch
.
bmm
(
attn_probs
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
if
self
.
onnx_trace
and
attn
.
size
(
1
)
==
1
:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn
=
attn
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
else
:
attn
=
attn
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
attn_weights
:
Optional
[
Tensor
]
=
None
if
need_weights
:
attn_weights
=
attn_weights_float
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
.
transpose
(
1
,
0
)
if
not
need_head_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
mean
(
dim
=
0
)
return
attn
,
attn_weights
@staticmethod
def
_append_prev_key_padding_mask
(
key_padding_mask
:
Optional
[
Tensor
],
prev_key_padding_mask
:
Optional
[
Tensor
],
batch_size
:
int
,
src_len
:
int
,
static_kv
:
bool
,
)
->
Optional
[
Tensor
]:
# saved key padding masks have shape (bsz, seq_len)
if
prev_key_padding_mask
is
not
None
and
static_kv
:
new_key_padding_mask
=
prev_key_padding_mask
elif
prev_key_padding_mask
is
not
None
and
key_padding_mask
is
not
None
:
new_key_padding_mask
=
torch
.
cat
(
[
prev_key_padding_mask
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif
prev_key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
(
batch_size
,
src_len
-
prev_key_padding_mask
.
size
(
1
)),
device
=
prev_key_padding_mask
.
device
,
)
new_key_padding_mask
=
torch
.
cat
(
[
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
)),
device
=
key_padding_mask
.
device
,
)
new_key_padding_mask
=
torch
.
cat
(
[
filler
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
else
:
new_key_padding_mask
=
prev_key_padding_mask
return
new_key_padding_mask
@torch.jit.export
def
reorder_incremental_state
(
self
,
incremental_state
:
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]],
new_order
:
Tensor
,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
input_buffer_k
=
input_buffer
[
k
]
if
input_buffer_k
is
not
None
:
if
self
.
encoder_decoder_attention
and
input_buffer_k
.
size
(
0
)
==
new_order
.
size
(
0
):
break
input_buffer
[
k
]
=
input_buffer_k
.
index_select
(
0
,
new_order
)
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
return
incremental_state
def
_get_input_buffer
(
self
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
)
->
Dict
[
str
,
Optional
[
Tensor
]]:
result
=
self
.
get_incremental_state
(
incremental_state
,
"attn_state"
)
if
result
is
not
None
:
return
result
else
:
empty_result
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{}
return
empty_result
def
_set_input_buffer
(
self
,
incremental_state
:
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]],
buffer
:
Dict
[
str
,
Optional
[
Tensor
]],
):
return
self
.
set_incremental_state
(
incremental_state
,
"attn_state"
,
buffer
)
def
apply_sparse_mask
(
self
,
attn_weights
,
tgt_len
:
int
,
src_len
:
int
,
bsz
:
int
):
return
attn_weights
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
prefix
=
name
+
"."
if
name
!=
""
else
""
items_to_add
=
{}
keys_to_remove
=
[]
for
k
in
state_dict
.
keys
():
if
k
.
endswith
(
prefix
+
"in_proj_weight"
):
# in_proj_weight used to be q + k + v with same dimensions
dim
=
int
(
state_dict
[
k
]
.
shape
[
0
]
/
3
)
items_to_add
[
prefix
+
"q_proj.weight"
]
=
state_dict
[
k
][:
dim
]
items_to_add
[
prefix
+
"k_proj.weight"
]
=
state_dict
[
k
][
dim
:
2
*
dim
]
items_to_add
[
prefix
+
"v_proj.weight"
]
=
state_dict
[
k
][
2
*
dim
:]
keys_to_remove
.
append
(
k
)
k_bias
=
prefix
+
"in_proj_bias"
if
k_bias
in
state_dict
.
keys
():
dim
=
int
(
state_dict
[
k
]
.
shape
[
0
]
/
3
)
items_to_add
[
prefix
+
"q_proj.bias"
]
=
state_dict
[
k_bias
][:
dim
]
items_to_add
[
prefix
+
"k_proj.bias"
]
=
state_dict
[
k_bias
][
dim
:
2
*
dim
]
items_to_add
[
prefix
+
"v_proj.bias"
]
=
state_dict
[
k_bias
][
2
*
dim
:]
keys_to_remove
.
append
(
prefix
+
"in_proj_bias"
)
for
k
in
keys_to_remove
:
del
state_dict
[
k
]
for
key
,
value
in
items_to_add
.
items
():
state_dict
[
key
]
=
value
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论