Commit e0da16dd by xuchen

fix the bugs of multibranch arch

parent a2fd43e7
arch: s2t_multibranch
junior-acoustic-encoder: transformer
junior-acoustic-encoder-layers: 12
senior-acoustic-encoder-layers: 6
textual-encoder-layers: 6
#encoder-drop-net: True
#encoder-drop-net-prob: 0.5
#collaboration-direction: none
#collaboration-direction: acoustic
#collaboration-direction: textual
collaboration-direction: both
collaboration-start: 0:0
collaboration-step: 1:1
encoder-collaboration-mode: serial
decoder-collaboration-mode: serial
#encoder-collaboration-mode: parallel
#decoder-collaboration-mode: parallel
encoder-league-s1-ratio: 0.5
encoder-league-s2-ratio: 0.5
encoder-league-drop-net: False
encoder-league-drop-net-prob: 0.2
encoder-league-drop-net-mix: False
decoder-league-s1-ratio: 0.5
decoder-league-s2-ratio: 0.5
decoder-league-drop-net: False
decoder-league-drop-net-prob: 0.0
decoder-league-drop-net-mix: False
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
encoder-embed-norm: True
encoder-no-scale-embedding: True
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
acoustic-encoder: transformer
adapter: inter_league
#adapter: none
#adapter-embed-norm: True
adapter-out-norm: True
#share-adapter-and-ctc: True
#share-adapter-and-embed: True
load-pretrained-junior-acoustic-encoder-from: /home/xuchen/st/checkpoints/mustc/st/0823_tok_base_ctc_baseline/avg_10_checkpoint.pt
#load-pretrained-senior-acoustic-encoder-from:
#load-pretrained-textual-encoder-from:
#load-pretrained-decoder-from:
...@@ -422,6 +422,13 @@ class S2TMultiBranchEncoder(FairseqEncoder): ...@@ -422,6 +422,13 @@ class S2TMultiBranchEncoder(FairseqEncoder):
self.acoustic_norm = LayerNorm(args.encoder_embed_dim) self.acoustic_norm = LayerNorm(args.encoder_embed_dim)
self.textual_norm = LayerNorm(args.encoder_embed_dim) self.textual_norm = LayerNorm(args.encoder_embed_dim)
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
self.junior_acoustic_encoder.set_ctc_infer(ctc_infer, post_process, src_dict=src_dict, tgt_dict=tgt_dict,
path=path)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
return self.junior_acoustic_encoder.ctc_valid(lprobs, targets, input_lengths, dictionary, lang)
def forward(self, src_tokens, src_lengths=None, **kwargs): def forward(self, src_tokens, src_lengths=None, **kwargs):
junior_acoustic_encoder_out = self.junior_acoustic_encoder(src_tokens, src_lengths, **kwargs) junior_acoustic_encoder_out = self.junior_acoustic_encoder(src_tokens, src_lengths, **kwargs)
acoustic_x = junior_acoustic_encoder_out["encoder_out"][0] acoustic_x = junior_acoustic_encoder_out["encoder_out"][0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论