Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
55702466
Commit
55702466
authored
Feb 24, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the pure ctc arch for pds method
parent
bab6c520
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
154 行增加
和
73 行删除
+154
-73
egs/libri_trans/asr/conf/debug.yaml
+20
-28
egs/mustc/asr/conf/purectc_pds_base_8.yaml
+43
-0
examples/speech_to_text/prep_audio_data.py
+7
-2
fairseq/criterions/ctc.py
+1
-3
fairseq/models/__init__.py
+2
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+67
-28
fairseq/models/speech_to_text/s2t_ctc.py
+0
-0
fairseq/models/speech_to_text/s2t_transformer.py
+6
-5
fairseq/modules/pds_layer.py
+0
-0
fairseq/modules/s2t_transformer_layer.py
+8
-7
没有找到文件。
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
55702466
arch
:
s2t_transformer_s
arch
:
s2t_ctc
encoder-type
:
pds
#arch: pdss2t_transformer_s_8
encoder-embed-dim
:
256
pds-stages
:
4
ctc-layer
:
12
pds-layers
:
3_3_3_3
pds-ratios
:
2_2_1_2
pds-fusion
:
True
pds-fusion-method
:
all_conv
pds-embed-dims
:
256_256_256_256
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1_1
pds-kernel-sizes
:
5_5_5_5
pds-ffn-ratios
:
8_8_8_8
pds-attn-heads
:
4_4_4_4
share-decoder-input-output-embed
:
True
share-decoder-input-output-embed
:
True
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
2e-3
lr
:
2e-3
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
label_smoothed_cross_entropy_with_ctc
criterion
:
ctc
label_smoothing
:
0.1
subsampling-type
:
conv1d
subsmapling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
ctc-weight
:
0.2
intermedia-ctc-layers
:
6,9
intermedia-adapter
:
league
intermedia-ctc-weight
:
0.1
intermedia-drop-prob
:
0.5
ctc-self-distill-weight
:
0
post-process
:
sentencepiece
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-embed-dim
:
256
encoder-ffn-embed-dim
:
2048
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
encoder-layers
:
12
decoder-layers
:
6
decoder-layers
:
6
...
@@ -40,8 +38,3 @@ encoder-attention-heads: 4
...
@@ -40,8 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim
:
256
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
decoder-attention-heads
:
4
attention-dropout
:
0.1
activation-dropout
:
0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
egs/mustc/asr/conf/purectc_pds_base_8.yaml
0 → 100644
查看文件 @
55702466
arch
:
s2t_ctc
encoder-type
:
pds
encoder-embed-dim
:
256
pds-stages
:
4
ctc-layer
:
12
pds-layers
:
3_3_3_3
pds-ratios
:
2_2_1_2
pds-fusion
:
True
pds-fusion-method
:
all_conv
pds-embed-dims
:
256_256_256_256
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1_1
pds-kernel-sizes
:
5_5_5_5
pds-ffn-ratios
:
8_8_8_8
pds-attn-heads
:
4_4_4_4
share-decoder-input-output-embed
:
True
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
lr
:
2e-3
adam_betas
:
(0.9,0.98)
criterion
:
ctc
dropout
:
0.1
activation-fn
:
relu
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
decoder-layers
:
6
encoder-attention-heads
:
4
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
examples/speech_to_text/prep_audio_data.py
查看文件 @
55702466
...
@@ -137,7 +137,8 @@ class AudioDataset(Dataset):
...
@@ -137,7 +137,8 @@ class AudioDataset(Dataset):
for
i
,
segment
in
enumerate
(
seg_group
):
for
i
,
segment
in
enumerate
(
seg_group
):
offset
=
int
(
float
(
segment
[
"offset"
])
*
sample_rate
)
offset
=
int
(
float
(
segment
[
"offset"
])
*
sample_rate
)
n_frames
=
int
(
float
(
segment
[
"duration"
])
*
sample_rate
)
n_frames
=
int
(
float
(
segment
[
"duration"
])
*
sample_rate
)
_id
=
f
"{split}_{wav_path.stem}_{i}"
# _id = f"{split}_{wav_path.stem}_{i}"
_id
=
f
"{wav_path.stem}_{i}"
item
=
dict
()
item
=
dict
()
item
[
"audio"
]
=
wav_path
.
as_posix
()
item
[
"audio"
]
=
wav_path
.
as_posix
()
...
@@ -263,8 +264,12 @@ def process(args):
...
@@ -263,8 +264,12 @@ def process(args):
utt_id
=
item
[
'id'
]
utt_id
=
item
[
'id'
]
features_path
=
(
feature_root
/
f
"{utt_id}.npy"
)
.
as_posix
()
features_path
=
(
feature_root
/
f
"{utt_id}.npy"
)
.
as_posix
()
tag_features_path
=
(
feature_root
/
f
"{split}_{utt_id}.npy"
)
.
as_posix
()
if
os
.
path
.
exists
(
features_path
):
if
os
.
path
.
exists
(
tag_features_path
):
continue
if
os
.
path
.
exists
(
features_path
)
and
not
os
.
path
.
exists
(
tag_features_path
):
shutil
.
move
(
features_path
,
tag_features_path
)
continue
continue
waveform
,
sample_rate
,
_
=
dataset
.
get
(
idx
,
need_waveform
=
True
)
waveform
,
sample_rate
,
_
=
dataset
.
get
(
idx
,
need_waveform
=
True
)
...
...
fairseq/criterions/ctc.py
查看文件 @
55702466
...
@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask
...
@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask
from
fairseq.logging.meters
import
safe_round
from
fairseq.logging.meters
import
safe_round
@dataclass
@dataclass
class
CtcCriterionConfig
(
FairseqDataclass
):
class
CtcCriterionConfig
(
FairseqDataclass
):
zero_infinity
:
bool
=
field
(
zero_infinity
:
bool
=
field
(
...
@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass):
...
@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass):
)
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
post_process
:
str
=
field
(
post_process
:
str
=
field
(
default
=
"
letter
"
,
default
=
"
sentencepiece
"
,
metadata
=
{
metadata
=
{
"help"
:
"how to post process predictions into words. can be letter, "
"help"
:
"how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"wordpiece, BPE symbols, etc. "
...
...
fairseq/models/__init__.py
查看文件 @
55702466
...
@@ -47,6 +47,8 @@ __all__ = [
...
@@ -47,6 +47,8 @@ __all__ = [
"FairseqLanguageModel"
,
"FairseqLanguageModel"
,
"FairseqModel"
,
"FairseqModel"
,
"FairseqMultiModel"
,
"FairseqMultiModel"
,
"register_model"
,
"register_model_architecture"
]
]
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
55702466
...
@@ -19,6 +19,8 @@ from fairseq.modules import (
...
@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout
,
FairseqDropout
,
LayerNorm
,
LayerNorm
,
PositionalEmbedding
,
PositionalEmbedding
,
RelPositionalEncoding
,
LegacyRelPositionalEncoding
,
PDSTransformerEncoderLayer
,
PDSTransformerEncoderLayer
,
DownSampleConvolutionModule
DownSampleConvolutionModule
)
)
...
@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod
@staticmethod
def
add_args
(
parser
):
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
"""Add model-specific arguments to the parser."""
# input
parser
.
add_argument
(
"--conv-kernel-sizes"
,
type
=
str
,
metavar
=
"N"
,
help
=
"kernel sizes of Conv1d subsampling layers"
,
)
parser
.
add_argument
(
"--conv-channels"
,
type
=
int
,
metavar
=
"N"
,
help
=
"# of channels in Conv1d subsampling layers"
,
)
# Transformer
# Transformer
parser
.
add_argument
(
parser
.
add_argument
(
"--activation-fn"
,
"--activation-fn"
,
...
@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"reduced"
,
"reduced"
,
"rel_selfattn"
,
"rel_selfattn"
,
"relative"
,
"relative"
,
"rel_pos_legacy"
,
"rel_pos"
,
"rope"
,
"abs"
,
],
],
help
=
"transformer encoder self-attention layer type"
help
=
"transformer encoder self-attention layer type"
)
)
...
@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help
=
'dropout for history output'
)
help
=
'dropout for history output'
)
parser
.
add_argument
(
'--history-window-size'
,
type
=
int
,
default
=
'-1'
,
parser
.
add_argument
(
'--history-window-size'
,
type
=
int
,
default
=
'-1'
,
help
=
'how many past layers are considered. -1 means all'
)
help
=
'how many past layers are considered. -1 means all'
)
# CTC
parser
.
add_argument
(
"--ctc-layer"
,
default
=
0
,
type
=
int
,
help
=
"the position of the ctc loss"
,
)
# local modeling
# local modeling
parser
.
add_argument
(
parser
.
add_argument
(
'--hard-mask-window'
,
'--hard-mask-window'
,
...
@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
# Conformer setting
# Conformer setting
parser
.
add_argument
(
parser
.
add_argument
(
"--encoder-activation-fn"
,
type
=
str
,
default
=
"relu"
,
choices
=
utils
.
get_available_activation_fns
(),
help
=
"activation function to use"
,
)
parser
.
add_argument
(
"--macaron-style"
,
"--macaron-style"
,
default
=
False
,
default
=
False
,
type
=
bool
,
type
=
bool
,
...
@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"The legacy relative positional encoding will be deprecated in the future."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816."
,
"More Details can be found in https://github.com/espnet/espnet/pull/2816."
,
)
)
# CTC
# CNN module
parser
.
add_argument
(
"--ctc-layer"
,
default
=
0
,
type
=
int
,
help
=
"the position of the ctc loss"
,
)
# Conformer module
parser
.
add_argument
(
parser
.
add_argument
(
"--use-cnn-module"
,
"--use-cnn-module"
,
default
=
False
,
default
=
False
,
...
@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
str
,
type
=
str
,
help
=
"use the position embedding or not before each encoding"
,
help
=
"use the position embedding or not before each encoding"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--pds-attn-heads"
,
"--pds-attn-heads"
,
type
=
str
,
type
=
str
,
...
@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
str
,
type
=
str
,
help
=
"use the ctc after each stage"
,
help
=
"use the ctc after each stage"
,
)
)
# intermedia ctc
parser
.
add_argument
(
parser
.
add_argument
(
"--intermedia-ctc-layers"
,
"--intermedia-ctc-layers"
,
default
=
None
,
default
=
None
,
...
@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
str
,
type
=
str
,
help
=
"type of intermedia adapter"
,
help
=
"type of intermedia adapter"
,
)
)
parser
.
add_argument
(
"--intermedia-distribution-cutoff"
,
default
=-
1
,
type
=
int
,
help
=
"cutoff of the distribution"
,
)
parser
.
add_argument
(
"--intermedia-drop-prob"
,
default
=
0
,
type
=
float
,
help
=
"probability of dropping the followed layers"
,
)
pass
pass
@classmethod
@classmethod
...
@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
f
"loaded pretrained encoder from: "
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
f
"{args.load_pretrained_encoder_from}"
)
)
return
encoder
return
encoder
...
@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
.
pds_kernel_sizes
=
[
int
(
n
)
for
n
in
args
.
pds_kernel_sizes
.
split
(
"_"
)]
self
.
pds_kernel_sizes
=
[
int
(
n
)
for
n
in
args
.
pds_kernel_sizes
.
split
(
"_"
)]
self
.
pds_embed_norm
=
args
.
pds_embed_norm
self
.
pds_embed_norm
=
args
.
pds_embed_norm
self
.
pds_position_embed
=
[
int
(
n
)
for
n
in
args
.
pds_position_embed
.
split
(
"_"
)]
self
.
pds_position_embed
=
[
int
(
n
)
for
n
in
args
.
pds_position_embed
.
split
(
"_"
)]
self
.
pds_attn_heads
=
[
int
(
n
)
for
n
in
args
.
pds_attn_heads
.
split
(
"_"
)]
self
.
pds_attn_heads
=
[
int
(
n
)
for
n
in
args
.
pds_attn_heads
.
split
(
"_"
)]
self
.
pds_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_ffn_ratios
.
split
(
"_"
)]
self
.
pds_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_ffn_ratios
.
split
(
"_"
)]
if
self
.
attn_type
==
"reduced"
:
if
self
.
attn_type
==
"reduced"
:
...
@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if
args
.
no_scale_embedding
:
if
args
.
no_scale_embedding
:
self
.
embed_scale
=
1.0
self
.
embed_scale
=
1.0
# down-sampling
downsampling
=
Downsampling
(
downsampling
=
Downsampling
(
self
.
pds_ds_method
,
self
.
pds_ds_method
,
self
.
pds_embed_norm
,
self
.
pds_embed_norm
,
...
@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
stride
=
ds_ratio
,
stride
=
ds_ratio
,
padding
=
(
kernel_size
-
1
)
//
2
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
)
# position encoding
if
use_pos_embed
:
if
use_pos_embed
:
pos_embed
=
PositionalEmbedding
(
args
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
)
if
self
.
attn_type
==
"rel_pos"
:
pos_embed
=
RelPositionalEncoding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
)
elif
self
.
attn_type
in
[
"rel_selfattn"
,
"rel_pos_legacy"
]:
pos_embed
=
LegacyRelPositionalEncoding
(
args
.
encoder_embed_dim
,
args
.
dropout
,
args
.
max_source_positions
)
elif
self
.
attn_type
==
"rope"
:
self
.
embed_positions
=
None
else
:
# Use absolute positional embedding
pos_embed
=
PositionalEmbedding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
)
else
:
else
:
pos_embed
=
None
pos_embed
=
None
...
@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
PDSTransformerEncoderLayer
(
args
,
embed_dim
,
embed_dim
*
ffn_ratio
,
num_head
,
attn_ds_ratio
)
PDSTransformerEncoderLayer
(
args
,
embed_dim
,
embed_dim
*
ffn_ratio
,
num_head
,
attn_ds_ratio
)
for
_
in
range
(
num_layers
)])
for
_
in
range
(
num_layers
)])
# representation fusion
fusion_pre_layer_norm
=
None
fusion_pre_layer_norm
=
None
fusion_post_layer_norm
=
None
fusion_post_layer_norm
=
None
fusion_downsampling
=
None
fusion_downsampling
=
None
...
@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
.
fusion_weight
.
data
=
self
.
fusion_weight
.
data
/
self
.
fusion_weight
.
data
.
sum
(
0
,
keepdim
=
True
)
self
.
fusion_weight
.
data
=
self
.
fusion_weight
.
data
/
self
.
fusion_weight
.
data
.
sum
(
0
,
keepdim
=
True
)
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
(
getattr
(
args
,
"criterion"
,
""
)
==
"ctc"
)
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
""
))
and
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
))
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
))
if
self
.
use_ctc
:
if
self
.
use_ctc
:
self
.
ctc_layer
=
(
args
.
ctc_layer
+
args
.
encoder_layers
)
%
args
.
encoder_layers
self
.
ctc_layer
=
(
args
.
ctc_layer
+
args
.
encoder_layers
)
%
args
.
encoder_layers
...
@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout
# add the position encoding and dropout
if
pos_embed
:
if
pos_embed
:
positions
=
pos_embed
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
x
+=
positions
positions
=
pos_embed
(
x
)
positions
=
self
.
dropout
(
positions
)
elif
self
.
attn_type
==
"rope"
:
positions
=
None
else
:
positions
=
pos_embed
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
x
+=
positions
positions
=
None
else
:
else
:
positions
=
None
positions
=
None
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
55702466
差异被折叠。
点击展开。
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
55702466
...
@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
f
"{args.load_pretrained_encoder_from}"
)
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
return
encoder
return
encoder
...
@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
padding_idx
=
1
self
.
padding_idx
=
1
self
.
subsample
=
subsampling
(
args
)
self
.
subsample
=
subsampling
(
args
)
self
.
linear
=
nn
.
Linear
(
dim
,
dim
)
#
self.linear = nn.Linear(dim, dim)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
...
@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
history
=
None
self
.
history
=
None
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
(
getattr
(
args
,
"criterion"
,
""
)
==
"ctc"
)
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
""
))
and
(
getattr
(
args
,
"ctc_weight"
,
0
)
>
0
))
((
"ctc"
in
getattr
(
args
,
"criterion"
,
""
))
and
(
getattr
(
args
,
"ctc_weight"
,
0
)
>
0
))
if
self
.
use_ctc
:
if
self
.
use_ctc
:
self
.
ctc_layer
=
args
.
ctc_layer
self
.
ctc_layer
=
args
.
ctc_layer
...
@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x
+=
positions
x
+=
positions
positions
=
None
positions
=
None
x
=
self
.
linear
(
x
)
#
x = self.linear(x)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
# add emb into history
# add emb into history
...
...
fairseq/modules/pds_layer.py
查看文件 @
55702466
差异被折叠。
点击展开。
fairseq/modules/s2t_transformer_layer.py
查看文件 @
55702466
...
@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module):
embed_dim
=
args
.
encoder_embed_dim
embed_dim
=
args
.
encoder_embed_dim
ffn_dim
=
args
.
encoder_ffn_embed_dim
ffn_dim
=
args
.
encoder_ffn_embed_dim
dropout
=
args
.
dropout
dropout
=
args
.
dropout
self
.
embed_dim
=
args
.
encoder_
embed_dim
self
.
embed_dim
=
embed_dim
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
args
)
self
.
self_attn
=
self
.
build_self_attention
(
args
,
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
self
.
dropout_module
=
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
dropout
,
module_name
=
self
.
__class__
.
__name__
...
@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module):
)
)
self
.
ffn_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
ffn_norm
=
LayerNorm
(
self
.
embed_dim
)
def
build_self_attention
(
self
,
embed_dim
,
args
):
def
build_self_attention
(
self
,
args
,
embed_dim
):
attention_heads
=
args
.
encoder_attention_heads
attention_heads
=
args
.
encoder_attention_heads
dropout
=
args
.
dropout
dropout
=
args
.
dropout
...
@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module):
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
)
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
return
RelativeMultiheadAttention
(
embed_dim
,
embed_dim
,
...
@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module):
)
)
else
:
else
:
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The
encoder
attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
return
attn_func
(
return
attn_func
(
embed_dim
,
embed_dim
,
a
rgs
.
encoder_a
ttention_heads
,
attention_heads
,
dropout
=
args
.
attention_
dropout
,
dropout
=
dropout
,
self_attention
=
True
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
qn_block_size
=
self
.
quant_noise_block_size
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论