Commit a2fd43e7 by xuchen

fix the bugs of multibranch arch

parent 47e0f6e0
...@@ -4,4 +4,5 @@ inter-mixup-prob: 1.0 ...@@ -4,4 +4,5 @@ inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0 inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5 inter-mixup-beta: 0.5
inter-mixup-keep-org: True inter-mixup-keep-org: True
ctc-mixupr-consistent-weight: 1 ctc-mixup-consistent-weight: 1
mixup-consistent-weight: 1
\ No newline at end of file
...@@ -19,6 +19,7 @@ from fairseq.models import ( ...@@ -19,6 +19,7 @@ from fairseq.models import (
register_model_architecture, register_model_architecture,
) )
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
LegacyRelPositionalEncoding, LegacyRelPositionalEncoding,
...@@ -103,6 +104,12 @@ class S2TMultiBranchModel(FairseqEncoderDecoderModel): ...@@ -103,6 +104,12 @@ class S2TMultiBranchModel(FairseqEncoderDecoderModel):
help="direction of collaboration", help="direction of collaboration",
) )
parser.add_argument( parser.add_argument(
"--collaboration-start",
default="0:0",
type=str,
help="start collaboration in two encoders",
)
parser.add_argument(
"--collaboration-step", "--collaboration-step",
default="1:1", default="1:1",
type=str, type=str,
...@@ -329,6 +336,9 @@ class S2TMultiBranchEncoder(FairseqEncoder): ...@@ -329,6 +336,9 @@ class S2TMultiBranchEncoder(FairseqEncoder):
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None) super().__init__(None)
self.padding_idx = 1 self.padding_idx = 1
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
setattr(args, "encoder_layers", args.junior_acoustic_encoder_layers) setattr(args, "encoder_layers", args.junior_acoustic_encoder_layers)
junior_encoder_type = args.junior_acoustic_encoder junior_encoder_type = args.junior_acoustic_encoder
...@@ -392,11 +402,22 @@ class S2TMultiBranchEncoder(FairseqEncoder): ...@@ -392,11 +402,22 @@ class S2TMultiBranchEncoder(FairseqEncoder):
# collaboration # collaboration
collaboration_step = args.collaboration_step collaboration_step = args.collaboration_step
self.collaboration_direction = args.collaboration_direction
collaboration_start = args.collaboration_start
if len(collaboration_start.split(":")) == 2:
self.collaboration_start = [int(s) for s in collaboration_start.split(":")]
elif len(collaboration_start.split(":")) == 1:
self.collaboration_start = [int(collaboration_start), int(collaboration_start)]
else:
self.collaboration_start = [0, 0]
if len(collaboration_step.split(":")) == 2: if len(collaboration_step.split(":")) == 2:
self.collaboration_step = [int(s) for s in collaboration_step.split(":")] self.collaboration_step = [int(s) for s in collaboration_step.split(":")]
elif len(collaboration_step.split(":")) == 1:
self.collaboration_step = [int(collaboration_step), int(collaboration_step)]
else: else:
self.collaboration_step = [1, 1] self.collaboration_step = [1, 1]
self.collaboration_direction = args.collaboration_direction
self.acoustic_norm = LayerNorm(args.encoder_embed_dim) self.acoustic_norm = LayerNorm(args.encoder_embed_dim)
self.textual_norm = LayerNorm(args.encoder_embed_dim) self.textual_norm = LayerNorm(args.encoder_embed_dim)
...@@ -414,7 +435,7 @@ class S2TMultiBranchEncoder(FairseqEncoder): ...@@ -414,7 +435,7 @@ class S2TMultiBranchEncoder(FairseqEncoder):
adapter_x, adapter_encoder_padding_mask = self.adapter(x, acoustic_encoder_padding_mask) adapter_x, adapter_encoder_padding_mask = self.adapter(x, acoustic_encoder_padding_mask)
textual_x = adapter_x + self.textual_embed_positions(adapter_encoder_padding_mask).transpose(0, 1) textual_x = adapter_x + self.textual_embed_positions(adapter_encoder_padding_mask).transpose(0, 1)
# textual_x = self.dropout_module(textual_x) textual_x = self.dropout_module(textual_x)
textual_encoder_padding_mask = adapter_encoder_padding_mask textual_encoder_padding_mask = adapter_encoder_padding_mask
senior_acoustic_encoder_idx = -1 senior_acoustic_encoder_idx = -1
...@@ -425,51 +446,67 @@ class S2TMultiBranchEncoder(FairseqEncoder): ...@@ -425,51 +446,67 @@ class S2TMultiBranchEncoder(FairseqEncoder):
for _ in range(self.collaboration_step[1]): for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1 textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx]( textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask, textual_x,
encoder_padding_mask=textual_encoder_padding_mask,
) )
for _ in range(self.collaboration_step[0]): for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1 senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx]( acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask, acoustic_x,
s2=textual_x, s2_encoder_padding_mask=textual_encoder_padding_mask encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x if senior_acoustic_encoder_idx >= self.collaboration_start[0] else None,
s2_encoder_padding_mask=textual_encoder_padding_mask if senior_acoustic_encoder_idx >=
self.collaboration_start[0] else None,
) )
elif self.collaboration_direction == "textual": elif self.collaboration_direction == "textual":
for _ in range(self.collaboration_step[0]): for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1 senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx]( acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask, acoustic_x,
encoder_padding_mask=acoustic_encoder_padding_mask,
) )
for _ in range(self.collaboration_step[1]): for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1 textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx]( textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask, textual_x,
xs=acoustic_x, s2_encoder_padding_mask=acoustic_encoder_padding_mask encoder_padding_mask=textual_encoder_padding_mask,
s2=acoustic_x if textual_encoder_idx >= self.collaboration_start[1] else None,
s2_encoder_padding_mask=acoustic_encoder_padding_mask if textual_encoder_idx >=
self.collaboration_start[1] else None,
) )
elif self.collaboration_direction == "both": elif self.collaboration_direction == "both":
for _ in range(self.collaboration_step[0]): for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1 senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx]( acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask, acoustic_x,
s2=textual_x, s2_encoder_padding_mask=textual_encoder_padding_mask encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x if senior_acoustic_encoder_idx >= self.collaboration_start[0] else None,
s2_encoder_padding_mask=textual_encoder_padding_mask if senior_acoustic_encoder_idx >=
self.collaboration_start[0] else None
) )
for _ in range(self.collaboration_step[1]): for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1 textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx]( textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask, textual_x,
s2=acoustic_x, s2_encoder_padding_mask=acoustic_encoder_padding_mask encoder_padding_mask=textual_encoder_padding_mask,
s2=acoustic_x if textual_encoder_idx >= self.collaboration_start[1] else None,
s2_encoder_padding_mask=acoustic_encoder_padding_mask if textual_encoder_idx >=
self.collaboration_start[1] else None
) )
elif self.collaboration_direction == "none": elif self.collaboration_direction == "none":
for _ in range(self.collaboration_step[0]): for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1 senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx]( acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask, acoustic_x,
encoder_padding_mask=acoustic_encoder_padding_mask,
) )
for _ in range(self.collaboration_step[1]): for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1 textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx]( textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask, textual_x,
encoder_padding_mask=textual_encoder_padding_mask,
) )
if senior_acoustic_encoder_idx == self.senior_acoustic_encoder_layer_num - 1 and \ if senior_acoustic_encoder_idx == self.senior_acoustic_encoder_layer_num - 1 and \
textual_encoder_idx == self.textual_encoder_layer_num - 1: textual_encoder_idx == self.textual_encoder_layer_num - 1:
...@@ -683,7 +720,7 @@ def base_architecture(args): ...@@ -683,7 +720,7 @@ def base_architecture(args):
args.pds_fusion = getattr(args, "pds_fusion", False) args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# dual # multibranch
args.junior_acoustic_encoder = getattr(args, "junior_acoustic_encoder", "transformer") args.junior_acoustic_encoder = getattr(args, "junior_acoustic_encoder", "transformer")
args.senior_acoustic_encoder = getattr(args, "senior_acoustic_encoder", "transformer") args.senior_acoustic_encoder = getattr(args, "senior_acoustic_encoder", "transformer")
args.textual_encoder = getattr(args, "textual_encoder", "transformer") args.textual_encoder = getattr(args, "textual_encoder", "transformer")
...@@ -695,6 +732,7 @@ def base_architecture(args): ...@@ -695,6 +732,7 @@ def base_architecture(args):
args.collaboration_direction = getattr(args, "collaboration_direction", "none") args.collaboration_direction = getattr(args, "collaboration_direction", "none")
args.collaboration_step = getattr(args, "collaboration_step", "1:1") args.collaboration_step = getattr(args, "collaboration_step", "1:1")
args.collaboration_start = getattr(args, "collaboration_start", "0:0")
args.encoder_collaboration_mode = getattr(args, "encoder_collaboration_mode", "serial") args.encoder_collaboration_mode = getattr(args, "encoder_collaboration_mode", "serial")
args.decoder_collaboration_mode = getattr(args, "decoder_collaboration_mode", "serial") args.decoder_collaboration_mode = getattr(args, "decoder_collaboration_mode", "serial")
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Optional from typing import Optional
from numpy.random import uniform
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -151,6 +152,28 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -151,6 +152,28 @@ class S2TTransformerS2EncoderLayer(nn.Module):
self_attention=False, self_attention=False,
) )
self.encoder_collaboration_mode = args.encoder_collaboration_mode
self.league_s1_ratio = args.encoder_league_s1_ratio
self.league_s2_ratio = args.encoder_league_s2_ratio
self.league_drop_net = args.encoder_league_drop_net
self.league_drop_net_prob = args.encoder_league_drop_net_prob
self.league_drop_net_mix = args.encoder_league_drop_net_mix
def get_ratio(self):
if self.league_drop_net:
frand = float(uniform(0, 1))
if self.league_drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.league_s1_ratio, self.league_s2_ratio]
def build_self_attention(self, args, embed_dim): def build_self_attention(self, args, embed_dim):
attention_heads = args.encoder_attention_heads attention_heads = args.encoder_attention_heads
dropout = args.dropout dropout = args.dropout
...@@ -274,6 +297,7 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -274,6 +297,7 @@ class S2TTransformerS2EncoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
attn_x = x
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]: if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
assert pos_emb is not None, "Positions is necessary for RPE!" assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn( x, _ = self.self_attn(
...@@ -294,25 +318,36 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -294,25 +318,36 @@ class S2TTransformerS2EncoderLayer(nn.Module):
need_weights=False, need_weights=False,
attn_mask=attn_mask, attn_mask=attn_mask,
) )
x = self.dropout_module(x) x = self.dropout_module(x)
x = self.residual_connection(x, residual) if s2 is None or self.encoder_collaboration_mode != "parallel":
if not self.normalize_before: x = self.residual_connection(x, residual)
x = self.self_attn_layer_norm(x) if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if s2 is not None: if s2 is not None:
residual = x
x = self.s2_attn_norm(x)
s2 = self.s2_norm(s2) s2 = self.s2_norm(s2)
x, _ = self.self_attn( if self.encoder_collaboration_mode == "serial":
query=x, residual = x
key=s2, x = self.s2_attn_norm(x)
value=s2, x, _ = self.s2_attn(
key_padding_mask=s2_encoder_padding_mask, query=x,
need_weights=False, key=s2,
) value=s2,
x = self.dropout_module(x) key_padding_mask=s2_encoder_padding_mask,
x = self.residual_connection(x, residual) need_weights=False,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.encoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=attn_x,
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask)
x2 = self.dropout_module(x2)
ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1]
x = self.residual_connection(x, residual)
# convolution module # convolution module
if self.conv_module is not None: if self.conv_module is not None:
......
...@@ -396,7 +396,7 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -396,7 +396,7 @@ class TransformerS2DecoderLayer(nn.Module):
def get_ratio(self): def get_ratio(self):
if self.league_drop_net: if self.league_drop_net:
frand = float(uniform(0, 1)) frand = float(uniform(0, 1))
if self.drop_net_mix and self.training: if self.league_drop_net_mix and self.training:
return [frand, 1 - frand] return [frand, 1 - frand]
if frand < self.league_drop_net_prob and self.training: if frand < self.league_drop_net_prob and self.training:
return [1, 0] return [1, 0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论