Commit 8f084189 by xuchen

fix the bugs and optimize the code

parent 4f679c86
arch: transformer arch: transformer_ctc
share-all-embeddings: True share-all-embeddings: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -8,7 +8,7 @@ warmup-updates: 8000 ...@@ -8,7 +8,7 @@ warmup-updates: 8000
lr: 1e-3 lr: 1e-3
adam_betas: (0.9,0.997) adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
dropout: 0.3 dropout: 0.3
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
...@@ -23,7 +23,8 @@ asr_vocab_prefix=spm_unigram10000_st_share ...@@ -23,7 +23,8 @@ asr_vocab_prefix=spm_unigram10000_st_share
src_lang=en src_lang=en
tgt_lang=zh tgt_lang=zh
subsets=(train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON) subsets=(train_covost train_eu train_iwslt train_mustc_ende train_voxpopuil train_mustc_enzh dev tst-COMMON train_ted)
#subsets=(train_ted)
mkdir -p $data_dir mkdir -p $data_dir
splits=$(echo ${subsets[*]} | sed 's/ /,/g') splits=$(echo ${subsets[*]} | sed 's/ /,/g')
......
#train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh train-subset: train_covost,train_eu,train_iwslt,train_mustc_ende,train_voxpopuil,train_mustc_enzh,train_ted,train-clean-100,train-clean-360,train-other-500
train-subset: train_mustc_enzh #train-subset: train_mustc_enzh
valid-subset: dev valid-subset: dev
max-epoch: 100 max-epoch: 100
......
arch: transformer arch: transformer
share-decoder-input-output-embed: True share-all-embeddings: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
......
arch: transformer
share-decoder-input-output-embed: True
optimizer: adam
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 2e-4
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
load-pretrained-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0317_unified_lcrm_tok_deep_baseline_pretrain/avg_5_checkpoint.pt
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
train-subset: train_mustc_enzh train-subset: train_mustc_enzh
valid-subset: dev valid-subset: dev
fp16-scale-tolerance: 0.25
max-epoch: 100 max-epoch: 100
max-update: 100000 max-update: 100000
patience: 20 patience: 20
......
arch: pdss2t_transformer_s_8 arch: s2t_ctc
pds-fusion: True encoder-type: transformer
ctc-layer: 12
inter_mixup: True
inter_mixup_layer: 0
inter_mixup_ratio: 0.2
share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0015
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: ctc
label_smoothing: 0.1 ctc-weight: 1.0
subsampling-type: conv2d
subsampling-layers: 2
subsampling-filter: 176
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: batch2d
subsampling-activation: swish
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-embed-dim: 176
encoder-layers: 12 encoder-ffn-embed-dim: 704
decoder-layers: 6 encoder-layers: 16
encoder-attention-heads: 4 encoder-attention-heads: 4
decoder-embed-dim: 256 macaron-style: True
decoder-ffn-embed-dim: 2048 use-cnn-module: True
decoder-attention-heads: 4 cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -43,6 +43,7 @@ lr: 0.0015 ...@@ -43,6 +43,7 @@ lr: 0.0015
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece post-process: sentencepiece
dropout: 0.1 dropout: 0.1
......
...@@ -7,6 +7,7 @@ patience: 20 ...@@ -7,6 +7,7 @@ patience: 20
best-checkpoint-metric: loss best-checkpoint-metric: loss
maximize-best-checkpoint-metric: False maximize-best-checkpoint-metric: False
post-process: sentencepiece
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
keep-best-checkpoints: 10 keep-best-checkpoints: 10
......
...@@ -2,16 +2,15 @@ arch: s2t_ctc ...@@ -2,16 +2,15 @@ arch: s2t_ctc
encoder-type: transformer encoder-type: transformer
optimizer: adam optimizer: adam
#clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015 lr: 0.0015
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
post-process: sentencepiece ctc-weight: 1.0
subsampling-type: conv2d subsampling-type: conv2d
subsampling-layers: 2 subsampling-layers: 2
......
...@@ -12,7 +12,7 @@ encoder-type: pds ...@@ -12,7 +12,7 @@ encoder-type: pds
encoder-embed-dim: 176 encoder-embed-dim: 176
pds-stages: 4 pds-stages: 4
ctc-layer: 16 #ctc-layer: 16
pds-layers: 4_4_4_4 pds-layers: 4_4_4_4
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
...@@ -38,11 +38,11 @@ post-process: sentencepiece ...@@ -38,11 +38,11 @@ post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-layers: 12 encoder-layers: 16
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 15
encoder-activation-fn: swish encoder-activation-fn: swish
encoder-attention-type: rel_pos encoder-attention-type: rel_pos
......
...@@ -34,6 +34,7 @@ lr: 0.0015 ...@@ -34,6 +34,7 @@ lr: 0.0015
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
ctc-weight: 1.0
post-process: sentencepiece post-process: sentencepiece
dropout: 0.1 dropout: 0.1
......
...@@ -696,8 +696,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -696,8 +696,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
fusion_stages_num = 0 fusion_stages_num = 0
self.fusion_stages_num = fusion_stages_num self.fusion_stages_num = fusion_stages_num
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0") args.pds_ctc = getattr(args, "pds_ctc", None)
self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")] self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")] if args.pds_ctc is not None else None
inter_ctc_module = None inter_ctc_module = None
inter_adapter = None inter_adapter = None
...@@ -708,11 +708,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -708,11 +708,12 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim = self.pds_embed_dims[i] embed_dim = self.pds_embed_dims[i]
kernel_size = self.pds_kernel_sizes[i] kernel_size = self.pds_kernel_sizes[i]
use_pos_embed = self.pds_position_embed[i] use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i] use_ctc = self.pds_ctc[i] if self.pds_ctc is not None else False
ffn_ratio = self.pds_ffn_ratios[i] ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i] num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] # if self.attn_type == "reduced" else -1 attn_ds_ratio = self.pds_attn_ds_ratios[i] \
if self.pds_conv_strides is not None and self.attn_type == "reduced" else 1
conv_stride = self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1 conv_stride = self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1
attn_stride = self.pds_attn_strides[i] if self.pds_attn_strides is not None else 1 attn_stride = self.pds_attn_strides[i] if self.pds_attn_strides is not None else 1
if conv_stride != 1 or attn_stride != 1: if conv_stride != 1 or attn_stride != 1:
...@@ -1231,15 +1232,15 @@ def base_architecture(args): ...@@ -1231,15 +1232,15 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv") args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None) args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True) args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None) args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1") args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1") args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1") args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
...@@ -1248,7 +1249,7 @@ def base_architecture(args): ...@@ -1248,7 +1249,7 @@ def base_architecture(args):
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC # intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0") args.pds_ctc = getattr(args, "pds_ctc", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none") args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
......
...@@ -558,6 +558,7 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -558,6 +558,7 @@ class S2TCTCEncoder(FairseqEncoder):
def __init__(self, args, task=None): def __init__(self, args, task=None):
super().__init__(None) super().__init__(None)
setattr(args, "ctc_weight", 1.0)
encoder_type = getattr(args, "encoder_type", "transformer") encoder_type = getattr(args, "encoder_type", "transformer")
if encoder_type == "transformer": if encoder_type == "transformer":
from .s2t_transformer import S2TTransformerEncoder from .s2t_transformer import S2TTransformerEncoder
...@@ -575,8 +576,7 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -575,8 +576,7 @@ class S2TCTCEncoder(FairseqEncoder):
return self.encoder(src_tokens, src_lengths, **kwargs) return self.encoder(src_tokens, src_lengths, **kwargs)
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
self.encoder.reorder_encoder_out(encoder_out, new_order) return self.encoder.reorder_encoder_out(encoder_out, new_order)
return
class CTCDecoder(object): class CTCDecoder(object):
......
...@@ -494,15 +494,15 @@ def base_architecture(args): ...@@ -494,15 +494,15 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv") args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None) args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True) args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None) args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1") args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1") args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1") args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
...@@ -112,6 +112,22 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -112,6 +112,22 @@ class S2TSATEModel(S2TTransformerModel):
type=str, type=str,
help="intermedia ctc layers for target sentence", help="intermedia ctc layers for target sentence",
) )
# freeze
parser.add_argument(
"--freeze-acoustic-encoder",
action="store_true",
help="freeze the parameters of the acoustic encoder",
)
parser.add_argument(
"--freeze-textual-encoder",
action="store_true",
help="freeze the parameters of the acoustic encoder",
)
parser.add_argument(
"--freeze-decoder",
action="store_true",
help="freeze the parameters of the decoder",
)
pass pass
@classmethod @classmethod
...@@ -150,6 +166,18 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -150,6 +166,18 @@ class S2TSATEModel(S2TTransformerModel):
return encoder return encoder
def forward(self, src_tokens, src_lengths, prev_output_tokens):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
return decoder_out
class TextEncoder(FairseqEncoder): class TextEncoder(FairseqEncoder):
def __init__(self, args, dictionary, embed_tokens=None): def __init__(self, args, dictionary, embed_tokens=None):
...@@ -224,7 +252,8 @@ class TextEncoder(FairseqEncoder): ...@@ -224,7 +252,8 @@ class TextEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league": elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter, self.adapter = Adapter(embed_dim, args.intermedia_adapter,
len(dictionary), embed_tokens=embed_tokens, len(dictionary),
# embed_tokens=embed_tokens,
strategy=strategy) strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
...@@ -294,7 +323,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -294,7 +323,7 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter # adapter
self.temperature = args.temperature self.temperature = args.temperature
# self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
strategy = None strategy = None
if args.adapter == "shrink": if args.adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg") strategy = getattr(args, "ctc_compress_strategy", "avg")
...@@ -318,6 +347,9 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -318,6 +347,9 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
self.freeze_acoustic_encoder = getattr(args, "freeze_acoustic_encoder", False)
self.freeze_textual_encoder = getattr(args, "freeze_textual_encoder", False)
if getattr(args, "use_enc_dlcl", False): if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 2 layer_num = args.encoder_layers + args.text_encoder_layers + 2
self.history = DynamicLinearCombination(args, is_encoder=True, layer_num=layer_num) self.history = DynamicLinearCombination(args, is_encoder=True, layer_num=layer_num)
...@@ -328,7 +360,11 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -328,7 +360,11 @@ class S2TSATEEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths) if self.freeze_acoustic_encoder:
with torch.no_grad():
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
else:
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
encoder_out = acoustic_encoder_out["encoder_out"][0] encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0] encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
...@@ -354,7 +390,11 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -354,7 +390,11 @@ class S2TSATEEncoder(FairseqEncoder):
self.history.push(x) self.history.push(x)
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history) if self.freeze_textual_encoder:
with torch.no_grad():
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
else:
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
...@@ -482,15 +522,15 @@ def base_architecture(args): ...@@ -482,15 +522,15 @@ def base_architecture(args):
args.pds_ds_method = getattr(args, "pds_ds_method", "conv") args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None) args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True) args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None) args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1") args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1") args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1") args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
......
...@@ -598,6 +598,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -598,6 +598,7 @@ class S2TTransformerEncoder(FairseqEncoder):
strategy = getattr(args, "intermedia_distribution_cutoff", None) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(dim, args.intermedia_adapter, self.adapter = Adapter(dim, args.intermedia_adapter,
len(task.source_dictionary), strategy=strategy) len(task.source_dictionary), strategy=strategy)
# embed_tokens=embed_tokens if embed_tokens is not None else self.ctc.ctc_projection)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
...@@ -700,6 +701,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -700,6 +701,10 @@ class S2TTransformerEncoder(FairseqEncoder):
logit = self.ctc(norm_x) logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit) intermedia_ctc_logits.append(logit)
logit = logit.clamp(min=-1e8 if logit.dtype == torch.float32 else -1e4,
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
......
...@@ -161,8 +161,8 @@ class Wav2VecCtc(BaseFairseqModel): ...@@ -161,8 +161,8 @@ class Wav2VecCtc(BaseFairseqModel):
padding = net_output["padding_mask"] padding = net_output["padding_mask"]
if padding is not None and padding.any(): if padding is not None and padding.any():
padding = padding.T padding = padding.T
logits[padding][...,0] = 0 logits[padding][..., 0] = 0
logits[padding][...,1:] = float('-inf') logits[padding][..., 1:] = float('-inf')
return logits return logits
......
...@@ -130,7 +130,7 @@ class Adapter(nn.Module): ...@@ -130,7 +130,7 @@ class Adapter(nn.Module):
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league": elif self.adapter_type == "inter_league":
soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = representation + soft_out out = representation + soft_out
elif self.adapter_type == "none": elif self.adapter_type == "none":
...@@ -153,7 +153,7 @@ class Adapter(nn.Module): ...@@ -153,7 +153,7 @@ class Adapter(nn.Module):
# x is T x B x C -> B x C x T; weights_matrix is B x T x T' # x is T x B x C -> B x C x T; weights_matrix is B x T x T'
representation = representation.permute(1, 2, 0) representation = representation.permute(1, 2, 0)
compressed_output = representation.float().bmm(weights_matrix).type_as(representation) # B x C x T' compressed_output = representation.bmm(weights_matrix).type_as(representation) # B x C x T'
out = compressed_output.permute(2, 0, 1) out = compressed_output.permute(2, 0, 1)
out_lengths = lengths.new(new_lengths) out_lengths = lengths.new(new_lengths)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论