Commit 8f45faa2 by xuchen

update the dual arch

parent cbeb5521
arch: s2t_dual
asr-encoder: pds
asr-encoder: transformer
mt-encoder-layers: 6
mt-encoder: transformer
encoder-drop-net: True
encoder-drop-net-prob: 0.5
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
encoder-collaboration-mode: parallel
decoder-collaboration-mode: parallel
encoder-league-s1-ratio: 0.5
encoder-league-s2-ratio: 0.5
encoder-league-drop-net: False
encoder-league-drop-net-prob: 0.2
encoder-league-drop-net-mix: False
decoder-league-s1-ratio: 0.5
decoder-league-s2-ratio: 0.5
decoder-league-drop-net: False
decoder-league-drop-net-prob: 0.0
decoder-league-drop-net-mix: False
share-decoder-input-output-embed: True
optimizer: adam
......@@ -35,6 +33,7 @@ label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
......
......@@ -35,13 +35,6 @@ class JoinSpeechTextLoss(
"""Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser)
CtcCriterion.add_args(parser)
parser.add_argument(
"--ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of CTC loss",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
......@@ -65,7 +58,7 @@ class JoinSpeechTextLoss(
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
loss, nll_loss, other_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
......
......@@ -79,18 +79,23 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16)
self.s2_norm = LayerNorm(self.embed_dim)
self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim),
dropout=args.attention_dropout,
self_attention=False,
)
self.use_s2_attn_norm = args.use_s2_attn_norm
if self.use_s2_attn_norm:
self.s2_norm = LayerNorm(self.embed_dim)
self.encoder_collaboration_mode = args.encoder_collaboration_mode
if self.encoder_collaboration_mode != "none":
if self.encoder_collaboration_mode == "serial":
self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
kdim=getattr(args, "encoder_s2_dim", self.embed_dim),
vdim=getattr(args, "encoder_s2_dim", self.embed_dim),
dropout=args.attention_dropout,
self_attention=False,
)
self.league_s1_ratio = args.encoder_league_s1_ratio
self.league_s2_ratio = args.encoder_league_s2_ratio
......@@ -251,7 +256,9 @@ class TransformerS2EncoderLayer(nn.Module):
x = self.self_attn_layer_norm(x)
if s2 is not None:
s2 = self.s2_norm(s2)
if self.use_s2_attn_norm:
s2 = self.s2_norm(s2)
if self.encoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_norm(x)
......@@ -399,9 +406,9 @@ class TransformerS2DecoderLayer(nn.Module):
if self.league_drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1]
# elif frand > 1 - self.league_drop_net_prob and self.training:
# return [1, 0]
else:
return [0.5, 0.5]
else:
......@@ -633,7 +640,6 @@ class TransformerS2DecoderLayer(nn.Module):
x2 = self.dropout_module(x2)
ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2
x = x + x2
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论