Commit c7242ff4 by xuchen

implement the w2v2-transformer arch

parent afa5095d
inter_mixup: True inter-mixup: True
inter_mixup_layer: -1 inter-mixup-layer: -1
inter_mixup_prob: 1.0 inter-mixup-prob: 1.0
inter_mixup_ratio: 0.2 inter-mixup-ratio: 1.0
inter_mixup_beta: 0.2 inter-mixup-beta: 0.5
inter-mixup-keep-org: True
ctc-mixup-consistent-weight: 1
mixup-consistent-weight: 1
arch: s2t_w2v2_transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-embed-norm: True
encoder-no-scale-embedding: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
w2v2-model-path: /home/xuchen/st/models/w2v2/wav2vec_small.pt
freeze-w2v: False
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
...@@ -12,3 +12,4 @@ from .s2t_dual import * # noqa ...@@ -12,3 +12,4 @@ from .s2t_dual import * # noqa
from .s2t_ctc import * from .s2t_ctc import *
from .s2t_multibranch import * from .s2t_multibranch import *
from .s2t_dynamic_transformer import * from .s2t_dynamic_transformer import *
from .s2t_w2v2_transformer import *
import logging
import math
from typing import Dict, List, Optional, Tuple
import os
import numpy as np
import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils, tasks
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules.speech_to_text import Adapter, CTC
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.speech_to_text import S2TTransformerModel, S2TTransformerEncoder
from fairseq.models.wav2vec import Wav2Vec2Model, Wav2VecCtc
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
TransformerEncoderLayer,
S2TTransformerEncoderLayer,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
)
from fairseq.modules.speech_to_text import (
subsampling
)
from torch import Tensor
logger = logging.getLogger(__name__)
@register_model("s2t_w2v2_transformer")
class S2TW2V2TransformerModel(S2TTransformerModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser)
parser.add_argument("--w2v2-model-path", type=str, metavar="N",
help="path/to/wav2vec/model, support hdfs")
parser.add_argument("--freeze-w2v", action="store_true",
help="if we want to freeze the w2v features")
parser.add_argument("--use-asr-finetune-w2v", action="store_true",
help="if we want to load wav2vec2.0 asr finetuned data")
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TW2V2TransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
return encoder
class S2TW2V2TransformerEncoder(FairseqEncoder):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None)
assert args.w2v2_model_path is not None
self.w2v2_model_path = args.w2v2_model_path
self.use_asr_finetune_w2v = args.use_asr_finetune_w2v
ckpt = torch.load(self.w2v2_model_path)
self.w2v_args = ckpt["args"]
if not self.use_asr_finetune_w2v: # if use ssl-trained only
self.w2v_args = ckpt["args"]
self.wav2vec_model = Wav2Vec2Model.build_model(ckpt['args'], task=None)
self.wav2vec_model.load_state_dict(ckpt['model'])
else: # wav2vec-ctc model
ckpt["args"].data = args.data
if not os.path.exists(os.path.join(ckpt["args"].data, f"dict.{ckpt['args'].labels}.txt")):
os.system(f"wget -P {ckpt['args'].data} https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt")
task = tasks.setup_task(ckpt["args"])
model_finetuned = Wav2VecCtc.build_model(ckpt["args"], task=task)
model_finetuned.load_state_dict(ckpt['model'])
self.wav2vec_model = model_finetuned.w2v_encoder.w2v_model
self.w2v_args = ckpt["args"].w2v_args["model"]
self.freeze_w2v = args.freeze_w2v
# w2v_output_dim = 512
w2v_output_dim = self.w2v_args.encoder_embed_dim
self.encoder = S2TTransformerEncoder(args, task, embed_tokens)
del self.encoder.subsample
self.encoder.subsample = subsampling(args, in_dim=w2v_output_dim)
def _get_w2v_feature(self, src_tokens, src_lengths):
"""
:param src_tokens: b x frames
:param src_lengths: b-dim length
:return: w2v_feature: b x short_frames x feature-dim;
w2v_lengths: b-dim tensor
w2v_padding_mask: b x short_frames x feature-dim T/F tensor
"""
padding_mask = lengths_to_padding_mask(src_lengths)
# print("padding mask:", padding_mask.size())
# print(padding_mask)
# w2v_feature = self.wav2vec_model.feature_extractor(src_tokens).transpose(1,2)
w2v_feature, padding_mask = self.wav2vec_model.extract_features(src_tokens, padding_mask)
# print("after extraction, padding:", padding_mask)
output_length = (1 - padding_mask.int()).sum(dim=1)
# output_length = (torch.ones(padding_mask.size()) - padding_mask.int()).sum(dim=1)
return w2v_feature, padding_mask, output_length
def forward(self, src_tokens, src_lengths):
# 1. wav2vec
if self.freeze_w2v:
with torch.no_grad():
w2v_feature, encoder_padding_mask, input_lengths = self._get_w2v_feature(
src_tokens, src_lengths)
else:
w2v_feature, encoder_padding_mask, input_lengths = self._get_w2v_feature(
src_tokens, src_lengths)
return self.encoder.forward(w2v_feature, input_lengths)
def reorder_encoder_out(self, encoder_out, new_order):
return self.encoder.reorder_encoder_out(encoder_out, new_order)
@register_model_architecture(model_name="s2t_w2v2_transformer", arch_name="s2t_w2v2_transformer")
def base_architecture(args):
# Convolutional subsampler
args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.encoder_no_scale_embedding = getattr(args, "encoder_no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.decoder_learnable = getattr(args, 'decoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# local modeling
args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.sae_embed_norm = getattr(args, "sae_embed_norm", False)
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", "-1")
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False)
args.condensation_metric = getattr(args, "condensation_metric", "ratio")
args.condensation_mode = getattr(args, "condensation_mode", "create")
args.condensation_layers = getattr(args, "condensation_layers", None)
args.condensation_threshold = getattr(args, "condensation_threshold", "1.0")
args.condensation_ratio = getattr(args, "condensation_ratio", "0.0")
# Wav2vec2.0 feature-extractor
args.w2v2_model_path = getattr(args, "w2v2_model_path", "./wav2vec_small.pt")
args.freeze_w2v = getattr(args, "freeze_w2v", False) # default is false, 'store_true'
args.use_asr_finetune_w2v = getattr(args, "use_asr_finetune_w2v", False)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_s")
def s2t_w2v2_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_s_relative")
def s2t_w2v2_transformer_s_relative(args):
args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20
args.k_only = True
s2t_w2v2_transformer_s(args)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_xs")
def s2t_w2v2_transformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 3)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
s2t_w2v2_transformer_s(args)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_sp")
def s2t_w2v2_transformer_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_w2v2_transformer_s(args)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_m")
def s2t_w2v2_transformer_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_mp")
def s2t_w2v2_transformer_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_w2v2_transformer_m(args)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_l")
def s2t_w2v2_transformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("s2t_w2v2_transformer", "s2t_w2v2_transformer_lp")
def s2t_w2v2_transformer_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_w2v2_transformer_l(args)
...@@ -14,7 +14,7 @@ import torch.nn.functional as F ...@@ -14,7 +14,7 @@ import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.data.data_utils import compute_mask_indices from fairseq.data.data_utils import compute_mask_indices
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.modules import ( from fairseq.modules import (
Fp32GroupNorm, Fp32GroupNorm,
Fp32LayerNorm, Fp32LayerNorm,
...@@ -228,6 +228,18 @@ class Wav2Vec2Model(BaseFairseqModel): ...@@ -228,6 +228,18 @@ class Wav2Vec2Model(BaseFairseqModel):
feature_enc_layers = eval(cfg.conv_feature_layers) feature_enc_layers = eval(cfg.conv_feature_layers)
self.embed = feature_enc_layers[-1][0] self.embed = feature_enc_layers[-1][0]
# cfg.extractor_mode = "default"
# cfg.mask_other = 0
# cfg.mask_length = 10
# cfg.mask_channel_prob = 0
# cfg.mask_channel_selection = "static"
# cfg.mask_channel_other = 0
# cfg.mask_channel_length = 10
# cfg.mask_channel_min_space = 1
# cfg.latent_dim = 0
# cfg.layer_norm_first = False
# cfg.target_glu = False
self.feature_extractor = ConvFeatureExtractionModel( self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers, conv_layers=feature_enc_layers,
dropout=0.0, dropout=0.0,
...@@ -328,6 +340,7 @@ class Wav2Vec2Model(BaseFairseqModel): ...@@ -328,6 +340,7 @@ class Wav2Vec2Model(BaseFairseqModel):
def build_model(cls, cfg: Wav2Vec2Config, task=None): def build_model(cls, cfg: Wav2Vec2Config, task=None):
"""Build a new model instance.""" """Build a new model instance."""
base_architecture(cfg)
return cls(cfg) return cls(cfg)
def apply_mask(self, x, padding_mask): def apply_mask(self, x, padding_mask):
...@@ -763,7 +776,7 @@ class TransformerEncoder(nn.Module): ...@@ -763,7 +776,7 @@ class TransformerEncoder(nn.Module):
x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2) x_conv = x_conv.transpose(1, 2)
x += x_conv x = x + x_conv
if not self.layer_norm_first: if not self.layer_norm_first:
x = self.layer_norm(x) x = self.layer_norm(x)
...@@ -898,3 +911,72 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -898,3 +911,72 @@ class TransformerSentenceEncoderLayer(nn.Module):
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
return x, attn return x, attn
def base_architecture(args):
args.extractor_mode = getattr(args, "extractor_mode", "default")
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.final_dim = getattr(args, "final_dim", 0)
args.layer_norm_first = getattr(args, "layer_norm_first", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
conv_feature_layers = "[(512, 10, 5)]"
conv_feature_layers += " + [(512, 8, 4)]"
conv_feature_layers += " + [(512, 4, 2)] * 3"
conv_feature_layers += " + [(512, 1, 1)]"
args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers)
args.logit_temp = getattr(args, "logit_temp", 0.1)
args.quantize_targets = getattr(args, "quantize_targets", False)
args.quantize_input = getattr(args, "quantize_input", False)
args.same_quantizer = getattr(args, "same_quantizer", False)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
args.latent_vars = getattr(args, "latent_vars", 320)
args.latent_groups = getattr(args, "latent_groups", 2)
args.latent_dim = getattr(args, "latent_dim", 0)
args.mask_length = getattr(args, "mask_length", 10)
args.mask_prob = getattr(args, "mask_prob", 0.65)
args.mask_selection = getattr(args, "mask_selection", "static")
args.mask_other = getattr(args, "mask_other", 0)
args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
args.mask_min_space = getattr(args, "mask_min_space", 1)
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0)
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
args.dropout_input = getattr(args, "dropout_input", 0)
args.dropout_features = getattr(args, "dropout_features", 0)
args.num_negatives = getattr(args, "num_negatives", 100)
args.negatives_from_everywhere = getattr(args, "negatives_from_everywhere", False)
args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0)
args.codebook_negatives = getattr(args, "codebook_negatives", 0)
args.conv_pos = getattr(args, "conv_pos", 128)
args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
args.latent_temp = getattr(args, "latent_temp", "(2,0.5,0.999995)")
args.target_glu = getattr(args, "target_glu", False)
args.conv_bias = getattr(args, "conv_bias", False)
\ No newline at end of file
...@@ -228,10 +228,11 @@ class Conv2dSubsampling(nn.Module): ...@@ -228,10 +228,11 @@ class Conv2dSubsampling(nn.Module):
return x, x_len return x, x_len
def subsampling(args, out_dim=None): def subsampling(args, in_dim=None, out_dim=None):
subsampling_type = getattr(args, "subsampling_type", "conv1d") subsampling_type = getattr(args, "subsampling_type", "conv1d")
layers = getattr(args, "subsampling_layers", 2) layers = getattr(args, "subsampling_layers", 2)
in_dim = args.input_feat_per_channel * args.input_channels if in_dim is None:
in_dim = args.input_feat_per_channel * args.input_channels
filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim if out_dim is None else out_dim] filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim if out_dim is None else out_dim]
kernel_size = getattr(args, "subsampling_kernel", 5) kernel_size = getattr(args, "subsampling_kernel", 5)
stride = getattr(args, "subsampling_stride", 2) stride = getattr(args, "subsampling_stride", 2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论