Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
47e0f6e0
You need to sign in or sign up before continuing.
Commit
47e0f6e0
authored
Aug 31, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the multibranch S2T architecture.
I also find some bugs in the dual architecture.
parent
793f553a
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
1200 行增加
和
52 行删除
+1200
-52
fairseq/models/speech_to_text/__init__.py
+1
-0
fairseq/models/speech_to_text/s2t_dual.py
+0
-18
fairseq/models/speech_to_text/s2t_multibranch.py
+777
-0
fairseq/models/transformer_s2.py
+2
-2
fairseq/modules/__init__.py
+2
-0
fairseq/modules/s2t_transformer_s2_layer.py
+343
-0
fairseq/modules/transformer_s2_layer.py
+75
-32
没有找到文件。
fairseq/models/speech_to_text/__init__.py
查看文件 @
47e0f6e0
...
...
@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
from
.s2t_sate
import
*
# noqa
from
.s2t_dual
import
*
# noqa
from
.s2t_ctc
import
*
from
.s2t_multibranch
import
*
fairseq/models/speech_to_text/s2t_dual.py
查看文件 @
47e0f6e0
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderModel
,
register_model
,
register_model_architecture
,
)
import
logging
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
TransformerS2Encoder
,
TransformerS2Decoder
,
)
from
fairseq.modules
import
(
FairseqDropout
,
LayerNorm
,
PositionalEmbedding
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
S2TTransformerEncoderLayer
,
DynamicLinearCombination
,
TransformerS2DecoderLayer
,
TransformerS2EncoderLayer
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
fairseq/models/speech_to_text/s2t_multibranch.py
0 → 100644
查看文件 @
47e0f6e0
import
logging
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch.nn
as
nn
from
torch
import
Tensor
from
fairseq
import
checkpoint_utils
,
utils
from
fairseq.models.speech_to_text
import
(
S2TTransformerModel
,
S2TTransformerEncoder
,
PDSS2TTransformerModel
,
PDSS2TTransformerEncoder
,
S2TSATEModel
,
)
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderDecoderModel
,
register_model
,
register_model_architecture
,
)
from
fairseq.modules
import
(
LayerNorm
,
PositionalEmbedding
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
S2TTransformerS2EncoderLayer
,
TransformerS2EncoderLayer
,
)
from
fairseq.modules.speech_to_text
import
Adapter
from
fairseq.models.transformer_s2
import
(
Embedding
,
TransformerS2Decoder
,
)
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"s2t_multibranch"
)
class
S2TMultiBranchModel
(
FairseqEncoderDecoderModel
):
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
S2TTransformerModel
.
add_args
(
parser
)
PDSS2TTransformerModel
.
add_specific_args
(
parser
)
S2TSATEModel
.
add_specific_args
(
parser
)
S2TMultiBranchModel
.
add_specific_args
(
parser
)
@staticmethod
def
add_specific_args
(
parser
):
# multibranch
parser
.
add_argument
(
"--junior-acoustic-encoder"
,
default
=
"transformer"
,
choices
=
[
"transformer"
,
"pds"
,
"sate"
,
"wav2vec"
],
type
=
str
,
help
=
"the architecture of the junior acoustic encoder"
,
)
parser
.
add_argument
(
"--senior-acoustic-encoder"
,
default
=
"transformer"
,
choices
=
[
"transformer"
,
"pds"
,
"sate"
,
"wav2vec"
],
type
=
str
,
help
=
"the architecture of the senior acoustic ASR encoder"
,
)
parser
.
add_argument
(
"--textual-encoder"
,
default
=
"transformer"
,
type
=
str
,
help
=
"the architecture of the MT encoder"
,
)
parser
.
add_argument
(
"--textual-encoder-dim"
,
type
=
int
,
help
=
"the dimension of the textual encoder"
,
)
parser
.
add_argument
(
"--junior-acoustic-encoder-layers"
,
default
=
6
,
type
=
int
,
help
=
"the layers of the senior acoustic encoder"
,
)
parser
.
add_argument
(
"--senior-acoustic-encoder-layers"
,
default
=
6
,
type
=
int
,
help
=
"the layers of the senior acoustic encoder"
,
)
parser
.
add_argument
(
"--textual-encoder-layers"
,
default
=
6
,
type
=
int
,
help
=
"the layers of the textual encoder"
,
)
# collaboration
parser
.
add_argument
(
"--collaboration-direction"
,
default
=
"none"
,
type
=
str
,
help
=
"direction of collaboration"
,
)
parser
.
add_argument
(
"--collaboration-step"
,
default
=
"1:1"
,
type
=
str
,
help
=
"collaboration step in two encoders"
,
)
parser
.
add_argument
(
"--encoder-collaboration-mode"
,
default
=
"serial"
,
type
=
str
,
help
=
"how to calculate attention during league in encoder"
,
)
parser
.
add_argument
(
"--decoder-collaboration-mode"
,
default
=
"serial"
,
type
=
str
,
help
=
"how to calculate attention during league in encoder"
,
)
# league
parser
.
add_argument
(
"--encoder-league-s1-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s1 representation"
,
)
parser
.
add_argument
(
"--encoder-league-s2-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s2 representation"
,
)
parser
.
add_argument
(
"--encoder-league-drop-net"
,
action
=
"store_true"
,
help
=
"drop one input during league"
,
)
parser
.
add_argument
(
"--encoder-league-drop-net-prob"
,
default
=
0.0
,
type
=
float
,
help
=
"probability of dropping one representations"
,
)
parser
.
add_argument
(
"--encoder-league-drop-net-mix"
,
action
=
"store_true"
,
help
=
"mix the two input with any probability"
,
)
parser
.
add_argument
(
"--decoder-league-s1-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s1 representation"
,
)
parser
.
add_argument
(
"--decoder-league-s2-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s2 representation"
,
)
parser
.
add_argument
(
"--decoder-league-drop-net"
,
action
=
"store_true"
,
help
=
"drop one input during league"
,
)
parser
.
add_argument
(
"--decoder-league-drop-net-prob"
,
default
=
0.0
,
type
=
float
,
help
=
"probability of dropping one representations"
,
)
parser
.
add_argument
(
"--decoder-league-drop-net-mix"
,
action
=
"store_true"
,
help
=
"mix the two input with any probability"
,
)
parser
.
add_argument
(
"--load-pretrained-junior-acoustic-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take junior acoustic encoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--load-pretrained-senior-acoustic-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take senior acoustic encoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--load-pretrained-textual-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take textual encoder weights from (for initialization)"
,
)
pass
@classmethod
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TMultiBranchEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
if
getattr
(
args
,
"load_pretrained_junior_encoder_from"
,
None
):
encoder
.
junior_acoustic_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
asr_encoder
,
checkpoint
=
args
.
load_pretrained_junior_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained junior acoustic encoder from: "
f
"{args.load_pretrained_junior_encoder_from}"
)
if
getattr
(
args
,
"load_pretrained_senior_encoder_from"
,
None
):
encoder
.
senior_acoustic_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
asr_encoder
,
checkpoint
=
args
.
load_pretrained_senior_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained senior acoustic encoder from: "
f
"{args.load_pretrained_senior_encoder_from}"
)
if
getattr
(
args
,
"load_pretrained_textual_encoder_from"
,
None
):
encoder
.
textual_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
mt_encoder
,
checkpoint
=
args
.
load_pretrained_textual_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained textual encoder from: "
f
"{args.load_pretrained_textual_encoder_from}"
)
return
encoder
@classmethod
def
build_decoder
(
cls
,
args
,
task
,
embed_tokens
):
decoder
=
TransformerS2Decoder
(
args
,
task
.
target_dictionary
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_decoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained decoder from: "
f
"{args.load_pretrained_decoder_from}"
)
decoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
decoder
,
checkpoint
=
args
.
load_pretrained_decoder_from
,
strict
=
False
)
return
decoder
@classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture
(
args
)
def
build_embedding
(
dictionary
,
embed_dim
):
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
return
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
src_dict
,
tgt_dict
=
task
.
source_dictionary
,
task
.
target_dictionary
if
args
.
share_all_embeddings
:
if
src_dict
!=
tgt_dict
:
raise
ValueError
(
"--share-all-embeddings requires a joined dictionary"
)
if
args
.
encoder_embed_dim
!=
args
.
decoder_embed_dim
:
raise
ValueError
(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
decoder_embed_tokens
=
encoder_embed_tokens
args
.
share_decoder_input_output_embed
=
True
else
:
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
decoder_embed_tokens
=
build_embedding
(
tgt_dict
,
args
.
decoder_embed_dim
)
encoder
=
cls
.
build_encoder
(
args
,
task
,
encoder_embed_tokens
)
if
getattr
(
args
,
"encoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
encoder
,
args
.
encoder_freeze_module
)
logging
.
info
(
"freeze the encoder module: {}"
.
format
(
args
.
encoder_freeze_module
))
decoder
=
cls
.
build_decoder
(
args
,
task
,
decoder_embed_tokens
)
if
getattr
(
args
,
"decoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
decoder
,
args
.
decoder_freeze_module
)
logging
.
info
(
"freeze the decoder module: {}"
.
format
(
args
.
decoder_freeze_module
))
return
cls
(
encoder
,
decoder
)
def
get_normalized_probs
(
self
,
net_output
:
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
List
[
Optional
[
Tensor
]]]]],
log_probs
:
bool
,
sample
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs
=
self
.
get_normalized_probs_scriptable
(
net_output
,
log_probs
,
sample
)
lprobs
.
batch_first
=
True
return
lprobs
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
,
**
kwargs
):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out
=
self
.
encoder
(
src_tokens
,
src_lengths
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
=
prev_output_tokens
,
encoder_out
=
encoder_out
)
return
decoder_out
class
S2TMultiBranchEncoder
(
FairseqEncoder
):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
None
)
self
.
padding_idx
=
1
setattr
(
args
,
"encoder_layers"
,
args
.
junior_acoustic_encoder_layers
)
junior_encoder_type
=
args
.
junior_acoustic_encoder
if
junior_encoder_type
==
"transformer"
:
self
.
junior_acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
elif
junior_encoder_type
==
"pds"
:
self
.
junior_acoustic_encoder
=
PDSS2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
else
:
logger
.
error
(
"Unsupported junior acoustic architecture:
%
s."
%
junior_encoder_type
)
self
.
senior_acoustic_attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
if
self
.
senior_acoustic_attn_type
==
"rel_pos"
:
self
.
senior_acoustic_embed_positions
=
RelPositionalEncoding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
)
elif
self
.
senior_acoustic_attn_type
in
[
"rel_selfattn"
,
"rel_pos_legacy"
]:
self
.
senior_acoustic_embed_positions
=
LegacyRelPositionalEncoding
(
args
.
encoder_embed_dim
,
args
.
dropout
,
args
.
max_source_positions
)
else
:
# Use absolute positional embedding
self
.
senior_acoustic_embed_positions
=
None
setattr
(
args
,
"collaboration_mode"
,
args
.
encoder_collaboration_mode
)
self
.
senior_acoustic_encoder_layer_num
=
args
.
senior_acoustic_encoder_layers
self
.
senior_acoustic_encoder_layers
=
nn
.
ModuleList
(
[
S2TTransformerS2EncoderLayer
(
args
)
for
_
in
range
(
self
.
senior_acoustic_encoder_layer_num
)])
# adapter
self
.
adapter_temperature
=
args
.
adapter_temperature
strategy
=
{
"embed_norm"
:
getattr
(
args
,
"adapter_embed_norm"
,
False
),
"out_norm"
:
getattr
(
args
,
"adapter_out_norm"
,
False
),
"ctc_compress_strategy"
:
getattr
(
args
,
"ctc_compress_strategy"
,
None
),
"distribution_cutoff"
:
getattr
(
args
,
"adapter_distribution_cutoff"
,
None
),
"drop_prob"
:
getattr
(
args
,
"adapter_drop_prob"
,
0
),
}
self
.
adapter
=
Adapter
(
args
.
encoder_embed_dim
,
args
.
adapter
,
len
(
task
.
source_dictionary
),
strategy
=
strategy
)
assert
not
(
args
.
share_adapter_and_ctc
and
args
.
share_adapter_and_embed
),
"Can not be True at the same time"
if
args
.
share_adapter_and_ctc
and
hasattr
(
self
.
adapter
,
"embed_adapter"
):
self
.
adapter
.
embed_adapter
.
weight
=
self
.
acoustic_encoder
.
ctc
.
ctc_projection
.
weight
if
args
.
share_adapter_and_embed
and
hasattr
(
self
.
adapter
,
"embed_adapter"
):
self
.
adapter
.
embed_adapter
.
weight
=
embed_tokens
.
weight
# textual encoder
self
.
textual_embed_positions
=
PositionalEmbedding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
)
attn_type
=
args
.
encoder_attention_type
setattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
textual_encoder_layer_num
=
args
.
textual_encoder_layers
self
.
textual_encoder_layers
=
nn
.
ModuleList
(
[
TransformerS2EncoderLayer
(
args
)
for
_
in
range
(
self
.
textual_encoder_layer_num
)])
setattr
(
args
,
"encoder_attention_type"
,
attn_type
)
# collaboration
collaboration_step
=
args
.
collaboration_step
if
len
(
collaboration_step
.
split
(
":"
))
==
2
:
self
.
collaboration_step
=
[
int
(
s
)
for
s
in
collaboration_step
.
split
(
":"
)]
else
:
self
.
collaboration_step
=
[
1
,
1
]
self
.
collaboration_direction
=
args
.
collaboration_direction
self
.
acoustic_norm
=
LayerNorm
(
args
.
encoder_embed_dim
)
self
.
textual_norm
=
LayerNorm
(
args
.
encoder_embed_dim
)
def
forward
(
self
,
src_tokens
,
src_lengths
=
None
,
**
kwargs
):
junior_acoustic_encoder_out
=
self
.
junior_acoustic_encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
acoustic_x
=
junior_acoustic_encoder_out
[
"encoder_out"
][
0
]
acoustic_encoder_padding_mask
=
junior_acoustic_encoder_out
[
"encoder_padding_mask"
][
0
]
if
"ctc_logit"
in
junior_acoustic_encoder_out
and
len
(
junior_acoustic_encoder_out
[
"ctc_logit"
])
>
0
:
ctc_logit
=
junior_acoustic_encoder_out
[
"ctc_logit"
][
0
]
else
:
ctc_logit
=
None
x
=
(
acoustic_x
,
ctc_logit
)
adapter_x
,
adapter_encoder_padding_mask
=
self
.
adapter
(
x
,
acoustic_encoder_padding_mask
)
textual_x
=
adapter_x
+
self
.
textual_embed_positions
(
adapter_encoder_padding_mask
)
.
transpose
(
0
,
1
)
# textual_x = self.dropout_module(textual_x)
textual_encoder_padding_mask
=
adapter_encoder_padding_mask
senior_acoustic_encoder_idx
=
-
1
textual_encoder_idx
=
-
1
while
True
:
if
self
.
collaboration_direction
==
"acoustic"
:
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
)
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
s2
=
textual_x
,
s2_encoder_padding_mask
=
textual_encoder_padding_mask
)
elif
self
.
collaboration_direction
==
"textual"
:
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
)
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
xs
=
acoustic_x
,
s2_encoder_padding_mask
=
acoustic_encoder_padding_mask
)
elif
self
.
collaboration_direction
==
"both"
:
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
s2
=
textual_x
,
s2_encoder_padding_mask
=
textual_encoder_padding_mask
)
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
s2
=
acoustic_x
,
s2_encoder_padding_mask
=
acoustic_encoder_padding_mask
)
elif
self
.
collaboration_direction
==
"none"
:
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
)
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
)
if
senior_acoustic_encoder_idx
==
self
.
senior_acoustic_encoder_layer_num
-
1
and
\
textual_encoder_idx
==
self
.
textual_encoder_layer_num
-
1
:
break
acoustic_x
=
self
.
acoustic_norm
(
acoustic_x
)
textual_x
=
self
.
acoustic_norm
(
textual_x
)
junior_acoustic_encoder_out
[
"encoder_out"
]
=
[
acoustic_x
]
junior_acoustic_encoder_out
[
"encoder_padding_mask"
]
=
[
acoustic_encoder_padding_mask
]
junior_acoustic_encoder_out
[
"s2_encoder_out"
]
=
[
textual_x
]
junior_acoustic_encoder_out
[
"s2_encoder_padding_mask"
]
=
[
textual_encoder_padding_mask
]
return
junior_acoustic_encoder_out
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if
len
(
encoder_out
[
"encoder_out"
])
==
0
:
new_encoder_out
=
[]
else
:
new_encoder_out
=
[
encoder_out
[
"encoder_out"
][
0
]
.
index_select
(
1
,
new_order
)]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
:
new_encoder_padding_mask
=
[]
else
:
new_encoder_padding_mask
=
[
encoder_out
[
"encoder_padding_mask"
][
0
]
.
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"s2_encoder_out"
])
==
0
:
new_s2_encoder_out
=
[]
else
:
new_s2_encoder_out
=
[
encoder_out
[
"s2_encoder_out"
][
0
]
.
index_select
(
1
,
new_order
)]
if
len
(
encoder_out
[
"s2_encoder_padding_mask"
])
==
0
:
new_s2_encoder_padding_mask
=
[]
else
:
new_s2_encoder_padding_mask
=
[
encoder_out
[
"s2_encoder_padding_mask"
][
0
]
.
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"encoder_embedding"
])
==
0
:
new_encoder_embedding
=
[]
else
:
new_encoder_embedding
=
[
encoder_out
[
"encoder_embedding"
][
0
]
.
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"src_tokens"
])
==
0
:
src_tokens
=
[]
else
:
src_tokens
=
[(
encoder_out
[
"src_tokens"
][
0
])
.
index_select
(
0
,
new_order
)]
if
len
(
encoder_out
[
"src_lengths"
])
==
0
:
src_lengths
=
[]
else
:
src_lengths
=
[(
encoder_out
[
"src_lengths"
][
0
])
.
index_select
(
0
,
new_order
)]
encoder_states
=
encoder_out
[
"encoder_states"
]
if
len
(
encoder_states
)
>
0
:
for
idx
,
state
in
enumerate
(
encoder_states
):
encoder_states
[
idx
]
=
state
.
index_select
(
1
,
new_order
)
return
{
"encoder_out"
:
new_encoder_out
,
# T x B x C
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"s2_encoder_out"
:
new_s2_encoder_out
,
# T x B x C
"s2_encoder_padding_mask"
:
new_s2_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
src_tokens
,
# B x T
"src_lengths"
:
src_lengths
,
# B x 1
}
@register_model_architecture
(
model_name
=
"s2t_multibranch"
,
arch_name
=
"s2t_multibranch"
)
def
base_architecture
(
args
):
# Convolutional subsampler
args
.
subsampling_type
=
getattr
(
args
,
"subsampling_type"
,
"conv1d"
)
args
.
subsampling_layers
=
getattr
(
args
,
"subsampling_layers"
,
2
)
args
.
subsampling_filter
=
getattr
(
args
,
"subsampling_filter"
,
1024
)
args
.
subsampling_kernel
=
getattr
(
args
,
"subsampling_kernel"
,
5
)
args
.
subsampling_stride
=
getattr
(
args
,
"subsampling_stride"
,
2
)
args
.
subsampling_norm
=
getattr
(
args
,
"subsampling_norm"
,
"none"
)
args
.
subsampling_activation
=
getattr
(
args
,
"subsampling_activation"
,
"glu"
)
# Transformer
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
True
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
True
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
args
.
dropout
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
args
.
dropout
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
tie_adaptive_weights
=
getattr
(
args
,
"tie_adaptive_weights"
,
False
)
args
.
tie_adaptive_proj
=
getattr
(
args
,
"tie_adaptive_proj"
,
False
)
args
.
adaptive_softmax_factor
=
getattr
(
args
,
"adaptive_softmax_factor"
,
4
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
"share_all_embeddings"
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
encoder_layerdrop
=
getattr
(
args
,
"encoder_layerdrop"
,
0.0
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0.0
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
encoder_no_scale_embedding
=
getattr
(
args
,
"encoder_no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
encoder_embed_linear
=
getattr
(
args
,
"encoder_embed_linear"
,
False
)
args
.
encoder_embed_norm
=
getattr
(
args
,
"encoder_embed_norm"
,
False
)
# CTC
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
share_ctc_and_embed
=
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
# Conformer
args
.
encoder_activation_fn
=
getattr
(
args
,
"encoder_activation_fn"
,
"relu"
)
args
.
macaron_style
=
getattr
(
args
,
"macaron_style"
,
False
)
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
False
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
args
.
cnn_module_norm
=
getattr
(
args
,
"cnn_module_norm"
,
"batch_norm"
)
# settings for DLCL
args
.
use_enc_dlcl
=
getattr
(
args
,
"use_enc_dlcl"
,
False
)
args
.
use_dec_dlcl
=
getattr
(
args
,
"use_dec_dlcl"
,
False
)
args
.
init_value
=
getattr
(
args
,
'init_value'
,
'avg'
)
args
.
weight_type
=
getattr
(
args
,
'weight_type'
,
'scalar'
)
args
.
encoder_learnable
=
getattr
(
args
,
'encoder_learnable'
,
True
)
args
.
normalize_embed
=
getattr
(
args
,
'normalize_embed'
,
False
)
args
.
history_dropout
=
getattr
(
args
,
'history_dropout'
,
0.0
)
args
.
history_window_size
=
getattr
(
args
,
'history_window_size'
,
-
1
)
# Relative position encoding
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
# local modeling
args
.
hard_mask_window
=
getattr
(
args
,
'hard_mask_window'
,
0
)
args
.
gauss_mask_sigma
=
getattr
(
args
,
'gauss_mask_sigma'
,
0
)
args
.
init_mask_weight
=
getattr
(
args
,
'init_mask_weight'
,
0
)
# interleaved CTC
args
.
interleaved_ctc_layers
=
getattr
(
args
,
"interleaved_ctc_layers"
,
None
)
args
.
interleaved_ctc_temperature
=
getattr
(
args
,
"interleaved_ctc_temperature"
,
1
)
args
.
interleaved_ctc_drop_prob
=
getattr
(
args
,
"interleaved_ctc_drop_prob"
,
0
)
# Semantics-augmented Encoding (sae)
args
.
sae_adapter
=
getattr
(
args
,
"sae_adapter"
,
"none"
)
args
.
target_sae_adapter
=
getattr
(
args
,
"target_sae_adapter"
,
args
.
sae_adapter
)
args
.
share_sae_and_ctc
=
getattr
(
args
,
"share_sae_and_ctc"
,
False
)
args
.
share_target_sae_and_ctc
=
getattr
(
args
,
"share_target_sae_and_ctc"
,
False
)
args
.
sae_drop_prob
=
getattr
(
args
,
"sae_drop_prob"
,
0
)
args
.
sae_distribution_cutoff
=
getattr
(
args
,
"sae_distribution_cutoff"
,
None
)
args
.
sae_distribution_hard
=
getattr
(
args
,
"sae_distribution_hard"
,
False
)
args
.
sae_gumbel
=
getattr
(
args
,
"sae_gumbel"
,
False
)
# mixup
args
.
inter_mixup
=
getattr
(
args
,
"inter_mixup"
,
False
)
args
.
inter_mixup_layer
=
getattr
(
args
,
"inter_mixup_layer"
,
None
)
args
.
inter_mixup_beta
=
getattr
(
args
,
"inter_mixup_beta"
,
0.5
)
args
.
inter_mixup_prob
=
getattr
(
args
,
"inter_mixup_prob"
,
1
)
args
.
inter_mixup_ratio
=
getattr
(
args
,
"inter_mixup_ratio"
,
0.3
)
args
.
inter_mixup_keep_org
=
getattr
(
args
,
"inter_mixup_keep_org"
,
False
)
# PDS
args
.
pds_stages
=
getattr
(
args
,
"pds_stages"
,
None
)
args
.
pds_layers
=
getattr
(
args
,
"pds_layers"
,
None
)
args
.
pds_ratios
=
getattr
(
args
,
"pds_ratios"
,
None
)
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
False
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_cnn_kernel_sizes
=
getattr
(
args
,
"pds_cnn_kernel_sizes"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
args
.
pds_fusion
=
getattr
(
args
,
"pds_fusion"
,
False
)
args
.
pds_fusion_method
=
getattr
(
args
,
"pds_fusion_method"
,
"all_conv"
)
# dual
args
.
junior_acoustic_encoder
=
getattr
(
args
,
"junior_acoustic_encoder"
,
"transformer"
)
args
.
senior_acoustic_encoder
=
getattr
(
args
,
"senior_acoustic_encoder"
,
"transformer"
)
args
.
textual_encoder
=
getattr
(
args
,
"textual_encoder"
,
"transformer"
)
args
.
textual_encoder_dim
=
getattr
(
args
,
"textual_encoder"
,
args
.
encoder_embed_dim
)
args
.
junior_acoustic_encoder_layers
=
getattr
(
args
,
"junior_acoustic_encoder_layers"
,
6
)
args
.
senior_acoustic_encoder_layers
=
getattr
(
args
,
"senior_acoustic_encoder_layers"
,
6
)
args
.
textual_encoder_layers
=
getattr
(
args
,
"textual_encoder_layers"
,
6
)
args
.
collaboration_direction
=
getattr
(
args
,
"collaboration_direction"
,
"none"
)
args
.
collaboration_step
=
getattr
(
args
,
"collaboration_step"
,
"1:1"
)
args
.
encoder_collaboration_mode
=
getattr
(
args
,
"encoder_collaboration_mode"
,
"serial"
)
args
.
decoder_collaboration_mode
=
getattr
(
args
,
"decoder_collaboration_mode"
,
"serial"
)
args
.
encoder_league_s1_ratio
=
getattr
(
args
,
"encoder_league_s1_ratio"
,
0.5
)
args
.
encoder_league_s2_ratio
=
getattr
(
args
,
"encoder_league_s2_ratio"
,
0.5
)
args
.
encoder_league_drop_net
=
getattr
(
args
,
"encoder_league_drop_net"
,
False
)
args
.
encoder_league_drop_net_prob
=
getattr
(
args
,
"encoder_league_drop_net_prob"
,
0.0
)
args
.
encoder_league_drop_net_mix
=
getattr
(
args
,
"encoder_league_drop_net_mix"
,
False
)
args
.
decoder_league_s1_ratio
=
getattr
(
args
,
"decoder_league_s1_ratio"
,
0.5
)
args
.
decoder_league_s2_ratio
=
getattr
(
args
,
"decoder_league_s2_ratio"
,
0.5
)
args
.
decoder_league_drop_net
=
getattr
(
args
,
"decoder_league_drop_net"
,
False
)
args
.
decoder_league_drop_net_prob
=
getattr
(
args
,
"decoder_league_drop_net_prob"
,
0.0
)
args
.
decoder_league_drop_net_mix
=
getattr
(
args
,
"decoder_league_drop_net_mix"
,
False
)
# args.encoder_asr_ratio = getattr(args, "encoder_asr_ratio", 1.0)
# args.encoder_mt_ratio = getattr(args, "encoder_mt_ratio", 1.0)
# args.encoder_drop_net = getattr(args, "encoder_drop_net", False)
# args.encoder_drop_net_prob = getattr(args, "encoder_drop_net_prob", 1.0)
# args.encoder_drop_net_mix = getattr(args, "encoder_drop_net_mix", False)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_s"
)
def
s2t_multibranch_s
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
8
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_s_relative"
)
def
s2t_multibranch_s_relative
(
args
):
args
.
max_encoder_relative_length
=
100
args
.
k_only
=
True
s2t_multibranch_s
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_xs"
)
def
s2t_multibranch_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.3
)
s2t_multibranch_s
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_sp"
)
def
s2t_multibranch_sp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_multibranch_s
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_m"
)
def
s2t_multibranch_m
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
512
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.15
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_mp"
)
def
s2t_multibranch_mp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_multibranch_m
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_l"
)
def
s2t_multibranch_l
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
1024
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_lp"
)
def
s2t_multibranch_lp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_multibranch_l
(
args
)
fairseq/models/transformer_s2.py
查看文件 @
47e0f6e0
...
...
@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
)
else
None
,
encoder_out_s2
=
encoder_out
[
"
encoder_out_s2
"
][
0
],
encoder_padding_mask_s2
=
encoder_out
[
"
encoder_padding_mask_s2
"
][
0
],
encoder_out_s2
=
encoder_out
[
"
s2_encoder_out
"
][
0
],
encoder_padding_mask_s2
=
encoder_out
[
"
s2_encoder_padding_mask
"
][
0
],
incremental_state
=
incremental_state
,
self_attn_mask
=
self_attn_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
...
...
fairseq/modules/__init__.py
查看文件 @
47e0f6e0
...
...
@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
)
from
.convolution
import
ConvolutionModule
from
.s2t_transformer_layer
import
S2TTransformerEncoderLayer
from
.s2t_transformer_s2_layer
import
S2TTransformerS2EncoderLayer
from
.pds_layer
import
PDSTransformerEncoderLayer
__all__
=
[
...
...
@@ -70,6 +71,7 @@ __all__ = [
"BeamableMM"
,
"CharacterTokenEmbedder"
,
"S2TTransformerEncoderLayer"
,
"S2TTransformerS2EncoderLayer"
,
"ConvolutionModule"
,
"ConvTBC"
,
"cross_entropy"
,
...
...
fairseq/modules/s2t_transformer_s2_layer.py
0 → 100644
查看文件 @
47e0f6e0
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
fairseq.modules
import
(
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
,
ConvolutionModule
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
torch
import
Tensor
from
fairseq.modules.activations
import
get_activation_class
class
FeedForwardModule
(
torch
.
nn
.
Module
):
"""Positionwise feed forward layer used in conformer"""
def
__init__
(
self
,
input_feat
,
hidden_units
,
dropout1
,
dropout2
,
activation_fn
=
"relu"
,
bias
=
True
,
):
"""
Args:
input_feat: Input feature dimension
hidden_units: Hidden unit dimension
dropout1: dropout value for layer1
dropout2: dropout value for layer2
activation_fn: Name of activation function
bias: If linear layers should have bias
"""
super
(
FeedForwardModule
,
self
)
.
__init__
()
self
.
w_1
=
torch
.
nn
.
Linear
(
input_feat
,
hidden_units
,
bias
=
bias
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_units
,
input_feat
,
bias
=
bias
)
self
.
dropout1
=
torch
.
nn
.
Dropout
(
dropout1
)
self
.
dropout2
=
torch
.
nn
.
Dropout
(
dropout2
)
self
.
activation
=
get_activation_class
(
activation_fn
)
def
forward
(
self
,
x
):
"""
Args:
x: Input Tensor of shape T X B X C
Returns:
Tensor of shape T X B X C
"""
x
=
self
.
w_1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
dropout1
(
x
)
x
=
self
.
w_2
(
x
)
return
self
.
dropout2
(
x
)
class
S2TTransformerS2EncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def
__init__
(
self
,
args
):
super
()
.
__init__
()
self
.
args
=
args
embed_dim
=
args
.
encoder_embed_dim
ffn_dim
=
args
.
encoder_ffn_embed_dim
dropout
=
args
.
dropout
self
.
embed_dim
=
embed_dim
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
args
,
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
normalize_before
=
args
.
encoder_normalize_before
activation
=
getattr
(
args
,
'encoder_activation_fn'
,
'relu'
)
if
args
.
macaron_style
:
self
.
macaron_ffn
=
FeedForwardModule
(
embed_dim
,
ffn_dim
,
dropout
,
dropout
,
activation
)
self
.
macaron_norm
=
LayerNorm
(
embed_dim
)
self
.
ffn_scale
=
0.5
else
:
self
.
macaron_ffn
=
None
self
.
macaron_norm
=
None
self
.
ffn_scale
=
1.0
if
args
.
use_cnn_module
:
self
.
conv_norm
=
LayerNorm
(
embed_dim
)
self
.
conv_module
=
ConvolutionModule
(
self
.
embed_dim
,
self
.
embed_dim
,
depthwise_kernel_size
=
args
.
cnn_module_kernel
,
dropout
=
args
.
dropout
,
activation_fn
=
getattr
(
args
,
'activation_fn'
,
'swish'
),
norm_type
=
args
.
cnn_module_norm
)
self
.
final_norm
=
LayerNorm
(
embed_dim
)
else
:
self
.
conv_norm
=
None
self
.
conv_module
=
None
self
.
final_norm
=
None
self
.
ffn
=
FeedForwardModule
(
embed_dim
,
ffn_dim
,
dropout
,
dropout
,
activation
)
self
.
ffn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
encoder_attention_heads
,
kdim
=
getattr
(
args
,
"s2_encoder_embed_dim"
,
self
.
embed_dim
),
vdim
=
getattr
(
args
,
"s2_encoder_embed_dim"
,
self
.
embed_dim
),
dropout
=
args
.
attention_dropout
,
self_attention
=
False
,
)
def
build_self_attention
(
self
,
args
,
embed_dim
):
attention_heads
=
args
.
encoder_attention_heads
dropout
=
args
.
dropout
if
self
.
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
elif
self
.
attn_type
==
"rel_pos"
:
return
RelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"rel_pos_legacy"
:
return
LegacyRelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"rope"
:
return
RotaryPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
precision
=
args
.
fp16
)
elif
self
.
attn_type
==
"abs"
:
return
ESPNETMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
else
:
attn_func
=
MultiheadAttention
print
(
"The encoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
residual_connection
(
self
,
x
,
residual
):
return
residual
+
x
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"final_layer_norm"
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layer_norms.{}.{}"
.
format
(
name
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.{}.{}"
.
format
(
name
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
s2
=
None
,
s2_encoder_padding_mask
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_emb
:
Optional
[
Tensor
]
=
None
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
pos_emb (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
.
to
(
torch
.
bool
),
-
1e8
)
# whether to use macaron style
if
self
.
macaron_norm
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
x
=
self
.
macaron_ffn
(
x
)
x
=
residual
+
self
.
ffn_scale
*
x
if
not
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
s2
is
not
None
:
residual
=
x
x
=
self
.
s2_attn_norm
(
x
)
s2
=
self
.
s2_norm
(
s2
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
s2
,
value
=
s2
,
key_padding_mask
=
s2_encoder_padding_mask
,
need_weights
=
False
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
# convolution module
if
self
.
conv_module
is
not
None
:
residual
=
x
x
=
x
.
transpose
(
0
,
1
)
if
self
.
normalize_before
:
x
=
self
.
conv_norm
(
x
)
x
=
self
.
conv_module
(
x
)
x
=
x
.
transpose
(
0
,
1
)
x
=
residual
+
x
if
not
self
.
normalize_before
:
x
=
self
.
conv_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
ffn_norm
(
x
)
x
=
self
.
ffn
(
x
)
x
=
self
.
residual_connection
(
self
.
ffn_scale
*
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
ffn_norm
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
final_norm
(
x
)
return
x
fairseq/modules/transformer_s2_layer.py
查看文件 @
47e0f6e0
...
...
@@ -79,6 +79,8 @@ class TransformerS2EncoderLayer(nn.Module):
if
self
.
use_se
:
self
.
se_attn
=
SEAttention
(
self
.
embed_dim
,
16
)
self
.
s2_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
encoder_attention_heads
,
...
...
@@ -87,26 +89,28 @@ class TransformerS2EncoderLayer(nn.Module):
dropout
=
args
.
attention_dropout
,
self_attention
=
False
,
)
self
.
s1_ratio
=
args
.
encoder_s1_ratio
self
.
s2_ratio
=
args
.
encoder_s2_ratio
self
.
drop_net
=
args
.
encoder_drop_net
self
.
drop_net_prob
=
args
.
encoder_drop_net_prob
self
.
drop_net_mix
=
args
.
encoder_drop_net_mix
self
.
encoder_collaboration_mode
=
args
.
encoder_collaboration_mode
self
.
league_s1_ratio
=
args
.
encoder_league_s1_ratio
self
.
league_s2_ratio
=
args
.
encoder_league_s2_ratio
self
.
league_drop_net
=
args
.
encoder_league_drop_net
self
.
league_drop_net_prob
=
args
.
encoder_league_drop_net_prob
self
.
league_drop_net_mix
=
args
.
encoder_league_drop_net_mix
def
get_ratio
(
self
):
if
self
.
drop_net
:
if
self
.
league_
drop_net
:
frand
=
float
(
uniform
(
0
,
1
))
if
self
.
drop_net_mix
and
self
.
training
:
return
[
frand
,
1
-
frand
]
if
frand
<
self
.
drop_net_prob
and
self
.
training
:
if
frand
<
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
1
,
0
]
elif
frand
>
1
-
self
.
drop_net_prob
and
self
.
training
:
elif
frand
>
1
-
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
0
,
1
]
else
:
return
[
0.5
,
0.5
]
else
:
return
[
self
.
s1_ratio
,
self
.
s2_ratio
]
return
[
self
.
league_s1_ratio
,
self
.
league_
s2_ratio
]
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
...
...
@@ -186,8 +190,8 @@ class TransformerS2EncoderLayer(nn.Module):
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
x
2
=
None
,
x
2_encoder_padding_mask
=
None
,
s
2
=
None
,
s
2_encoder_padding_mask
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_emb
:
Optional
[
Tensor
]
=
None
):
"""
...
...
@@ -219,6 +223,7 @@ class TransformerS2EncoderLayer(nn.Module):
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
attn_x
=
x
if
self
.
attn_type
==
"rel_selfattn"
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
...
...
@@ -240,20 +245,34 @@ class TransformerS2EncoderLayer(nn.Module):
attn_mask
=
attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
if
s2
is
None
or
self
.
encoder_collaboration_mode
!=
"parallel"
:
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
x2
is
not
None
:
x2
,
_
=
self
.
s2_attn
(
if
s2
is
not
None
:
s2
=
self
.
s2_norm
(
s2
)
if
self
.
encoder_collaboration_mode
==
"serial"
:
residual
=
x
x
=
self
.
s2_attn_norm
(
x
)
x
,
_
=
self
.
s2_attn
(
query
=
x
,
key
=
x2
,
value
=
x2
,
key_padding_mask
=
x2_encoder_padding_mask
)
key
=
s2
,
value
=
s2
,
key_padding_mask
=
s2_encoder_padding_mask
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
elif
self
.
encoder_collaboration_mode
==
"parallel"
:
x2
,
_
=
self
.
s2_attn
(
query
=
attn_x
,
key
=
s2
,
value
=
s2
,
key_padding_mask
=
s2_encoder_padding_mask
)
x2
=
self
.
dropout_module
(
x2
)
ratio
=
self
.
get_ratio
()
x
=
x
*
ratio
[
0
]
+
x2
*
ratio
[
1
]
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
...
...
@@ -341,11 +360,12 @@ class TransformerS2DecoderLayer(nn.Module):
self
.
s2_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
kdim
=
getattr
(
args
,
"encoder_
x
2_dim"
,
self
.
embed_dim
),
vdim
=
getattr
(
args
,
"encoder_
x
2_dim"
,
self
.
embed_dim
),
kdim
=
getattr
(
args
,
"encoder_
s
2_dim"
,
self
.
embed_dim
),
vdim
=
getattr
(
args
,
"encoder_
s
2_dim"
,
self
.
embed_dim
),
dropout
=
args
.
attention_dropout
,
encoder_decoder_attention
=
True
,
)
self
.
s2_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
fc1
=
self
.
build_fc1
(
self
.
embed_dim
,
...
...
@@ -365,26 +385,27 @@ class TransformerS2DecoderLayer(nn.Module):
self
.
onnx_trace
=
False
self
.
s1_ratio
=
args
.
encoder_s1_ratio
self
.
s2_ratio
=
args
.
encoder_s2_ratio
self
.
decoder_collaboration_mode
=
args
.
decoder_collaboration_mode
self
.
league_s1_ratio
=
args
.
decoder_league_s1_ratio
self
.
league_s2_ratio
=
args
.
decoder_league_s2_ratio
self
.
drop_net
=
args
.
encoder
_drop_net
self
.
drop_net_prob
=
args
.
encoder
_drop_net_prob
self
.
drop_net_mix
=
args
.
encoder
_drop_net_mix
self
.
league_drop_net
=
args
.
decoder_league
_drop_net
self
.
league_drop_net_prob
=
args
.
decoder_league
_drop_net_prob
self
.
league_drop_net_mix
=
args
.
decoder_league
_drop_net_mix
def
get_ratio
(
self
):
if
self
.
drop_net
:
if
self
.
league_
drop_net
:
frand
=
float
(
uniform
(
0
,
1
))
if
self
.
drop_net_mix
and
self
.
training
:
return
[
frand
,
1
-
frand
]
if
frand
<
self
.
drop_net_prob
and
self
.
training
:
if
frand
<
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
1
,
0
]
elif
frand
>
1
-
self
.
drop_net_prob
and
self
.
training
:
elif
frand
>
1
-
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
0
,
1
]
else
:
return
[
0.5
,
0.5
]
else
:
return
[
self
.
s1_ratio
,
self
.
s2_ratio
]
return
[
self
.
league_s1_ratio
,
self
.
league_
s2_ratio
]
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
q_noise
,
qn_block_size
)
...
...
@@ -551,6 +572,8 @@ class TransformerS2DecoderLayer(nn.Module):
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
cross_attn_x
=
x
if
prev_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
...
...
@@ -575,9 +598,16 @@ class TransformerS2DecoderLayer(nn.Module):
need_head_weights
=
need_head_weights
,
)
x
=
self
.
dropout_module
(
x
)
if
encoder_out_s2
is
None
or
self
.
decoder_collaboration_mode
!=
"parallel"
:
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
if
encoder_out_s2
is
not
None
:
x2
,
_
=
self
.
s2_attn
(
if
self
.
decoder_collaboration_mode
==
"serial"
:
residual
=
x
x
=
self
.
s2_attn_layer_norm
(
x
)
x
,
_
=
self
.
s2_attn
(
query
=
x
,
key
=
encoder_out_s2
,
value
=
encoder_out_s2
,
...
...
@@ -587,10 +617,23 @@ class TransformerS2DecoderLayer(nn.Module):
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
need_head_weights
=
need_head_weights
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
elif
self
.
decoder_collaboration_mode
==
"parallel"
:
x2
,
_
=
self
.
s2_attn
(
query
=
cross_attn_x
,
key
=
encoder_out_s2
,
value
=
encoder_out_s2
,
key_padding_mask
=
encoder_padding_mask_s2
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
need_head_weights
=
need_head_weights
,
)
x2
=
self
.
dropout_module
(
x2
)
ratios
=
self
.
get_ratio
()
x
=
ratios
[
0
]
*
x
+
ratios
[
1
]
*
x2
x
=
x
+
x2
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论