Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
c7242ff4
Commit
c7242ff4
authored
Sep 07, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement the w2v2-transformer arch
parent
afa5095d
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
480 行增加
和
8 行删除
+480
-8
egs/mustc/asr/conf/mixup.yaml
+8
-5
egs/mustc/st/conf/w2v2.yaml
+43
-0
fairseq/models/speech_to_text/__init__.py
+1
-0
fairseq/models/speech_to_text/s2t_w2v2_transformer.py
+341
-0
fairseq/models/wav2vec/wav2vec2.py
+85
-2
fairseq/modules/speech_to_text/subsampling.py
+2
-1
没有找到文件。
egs/mustc/asr/conf/mixup.yaml
查看文件 @
c7242ff4
inter_mixup
:
True
inter_mixup_layer
:
-1
inter_mixup_prob
:
1.0
inter_mixup_ratio
:
0.2
inter_mixup_beta
:
0.2
inter-mixup
:
True
inter-mixup-layer
:
-1
inter-mixup-prob
:
1.0
inter-mixup-ratio
:
1.0
inter-mixup-beta
:
0.5
inter-mixup-keep-org
:
True
ctc-mixup-consistent-weight
:
1
mixup-consistent-weight
:
1
egs/mustc/st/conf/w2v2.yaml
0 → 100644
查看文件 @
c7242ff4
arch
:
s2t_w2v2_transformer
share-decoder-input-output-embed
:
True
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
lr
:
2e-3
adam_betas
:
(0.9,0.98)
criterion
:
label_smoothed_cross_entropy_with_ctc
label_smoothing
:
0.1
encoder-embed-norm
:
True
encoder-no-scale-embedding
:
True
subsampling-type
:
conv1d
subsampling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
dropout
:
0.1
activation-fn
:
relu
encoder-embed-dim
:
256
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
decoder-layers
:
6
encoder-attention-heads
:
4
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
attention-dropout
:
0.1
activation-dropout
:
0.1
w2v2-model-path
:
/home/xuchen/st/models/w2v2/wav2vec_small.pt
freeze-w2v
:
False
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
fairseq/models/speech_to_text/__init__.py
查看文件 @
c7242ff4
...
...
@@ -12,3 +12,4 @@ from .s2t_dual import * # noqa
from
.s2t_ctc
import
*
from
.s2t_multibranch
import
*
from
.s2t_dynamic_transformer
import
*
from
.s2t_w2v2_transformer
import
*
fairseq/models/speech_to_text/s2t_w2v2_transformer.py
0 → 100644
查看文件 @
c7242ff4
import
logging
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
os
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
fairseq
import
checkpoint_utils
,
utils
,
tasks
from
fairseq.data.data_utils
import
lengths_to_padding_mask
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderDecoderModel
,
register_model
,
register_model_architecture
,
)
from
fairseq.modules.speech_to_text
import
Adapter
,
CTC
from
fairseq.models.transformer
import
Embedding
,
TransformerDecoder
from
fairseq.models.speech_to_text
import
S2TTransformerModel
,
S2TTransformerEncoder
from
fairseq.models.wav2vec
import
Wav2Vec2Model
,
Wav2VecCtc
from
fairseq.modules
import
(
FairseqDropout
,
LayerNorm
,
PositionalEmbedding
,
TransformerEncoderLayer
,
S2TTransformerEncoderLayer
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
S2TTransformerEncoderLayer
,
DynamicLinearCombination
,
)
from
fairseq.modules.speech_to_text
import
(
subsampling
)
from
torch
import
Tensor
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"s2t_w2v2_transformer"
)
class
S2TW2V2TransformerModel
(
S2TTransformerModel
):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
S2TTransformerModel
.
add_args
(
parser
)
parser
.
add_argument
(
"--w2v2-model-path"
,
type
=
str
,
metavar
=
"N"
,
help
=
"path/to/wav2vec/model, support hdfs"
)
parser
.
add_argument
(
"--freeze-w2v"
,
action
=
"store_true"
,
help
=
"if we want to freeze the w2v features"
)
parser
.
add_argument
(
"--use-asr-finetune-w2v"
,
action
=
"store_true"
,
help
=
"if we want to load wav2vec2.0 asr finetuned data"
)
pass
@classmethod
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TW2V2TransformerEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
return
encoder
class
S2TW2V2TransformerEncoder
(
FairseqEncoder
):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
None
)
assert
args
.
w2v2_model_path
is
not
None
self
.
w2v2_model_path
=
args
.
w2v2_model_path
self
.
use_asr_finetune_w2v
=
args
.
use_asr_finetune_w2v
ckpt
=
torch
.
load
(
self
.
w2v2_model_path
)
self
.
w2v_args
=
ckpt
[
"args"
]
if
not
self
.
use_asr_finetune_w2v
:
# if use ssl-trained only
self
.
w2v_args
=
ckpt
[
"args"
]
self
.
wav2vec_model
=
Wav2Vec2Model
.
build_model
(
ckpt
[
'args'
],
task
=
None
)
self
.
wav2vec_model
.
load_state_dict
(
ckpt
[
'model'
])
else
:
# wav2vec-ctc model
ckpt
[
"args"
]
.
data
=
args
.
data
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
ckpt
[
"args"
]
.
data
,
f
"dict.{ckpt['args'].labels}.txt"
)):
os
.
system
(
f
"wget -P {ckpt['args'].data} https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt"
)
task
=
tasks
.
setup_task
(
ckpt
[
"args"
])
model_finetuned
=
Wav2VecCtc
.
build_model
(
ckpt
[
"args"
],
task
=
task
)
model_finetuned
.
load_state_dict
(
ckpt
[
'model'
])
self
.
wav2vec_model
=
model_finetuned
.
w2v_encoder
.
w2v_model
self
.
w2v_args
=
ckpt
[
"args"
]
.
w2v_args
[
"model"
]
self
.
freeze_w2v
=
args
.
freeze_w2v
# w2v_output_dim = 512
w2v_output_dim
=
self
.
w2v_args
.
encoder_embed_dim
self
.
encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
del
self
.
encoder
.
subsample
self
.
encoder
.
subsample
=
subsampling
(
args
,
in_dim
=
w2v_output_dim
)
def
_get_w2v_feature
(
self
,
src_tokens
,
src_lengths
):
"""
:param src_tokens: b x frames
:param src_lengths: b-dim length
:return: w2v_feature: b x short_frames x feature-dim;
w2v_lengths: b-dim tensor
w2v_padding_mask: b x short_frames x feature-dim T/F tensor
"""
padding_mask
=
lengths_to_padding_mask
(
src_lengths
)
# print("padding mask:", padding_mask.size())
# print(padding_mask)
# w2v_feature = self.wav2vec_model.feature_extractor(src_tokens).transpose(1,2)
w2v_feature
,
padding_mask
=
self
.
wav2vec_model
.
extract_features
(
src_tokens
,
padding_mask
)
# print("after extraction, padding:", padding_mask)
output_length
=
(
1
-
padding_mask
.
int
())
.
sum
(
dim
=
1
)
# output_length = (torch.ones(padding_mask.size()) - padding_mask.int()).sum(dim=1)
return
w2v_feature
,
padding_mask
,
output_length
def
forward
(
self
,
src_tokens
,
src_lengths
):
# 1. wav2vec
if
self
.
freeze_w2v
:
with
torch
.
no_grad
():
w2v_feature
,
encoder_padding_mask
,
input_lengths
=
self
.
_get_w2v_feature
(
src_tokens
,
src_lengths
)
else
:
w2v_feature
,
encoder_padding_mask
,
input_lengths
=
self
.
_get_w2v_feature
(
src_tokens
,
src_lengths
)
return
self
.
encoder
.
forward
(
w2v_feature
,
input_lengths
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
return
self
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
@register_model_architecture
(
model_name
=
"s2t_w2v2_transformer"
,
arch_name
=
"s2t_w2v2_transformer"
)
def
base_architecture
(
args
):
# Convolutional subsampler
args
.
subsampling_type
=
getattr
(
args
,
"subsampling_type"
,
"conv1d"
)
args
.
subsampling_layers
=
getattr
(
args
,
"subsampling_layers"
,
2
)
args
.
subsampling_filter
=
getattr
(
args
,
"subsampling_filter"
,
1024
)
args
.
subsampling_kernel
=
getattr
(
args
,
"subsampling_kernel"
,
5
)
args
.
subsampling_stride
=
getattr
(
args
,
"subsampling_stride"
,
2
)
args
.
subsampling_norm
=
getattr
(
args
,
"subsampling_norm"
,
"none"
)
args
.
subsampling_activation
=
getattr
(
args
,
"subsampling_activation"
,
"glu"
)
# Transformer
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
True
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
True
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
args
.
dropout
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
args
.
dropout
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
tie_adaptive_weights
=
getattr
(
args
,
"tie_adaptive_weights"
,
False
)
args
.
tie_adaptive_proj
=
getattr
(
args
,
"tie_adaptive_proj"
,
False
)
args
.
adaptive_softmax_factor
=
getattr
(
args
,
"adaptive_softmax_factor"
,
4
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
"share_all_embeddings"
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
encoder_layerdrop
=
getattr
(
args
,
"encoder_layerdrop"
,
0.0
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0.0
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
encoder_no_scale_embedding
=
getattr
(
args
,
"encoder_no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
encoder_embed_linear
=
getattr
(
args
,
"encoder_embed_linear"
,
False
)
args
.
encoder_embed_norm
=
getattr
(
args
,
"encoder_embed_norm"
,
False
)
# CTC
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
share_ctc_and_embed
=
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
# Conformer
args
.
encoder_activation_fn
=
getattr
(
args
,
"encoder_activation_fn"
,
"relu"
)
args
.
macaron_style
=
getattr
(
args
,
"macaron_style"
,
False
)
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
False
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
args
.
cnn_module_norm
=
getattr
(
args
,
"cnn_module_norm"
,
"batch_norm"
)
# settings for DLCL
args
.
use_enc_dlcl
=
getattr
(
args
,
"use_enc_dlcl"
,
False
)
args
.
use_dec_dlcl
=
getattr
(
args
,
"use_dec_dlcl"
,
False
)
args
.
init_value
=
getattr
(
args
,
'init_value'
,
'avg'
)
args
.
weight_type
=
getattr
(
args
,
'weight_type'
,
'scalar'
)
args
.
encoder_learnable
=
getattr
(
args
,
'encoder_learnable'
,
True
)
args
.
decoder_learnable
=
getattr
(
args
,
'decoder_learnable'
,
True
)
args
.
normalize_embed
=
getattr
(
args
,
'normalize_embed'
,
False
)
args
.
history_dropout
=
getattr
(
args
,
'history_dropout'
,
0.0
)
args
.
history_window_size
=
getattr
(
args
,
'history_window_size'
,
-
1
)
# Relative position encoding
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
# local modeling
args
.
hard_mask_window
=
getattr
(
args
,
'hard_mask_window'
,
0
)
args
.
gauss_mask_sigma
=
getattr
(
args
,
'gauss_mask_sigma'
,
0
)
args
.
init_mask_weight
=
getattr
(
args
,
'init_mask_weight'
,
0
)
# interleaved CTC
args
.
interleaved_ctc_layers
=
getattr
(
args
,
"interleaved_ctc_layers"
,
None
)
args
.
interleaved_ctc_temperature
=
getattr
(
args
,
"interleaved_ctc_temperature"
,
1
)
args
.
interleaved_ctc_drop_prob
=
getattr
(
args
,
"interleaved_ctc_drop_prob"
,
0
)
# Semantics-augmented Encoding (sae)
args
.
sae_adapter
=
getattr
(
args
,
"sae_adapter"
,
"none"
)
args
.
share_sae_and_ctc
=
getattr
(
args
,
"share_sae_and_ctc"
,
False
)
args
.
sae_embed_norm
=
getattr
(
args
,
"sae_embed_norm"
,
False
)
args
.
sae_out_norm
=
getattr
(
args
,
"sae_out_norm"
,
False
)
args
.
sae_drop_prob
=
getattr
(
args
,
"sae_drop_prob"
,
0
)
args
.
sae_distribution_cutoff
=
getattr
(
args
,
"sae_distribution_cutoff"
,
None
)
args
.
sae_distribution_hard
=
getattr
(
args
,
"sae_distribution_hard"
,
False
)
args
.
sae_gumbel
=
getattr
(
args
,
"sae_gumbel"
,
False
)
# mixup
args
.
inter_mixup
=
getattr
(
args
,
"inter_mixup"
,
False
)
args
.
inter_mixup_layer
=
getattr
(
args
,
"inter_mixup_layer"
,
"-1"
)
args
.
inter_mixup_beta
=
getattr
(
args
,
"inter_mixup_beta"
,
0.5
)
args
.
inter_mixup_prob
=
getattr
(
args
,
"inter_mixup_prob"
,
1
)
args
.
inter_mixup_ratio
=
getattr
(
args
,
"inter_mixup_ratio"
,
0.3
)
args
.
inter_mixup_keep_org
=
getattr
(
args
,
"inter_mixup_keep_org"
,
False
)
args
.
condensation_metric
=
getattr
(
args
,
"condensation_metric"
,
"ratio"
)
args
.
condensation_mode
=
getattr
(
args
,
"condensation_mode"
,
"create"
)
args
.
condensation_layers
=
getattr
(
args
,
"condensation_layers"
,
None
)
args
.
condensation_threshold
=
getattr
(
args
,
"condensation_threshold"
,
"1.0"
)
args
.
condensation_ratio
=
getattr
(
args
,
"condensation_ratio"
,
"0.0"
)
# Wav2vec2.0 feature-extractor
args
.
w2v2_model_path
=
getattr
(
args
,
"w2v2_model_path"
,
"./wav2vec_small.pt"
)
args
.
freeze_w2v
=
getattr
(
args
,
"freeze_w2v"
,
False
)
# default is false, 'store_true'
args
.
use_asr_finetune_w2v
=
getattr
(
args
,
"use_asr_finetune_w2v"
,
False
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_s"
)
def
s2t_w2v2_transformer_s
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
8
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_s_relative"
)
def
s2t_w2v2_transformer_s_relative
(
args
):
args
.
max_encoder_relative_length
=
100
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
s2t_w2v2_transformer_s
(
args
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_xs"
)
def
s2t_w2v2_transformer_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
3
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.3
)
s2t_w2v2_transformer_s
(
args
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_sp"
)
def
s2t_w2v2_transformer_sp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_w2v2_transformer_s
(
args
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_m"
)
def
s2t_w2v2_transformer_m
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
512
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.15
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_mp"
)
def
s2t_w2v2_transformer_mp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_w2v2_transformer_m
(
args
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_l"
)
def
s2t_w2v2_transformer_l
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
1024
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_w2v2_transformer"
,
"s2t_w2v2_transformer_lp"
)
def
s2t_w2v2_transformer_lp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_w2v2_transformer_l
(
args
)
fairseq/models/wav2vec/wav2vec2.py
查看文件 @
c7242ff4
...
...
@@ -14,7 +14,7 @@ import torch.nn.functional as F
from
fairseq
import
utils
from
fairseq.data.data_utils
import
compute_mask_indices
from
fairseq.dataclass
import
ChoiceEnum
,
FairseqDataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models
import
BaseFairseqModel
,
register_model
,
register_model_architecture
from
fairseq.modules
import
(
Fp32GroupNorm
,
Fp32LayerNorm
,
...
...
@@ -228,6 +228,18 @@ class Wav2Vec2Model(BaseFairseqModel):
feature_enc_layers
=
eval
(
cfg
.
conv_feature_layers
)
self
.
embed
=
feature_enc_layers
[
-
1
][
0
]
# cfg.extractor_mode = "default"
# cfg.mask_other = 0
# cfg.mask_length = 10
# cfg.mask_channel_prob = 0
# cfg.mask_channel_selection = "static"
# cfg.mask_channel_other = 0
# cfg.mask_channel_length = 10
# cfg.mask_channel_min_space = 1
# cfg.latent_dim = 0
# cfg.layer_norm_first = False
# cfg.target_glu = False
self
.
feature_extractor
=
ConvFeatureExtractionModel
(
conv_layers
=
feature_enc_layers
,
dropout
=
0.0
,
...
...
@@ -328,6 +340,7 @@ class Wav2Vec2Model(BaseFairseqModel):
def
build_model
(
cls
,
cfg
:
Wav2Vec2Config
,
task
=
None
):
"""Build a new model instance."""
base_architecture
(
cfg
)
return
cls
(
cfg
)
def
apply_mask
(
self
,
x
,
padding_mask
):
...
...
@@ -763,7 +776,7 @@ class TransformerEncoder(nn.Module):
x_conv
=
self
.
pos_conv
(
x
.
transpose
(
1
,
2
))
x_conv
=
x_conv
.
transpose
(
1
,
2
)
x
+=
x_conv
x
=
x
+
x_conv
if
not
self
.
layer_norm_first
:
x
=
self
.
layer_norm
(
x
)
...
...
@@ -898,3 +911,72 @@ class TransformerSentenceEncoderLayer(nn.Module):
x
=
self
.
final_layer_norm
(
x
)
return
x
,
attn
def
base_architecture
(
args
):
args
.
extractor_mode
=
getattr
(
args
,
"extractor_mode"
,
"default"
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
768
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
3072
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
12
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"gelu"
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.1
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
0.0
)
args
.
final_dim
=
getattr
(
args
,
"final_dim"
,
0
)
args
.
layer_norm_first
=
getattr
(
args
,
"layer_norm_first"
,
False
)
args
.
encoder_layerdrop
=
getattr
(
args
,
"encoder_layerdrop"
,
0.0
)
conv_feature_layers
=
"[(512, 10, 5)]"
conv_feature_layers
+=
" + [(512, 8, 4)]"
conv_feature_layers
+=
" + [(512, 4, 2)] * 3"
conv_feature_layers
+=
" + [(512, 1, 1)]"
args
.
conv_feature_layers
=
getattr
(
args
,
"conv_feature_layers"
,
conv_feature_layers
)
args
.
logit_temp
=
getattr
(
args
,
"logit_temp"
,
0.1
)
args
.
quantize_targets
=
getattr
(
args
,
"quantize_targets"
,
False
)
args
.
quantize_input
=
getattr
(
args
,
"quantize_input"
,
False
)
args
.
same_quantizer
=
getattr
(
args
,
"same_quantizer"
,
False
)
args
.
feature_grad_mult
=
getattr
(
args
,
"feature_grad_mult"
,
1.0
)
args
.
latent_vars
=
getattr
(
args
,
"latent_vars"
,
320
)
args
.
latent_groups
=
getattr
(
args
,
"latent_groups"
,
2
)
args
.
latent_dim
=
getattr
(
args
,
"latent_dim"
,
0
)
args
.
mask_length
=
getattr
(
args
,
"mask_length"
,
10
)
args
.
mask_prob
=
getattr
(
args
,
"mask_prob"
,
0.65
)
args
.
mask_selection
=
getattr
(
args
,
"mask_selection"
,
"static"
)
args
.
mask_other
=
getattr
(
args
,
"mask_other"
,
0
)
args
.
no_mask_overlap
=
getattr
(
args
,
"no_mask_overlap"
,
False
)
args
.
mask_min_space
=
getattr
(
args
,
"mask_min_space"
,
1
)
args
.
mask_channel_length
=
getattr
(
args
,
"mask_channel_length"
,
10
)
args
.
mask_channel_prob
=
getattr
(
args
,
"mask_channel_prob"
,
0
)
args
.
mask_channel_selection
=
getattr
(
args
,
"mask_channel_selection"
,
"static"
)
args
.
mask_channel_other
=
getattr
(
args
,
"mask_channel_other"
,
0
)
args
.
no_mask_channel_overlap
=
getattr
(
args
,
"no_mask_channel_overlap"
,
False
)
args
.
mask_channel_min_space
=
getattr
(
args
,
"mask_channel_min_space"
,
1
)
args
.
dropout_input
=
getattr
(
args
,
"dropout_input"
,
0
)
args
.
dropout_features
=
getattr
(
args
,
"dropout_features"
,
0
)
args
.
num_negatives
=
getattr
(
args
,
"num_negatives"
,
100
)
args
.
negatives_from_everywhere
=
getattr
(
args
,
"negatives_from_everywhere"
,
False
)
args
.
cross_sample_negatives
=
getattr
(
args
,
"cross_sample_negatives"
,
0
)
args
.
codebook_negatives
=
getattr
(
args
,
"codebook_negatives"
,
0
)
args
.
conv_pos
=
getattr
(
args
,
"conv_pos"
,
128
)
args
.
conv_pos_groups
=
getattr
(
args
,
"conv_pos_groups"
,
16
)
args
.
latent_temp
=
getattr
(
args
,
"latent_temp"
,
"(2,0.5,0.999995)"
)
args
.
target_glu
=
getattr
(
args
,
"target_glu"
,
False
)
args
.
conv_bias
=
getattr
(
args
,
"conv_bias"
,
False
)
\ No newline at end of file
fairseq/modules/speech_to_text/subsampling.py
查看文件 @
c7242ff4
...
...
@@ -228,9 +228,10 @@ class Conv2dSubsampling(nn.Module):
return
x
,
x_len
def
subsampling
(
args
,
out_dim
=
None
):
def
subsampling
(
args
,
in_dim
=
None
,
out_dim
=
None
):
subsampling_type
=
getattr
(
args
,
"subsampling_type"
,
"conv1d"
)
layers
=
getattr
(
args
,
"subsampling_layers"
,
2
)
if
in_dim
is
None
:
in_dim
=
args
.
input_feat_per_channel
*
args
.
input_channels
filters
=
[
getattr
(
args
,
"subsampling_filter"
)]
+
[
args
.
encoder_embed_dim
if
out_dim
is
None
else
out_dim
]
kernel_size
=
getattr
(
args
,
"subsampling_kernel"
,
5
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论