Commit 0c7e71c7 by xuchen

Happy Valentine's Day!

I reformat the code, implement the intermedia ctc with adapter and the ctc decoding.
parent f286e56c
...@@ -11,8 +11,14 @@ lr: 2e-3 ...@@ -11,8 +11,14 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
ctc-weight: 0.3 ctc-weight: 0.3
zero_infinity: True
post-process: sentencepiece
\ No newline at end of file
#arch: pdss2t_transformer_s_8
arch: s2t_transformer_s arch: s2t_transformer_s
#pds-ctc: 1_1_1_1
#pds-ctc: 0_0_0_1
intermedia-ctc-layers: 6,8,10
intermedia-adapter: league
intermedia-ctc-weight: 0.2
ctc-self-distill-weight: 1
ctc-weight: 0.1
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -20,8 +11,14 @@ lr: 2e-3 ...@@ -20,8 +11,14 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(tst-COMMON) test_subset=(test)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
......
...@@ -37,7 +37,7 @@ dataset=libri_trans ...@@ -37,7 +37,7 @@ dataset=libri_trans
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=1000 vocab_size=1000
speed_perturb=1 speed_perturb=0
lcrm=1 lcrm=1
tokenizer=0 tokenizer=0
use_raw_audio=1 use_raw_audio=1
......
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
conv-kernel-sizes: 5,5
conv-channels: 704
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 176
encoder-ffn-embed-dim: 704
encoder-layers: 16
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-attention-type: rel_selfattn
\ No newline at end of file
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
\ No newline at end of file
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
\ No newline at end of file
...@@ -43,7 +43,7 @@ tokenizer=1 ...@@ -43,7 +43,7 @@ tokenizer=1
use_specific_dict=1 use_specific_dict=1
subword=1 subword=1
specific_prefix=subword32000_share_tok specific_prefix=subword32000_share
specific_dir=${root_dir}/data/mustc/st specific_dir=${root_dir}/data/mustc/st
src_vocab_prefix=spm_unigram10000_st_share src_vocab_prefix=spm_unigram10000_st_share
tgt_vocab_prefix=spm_unigram10000_st_share tgt_vocab_prefix=spm_unigram10000_st_share
......
...@@ -108,20 +108,14 @@ class CtcCriterion(FairseqCriterion): ...@@ -108,20 +108,14 @@ class CtcCriterion(FairseqCriterion):
net_output, log_probs=True net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
if "src_lengths" in sample["net_input"]: non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = sample["net_input"]["src_lengths"] input_lengths = non_padding_mask.long().sum(-1)
else:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (sample["target"] != self.pad_idx) & ( pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx sample["target"] != self.eos_idx
) )
targets_flat = sample["target"].masked_select(pad_mask) targets_flat = sample["target"].masked_select(pad_mask)
if "target_lengths" in sample: target_lengths = pad_mask.sum(-1)
target_lengths = sample["target_lengths"]
else:
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss( loss = F.ctc_loss(
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .berard import * # noqa from .berard import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_ctc import *
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
from .pdss2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
...@@ -663,7 +663,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -663,7 +663,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True) need_layernorm=True)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None: if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight ctc.ctc_projection.weight = embed_tokens.weight
inter_ctc_module = ctc inter_ctc_module = ctc
else: else:
......
#!/usr/bin/env python3
import logging
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import S2TTransformerModel, S2TTransformerEncoder
from fairseq.modules import (
ConformerEncoderLayer,
)
logger = logging.getLogger(__name__)
@register_model("s2t_conformer")
class S2TConformerModel(S2TTransformerModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TConformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
return encoder
class S2TConformerEncoder(S2TTransformerEncoder):
"""Speech-to-text Conformer encoder that consists of input subsampler and
Conformer encoder."""
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(args, task, embed_tokens)
del self.layers
self.layers = nn.ModuleList(
[ConformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
def forward(self, src_tokens, src_lengths):
if self.history is not None:
self.history.clean()
cos_sim_idx = -1
dis = self.dis
if self.gather_cos_sim:
x = src_tokens
x = x.transpose(0, 1)
self.add_to_dict(x, dis, cos_sim_idx)
x, input_lengths = self.subsample(src_tokens, src_lengths)
if type(x) == list:
inner_x = x
if self.gather_cos_sim:
for x in inner_x:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
x = inner_x[-1]
x = self.embed_scale * x
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
x = self.dropout_module(x)
positions = self.dropout_module(positions)
# add emb into history
if self.history is not None:
self.history.add(x)
cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.history is not None:
self.history.add(x)
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
@register_model_architecture(model_name="s2t_conformer", arch_name="s2t_conformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("s2t_conformer", "s2t_conformer_s")
def s2t_conformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_s_relative")
def s2t_conformer_s_relative(args):
args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20
args.k_only = True
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_xs")
def s2t_conformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 3)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_sp")
def s2t_conformer_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_m")
def s2t_conformer_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_mp")
def s2t_conformer_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_conformer_m(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_l")
def s2t_conformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_lp")
def s2t_conformer_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_conformer_l(args)
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text.modules import InterAdapter, CTC
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
)
from torch import Tensor
logger = logging.getLogger(__name__)
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
kernel_sizes: List[int] = (3, 3),
):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
inner_x = []
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
inner_x.append(x)
_, _, out_seq_len = x.size()
# x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
out_inner_x = []
for x in inner_x:
out_inner_x.append(x.transpose(1, 2).transpose(0, 1).contiguous())
return out_inner_x, self.get_out_seq_lens_tensor(src_lengths)
@register_model("s2t_ctc")
class S2TCTCModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument('--share-all-embeddings',
action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument('--init-value', type=str, default='avg', choices=['avg', 'one'],
help='how to init the learned weight matrix')
parser.add_argument('--weight-type', type=str, default='scalar',
help='type of learned weight [scalar, scalar_n(n>1), vector]')
parser.add_argument('--encoder-learnable', type=eval, default='True',
help='enable to learn weights for encoder')
parser.add_argument('--decoder-learnable', type=eval, default='True',
help='enable to learn weights for decoder')
parser.add_argument('--normalize-learned-weight', type=eval, default='False',
help='normalize learned weight by softmax')
parser.add_argument('--normalize-embedding', type=eval, default='False',
help='normalize the input of embedding')
parser.add_argument('--history-dropout', type=float, default=0.0, metavar='D',
help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
# Conformer setting
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the upper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Simultaneous speech translation
parser.add_argument(
"--simul",
default=False,
action="store_true",
help="Simultaneous speech translation or not",
)
# interleaved dropout
parser.add_argument('--interleave-dropout', type=int,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout',
action="store_true",
default=False,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-epoch',
type=int,
default=None,
help='interleaved dropout probability')
parser.add_argument('--cl-dropout-strategy',
type=str,
help='interleaved dropout probability')
# intermedia CTC loss
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TCTCEncoder(args, task)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
encoder = cls.build_encoder(args, task)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
return cls(encoder)
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (T, B, D) tensor
logits = net_output["ctc_logit"][0]
# logits = logits.transpose(0, 1)
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, src_tokens, src_lengths, prev_output_tokens=None):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
return encoder_out
class S2TCTCEncoder(FairseqEncoder):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def __init__(self, args, task=None):
super().__init__(None)
dim = args.encoder_embed_dim
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_scale = math.sqrt(dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = 1
self.subsample = Conv1dSubsampler(
args.input_feat_per_channel * args.input_channels,
args.conv_channels,
dim,
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding(
args.max_source_positions, dim, self.padding_idx
)
self.layers = nn.ModuleList(
[S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(dim)
else:
self.layer_norm = None
if args.use_enc_dlcl:
self.history = DynamicLinearCombination(args, is_encoder=True)
else:
self.history = None
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
)
# gather cosine similarity of the representation
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
# self.gather_cos_sim = True
self.dis = 2
self.cos_sim = dict()
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
self.adapter = InterAdapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
def add_to_dict(self, x, dis, idx):
sim = 0
seq_len = x.size(0)
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
for i in range(dis, seq_len - dis):
a = x[i, :, :]
for j in range(-dis, dis + 1):
if j == 0:
continue
b = x[i + j, :, :]
sim_j = cos(a, b).mean()
sim += sim_j
sim = sim / 2 / dis / (seq_len - 2 * dis)
if idx not in self.cos_sim:
self.cos_sim[idx] = []
self.cos_sim[idx].append(float(sim))
def forward(self, src_tokens, src_lengths, **kwargs):
if self.history is not None:
self.history.clean()
# gather cosine similarity
cos_sim_idx = -1
dis = self.dis
if self.gather_cos_sim:
self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
# down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths)
if type(x) == list:
inner_x = x
# gather cosine similarity
if self.gather_cos_sim:
for x in inner_x:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
x = inner_x[-1]
# embedding scaling
x = self.embed_scale * x
# padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
x = self.dropout_module(x)
positions = self.dropout_module(positions)
# add emb into history
if self.history is not None:
self.history.push(x)
# gather cosine similarity
cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
layer_idx = 0
intermedia_ctc_logits = []
for layer in self.layers:
layer_idx += 1
if self.history is not None:
x = self.history.pop()
# encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions)
# interleave CTC
if layer_idx in self.intermedia_ctc_layers:
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
prob = F.softmax(logit, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
# gather cosine similarity
if self.gather_cos_sim:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
if self.history is not None:
self.history.push(x)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
ctc_logit = self.ctc(x)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # B x T x C
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_ctc_logit = (
[] if len(encoder_out["ctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["ctc_logit"] if x is not None]
)
new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]]
)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"ctc_logit": new_ctc_logit, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], # B x T
"src_lengths": [], # B x 1
}
class CTCDecoder(object):
def __init__(self, models, args, dictionary, blank_idx):
self.dict = dictionary
self.vocab_size = len(dictionary)
self.blank = blank_idx
self.pad = dictionary.pad()
self.unk = dictionary.unk()
self.eos = dictionary.eos()
self.vocab_size = len(dictionary)
self.beam_size = args.beam
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(self.beam_size, self.vocab_size - 1)
# from fairseq.sequence_generator import EnsembleModel
from fairseq.sequence_generator import EnsembleModel
if isinstance(models, EnsembleModel):
self.model = models
else:
self.model = EnsembleModel(models)
self.model = models[0]
self.model.eval()
self.lm_model = getattr(args, "kenlm_model", None)
self.lm_weight = getattr(args, "lm_weight", 0)
if self.lm_model is not None:
self.lm_model.eval()
from ctcdecode import CTCBeamDecoder
self.ctc_decoder = CTCBeamDecoder(
dictionary.symbols,
model_path=self.lm_model,
alpha=self.lm_weight,
beta=0,
cutoff_top_n=40,
cutoff_prob=1.0,
beam_width=self.beam_size,
num_processes=20,
blank_id=self.blank,
log_probs_input=False
)
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
net_input = sample["net_input"]
# bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
src_tokens = net_input["src_tokens"]
src_lengths = net_input["src_lengths"]
bsz, src_len = src_tokens.size()[:2]
beam_size = self.beam_size
encoder_outs = self.model(src_tokens=src_tokens,
src_lengths=src_lengths)
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
beam_results, beam_scores, timesteps, out_lens = self.ctc_decoder.decode(F.softmax(ctc_logit, -1), src_lengths)
# beam_results = beam_results[:, :, :out_lens.max()]
# for beam_idx in range(beam_size):
# top_beam_tokens = beam_results[:, beam_idx, :]
# top_beam_len = out_lens[:, beam_idx]
# mask = torch.arange(0, top_beam_tokens.size(1)).type_as(top_beam_len). \
# repeat(top_beam_len.size(0), 1).lt(top_beam_len.unsqueeze(1))
# top_beam_tokens[~mask] = self.pad
finalized = []
for idx in range(bsz):
hypos = []
for beam_idx in range(beam_size):
hypo = dict()
length = out_lens[idx][beam_idx]
scores = beam_scores[idx, beam_idx]
hypo["tokens"] = beam_results[idx, beam_idx, : length]
hypo["score"] = scores
hypo["attention"] = None
hypo["alignment"] = None
hypo["positional_scores"] = torch.Tensor([scores / length] * length)
hypos.append(hypo)
finalized.append(hypos)
return finalized
@register_model_architecture(model_name="s2t_ctc", arch_name="s2t_ctc")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# local modeling
args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved dropout
args.interleave_dropout = getattr(args, "interleave_dropout", None)
args.cl_dropout = getattr(args, "cl_dropout", False)
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
@register_model_architecture("s2t_ctc", "s2t_ctc_s")
def s2t_ctc_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
@register_model_architecture("s2t_ctc", "s2t_ctc_s_relative")
def s2t_ctc_s_relative(args):
args.max_encoder_relative_length = 100
args.k_only = True
s2t_ctc_s(args)
@register_model_architecture("s2t_ctc", "s2t_ctc_xs")
def s2t_ctc_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
s2t_ctc_s(args)
@register_model_architecture("s2t_ctc", "s2t_ctc_sp")
def s2t_ctc_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_ctc_s(args)
@register_model_architecture("s2t_ctc", "s2t_ctc_m")
def s2t_ctc_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("s2t_ctc", "s2t_ctc_mp")
def s2t_ctc_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_ctc_m(args)
@register_model_architecture("s2t_ctc", "s2t_ctc_l")
def s2t_ctc_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("s2t_ctc", "s2t_ctc_lp")
def s2t_ctc_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_ctc_l(args)
...@@ -19,7 +19,6 @@ from fairseq.models.speech_to_text import ( ...@@ -19,7 +19,6 @@ from fairseq.models.speech_to_text import (
PDSS2TTransformerEncoder, PDSS2TTransformerEncoder,
) )
from fairseq.models.speech_to_text.modules import CTCCompressStrategy from fairseq.models.speech_to_text.modules import CTCCompressStrategy
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
...@@ -158,13 +157,6 @@ class Adapter(nn.Module): ...@@ -158,13 +157,6 @@ class Adapter(nn.Module):
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(embed_dim, embed_dim), nn.Linear(embed_dim, embed_dim),
) )
elif self.adapter_type == "subsample":
self.subsample_adaptor = Conv1dSubsampler(
embed_dim,
args.conv_channels,
embed_dim,
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None: if embed_tokens is None:
...@@ -197,11 +189,6 @@ class Adapter(nn.Module): ...@@ -197,11 +189,6 @@ class Adapter(nn.Module):
elif self.adapter_type == "context": elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1) out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "subsample":
representation = representation.transpose(0, 1)
out, input_lengths = self.subsample_adaptor(representation, lengths)
padding = lengths_to_padding_mask(input_lengths)
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
......
...@@ -19,68 +19,18 @@ from fairseq.modules import ( ...@@ -19,68 +19,18 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
ConformerEncoderLayer, S2TTransformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
) )
from fairseq.modules.speech_to_text import (
subsampling
)
from torch import Tensor from torch import Tensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
kernel_sizes: List[int] = (3, 3),
):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
inner_x = []
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
inner_x.append(x)
_, _, out_seq_len = x.size()
# x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
out_inner_x = []
for x in inner_x:
out_inner_x.append(x.transpose(1, 2).transpose(0, 1).contiguous())
return out_inner_x, self.get_out_seq_lens_tensor(src_lengths)
@register_model("s2t_transformer") @register_model("s2t_transformer")
class S2TTransformerModel(FairseqEncoderDecoderModel): class S2TTransformerModel(FairseqEncoderDecoderModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
...@@ -95,18 +45,43 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -95,18 +45,43 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# input # subsampling
parser.add_argument( parser.add_argument(
"--conv-kernel-sizes", "--subsampling-type",
type=str, type=str,
metavar="N", help="subsampling type, like conv1d and conv2d",
help="kernel sizes of Conv1d subsampling layers",
) )
parser.add_argument( parser.add_argument(
"--conv-channels", "--subsampling-layers",
type=int, type=int,
metavar="N", help="subsampling layers",
help="# of channels in Conv1d subsampling layers", )
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
) )
# Transformer # Transformer
parser.add_argument( parser.add_argument(
...@@ -499,12 +474,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -499,12 +474,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.embed_scale = 1.0 self.embed_scale = 1.0
self.padding_idx = 1 self.padding_idx = 1
self.subsample = Conv1dSubsampler( self.subsample = subsampling(args)
args.input_feat_per_channel * args.input_channels,
args.conv_channels,
dim,
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
...@@ -512,7 +482,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -512,7 +482,7 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ConformerEncoderLayer(args) for _ in range(args.encoder_layers)] [S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
if args.encoder_normalize_before: if args.encoder_normalize_before:
...@@ -608,15 +578,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -608,15 +578,8 @@ class S2TTransformerEncoder(FairseqEncoder):
# down-sampling # down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
# (B, T, D) -> (T, B, D)
if type(x) == list: x = x.transpose(0, 1)
inner_x = x
# gather cosine similarity
if self.gather_cos_sim:
for x in inner_x:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
x = inner_x[-1]
# embedding scaling # embedding scaling
x = self.embed_scale * x x = self.embed_scale * x
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""isort:skip_file""" """isort:skip_file"""
from .squeeze_excitation import SEAttention from .squeeze_excitation import SEAttention
from .activations import swish, Swish
from .adaptive_input import AdaptiveInput from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
...@@ -43,7 +44,7 @@ from .transpose_last import TransposeLast ...@@ -43,7 +44,7 @@ from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .conformer_layer import ConformerEncoderLayer from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .pds_layer import PDSTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
__all__ = [ __all__ = [
...@@ -52,7 +53,7 @@ __all__ = [ ...@@ -52,7 +53,7 @@ __all__ = [
"AdaptiveSoftmax", "AdaptiveSoftmax",
"BeamableMM", "BeamableMM",
"CharacterTokenEmbedder", "CharacterTokenEmbedder",
"ConformerEncoderLayer", "S2TTransformerEncoderLayer",
"ConvolutionModule", "ConvolutionModule",
"ConvTBC", "ConvTBC",
"cross_entropy", "cross_entropy",
...@@ -86,6 +87,8 @@ __all__ = [ ...@@ -86,6 +87,8 @@ __all__ = [
"ScalarBias", "ScalarBias",
"SEAttention", "SEAttention",
"SinusoidalPositionalEmbedding", "SinusoidalPositionalEmbedding",
"swish",
"Swish",
"TransformerSentenceEncoderLayer", "TransformerSentenceEncoderLayer",
"TransformerSentenceEncoder", "TransformerSentenceEncoder",
"TransformerDecoderLayer", "TransformerDecoderLayer",
......
import torch
import torch.nn as nn
def get_activation_class(activation: str, dim=None):
""" Returns the activation function corresponding to `activation` """
if activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "glu":
assert dim is not None
return nn.GLU(dim=dim)
elif activation == "swish":
return Swish()
elif activation == "none":
return nn.Identity()
else:
raise RuntimeError("activation function {} not supported".format(activation))
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * x.sigmoid()
...@@ -9,14 +9,9 @@ from typing import Optional, Tuple ...@@ -9,14 +9,9 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from fairseq.modules.layer_norm import LayerNorm
class Swish(nn.Module): from fairseq.modules.layer_norm import LayerNorm
"""Construct an Swish object.""" from fairseq.modules.activations import get_activation_class
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
...@@ -73,7 +68,7 @@ class ConvolutionModule(nn.Module): ...@@ -73,7 +68,7 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.activation = Swish() self.activation = get_activation_class("swish")
def forward( def forward(
self, self,
......
...@@ -20,7 +20,7 @@ from fairseq.modules.quant_noise import quant_noise ...@@ -20,7 +20,7 @@ from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
class ConformerEncoderLayer(nn.Module): class S2TTransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is In the original paper each operation (multi-head attention or FFN) is
......
from .subsampling import *
\ No newline at end of file
import torch
import torch.nn as nn
from fairseq.modules.activations import Swish
from fairseq.modules.layer_norm import LayerNorm
def get_activation_class(activation: str, dim=None):
""" Returns the activation function corresponding to `activation` """
if activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "glu":
assert dim is not None
return nn.GLU(dim=dim)
elif activation == "swish":
return Swish()
elif activation == "none":
return nn.Identity()
else:
raise RuntimeError("activation function {} not supported".format(activation))
class TransposeLast(nn.Module):
@staticmethod
def forward(x):
return x.transpose(-1, -2).contiguous()
def get_norm(norm_type, size, transpose=False):
trans = nn.Identity()
if transpose:
trans = TransposeLast()
if norm_type == "batch1d":
return nn.Sequential(trans, nn.BatchNorm1d(size), trans)
elif norm_type == "batch2d":
return nn.Sequential(trans, nn.BatchNorm2d(size), trans)
elif norm_type == "layer":
return nn.Sequential(trans, LayerNorm(size), trans)
elif norm_type == "none":
return nn.Identity()
else:
raise RuntimeError("normalization type {} not supported".format(norm_type))
class Conv1dSubsampling(nn.Module):
"""Conv1d Subsampling Block
Args:
num_layers: number of strided convolution layers
in_dim: input feature dimension
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_length, in_dim)
Output: (batch_size, out_length, out_dim)
"""
def __init__(self, num_layers,
in_dim, filters, kernel_size, stride=2,
norm="none", act="glu"):
super(Conv1dSubsampling, self).__init__()
# Assert
assert norm in ["batch1d", "layer", "none"]
assert act in ["relu", "swish", "glu", "none"]
# Layers
self.layers = nn.ModuleList([nn.Sequential(
nn.Conv1d(in_dim if layer_id == 0 else filters[layer_id - 1],
filters[layer_id] * 2 if act == "glu" else filters[layer_id],
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1)
) for layer_id in range(num_layers)])
def forward(self, x, x_len):
# (B, T, D) -> (B, D, T)
x = x.transpose(1, 2)
# Layers
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
x = x.transpose(1, 2)
return x, x_len
class Conv2dSubsampling(nn.Module):
"""Conv2d Subsampling Block
Args:
num_layers: number of strided convolution layers
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_length, in_dim)
Output: (batch_size, out_length, out_dim)
"""
def __init__(self, num_layers,
in_dim, filters, kernel_size, stride=2,
norm="none", act="glu"):
super(Conv2dSubsampling, self).__init__()
# Assert
assert norm in ["batch2d", "none"]
assert act in ["relu", "swish", "glu", "none"]
# Conv 2D Subsampling Layers
self.layers = nn.ModuleList([nn.Sequential(
nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1],
filters[layer_id] * 2 if act =="glu" else filters[layer_id],
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1)
) for layer_id in range(num_layers)])
self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1])
def forward(self, x, x_len):
# (B, T, D) -> (B, D, T) -> (B, 1, D, T)
x = x.tranpose(1, 2).unsqueeze(dim=1)
# Layers
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).transpose(1, 2)
x = self.linear(x)
return x, x_len
def subsampling(args):
subsampling_type = getattr(args, "subsampling_type", "conv1d")
layers = getattr(args, "subsampling_layers", 2)
in_dim = args.input_feat_per_channel * args.input_channels
filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim]
kernel_size = getattr(args, "subsampling_kernel", 5)
stride = getattr(args, "subsampling_stride", 2)
norm = getattr(args, "subsampling_norm", "none")
activation = getattr(args, "subsampling_activation", "none")
if subsampling_type == "conv1d":
return Conv1dSubsampling(layers, in_dim, filters, kernel_size, stride, norm, activation)
elif subsampling_type == "conv2d":
return Conv2dSubsampling(layers, in_dim, filters, kernel_size, stride, norm, activation)
else:
raise RuntimeError("Subsampling type {} not supported".format(subsampling_type))
...@@ -16,7 +16,6 @@ from fairseq.data.audio.speech_to_text_dataset import ( ...@@ -16,7 +16,6 @@ from fairseq.data.audio.speech_to_text_dataset import (
) )
from fairseq.tasks import LegacyFairseqTask, register_task from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -143,12 +142,18 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -143,12 +142,18 @@ class SpeechToTextTask(LegacyFairseqTask):
return super(SpeechToTextTask, self).build_model(args) return super(SpeechToTextTask, self).build_model(args)
def build_generator( def build_generator(
self, self,
models, models,
args, args,
seq_gen_cls=None, seq_gen_cls=None,
extra_gen_cls_kwargs=None, extra_gen_cls_kwargs=None,
): ):
from fairseq.models.speech_to_text import S2TCTCModel, CTCDecoder
if isinstance(models[0], S2TCTCModel):
blank_idx = self.target_dictionary.index(self.blank_symbol) if hasattr(self, 'blank_symbol') else 0
return CTCDecoder(models, args,
self.target_dictionary,
blank_idx)
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
raise ValueError( raise ValueError(
'Please set "--prefix-size 1" since ' 'Please set "--prefix-size 1" since '
......
...@@ -406,7 +406,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -406,7 +406,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for item in translation_list: for item in translation_list:
f.write("{}\n".format("\t".join(item))) f.write("{}\n".format("\t".join(item)))
if models[0].decoder.gather_attn_weight: if hasattr(models[0], "decoder") and models[0].decoder.gather_attn_weight:
weights = models[0].decoder.attn_weights weights = models[0].decoder.attn_weights
sort_weights = sorted(weights.items(), key=lambda k: k[0]) sort_weights = sorted(weights.items(), key=lambda k: k[0])
num = sum([k[1] for k in sort_weights]) num = sum([k[1] for k in sort_weights])
...@@ -419,8 +419,6 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -419,8 +419,6 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
with open("cos_sim", "w", encoding="utf-8") as fw: with open("cos_sim", "w", encoding="utf-8") as fw:
for layer, sim in cos_sim.items(): for layer, sim in cos_sim.items():
sim = sum(sim) / len(sim) * 100 sim = sum(sim) / len(sim) * 100
# if layer >= 10:
# layer -= 10
fw.write("%d\t%f\n" % (layer, sim)) fw.write("%d\t%f\n" % (layer, sim))
return scorer return scorer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论