Commit 8f45faa2 by xuchen

update the dual arch

parent cbeb5521
arch: s2t_dual arch: s2t_dual
asr-encoder: pds asr-encoder: transformer
mt-encoder-layers: 6 mt-encoder-layers: 6
mt-encoder: transformer mt-encoder: transformer
encoder-drop-net: True
encoder-drop-net-prob: 0.5
encoder-embed-dim: 256 encoder-collaboration-mode: parallel
pds-stages: 4 decoder-collaboration-mode: parallel
#ctc-layer: 12
pds-layers: 3_3_3_3 encoder-league-s1-ratio: 0.5
pds-ratios: 2_2_1_2 encoder-league-s2-ratio: 0.5
pds-fusion: True encoder-league-drop-net: False
pds-fusion-method: all_conv encoder-league-drop-net-prob: 0.2
pds-embed-dims: 256_256_256_256 encoder-league-drop-net-mix: False
pds-ds-method: conv
pds-embed-norm: True decoder-league-s1-ratio: 0.5
pds-position-embed: 1_1_1_1 decoder-league-s2-ratio: 0.5
pds-kernel-sizes: 5_5_5_5 decoder-league-drop-net: False
pds-ffn-ratios: 8_8_8_8 decoder-league-drop-net-prob: 0.0
pds-attn-heads: 4_4_4_4 decoder-league-drop-net-mix: False
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
...@@ -35,6 +33,7 @@ label_smoothing: 0.1 ...@@ -35,6 +33,7 @@ label_smoothing: 0.1
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
......
...@@ -35,13 +35,6 @@ class JoinSpeechTextLoss( ...@@ -35,13 +35,6 @@ class JoinSpeechTextLoss(
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser) LabelSmoothedCrossEntropyCriterion.add_args(parser)
CtcCriterion.add_args(parser) CtcCriterion.add_args(parser)
parser.add_argument(
"--ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of CTC loss",
)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -65,7 +58,7 @@ class JoinSpeechTextLoss( ...@@ -65,7 +58,7 @@ class JoinSpeechTextLoss(
if "mixup" in encoder_out and encoder_out["mixup"] is not None: if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True use_mixup = True
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) loss, nll_loss, other_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = ( sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
) )
......
...@@ -79,18 +79,23 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -79,18 +79,23 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se: if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16) self.se_attn = SEAttention(self.embed_dim, 16)
self.s2_norm = LayerNorm(self.embed_dim) self.use_s2_attn_norm = args.use_s2_attn_norm
self.s2_attn_norm = LayerNorm(self.embed_dim) if self.use_s2_attn_norm:
self.s2_attn = MultiheadAttention( self.s2_norm = LayerNorm(self.embed_dim)
self.embed_dim,
args.encoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim),
dropout=args.attention_dropout,
self_attention=False,
)
self.encoder_collaboration_mode = args.encoder_collaboration_mode self.encoder_collaboration_mode = args.encoder_collaboration_mode
if self.encoder_collaboration_mode != "none":
if self.encoder_collaboration_mode == "serial":
self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
kdim=getattr(args, "encoder_s2_dim", self.embed_dim),
vdim=getattr(args, "encoder_s2_dim", self.embed_dim),
dropout=args.attention_dropout,
self_attention=False,
)
self.league_s1_ratio = args.encoder_league_s1_ratio self.league_s1_ratio = args.encoder_league_s1_ratio
self.league_s2_ratio = args.encoder_league_s2_ratio self.league_s2_ratio = args.encoder_league_s2_ratio
...@@ -251,7 +256,9 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -251,7 +256,9 @@ class TransformerS2EncoderLayer(nn.Module):
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
if s2 is not None: if s2 is not None:
s2 = self.s2_norm(s2) if self.use_s2_attn_norm:
s2 = self.s2_norm(s2)
if self.encoder_collaboration_mode == "serial": if self.encoder_collaboration_mode == "serial":
residual = x residual = x
x = self.s2_attn_norm(x) x = self.s2_attn_norm(x)
...@@ -399,9 +406,9 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -399,9 +406,9 @@ class TransformerS2DecoderLayer(nn.Module):
if self.league_drop_net_mix and self.training: if self.league_drop_net_mix and self.training:
return [frand, 1 - frand] return [frand, 1 - frand]
if frand < self.league_drop_net_prob and self.training: if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1] return [0, 1]
# elif frand > 1 - self.league_drop_net_prob and self.training:
# return [1, 0]
else: else:
return [0.5, 0.5] return [0.5, 0.5]
else: else:
...@@ -633,7 +640,6 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -633,7 +640,6 @@ class TransformerS2DecoderLayer(nn.Module):
x2 = self.dropout_module(x2) x2 = self.dropout_module(x2)
ratios = self.get_ratio() ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2 x = ratios[0] * x + ratios[1] * x2
x = x + x2
x = self.residual_connection(x, residual) x = self.residual_connection(x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.encoder_attn_layer_norm(x) x = self.encoder_attn_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论