Commit 55702466 by xuchen

add the pure ctc arch for pds method

parent bab6c520
arch: s2t_transformer_s arch: s2t_ctc
encoder-type: pds
#arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3 lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
ctc-weight: 0.2
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.1
intermedia-drop-prob: 0.5
ctc-self-distill-weight: 0
post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
...@@ -40,8 +38,3 @@ encoder-attention-heads: 4 ...@@ -40,8 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
...@@ -137,7 +137,8 @@ class AudioDataset(Dataset): ...@@ -137,7 +137,8 @@ class AudioDataset(Dataset):
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{split}_{wav_path.stem}_{i}" # _id = f"{split}_{wav_path.stem}_{i}"
_id = f"{wav_path.stem}_{i}"
item = dict() item = dict()
item["audio"] = wav_path.as_posix() item["audio"] = wav_path.as_posix()
...@@ -263,8 +264,12 @@ def process(args): ...@@ -263,8 +264,12 @@ def process(args):
utt_id = item['id'] utt_id = item['id']
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
tag_features_path = (feature_root / f"{split}_{utt_id}.npy").as_posix()
if os.path.exists(features_path): if os.path.exists(tag_features_path):
continue
if os.path.exists(features_path) and not os.path.exists(tag_features_path):
shutil.move(features_path, tag_features_path)
continue continue
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True) waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
......
...@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask ...@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round from fairseq.logging.meters import safe_round
@dataclass @dataclass
class CtcCriterionConfig(FairseqDataclass): class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field( zero_infinity: bool = field(
...@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass):
) )
sentence_avg: bool = II("optimization.sentence_avg") sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field( post_process: str = field(
default="letter", default="sentencepiece",
metadata={ metadata={
"help": "how to post process predictions into words. can be letter, " "help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. " "wordpiece, BPE symbols, etc. "
......
...@@ -47,6 +47,8 @@ __all__ = [ ...@@ -47,6 +47,8 @@ __all__ = [
"FairseqLanguageModel", "FairseqLanguageModel",
"FairseqModel", "FairseqModel",
"FairseqMultiModel", "FairseqMultiModel",
"register_model",
"register_model_architecture"
] ]
......
...@@ -19,6 +19,8 @@ from fairseq.modules import ( ...@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
RelPositionalEncoding,
LegacyRelPositionalEncoding,
PDSTransformerEncoderLayer, PDSTransformerEncoderLayer,
DownSampleConvolutionModule DownSampleConvolutionModule
) )
...@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer # Transformer
parser.add_argument( parser.add_argument(
"--activation-fn", "--activation-fn",
...@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"reduced", "reduced",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"rel_pos_legacy",
"rel_pos",
"rope",
"abs",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help='dropout for history output') help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1', parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all') help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling # local modeling
parser.add_argument( parser.add_argument(
'--hard-mask-window', '--hard-mask-window',
...@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
# Conformer setting # Conformer setting
parser.add_argument( parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style", "--macaron-style",
default=False, default=False,
type=bool, type=bool,
...@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"The legacy relative positional encoding will be deprecated in the future." "The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.", "More Details can be found in https://github.com/espnet/espnet/pull/2816.",
) )
# CTC # CNN module
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# Conformer module
parser.add_argument( parser.add_argument(
"--use-cnn-module", "--use-cnn-module",
default=False, default=False,
...@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="use the position embedding or not before each encoding", help="use the position embedding or not before each encoding",
) )
parser.add_argument( parser.add_argument(
"--pds-attn-heads", "--pds-attn-heads",
type=str, type=str,
...@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="use the ctc after each stage", help="use the ctc after each stage",
) )
# intermedia ctc
parser.add_argument( parser.add_argument(
"--intermedia-ctc-layers", "--intermedia-ctc-layers",
default=None, default=None,
...@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="type of intermedia adapter", help="type of intermedia adapter",
) )
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
pass pass
@classmethod @classmethod
...@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
f"loaded pretrained encoder from: " f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}" f"{args.load_pretrained_encoder_from}"
) )
return encoder return encoder
...@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_kernel_sizes = [int(n) for n in args.pds_kernel_sizes.split("_")] self.pds_kernel_sizes = [int(n) for n in args.pds_kernel_sizes.split("_")]
self.pds_embed_norm = args.pds_embed_norm self.pds_embed_norm = args.pds_embed_norm
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")] self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")] self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")] self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
if self.attn_type == "reduced": if self.attn_type == "reduced":
...@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if args.no_scale_embedding: if args.no_scale_embedding:
self.embed_scale = 1.0 self.embed_scale = 1.0
# down-sampling
downsampling = Downsampling( downsampling = Downsampling(
self.pds_ds_method, self.pds_ds_method,
self.pds_embed_norm, self.pds_embed_norm,
...@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
stride=ds_ratio, stride=ds_ratio,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
) )
# position encoding
if use_pos_embed: if use_pos_embed:
pos_embed = PositionalEmbedding(args.max_source_positions, embed_dim, self.padding_idx) if self.attn_type == "rel_pos":
pos_embed = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
pos_embed = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
pos_embed = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
else: else:
pos_embed = None pos_embed = None
...@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio) PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio)
for _ in range(num_layers)]) for _ in range(num_layers)])
# representation fusion
fusion_pre_layer_norm = None fusion_pre_layer_norm = None
fusion_post_layer_norm = None fusion_post_layer_norm = None
fusion_downsampling = None fusion_downsampling = None
...@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True) self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and (getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0)) (getattr(args, "ctc_weight", False) > 0))
if self.use_ctc: if self.use_ctc:
self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
...@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout # add the position encoding and dropout
if pos_embed: if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1) if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
x += positions positions = pos_embed(x)
positions = self.dropout(positions)
elif self.attn_type == "rope":
positions = None
else:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
x += positions
positions = None
else: else:
positions = None positions = None
......
...@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TTransformerEncoder(args, task, embed_tokens) encoder = S2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info( logger.info(
f"loaded pretrained encoder from: " f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}" f"{args.load_pretrained_encoder_from}"
) )
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder return encoder
...@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.padding_idx = 1 self.padding_idx = 1
self.subsample = subsampling(args) self.subsample = subsampling(args)
self.linear = nn.Linear(dim, dim) # self.linear = nn.Linear(dim, dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
...@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.history = None self.history = None
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0)) (("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
if self.use_ctc: if self.use_ctc:
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
...@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x += positions x += positions
positions = None positions = None
x = self.linear(x) # x = self.linear(x)
x = self.dropout_module(x) x = self.dropout_module(x)
# add emb into history # add emb into history
......
...@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module):
embed_dim = args.encoder_embed_dim embed_dim = args.encoder_embed_dim
ffn_dim = args.encoder_ffn_embed_dim ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout dropout = args.dropout
self.embed_dim = args.encoder_embed_dim self.embed_dim = embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn = self.build_self_attention(args, self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__ dropout, module_name=self.__class__.__name__
...@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module):
) )
self.ffn_norm = LayerNorm(self.embed_dim) self.ffn_norm = LayerNorm(self.embed_dim)
def build_self_attention(self, embed_dim, args): def build_self_attention(self, args, embed_dim):
attention_heads = args.encoder_attention_heads attention_heads = args.encoder_attention_heads
dropout = args.dropout dropout = args.dropout
...@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module):
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative": elif self.attn_type == "relative":
max_relative_length = getattr(args, "max_encoder_relative_length", -1) max_relative_length = max(getattr(args, "max_encoder_relative_length", -1),
getattr(args, "max_relative_length", -1))
if max_relative_length != -1: if max_relative_length != -1:
return RelativeMultiheadAttention( return RelativeMultiheadAttention(
embed_dim, embed_dim,
...@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module):
) )
else: else:
attn_func = MultiheadAttention attn_func = MultiheadAttention
print("The attention type %s is not supported!" % self.attn_type) print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
return attn_func( return attn_func(
embed_dim, embed_dim,
args.encoder_attention_heads, attention_heads,
dropout=args.attention_dropout, dropout=dropout,
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论