Commit 55702466 by xuchen

add the pure ctc arch for pds method

parent bab6c520
arch: s2t_transformer_s
arch: s2t_ctc
encoder-type: pds
#arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
ctc-weight: 0.2
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.1
intermedia-drop-prob: 0.5
ctc-self-distill-weight: 0
post-process: sentencepiece
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
......@@ -40,8 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -137,7 +137,8 @@ class AudioDataset(Dataset):
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{split}_{wav_path.stem}_{i}"
# _id = f"{split}_{wav_path.stem}_{i}"
_id = f"{wav_path.stem}_{i}"
item = dict()
item["audio"] = wav_path.as_posix()
......@@ -263,8 +264,12 @@ def process(args):
utt_id = item['id']
features_path = (feature_root / f"{utt_id}.npy").as_posix()
tag_features_path = (feature_root / f"{split}_{utt_id}.npy").as_posix()
if os.path.exists(features_path):
if os.path.exists(tag_features_path):
continue
if os.path.exists(features_path) and not os.path.exists(tag_features_path):
shutil.move(features_path, tag_features_path)
continue
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
......
......@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round
@dataclass
class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field(
......@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass):
)
sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field(
default="letter",
default="sentencepiece",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
......
......@@ -47,6 +47,8 @@ __all__ = [
"FairseqLanguageModel",
"FairseqModel",
"FairseqMultiModel",
"register_model",
"register_model_architecture"
]
......
......@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
RelPositionalEncoding,
LegacyRelPositionalEncoding,
PDSTransformerEncoderLayer,
DownSampleConvolutionModule
)
......@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer
parser.add_argument(
"--activation-fn",
......@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"reduced",
"rel_selfattn",
"relative",
"rel_pos_legacy",
"rel_pos",
"rope",
"abs",
],
help="transformer encoder self-attention layer type"
)
......@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling
parser.add_argument(
'--hard-mask-window',
......@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
# Conformer setting
parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
......@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# Conformer module
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
......@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str,
help="use the position embedding or not before each encoding",
)
parser.add_argument(
"--pds-attn-heads",
type=str,
......@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str,
help="use the ctc after each stage",
)
# intermedia ctc
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
......@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
pass
@classmethod
......@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
return encoder
......@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_kernel_sizes = [int(n) for n in args.pds_kernel_sizes.split("_")]
self.pds_embed_norm = args.pds_embed_norm
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
if self.attn_type == "reduced":
......@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if args.no_scale_embedding:
self.embed_scale = 1.0
# down-sampling
downsampling = Downsampling(
self.pds_ds_method,
self.pds_embed_norm,
......@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
stride=ds_ratio,
padding=(kernel_size - 1) // 2,
)
# position encoding
if use_pos_embed:
pos_embed = PositionalEmbedding(args.max_source_positions, embed_dim, self.padding_idx)
if self.attn_type == "rel_pos":
pos_embed = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
pos_embed = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
pos_embed = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
else:
pos_embed = None
......@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio)
for _ in range(num_layers)])
# representation fusion
fusion_pre_layer_norm = None
fusion_post_layer_norm = None
fusion_downsampling = None
......@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and
(getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0))
if self.use_ctc:
self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
......@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout
if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
x += positions
positions = self.dropout(positions)
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
positions = pos_embed(x)
elif self.attn_type == "rope":
positions = None
else:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
x += positions
positions = None
else:
positions = None
......
......@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
......@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.padding_idx = 1
self.subsample = subsampling(args)
self.linear = nn.Linear(dim, dim)
# self.linear = nn.Linear(dim, dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
......@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.history = None
self.use_ctc = "sate" in args.arch or \
(getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
if self.use_ctc:
self.ctc_layer = args.ctc_layer
......@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x += positions
positions = None
x = self.linear(x)
# x = self.linear(x)
x = self.dropout_module(x)
# add emb into history
......
......@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module):
embed_dim = args.encoder_embed_dim
ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout
self.embed_dim = args.encoder_embed_dim
self.embed_dim = embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn = self.build_self_attention(args, self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
......@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module):
)
self.ffn_norm = LayerNorm(self.embed_dim)
def build_self_attention(self, embed_dim, args):
def build_self_attention(self, args, embed_dim):
attention_heads = args.encoder_attention_heads
dropout = args.dropout
......@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module):
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length = max(getattr(args, "max_encoder_relative_length", -1),
getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
......@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module):
)
else:
attn_func = MultiheadAttention
print("The attention type %s is not supported!" % self.attn_type)
print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
attention_heads,
dropout=dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论