Commit 0c7e71c7 by xuchen

Happy Valentine's Day!

I reformat the code, implement the intermedia ctc with adapter and the ctc decoding.
parent f286e56c
...@@ -11,8 +11,14 @@ lr: 2e-3 ...@@ -11,8 +11,14 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
ctc-weight: 0.3 ctc-weight: 0.3
zero_infinity: True
post-process: sentencepiece
\ No newline at end of file
#arch: pdss2t_transformer_s_8
arch: s2t_transformer_s arch: s2t_transformer_s
#pds-ctc: 1_1_1_1
#pds-ctc: 0_0_0_1
intermedia-ctc-layers: 6,8,10
intermedia-adapter: league
intermedia-ctc-weight: 0.2
ctc-self-distill-weight: 1
ctc-weight: 0.1
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -20,8 +11,14 @@ lr: 2e-3 ...@@ -20,8 +11,14 @@ lr: 2e-3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(tst-COMMON) test_subset=(test)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
......
...@@ -37,7 +37,7 @@ dataset=libri_trans ...@@ -37,7 +37,7 @@ dataset=libri_trans
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=1000 vocab_size=1000
speed_perturb=1 speed_perturb=0
lcrm=1 lcrm=1
tokenizer=0 tokenizer=0
use_raw_audio=1 use_raw_audio=1
......
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
conv-kernel-sizes: 5,5
conv-channels: 704
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 176
encoder-ffn-embed-dim: 704
encoder-layers: 16
encoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-attention-type: rel_selfattn
\ No newline at end of file
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
\ No newline at end of file
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
\ No newline at end of file
...@@ -43,7 +43,7 @@ tokenizer=1 ...@@ -43,7 +43,7 @@ tokenizer=1
use_specific_dict=1 use_specific_dict=1
subword=1 subword=1
specific_prefix=subword32000_share_tok specific_prefix=subword32000_share
specific_dir=${root_dir}/data/mustc/st specific_dir=${root_dir}/data/mustc/st
src_vocab_prefix=spm_unigram10000_st_share src_vocab_prefix=spm_unigram10000_st_share
tgt_vocab_prefix=spm_unigram10000_st_share tgt_vocab_prefix=spm_unigram10000_st_share
......
...@@ -108,19 +108,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -108,19 +108,13 @@ class CtcCriterion(FairseqCriterion):
net_output, log_probs=True net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
if "src_lengths" in sample["net_input"]: non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = sample["net_input"]["src_lengths"]
else:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (sample["target"] != self.pad_idx) & ( pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx sample["target"] != self.eos_idx
) )
targets_flat = sample["target"].masked_select(pad_mask) targets_flat = sample["target"].masked_select(pad_mask)
if "target_lengths" in sample:
target_lengths = sample["target_lengths"]
else:
target_lengths = pad_mask.sum(-1) target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .berard import * # noqa from .berard import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_ctc import *
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
from .pdss2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
...@@ -663,7 +663,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -663,7 +663,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True) need_layernorm=True)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None: if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight ctc.ctc_projection.weight = embed_tokens.weight
inter_ctc_module = ctc inter_ctc_module = ctc
else: else:
......
...@@ -19,7 +19,6 @@ from fairseq.models.speech_to_text import ( ...@@ -19,7 +19,6 @@ from fairseq.models.speech_to_text import (
PDSS2TTransformerEncoder, PDSS2TTransformerEncoder,
) )
from fairseq.models.speech_to_text.modules import CTCCompressStrategy from fairseq.models.speech_to_text.modules import CTCCompressStrategy
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
...@@ -158,13 +157,6 @@ class Adapter(nn.Module): ...@@ -158,13 +157,6 @@ class Adapter(nn.Module):
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(embed_dim, embed_dim), nn.Linear(embed_dim, embed_dim),
) )
elif self.adapter_type == "subsample":
self.subsample_adaptor = Conv1dSubsampler(
embed_dim,
args.conv_channels,
embed_dim,
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None: if embed_tokens is None:
...@@ -197,11 +189,6 @@ class Adapter(nn.Module): ...@@ -197,11 +189,6 @@ class Adapter(nn.Module):
elif self.adapter_type == "context": elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1) out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "subsample":
representation = representation.transpose(0, 1)
out, input_lengths = self.subsample_adaptor(representation, lengths)
padding = lengths_to_padding_mask(input_lengths)
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
......
...@@ -19,68 +19,18 @@ from fairseq.modules import ( ...@@ -19,68 +19,18 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
ConformerEncoderLayer, S2TTransformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
) )
from fairseq.modules.speech_to_text import (
subsampling
)
from torch import Tensor from torch import Tensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
kernel_sizes: List[int] = (3, 3),
):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
inner_x = []
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
inner_x.append(x)
_, _, out_seq_len = x.size()
# x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
out_inner_x = []
for x in inner_x:
out_inner_x.append(x.transpose(1, 2).transpose(0, 1).contiguous())
return out_inner_x, self.get_out_seq_lens_tensor(src_lengths)
@register_model("s2t_transformer") @register_model("s2t_transformer")
class S2TTransformerModel(FairseqEncoderDecoderModel): class S2TTransformerModel(FairseqEncoderDecoderModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
...@@ -95,18 +45,43 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -95,18 +45,43 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# input # subsampling
parser.add_argument( parser.add_argument(
"--conv-kernel-sizes", "--subsampling-type",
type=str, type=str,
metavar="N", help="subsampling type, like conv1d and conv2d",
help="kernel sizes of Conv1d subsampling layers",
) )
parser.add_argument( parser.add_argument(
"--conv-channels", "--subsampling-layers",
type=int, type=int,
metavar="N", help="subsampling layers",
help="# of channels in Conv1d subsampling layers", )
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
) )
# Transformer # Transformer
parser.add_argument( parser.add_argument(
...@@ -499,12 +474,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -499,12 +474,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.embed_scale = 1.0 self.embed_scale = 1.0
self.padding_idx = 1 self.padding_idx = 1
self.subsample = Conv1dSubsampler( self.subsample = subsampling(args)
args.input_feat_per_channel * args.input_channels,
args.conv_channels,
dim,
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
...@@ -512,7 +482,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -512,7 +482,7 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ConformerEncoderLayer(args) for _ in range(args.encoder_layers)] [S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
if args.encoder_normalize_before: if args.encoder_normalize_before:
...@@ -608,15 +578,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -608,15 +578,8 @@ class S2TTransformerEncoder(FairseqEncoder):
# down-sampling # down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
# (B, T, D) -> (T, B, D)
if type(x) == list: x = x.transpose(0, 1)
inner_x = x
# gather cosine similarity
if self.gather_cos_sim:
for x in inner_x:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
x = inner_x[-1]
# embedding scaling # embedding scaling
x = self.embed_scale * x x = self.embed_scale * x
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""isort:skip_file""" """isort:skip_file"""
from .squeeze_excitation import SEAttention from .squeeze_excitation import SEAttention
from .activations import swish, Swish
from .adaptive_input import AdaptiveInput from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
...@@ -43,7 +44,7 @@ from .transpose_last import TransposeLast ...@@ -43,7 +44,7 @@ from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .conformer_layer import ConformerEncoderLayer from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .pds_layer import PDSTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
__all__ = [ __all__ = [
...@@ -52,7 +53,7 @@ __all__ = [ ...@@ -52,7 +53,7 @@ __all__ = [
"AdaptiveSoftmax", "AdaptiveSoftmax",
"BeamableMM", "BeamableMM",
"CharacterTokenEmbedder", "CharacterTokenEmbedder",
"ConformerEncoderLayer", "S2TTransformerEncoderLayer",
"ConvolutionModule", "ConvolutionModule",
"ConvTBC", "ConvTBC",
"cross_entropy", "cross_entropy",
...@@ -86,6 +87,8 @@ __all__ = [ ...@@ -86,6 +87,8 @@ __all__ = [
"ScalarBias", "ScalarBias",
"SEAttention", "SEAttention",
"SinusoidalPositionalEmbedding", "SinusoidalPositionalEmbedding",
"swish",
"Swish",
"TransformerSentenceEncoderLayer", "TransformerSentenceEncoderLayer",
"TransformerSentenceEncoder", "TransformerSentenceEncoder",
"TransformerDecoderLayer", "TransformerDecoderLayer",
......
import torch
import torch.nn as nn
def get_activation_class(activation: str, dim=None):
""" Returns the activation function corresponding to `activation` """
if activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "glu":
assert dim is not None
return nn.GLU(dim=dim)
elif activation == "swish":
return Swish()
elif activation == "none":
return nn.Identity()
else:
raise RuntimeError("activation function {} not supported".format(activation))
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * x.sigmoid()
...@@ -9,14 +9,9 @@ from typing import Optional, Tuple ...@@ -9,14 +9,9 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from fairseq.modules.layer_norm import LayerNorm
class Swish(nn.Module): from fairseq.modules.layer_norm import LayerNorm
"""Construct an Swish object.""" from fairseq.modules.activations import get_activation_class
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
...@@ -73,7 +68,7 @@ class ConvolutionModule(nn.Module): ...@@ -73,7 +68,7 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.activation = Swish() self.activation = get_activation_class("swish")
def forward( def forward(
self, self,
......
...@@ -20,7 +20,7 @@ from fairseq.modules.quant_noise import quant_noise ...@@ -20,7 +20,7 @@ from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
class ConformerEncoderLayer(nn.Module): class S2TTransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is In the original paper each operation (multi-head attention or FFN) is
......
from .subsampling import *
\ No newline at end of file
import torch
import torch.nn as nn
from fairseq.modules.activations import Swish
from fairseq.modules.layer_norm import LayerNorm
def get_activation_class(activation: str, dim=None):
""" Returns the activation function corresponding to `activation` """
if activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "glu":
assert dim is not None
return nn.GLU(dim=dim)
elif activation == "swish":
return Swish()
elif activation == "none":
return nn.Identity()
else:
raise RuntimeError("activation function {} not supported".format(activation))
class TransposeLast(nn.Module):
@staticmethod
def forward(x):
return x.transpose(-1, -2).contiguous()
def get_norm(norm_type, size, transpose=False):
trans = nn.Identity()
if transpose:
trans = TransposeLast()
if norm_type == "batch1d":
return nn.Sequential(trans, nn.BatchNorm1d(size), trans)
elif norm_type == "batch2d":
return nn.Sequential(trans, nn.BatchNorm2d(size), trans)
elif norm_type == "layer":
return nn.Sequential(trans, LayerNorm(size), trans)
elif norm_type == "none":
return nn.Identity()
else:
raise RuntimeError("normalization type {} not supported".format(norm_type))
class Conv1dSubsampling(nn.Module):
"""Conv1d Subsampling Block
Args:
num_layers: number of strided convolution layers
in_dim: input feature dimension
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_length, in_dim)
Output: (batch_size, out_length, out_dim)
"""
def __init__(self, num_layers,
in_dim, filters, kernel_size, stride=2,
norm="none", act="glu"):
super(Conv1dSubsampling, self).__init__()
# Assert
assert norm in ["batch1d", "layer", "none"]
assert act in ["relu", "swish", "glu", "none"]
# Layers
self.layers = nn.ModuleList([nn.Sequential(
nn.Conv1d(in_dim if layer_id == 0 else filters[layer_id - 1],
filters[layer_id] * 2 if act == "glu" else filters[layer_id],
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1)
) for layer_id in range(num_layers)])
def forward(self, x, x_len):
# (B, T, D) -> (B, D, T)
x = x.transpose(1, 2)
# Layers
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
x = x.transpose(1, 2)
return x, x_len
class Conv2dSubsampling(nn.Module):
"""Conv2d Subsampling Block
Args:
num_layers: number of strided convolution layers
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_length, in_dim)
Output: (batch_size, out_length, out_dim)
"""
def __init__(self, num_layers,
in_dim, filters, kernel_size, stride=2,
norm="none", act="glu"):
super(Conv2dSubsampling, self).__init__()
# Assert
assert norm in ["batch2d", "none"]
assert act in ["relu", "swish", "glu", "none"]
# Conv 2D Subsampling Layers
self.layers = nn.ModuleList([nn.Sequential(
nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1],
filters[layer_id] * 2 if act =="glu" else filters[layer_id],
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1)
) for layer_id in range(num_layers)])
self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1])
def forward(self, x, x_len):
# (B, T, D) -> (B, D, T) -> (B, 1, D, T)
x = x.tranpose(1, 2).unsqueeze(dim=1)
# Layers
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).transpose(1, 2)
x = self.linear(x)
return x, x_len
def subsampling(args):
subsampling_type = getattr(args, "subsampling_type", "conv1d")
layers = getattr(args, "subsampling_layers", 2)
in_dim = args.input_feat_per_channel * args.input_channels
filters = [getattr(args, "subsampling_filter")] + [args.encoder_embed_dim]
kernel_size = getattr(args, "subsampling_kernel", 5)
stride = getattr(args, "subsampling_stride", 2)
norm = getattr(args, "subsampling_norm", "none")
activation = getattr(args, "subsampling_activation", "none")
if subsampling_type == "conv1d":
return Conv1dSubsampling(layers, in_dim, filters, kernel_size, stride, norm, activation)
elif subsampling_type == "conv2d":
return Conv2dSubsampling(layers, in_dim, filters, kernel_size, stride, norm, activation)
else:
raise RuntimeError("Subsampling type {} not supported".format(subsampling_type))
...@@ -16,7 +16,6 @@ from fairseq.data.audio.speech_to_text_dataset import ( ...@@ -16,7 +16,6 @@ from fairseq.data.audio.speech_to_text_dataset import (
) )
from fairseq.tasks import LegacyFairseqTask, register_task from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -149,6 +148,12 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -149,6 +148,12 @@ class SpeechToTextTask(LegacyFairseqTask):
seq_gen_cls=None, seq_gen_cls=None,
extra_gen_cls_kwargs=None, extra_gen_cls_kwargs=None,
): ):
from fairseq.models.speech_to_text import S2TCTCModel, CTCDecoder
if isinstance(models[0], S2TCTCModel):
blank_idx = self.target_dictionary.index(self.blank_symbol) if hasattr(self, 'blank_symbol') else 0
return CTCDecoder(models, args,
self.target_dictionary,
blank_idx)
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
raise ValueError( raise ValueError(
'Please set "--prefix-size 1" since ' 'Please set "--prefix-size 1" since '
......
...@@ -406,7 +406,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -406,7 +406,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for item in translation_list: for item in translation_list:
f.write("{}\n".format("\t".join(item))) f.write("{}\n".format("\t".join(item)))
if models[0].decoder.gather_attn_weight: if hasattr(models[0], "decoder") and models[0].decoder.gather_attn_weight:
weights = models[0].decoder.attn_weights weights = models[0].decoder.attn_weights
sort_weights = sorted(weights.items(), key=lambda k: k[0]) sort_weights = sorted(weights.items(), key=lambda k: k[0])
num = sum([k[1] for k in sort_weights]) num = sum([k[1] for k in sort_weights])
...@@ -419,8 +419,6 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -419,8 +419,6 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
with open("cos_sim", "w", encoding="utf-8") as fw: with open("cos_sim", "w", encoding="utf-8") as fw:
for layer, sim in cos_sim.items(): for layer, sim in cos_sim.items():
sim = sum(sim) / len(sim) * 100 sim = sum(sim) / len(sim) * 100
# if layer >= 10:
# layer -= 10
fw.write("%d\t%f\n" % (layer, sim)) fw.write("%d\t%f\n" % (layer, sim))
return scorer return scorer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论