Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
f0e3290f
Commit
f0e3290f
authored
Apr 10, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the traditional relative multihead attention
parent
a3e0f4c2
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
814 行增加
和
4 行删除
+814
-4
fairseq/models/speech_to_text/s2t_conformer.py
+7
-0
fairseq/models/speech_to_text/s2t_sate.py
+7
-0
fairseq/models/speech_to_text/s2t_transformer.py
+14
-0
fairseq/models/transformer.py
+73
-3
fairseq/modules/__init__.py
+4
-0
fairseq/modules/layer_history.py
+316
-0
fairseq/modules/relative_multihead_attention.py
+367
-0
fairseq/modules/transformer_layer.py
+26
-1
没有找到文件。
fairseq/models/speech_to_text/s2t_conformer.py
查看文件 @
f0e3290f
...
@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
...
@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
base_architecture
(
args
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_s_relative"
)
def
s2t_conformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
s2t_conformer_s
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_xs"
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_xs"
)
def
s2t_conformer_xs
(
args
):
def
s2t_conformer_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
f0e3290f
...
@@ -357,6 +357,13 @@ def s2t_sate_s(args):
...
@@ -357,6 +357,13 @@ def s2t_sate_s(args):
base_architecture
(
args
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s_relative"
)
def
s2t_sate_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
s2t_sate_s
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_xs"
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_xs"
)
def
s2t_sate_xs
(
args
):
def
s2t_sate_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
f0e3290f
...
@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"if True, dont scale embeddings"
,
help
=
"if True, dont scale embeddings"
,
)
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--load-pretrained-encoder-from"
,
"--load-pretrained-encoder-from"
,
type
=
str
,
type
=
str
,
...
@@ -518,6 +522,9 @@ def base_architecture(args):
...
@@ -518,6 +522,9 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s"
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s"
)
def
s2t_transformer_s
(
args
):
def
s2t_transformer_s
(
args
):
...
@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
...
@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
base_architecture
(
args
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s_relative"
)
def
s2t_transformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
s2t_transformer_s
(
args
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_xs"
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_xs"
)
def
s2t_transformer_xs
(
args
):
def
s2t_transformer_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
...
...
fairseq/models/transformer.py
查看文件 @
f0e3290f
...
@@ -5,10 +5,11 @@
...
@@ -5,10 +5,11 @@
import
math
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
logging
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq
import
checkpoint_utils
,
utils
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models
import
(
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoder
,
...
@@ -35,6 +36,8 @@ from torch import Tensor
...
@@ -35,6 +36,8 @@ from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"transformer"
)
@register_model
(
"transformer"
)
class
TransformerModel
(
FairseqEncoderDecoderModel
):
class
TransformerModel
(
FairseqEncoderDecoderModel
):
...
@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
...
@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
help
=
'block size of quantization noise at training time'
)
help
=
'block size of quantization noise at training time'
)
parser
.
add_argument
(
'--quant-noise-scalar'
,
type
=
float
,
metavar
=
'D'
,
default
=
0
,
parser
.
add_argument
(
'--quant-noise-scalar'
,
type
=
float
,
metavar
=
'D'
,
default
=
0
,
help
=
'scalar quantization noise and scalar quantization at training time'
)
help
=
'scalar quantization noise and scalar quantization at training time'
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
# args for loading pre-trained models
parser
.
add_argument
(
"--load-pretrained-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take encoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--load-pretrained-decoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take decoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--encoder-freeze-module"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"freeze the module of the encoder"
,
)
parser
.
add_argument
(
"--decoder-freeze-module"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"freeze the module of the decoder"
,
)
# fmt: on
# fmt: on
@classmethod
@classmethod
...
@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
...
@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
if
getattr
(
args
,
"offload_activations"
,
False
):
if
getattr
(
args
,
"offload_activations"
,
False
):
args
.
checkpoint_activations
=
True
# offloading implies checkpointing
args
.
checkpoint_activations
=
True
# offloading implies checkpointing
encoder
=
cls
.
build_encoder
(
args
,
src_dict
,
encoder_embed_tokens
)
encoder
=
cls
.
build_encoder
(
args
,
src_dict
,
encoder_embed_tokens
)
if
getattr
(
args
,
"encoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
encoder
,
args
.
encoder_freeze_module
)
logging
.
info
(
"freeze the encoder module: {}"
.
format
(
args
.
encoder_freeze_module
))
decoder
=
cls
.
build_decoder
(
args
,
tgt_dict
,
decoder_embed_tokens
)
decoder
=
cls
.
build_decoder
(
args
,
tgt_dict
,
decoder_embed_tokens
)
if
getattr
(
args
,
"decoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
decoder
,
args
.
decoder_freeze_module
)
logging
.
info
(
"freeze the decoder module: {}"
.
format
(
args
.
decoder_freeze_module
))
if
not
args
.
share_all_embeddings
:
if
not
args
.
share_all_embeddings
:
encoder
=
fsdp_wrap
(
encoder
,
min_num_params
=
1e8
)
encoder
=
fsdp_wrap
(
encoder
,
min_num_params
=
1e8
)
decoder
=
fsdp_wrap
(
decoder
,
min_num_params
=
1e8
)
decoder
=
fsdp_wrap
(
decoder
,
min_num_params
=
1e8
)
...
@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
...
@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod
@classmethod
def
build_encoder
(
cls
,
args
,
src_dict
,
embed_tokens
):
def
build_encoder
(
cls
,
args
,
src_dict
,
embed_tokens
):
return
TransformerEncoder
(
args
,
src_dict
,
embed_tokens
)
encoder
=
TransformerEncoder
(
args
,
src_dict
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
return
encoder
@classmethod
@classmethod
def
build_decoder
(
cls
,
args
,
tgt_dict
,
embed_tokens
):
def
build_decoder
(
cls
,
args
,
tgt_dict
,
embed_tokens
):
return
TransformerDecoder
(
decoder
=
TransformerDecoder
(
args
,
args
,
tgt_dict
,
tgt_dict
,
embed_tokens
,
embed_tokens
,
no_encoder_attn
=
getattr
(
args
,
"no_cross_attention"
,
False
),
no_encoder_attn
=
getattr
(
args
,
"no_cross_attention"
,
False
),
)
)
if
getattr
(
args
,
"load_pretrained_decoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained decoder from: "
f
"{args.load_pretrained_decoder_from}"
)
decoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
decoder
,
checkpoint
=
args
.
load_pretrained_decoder_from
,
strict
=
False
)
return
decoder
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
# Current workaround is to add union of all arguments in child classes.
def
forward
(
def
forward
(
...
@@ -1073,6 +1134,15 @@ def base_architecture(args):
...
@@ -1073,6 +1134,15 @@ def base_architecture(args):
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
args
.
quant_noise_pq_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
args
.
quant_noise_scalar
=
getattr
(
args
,
"quant_noise_scalar"
,
0
)
args
.
quant_noise_scalar
=
getattr
(
args
,
"quant_noise_scalar"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"transformer"
,
"transformer_relative"
)
def
transformer_rpr
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
base_architecture
(
args
)
@register_model_architecture
(
"transformer"
,
"transformer_iwslt_de_en"
)
@register_model_architecture
(
"transformer"
,
"transformer_iwslt_de_en"
)
...
...
fairseq/modules/__init__.py
查看文件 @
f0e3290f
...
@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
...
@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
from
.gumbel_vector_quantizer
import
GumbelVectorQuantizer
from
.gumbel_vector_quantizer
import
GumbelVectorQuantizer
from
.kmeans_vector_quantizer
import
KmeansVectorQuantizer
from
.kmeans_vector_quantizer
import
KmeansVectorQuantizer
from
.layer_drop
import
LayerDropModuleList
from
.layer_drop
import
LayerDropModuleList
from
.layer_history
import
CreateLayerHistory
from
.layer_norm
import
Fp32LayerNorm
,
LayerNorm
from
.layer_norm
import
Fp32LayerNorm
,
LayerNorm
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.lightweight_convolution
import
LightweightConv
,
LightweightConv1dTBC
from
.lightweight_convolution
import
LightweightConv
,
LightweightConv1dTBC
...
@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
...
@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from
.multihead_attention
import
MultiheadAttention
from
.multihead_attention
import
MultiheadAttention
from
.positional_embedding
import
PositionalEmbedding
from
.positional_embedding
import
PositionalEmbedding
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
from
.relative_multihead_attention
import
RelativeMultiheadAttention
from
.same_pad
import
SamePad
from
.same_pad
import
SamePad
from
.scalar_bias
import
ScalarBias
from
.scalar_bias
import
ScalarBias
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
...
@@ -47,6 +49,7 @@ __all__ = [
...
@@ -47,6 +49,7 @@ __all__ = [
"ConformerEncoderLayer"
,
"ConformerEncoderLayer"
,
"ConvolutionModule"
,
"ConvolutionModule"
,
"ConvTBC"
,
"ConvTBC"
,
"CreateLayerHistory"
,
"cross_entropy"
,
"cross_entropy"
,
"DownsampledMultiHeadAttention"
,
"DownsampledMultiHeadAttention"
,
"DynamicConv1dTBC"
,
"DynamicConv1dTBC"
,
...
@@ -69,6 +72,7 @@ __all__ = [
...
@@ -69,6 +72,7 @@ __all__ = [
"MultiheadAttention"
,
"MultiheadAttention"
,
"PositionalEmbedding"
,
"PositionalEmbedding"
,
"RelPositionMultiheadAttention"
,
"RelPositionMultiheadAttention"
,
"RelativeMultiheadAttention"
,
"SamePad"
,
"SamePad"
,
"ScalarBias"
,
"ScalarBias"
,
"SinusoidalPositionalEmbedding"
,
"SinusoidalPositionalEmbedding"
,
...
...
fairseq/modules/layer_history.py
0 → 100644
查看文件 @
f0e3290f
import
torch
import
torch.nn
as
nn
from
fairseq.models.transformer
import
LayerNorm
import
queue
import
numpy
as
np
def
CreateLayerHistory
(
args
,
is_encoder
):
history_type
=
args
.
encoder_history_type
if
is_encoder
else
args
.
decoder_history_type
if
history_type
is
None
:
return
None
elif
history_type
==
"residual"
:
return
ResidualLayerHistory
(
args
,
is_encoder
)
elif
history_type
==
"dense"
:
return
DenseLayerHistory
(
args
,
is_encoder
)
elif
history_type
==
"learnable_dense"
:
return
LearnableDenseLayerHistory
(
args
,
is_encoder
)
elif
history_type
==
"learnable_dense_mask"
:
return
LearnableDenseMaskLayerHistory
(
args
,
is_encoder
)
elif
history_type
==
"learnable_dense_nonorm"
:
return
LearnableDenseNoNormLayerHistory
(
args
,
is_encoder
)
elif
history_type
==
"gru"
:
return
GruLayerHistory
(
args
,
is_encoder
)
else
:
raise
ValueError
class
BaseLayerHistory
(
nn
.
Module
):
def
__init__
(
self
,
args
,
is_encoder
):
super
(
BaseLayerHistory
,
self
)
.
__init__
()
self
.
is_encoder
=
is_encoder
self
.
normalize_before
=
args
.
encoder_normalize_before
if
is_encoder
else
args
.
decoder_normalize_before
# the first layer (aka. embedding layer) does not have layer normalization
layers
=
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
dim
=
args
.
encoder_embed_dim
if
is_encoder
else
args
.
decoder_embed_dim
self
.
layer_norms
=
nn
.
ModuleList
(
LayerNorm
(
dim
)
for
_
in
range
(
layers
))
def
add
(
self
,
layer
):
raise
NotImplemented
def
pop
(
self
):
raise
NotImplemented
def
clean
(
self
):
raise
NotImplemented
class
ResidualLayerHistory
(
BaseLayerHistory
):
"""
x_n = x_{n-1} + y_{n-1}
"""
def
__init__
(
self
,
args
,
is_encoder
):
super
(
ResidualLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
self
.
count
=
0
self
.
x
=
None
self
.
y
=
None
def
add
(
self
,
layer
):
if
self
.
x
is
None
:
self
.
x
=
layer
self
.
count
+=
1
return
self
.
count
+=
1
if
self
.
normalize_before
:
self
.
y
=
self
.
layer_norms
[
self
.
count
-
2
](
layer
)
else
:
self
.
y
=
layer
def
pop
(
self
):
assert
self
.
x
is
not
None
if
self
.
y
is
None
:
return
self
.
x
ret
=
self
.
x
+
self
.
y
if
not
self
.
normalize_before
:
ret
=
self
.
layer_norms
[
self
.
count
-
2
](
ret
)
self
.
x
=
ret
return
ret
def
clean
(
self
):
self
.
x
=
None
self
.
y
=
None
self
.
count
=
0
class
DenseLayerHistory
(
BaseLayerHistory
):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def
__init__
(
self
,
args
,
is_encoder
):
super
(
DenseLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
self
.
sum
=
None
self
.
count
=
0
self
.
individuals
=
None
# store past individual value, used for windows_size > 0
self
.
integration_type
=
getattr
(
args
,
'encoder_integration_type'
,
'avg'
)
if
is_encoder
else
\
getattr
(
args
,
'decoder_integration_type'
,
'avg'
)
# windows = 1 means not use residual connection
self
.
windows_size
=
getattr
(
args
,
'encoder_windows_size'
,
-
1
)
if
is_encoder
else
\
getattr
(
args
,
'decoder_windows_size'
,
-
1
)
if
self
.
windows_size
>
0
:
assert
self
.
windows_size
<=
(
args
.
encoder_layers
+
1
)
if
is_encoder
else
(
args
.
decoder_layers
+
1
)
self
.
individuals
=
queue
.
Queue
(
self
.
windows_size
)
def
add
(
self
,
layer
):
self
.
count
+=
1
# first layer
if
self
.
sum
is
None
:
self
.
sum
=
layer
if
self
.
individuals
is
not
None
:
self
.
individuals
.
put
(
layer
)
return
# following layer
if
self
.
normalize_before
:
layer
=
self
.
layer_norms
[
self
.
count
-
2
](
layer
)
self
.
sum
=
self
.
sum
+
layer
if
self
.
windows_size
!=
-
1
and
self
.
count
>
self
.
windows_size
:
self
.
sum
=
self
.
sum
-
self
.
individuals
.
get
()
if
self
.
individuals
is
not
None
:
self
.
individuals
.
put
(
layer
)
def
pop
(
self
):
assert
self
.
sum
is
not
None
if
self
.
integration_type
==
'sum'
:
ret
=
self
.
sum
else
:
if
self
.
windows_size
==
-
1
:
ret
=
self
.
sum
/
self
.
count
else
:
ret
=
self
.
sum
/
min
(
self
.
count
,
self
.
windows_size
)
if
self
.
count
==
1
or
self
.
normalize_before
:
return
ret
return
self
.
layer_norms
[
self
.
count
-
2
](
ret
)
def
clean
(
self
):
self
.
sum
=
None
self
.
count
=
0
if
self
.
individuals
is
not
None
:
self
.
individuals
.
queue
.
clear
()
class
LearnableDenseLayerHistory
(
BaseLayerHistory
):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def
__init__
(
self
,
args
,
is_encoder
):
super
(
LearnableDenseLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
self
.
sum
=
None
self
.
count
=
0
self
.
layer_num
=
1
+
(
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
)
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
layer_num
,
self
.
layer_num
)
.
fill_
(
1.0
)
.
tril
())
self
.
weight
.
data
=
self
.
weight
.
data
/
self
.
weight
.
data
.
sum
(
1
,
keepdim
=
True
)
def
extra_repr
(
self
):
return
'n_layers={layer_num}, '
.
format
(
**
self
.
__dict__
)
def
add
(
self
,
layer
):
self
.
count
+=
1
# first layer
if
self
.
sum
is
None
:
self
.
sum
=
layer
self
.
layers
.
append
(
layer
)
return
# following layer
if
self
.
normalize_before
:
layer
=
self
.
layer_norms
[
self
.
count
-
2
](
layer
)
self
.
layers
.
append
(
layer
)
def
pop
(
self
):
assert
len
(
self
.
layers
)
>
0
ret
=
(
torch
.
stack
(
self
.
layers
,
0
)
*
self
.
weight
[
self
.
count
-
1
,
:
self
.
count
]
.
view
(
-
1
,
1
,
1
,
1
))
.
sum
(
0
)
if
self
.
count
==
1
or
self
.
normalize_before
:
return
ret
return
self
.
layer_norms
[
self
.
count
-
2
](
ret
)
def
clean
(
self
):
self
.
sum
=
None
self
.
count
=
0
self
.
layers
=
[]
def
get_loss
(
self
):
return
(
0.5
*
(
self
.
weight
.
sum
(
1
)
-
1.0
)
**
2
)
.
mean
()
class
LearnableDenseMaskLayerHistory
(
BaseLayerHistory
):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def
__init__
(
self
,
args
,
is_encoder
):
super
(
LearnableDenseMaskLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
self
.
sum
=
None
self
.
count
=
0
self
.
layer_num
=
1
+
(
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
)
if
is_encoder
:
self
.
weight_mask
=
np
.
loadtxt
(
"encoder_mask.txt"
,
dtype
=
float
,
delimiter
=
' '
)
else
:
self
.
weight_mask
=
np
.
loadtxt
(
"decoder_mask.txt"
,
dtype
=
float
,
delimiter
=
' '
)
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
layer_num
,
self
.
layer_num
)
.
fill_
(
1.0
)
.
tril
())
self
.
weight
.
data
=
self
.
weight
.
data
/
self
.
weight
.
data
.
sum
(
1
,
keepdim
=
True
)
def
add
(
self
,
layer
):
self
.
count
+=
1
# first layer
if
self
.
sum
is
None
:
self
.
sum
=
layer
self
.
layers
.
append
(
layer
)
return
# following layer
if
self
.
normalize_before
:
layer
=
self
.
layer_norms
[
self
.
count
-
2
](
layer
)
self
.
layers
.
append
(
layer
)
def
pop
(
self
):
assert
len
(
self
.
layers
)
>
0
ret
=
(
torch
.
stack
(
self
.
layers
,
0
)
*
self
.
weight
[
self
.
count
-
1
,
:
self
.
count
]
.
view
(
-
1
,
1
,
1
,
1
))
.
sum
(
0
)
if
self
.
count
==
1
or
self
.
normalize_before
:
return
ret
return
self
.
layer_norms
[
self
.
count
-
2
](
ret
)
def
clean
(
self
):
self
.
sum
=
None
self
.
count
=
0
self
.
layers
=
[]
def
get_loss
(
self
):
return
(
0.5
*
(
self
.
weight
.
sum
(
1
)
-
1.0
)
**
2
)
.
mean
()
class
LearnableDenseNoNormLayerHistory
(
BaseLayerHistory
):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def
__init__
(
self
,
args
,
is_encoder
):
super
(
LearnableDenseNoNormLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
self
.
sum
=
None
self
.
count
=
0
self
.
layer_num
=
1
+
(
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
)
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
layer_num
,
self
.
layer_num
)
.
fill_
(
1.0
)
.
tril
())
self
.
weight
.
data
=
self
.
weight
.
data
/
self
.
weight
.
data
.
sum
(
1
,
keepdim
=
True
)
self
.
layers
=
[]
self
.
layer_norms
=
None
def
add
(
self
,
layer
):
self
.
count
+=
1
# first layer
if
self
.
sum
is
None
:
self
.
sum
=
layer
self
.
layers
.
append
(
layer
)
return
self
.
layers
.
append
(
layer
)
def
pop
(
self
):
assert
len
(
self
.
layers
)
>
0
ret
=
(
torch
.
stack
(
self
.
layers
,
0
)
*
self
.
weight
[
self
.
count
-
1
,
:
self
.
count
]
.
view
(
-
1
,
1
,
1
,
1
))
.
sum
(
0
)
if
self
.
count
==
1
or
self
.
normalize_before
:
return
ret
return
self
.
layer_norms
[
self
.
count
-
2
](
ret
)
def
clean
(
self
):
self
.
sum
=
None
self
.
count
=
0
self
.
layers
=
[]
class
GruLayerHistory
(
BaseLayerHistory
):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def
__init__
(
self
,
args
,
is_encoder
):
super
(
GruLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
self
.
count
=
0
self
.
gru
=
nn
.
GRUCell
(
args
.
encoder_embed_dim
,
args
.
encoder_embed_dim
)
self
.
gru_cells
=
[]
self
.
layer_norms
=
nn
.
ModuleList
(
LayerNorm
(
args
.
encoder_embed_dim
)
for
_
in
range
(
args
.
decoder_layers
+
1
))
self
.
decoder_layers
=
args
.
decoder_layers
def
compute_gru
(
self
,
layer_output
):
if
len
(
self
.
gru_cells
)
==
0
:
self
.
gru_cells
.
append
(
layer_output
)
return
self
.
layer_norms
[
self
.
count
](
layer_output
)
self
.
count
+=
1
prev_h
=
self
.
gru_cells
[
-
1
]
L
,
B
,
H
=
layer_output
.
size
()
layer_output
=
torch
.
reshape
(
layer_output
,
(
-
1
,
H
))
prev_h
=
torch
.
reshape
(
prev_h
,
(
-
1
,
H
))
h
=
self
.
gru
(
layer_output
,
prev_h
)
.
view
(
L
,
B
,
H
)
self
.
gru_cells
.
append
(
h
)
if
self
.
count
!=
self
.
decoder_layers
:
return
self
.
layer_norms
[
self
.
count
](
h
)
else
:
return
None
def
clean
(
self
):
self
.
gru_cells
=
[]
self
.
count
=
0
fairseq/modules/relative_multihead_attention.py
0 → 100644
查看文件 @
f0e3290f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.modules.multihead_attention
import
MultiheadAttention
from
torch
import
Tensor
,
nn
from
torch.nn
import
Parameter
class
RelativeMultiheadAttention
(
MultiheadAttention
):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
kdim
=
None
,
vdim
=
None
,
dropout
=
0.0
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
,
self_attention
=
False
,
encoder_decoder_attention
=
False
,
q_noise
=
0.0
,
qn_block_size
=
8
,
max_relative_length
=-
1
,
k_only
=
True
):
super
()
.
__init__
(
embed_dim
,
num_heads
,
kdim
,
vdim
,
dropout
,
bias
,
add_bias_kv
,
add_zero_attn
,
self_attention
,
encoder_decoder_attention
,
q_noise
,
qn_block_size
,
)
self
.
max_relative_length
=
max_relative_length
self
.
k_only
=
k_only
self
.
relative_position_keys
=
Parameter
(
torch
.
Tensor
(
2
*
self
.
max_relative_length
+
1
,
self
.
head_dim
))
if
not
self
.
k_only
:
self
.
relative_position_values
=
Parameter
(
torch
.
Tensor
(
2
*
self
.
max_relative_length
+
1
,
self
.
head_dim
))
nn
.
init
.
xavier_uniform_
(
self
.
relative_position_keys
)
if
not
self
.
k_only
:
nn
.
init
.
xavier_uniform_
(
self
.
relative_position_values
)
def
forward
(
self
,
query
,
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
need_weights
:
bool
=
True
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
before_softmax
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if
need_head_weights
:
need_weights
=
True
is_tpu
=
query
.
device
.
type
==
"xla"
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute
# key and value if they are static
if
static_kv
:
assert
self
.
encoder_decoder_attention
and
not
self
.
self_attention
key
=
value
=
None
else
:
saved_state
=
None
if
self
.
self_attention
:
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
query
)
v
=
self
.
v_proj
(
query
)
elif
self
.
encoder_decoder_attention
:
# encoder-decoder attention
q
=
self
.
q_proj
(
query
)
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
key
)
else
:
assert
key
is
not
None
and
value
is
not
None
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
value
)
q
*=
self
.
scaling
if
self
.
bias_k
is
not
None
:
assert
self
.
bias_v
is
not
None
k
=
torch
.
cat
([
k
,
self
.
bias_k
.
repeat
(
1
,
bsz
,
1
)])
v
=
torch
.
cat
([
v
,
self
.
bias_v
.
repeat
(
1
,
bsz
,
1
)])
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
key_padding_mask
.
new_zeros
(
key_padding_mask
.
size
(
0
),
1
),
],
dim
=
1
,
)
q
=
(
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
k
is
not
None
:
k
=
(
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
v
is
not
None
:
v
=
(
v
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
saved_state
is
not
None
:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if
"prev_key"
in
saved_state
:
_prev_key
=
saved_state
[
"prev_key"
]
assert
_prev_key
is
not
None
prev_key
=
_prev_key
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
k
=
prev_key
else
:
assert
k
is
not
None
k
=
torch
.
cat
([
prev_key
,
k
],
dim
=
1
)
if
"prev_value"
in
saved_state
:
_prev_value
=
saved_state
[
"prev_value"
]
assert
_prev_value
is
not
None
prev_value
=
_prev_value
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
v
=
prev_value
else
:
assert
v
is
not
None
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
prev_key_padding_mask
:
Optional
[
Tensor
]
=
None
if
"prev_key_padding_mask"
in
saved_state
:
prev_key_padding_mask
=
saved_state
[
"prev_key_padding_mask"
]
assert
k
is
not
None
and
v
is
not
None
key_padding_mask
=
MultiheadAttention
.
_append_prev_key_padding_mask
(
key_padding_mask
=
key_padding_mask
,
prev_key_padding_mask
=
prev_key_padding_mask
,
batch_size
=
bsz
,
src_len
=
k
.
size
(
1
),
static_kv
=
static_kv
,
)
saved_state
[
"prev_key"
]
=
k
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_value"
]
=
v
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_key_padding_mask"
]
=
key_padding_mask
# In this branch incremental_state is never None
assert
incremental_state
is
not
None
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
assert
k
is
not
None
src_len
=
k
.
size
(
1
)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if
key_padding_mask
is
not
None
and
key_padding_mask
.
dim
()
==
0
:
key_padding_mask
=
None
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
if
self
.
add_zero_attn
:
assert
v
is
not
None
src_len
+=
1
k
=
torch
.
cat
([
k
,
k
.
new_zeros
((
k
.
size
(
0
),
1
)
+
k
.
size
()[
2
:])],
dim
=
1
)
v
=
torch
.
cat
([
v
,
v
.
new_zeros
((
v
.
size
(
0
),
1
)
+
v
.
size
()[
2
:])],
dim
=
1
)
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
torch
.
zeros
(
key_padding_mask
.
size
(
0
),
1
)
.
type_as
(
key_padding_mask
),
],
dim
=
1
,
)
relative_positions_matrix
=
self
.
_generate_relative_positions_matrix
(
src_len
,
self
.
max_relative_length
,
incremental_state
)
if
self
.
k_only
:
relation_keys
=
F
.
embedding
(
relative_positions_matrix
.
long
()
.
cuda
(),
self
.
relative_position_keys
)
else
:
relation_keys
=
F
.
embedding
(
relative_positions_matrix
.
long
()
.
cuda
(),
self
.
relative_position_keys
)
relation_values
=
F
.
embedding
(
relative_positions_matrix
.
long
()
.
cuda
(),
self
.
relative_position_values
)
attn_weights
=
self
.
_relative_attention_inner
(
q
,
k
,
relation_keys
,
transpose
=
True
)
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
if
self
.
onnx_trace
:
attn_mask
=
attn_mask
.
repeat
(
attn_weights
.
size
(
0
),
1
,
1
)
attn_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
if
not
is_tpu
:
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
to
(
torch
.
bool
),
float
(
"-inf"
),
)
else
:
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
before_softmax
:
return
attn_weights
,
v
attn_weights_float
=
utils
.
softmax
(
attn_weights
,
dim
=-
1
,
onnx_trace
=
self
.
onnx_trace
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_probs
=
self
.
dropout_module
(
attn_weights
)
assert
v
is
not
None
# key only mode
if
self
.
k_only
:
attn
=
torch
.
bmm
(
attn_probs
,
v
)
# original implementation
else
:
attn
=
self
.
_relative_attention_inner
(
attn_probs
,
v
,
relation_values
,
transpose
=
False
)
# attn = torch.bmm(attn_probs, v)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
if
self
.
onnx_trace
and
attn
.
size
(
1
)
==
1
:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn
=
attn
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
else
:
attn
=
attn
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
attn_weights
:
Optional
[
Tensor
]
=
None
if
need_weights
:
attn_weights
=
attn_weights_float
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
.
transpose
(
1
,
0
)
if
not
need_head_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
mean
(
dim
=
0
)
return
attn
,
attn_weights
def
_generate_relative_positions_matrix
(
self
,
length
,
max_relative_length
,
incremental_state
):
if
not
incremental_state
:
# training process
range_vec
=
torch
.
arange
(
length
)
range_mat
=
range_vec
.
repeat
(
length
,
1
)
distance_mat
=
range_mat
-
range_mat
.
transpose
(
0
,
1
)
else
:
distance_mat
=
torch
.
range
(
-
length
+
1
,
0
)
.
view
(
1
,
-
1
)
distance_mat_clipped
=
torch
.
clamp
(
distance_mat
,
-
max_relative_length
,
max_relative_length
)
# position difference.
final_mat
=
distance_mat_clipped
+
max_relative_length
return
final_mat
def
_relative_attention_inner
(
self
,
x
,
y
,
z
,
transpose
=
True
):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
Args:
x: Tensor with shape [batch_size*heads, length, length or depth].
y: Tensor with shap e [batch_size*heads, length, depth].
z: Tensor with shape [length, length, depth].
transpose: Whether to tranpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
A Tensor with shape [batch_size*heads, length, length or depth].
wq: this function actually does 'X(Y+Z)', where Z is vector,
but factor above formular as: 'XY + XZ'
"""
batch_size_mul_head
=
x
.
size
()[
0
]
length
=
z
.
size
()[
0
]
# print(batch_size_mul_head, length)
# xy_matmul is [batch_size*heads, length, length or depth]
if
transpose
:
y
=
y
.
transpose
(
1
,
2
)
xy_matmul
=
torch
.
bmm
(
x
,
y
)
# x_t is [length, batch_size * heads, length or depth]
x_t
=
x
.
transpose
(
0
,
1
)
# x_tz_matmul is [length, batch_size * heads, length or depth]
if
transpose
:
z
=
z
.
transpose
(
1
,
2
)
x_tz_matmul
=
torch
.
bmm
(
x_t
,
z
)
.
transpose
(
0
,
1
)
.
view
(
batch_size_mul_head
,
length
,
-
1
)
attn
=
xy_matmul
+
x_tz_matmul
return
attn
fairseq/modules/transformer_layer.py
查看文件 @
f0e3290f
...
@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
...
@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq
import
utils
from
fairseq.modules
import
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
from
fairseq.modules
import
(
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
args
.
max_relative_length
,
)
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
...
@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
args
.
max_relative_length
,
)
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论