Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
ca78c4b8
Commit
ca78c4b8
authored
Aug 09, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the local attention
parent
f1cf477d
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
546 行增加
和
1 行删除
+546
-1
.gitignore
+1
-0
fairseq/models/speech_to_text/s2t_transformer.py
+29
-0
fairseq/modules/__init__.py
+2
-0
fairseq/modules/local_multihead_attention.py
+497
-0
fairseq/modules/transformer_layer.py
+17
-1
没有找到文件。
.gitignore
查看文件 @
ca78c4b8
...
...
@@ -135,3 +135,4 @@ experimental/*
# Weights and Biases logs
wandb/
/examples/translation/iwslt14.tokenized.de-en/
toy/
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
ca78c4b8
...
...
@@ -150,6 +150,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"selfattn"
,
"rel_selfattn"
,
"relative"
,
"local"
,
],
help
=
"transformer encoder self-attention layer type"
)
...
...
@@ -187,6 +188,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"selfattn"
,
"rel_selfattn"
,
"relative"
,
"local"
,
],
help
=
"transformer decoder self-attention layer type"
)
...
...
@@ -277,6 +279,29 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default
=
"learnable_dense"
,
help
=
'decoder layer history type'
)
parser
.
add_argument
(
'--hard-mask-window'
,
type
=
float
,
metavar
=
"D"
,
default
=
0
,
help
=
'window size of local mask'
)
parser
.
add_argument
(
'--gauss-mask-sigma'
,
type
=
float
,
metavar
=
"D"
,
default
=
0
,
help
=
'standard deviation of the gauss mask'
)
parser
.
add_argument
(
'--init-mask-weight'
,
type
=
float
,
metavar
=
"D"
,
default
=
0.5
,
help
=
'initialized weight for local mask'
)
pass
@classmethod
...
...
@@ -627,6 +652,10 @@ def base_architecture(args):
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
args
.
hard_mask_window
=
getattr
(
args
,
'hard_mask_window'
,
0
)
args
.
gauss_mask_sigma
=
getattr
(
args
,
'gauss_mask_sigma'
,
0
)
args
.
init_mask_weight
=
getattr
(
args
,
'init_mask_weight'
,
0
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s"
)
def
s2t_transformer_s
(
args
):
...
...
fairseq/modules/__init__.py
查看文件 @
ca78c4b8
...
...
@@ -26,6 +26,7 @@ from .layer_norm import Fp32LayerNorm, LayerNorm
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.lightweight_convolution
import
LightweightConv
,
LightweightConv1dTBC
from
.linearized_convolution
import
LinearizedConvolution
from
.local_multihead_attention
import
LocalMultiheadAttention
from
.multihead_attention
import
MultiheadAttention
from
.positional_embedding
import
PositionalEmbedding
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
...
...
@@ -70,6 +71,7 @@ __all__ = [
"LightweightConv1dTBC"
,
"LightweightConv"
,
"LinearizedConvolution"
,
"LocalMultiheadAttention"
,
"MultiheadAttention"
,
"PositionalEmbedding"
,
"RelPositionMultiheadAttention"
,
...
...
fairseq/modules/local_multihead_attention.py
0 → 100644
查看文件 @
ca78c4b8
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.incremental_decoding_utils
import
with_incremental_state
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
,
nn
from
torch.nn
import
Parameter
@with_incremental_state
class
LocalMultiheadAttention
(
nn
.
Module
):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
kdim
=
None
,
vdim
=
None
,
dropout
=
0.0
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
,
self_attention
=
False
,
encoder_decoder_attention
=
False
,
q_noise
=
0.0
,
qn_block_size
=
8
,
hard_mask_window
=
0
,
gauss_mask_sigma
=
0
,
init_mask_weight
=
0.5
,
):
super
()
.
__init__
()
self
.
embed_dim
=
embed_dim
self
.
kdim
=
kdim
if
kdim
is
not
None
else
embed_dim
self
.
vdim
=
vdim
if
vdim
is
not
None
else
embed_dim
self
.
qkv_same_dim
=
self
.
kdim
==
embed_dim
and
self
.
vdim
==
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout_module
=
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
head_dim
=
embed_dim
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
),
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
self_attention
=
self_attention
self
.
encoder_decoder_attention
=
encoder_decoder_attention
assert
not
self
.
self_attention
or
self
.
qkv_same_dim
,
(
"Self-attention requires query, key and "
"value to be of the same size"
)
self
.
k_proj
=
quant_noise
(
nn
.
Linear
(
self
.
kdim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
self
.
v_proj
=
quant_noise
(
nn
.
Linear
(
self
.
vdim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
self
.
q_proj
=
quant_noise
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
self
.
out_proj
=
quant_noise
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
if
add_bias_kv
:
self
.
bias_k
=
Parameter
(
torch
.
Tensor
(
1
,
1
,
embed_dim
))
self
.
bias_v
=
Parameter
(
torch
.
Tensor
(
1
,
1
,
embed_dim
))
else
:
self
.
bias_k
=
self
.
bias_v
=
None
self
.
add_zero_attn
=
add_zero_attn
self
.
reset_parameters
()
self
.
onnx_trace
=
False
if
hard_mask_window
!=
0
:
self
.
hard_mask_window
=
hard_mask_window
else
:
self
.
hard_mask_window
=
0
if
gauss_mask_sigma
!=
0
:
self
.
multihead_gauss_mask_sigma
=
Parameter
(
torch
.
Tensor
(
num_heads
,
1
,
1
))
nn
.
init
.
constant_
(
self
.
multihead_gauss_mask_sigma
,
gauss_mask_sigma
)
else
:
self
.
multihead_gauss_mask_sigma
=
None
self
.
multihead_mask_weight
=
Parameter
(
torch
.
Tensor
(
num_heads
,
1
,
1
))
nn
.
init
.
constant_
(
self
.
multihead_mask_weight
,
init_mask_weight
)
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
def
reset_parameters
(
self
):
if
self
.
qkv_same_dim
:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn
.
init
.
xavier_uniform_
(
self
.
k_proj
.
weight
,
gain
=
1
/
math
.
sqrt
(
2
))
nn
.
init
.
xavier_uniform_
(
self
.
v_proj
.
weight
,
gain
=
1
/
math
.
sqrt
(
2
))
nn
.
init
.
xavier_uniform_
(
self
.
q_proj
.
weight
,
gain
=
1
/
math
.
sqrt
(
2
))
else
:
nn
.
init
.
xavier_uniform_
(
self
.
k_proj
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
v_proj
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
q_proj
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
if
self
.
out_proj
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.0
)
if
self
.
bias_k
is
not
None
:
nn
.
init
.
xavier_normal_
(
self
.
bias_k
)
if
self
.
bias_v
is
not
None
:
nn
.
init
.
xavier_normal_
(
self
.
bias_v
)
def
forward
(
self
,
query
,
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
need_weights
:
bool
=
True
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
before_softmax
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if
need_head_weights
:
need_weights
=
True
is_tpu
=
query
.
device
.
type
==
"xla"
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute
# key and value if they are static
if
static_kv
:
assert
self
.
encoder_decoder_attention
and
not
self
.
self_attention
key
=
value
=
None
else
:
saved_state
=
None
if
self
.
self_attention
:
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
query
)
v
=
self
.
v_proj
(
query
)
elif
self
.
encoder_decoder_attention
:
# encoder-decoder attention
q
=
self
.
q_proj
(
query
)
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
key
)
else
:
assert
key
is
not
None
and
value
is
not
None
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
value
)
q
*=
self
.
scaling
if
self
.
bias_k
is
not
None
:
assert
self
.
bias_v
is
not
None
k
=
torch
.
cat
([
k
,
self
.
bias_k
.
repeat
(
1
,
bsz
,
1
)])
v
=
torch
.
cat
([
v
,
self
.
bias_v
.
repeat
(
1
,
bsz
,
1
)])
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
key_padding_mask
.
new_zeros
(
key_padding_mask
.
size
(
0
),
1
),
],
dim
=
1
,
)
q
=
(
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
k
is
not
None
:
k
=
(
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
v
is
not
None
:
v
=
(
v
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
saved_state
is
not
None
:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if
"prev_key"
in
saved_state
:
_prev_key
=
saved_state
[
"prev_key"
]
assert
_prev_key
is
not
None
prev_key
=
_prev_key
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
k
=
prev_key
else
:
assert
k
is
not
None
k
=
torch
.
cat
([
prev_key
,
k
],
dim
=
1
)
if
"prev_value"
in
saved_state
:
_prev_value
=
saved_state
[
"prev_value"
]
assert
_prev_value
is
not
None
prev_value
=
_prev_value
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
v
=
prev_value
else
:
assert
v
is
not
None
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
prev_key_padding_mask
:
Optional
[
Tensor
]
=
None
if
"prev_key_padding_mask"
in
saved_state
:
prev_key_padding_mask
=
saved_state
[
"prev_key_padding_mask"
]
assert
k
is
not
None
and
v
is
not
None
key_padding_mask
=
LocalMultiheadAttention
.
_append_prev_key_padding_mask
(
key_padding_mask
=
key_padding_mask
,
prev_key_padding_mask
=
prev_key_padding_mask
,
batch_size
=
bsz
,
src_len
=
k
.
size
(
1
),
static_kv
=
static_kv
,
)
saved_state
[
"prev_key"
]
=
k
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_value"
]
=
v
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_key_padding_mask"
]
=
key_padding_mask
# In this branch incremental_state is never None
assert
incremental_state
is
not
None
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
assert
k
is
not
None
src_len
=
k
.
size
(
1
)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if
key_padding_mask
is
not
None
and
key_padding_mask
.
dim
()
==
0
:
key_padding_mask
=
None
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
if
self
.
add_zero_attn
:
assert
v
is
not
None
src_len
+=
1
k
=
torch
.
cat
([
k
,
k
.
new_zeros
((
k
.
size
(
0
),
1
)
+
k
.
size
()[
2
:])],
dim
=
1
)
v
=
torch
.
cat
([
v
,
v
.
new_zeros
((
v
.
size
(
0
),
1
)
+
v
.
size
()[
2
:])],
dim
=
1
)
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
torch
.
zeros
(
key_padding_mask
.
size
(
0
),
1
)
.
type_as
(
key_padding_mask
),
],
dim
=
1
,
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
self
.
hard_mask_window
!=
0
:
hard_mask_window
=
self
.
hard_mask_window
if
0
<
self
.
hard_mask_window
<=
1
:
hard_mask_window
=
int
(
src_len
*
self
.
hard_mask_window
)
x1
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
-
1
,
1
)
x2
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
1
,
-
1
)
dis
=
x2
-
x1
mask_diag
=
torch
.
abs
(
dis
)
>
hard_mask_window
mask_diag
=
mask_diag
.
unsqueeze
(
0
)
attn_weights
=
attn_weights
.
masked_fill
(
mask_diag
,
float
(
"-inf"
))
if
self
.
multihead_gauss_mask_sigma
is
not
None
:
x1
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
-
1
,
1
)
x2
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
1
,
-
1
)
diag_growing
=
-
(
x1
-
x2
)
**
2
/
2.0
e_diag_gauss_mask
=
diag_growing
.
unsqueeze
(
0
)
.
repeat
(
self
.
num_heads
,
1
,
1
)
e_sigma_square
=
1
/
torch
.
square
(
self
.
multihead_gauss_mask_sigma
)
e_diag_gauss_mask_final
=
e_diag_gauss_mask
*
e_sigma_square
e_diag_gauss_mask_final
=
torch
.
unsqueeze
(
e_diag_gauss_mask_final
,
0
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
multihead_mask_weight
=
torch
.
sigmoid
(
self
.
multihead_mask_weight
.
unsqueeze
(
0
))
attn_weights
=
(
1
-
multihead_mask_weight
)
*
attn_weights
+
multihead_mask_weight
*
e_diag_gauss_mask_final
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
if
self
.
onnx_trace
:
attn_mask
=
attn_mask
.
repeat
(
attn_weights
.
size
(
0
),
1
,
1
)
attn_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
if
not
is_tpu
:
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
to
(
torch
.
bool
),
float
(
"-inf"
),
)
else
:
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
before_softmax
:
return
attn_weights
,
v
attn_weights_float
=
utils
.
softmax
(
attn_weights
,
dim
=-
1
,
onnx_trace
=
self
.
onnx_trace
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_probs
=
self
.
dropout_module
(
attn_weights
)
assert
v
is
not
None
attn
=
torch
.
bmm
(
attn_probs
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
if
self
.
onnx_trace
and
attn
.
size
(
1
)
==
1
:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn
=
attn
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
else
:
attn
=
attn
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
attn_weights
:
Optional
[
Tensor
]
=
None
if
need_weights
:
attn_weights
=
attn_weights_float
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
.
transpose
(
1
,
0
)
if
not
need_head_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
mean
(
dim
=
0
)
return
attn
,
attn_weights
@staticmethod
def
_append_prev_key_padding_mask
(
key_padding_mask
:
Optional
[
Tensor
],
prev_key_padding_mask
:
Optional
[
Tensor
],
batch_size
:
int
,
src_len
:
int
,
static_kv
:
bool
,
)
->
Optional
[
Tensor
]:
# saved key padding masks have shape (bsz, seq_len)
if
prev_key_padding_mask
is
not
None
and
static_kv
:
new_key_padding_mask
=
prev_key_padding_mask
elif
prev_key_padding_mask
is
not
None
and
key_padding_mask
is
not
None
:
new_key_padding_mask
=
torch
.
cat
(
[
prev_key_padding_mask
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif
prev_key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
(
batch_size
,
src_len
-
prev_key_padding_mask
.
size
(
1
)),
device
=
prev_key_padding_mask
.
device
,
)
new_key_padding_mask
=
torch
.
cat
(
[
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
)),
device
=
key_padding_mask
.
device
,
)
new_key_padding_mask
=
torch
.
cat
(
[
filler
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
else
:
new_key_padding_mask
=
prev_key_padding_mask
return
new_key_padding_mask
@torch.jit.export
def
reorder_incremental_state
(
self
,
incremental_state
:
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]],
new_order
:
Tensor
,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
input_buffer_k
=
input_buffer
[
k
]
if
input_buffer_k
is
not
None
:
if
self
.
encoder_decoder_attention
and
input_buffer_k
.
size
(
0
)
==
new_order
.
size
(
0
):
break
input_buffer
[
k
]
=
input_buffer_k
.
index_select
(
0
,
new_order
)
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
return
incremental_state
def
_get_input_buffer
(
self
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
)
->
Dict
[
str
,
Optional
[
Tensor
]]:
result
=
self
.
get_incremental_state
(
incremental_state
,
"attn_state"
)
if
result
is
not
None
:
return
result
else
:
empty_result
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{}
return
empty_result
def
_set_input_buffer
(
self
,
incremental_state
:
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]],
buffer
:
Dict
[
str
,
Optional
[
Tensor
]],
):
return
self
.
set_incremental_state
(
incremental_state
,
"attn_state"
,
buffer
)
def
apply_sparse_mask
(
self
,
attn_weights
,
tgt_len
:
int
,
src_len
:
int
,
bsz
:
int
):
return
attn_weights
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
prefix
=
name
+
"."
if
name
!=
""
else
""
items_to_add
=
{}
keys_to_remove
=
[]
for
k
in
state_dict
.
keys
():
if
k
.
endswith
(
prefix
+
"in_proj_weight"
):
# in_proj_weight used to be q + k + v with same dimensions
dim
=
int
(
state_dict
[
k
]
.
shape
[
0
]
/
3
)
items_to_add
[
prefix
+
"q_proj.weight"
]
=
state_dict
[
k
][:
dim
]
items_to_add
[
prefix
+
"k_proj.weight"
]
=
state_dict
[
k
][
dim
:
2
*
dim
]
items_to_add
[
prefix
+
"v_proj.weight"
]
=
state_dict
[
k
][
2
*
dim
:]
keys_to_remove
.
append
(
k
)
k_bias
=
prefix
+
"in_proj_bias"
if
k_bias
in
state_dict
.
keys
():
dim
=
int
(
state_dict
[
k
]
.
shape
[
0
]
/
3
)
items_to_add
[
prefix
+
"q_proj.bias"
]
=
state_dict
[
k_bias
][:
dim
]
items_to_add
[
prefix
+
"k_proj.bias"
]
=
state_dict
[
k_bias
][
dim
:
2
*
dim
]
items_to_add
[
prefix
+
"v_proj.bias"
]
=
state_dict
[
k_bias
][
2
*
dim
:]
keys_to_remove
.
append
(
prefix
+
"in_proj_bias"
)
for
k
in
keys_to_remove
:
del
state_dict
[
k
]
for
key
,
value
in
items_to_add
.
items
():
state_dict
[
key
]
=
value
fairseq/modules/transformer_layer.py
查看文件 @
ca78c4b8
...
...
@@ -12,7 +12,8 @@ from fairseq.modules import (
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
RelativeMultiheadAttention
,
LocalMultiheadAttention
,
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
...
...
@@ -103,6 +104,21 @@ class TransformerEncoderLayer(nn.Module):
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
elif
self
.
attn_type
==
"local"
:
hard_mask_window
=
getattr
(
args
,
"hard_mask_window"
,
0
)
gauss_mask_sigma
=
getattr
(
args
,
"gauss_mask_sigma"
,
0
)
init_mask_weight
=
getattr
(
args
,
"init_mask_weight"
,
0
)
return
LocalMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
hard_mask_window
=
hard_mask_window
,
gauss_mask_sigma
=
gauss_mask_sigma
,
init_mask_weight
=
init_mask_weight
)
else
:
print
(
"The encoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论