Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
922ef3d9
Commit
922ef3d9
authored
Mar 24, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the implementation of conformer
parent
0150d9ac
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
623 行增加
和
73 行删除
+623
-73
fairseq/models/speech_to_text/__init__.py
+1
-0
fairseq/models/speech_to_text/s2t_conformer.py
+227
-0
fairseq/models/speech_to_text/s2t_transformer.py
+3
-2
fairseq/modules/__init__.py
+4
-0
fairseq/modules/conformer_layer.py
+236
-0
fairseq/modules/convolution.py
+79
-0
fairseq/modules/positional_embedding.py
+9
-2
fairseq/modules/rel_position_multihead_attention.py
+30
-69
fairseq/modules/sinusoidal_positional_embedding.py
+34
-0
没有找到文件。
fairseq/models/speech_to_text/__init__.py
查看文件 @
922ef3d9
...
@@ -6,3 +6,4 @@
...
@@ -6,3 +6,4 @@
from
.berard
import
*
# noqa
from
.berard
import
*
# noqa
from
.convtransformer
import
*
# noqa
from
.convtransformer
import
*
# noqa
from
.s2t_transformer
import
*
# noqa
from
.s2t_transformer
import
*
# noqa
from
.s2t_conformer
import
*
# noqa
fairseq/models/speech_to_text/s2t_conformer.py
0 → 100644
查看文件 @
922ef3d9
#!/usr/bin/env python3
import
logging
import
torch.nn
as
nn
from
fairseq
import
checkpoint_utils
from
fairseq.data.data_utils
import
lengths_to_padding_mask
from
fairseq.models
import
(
register_model
,
register_model_architecture
,
)
from
fairseq.models.speech_to_text
import
S2TTransformerModel
,
S2TTransformerEncoder
from
fairseq.modules
import
(
ConformerEncoderLayer
,
)
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"s2t_conformer"
)
class
S2TConformerModel
(
S2TTransformerModel
):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
S2TTransformerModel
.
add_args
(
parser
)
parser
.
add_argument
(
"--macaron-style"
,
default
=
False
,
type
=
bool
,
help
=
"Whether to use macaron style for positionwise layer"
,
)
# Attention
parser
.
add_argument
(
"--zero-triu"
,
default
=
False
,
type
=
bool
,
help
=
"If true, zero the uppper triangular part of attention matrix."
,
)
# Relative positional encoding
parser
.
add_argument
(
"--rel-pos-type"
,
type
=
str
,
default
=
"legacy"
,
choices
=
[
"legacy"
,
"latest"
],
help
=
"Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816."
,
)
# CNN module
parser
.
add_argument
(
"--use-cnn-module"
,
default
=
False
,
type
=
bool
,
help
=
"Use convolution module or not"
,
)
parser
.
add_argument
(
"--cnn-module-kernel"
,
default
=
31
,
type
=
int
,
help
=
"Kernel size of convolution module."
,
)
pass
@classmethod
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TConformerEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
)
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
return
encoder
class
S2TConformerEncoder
(
S2TTransformerEncoder
):
"""Speech-to-text Conformer encoder that consists of input subsampler and
Transformer encoder."""
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
args
,
task
,
embed_tokens
)
self
.
transformer_layers
=
nn
.
ModuleList
(
[
ConformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
encoder_layers
)]
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
x
=
self
.
embed_scale
*
x
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
positions
=
self
.
embed_positions
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
if
self
.
attn_type
!=
"rel_selfattn"
:
x
+=
positions
x
=
self
.
dropout_module
(
x
)
positions
=
self
.
dropout_module
(
positions
)
for
layer
in
self
.
transformer_layers
:
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[],
# B x T x C
"encoder_states"
:
[],
# List[T x B x C]
"src_tokens"
:
[],
"src_lengths"
:
[],
}
@register_model_architecture
(
model_name
=
"s2t_conformer"
,
arch_name
=
"s2t_conformer"
)
def
base_architecture
(
args
):
# Convolutional subsampler
args
.
conv_kernel_sizes
=
getattr
(
args
,
"conv_kernel_sizes"
,
"5,5"
)
args
.
conv_channels
=
getattr
(
args
,
"conv_channels"
,
1024
)
# Conformer
args
.
macaron_style
=
getattr
(
args
,
"macaron_style"
,
True
)
args
.
macaron_style
=
getattr
(
args
,
"use_cnn_module"
,
True
)
args
.
macaron_style
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
True
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
True
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
args
.
dropout
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
args
.
dropout
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0.0
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_s"
)
def
s2t_conformer_s
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
8
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_xs"
)
def
s2t_conformer_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
3
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.3
)
s2t_conformer_s
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_sp"
)
def
s2t_conformer_sp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_conformer_s
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_m"
)
def
s2t_conformer_m
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
512
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.15
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_mp"
)
def
s2t_conformer_mp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_conformer_m
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_l"
)
def
s2t_conformer_l
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
1024
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_lp"
)
def
s2t_conformer_lp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_conformer_l
(
args
)
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
922ef3d9
...
@@ -307,8 +307,9 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -307,8 +307,9 @@ class S2TTransformerEncoder(FairseqEncoder):
[
int
(
k
)
for
k
in
args
.
conv_kernel_sizes
.
split
(
","
)],
[
int
(
k
)
for
k
in
args
.
conv_kernel_sizes
.
split
(
","
)],
)
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
embed_positions
=
PositionalEmbedding
(
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
,
pos_emb_type
=
self
.
attn_type
)
)
self
.
transformer_layers
=
nn
.
ModuleList
(
self
.
transformer_layers
=
nn
.
ModuleList
(
...
@@ -319,7 +320,6 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -319,7 +320,6 @@ class S2TTransformerEncoder(FairseqEncoder):
else
:
else
:
self
.
layer_norm
=
None
self
.
layer_norm
=
None
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
use_ctc
=
(
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
\
self
.
use_ctc
=
(
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
\
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
)
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
)
if
self
.
use_ctc
:
if
self
.
use_ctc
:
...
@@ -348,6 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -348,6 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if
self
.
attn_type
!=
"rel_selfattn"
:
if
self
.
attn_type
!=
"rel_selfattn"
:
x
+=
positions
x
+=
positions
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
positions
=
self
.
dropout_module
(
positions
)
for
layer
in
self
.
transformer_layers
:
for
layer
in
self
.
transformer_layers
:
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
...
...
fairseq/modules/__init__.py
查看文件 @
922ef3d9
...
@@ -8,6 +8,7 @@ from .adaptive_input import AdaptiveInput
...
@@ -8,6 +8,7 @@ from .adaptive_input import AdaptiveInput
from
.adaptive_softmax
import
AdaptiveSoftmax
from
.adaptive_softmax
import
AdaptiveSoftmax
from
.beamable_mm
import
BeamableMM
from
.beamable_mm
import
BeamableMM
from
.character_token_embedder
import
CharacterTokenEmbedder
from
.character_token_embedder
import
CharacterTokenEmbedder
from
.convolution
import
ConvolutionModule
from
.conv_tbc
import
ConvTBC
from
.conv_tbc
import
ConvTBC
from
.cross_entropy
import
cross_entropy
from
.cross_entropy
import
cross_entropy
from
.downsampled_multihead_attention
import
DownsampledMultiHeadAttention
from
.downsampled_multihead_attention
import
DownsampledMultiHeadAttention
...
@@ -36,12 +37,15 @@ from .transpose_last import TransposeLast
...
@@ -36,12 +37,15 @@ from .transpose_last import TransposeLast
from
.unfold
import
unfold1d
from
.unfold
import
unfold1d
from
.transformer_layer
import
TransformerDecoderLayer
,
TransformerEncoderLayer
from
.transformer_layer
import
TransformerDecoderLayer
,
TransformerEncoderLayer
from
.vggblock
import
VGGBlock
from
.vggblock
import
VGGBlock
from
.conformer_layer
import
ConformerEncoderLayer
__all__
=
[
__all__
=
[
"AdaptiveInput"
,
"AdaptiveInput"
,
"AdaptiveSoftmax"
,
"AdaptiveSoftmax"
,
"BeamableMM"
,
"BeamableMM"
,
"CharacterTokenEmbedder"
,
"CharacterTokenEmbedder"
,
"ConformerEncoderLayer"
,
"ConvolutionModule"
,
"ConvTBC"
,
"ConvTBC"
,
"cross_entropy"
,
"cross_entropy"
,
"DownsampledMultiHeadAttention"
,
"DownsampledMultiHeadAttention"
,
...
...
fairseq/modules/conformer_layer.py
0 → 100644
查看文件 @
922ef3d9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq.modules
import
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
ConvolutionModule
# from .layer_norm import LayerNorm
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
class
ConformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def
__init__
(
self
,
args
):
super
()
.
__init__
()
self
.
args
=
args
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
args
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation
=
getattr
(
args
,
'activation_fn'
,
'relu'
)
or
"relu"
)
activation_dropout_p
=
getattr
(
args
,
"activation_dropout"
,
0
)
or
0
if
activation_dropout_p
==
0
:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p
=
getattr
(
args
,
"relu_dropout"
,
0
)
or
0
self
.
activation_dropout_module
=
FairseqDropout
(
float
(
activation_dropout_p
),
module_name
=
self
.
__class__
.
__name__
)
if
args
.
macaron_style
:
self
.
macaron_fc1
=
self
.
build_fc1
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
macaron_fc2
=
self
.
build_fc2
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
macaron_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
ff_scale
=
0.5
else
:
self
.
macaron_fc1
=
None
self
.
macaron_fc2
=
None
self
.
macaron_norm
=
None
self
.
ff_scale
=
1.0
if
args
.
use_cnn_module
:
self
.
conv_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
conv_module
=
ConvolutionModule
(
self
.
embed_dim
,
args
.
cnn_module_kernel
,
self
.
activation_fn
)
self
.
final_norm
(
self
.
embed_dim
)
else
:
self
.
conv_norm
=
False
self
.
conv_module
=
None
self
.
final_norm
=
None
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
self
.
build_fc1
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
fc2
=
self
.
build_fc2
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
ff_norm
=
LayerNorm
(
self
.
embed_dim
)
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
p
=
q_noise
,
block_size
=
qn_block_size
)
def
build_fc2
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
p
=
q_noise
,
block_size
=
qn_block_size
)
def
build_self_attention
(
self
,
embed_dim
,
args
):
if
self
.
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
else
:
attn_func
=
MultiheadAttention
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
residual_connection
(
self
,
x
,
residual
):
return
residual
+
x
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"final_layer_norm"
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layer_norms.{}.{}"
.
format
(
name
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.{}.{}"
.
format
(
name
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_emb
:
Optional
[
Tensor
]
=
None
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
positions (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
.
to
(
torch
.
bool
),
-
1e8
)
# whether to use macaron style
if
self
.
macaron_norm
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
x
=
self
.
macaron_fc2
(
self
.
activation_dropout_module
(
self
.
activation_fn
(
self
.
macaron_fc1
(
x
))))
x
=
residual
+
self
.
ff_scale
*
self
.
dropout_module
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
==
"rel_selfattn"
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
# convolution module
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
=
residual
+
self
.
dropout_module
(
self
.
conv_module
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
ff_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
ff_norm
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
return
x
fairseq/modules/convolution.py
0 → 100644
查看文件 @
922ef3d9
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ConvolutionModule definition."""
from
torch
import
nn
class
ConvolutionModule
(
nn
.
Module
):
"""ConvolutionModule in Conformer model.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
"""
def
__init__
(
self
,
channels
,
kernel_size
,
activation
=
nn
.
ReLU
(),
bias
=
True
):
"""Construct an ConvolutionModule object."""
super
(
ConvolutionModule
,
self
)
.
__init__
()
# kernerl_size should be a odd number for 'SAME' padding
assert
(
kernel_size
-
1
)
%
2
==
0
self
.
pointwise_conv1
=
nn
.
Conv1d
(
channels
,
2
*
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
self
.
depthwise_conv
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
channels
,
bias
=
bias
,
)
self
.
norm
=
nn
.
BatchNorm1d
(
channels
)
self
.
pointwise_conv2
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
self
.
activation
=
activation
def
forward
(
self
,
x
):
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x
=
x
.
transpose
(
1
,
2
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
x
=
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (batch, channel, dim)
# 1D Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
x
=
self
.
activation
(
self
.
norm
(
x
))
x
=
self
.
pointwise_conv2
(
x
)
return
x
.
transpose
(
1
,
2
)
fairseq/modules/positional_embedding.py
查看文件 @
922ef3d9
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
,
RelPositionalEmbedding
def
PositionalEmbedding
(
def
PositionalEmbedding
(
...
@@ -14,8 +14,9 @@ def PositionalEmbedding(
...
@@ -14,8 +14,9 @@ def PositionalEmbedding(
embedding_dim
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
padding_idx
:
int
,
learned
:
bool
=
False
,
learned
:
bool
=
False
,
pos_emb_type
:
str
=
None
,
):
):
if
learned
:
if
learned
or
pos_emb_type
==
"learned"
:
# if padding_idx is specified then offset the embedding ids by
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
# this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside
# TODO: The right place for this offset would be inside
...
@@ -26,6 +27,12 @@ def PositionalEmbedding(
...
@@ -26,6 +27,12 @@ def PositionalEmbedding(
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
if
padding_idx
is
not
None
:
if
padding_idx
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
elif
pos_emb_type
is
not
None
and
pos_emb_type
.
startswith
(
"debug"
):
m
=
RelPositionalEmbedding
(
embedding_dim
,
padding_idx
,
init_size
=
num_embeddings
+
padding_idx
+
1
,
)
else
:
else
:
m
=
SinusoidalPositionalEmbedding
(
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
embedding_dim
,
...
...
fairseq/modules/rel_position_multihead_attention.py
查看文件 @
922ef3d9
...
@@ -3,13 +3,10 @@
...
@@ -3,13 +3,10 @@
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Dict
,
Optional
,
Tuple
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq
import
utils
from
fairseq.incremental_decoding_utils
import
with_incremental_state
from
fairseq.modules.multihead_attention
import
MultiheadAttention
from
fairseq.modules.multihead_attention
import
MultiheadAttention
from
fairseq.modules.quant_noise
import
quant_noise
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
...
@@ -67,7 +64,6 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -67,7 +64,6 @@ class RelPositionMultiheadAttention(MultiheadAttention):
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_u
)
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_u
)
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_v
)
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_v
)
def
forward
(
def
forward
(
self
,
self
,
query
,
query
,
...
@@ -108,41 +104,6 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -108,41 +104,6 @@ class RelPositionMultiheadAttention(MultiheadAttention):
assert
embed_dim
==
self
.
embed_dim
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
if
(
False
and
not
self
.
onnx_trace
and
not
is_tpu
# don't use PyTorch version on TPUs
and
incremental_state
is
None
and
not
static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and
not
torch
.
jit
.
is_scripting
()
):
assert
key
is
not
None
and
value
is
not
None
return
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
self
.
embed_dim
,
self
.
num_heads
,
torch
.
empty
([
0
]),
torch
.
cat
((
self
.
q_proj
.
bias
,
self
.
k_proj
.
bias
,
self
.
v_proj
.
bias
)),
self
.
bias_k
,
self
.
bias_v
,
self
.
add_zero_attn
,
self
.
dropout_module
.
p
,
self
.
out_proj
.
weight
,
self
.
out_proj
.
bias
,
self
.
training
or
self
.
dropout_module
.
apply_during_inference
,
key_padding_mask
,
need_weights
,
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
self
.
q_proj
.
weight
,
k_proj_weight
=
self
.
k_proj
.
weight
,
v_proj_weight
=
self
.
v_proj
.
weight
,
)
if
incremental_state
is
not
None
:
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
...
@@ -197,14 +158,19 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -197,14 +158,19 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# .view(tgt_len, bsz * self.num_heads, self.head_dim)
# .view(tgt_len, bsz * self.num_heads, self.head_dim)
# .transpose(0, 1)
# .transpose(0, 1)
# )
# )
# prepare q for RPE # (tgt_len, bsz, num_heads, head_dim)
q
=
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
self
.
num_heads
,
self
.
head_dim
)
# prepare q for RPE # (bsz, tgt_len num_heads, head_dim)
q
=
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
# k (bsz * num_heads, tgt_len, head_dim)
if
k
is
not
None
:
if
k
is
not
None
:
k
=
(
k
=
(
k
.
contiguous
()
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
)
)
# v (bsz * num_heads, tgt_len, head_dim)
if
v
is
not
None
:
if
v
is
not
None
:
v
=
(
v
=
(
v
.
contiguous
()
v
.
contiguous
()
...
@@ -283,31 +249,32 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -283,31 +249,32 @@ class RelPositionMultiheadAttention(MultiheadAttention):
)
)
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
p_rep
=
self
.
linear_pos
(
pos_emb
)
.
view
(
bsz
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
p
=
self
.
linear_pos
(
pos_emb
)
.
view
(
bsz
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
p_rep
=
p_rep
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
# p (bsz * num_heads, tgt_len, head_dim)
p
=
p
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
# (batch * head, time1, d_k)
# (batch * head, time1, d_k)
q_with_bias_u
=
(
q_with_bias_u
=
(
(
q
+
self
.
pos_bias_u
)
.
contiguous
(
)
(
q
+
self
.
pos_bias_u
)
.
transpose
(
1
,
2
)
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
contiguous
(
)
.
transpose
(
0
,
1
)
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
)
)
# (batch * head, time1, d_k)
# (batch * head, time1, d_k)
q_with_bias_v
=
(
q_with_bias_v
=
(
(
q
+
self
.
pos_bias_v
)
.
contiguous
(
)
(
q
+
self
.
pos_bias_v
)
.
transpose
(
0
,
1
)
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
contiguous
(
)
.
transpose
(
0
,
1
)
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
)
)
# compute attention score
# compute attention score
# first compute matrix a and matrix c
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch * head, time1, time2)
# (batch * head, time1, time2)
matrix_ac
=
torch
.
bmm
(
q_with_bias_u
,
k
.
transpose
(
1
,
2
))
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
(
1
,
2
))
# compute matrix b and matrix d
# compute matrix b and matrix d
# (batch * head, time1, time2)
# (batch * head, time1, time2)
matrix_bd
=
torch
.
bmm
(
q_with_bias_v
,
p_re
p
.
transpose
(
1
,
2
))
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
1
,
2
))
def
rel_shift
(
x
,
zero_triu
=
False
):
def
rel_shift
(
x
,
zero_triu
=
False
):
"""Compute relative positional encoding.
"""Compute relative positional encoding.
...
@@ -315,36 +282,30 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -315,36 +282,30 @@ class RelPositionMultiheadAttention(MultiheadAttention):
Args:
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
time1 means the length of query vector.
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
Returns:
torch.Tensor: Output tensor.
torch.Tensor: Output tensor.
"""
"""
zero_pad
=
torch
.
zeros
((
*
x
.
size
()[:
3
],
1
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
zero_pad
=
torch
.
zeros
((
x
.
size
()[
0
],
x
.
size
()[
1
],
1
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=-
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=-
1
)
x_padded
=
x_padded
.
view
(
*
x
.
size
()[:
2
],
x
.
size
(
3
)
+
1
,
x
.
size
(
2
))
x_padded
=
x_padded
.
view
(
x
.
size
()[
0
],
x
=
x_padded
[:,
:,
1
:]
.
view_as
(
x
)
x
.
size
()[
2
]
+
1
,
x
.
size
()[
1
])
x
=
x_padded
[:,
1
:]
.
view_as
(
x
)
# zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
# x_padded = torch.cat([zero_pad, x], dim=-1)
#
# x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
# x = x_padded[:, :, 1:].view_as(x)[
# :, :, :, : x.size(-1) // 2 + 1
# ] # only keep the positions from 0 to time2
if
zero_triu
:
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
)
ones
=
torch
.
ones
((
x
.
size
(
1
),
x
.
size
(
2
)),
device
=
x
.
device
)
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
3
)
-
x
.
size
(
2
))[
None
,
None
,
:,
:]
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
2
)
-
x
.
size
(
1
))[
None
,
:,
:]
return
x
return
x
# matrix_bd = matrix_bd.contiguous().view(bsz, self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
matrix_bd
=
rel_shift
(
matrix_bd
)
# matrix_bd = rel_shift(
# matrix_bd,
# ).contiguous().view(bsz * self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
attn_weights
=
(
matrix_ac
+
matrix_bd
)
*
self
.
scaling
attn_weights
=
(
matrix_ac
+
matrix_bd
)
*
self
.
scaling
# attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
...
...
fairseq/modules/sinusoidal_positional_embedding.py
查看文件 @
922ef3d9
...
@@ -103,3 +103,37 @@ class SinusoidalPositionalEmbedding(nn.Module):
...
@@ -103,3 +103,37 @@ class SinusoidalPositionalEmbedding(nn.Module):
.
view
(
bsz
,
seq_len
,
-
1
)
.
view
(
bsz
,
seq_len
,
-
1
)
.
detach
()
.
detach
()
)
)
class
RelPositionalEmbedding
(
SinusoidalPositionalEmbedding
):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
embedding_dim
,
padding_idx
,
init_size
=
1024
):
super
()
.
__init__
(
embedding_dim
,
padding_idx
,
init_size
)
self
.
max_size
=
init_size
def
forward
(
self
,
input
,
incremental_state
:
Optional
[
Any
]
=
None
,
timestep
:
Optional
[
Tensor
]
=
None
,
positions
:
Optional
[
Any
]
=
None
,
offset
:
int
=
0
):
"""Compute positional encoding.
Args:
input (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert
offset
+
input
.
size
(
1
)
<
self
.
max_size
self
.
weights
=
self
.
weights
.
to
(
input
.
device
)
pos_emb
=
self
.
weights
[:,
offset
:
offset
+
input
.
size
(
1
)]
return
pos_emb
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论