Commit a2fd43e7 by xuchen

fix the bugs of multibranch arch

parent 47e0f6e0
......@@ -4,4 +4,5 @@ inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5
inter-mixup-keep-org: True
ctc-mixupr-consistent-weight: 1
ctc-mixup-consistent-weight: 1
mixup-consistent-weight: 1
\ No newline at end of file
......@@ -19,6 +19,7 @@ from fairseq.models import (
register_model_architecture,
)
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
......@@ -103,6 +104,12 @@ class S2TMultiBranchModel(FairseqEncoderDecoderModel):
help="direction of collaboration",
)
parser.add_argument(
"--collaboration-start",
default="0:0",
type=str,
help="start collaboration in two encoders",
)
parser.add_argument(
"--collaboration-step",
default="1:1",
type=str,
......@@ -329,6 +336,9 @@ class S2TMultiBranchEncoder(FairseqEncoder):
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None)
self.padding_idx = 1
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
setattr(args, "encoder_layers", args.junior_acoustic_encoder_layers)
junior_encoder_type = args.junior_acoustic_encoder
......@@ -392,11 +402,22 @@ class S2TMultiBranchEncoder(FairseqEncoder):
# collaboration
collaboration_step = args.collaboration_step
self.collaboration_direction = args.collaboration_direction
collaboration_start = args.collaboration_start
if len(collaboration_start.split(":")) == 2:
self.collaboration_start = [int(s) for s in collaboration_start.split(":")]
elif len(collaboration_start.split(":")) == 1:
self.collaboration_start = [int(collaboration_start), int(collaboration_start)]
else:
self.collaboration_start = [0, 0]
if len(collaboration_step.split(":")) == 2:
self.collaboration_step = [int(s) for s in collaboration_step.split(":")]
elif len(collaboration_step.split(":")) == 1:
self.collaboration_step = [int(collaboration_step), int(collaboration_step)]
else:
self.collaboration_step = [1, 1]
self.collaboration_direction = args.collaboration_direction
self.acoustic_norm = LayerNorm(args.encoder_embed_dim)
self.textual_norm = LayerNorm(args.encoder_embed_dim)
......@@ -414,7 +435,7 @@ class S2TMultiBranchEncoder(FairseqEncoder):
adapter_x, adapter_encoder_padding_mask = self.adapter(x, acoustic_encoder_padding_mask)
textual_x = adapter_x + self.textual_embed_positions(adapter_encoder_padding_mask).transpose(0, 1)
# textual_x = self.dropout_module(textual_x)
textual_x = self.dropout_module(textual_x)
textual_encoder_padding_mask = adapter_encoder_padding_mask
senior_acoustic_encoder_idx = -1
......@@ -425,51 +446,67 @@ class S2TMultiBranchEncoder(FairseqEncoder):
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
textual_x,
encoder_padding_mask=textual_encoder_padding_mask,
)
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x, s2_encoder_padding_mask=textual_encoder_padding_mask
acoustic_x,
encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x if senior_acoustic_encoder_idx >= self.collaboration_start[0] else None,
s2_encoder_padding_mask=textual_encoder_padding_mask if senior_acoustic_encoder_idx >=
self.collaboration_start[0] else None,
)
elif self.collaboration_direction == "textual":
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
acoustic_x,
encoder_padding_mask=acoustic_encoder_padding_mask,
)
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
xs=acoustic_x, s2_encoder_padding_mask=acoustic_encoder_padding_mask
textual_x,
encoder_padding_mask=textual_encoder_padding_mask,
s2=acoustic_x if textual_encoder_idx >= self.collaboration_start[1] else None,
s2_encoder_padding_mask=acoustic_encoder_padding_mask if textual_encoder_idx >=
self.collaboration_start[1] else None,
)
elif self.collaboration_direction == "both":
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x, s2_encoder_padding_mask=textual_encoder_padding_mask
acoustic_x,
encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x if senior_acoustic_encoder_idx >= self.collaboration_start[0] else None,
s2_encoder_padding_mask=textual_encoder_padding_mask if senior_acoustic_encoder_idx >=
self.collaboration_start[0] else None
)
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
s2=acoustic_x, s2_encoder_padding_mask=acoustic_encoder_padding_mask
textual_x,
encoder_padding_mask=textual_encoder_padding_mask,
s2=acoustic_x if textual_encoder_idx >= self.collaboration_start[1] else None,
s2_encoder_padding_mask=acoustic_encoder_padding_mask if textual_encoder_idx >=
self.collaboration_start[1] else None
)
elif self.collaboration_direction == "none":
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
acoustic_x,
encoder_padding_mask=acoustic_encoder_padding_mask,
)
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
textual_x,
encoder_padding_mask=textual_encoder_padding_mask,
)
if senior_acoustic_encoder_idx == self.senior_acoustic_encoder_layer_num - 1 and \
textual_encoder_idx == self.textual_encoder_layer_num - 1:
......@@ -683,7 +720,7 @@ def base_architecture(args):
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# dual
# multibranch
args.junior_acoustic_encoder = getattr(args, "junior_acoustic_encoder", "transformer")
args.senior_acoustic_encoder = getattr(args, "senior_acoustic_encoder", "transformer")
args.textual_encoder = getattr(args, "textual_encoder", "transformer")
......@@ -695,6 +732,7 @@ def base_architecture(args):
args.collaboration_direction = getattr(args, "collaboration_direction", "none")
args.collaboration_step = getattr(args, "collaboration_step", "1:1")
args.collaboration_start = getattr(args, "collaboration_start", "0:0")
args.encoder_collaboration_mode = getattr(args, "encoder_collaboration_mode", "serial")
args.decoder_collaboration_mode = getattr(args, "decoder_collaboration_mode", "serial")
......
......@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Optional
from numpy.random import uniform
import torch
import torch.nn as nn
......@@ -151,6 +152,28 @@ class S2TTransformerS2EncoderLayer(nn.Module):
self_attention=False,
)
self.encoder_collaboration_mode = args.encoder_collaboration_mode
self.league_s1_ratio = args.encoder_league_s1_ratio
self.league_s2_ratio = args.encoder_league_s2_ratio
self.league_drop_net = args.encoder_league_drop_net
self.league_drop_net_prob = args.encoder_league_drop_net_prob
self.league_drop_net_mix = args.encoder_league_drop_net_mix
def get_ratio(self):
if self.league_drop_net:
frand = float(uniform(0, 1))
if self.league_drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.league_s1_ratio, self.league_s2_ratio]
def build_self_attention(self, args, embed_dim):
attention_heads = args.encoder_attention_heads
dropout = args.dropout
......@@ -274,6 +297,7 @@ class S2TTransformerS2EncoderLayer(nn.Module):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
attn_x = x
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
......@@ -295,16 +319,17 @@ class S2TTransformerS2EncoderLayer(nn.Module):
attn_mask=attn_mask,
)
x = self.dropout_module(x)
if s2 is None or self.encoder_collaboration_mode != "parallel":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if s2 is not None:
s2 = self.s2_norm(s2)
if self.encoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_norm(x)
s2 = self.s2_norm(s2)
x, _ = self.self_attn(
x, _ = self.s2_attn(
query=x,
key=s2,
value=s2,
......@@ -313,6 +338,16 @@ class S2TTransformerS2EncoderLayer(nn.Module):
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.encoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=attn_x,
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask)
x2 = self.dropout_module(x2)
ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1]
x = self.residual_connection(x, residual)
# convolution module
if self.conv_module is not None:
......
......@@ -396,7 +396,7 @@ class TransformerS2DecoderLayer(nn.Module):
def get_ratio(self):
if self.league_drop_net:
frand = float(uniform(0, 1))
if self.drop_net_mix and self.training:
if self.league_drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论