Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
55702466
Commit
55702466
authored
Feb 24, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the pure ctc arch for pds method
parent
bab6c520
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
513 行增加
和
632 行删除
+513
-632
egs/libri_trans/asr/conf/debug.yaml
+20
-28
egs/mustc/asr/conf/purectc_pds_base_8.yaml
+43
-0
examples/speech_to_text/prep_audio_data.py
+7
-2
fairseq/criterions/ctc.py
+1
-3
fairseq/models/__init__.py
+2
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+67
-28
fairseq/models/speech_to_text/s2t_ctc.py
+269
-166
fairseq/models/speech_to_text/s2t_transformer.py
+6
-5
fairseq/modules/pds_layer.py
+90
-393
fairseq/modules/s2t_transformer_layer.py
+8
-7
没有找到文件。
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
55702466
arch
:
s2t_transformer_s
arch
:
s2t_ctc
encoder-type
:
pds
#arch: pdss2t_transformer_s_8
encoder-embed-dim
:
256
pds-stages
:
4
ctc-layer
:
12
pds-layers
:
3_3_3_3
pds-ratios
:
2_2_1_2
pds-fusion
:
True
pds-fusion-method
:
all_conv
pds-embed-dims
:
256_256_256_256
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1_1
pds-kernel-sizes
:
5_5_5_5
pds-ffn-ratios
:
8_8_8_8
pds-attn-heads
:
4_4_4_4
share-decoder-input-output-embed
:
True
share-decoder-input-output-embed
:
True
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
2e-3
lr
:
2e-3
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
label_smoothed_cross_entropy_with_ctc
criterion
:
ctc
label_smoothing
:
0.1
subsampling-type
:
conv1d
subsmapling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
ctc-weight
:
0.2
intermedia-ctc-layers
:
6,9
intermedia-adapter
:
league
intermedia-ctc-weight
:
0.1
intermedia-drop-prob
:
0.5
ctc-self-distill-weight
:
0
post-process
:
sentencepiece
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-embed-dim
:
256
encoder-ffn-embed-dim
:
2048
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
encoder-layers
:
12
decoder-layers
:
6
decoder-layers
:
6
...
@@ -40,8 +38,3 @@ encoder-attention-heads: 4
...
@@ -40,8 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim
:
256
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
decoder-attention-heads
:
4
attention-dropout
:
0.1
activation-dropout
:
0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
egs/mustc/asr/conf/purectc_pds_base_8.yaml
0 → 100644
查看文件 @
55702466
arch
:
s2t_ctc
encoder-type
:
pds
encoder-embed-dim
:
256
pds-stages
:
4
ctc-layer
:
12
pds-layers
:
3_3_3_3
pds-ratios
:
2_2_1_2
pds-fusion
:
True
pds-fusion-method
:
all_conv
pds-embed-dims
:
256_256_256_256
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1_1
pds-kernel-sizes
:
5_5_5_5
pds-ffn-ratios
:
8_8_8_8
pds-attn-heads
:
4_4_4_4
share-decoder-input-output-embed
:
True
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
lr
:
2e-3
adam_betas
:
(0.9,0.98)
criterion
:
ctc
dropout
:
0.1
activation-fn
:
relu
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
decoder-layers
:
6
encoder-attention-heads
:
4
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
examples/speech_to_text/prep_audio_data.py
查看文件 @
55702466
...
@@ -137,7 +137,8 @@ class AudioDataset(Dataset):
...
@@ -137,7 +137,8 @@ class AudioDataset(Dataset):
for
i
,
segment
in
enumerate
(
seg_group
):
for
i
,
segment
in
enumerate
(
seg_group
):
offset
=
int
(
float
(
segment
[
"offset"
])
*
sample_rate
)
offset
=
int
(
float
(
segment
[
"offset"
])
*
sample_rate
)
n_frames
=
int
(
float
(
segment
[
"duration"
])
*
sample_rate
)
n_frames
=
int
(
float
(
segment
[
"duration"
])
*
sample_rate
)
_id
=
f
"{split}_{wav_path.stem}_{i}"
# _id = f"{split}_{wav_path.stem}_{i}"
_id
=
f
"{wav_path.stem}_{i}"
item
=
dict
()
item
=
dict
()
item
[
"audio"
]
=
wav_path
.
as_posix
()
item
[
"audio"
]
=
wav_path
.
as_posix
()
...
@@ -263,8 +264,12 @@ def process(args):
...
@@ -263,8 +264,12 @@ def process(args):
utt_id
=
item
[
'id'
]
utt_id
=
item
[
'id'
]
features_path
=
(
feature_root
/
f
"{utt_id}.npy"
)
.
as_posix
()
features_path
=
(
feature_root
/
f
"{utt_id}.npy"
)
.
as_posix
()
tag_features_path
=
(
feature_root
/
f
"{split}_{utt_id}.npy"
)
.
as_posix
()
if
os
.
path
.
exists
(
features_path
):
if
os
.
path
.
exists
(
tag_features_path
):
continue
if
os
.
path
.
exists
(
features_path
)
and
not
os
.
path
.
exists
(
tag_features_path
):
shutil
.
move
(
features_path
,
tag_features_path
)
continue
continue
waveform
,
sample_rate
,
_
=
dataset
.
get
(
idx
,
need_waveform
=
True
)
waveform
,
sample_rate
,
_
=
dataset
.
get
(
idx
,
need_waveform
=
True
)
...
...
fairseq/criterions/ctc.py
查看文件 @
55702466
...
@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask
...
@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask
from
fairseq.logging.meters
import
safe_round
from
fairseq.logging.meters
import
safe_round
@dataclass
@dataclass
class
CtcCriterionConfig
(
FairseqDataclass
):
class
CtcCriterionConfig
(
FairseqDataclass
):
zero_infinity
:
bool
=
field
(
zero_infinity
:
bool
=
field
(
...
@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass):
...
@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass):
)
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
post_process
:
str
=
field
(
post_process
:
str
=
field
(
default
=
"
letter
"
,
default
=
"
sentencepiece
"
,
metadata
=
{
metadata
=
{
"help"
:
"how to post process predictions into words. can be letter, "
"help"
:
"how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"wordpiece, BPE symbols, etc. "
...
...
fairseq/models/__init__.py
查看文件 @
55702466
...
@@ -47,6 +47,8 @@ __all__ = [
...
@@ -47,6 +47,8 @@ __all__ = [
"FairseqLanguageModel"
,
"FairseqLanguageModel"
,
"FairseqModel"
,
"FairseqModel"
,
"FairseqMultiModel"
,
"FairseqMultiModel"
,
"register_model"
,
"register_model_architecture"
]
]
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
55702466
...
@@ -19,6 +19,8 @@ from fairseq.modules import (
...
@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout
,
FairseqDropout
,
LayerNorm
,
LayerNorm
,
PositionalEmbedding
,
PositionalEmbedding
,
RelPositionalEncoding
,
LegacyRelPositionalEncoding
,
PDSTransformerEncoderLayer
,
PDSTransformerEncoderLayer
,
DownSampleConvolutionModule
DownSampleConvolutionModule
)
)
...
@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod
@staticmethod
def
add_args
(
parser
):
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
"""Add model-specific arguments to the parser."""
# input
parser
.
add_argument
(
"--conv-kernel-sizes"
,
type
=
str
,
metavar
=
"N"
,
help
=
"kernel sizes of Conv1d subsampling layers"
,
)
parser
.
add_argument
(
"--conv-channels"
,
type
=
int
,
metavar
=
"N"
,
help
=
"# of channels in Conv1d subsampling layers"
,
)
# Transformer
# Transformer
parser
.
add_argument
(
parser
.
add_argument
(
"--activation-fn"
,
"--activation-fn"
,
...
@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"reduced"
,
"reduced"
,
"rel_selfattn"
,
"rel_selfattn"
,
"relative"
,
"relative"
,
"rel_pos_legacy"
,
"rel_pos"
,
"rope"
,
"abs"
,
],
],
help
=
"transformer encoder self-attention layer type"
help
=
"transformer encoder self-attention layer type"
)
)
...
@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help
=
'dropout for history output'
)
help
=
'dropout for history output'
)
parser
.
add_argument
(
'--history-window-size'
,
type
=
int
,
default
=
'-1'
,
parser
.
add_argument
(
'--history-window-size'
,
type
=
int
,
default
=
'-1'
,
help
=
'how many past layers are considered. -1 means all'
)
help
=
'how many past layers are considered. -1 means all'
)
# CTC
parser
.
add_argument
(
"--ctc-layer"
,
default
=
0
,
type
=
int
,
help
=
"the position of the ctc loss"
,
)
# local modeling
# local modeling
parser
.
add_argument
(
parser
.
add_argument
(
'--hard-mask-window'
,
'--hard-mask-window'
,
...
@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
# Conformer setting
# Conformer setting
parser
.
add_argument
(
parser
.
add_argument
(
"--encoder-activation-fn"
,
type
=
str
,
default
=
"relu"
,
choices
=
utils
.
get_available_activation_fns
(),
help
=
"activation function to use"
,
)
parser
.
add_argument
(
"--macaron-style"
,
"--macaron-style"
,
default
=
False
,
default
=
False
,
type
=
bool
,
type
=
bool
,
...
@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"The legacy relative positional encoding will be deprecated in the future."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816."
,
"More Details can be found in https://github.com/espnet/espnet/pull/2816."
,
)
)
# CTC
# CNN module
parser
.
add_argument
(
"--ctc-layer"
,
default
=
0
,
type
=
int
,
help
=
"the position of the ctc loss"
,
)
# Conformer module
parser
.
add_argument
(
parser
.
add_argument
(
"--use-cnn-module"
,
"--use-cnn-module"
,
default
=
False
,
default
=
False
,
...
@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
str
,
type
=
str
,
help
=
"use the position embedding or not before each encoding"
,
help
=
"use the position embedding or not before each encoding"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--pds-attn-heads"
,
"--pds-attn-heads"
,
type
=
str
,
type
=
str
,
...
@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
str
,
type
=
str
,
help
=
"use the ctc after each stage"
,
help
=
"use the ctc after each stage"
,
)
)
# intermedia ctc
parser
.
add_argument
(
parser
.
add_argument
(
"--intermedia-ctc-layers"
,
"--intermedia-ctc-layers"
,
default
=
None
,
default
=
None
,
...
@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
str
,
type
=
str
,
help
=
"type of intermedia adapter"
,
help
=
"type of intermedia adapter"
,
)
)
parser
.
add_argument
(
"--intermedia-distribution-cutoff"
,
default
=-
1
,
type
=
int
,
help
=
"cutoff of the distribution"
,
)
parser
.
add_argument
(
"--intermedia-drop-prob"
,
default
=
0
,
type
=
float
,
help
=
"probability of dropping the followed layers"
,
)
pass
pass
@classmethod
@classmethod
...
@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
f
"loaded pretrained encoder from: "
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
f
"{args.load_pretrained_encoder_from}"
)
)
return
encoder
return
encoder
...
@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
.
pds_kernel_sizes
=
[
int
(
n
)
for
n
in
args
.
pds_kernel_sizes
.
split
(
"_"
)]
self
.
pds_kernel_sizes
=
[
int
(
n
)
for
n
in
args
.
pds_kernel_sizes
.
split
(
"_"
)]
self
.
pds_embed_norm
=
args
.
pds_embed_norm
self
.
pds_embed_norm
=
args
.
pds_embed_norm
self
.
pds_position_embed
=
[
int
(
n
)
for
n
in
args
.
pds_position_embed
.
split
(
"_"
)]
self
.
pds_position_embed
=
[
int
(
n
)
for
n
in
args
.
pds_position_embed
.
split
(
"_"
)]
self
.
pds_attn_heads
=
[
int
(
n
)
for
n
in
args
.
pds_attn_heads
.
split
(
"_"
)]
self
.
pds_attn_heads
=
[
int
(
n
)
for
n
in
args
.
pds_attn_heads
.
split
(
"_"
)]
self
.
pds_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_ffn_ratios
.
split
(
"_"
)]
self
.
pds_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_ffn_ratios
.
split
(
"_"
)]
if
self
.
attn_type
==
"reduced"
:
if
self
.
attn_type
==
"reduced"
:
...
@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if
args
.
no_scale_embedding
:
if
args
.
no_scale_embedding
:
self
.
embed_scale
=
1.0
self
.
embed_scale
=
1.0
# down-sampling
downsampling
=
Downsampling
(
downsampling
=
Downsampling
(
self
.
pds_ds_method
,
self
.
pds_ds_method
,
self
.
pds_embed_norm
,
self
.
pds_embed_norm
,
...
@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
stride
=
ds_ratio
,
stride
=
ds_ratio
,
padding
=
(
kernel_size
-
1
)
//
2
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
)
# position encoding
if
use_pos_embed
:
if
use_pos_embed
:
pos_embed
=
PositionalEmbedding
(
args
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
)
if
self
.
attn_type
==
"rel_pos"
:
pos_embed
=
RelPositionalEncoding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
)
elif
self
.
attn_type
in
[
"rel_selfattn"
,
"rel_pos_legacy"
]:
pos_embed
=
LegacyRelPositionalEncoding
(
args
.
encoder_embed_dim
,
args
.
dropout
,
args
.
max_source_positions
)
elif
self
.
attn_type
==
"rope"
:
self
.
embed_positions
=
None
else
:
# Use absolute positional embedding
pos_embed
=
PositionalEmbedding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
)
else
:
else
:
pos_embed
=
None
pos_embed
=
None
...
@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
PDSTransformerEncoderLayer
(
args
,
embed_dim
,
embed_dim
*
ffn_ratio
,
num_head
,
attn_ds_ratio
)
PDSTransformerEncoderLayer
(
args
,
embed_dim
,
embed_dim
*
ffn_ratio
,
num_head
,
attn_ds_ratio
)
for
_
in
range
(
num_layers
)])
for
_
in
range
(
num_layers
)])
# representation fusion
fusion_pre_layer_norm
=
None
fusion_pre_layer_norm
=
None
fusion_post_layer_norm
=
None
fusion_post_layer_norm
=
None
fusion_downsampling
=
None
fusion_downsampling
=
None
...
@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
.
fusion_weight
.
data
=
self
.
fusion_weight
.
data
/
self
.
fusion_weight
.
data
.
sum
(
0
,
keepdim
=
True
)
self
.
fusion_weight
.
data
=
self
.
fusion_weight
.
data
/
self
.
fusion_weight
.
data
.
sum
(
0
,
keepdim
=
True
)
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
(
getattr
(
args
,
"criterion"
,
""
)
==
"ctc"
)
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
""
))
and
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
))
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
))
if
self
.
use_ctc
:
if
self
.
use_ctc
:
self
.
ctc_layer
=
(
args
.
ctc_layer
+
args
.
encoder_layers
)
%
args
.
encoder_layers
self
.
ctc_layer
=
(
args
.
ctc_layer
+
args
.
encoder_layers
)
%
args
.
encoder_layers
...
@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout
# add the position encoding and dropout
if
pos_embed
:
if
pos_embed
:
positions
=
pos_embed
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
x
+=
positions
positions
=
pos_embed
(
x
)
positions
=
self
.
dropout
(
positions
)
elif
self
.
attn_type
==
"rope"
:
positions
=
None
else
:
positions
=
pos_embed
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
x
+=
positions
positions
=
None
else
:
else
:
positions
=
None
positions
=
None
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
55702466
...
@@ -363,6 +363,85 @@ class S2TCTCModel(FairseqEncoderModel):
...
@@ -363,6 +363,85 @@ class S2TCTCModel(FairseqEncoderModel):
parser
.
add_argument
(
'--cl-dropout-strategy'
,
parser
.
add_argument
(
'--cl-dropout-strategy'
,
type
=
str
,
type
=
str
,
help
=
'interleaved dropout probability'
)
help
=
'interleaved dropout probability'
)
# pds setting
parser
.
add_argument
(
"--pds-stages"
,
type
=
int
,
help
=
"the number of the stage"
,
)
parser
.
add_argument
(
"--pds-layers"
,
type
=
str
,
help
=
"the number of the encoder layers in each stage"
,
)
parser
.
add_argument
(
"--pds-ratios"
,
type
=
str
,
help
=
"the ratio of the down-sampling in each stage"
,
)
parser
.
add_argument
(
"--pds-ds-method"
,
type
=
str
,
choices
=
[
"glu"
,
"conv"
,
"proj"
,
"fusion"
],
help
=
"the down-sampling method"
,
)
parser
.
add_argument
(
"--pds-embed-dims"
,
type
=
str
,
help
=
"the embedding dimension in each stage"
,
)
parser
.
add_argument
(
"--pds-kernel-sizes"
,
type
=
str
,
help
=
"the kernel size of the down-sampling module in each stage"
,
)
parser
.
add_argument
(
"--pds-embed-norm"
,
action
=
"store_true"
,
help
=
"use layer norm in the down-sampling module"
,
)
parser
.
add_argument
(
"--pds-position-embed"
,
type
=
str
,
help
=
"use the position embedding or not before each encoding"
,
)
parser
.
add_argument
(
"--pds-attn-heads"
,
type
=
str
,
help
=
"the number of the attention heads in each stage"
,
)
parser
.
add_argument
(
"--pds-attn-ds-ratio"
,
type
=
str
,
help
=
"the ratio of the down-sampling in the self attention module"
,
)
parser
.
add_argument
(
"--pds-ffn-ratios"
,
type
=
str
,
help
=
"the ratio of the ffn in each stage"
,
)
parser
.
add_argument
(
"--pds-fusion"
,
action
=
"store_true"
,
help
=
"use the representation fusion method"
,
)
parser
.
add_argument
(
"--pds-fusion-method"
,
type
=
str
,
help
=
"the fusion method"
,
)
parser
.
add_argument
(
"--pds-dropout"
,
type
=
float
,
help
=
"dropout in each stage"
,
)
parser
.
add_argument
(
"--pds-ctc"
,
type
=
str
,
help
=
"use the ctc after each stage"
,
)
# intermedia CTC loss
# intermedia CTC loss
parser
.
add_argument
(
parser
.
add_argument
(
"--intermedia-ctc-layers"
,
"--intermedia-ctc-layers"
,
...
@@ -388,6 +467,14 @@ class S2TCTCModel(FairseqEncoderModel):
...
@@ -388,6 +467,14 @@ class S2TCTCModel(FairseqEncoderModel):
type
=
float
,
type
=
float
,
help
=
"probability of dropping the followed layers"
,
help
=
"probability of dropping the followed layers"
,
)
)
# encoder
parser
.
add_argument
(
"--encoder-type"
,
default
=
"transformer"
,
type
=
str
,
help
=
"encoder type"
,
)
pass
pass
@classmethod
@classmethod
...
@@ -452,79 +539,90 @@ class S2TCTCEncoder(FairseqEncoder):
...
@@ -452,79 +539,90 @@ class S2TCTCEncoder(FairseqEncoder):
def
__init__
(
self
,
args
,
task
=
None
):
def
__init__
(
self
,
args
,
task
=
None
):
super
()
.
__init__
(
None
)
super
()
.
__init__
(
None
)
dim
=
args
.
encoder_embed_dim
encoder_type
=
getattr
(
args
,
"encoder_type"
,
"transformer"
)
self
.
dropout_module
=
FairseqDropout
(
if
encoder_type
==
"transformer"
:
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
from
.s2t_transformer
import
S2TTransformerEncoder
)
self
.
encoder
=
S2TTransformerEncoder
(
args
,
task
)
self
.
embed_scale
=
math
.
sqrt
(
dim
)
elif
encoder_type
==
"pds"
:
if
args
.
no_scale_embedding
:
from
.pdss2t_transformer
import
PDSS2TTransformerEncoder
self
.
embed_scale
=
1.0
self
.
encoder
=
PDSS2TTransformerEncoder
(
args
,
task
)
self
.
padding_idx
=
1
self
.
subsample
=
subsampling
(
args
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
if
self
.
attn_type
==
"rel_pos"
:
self
.
embed_positions
=
RelPositionalEncoding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
)
elif
self
.
attn_type
in
[
"rel_selfattn"
,
"rel_pos_legacy"
]:
self
.
embed_positions
=
LegacyRelPositionalEncoding
(
args
.
encoder_embed_dim
,
args
.
dropout
,
args
.
max_source_positions
)
elif
self
.
attn_type
==
"rope"
:
self
.
embed_positions
=
None
else
:
# Use absolute positional embedding
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
)
self
.
layers
=
nn
.
ModuleList
(
[
S2TTransformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
encoder_layers
)]
)
if
args
.
encoder_normalize_before
:
self
.
layer_norm
=
LayerNorm
(
dim
)
else
:
else
:
self
.
layer_norm
=
None
logger
.
error
(
"Unsupported architecture:
%
s."
%
encoder_type
)
if
args
.
use_enc_dlcl
:
return
self
.
history
=
DynamicLinearCombination
(
args
,
is_encoder
=
True
)
# dim = args.encoder_embed_dim
else
:
# self.dropout_module = FairseqDropout(
self
.
history
=
None
# p=args.dropout, module_name=self.__class__.__name__
# )
self
.
ctc
=
CTC
(
dim
,
# self.embed_scale = math.sqrt(dim)
dictionary_size
=
len
(
task
.
source_dictionary
),
# if args.no_scale_embedding:
dropout
=
args
.
dropout
,
# self.embed_scale = 1.0
)
# self.padding_idx = 1
#
# gather cosine similarity of the representation
# self.subsample = subsampling(args)
self
.
gather_cos_sim
=
getattr
(
args
,
"gather_cos_sim"
,
False
)
#
# self.gather_cos_sim = True
# self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self
.
dis
=
2
#
self
.
cos_sim
=
dict
()
# if self.attn_type == "rel_pos":
# self.embed_positions = RelPositionalEncoding(
self
.
intermedia_ctc_layers
=
[]
# args.max_source_positions, args.encoder_embed_dim
# )
if
args
.
intermedia_ctc_layers
is
not
None
:
# elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
intermedia_ctc_layers
=
args
.
intermedia_ctc_layers
.
split
(
","
)
# self.embed_positions = LegacyRelPositionalEncoding(
for
layer_idx
in
intermedia_ctc_layers
:
# args.encoder_embed_dim, args.dropout, args.max_source_positions
layer_idx
=
int
(
layer_idx
)
# )
if
layer_idx
<=
0
:
# elif self.attn_type == "rope":
layer_idx
+=
args
.
encoder_layers
# self.embed_positions = None
self
.
intermedia_ctc_layers
.
append
(
layer_idx
)
# else: # Use absolute positional embedding
# self.embed_positions = PositionalEmbedding(
logger
.
info
(
"Intermedia CTC loss in layer
%
d"
%
layer_idx
)
# args.max_source_positions, args.encoder_embed_dim, self.padding_idx
# )
strategy
=
None
#
if
args
.
intermedia_adapter
==
"shrink"
:
# self.layers = nn.ModuleList(
strategy
=
getattr
(
args
,
"ctc_compress_strategy"
,
"avg"
)
# [S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
elif
args
.
intermedia_adapter
==
"league"
:
# )
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
-
1
)
#
self
.
adapter
=
Adapter
(
dim
,
args
.
intermedia_adapter
,
# if args.encoder_normalize_before:
task
.
source_dictionary
,
strategy
=
strategy
)
# self.layer_norm = LayerNorm(dim)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
# else:
# self.layer_norm = None
#
# if args.use_enc_dlcl:
# self.history = DynamicLinearCombination(args, is_encoder=True)
# else:
# self.history = None
#
# self.ctc = CTC(dim,
# dictionary_size=len(task.source_dictionary),
# dropout=args.dropout,
# )
#
# # gather cosine similarity of the representation
# self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
# # self.gather_cos_sim = True
# self.dis = 2
# self.cos_sim = dict()
#
# self.intermedia_ctc_layers = []
#
# if args.intermedia_ctc_layers is not None:
# intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
# for layer_idx in intermedia_ctc_layers:
# layer_idx = int(layer_idx)
# if layer_idx <= 0:
# layer_idx += args.encoder_layers
# self.intermedia_ctc_layers.append(layer_idx)
#
# logger.info("Intermedia CTC loss in layer %d" % layer_idx)
#
# strategy = None
# if args.intermedia_adapter == "shrink":
# strategy = getattr(args, "ctc_compress_strategy", "avg")
# elif args.intermedia_adapter == "league":
# strategy = getattr(args, "intermedia_distribution_cutoff", -1)
# self.adapter = Adapter(dim, args.intermedia_adapter,
# task.source_dictionary, strategy=strategy)
# self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
def
add_to_dict
(
self
,
x
,
dis
,
idx
):
def
add_to_dict
(
self
,
x
,
dis
,
idx
):
sim
=
0
sim
=
0
...
@@ -546,102 +644,107 @@ class S2TCTCEncoder(FairseqEncoder):
...
@@ -546,102 +644,107 @@ class S2TCTCEncoder(FairseqEncoder):
def
forward
(
self
,
src_tokens
,
src_lengths
,
**
kwargs
):
def
forward
(
self
,
src_tokens
,
src_lengths
,
**
kwargs
):
if
self
.
history
is
not
None
:
return
self
.
encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
self
.
history
.
clean
()
#
# if self.history is not None:
# gather cosine similarity
# self.history.clean()
cos_sim_idx
=
-
1
#
dis
=
self
.
dis
# # gather cosine similarity
if
self
.
gather_cos_sim
:
# cos_sim_idx = -1
self
.
add_to_dict
(
src_tokens
.
transpose
(
0
,
1
),
dis
,
cos_sim_idx
)
# dis = self.dis
# if self.gather_cos_sim:
# down-sampling
# self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
#
# (B, T, D) -> (T, B, D)
# # down-sampling
x
=
x
.
transpose
(
0
,
1
)
# x, input_lengths = self.subsample(src_tokens, src_lengths)
# # (B, T, D) -> (T, B, D)
# embedding scaling
# x = x.transpose(0, 1)
x
=
self
.
embed_scale
*
x
#
# # embedding scaling
# padding and position embedding
# x = self.embed_scale * x
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
#
# # padding and position embedding
if
self
.
attn_type
in
[
"rel_selfattn"
,
"rel_pos"
,
"rel_pos_legacy"
]:
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions
=
self
.
embed_positions
(
x
)
#
# if self.attn_type in ["rel_selfattn", "rel_pos", "rel_pos_legacy"]:
elif
self
.
attn_type
==
"rope"
:
# positions = self.embed_positions(x)
positions
=
None
#
# elif self.attn_type == "rope":
else
:
# positions = None
positions
=
self
.
embed_positions
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
#
x
+=
positions
# else:
positions
=
None
# positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
# x += positions
x
=
self
.
dropout_module
(
x
)
# positions = None
#
# add emb into history
# x = self.dropout_module(x)
if
self
.
history
is
not
None
:
#
self
.
history
.
push
(
x
)
# # add emb into history
# if self.history is not None:
# gather cosine similarity
# self.history.push(x)
cos_sim_idx
=
(
cos_sim_idx
+
10
)
//
10
*
10
-
1
#
if
self
.
gather_cos_sim
:
# # gather cosine similarity
cos_sim_idx
+=
1
# cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
self
.
add_to_dict
(
x
,
dis
,
cos_sim_idx
)
# if self.gather_cos_sim:
# cos_sim_idx += 1
layer_idx
=
0
# self.add_to_dict(x, dis, cos_sim_idx)
intermedia_ctc_logits
=
[]
#
for
layer
in
self
.
layers
:
# layer_idx = 0
layer_idx
+=
1
# intermedia_ctc_logits = []
# for layer in self.layers:
if
self
.
history
is
not
None
:
# layer_idx += 1
x
=
self
.
history
.
pop
()
#
# if self.history is not None:
# encoder layer
# x = self.history.pop()
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
#
# # encoder layer
# interleave CTC
# x = layer(x, encoder_padding_mask, pos_emb=positions)
if
layer_idx
in
self
.
intermedia_ctc_layers
:
#
if
self
.
intermedia_drop_prob
>
0
:
# # interleave CTC
p
=
torch
.
rand
(
1
)
.
uniform_
()
# if layer_idx in self.intermedia_ctc_layers:
if
p
<
self
.
intermedia_drop_prob
:
# if self.intermedia_drop_prob > 0:
break
# p = torch.rand(1).uniform_()
# if p < self.intermedia_drop_prob:
norm_x
=
self
.
layer_norm
(
x
)
# break
logit
=
self
.
ctc
(
norm_x
)
#
intermedia_ctc_logits
.
append
(
logit
)
# norm_x = self.layer_norm(x)
# logit = self.ctc(norm_x)
prob
=
F
.
softmax
(
logit
,
dim
=-
1
,
dtype
=
torch
.
float32
)
# intermedia_ctc_logits.append(logit)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
#
# prob = F.softmax(logit, dim=-1, dtype=torch.float32)
# gather cosine similarity
# x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if
self
.
gather_cos_sim
:
#
cos_sim_idx
+=
1
# # gather cosine similarity
self
.
add_to_dict
(
x
,
dis
,
cos_sim_idx
)
# if self.gather_cos_sim:
# cos_sim_idx += 1
if
self
.
history
is
not
None
:
# self.add_to_dict(x, dis, cos_sim_idx)
self
.
history
.
push
(
x
)
#
# if self.history is not None:
if
self
.
history
is
not
None
:
# self.history.push(x)
x
=
self
.
history
.
pop
()
#
# if self.history is not None:
if
self
.
layer_norm
is
not
None
:
# x = self.history.pop()
x
=
self
.
layer_norm
(
x
)
#
# if self.layer_norm is not None:
ctc_logit
=
self
.
ctc
(
x
)
# x = self.layer_norm(x)
#
return
{
# ctc_logit = self.ctc(x)
"encoder_out"
:
[
x
],
# T x B x C
#
"ctc_logit"
:
[
ctc_logit
],
# B x T x C
# return {
"intermedia_ctc_logits"
:
intermedia_ctc_logits
,
# B x T x C
# "encoder_out": [x], # T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
# "ctc_logit": [ctc_logit], # B x T x C
"encoder_embedding"
:
[],
# B x T x C
# "intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"encoder_states"
:
[],
# List[T x B x C]
# "encoder_padding_mask": [encoder_padding_mask], # B x T
"src_tokens"
:
[],
# "encoder_embedding": [], # B x T x C
"src_lengths"
:
[],
# "encoder_states": [], # List[T x B x C]
}
# "src_tokens": [],
# "src_lengths": [],
# }
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
self
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
return
new_encoder_out
=
(
new_encoder_out
=
(
[]
if
len
(
encoder_out
[
"encoder_out"
])
==
0
[]
if
len
(
encoder_out
[
"encoder_out"
])
==
0
else
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"encoder_out"
]]
else
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"encoder_out"
]]
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
55702466
...
@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
f
"{args.load_pretrained_encoder_from}"
)
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
return
encoder
return
encoder
...
@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
padding_idx
=
1
self
.
padding_idx
=
1
self
.
subsample
=
subsampling
(
args
)
self
.
subsample
=
subsampling
(
args
)
self
.
linear
=
nn
.
Linear
(
dim
,
dim
)
#
self.linear = nn.Linear(dim, dim)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
...
@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
history
=
None
self
.
history
=
None
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
(
getattr
(
args
,
"criterion"
,
""
)
==
"ctc"
)
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
""
))
and
(
getattr
(
args
,
"ctc_weight"
,
0
)
>
0
))
((
"ctc"
in
getattr
(
args
,
"criterion"
,
""
))
and
(
getattr
(
args
,
"ctc_weight"
,
0
)
>
0
))
if
self
.
use_ctc
:
if
self
.
use_ctc
:
self
.
ctc_layer
=
args
.
ctc_layer
self
.
ctc_layer
=
args
.
ctc_layer
...
@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x
+=
positions
x
+=
positions
positions
=
None
positions
=
None
x
=
self
.
linear
(
x
)
#
x = self.linear(x)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
# add emb into history
# add emb into history
...
...
fairseq/modules/pds_layer.py
查看文件 @
55702466
# Copyright (c) Facebook, Inc. and its affiliates.
from
typing
import
Optional
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq.modules
import
(
from
fairseq.modules
import
(
LayerNorm
,
LayerNorm
,
MultiheadAttention
,
MultiheadAttention
,
ReducedMultiheadAttention
,
RelPositionMultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
,
RelativeMultiheadAttention
,
ConvolutionModule
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
LocalMultiheadAttention
,
LocalMultiheadAttention
,
ConvolutionModule
ReducedMultiheadAttention
,
RotaryPositionMultiHeadedAttention
,
)
)
from
fairseq.modules.s2t_transformer_layer
import
FeedForwardModule
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -40,104 +38,76 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -40,104 +38,76 @@ class PDSTransformerEncoderLayer(nn.Module):
def
__init__
(
self
,
args
,
embed_dim
,
ffn_embed_dim
,
num_head
,
att_sample_ratio
=
1
):
def
__init__
(
self
,
args
,
embed_dim
,
ffn_embed_dim
,
num_head
,
att_sample_ratio
=
1
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
args
=
args
self
.
args
=
args
self
.
embed_dim
=
embed_dim
self
.
encoder_ffn_embed_dim
=
ffn_embed_dim
embed_dim
=
embed_dim
ffn_dim
=
args
.
encoder_ffn_embed_dim
dropout
=
args
.
dropout
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
args
,
self
.
embed_dim
,
num_head
,
att_sample_ratio
)
self
.
self_attn
=
self
.
build_self_attention
(
args
,
embed_dim
,
num_head
,
att_sample_ratio
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
self
.
dropout_module
=
FairseqDropout
(
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
dropout
,
module_name
=
self
.
__class__
.
__name__
)
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
self
.
normalize_before
=
args
.
encoder_normalize_before
activation
=
getattr
(
args
,
'activation_fn'
,
'relu'
)
or
"relu"
activation
=
getattr
(
args
,
'encoder_activation_fn'
,
'relu'
)
)
activation_dropout_p
=
getattr
(
args
,
"activation_dropout"
,
0
)
or
0
if
activation_dropout_p
==
0
:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p
=
getattr
(
args
,
"relu_dropout"
,
0
)
or
0
self
.
activation_dropout_module
=
FairseqDropout
(
float
(
activation_dropout_p
),
module_name
=
self
.
__class__
.
__name__
)
args
.
macaron_style
=
getattr
(
args
,
"macaron_style"
,
False
)
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
False
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
if
args
.
macaron_style
:
if
args
.
macaron_style
:
self
.
macaron_fc1
=
self
.
build_fc1
(
self
.
macaron_ffn
=
FeedForwardModule
(
self
.
embed_dim
,
embed_dim
,
args
.
encoder_ffn_embed_dim
,
ffn_dim
,
self
.
quant_noise
,
dropout
,
self
.
quant_noise_block_size
,
dropout
,
)
activation
self
.
macaron_fc2
=
self
.
build_fc2
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
)
self
.
macaron_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
macaron_norm
=
LayerNorm
(
embed_dim
)
self
.
ffn_scale
=
0.5
self
.
ffn_scale
=
0.5
else
:
else
:
self
.
macaron_fc1
=
None
self
.
macaron_ffn
=
None
self
.
macaron_fc2
=
None
self
.
macaron_norm
=
None
self
.
macaron_norm
=
None
self
.
ffn_scale
=
1.0
self
.
ffn_scale
=
1.0
if
args
.
use_cnn_module
:
if
args
.
use_cnn_module
:
self
.
conv_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
conv_norm
=
LayerNorm
(
embed_dim
)
self
.
conv_module
=
ConvolutionModule
(
self
.
conv_module
=
ConvolutionModule
(
self
.
embed_dim
,
embed_dim
,
args
.
cnn_module_kernel
)
embed_dim
,
self
.
final_norm
=
LayerNorm
(
self
.
embed_dim
)
depthwise_kernel_size
=
args
.
cnn_module_kernel
,
dropout
=
args
.
dropout
,
activation_fn
=
getattr
(
args
,
'activation_fn'
,
'swish'
))
self
.
final_norm
=
LayerNorm
(
embed_dim
)
else
:
else
:
self
.
conv_norm
=
None
self
.
conv_norm
=
None
self
.
conv_module
=
None
self
.
conv_module
=
None
self
.
final_norm
=
None
self
.
final_norm
=
None
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
ffn
=
FeedForwardModule
(
self
.
fc1
=
self
.
build_fc1
(
embed_dim
,
self
.
embed_dim
,
ffn_dim
,
self
.
encoder_ffn_embed_dim
,
dropout
,
self
.
quant_noise
,
dropout
,
self
.
quant_noise_block_size
,
activation
)
)
self
.
fc2
=
self
.
build_fc2
(
self
.
ffn_norm
=
LayerNorm
(
embed_dim
)
self
.
encoder_ffn_embed_dim
,
self
.
embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
p
=
q_noise
,
block_size
=
qn_block_size
)
def
build_fc2
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
p
=
q_noise
,
block_size
=
qn_block_size
)
def
build_self_attention
(
self
,
args
,
embed_dim
,
num_head
,
sample_ratio
=
1
):
def
build_self_attention
(
self
,
args
,
embed_dim
,
num_head
,
sample_ratio
=
1
):
encoder_attention_heads
=
num_head
attention_heads
=
num_head
dropout
=
args
.
dropout
if
self
.
attn_type
==
"selfattn"
:
if
self
.
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
elif
self
.
attn_type
==
"relative"
:
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
return
RelativeMultiheadAttention
(
embed_dim
,
embed_dim
,
encoder_
attention_heads
,
attention_heads
,
dropout
=
args
.
attention_
dropout
,
dropout
=
dropout
,
self_attention
=
True
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
qn_block_size
=
self
.
quant_noise_block_size
,
...
@@ -152,8 +122,8 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -152,8 +122,8 @@ class PDSTransformerEncoderLayer(nn.Module):
init_mask_weight
=
getattr
(
args
,
"init_mask_weight"
,
0
)
init_mask_weight
=
getattr
(
args
,
"init_mask_weight"
,
0
)
return
LocalMultiheadAttention
(
return
LocalMultiheadAttention
(
embed_dim
,
embed_dim
,
encoder_
attention_heads
,
attention_heads
,
dropout
=
args
.
attention_
dropout
,
dropout
=
dropout
,
self_attention
=
True
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
qn_block_size
=
self
.
quant_noise_block_size
,
...
@@ -161,11 +131,36 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -161,11 +131,36 @@ class PDSTransformerEncoderLayer(nn.Module):
gauss_mask_sigma
=
gauss_mask_sigma
,
gauss_mask_sigma
=
gauss_mask_sigma
,
init_mask_weight
=
init_mask_weight
init_mask_weight
=
init_mask_weight
)
)
elif
self
.
attn_type
==
"rel_pos"
:
return
RelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"rel_pos_legacy"
:
return
LegacyRelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"rope"
:
return
RotaryPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
precision
=
args
.
fp16
)
elif
self
.
attn_type
==
"abs"
:
return
ESPNETMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"reduced"
:
elif
self
.
attn_type
==
"reduced"
:
return
ReducedMultiheadAttention
(
return
ReducedMultiheadAttention
(
embed_dim
,
embed_dim
,
encoder_
attention_heads
,
attention_heads
,
dropout
=
args
.
attention_
dropout
,
dropout
=
dropout
,
self_attention
=
True
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
qn_block_size
=
self
.
quant_noise_block_size
,
...
@@ -177,8 +172,8 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -177,8 +172,8 @@ class PDSTransformerEncoderLayer(nn.Module):
return
attn_func
(
return
attn_func
(
embed_dim
,
embed_dim
,
encoder_
attention_heads
,
attention_heads
,
dropout
=
args
.
attention_
dropout
,
dropout
=
dropout
,
self_attention
=
True
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
qn_block_size
=
self
.
quant_noise_block_size
,
...
@@ -234,15 +229,15 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -234,15 +229,15 @@ class PDSTransformerEncoderLayer(nn.Module):
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
x
=
self
.
macaron_norm
(
x
)
x
=
self
.
macaron_f
c2
(
self
.
activation_dropout_module
(
self
.
activation_fn
(
self
.
macaron_fc1
(
x
)))
)
x
=
self
.
macaron_f
fn
(
x
)
x
=
residual
+
self
.
ffn_scale
*
self
.
dropout_module
(
x
)
x
=
residual
+
self
.
ffn_scale
*
x
if
not
self
.
normalize_before
:
if
not
self
.
normalize_before
:
x
=
self
.
macaron_norm
(
x
)
x
=
self
.
macaron_norm
(
x
)
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
==
"rel_selfattn"
:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
x
,
_
=
self
.
self_attn
(
query
=
x
,
query
=
x
,
...
@@ -269,326 +264,28 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -269,326 +264,28 @@ class PDSTransformerEncoderLayer(nn.Module):
# convolution module
# convolution module
if
self
.
conv_module
is
not
None
:
if
self
.
conv_module
is
not
None
:
x
=
x
.
transpose
(
0
,
1
)
residual
=
x
residual
=
x
x
=
x
.
transpose
(
0
,
1
)
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
conv_norm
(
x
)
x
=
self
.
conv_norm
(
x
)
x
=
residual
+
self
.
dropout_module
(
self
.
conv_module
(
x
,
encoder_padding_mask
))
x
=
self
.
conv_module
(
x
)
x
=
x
.
transpose
(
0
,
1
)
x
=
residual
+
x
if
not
self
.
normalize_before
:
if
not
self
.
normalize_before
:
x
=
self
.
conv_norm
(
x
)
x
=
self
.
conv_norm
(
x
)
x
=
x
.
transpose
(
0
,
1
)
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
ffn_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
ffn
(
x
)
x
=
self
.
activation_dropout_module
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
self
.
ffn_scale
*
x
,
residual
)
x
=
self
.
residual_connection
(
self
.
ffn_scale
*
x
,
residual
)
if
not
self
.
normalize_before
:
if
not
self
.
normalize_before
:
x
=
self
.
f
inal_layer
_norm
(
x
)
x
=
self
.
f
fn
_norm
(
x
)
if
self
.
conv_module
is
not
None
:
if
self
.
conv_module
is
not
None
:
x
=
self
.
final_norm
(
x
)
x
=
self
.
final_norm
(
x
)
return
x
return
x
class
TransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
args
,
no_encoder_attn
=
False
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
super
()
.
__init__
()
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
dropout_module
=
FairseqDropout
(
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
quant_noise
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
self
.
cross_self_attention
=
getattr
(
args
,
"cross_self_attention"
,
False
)
self
.
attn_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
args
,
add_bias_kv
=
add_bias_kv
,
add_zero_attn
=
add_zero_attn
,
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation
=
str
(
args
.
activation_fn
)
if
getattr
(
args
,
"activation_fn"
,
None
)
is
not
None
else
"relu"
)
activation_dropout_p
=
getattr
(
args
,
"activation_dropout"
,
0
)
or
0
if
activation_dropout_p
==
0
:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p
=
getattr
(
args
,
"relu_dropout"
,
0
)
or
0
self
.
activation_dropout_module
=
FairseqDropout
(
float
(
activation_dropout_p
),
module_name
=
self
.
__class__
.
__name__
)
self
.
normalize_before
=
args
.
decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export
=
getattr
(
args
,
"char_inputs"
,
False
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
if
no_encoder_attn
:
self
.
encoder_attn
=
None
self
.
encoder_attn_layer_norm
=
None
else
:
self
.
encoder_attn
=
self
.
build_encoder_attention
(
self
.
embed_dim
,
args
)
self
.
encoder_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
self
.
fc1
=
self
.
build_fc1
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
fc2
=
self
.
build_fc2
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
,
self
.
quant_noise
,
self
.
quant_noise_block_size
,
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
self
.
need_attn
=
True
self
.
onnx_trace
=
False
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
q_noise
,
qn_block_size
)
def
build_fc2
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
return
quant_noise
(
nn
.
Linear
(
input_dim
,
output_dim
),
q_noise
,
qn_block_size
)
def
build_self_attention
(
self
,
embed_dim
,
args
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
if
self
.
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
max
(
getattr
(
args
,
"max_decoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum decoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
print
(
"The decoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
add_bias_kv
=
add_bias_kv
,
add_zero_attn
=
add_zero_attn
,
self_attention
=
not
getattr
(
args
,
"cross_self_attention"
,
False
),
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
build_encoder_attention
(
self
,
embed_dim
,
args
):
return
MultiheadAttention
(
embed_dim
,
args
.
decoder_attention_heads
,
kdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
vdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
dropout
=
args
.
attention_dropout
,
encoder_decoder_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
)
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
def
residual_connection
(
self
,
x
,
residual
):
return
residual
+
x
def
forward
(
self
,
x
,
encoder_out
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
prev_self_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
prev_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
self_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
self_attn_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_attn
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
pos_emb
:
Optional
[
Tensor
]
=
None
,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if
need_head_weights
:
need_attn
=
True
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
prev_self_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_self_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_self_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_self_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
self_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
_self_attn_input_buffer
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
if
self
.
cross_self_attention
and
not
(
incremental_state
is
not
None
and
_self_attn_input_buffer
is
not
None
and
"prev_key"
in
_self_attn_input_buffer
):
if
self_attn_mask
is
not
None
:
assert
encoder_out
is
not
None
self_attn_mask
=
torch
.
cat
(
(
x
.
new_zeros
(
x
.
size
(
0
),
encoder_out
.
size
(
0
)),
self_attn_mask
),
dim
=
1
)
if
self_attn_padding_mask
is
not
None
:
if
encoder_padding_mask
is
None
:
assert
encoder_out
is
not
None
encoder_padding_mask
=
self_attn_padding_mask
.
new_zeros
(
encoder_out
.
size
(
1
),
encoder_out
.
size
(
0
)
)
self_attn_padding_mask
=
torch
.
cat
(
(
encoder_padding_mask
,
self_attn_padding_mask
),
dim
=
1
)
assert
encoder_out
is
not
None
y
=
torch
.
cat
((
encoder_out
,
x
),
dim
=
0
)
else
:
y
=
x
if
self
.
attn_type
==
"rel_selfattn"
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
encoder_attn
is
not
None
and
encoder_out
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
if
prev_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
encoder_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
need_head_weights
=
need_head_weights
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
if
self
.
onnx_trace
and
incremental_state
is
not
None
:
saved_state
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
assert
saved_state
is
not
None
if
self_attn_padding_mask
is
not
None
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
],
saved_state
[
"prev_key_padding_mask"
],
]
else
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
]]
return
x
,
attn
,
self_attn_state
return
x
,
attn
,
None
def
make_generation_fast_
(
self
,
need_attn
:
bool
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
fairseq/modules/s2t_transformer_layer.py
查看文件 @
55702466
...
@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module):
embed_dim
=
args
.
encoder_embed_dim
embed_dim
=
args
.
encoder_embed_dim
ffn_dim
=
args
.
encoder_ffn_embed_dim
ffn_dim
=
args
.
encoder_ffn_embed_dim
dropout
=
args
.
dropout
dropout
=
args
.
dropout
self
.
embed_dim
=
args
.
encoder_
embed_dim
self
.
embed_dim
=
embed_dim
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
args
)
self
.
self_attn
=
self
.
build_self_attention
(
args
,
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
self
.
dropout_module
=
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
dropout
,
module_name
=
self
.
__class__
.
__name__
...
@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module):
)
)
self
.
ffn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
ffn_norm
=
LayerNorm
(
self
.
embed_dim
)
def
build_self_attention
(
self
,
embed_dim
,
args
):
def
build_self_attention
(
self
,
args
,
embed_dim
):
attention_heads
=
args
.
encoder_attention_heads
attention_heads
=
args
.
encoder_attention_heads
dropout
=
args
.
dropout
dropout
=
args
.
dropout
...
@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module):
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
)
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
return
RelativeMultiheadAttention
(
embed_dim
,
embed_dim
,
...
@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module):
)
)
else
:
else
:
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The
encoder
attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
return
attn_func
(
return
attn_func
(
embed_dim
,
embed_dim
,
a
rgs
.
encoder_a
ttention_heads
,
attention_heads
,
dropout
=
args
.
attention_
dropout
,
dropout
=
dropout
,
self_attention
=
True
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
qn_block_size
=
self
.
quant_noise_block_size
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论