Commit c7242ff4 by xuchen

implement the w2v2-transformer arch

parent afa5095d
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
inter_mixup_beta: 0.2
inter-mixup: True
inter-mixup-layer: -1
inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5
inter-mixup-keep-org: True
ctc-mixup-consistent-weight: 1
mixup-consistent-weight: 1
arch: s2t_w2v2_transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-embed-norm: True
encoder-no-scale-embedding: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
w2v2-model-path: /home/xuchen/st/models/w2v2/wav2vec_small.pt
freeze-w2v: False
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -12,3 +12,4 @@ from .s2t_dual import * # noqa
from .s2t_ctc import *
from .s2t_multibranch import *
from .s2t_dynamic_transformer import *
from .s2t_w2v2_transformer import *
......@@ -14,7 +14,7 @@ import torch.nn.functional as F
from fairseq import utils
from fairseq.data.data_utils import compute_mask_indices
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
......@@ -228,6 +228,18 @@ class Wav2Vec2Model(BaseFairseqModel):
feature_enc_layers = eval(cfg.conv_feature_layers)
self.embed = feature_enc_layers[-1][0]
# cfg.extractor_mode = "default"
# cfg.mask_other = 0
# cfg.mask_length = 10
# cfg.mask_channel_prob = 0
# cfg.mask_channel_selection = "static"
# cfg.mask_channel_other = 0
# cfg.mask_channel_length = 10
# cfg.mask_channel_min_space = 1
# cfg.latent_dim = 0
# cfg.layer_norm_first = False
# cfg.target_glu = False
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
......@@ -328,6 +340,7 @@ class Wav2Vec2Model(BaseFairseqModel):
def build_model(cls, cfg: Wav2Vec2Config, task=None):
"""Build a new model instance."""
base_architecture(cfg)
return cls(cfg)
def apply_mask(self, x, padding_mask):
......@@ -763,7 +776,7 @@ class TransformerEncoder(nn.Module):
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x += x_conv
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
......@@ -898,3 +911,72 @@ class TransformerSentenceEncoderLayer(nn.Module):
x = self.final_layer_norm(x)
return x, attn
def base_architecture(args):
args.extractor_mode = getattr(args, "extractor_mode", "default")
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.final_dim = getattr(args, "final_dim", 0)
args.layer_norm_first = getattr(args, "layer_norm_first", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
conv_feature_layers = "[(512, 10, 5)]"
conv_feature_layers += " + [(512, 8, 4)]"
conv_feature_layers += " + [(512, 4, 2)] * 3"
conv_feature_layers += " + [(512, 1, 1)]"
args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers)
args.logit_temp = getattr(args, "logit_temp", 0.1)
args.quantize_targets = getattr(args, "quantize_targets", False)
args.quantize_input = getattr(args, "quantize_input", False)
args.same_quantizer = getattr(args, "same_quantizer", False)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
args.latent_vars = getattr(args, "latent_vars", 320)
args.latent_groups = getattr(args, "latent_groups", 2)
args.latent_dim = getattr(args, "latent_dim", 0)
args.mask_length = getattr(args, "mask_length", 10)
args.mask_prob = getattr(args, "mask_prob", 0.65)
args.mask_selection = getattr(args, "mask_selection", "static")
args.mask_other = getattr(args, "mask_other", 0)
args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
args.mask_min_space = getattr(args, "mask_min_space", 1)
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0)
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
args.dropout_input = getattr(args, "dropout_input", 0)
args.dropout_features = getattr(args, "dropout_features", 0)
args.num_negatives = getattr(args, "num_negatives", 100)
args.negatives_from_everywhere = getattr(args, "negatives_from_everywhere", False)
args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0)
args.codebook_negatives = getattr(args, "codebook_negatives", 0)
args.conv_pos = getattr(args, "conv_pos", 128)
args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
args.latent_temp = getattr(args, "latent_temp", "(2,0.5,0.999995)")
args.target_glu = getattr(args, "target_glu", False)
args.conv_bias = getattr(args, "conv_bias", False)
\ No newline at end of file
......@@ -228,10 +228,11 @@ class Conv2dSubsampling(nn.Module):
return x, x_len
def subsampling(args, out_dim=None):
def subsampling(args, in_dim=None, out_dim=None):
subsampling_type = getattr(args, "subsampling_type", "conv1d")
layers = getattr(args, "subsampling_layers", 2)
in_dim = args.input_feat_per_channel * args.input_channels
if in_dim is None:
in_dim = args.input_feat_per_channel * args.input_channels
filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim if out_dim is None else out_dim]
kernel_size = getattr(args, "subsampling_kernel", 5)
stride = getattr(args, "subsampling_stride", 2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论