Commit 8f084189 by xuchen

fix the bugs and optimize the code

parent 4f679c86
arch: transformer
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
......@@ -8,7 +8,7 @@ warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.3
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
......@@ -23,7 +23,8 @@ asr_vocab_prefix=spm_unigram10000_st_share
src_lang=en
tgt_lang=zh
subsets=(train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON)
subsets=(train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON train_ted)
#subsets=(train_ted)
mkdir -p $data_dir
splits=$(echo ${subsets[*]} | sed 's/ /,/g')
......
#train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh
train-subset: train_mustc_enzh
train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh,train_ted,train-clean-100,train-clean-360,train-other-500
#train-subset: train_mustc_enzh
valid-subset: dev
max-epoch: 100
......
arch: transformer
share-decoder-input-output-embed: True
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
......
arch: transformer
share-decoder-input-output-embed: True
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 2e-4
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
load-pretrained-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
......@@ -2,6 +2,7 @@
train-subset: train_mustc_enzh
valid-subset: dev
fp16-scale-tolerance: 0.25
max-epoch: 100
max-update: 100000
patience: 20
......
arch: pdss2t_transformer_s_8
pds-fusion: True
ctc-layer: 12
arch: s2t_ctc
encoder-type: transformer
inter_mixup: True
inter_mixup_layer: 0
inter_mixup_ratio: 0.2
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
criterion: ctc
ctc-weight: 1.0
subsampling-type: conv2d
subsampling-layers: 2
subsampling-filter: 176
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-embed-dim: 176
encoder-ffn-embed-dim: 704
encoder-layers: 16
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
......@@ -43,6 +43,7 @@ lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
......
......@@ -7,6 +7,7 @@ patience: 20
best-checkpoint-metric: loss
maximize-best-checkpoint-metric: False
post-process: sentencepiece
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 10
......
......@@ -2,16 +2,15 @@ arch: s2t_ctc
encoder-type: transformer
optimizer: adam
#clip-norm: 10.0
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
ctc-weight: 1.0
subsampling-type: conv2d
subsampling-layers: 2
......
......@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 176
pds-stages: 4
ctc-layer: 16
#ctc-layer: 16
pds-layers: 4_4_4_4
pds-ratios: 2_2_1_2
pds-fusion: True
......@@ -38,11 +38,11 @@ post-process: sentencepiece
dropout: 0.1
activation-fn: relu
encoder-layers: 12
encoder-layers: 16
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
......
......@@ -34,6 +34,7 @@ lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece
dropout: 0.1
......
......@@ -696,8 +696,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
fusion_stages_num = 0
self.fusion_stages_num = fusion_stages_num
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")]
args.pds_ctc = getattr(args, "pds_ctc", None)
self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")] if args.pds_ctc is not None else None
inter_ctc_module = None
inter_adapter = None
......@@ -708,11 +708,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim = self.pds_embed_dims[i]
kernel_size = self.pds_kernel_sizes[i]
use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i]
use_ctc = self.pds_ctc[i] if self.pds_ctc is not None else False
ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] # if self.attn_type == "reduced" else -1
attn_ds_ratio = self.pds_attn_ds_ratios[i] \
if self.pds_conv_strides is not None and self.attn_type == "reduced" else 1
conv_stride = self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1
attn_stride = self.pds_attn_strides[i] if self.pds_attn_strides is not None else 1
if conv_stride != 1 or attn_stride != 1:
......@@ -1231,15 +1232,15 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......@@ -1248,7 +1249,7 @@ def base_architecture(args):
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.pds_ctc = getattr(args, "pds_ctc", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
......
......@@ -558,6 +558,7 @@ class S2TCTCEncoder(FairseqEncoder):
def __init__(self, args, task=None):
super().__init__(None)
setattr(args, "ctc_weight", 1.0)
encoder_type = getattr(args, "encoder_type", "transformer")
if encoder_type == "transformer":
from .s2t_transformer import S2TTransformerEncoder
......@@ -575,8 +576,7 @@ class S2TCTCEncoder(FairseqEncoder):
return self.encoder(src_tokens, src_lengths, **kwargs)
def reorder_encoder_out(self, encoder_out, new_order):
self.encoder.reorder_encoder_out(encoder_out, new_order)
return
return self.encoder.reorder_encoder_out(encoder_out, new_order)
class CTCDecoder(object):
......
......@@ -494,15 +494,15 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
......@@ -112,6 +112,22 @@ class S2TSATEModel(S2TTransformerModel):
type=str,
help="intermedia ctc layers for target sentence",
)
# freeze
parser.add_argument(
"--freeze-acoustic-encoder",
action="store_true",
help="freeze the parameters of the acoustic encoder",
)
parser.add_argument(
"--freeze-textual-encoder",
action="store_true",
help="freeze the parameters of the acoustic encoder",
)
parser.add_argument(
"--freeze-decoder",
action="store_true",
help="freeze the parameters of the decoder",
)
pass
@classmethod
......@@ -150,6 +166,18 @@ class S2TSATEModel(S2TTransformerModel):
return encoder
def forward(self, src_tokens, src_lengths, prev_output_tokens):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
return decoder_out
class TextEncoder(FairseqEncoder):
def __init__(self, args, dictionary, embed_tokens=None):
......@@ -224,7 +252,8 @@ class TextEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
len(dictionary), embed_tokens=embed_tokens,
len(dictionary),
# embed_tokens=embed_tokens,
strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
......@@ -294,7 +323,7 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter
self.temperature = args.temperature
# self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
strategy = None
if args.adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg")
......@@ -318,6 +347,9 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = acoustic_encoder_attention_type
self.freeze_acoustic_encoder = getattr(args, "freeze_acoustic_encoder", False)
self.freeze_textual_encoder = getattr(args, "freeze_textual_encoder", False)
if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 2
self.history = DynamicLinearCombination(args, is_encoder=True, layer_num=layer_num)
......@@ -328,6 +360,10 @@ class S2TSATEEncoder(FairseqEncoder):
if self.history is not None:
self.history.clean()
if self.freeze_acoustic_encoder:
with torch.no_grad():
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
else:
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
encoder_out = acoustic_encoder_out["encoder_out"][0]
......@@ -354,6 +390,10 @@ class S2TSATEEncoder(FairseqEncoder):
self.history.push(x)
if self.freeze_textual_encoder:
with torch.no_grad():
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
else:
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
return {
......@@ -482,15 +522,15 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
......@@ -598,6 +598,7 @@ class S2TTransformerEncoder(FairseqEncoder):
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(dim, args.intermedia_adapter,
len(task.source_dictionary), strategy=strategy)
# embed_tokens=embed_tokens if embed_tokens is not None else self.ctc.ctc_projection)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
......@@ -700,6 +701,10 @@ class S2TTransformerEncoder(FairseqEncoder):
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
......
......@@ -161,8 +161,8 @@ class Wav2VecCtc(BaseFairseqModel):
padding = net_output["padding_mask"]
if padding is not None and padding.any():
padding = padding.T
logits[padding][...,0] = 0
logits[padding][...,1:] = float('-inf')
logits[padding][..., 0] = 0
logits[padding][..., 1:] = float('-inf')
return logits
......
......@@ -130,7 +130,7 @@ class Adapter(nn.Module):
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league":
soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = representation + soft_out
elif self.adapter_type == "none":
......@@ -153,7 +153,7 @@ class Adapter(nn.Module):
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
representation = representation.permute(1, 2, 0)
compressed_output = representation.float().bmm(weights_matrix).type_as(representation) # B x C x T'
compressed_output = representation.bmm(weights_matrix).type_as(representation) # B x C x T'
out = compressed_output.permute(2, 0, 1)
out_lengths = lengths.new(new_lengths)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论