Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
47e0f6e0
Commit
47e0f6e0
authored
Aug 31, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the multibranch S2T architecture.
I also find some bugs in the dual architecture.
parent
793f553a
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
1221 行增加
和
73 行删除
+1221
-73
fairseq/models/speech_to_text/__init__.py
+1
-0
fairseq/models/speech_to_text/s2t_dual.py
+0
-18
fairseq/models/speech_to_text/s2t_multibranch.py
+777
-0
fairseq/models/transformer_s2.py
+2
-2
fairseq/modules/__init__.py
+2
-0
fairseq/modules/s2t_transformer_s2_layer.py
+343
-0
fairseq/modules/transformer_s2_layer.py
+96
-53
没有找到文件。
fairseq/models/speech_to_text/__init__.py
查看文件 @
47e0f6e0
...
@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
...
@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
from
.s2t_sate
import
*
# noqa
from
.s2t_sate
import
*
# noqa
from
.s2t_dual
import
*
# noqa
from
.s2t_dual
import
*
# noqa
from
.s2t_ctc
import
*
from
.s2t_ctc
import
*
from
.s2t_multibranch
import
*
fairseq/models/speech_to_text/s2t_dual.py
查看文件 @
47e0f6e0
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderModel
,
register_model
,
register_model_architecture
,
)
import
logging
import
logging
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
...
@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
TransformerS2Encoder
,
TransformerS2Encoder
,
TransformerS2Decoder
,
TransformerS2Decoder
,
)
)
from
fairseq.modules
import
(
FairseqDropout
,
LayerNorm
,
PositionalEmbedding
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
S2TTransformerEncoderLayer
,
DynamicLinearCombination
,
TransformerS2DecoderLayer
,
TransformerS2EncoderLayer
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
fairseq/models/speech_to_text/s2t_multibranch.py
0 → 100644
查看文件 @
47e0f6e0
import
logging
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch.nn
as
nn
from
torch
import
Tensor
from
fairseq
import
checkpoint_utils
,
utils
from
fairseq.models.speech_to_text
import
(
S2TTransformerModel
,
S2TTransformerEncoder
,
PDSS2TTransformerModel
,
PDSS2TTransformerEncoder
,
S2TSATEModel
,
)
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderDecoderModel
,
register_model
,
register_model_architecture
,
)
from
fairseq.modules
import
(
LayerNorm
,
PositionalEmbedding
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
S2TTransformerS2EncoderLayer
,
TransformerS2EncoderLayer
,
)
from
fairseq.modules.speech_to_text
import
Adapter
from
fairseq.models.transformer_s2
import
(
Embedding
,
TransformerS2Decoder
,
)
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"s2t_multibranch"
)
class
S2TMultiBranchModel
(
FairseqEncoderDecoderModel
):
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
S2TTransformerModel
.
add_args
(
parser
)
PDSS2TTransformerModel
.
add_specific_args
(
parser
)
S2TSATEModel
.
add_specific_args
(
parser
)
S2TMultiBranchModel
.
add_specific_args
(
parser
)
@staticmethod
def
add_specific_args
(
parser
):
# multibranch
parser
.
add_argument
(
"--junior-acoustic-encoder"
,
default
=
"transformer"
,
choices
=
[
"transformer"
,
"pds"
,
"sate"
,
"wav2vec"
],
type
=
str
,
help
=
"the architecture of the junior acoustic encoder"
,
)
parser
.
add_argument
(
"--senior-acoustic-encoder"
,
default
=
"transformer"
,
choices
=
[
"transformer"
,
"pds"
,
"sate"
,
"wav2vec"
],
type
=
str
,
help
=
"the architecture of the senior acoustic ASR encoder"
,
)
parser
.
add_argument
(
"--textual-encoder"
,
default
=
"transformer"
,
type
=
str
,
help
=
"the architecture of the MT encoder"
,
)
parser
.
add_argument
(
"--textual-encoder-dim"
,
type
=
int
,
help
=
"the dimension of the textual encoder"
,
)
parser
.
add_argument
(
"--junior-acoustic-encoder-layers"
,
default
=
6
,
type
=
int
,
help
=
"the layers of the senior acoustic encoder"
,
)
parser
.
add_argument
(
"--senior-acoustic-encoder-layers"
,
default
=
6
,
type
=
int
,
help
=
"the layers of the senior acoustic encoder"
,
)
parser
.
add_argument
(
"--textual-encoder-layers"
,
default
=
6
,
type
=
int
,
help
=
"the layers of the textual encoder"
,
)
# collaboration
parser
.
add_argument
(
"--collaboration-direction"
,
default
=
"none"
,
type
=
str
,
help
=
"direction of collaboration"
,
)
parser
.
add_argument
(
"--collaboration-step"
,
default
=
"1:1"
,
type
=
str
,
help
=
"collaboration step in two encoders"
,
)
parser
.
add_argument
(
"--encoder-collaboration-mode"
,
default
=
"serial"
,
type
=
str
,
help
=
"how to calculate attention during league in encoder"
,
)
parser
.
add_argument
(
"--decoder-collaboration-mode"
,
default
=
"serial"
,
type
=
str
,
help
=
"how to calculate attention during league in encoder"
,
)
# league
parser
.
add_argument
(
"--encoder-league-s1-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s1 representation"
,
)
parser
.
add_argument
(
"--encoder-league-s2-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s2 representation"
,
)
parser
.
add_argument
(
"--encoder-league-drop-net"
,
action
=
"store_true"
,
help
=
"drop one input during league"
,
)
parser
.
add_argument
(
"--encoder-league-drop-net-prob"
,
default
=
0.0
,
type
=
float
,
help
=
"probability of dropping one representations"
,
)
parser
.
add_argument
(
"--encoder-league-drop-net-mix"
,
action
=
"store_true"
,
help
=
"mix the two input with any probability"
,
)
parser
.
add_argument
(
"--decoder-league-s1-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s1 representation"
,
)
parser
.
add_argument
(
"--decoder-league-s2-ratio"
,
default
=
0.5
,
type
=
float
,
help
=
"league ratio of the s2 representation"
,
)
parser
.
add_argument
(
"--decoder-league-drop-net"
,
action
=
"store_true"
,
help
=
"drop one input during league"
,
)
parser
.
add_argument
(
"--decoder-league-drop-net-prob"
,
default
=
0.0
,
type
=
float
,
help
=
"probability of dropping one representations"
,
)
parser
.
add_argument
(
"--decoder-league-drop-net-mix"
,
action
=
"store_true"
,
help
=
"mix the two input with any probability"
,
)
parser
.
add_argument
(
"--load-pretrained-junior-acoustic-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take junior acoustic encoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--load-pretrained-senior-acoustic-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take senior acoustic encoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--load-pretrained-textual-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take textual encoder weights from (for initialization)"
,
)
pass
@classmethod
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TMultiBranchEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
if
getattr
(
args
,
"load_pretrained_junior_encoder_from"
,
None
):
encoder
.
junior_acoustic_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
asr_encoder
,
checkpoint
=
args
.
load_pretrained_junior_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained junior acoustic encoder from: "
f
"{args.load_pretrained_junior_encoder_from}"
)
if
getattr
(
args
,
"load_pretrained_senior_encoder_from"
,
None
):
encoder
.
senior_acoustic_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
asr_encoder
,
checkpoint
=
args
.
load_pretrained_senior_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained senior acoustic encoder from: "
f
"{args.load_pretrained_senior_encoder_from}"
)
if
getattr
(
args
,
"load_pretrained_textual_encoder_from"
,
None
):
encoder
.
textual_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
mt_encoder
,
checkpoint
=
args
.
load_pretrained_textual_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained textual encoder from: "
f
"{args.load_pretrained_textual_encoder_from}"
)
return
encoder
@classmethod
def
build_decoder
(
cls
,
args
,
task
,
embed_tokens
):
decoder
=
TransformerS2Decoder
(
args
,
task
.
target_dictionary
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_decoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained decoder from: "
f
"{args.load_pretrained_decoder_from}"
)
decoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
decoder
,
checkpoint
=
args
.
load_pretrained_decoder_from
,
strict
=
False
)
return
decoder
@classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture
(
args
)
def
build_embedding
(
dictionary
,
embed_dim
):
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
return
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
src_dict
,
tgt_dict
=
task
.
source_dictionary
,
task
.
target_dictionary
if
args
.
share_all_embeddings
:
if
src_dict
!=
tgt_dict
:
raise
ValueError
(
"--share-all-embeddings requires a joined dictionary"
)
if
args
.
encoder_embed_dim
!=
args
.
decoder_embed_dim
:
raise
ValueError
(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
decoder_embed_tokens
=
encoder_embed_tokens
args
.
share_decoder_input_output_embed
=
True
else
:
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
decoder_embed_tokens
=
build_embedding
(
tgt_dict
,
args
.
decoder_embed_dim
)
encoder
=
cls
.
build_encoder
(
args
,
task
,
encoder_embed_tokens
)
if
getattr
(
args
,
"encoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
encoder
,
args
.
encoder_freeze_module
)
logging
.
info
(
"freeze the encoder module: {}"
.
format
(
args
.
encoder_freeze_module
))
decoder
=
cls
.
build_decoder
(
args
,
task
,
decoder_embed_tokens
)
if
getattr
(
args
,
"decoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
decoder
,
args
.
decoder_freeze_module
)
logging
.
info
(
"freeze the decoder module: {}"
.
format
(
args
.
decoder_freeze_module
))
return
cls
(
encoder
,
decoder
)
def
get_normalized_probs
(
self
,
net_output
:
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
List
[
Optional
[
Tensor
]]]]],
log_probs
:
bool
,
sample
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs
=
self
.
get_normalized_probs_scriptable
(
net_output
,
log_probs
,
sample
)
lprobs
.
batch_first
=
True
return
lprobs
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
,
**
kwargs
):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out
=
self
.
encoder
(
src_tokens
,
src_lengths
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
=
prev_output_tokens
,
encoder_out
=
encoder_out
)
return
decoder_out
class
S2TMultiBranchEncoder
(
FairseqEncoder
):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
None
)
self
.
padding_idx
=
1
setattr
(
args
,
"encoder_layers"
,
args
.
junior_acoustic_encoder_layers
)
junior_encoder_type
=
args
.
junior_acoustic_encoder
if
junior_encoder_type
==
"transformer"
:
self
.
junior_acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
elif
junior_encoder_type
==
"pds"
:
self
.
junior_acoustic_encoder
=
PDSS2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
else
:
logger
.
error
(
"Unsupported junior acoustic architecture:
%
s."
%
junior_encoder_type
)
self
.
senior_acoustic_attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
if
self
.
senior_acoustic_attn_type
==
"rel_pos"
:
self
.
senior_acoustic_embed_positions
=
RelPositionalEncoding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
)
elif
self
.
senior_acoustic_attn_type
in
[
"rel_selfattn"
,
"rel_pos_legacy"
]:
self
.
senior_acoustic_embed_positions
=
LegacyRelPositionalEncoding
(
args
.
encoder_embed_dim
,
args
.
dropout
,
args
.
max_source_positions
)
else
:
# Use absolute positional embedding
self
.
senior_acoustic_embed_positions
=
None
setattr
(
args
,
"collaboration_mode"
,
args
.
encoder_collaboration_mode
)
self
.
senior_acoustic_encoder_layer_num
=
args
.
senior_acoustic_encoder_layers
self
.
senior_acoustic_encoder_layers
=
nn
.
ModuleList
(
[
S2TTransformerS2EncoderLayer
(
args
)
for
_
in
range
(
self
.
senior_acoustic_encoder_layer_num
)])
# adapter
self
.
adapter_temperature
=
args
.
adapter_temperature
strategy
=
{
"embed_norm"
:
getattr
(
args
,
"adapter_embed_norm"
,
False
),
"out_norm"
:
getattr
(
args
,
"adapter_out_norm"
,
False
),
"ctc_compress_strategy"
:
getattr
(
args
,
"ctc_compress_strategy"
,
None
),
"distribution_cutoff"
:
getattr
(
args
,
"adapter_distribution_cutoff"
,
None
),
"drop_prob"
:
getattr
(
args
,
"adapter_drop_prob"
,
0
),
}
self
.
adapter
=
Adapter
(
args
.
encoder_embed_dim
,
args
.
adapter
,
len
(
task
.
source_dictionary
),
strategy
=
strategy
)
assert
not
(
args
.
share_adapter_and_ctc
and
args
.
share_adapter_and_embed
),
"Can not be True at the same time"
if
args
.
share_adapter_and_ctc
and
hasattr
(
self
.
adapter
,
"embed_adapter"
):
self
.
adapter
.
embed_adapter
.
weight
=
self
.
acoustic_encoder
.
ctc
.
ctc_projection
.
weight
if
args
.
share_adapter_and_embed
and
hasattr
(
self
.
adapter
,
"embed_adapter"
):
self
.
adapter
.
embed_adapter
.
weight
=
embed_tokens
.
weight
# textual encoder
self
.
textual_embed_positions
=
PositionalEmbedding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
)
attn_type
=
args
.
encoder_attention_type
setattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
textual_encoder_layer_num
=
args
.
textual_encoder_layers
self
.
textual_encoder_layers
=
nn
.
ModuleList
(
[
TransformerS2EncoderLayer
(
args
)
for
_
in
range
(
self
.
textual_encoder_layer_num
)])
setattr
(
args
,
"encoder_attention_type"
,
attn_type
)
# collaboration
collaboration_step
=
args
.
collaboration_step
if
len
(
collaboration_step
.
split
(
":"
))
==
2
:
self
.
collaboration_step
=
[
int
(
s
)
for
s
in
collaboration_step
.
split
(
":"
)]
else
:
self
.
collaboration_step
=
[
1
,
1
]
self
.
collaboration_direction
=
args
.
collaboration_direction
self
.
acoustic_norm
=
LayerNorm
(
args
.
encoder_embed_dim
)
self
.
textual_norm
=
LayerNorm
(
args
.
encoder_embed_dim
)
def
forward
(
self
,
src_tokens
,
src_lengths
=
None
,
**
kwargs
):
junior_acoustic_encoder_out
=
self
.
junior_acoustic_encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
acoustic_x
=
junior_acoustic_encoder_out
[
"encoder_out"
][
0
]
acoustic_encoder_padding_mask
=
junior_acoustic_encoder_out
[
"encoder_padding_mask"
][
0
]
if
"ctc_logit"
in
junior_acoustic_encoder_out
and
len
(
junior_acoustic_encoder_out
[
"ctc_logit"
])
>
0
:
ctc_logit
=
junior_acoustic_encoder_out
[
"ctc_logit"
][
0
]
else
:
ctc_logit
=
None
x
=
(
acoustic_x
,
ctc_logit
)
adapter_x
,
adapter_encoder_padding_mask
=
self
.
adapter
(
x
,
acoustic_encoder_padding_mask
)
textual_x
=
adapter_x
+
self
.
textual_embed_positions
(
adapter_encoder_padding_mask
)
.
transpose
(
0
,
1
)
# textual_x = self.dropout_module(textual_x)
textual_encoder_padding_mask
=
adapter_encoder_padding_mask
senior_acoustic_encoder_idx
=
-
1
textual_encoder_idx
=
-
1
while
True
:
if
self
.
collaboration_direction
==
"acoustic"
:
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
)
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
s2
=
textual_x
,
s2_encoder_padding_mask
=
textual_encoder_padding_mask
)
elif
self
.
collaboration_direction
==
"textual"
:
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
)
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
xs
=
acoustic_x
,
s2_encoder_padding_mask
=
acoustic_encoder_padding_mask
)
elif
self
.
collaboration_direction
==
"both"
:
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
s2
=
textual_x
,
s2_encoder_padding_mask
=
textual_encoder_padding_mask
)
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
s2
=
acoustic_x
,
s2_encoder_padding_mask
=
acoustic_encoder_padding_mask
)
elif
self
.
collaboration_direction
==
"none"
:
for
_
in
range
(
self
.
collaboration_step
[
0
]):
senior_acoustic_encoder_idx
+=
1
acoustic_x
=
self
.
senior_acoustic_encoder_layers
[
senior_acoustic_encoder_idx
](
acoustic_x
,
encoder_padding_mask
=
acoustic_encoder_padding_mask
,
)
for
_
in
range
(
self
.
collaboration_step
[
1
]):
textual_encoder_idx
+=
1
textual_x
=
self
.
textual_encoder_layers
[
textual_encoder_idx
](
textual_x
,
encoder_padding_mask
=
textual_encoder_padding_mask
,
)
if
senior_acoustic_encoder_idx
==
self
.
senior_acoustic_encoder_layer_num
-
1
and
\
textual_encoder_idx
==
self
.
textual_encoder_layer_num
-
1
:
break
acoustic_x
=
self
.
acoustic_norm
(
acoustic_x
)
textual_x
=
self
.
acoustic_norm
(
textual_x
)
junior_acoustic_encoder_out
[
"encoder_out"
]
=
[
acoustic_x
]
junior_acoustic_encoder_out
[
"encoder_padding_mask"
]
=
[
acoustic_encoder_padding_mask
]
junior_acoustic_encoder_out
[
"s2_encoder_out"
]
=
[
textual_x
]
junior_acoustic_encoder_out
[
"s2_encoder_padding_mask"
]
=
[
textual_encoder_padding_mask
]
return
junior_acoustic_encoder_out
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if
len
(
encoder_out
[
"encoder_out"
])
==
0
:
new_encoder_out
=
[]
else
:
new_encoder_out
=
[
encoder_out
[
"encoder_out"
][
0
]
.
index_select
(
1
,
new_order
)]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
:
new_encoder_padding_mask
=
[]
else
:
new_encoder_padding_mask
=
[
encoder_out
[
"encoder_padding_mask"
][
0
]
.
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"s2_encoder_out"
])
==
0
:
new_s2_encoder_out
=
[]
else
:
new_s2_encoder_out
=
[
encoder_out
[
"s2_encoder_out"
][
0
]
.
index_select
(
1
,
new_order
)]
if
len
(
encoder_out
[
"s2_encoder_padding_mask"
])
==
0
:
new_s2_encoder_padding_mask
=
[]
else
:
new_s2_encoder_padding_mask
=
[
encoder_out
[
"s2_encoder_padding_mask"
][
0
]
.
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"encoder_embedding"
])
==
0
:
new_encoder_embedding
=
[]
else
:
new_encoder_embedding
=
[
encoder_out
[
"encoder_embedding"
][
0
]
.
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"src_tokens"
])
==
0
:
src_tokens
=
[]
else
:
src_tokens
=
[(
encoder_out
[
"src_tokens"
][
0
])
.
index_select
(
0
,
new_order
)]
if
len
(
encoder_out
[
"src_lengths"
])
==
0
:
src_lengths
=
[]
else
:
src_lengths
=
[(
encoder_out
[
"src_lengths"
][
0
])
.
index_select
(
0
,
new_order
)]
encoder_states
=
encoder_out
[
"encoder_states"
]
if
len
(
encoder_states
)
>
0
:
for
idx
,
state
in
enumerate
(
encoder_states
):
encoder_states
[
idx
]
=
state
.
index_select
(
1
,
new_order
)
return
{
"encoder_out"
:
new_encoder_out
,
# T x B x C
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"s2_encoder_out"
:
new_s2_encoder_out
,
# T x B x C
"s2_encoder_padding_mask"
:
new_s2_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
src_tokens
,
# B x T
"src_lengths"
:
src_lengths
,
# B x 1
}
@register_model_architecture
(
model_name
=
"s2t_multibranch"
,
arch_name
=
"s2t_multibranch"
)
def
base_architecture
(
args
):
# Convolutional subsampler
args
.
subsampling_type
=
getattr
(
args
,
"subsampling_type"
,
"conv1d"
)
args
.
subsampling_layers
=
getattr
(
args
,
"subsampling_layers"
,
2
)
args
.
subsampling_filter
=
getattr
(
args
,
"subsampling_filter"
,
1024
)
args
.
subsampling_kernel
=
getattr
(
args
,
"subsampling_kernel"
,
5
)
args
.
subsampling_stride
=
getattr
(
args
,
"subsampling_stride"
,
2
)
args
.
subsampling_norm
=
getattr
(
args
,
"subsampling_norm"
,
"none"
)
args
.
subsampling_activation
=
getattr
(
args
,
"subsampling_activation"
,
"glu"
)
# Transformer
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
True
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
True
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
args
.
dropout
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
args
.
dropout
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
tie_adaptive_weights
=
getattr
(
args
,
"tie_adaptive_weights"
,
False
)
args
.
tie_adaptive_proj
=
getattr
(
args
,
"tie_adaptive_proj"
,
False
)
args
.
adaptive_softmax_factor
=
getattr
(
args
,
"adaptive_softmax_factor"
,
4
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
"share_all_embeddings"
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
encoder_layerdrop
=
getattr
(
args
,
"encoder_layerdrop"
,
0.0
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0.0
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
encoder_no_scale_embedding
=
getattr
(
args
,
"encoder_no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
encoder_embed_linear
=
getattr
(
args
,
"encoder_embed_linear"
,
False
)
args
.
encoder_embed_norm
=
getattr
(
args
,
"encoder_embed_norm"
,
False
)
# CTC
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
share_ctc_and_embed
=
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
# Conformer
args
.
encoder_activation_fn
=
getattr
(
args
,
"encoder_activation_fn"
,
"relu"
)
args
.
macaron_style
=
getattr
(
args
,
"macaron_style"
,
False
)
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
False
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
args
.
cnn_module_norm
=
getattr
(
args
,
"cnn_module_norm"
,
"batch_norm"
)
# settings for DLCL
args
.
use_enc_dlcl
=
getattr
(
args
,
"use_enc_dlcl"
,
False
)
args
.
use_dec_dlcl
=
getattr
(
args
,
"use_dec_dlcl"
,
False
)
args
.
init_value
=
getattr
(
args
,
'init_value'
,
'avg'
)
args
.
weight_type
=
getattr
(
args
,
'weight_type'
,
'scalar'
)
args
.
encoder_learnable
=
getattr
(
args
,
'encoder_learnable'
,
True
)
args
.
normalize_embed
=
getattr
(
args
,
'normalize_embed'
,
False
)
args
.
history_dropout
=
getattr
(
args
,
'history_dropout'
,
0.0
)
args
.
history_window_size
=
getattr
(
args
,
'history_window_size'
,
-
1
)
# Relative position encoding
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
# local modeling
args
.
hard_mask_window
=
getattr
(
args
,
'hard_mask_window'
,
0
)
args
.
gauss_mask_sigma
=
getattr
(
args
,
'gauss_mask_sigma'
,
0
)
args
.
init_mask_weight
=
getattr
(
args
,
'init_mask_weight'
,
0
)
# interleaved CTC
args
.
interleaved_ctc_layers
=
getattr
(
args
,
"interleaved_ctc_layers"
,
None
)
args
.
interleaved_ctc_temperature
=
getattr
(
args
,
"interleaved_ctc_temperature"
,
1
)
args
.
interleaved_ctc_drop_prob
=
getattr
(
args
,
"interleaved_ctc_drop_prob"
,
0
)
# Semantics-augmented Encoding (sae)
args
.
sae_adapter
=
getattr
(
args
,
"sae_adapter"
,
"none"
)
args
.
target_sae_adapter
=
getattr
(
args
,
"target_sae_adapter"
,
args
.
sae_adapter
)
args
.
share_sae_and_ctc
=
getattr
(
args
,
"share_sae_and_ctc"
,
False
)
args
.
share_target_sae_and_ctc
=
getattr
(
args
,
"share_target_sae_and_ctc"
,
False
)
args
.
sae_drop_prob
=
getattr
(
args
,
"sae_drop_prob"
,
0
)
args
.
sae_distribution_cutoff
=
getattr
(
args
,
"sae_distribution_cutoff"
,
None
)
args
.
sae_distribution_hard
=
getattr
(
args
,
"sae_distribution_hard"
,
False
)
args
.
sae_gumbel
=
getattr
(
args
,
"sae_gumbel"
,
False
)
# mixup
args
.
inter_mixup
=
getattr
(
args
,
"inter_mixup"
,
False
)
args
.
inter_mixup_layer
=
getattr
(
args
,
"inter_mixup_layer"
,
None
)
args
.
inter_mixup_beta
=
getattr
(
args
,
"inter_mixup_beta"
,
0.5
)
args
.
inter_mixup_prob
=
getattr
(
args
,
"inter_mixup_prob"
,
1
)
args
.
inter_mixup_ratio
=
getattr
(
args
,
"inter_mixup_ratio"
,
0.3
)
args
.
inter_mixup_keep_org
=
getattr
(
args
,
"inter_mixup_keep_org"
,
False
)
# PDS
args
.
pds_stages
=
getattr
(
args
,
"pds_stages"
,
None
)
args
.
pds_layers
=
getattr
(
args
,
"pds_layers"
,
None
)
args
.
pds_ratios
=
getattr
(
args
,
"pds_ratios"
,
None
)
args
.
pds_ds_method
=
getattr
(
args
,
"pds_ds_method"
,
"conv"
)
args
.
pds_embed_dims
=
getattr
(
args
,
"pds_embed_dims"
,
None
)
args
.
pds_embed_norm
=
getattr
(
args
,
"pds_embed_norm"
,
False
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_cnn_kernel_sizes
=
getattr
(
args
,
"pds_cnn_kernel_sizes"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
None
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
None
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
args
.
pds_fusion
=
getattr
(
args
,
"pds_fusion"
,
False
)
args
.
pds_fusion_method
=
getattr
(
args
,
"pds_fusion_method"
,
"all_conv"
)
# dual
args
.
junior_acoustic_encoder
=
getattr
(
args
,
"junior_acoustic_encoder"
,
"transformer"
)
args
.
senior_acoustic_encoder
=
getattr
(
args
,
"senior_acoustic_encoder"
,
"transformer"
)
args
.
textual_encoder
=
getattr
(
args
,
"textual_encoder"
,
"transformer"
)
args
.
textual_encoder_dim
=
getattr
(
args
,
"textual_encoder"
,
args
.
encoder_embed_dim
)
args
.
junior_acoustic_encoder_layers
=
getattr
(
args
,
"junior_acoustic_encoder_layers"
,
6
)
args
.
senior_acoustic_encoder_layers
=
getattr
(
args
,
"senior_acoustic_encoder_layers"
,
6
)
args
.
textual_encoder_layers
=
getattr
(
args
,
"textual_encoder_layers"
,
6
)
args
.
collaboration_direction
=
getattr
(
args
,
"collaboration_direction"
,
"none"
)
args
.
collaboration_step
=
getattr
(
args
,
"collaboration_step"
,
"1:1"
)
args
.
encoder_collaboration_mode
=
getattr
(
args
,
"encoder_collaboration_mode"
,
"serial"
)
args
.
decoder_collaboration_mode
=
getattr
(
args
,
"decoder_collaboration_mode"
,
"serial"
)
args
.
encoder_league_s1_ratio
=
getattr
(
args
,
"encoder_league_s1_ratio"
,
0.5
)
args
.
encoder_league_s2_ratio
=
getattr
(
args
,
"encoder_league_s2_ratio"
,
0.5
)
args
.
encoder_league_drop_net
=
getattr
(
args
,
"encoder_league_drop_net"
,
False
)
args
.
encoder_league_drop_net_prob
=
getattr
(
args
,
"encoder_league_drop_net_prob"
,
0.0
)
args
.
encoder_league_drop_net_mix
=
getattr
(
args
,
"encoder_league_drop_net_mix"
,
False
)
args
.
decoder_league_s1_ratio
=
getattr
(
args
,
"decoder_league_s1_ratio"
,
0.5
)
args
.
decoder_league_s2_ratio
=
getattr
(
args
,
"decoder_league_s2_ratio"
,
0.5
)
args
.
decoder_league_drop_net
=
getattr
(
args
,
"decoder_league_drop_net"
,
False
)
args
.
decoder_league_drop_net_prob
=
getattr
(
args
,
"decoder_league_drop_net_prob"
,
0.0
)
args
.
decoder_league_drop_net_mix
=
getattr
(
args
,
"decoder_league_drop_net_mix"
,
False
)
# args.encoder_asr_ratio = getattr(args, "encoder_asr_ratio", 1.0)
# args.encoder_mt_ratio = getattr(args, "encoder_mt_ratio", 1.0)
# args.encoder_drop_net = getattr(args, "encoder_drop_net", False)
# args.encoder_drop_net_prob = getattr(args, "encoder_drop_net_prob", 1.0)
# args.encoder_drop_net_mix = getattr(args, "encoder_drop_net_mix", False)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_s"
)
def
s2t_multibranch_s
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
8
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_s_relative"
)
def
s2t_multibranch_s_relative
(
args
):
args
.
max_encoder_relative_length
=
100
args
.
k_only
=
True
s2t_multibranch_s
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_xs"
)
def
s2t_multibranch_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.3
)
s2t_multibranch_s
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_sp"
)
def
s2t_multibranch_sp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_multibranch_s
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_m"
)
def
s2t_multibranch_m
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
512
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.15
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_mp"
)
def
s2t_multibranch_mp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_multibranch_m
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_l"
)
def
s2t_multibranch_l
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
1024
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_multibranch"
,
"s2t_multibranch_lp"
)
def
s2t_multibranch_lp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_multibranch_l
(
args
)
fairseq/models/transformer_s2.py
查看文件 @
47e0f6e0
...
@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
...
@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
)
)
else
None
,
else
None
,
encoder_out_s2
=
encoder_out
[
"
encoder_out_s2
"
][
0
],
encoder_out_s2
=
encoder_out
[
"
s2_encoder_out
"
][
0
],
encoder_padding_mask_s2
=
encoder_out
[
"
encoder_padding_mask_s2
"
][
0
],
encoder_padding_mask_s2
=
encoder_out
[
"
s2_encoder_padding_mask
"
][
0
],
incremental_state
=
incremental_state
,
incremental_state
=
incremental_state
,
self_attn_mask
=
self_attn_mask
,
self_attn_mask
=
self_attn_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
...
...
fairseq/modules/__init__.py
查看文件 @
47e0f6e0
...
@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
...
@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
)
)
from
.convolution
import
ConvolutionModule
from
.convolution
import
ConvolutionModule
from
.s2t_transformer_layer
import
S2TTransformerEncoderLayer
from
.s2t_transformer_layer
import
S2TTransformerEncoderLayer
from
.s2t_transformer_s2_layer
import
S2TTransformerS2EncoderLayer
from
.pds_layer
import
PDSTransformerEncoderLayer
from
.pds_layer
import
PDSTransformerEncoderLayer
__all__
=
[
__all__
=
[
...
@@ -70,6 +71,7 @@ __all__ = [
...
@@ -70,6 +71,7 @@ __all__ = [
"BeamableMM"
,
"BeamableMM"
,
"CharacterTokenEmbedder"
,
"CharacterTokenEmbedder"
,
"S2TTransformerEncoderLayer"
,
"S2TTransformerEncoderLayer"
,
"S2TTransformerS2EncoderLayer"
,
"ConvolutionModule"
,
"ConvolutionModule"
,
"ConvTBC"
,
"ConvTBC"
,
"cross_entropy"
,
"cross_entropy"
,
...
...
fairseq/modules/s2t_transformer_s2_layer.py
0 → 100644
查看文件 @
47e0f6e0
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
fairseq.modules
import
(
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
,
ConvolutionModule
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
torch
import
Tensor
from
fairseq.modules.activations
import
get_activation_class
class
FeedForwardModule
(
torch
.
nn
.
Module
):
"""Positionwise feed forward layer used in conformer"""
def
__init__
(
self
,
input_feat
,
hidden_units
,
dropout1
,
dropout2
,
activation_fn
=
"relu"
,
bias
=
True
,
):
"""
Args:
input_feat: Input feature dimension
hidden_units: Hidden unit dimension
dropout1: dropout value for layer1
dropout2: dropout value for layer2
activation_fn: Name of activation function
bias: If linear layers should have bias
"""
super
(
FeedForwardModule
,
self
)
.
__init__
()
self
.
w_1
=
torch
.
nn
.
Linear
(
input_feat
,
hidden_units
,
bias
=
bias
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_units
,
input_feat
,
bias
=
bias
)
self
.
dropout1
=
torch
.
nn
.
Dropout
(
dropout1
)
self
.
dropout2
=
torch
.
nn
.
Dropout
(
dropout2
)
self
.
activation
=
get_activation_class
(
activation_fn
)
def
forward
(
self
,
x
):
"""
Args:
x: Input Tensor of shape T X B X C
Returns:
Tensor of shape T X B X C
"""
x
=
self
.
w_1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
dropout1
(
x
)
x
=
self
.
w_2
(
x
)
return
self
.
dropout2
(
x
)
class
S2TTransformerS2EncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def
__init__
(
self
,
args
):
super
()
.
__init__
()
self
.
args
=
args
embed_dim
=
args
.
encoder_embed_dim
ffn_dim
=
args
.
encoder_ffn_embed_dim
dropout
=
args
.
dropout
self
.
embed_dim
=
embed_dim
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
args
,
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
normalize_before
=
args
.
encoder_normalize_before
activation
=
getattr
(
args
,
'encoder_activation_fn'
,
'relu'
)
if
args
.
macaron_style
:
self
.
macaron_ffn
=
FeedForwardModule
(
embed_dim
,
ffn_dim
,
dropout
,
dropout
,
activation
)
self
.
macaron_norm
=
LayerNorm
(
embed_dim
)
self
.
ffn_scale
=
0.5
else
:
self
.
macaron_ffn
=
None
self
.
macaron_norm
=
None
self
.
ffn_scale
=
1.0
if
args
.
use_cnn_module
:
self
.
conv_norm
=
LayerNorm
(
embed_dim
)
self
.
conv_module
=
ConvolutionModule
(
self
.
embed_dim
,
self
.
embed_dim
,
depthwise_kernel_size
=
args
.
cnn_module_kernel
,
dropout
=
args
.
dropout
,
activation_fn
=
getattr
(
args
,
'activation_fn'
,
'swish'
),
norm_type
=
args
.
cnn_module_norm
)
self
.
final_norm
=
LayerNorm
(
embed_dim
)
else
:
self
.
conv_norm
=
None
self
.
conv_module
=
None
self
.
final_norm
=
None
self
.
ffn
=
FeedForwardModule
(
embed_dim
,
ffn_dim
,
dropout
,
dropout
,
activation
)
self
.
ffn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
encoder_attention_heads
,
kdim
=
getattr
(
args
,
"s2_encoder_embed_dim"
,
self
.
embed_dim
),
vdim
=
getattr
(
args
,
"s2_encoder_embed_dim"
,
self
.
embed_dim
),
dropout
=
args
.
attention_dropout
,
self_attention
=
False
,
)
def
build_self_attention
(
self
,
args
,
embed_dim
):
attention_heads
=
args
.
encoder_attention_heads
dropout
=
args
.
dropout
if
self
.
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
elif
self
.
attn_type
==
"rel_pos"
:
return
RelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"rel_pos_legacy"
:
return
LegacyRelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"rope"
:
return
RotaryPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
precision
=
args
.
fp16
)
elif
self
.
attn_type
==
"abs"
:
return
ESPNETMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
else
:
attn_func
=
MultiheadAttention
print
(
"The encoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
residual_connection
(
self
,
x
,
residual
):
return
residual
+
x
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"final_layer_norm"
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layer_norms.{}.{}"
.
format
(
name
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.{}.{}"
.
format
(
name
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
s2
=
None
,
s2_encoder_padding_mask
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_emb
:
Optional
[
Tensor
]
=
None
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
pos_emb (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
.
to
(
torch
.
bool
),
-
1e8
)
# whether to use macaron style
if
self
.
macaron_norm
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
x
=
self
.
macaron_ffn
(
x
)
x
=
residual
+
self
.
ffn_scale
*
x
if
not
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
s2
is
not
None
:
residual
=
x
x
=
self
.
s2_attn_norm
(
x
)
s2
=
self
.
s2_norm
(
s2
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
s2
,
value
=
s2
,
key_padding_mask
=
s2_encoder_padding_mask
,
need_weights
=
False
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
# convolution module
if
self
.
conv_module
is
not
None
:
residual
=
x
x
=
x
.
transpose
(
0
,
1
)
if
self
.
normalize_before
:
x
=
self
.
conv_norm
(
x
)
x
=
self
.
conv_module
(
x
)
x
=
x
.
transpose
(
0
,
1
)
x
=
residual
+
x
if
not
self
.
normalize_before
:
x
=
self
.
conv_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
ffn_norm
(
x
)
x
=
self
.
ffn
(
x
)
x
=
self
.
residual_connection
(
self
.
ffn_scale
*
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
ffn_norm
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
final_norm
(
x
)
return
x
fairseq/modules/transformer_s2_layer.py
查看文件 @
47e0f6e0
...
@@ -79,6 +79,8 @@ class TransformerS2EncoderLayer(nn.Module):
...
@@ -79,6 +79,8 @@ class TransformerS2EncoderLayer(nn.Module):
if
self
.
use_se
:
if
self
.
use_se
:
self
.
se_attn
=
SEAttention
(
self
.
embed_dim
,
16
)
self
.
se_attn
=
SEAttention
(
self
.
embed_dim
,
16
)
self
.
s2_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
s2_attn
=
MultiheadAttention
(
self
.
s2_attn
=
MultiheadAttention
(
self
.
embed_dim
,
self
.
embed_dim
,
args
.
encoder_attention_heads
,
args
.
encoder_attention_heads
,
...
@@ -87,26 +89,28 @@ class TransformerS2EncoderLayer(nn.Module):
...
@@ -87,26 +89,28 @@ class TransformerS2EncoderLayer(nn.Module):
dropout
=
args
.
attention_dropout
,
dropout
=
args
.
attention_dropout
,
self_attention
=
False
,
self_attention
=
False
,
)
)
self
.
s1_ratio
=
args
.
encoder_s1_ratio
self
.
s2_ratio
=
args
.
encoder_s2_ratio
self
.
encoder_collaboration_mode
=
args
.
encoder_collaboration_mode
self
.
league_s1_ratio
=
args
.
encoder_league_s1_ratio
self
.
league_s2_ratio
=
args
.
encoder_league_s2_ratio
self
.
drop_net
=
args
.
encoder
_drop_net
self
.
league_drop_net
=
args
.
encoder_league
_drop_net
self
.
drop_net_prob
=
args
.
encoder
_drop_net_prob
self
.
league_drop_net_prob
=
args
.
encoder_league
_drop_net_prob
self
.
drop_net_mix
=
args
.
encoder
_drop_net_mix
self
.
league_drop_net_mix
=
args
.
encoder_league
_drop_net_mix
def
get_ratio
(
self
):
def
get_ratio
(
self
):
if
self
.
drop_net
:
if
self
.
league_
drop_net
:
frand
=
float
(
uniform
(
0
,
1
))
frand
=
float
(
uniform
(
0
,
1
))
if
self
.
drop_net_mix
and
self
.
training
:
if
self
.
drop_net_mix
and
self
.
training
:
return
[
frand
,
1
-
frand
]
return
[
frand
,
1
-
frand
]
if
frand
<
self
.
drop_net_prob
and
self
.
training
:
if
frand
<
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
1
,
0
]
return
[
1
,
0
]
elif
frand
>
1
-
self
.
drop_net_prob
and
self
.
training
:
elif
frand
>
1
-
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
0
,
1
]
return
[
0
,
1
]
else
:
else
:
return
[
0.5
,
0.5
]
return
[
0.5
,
0.5
]
else
:
else
:
return
[
self
.
s1_ratio
,
self
.
s2_ratio
]
return
[
self
.
league_s1_ratio
,
self
.
league_
s2_ratio
]
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
return
quant_noise
(
...
@@ -186,8 +190,8 @@ class TransformerS2EncoderLayer(nn.Module):
...
@@ -186,8 +190,8 @@ class TransformerS2EncoderLayer(nn.Module):
def
forward
(
self
,
x
,
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
encoder_padding_mask
:
Optional
[
Tensor
],
x
2
=
None
,
s
2
=
None
,
x
2_encoder_padding_mask
=
None
,
s
2_encoder_padding_mask
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_emb
:
Optional
[
Tensor
]
=
None
):
pos_emb
:
Optional
[
Tensor
]
=
None
):
"""
"""
...
@@ -219,6 +223,7 @@ class TransformerS2EncoderLayer(nn.Module):
...
@@ -219,6 +223,7 @@ class TransformerS2EncoderLayer(nn.Module):
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
attn_x
=
x
if
self
.
attn_type
==
"rel_selfattn"
:
if
self
.
attn_type
==
"rel_selfattn"
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
x
,
_
=
self
.
self_attn
(
...
@@ -240,20 +245,34 @@ class TransformerS2EncoderLayer(nn.Module):
...
@@ -240,20 +245,34 @@ class TransformerS2EncoderLayer(nn.Module):
attn_mask
=
attn_mask
,
attn_mask
=
attn_mask
,
)
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
if
s2
is
None
or
self
.
encoder_collaboration_mode
!=
"parallel"
:
if
x2
is
not
None
:
x
=
self
.
residual_connection
(
x
,
residual
)
x2
,
_
=
self
.
s2_attn
(
if
not
self
.
normalize_before
:
query
=
x
,
x
=
self
.
self_attn_layer_norm
(
x
)
key
=
x2
,
value
=
x2
,
if
s2
is
not
None
:
key_padding_mask
=
x2_encoder_padding_mask
)
s2
=
self
.
s2_norm
(
s2
)
x2
=
self
.
dropout_module
(
x2
)
if
self
.
encoder_collaboration_mode
==
"serial"
:
ratio
=
self
.
get_ratio
()
residual
=
x
x
=
x
*
ratio
[
0
]
+
x2
*
ratio
[
1
]
x
=
self
.
s2_attn_norm
(
x
)
x
,
_
=
self
.
s2_attn
(
x
=
self
.
residual_connection
(
x
,
residual
)
query
=
x
,
if
not
self
.
normalize_before
:
key
=
s2
,
x
=
self
.
self_attn_layer_norm
(
x
)
value
=
s2
,
key_padding_mask
=
s2_encoder_padding_mask
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
elif
self
.
encoder_collaboration_mode
==
"parallel"
:
x2
,
_
=
self
.
s2_attn
(
query
=
attn_x
,
key
=
s2
,
value
=
s2
,
key_padding_mask
=
s2_encoder_padding_mask
)
x2
=
self
.
dropout_module
(
x2
)
ratio
=
self
.
get_ratio
()
x
=
x
*
ratio
[
0
]
+
x2
*
ratio
[
1
]
x
=
self
.
residual_connection
(
x
,
residual
)
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
...
@@ -341,11 +360,12 @@ class TransformerS2DecoderLayer(nn.Module):
...
@@ -341,11 +360,12 @@ class TransformerS2DecoderLayer(nn.Module):
self
.
s2_attn
=
MultiheadAttention
(
self
.
s2_attn
=
MultiheadAttention
(
self
.
embed_dim
,
self
.
embed_dim
,
args
.
decoder_attention_heads
,
args
.
decoder_attention_heads
,
kdim
=
getattr
(
args
,
"encoder_
x
2_dim"
,
self
.
embed_dim
),
kdim
=
getattr
(
args
,
"encoder_
s
2_dim"
,
self
.
embed_dim
),
vdim
=
getattr
(
args
,
"encoder_
x
2_dim"
,
self
.
embed_dim
),
vdim
=
getattr
(
args
,
"encoder_
s
2_dim"
,
self
.
embed_dim
),
dropout
=
args
.
attention_dropout
,
dropout
=
args
.
attention_dropout
,
encoder_decoder_attention
=
True
,
encoder_decoder_attention
=
True
,
)
)
self
.
s2_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
fc1
=
self
.
build_fc1
(
self
.
fc1
=
self
.
build_fc1
(
self
.
embed_dim
,
self
.
embed_dim
,
...
@@ -365,26 +385,27 @@ class TransformerS2DecoderLayer(nn.Module):
...
@@ -365,26 +385,27 @@ class TransformerS2DecoderLayer(nn.Module):
self
.
onnx_trace
=
False
self
.
onnx_trace
=
False
self
.
s1_ratio
=
args
.
encoder_s1_ratio
self
.
decoder_collaboration_mode
=
args
.
decoder_collaboration_mode
self
.
s2_ratio
=
args
.
encoder_s2_ratio
self
.
league_s1_ratio
=
args
.
decoder_league_s1_ratio
self
.
league_s2_ratio
=
args
.
decoder_league_s2_ratio
self
.
drop_net
=
args
.
encoder
_drop_net
self
.
league_drop_net
=
args
.
decoder_league
_drop_net
self
.
drop_net_prob
=
args
.
encoder
_drop_net_prob
self
.
league_drop_net_prob
=
args
.
decoder_league
_drop_net_prob
self
.
drop_net_mix
=
args
.
encoder
_drop_net_mix
self
.
league_drop_net_mix
=
args
.
decoder_league
_drop_net_mix
def
get_ratio
(
self
):
def
get_ratio
(
self
):
if
self
.
drop_net
:
if
self
.
league_
drop_net
:
frand
=
float
(
uniform
(
0
,
1
))
frand
=
float
(
uniform
(
0
,
1
))
if
self
.
drop_net_mix
and
self
.
training
:
if
self
.
drop_net_mix
and
self
.
training
:
return
[
frand
,
1
-
frand
]
return
[
frand
,
1
-
frand
]
if
frand
<
self
.
drop_net_prob
and
self
.
training
:
if
frand
<
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
1
,
0
]
return
[
1
,
0
]
elif
frand
>
1
-
self
.
drop_net_prob
and
self
.
training
:
elif
frand
>
1
-
self
.
league_
drop_net_prob
and
self
.
training
:
return
[
0
,
1
]
return
[
0
,
1
]
else
:
else
:
return
[
0.5
,
0.5
]
return
[
0.5
,
0.5
]
else
:
else
:
return
[
self
.
s1_ratio
,
self
.
s2_ratio
]
return
[
self
.
league_s1_ratio
,
self
.
league_
s2_ratio
]
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
q_noise
,
qn_block_size
)
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
q_noise
,
qn_block_size
)
...
@@ -551,6 +572,8 @@ class TransformerS2DecoderLayer(nn.Module):
...
@@ -551,6 +572,8 @@ class TransformerS2DecoderLayer(nn.Module):
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
x
=
self
.
encoder_attn_layer_norm
(
x
)
cross_attn_x
=
x
if
prev_attn_state
is
not
None
:
if
prev_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_attn_state
[:
2
]
prev_key
,
prev_value
=
prev_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
...
@@ -575,25 +598,45 @@ class TransformerS2DecoderLayer(nn.Module):
...
@@ -575,25 +598,45 @@ class TransformerS2DecoderLayer(nn.Module):
need_head_weights
=
need_head_weights
,
need_head_weights
=
need_head_weights
,
)
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
if
encoder_out_s2
is
None
or
self
.
decoder_collaboration_mode
!=
"parallel"
:
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
if
encoder_out_s2
is
not
None
:
if
encoder_out_s2
is
not
None
:
x2
,
_
=
self
.
s2_attn
(
if
self
.
decoder_collaboration_mode
==
"serial"
:
query
=
x
,
residual
=
x
key
=
encoder_out_s2
,
x
=
self
.
s2_attn_layer_norm
(
x
)
value
=
encoder_out_s2
,
x
,
_
=
self
.
s2_attn
(
key_padding_mask
=
encoder_padding_mask_s2
,
query
=
x
,
incremental_state
=
incremental_state
,
key
=
encoder_out_s2
,
static_kv
=
True
,
value
=
encoder_out_s2
,
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
key_padding_mask
=
encoder_padding_mask_s2
,
need_head_weights
=
need_head_weights
,
incremental_state
=
incremental_state
,
)
static_kv
=
True
,
x2
=
self
.
dropout_module
(
x2
)
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
ratios
=
self
.
get_ratio
()
need_head_weights
=
need_head_weights
,
x
=
ratios
[
0
]
*
x
+
ratios
[
1
]
*
x2
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
elif
self
.
decoder_collaboration_mode
==
"parallel"
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
x2
,
_
=
self
.
s2_attn
(
query
=
cross_attn_x
,
key
=
encoder_out_s2
,
value
=
encoder_out_s2
,
key_padding_mask
=
encoder_padding_mask_s2
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
need_head_weights
=
need_head_weights
,
)
x2
=
self
.
dropout_module
(
x2
)
ratios
=
self
.
get_ratio
()
x
=
ratios
[
0
]
*
x
+
ratios
[
1
]
*
x2
x
=
x
+
x2
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论