Commit 8f45faa2 by xuchen

update the dual arch

parent cbeb5521
arch: s2t_dual arch: s2t_dual
asr-encoder: pds asr-encoder: transformer
mt-encoder-layers: 6 mt-encoder-layers: 6
mt-encoder: transformer mt-encoder: transformer
encoder-drop-net: True
encoder-drop-net-prob: 0.5
encoder-embed-dim: 256 encoder-collaboration-mode: parallel
pds-stages: 4 decoder-collaboration-mode: parallel
#ctc-layer: 12
pds-layers: 3_3_3_3 encoder-league-s1-ratio: 0.5
pds-ratios: 2_2_1_2 encoder-league-s2-ratio: 0.5
pds-fusion: True encoder-league-drop-net: False
pds-fusion-method: all_conv encoder-league-drop-net-prob: 0.2
pds-embed-dims: 256_256_256_256 encoder-league-drop-net-mix: False
pds-ds-method: conv
pds-embed-norm: True decoder-league-s1-ratio: 0.5
pds-position-embed: 1_1_1_1 decoder-league-s2-ratio: 0.5
pds-kernel-sizes: 5_5_5_5 decoder-league-drop-net: False
pds-ffn-ratios: 8_8_8_8 decoder-league-drop-net-prob: 0.0
pds-attn-heads: 4_4_4_4 decoder-league-drop-net-mix: False
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
...@@ -35,6 +33,7 @@ label_smoothing: 0.1 ...@@ -35,6 +33,7 @@ label_smoothing: 0.1
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
......
...@@ -35,13 +35,6 @@ class JoinSpeechTextLoss( ...@@ -35,13 +35,6 @@ class JoinSpeechTextLoss(
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser) LabelSmoothedCrossEntropyCriterion.add_args(parser)
CtcCriterion.add_args(parser) CtcCriterion.add_args(parser)
parser.add_argument(
"--ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of CTC loss",
)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -65,7 +58,7 @@ class JoinSpeechTextLoss( ...@@ -65,7 +58,7 @@ class JoinSpeechTextLoss(
if "mixup" in encoder_out and encoder_out["mixup"] is not None: if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True use_mixup = True
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) loss, nll_loss, other_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = ( sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
) )
......
...@@ -73,31 +73,75 @@ class S2TDualModel(FairseqEncoderDecoderModel): ...@@ -73,31 +73,75 @@ class S2TDualModel(FairseqEncoderDecoderModel):
type=int, type=int,
help="the layers of the MT encoder", help="the layers of the MT encoder",
) )
# collaboration
parser.add_argument( parser.add_argument(
"--encoder-asr-ratio", "--encoder-collaboration-mode",
default="none",
type=str,
help="how to calculate attention during league in encoder",
)
parser.add_argument(
"--decoder-collaboration-mode",
default="none",
type=str,
help="how to calculate attention during league in encoder",
)
# league
parser.add_argument(
"--encoder-league-s1-ratio",
default=0.5, default=0.5,
type=float, type=float,
help="the ratio of the asr representation", help="league ratio of the s1 representation",
) )
parser.add_argument( parser.add_argument(
"--encoder-mt-ratio", "--encoder-league-s2-ratio",
default=0.5, default=0.5,
type=float, type=float,
help="the ratio of the mt representation", help="league ratio of the s2 representation",
) )
parser.add_argument( parser.add_argument(
"--encoder-drop-net", "--encoder-league-drop-net",
action="store_true", action="store_true",
help="drop an input", help="drop one input during league",
) )
parser.add_argument( parser.add_argument(
"--encoder-drop-net-prob", "--encoder-league-drop-net-prob",
default=0.2, default=0.0,
type=float, type=float,
help="probability of dropping one of the representations", help="probability of dropping one representations",
) )
parser.add_argument( parser.add_argument(
"--encoder-drop-net-mix", "--encoder-league-drop-net-mix",
action="store_true",
help="mix the two input with any probability",
)
parser.add_argument(
"--decoder-league-s1-ratio",
default=0.5,
type=float,
help="league ratio of the s1 representation",
)
parser.add_argument(
"--decoder-league-s2-ratio",
default=0.5,
type=float,
help="league ratio of the s2 representation",
)
parser.add_argument(
"--decoder-league-drop-net",
action="store_true",
help="drop one input during league",
)
parser.add_argument(
"--decoder-league-drop-net-prob",
default=0.0,
type=float,
help="probability of dropping one representations",
)
parser.add_argument(
"--decoder-league-drop-net-mix",
action="store_true", action="store_true",
help="mix the two input with any probability", help="mix the two input with any probability",
) )
...@@ -241,6 +285,8 @@ class S2TDualEncoder(FairseqEncoder): ...@@ -241,6 +285,8 @@ class S2TDualEncoder(FairseqEncoder):
super().__init__(None) super().__init__(None)
asr_encoder_type = args.asr_encoder asr_encoder_type = args.asr_encoder
args.encoder_layers = 12
if asr_encoder_type == "transformer": if asr_encoder_type == "transformer":
self.asr_encoder = S2TTransformerEncoder(args, task) self.asr_encoder = S2TTransformerEncoder(args, task)
elif asr_encoder_type == "pds": elif asr_encoder_type == "pds":
...@@ -250,14 +296,15 @@ class S2TDualEncoder(FairseqEncoder): ...@@ -250,14 +296,15 @@ class S2TDualEncoder(FairseqEncoder):
else: else:
logger.error("Unsupported ASR architecture: %s." % asr_encoder_type) logger.error("Unsupported ASR architecture: %s." % asr_encoder_type)
setattr(args, "encoder_s1_ratio", args.encoder_asr_ratio) self.encoder_collaboration_mode = args.encoder_collaboration_mode
setattr(args, "encoder_s2_ratio", args.encoder_mt_ratio) setattr(args, "use_s2_attn_norm", False)
asr_encoder_layers = args.encoder_layers
setattr(args, "encoder_layers", args.mt_encoder_layers) setattr(args, "encoder_layers", args.mt_encoder_layers)
attn_type = args.encoder_attention_type attn_type = args.encoder_attention_type
setattr(args, "encoder_attention_type", "selfattn") setattr(args, "encoder_attention_type", "selfattn")
self.mt_encoder = TransformerS2Encoder(args, task.source_dictionary, embed_tokens) self.mt_encoder = TransformerS2Encoder(args, task.source_dictionary, embed_tokens)
setattr(args, "encoder_attention_type", attn_type) setattr(args, "encoder_attention_type", attn_type)
setattr(args, "encoder_layers", asr_encoder_layers)
def forward(self, speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths, **kwargs): def forward(self, speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths, **kwargs):
asr_encoder_out = self.asr_encoder(speech_src_tokens, speech_src_lengths) asr_encoder_out = self.asr_encoder(speech_src_tokens, speech_src_lengths)
...@@ -269,6 +316,13 @@ class S2TDualEncoder(FairseqEncoder): ...@@ -269,6 +316,13 @@ class S2TDualEncoder(FairseqEncoder):
encoder_out["ctc_logit"] = asr_encoder_out["ctc_logit"] encoder_out["ctc_logit"] = asr_encoder_out["ctc_logit"]
encoder_out["ctc_padding_mask"] = asr_encoder_out["encoder_padding_mask"] encoder_out["ctc_padding_mask"] = asr_encoder_out["encoder_padding_mask"]
# encoder_out["encoder_out"] = encoder_out["s2_encoder_out"]
# encoder_out["encoder_padding_mask"] = encoder_out["s2_encoder_padding_mask"]
#
# encoder_out["s2_encoder_out"] = []
# encoder_out["s2_encoder_padding_mask"] = []
return encoder_out return encoder_out
def forward_torchscript(self, net_input: Dict[str, Tensor]): def forward_torchscript(self, net_input: Dict[str, Tensor]):
...@@ -301,15 +355,15 @@ class S2TDualEncoder(FairseqEncoder): ...@@ -301,15 +355,15 @@ class S2TDualEncoder(FairseqEncoder):
new_encoder_padding_mask = [ new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order) encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
] ]
if len(encoder_out["encoder_out_s2"]) == 0: if len(encoder_out["s2_encoder_out"]) == 0:
new_encoder_out_s2 = [] new_s2_encoder_out = []
else: else:
new_encoder_out_s2 = [encoder_out["encoder_out_s2"][0].index_select(1, new_order)] new_s2_encoder_out = [encoder_out["s2_encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask_s2"]) == 0: if len(encoder_out["s2_encoder_padding_mask"]) == 0:
new_encoder_padding_mask_s2 = [] new_encoder_padding_mask_s2 = []
else: else:
new_encoder_padding_mask_s2 = [ new_s2_encoder_padding_mask = [
encoder_out["encoder_padding_mask_s2"][0].index_select(0, new_order) encoder_out["s2_encoder_padding_mask"][0].index_select(0, new_order)
] ]
if len(encoder_out["encoder_embedding"]) == 0: if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = [] new_encoder_embedding = []
...@@ -336,8 +390,8 @@ class S2TDualEncoder(FairseqEncoder): ...@@ -336,8 +390,8 @@ class S2TDualEncoder(FairseqEncoder):
return { return {
"encoder_out": new_encoder_out, # T x B x C "encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_out_s2": new_encoder_out_s2, # T x B x C "s2_encoder_out": new_s2_encoder_out, # T x B x C
"encoder_padding_mask_s2": new_encoder_padding_mask_s2, # B x T "s2_encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C "encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T "src_tokens": src_tokens, # B x T
...@@ -356,8 +410,7 @@ def base_architecture(args): ...@@ -356,8 +410,7 @@ def base_architecture(args):
args.subsampling_norm = getattr(args, "subsampling_norm", "none") args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu") args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
...@@ -372,6 +425,7 @@ def base_architecture(args): ...@@ -372,6 +425,7 @@ def base_architecture(args):
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn") args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1) args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout) args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
...@@ -379,6 +433,10 @@ def base_architecture(args): ...@@ -379,6 +433,10 @@ def base_architecture(args):
args.activation_fn = getattr(args, "activation_fn", "relu") args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
...@@ -394,25 +452,64 @@ def base_architecture(args): ...@@ -394,25 +452,64 @@ def base_architecture(args):
) )
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.encoder_no_scale_embedding = getattr(args, "encoder_no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1) args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
args.k_only = getattr(args, 'k_only', True)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# SATE # local modeling
args.acoustic_encoder = getattr(args, "acoustic_encoder", "transformer") args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.adapter = getattr(args, "adapter", "league") args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.ctc_compress_strategy = getattr(args, "ctc_compress_strategy", "avg") args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
args.temperature = getattr(args, "temperature", 1.0)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) # interleaved CTC
args.text_attention_type = getattr(args, "text_attention_type", "selfattn") args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.share_ctc_and_adapter = getattr(args, "share_ctc_and_adapter", False) args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.target_sae_adapter = getattr(args, "target_sae_adapter", args.sae_adapter)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.share_target_sae_and_ctc = getattr(args, "share_target_sae_and_ctc", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False)
# PDS # PDS
args.pds_stages = getattr(args, "pds_stages", None) args.pds_stages = getattr(args, "pds_stages", None)
...@@ -432,23 +529,25 @@ def base_architecture(args): ...@@ -432,23 +529,25 @@ def base_architecture(args):
args.pds_conv_strides = getattr(args, "pds_conv_strides", None) args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None) args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False) args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# dual # dual
args.encoder_asr_ratio = getattr(args, "encoder_asr_ratio", 1.0) args.encoder_collaboration_mode = getattr(args, "encoder_collaboration_mode", "none")
args.encoder_mt_ratio = getattr(args, "encoder_mt_ratio", 1.0) args.decoder_collaboration_mode = getattr(args, "decoder_collaboration_mode", "none")
args.encoder_drop_net = getattr(args, "encoder_drop_net", False)
args.encoder_drop_net_prob = getattr(args, "encoder_drop_net_prob", 1.0) args.encoder_league_s1_ratio = getattr(args, "encoder_league_s1_ratio", 0.5)
args.encoder_drop_net_mix = getattr(args, "encoder_drop_net_mix", False) args.encoder_league_s2_ratio = getattr(args, "encoder_league_s2_ratio", 0.5)
args.encoder_league_drop_net = getattr(args, "encoder_league_drop_net", False)
args.encoder_league_drop_net_prob = getattr(args, "encoder_league_drop_net_prob", 0.0)
args.encoder_league_drop_net_mix = getattr(args, "encoder_league_drop_net_mix", False)
args.decoder_league_s1_ratio = getattr(args, "decoder_league_s1_ratio", 0.5)
args.decoder_league_s2_ratio = getattr(args, "decoder_league_s2_ratio", 0.5)
args.decoder_league_drop_net = getattr(args, "decoder_league_drop_net", False)
args.decoder_league_drop_net_prob = getattr(args, "decoder_league_drop_net_prob", 0.0)
args.decoder_league_drop_net_mix = getattr(args, "decoder_league_drop_net_mix", False)
@register_model_architecture("s2t_dual", "s2t_dual_s") @register_model_architecture("s2t_dual", "s2t_dual_s")
......
...@@ -79,18 +79,23 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -79,18 +79,23 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se: if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16) self.se_attn = SEAttention(self.embed_dim, 16)
self.use_s2_attn_norm = args.use_s2_attn_norm
if self.use_s2_attn_norm:
self.s2_norm = LayerNorm(self.embed_dim) self.s2_norm = LayerNorm(self.embed_dim)
self.encoder_collaboration_mode = args.encoder_collaboration_mode
if self.encoder_collaboration_mode != "none":
if self.encoder_collaboration_mode == "serial":
self.s2_attn_norm = LayerNorm(self.embed_dim) self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention( self.s2_attn = MultiheadAttention(
self.embed_dim, self.embed_dim,
args.encoder_attention_heads, args.encoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim), kdim=getattr(args, "encoder_s2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim), vdim=getattr(args, "encoder_s2_dim", self.embed_dim),
dropout=args.attention_dropout, dropout=args.attention_dropout,
self_attention=False, self_attention=False,
) )
self.encoder_collaboration_mode = args.encoder_collaboration_mode
self.league_s1_ratio = args.encoder_league_s1_ratio self.league_s1_ratio = args.encoder_league_s1_ratio
self.league_s2_ratio = args.encoder_league_s2_ratio self.league_s2_ratio = args.encoder_league_s2_ratio
...@@ -251,7 +256,9 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -251,7 +256,9 @@ class TransformerS2EncoderLayer(nn.Module):
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
if s2 is not None: if s2 is not None:
if self.use_s2_attn_norm:
s2 = self.s2_norm(s2) s2 = self.s2_norm(s2)
if self.encoder_collaboration_mode == "serial": if self.encoder_collaboration_mode == "serial":
residual = x residual = x
x = self.s2_attn_norm(x) x = self.s2_attn_norm(x)
...@@ -399,9 +406,9 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -399,9 +406,9 @@ class TransformerS2DecoderLayer(nn.Module):
if self.league_drop_net_mix and self.training: if self.league_drop_net_mix and self.training:
return [frand, 1 - frand] return [frand, 1 - frand]
if frand < self.league_drop_net_prob and self.training: if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1] return [0, 1]
# elif frand > 1 - self.league_drop_net_prob and self.training:
# return [1, 0]
else: else:
return [0.5, 0.5] return [0.5, 0.5]
else: else:
...@@ -633,7 +640,6 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -633,7 +640,6 @@ class TransformerS2DecoderLayer(nn.Module):
x2 = self.dropout_module(x2) x2 = self.dropout_module(x2)
ratios = self.get_ratio() ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2 x = ratios[0] * x + ratios[1] * x2
x = x + x2
x = self.residual_connection(x, residual) x = self.residual_connection(x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.encoder_attn_layer_norm(x) x = self.encoder_attn_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论