Commit 47e0f6e0 by xuchen

add the multibranch S2T architecture.

I also find some bugs in the dual architecture.
parent 793f553a
......@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa
from .s2t_dual import * # noqa
from .s2t_ctc import *
from .s2t_multibranch import *
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch
......@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
TransformerS2Encoder,
TransformerS2Decoder,
)
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
TransformerS2DecoderLayer,
TransformerS2EncoderLayer,
)
logger = logging.getLogger(__name__)
......
......@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
encoder_out_s2=encoder_out["encoder_out_s2"][0],
encoder_padding_mask_s2=encoder_out["encoder_padding_mask_s2"][0],
encoder_out_s2=encoder_out["s2_encoder_out"][0],
encoder_padding_mask_s2=encoder_out["s2_encoder_padding_mask"][0],
incremental_state=incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
......
......@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
)
from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .s2t_transformer_s2_layer import S2TTransformerS2EncoderLayer
from .pds_layer import PDSTransformerEncoderLayer
__all__ = [
......@@ -70,6 +71,7 @@ __all__ = [
"BeamableMM",
"CharacterTokenEmbedder",
"S2TTransformerEncoderLayer",
"S2TTransformerS2EncoderLayer",
"ConvolutionModule",
"ConvTBC",
"cross_entropy",
......
......@@ -79,6 +79,8 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16)
self.s2_norm = LayerNorm(self.embed_dim)
self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
......@@ -87,26 +89,28 @@ class TransformerS2EncoderLayer(nn.Module):
dropout=args.attention_dropout,
self_attention=False,
)
self.s1_ratio = args.encoder_s1_ratio
self.s2_ratio = args.encoder_s2_ratio
self.drop_net = args.encoder_drop_net
self.drop_net_prob = args.encoder_drop_net_prob
self.drop_net_mix = args.encoder_drop_net_mix
self.encoder_collaboration_mode = args.encoder_collaboration_mode
self.league_s1_ratio = args.encoder_league_s1_ratio
self.league_s2_ratio = args.encoder_league_s2_ratio
self.league_drop_net = args.encoder_league_drop_net
self.league_drop_net_prob = args.encoder_league_drop_net_prob
self.league_drop_net_mix = args.encoder_league_drop_net_mix
def get_ratio(self):
if self.drop_net:
if self.league_drop_net:
frand = float(uniform(0, 1))
if self.drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training:
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training:
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.s1_ratio, self.s2_ratio]
return [self.league_s1_ratio, self.league_s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
......@@ -186,8 +190,8 @@ class TransformerS2EncoderLayer(nn.Module):
def forward(self, x,
encoder_padding_mask: Optional[Tensor],
x2 = None,
x2_encoder_padding_mask = None,
s2 = None,
s2_encoder_padding_mask = None,
attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None):
"""
......@@ -219,6 +223,7 @@ class TransformerS2EncoderLayer(nn.Module):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
attn_x = x
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
......@@ -240,20 +245,34 @@ class TransformerS2EncoderLayer(nn.Module):
attn_mask=attn_mask,
)
x = self.dropout_module(x)
if s2 is None or self.encoder_collaboration_mode != "parallel":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if x2 is not None:
x2, _ = self.s2_attn(
if s2 is not None:
s2 = self.s2_norm(s2)
if self.encoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_norm(x)
x, _ = self.s2_attn(
query=x,
key=x2,
value=x2,
key_padding_mask=x2_encoder_padding_mask)
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.encoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=attn_x,
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask)
x2 = self.dropout_module(x2)
ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1]
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
......@@ -341,11 +360,12 @@ class TransformerS2DecoderLayer(nn.Module):
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim),
kdim=getattr(args, "encoder_s2_dim", self.embed_dim),
vdim=getattr(args, "encoder_s2_dim", self.embed_dim),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.s2_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = self.build_fc1(
self.embed_dim,
......@@ -365,26 +385,27 @@ class TransformerS2DecoderLayer(nn.Module):
self.onnx_trace = False
self.s1_ratio = args.encoder_s1_ratio
self.s2_ratio = args.encoder_s2_ratio
self.decoder_collaboration_mode = args.decoder_collaboration_mode
self.league_s1_ratio = args.decoder_league_s1_ratio
self.league_s2_ratio = args.decoder_league_s2_ratio
self.drop_net = args.encoder_drop_net
self.drop_net_prob = args.encoder_drop_net_prob
self.drop_net_mix = args.encoder_drop_net_mix
self.league_drop_net = args.decoder_league_drop_net
self.league_drop_net_prob = args.decoder_league_drop_net_prob
self.league_drop_net_mix = args.decoder_league_drop_net_mix
def get_ratio(self):
if self.drop_net:
if self.league_drop_net:
frand = float(uniform(0, 1))
if self.drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training:
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training:
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.s1_ratio, self.s2_ratio]
return [self.league_s1_ratio, self.league_s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
......@@ -551,6 +572,8 @@ class TransformerS2DecoderLayer(nn.Module):
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
cross_attn_x = x
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
......@@ -575,9 +598,16 @@ class TransformerS2DecoderLayer(nn.Module):
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
if encoder_out_s2 is None or self.decoder_collaboration_mode != "parallel":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if encoder_out_s2 is not None:
x2, _ = self.s2_attn(
if self.decoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_layer_norm(x)
x, _ = self.s2_attn(
query=x,
key=encoder_out_s2,
value=encoder_out_s2,
......@@ -587,10 +617,23 @@ class TransformerS2DecoderLayer(nn.Module):
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.decoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=cross_attn_x,
key=encoder_out_s2,
value=encoder_out_s2,
key_padding_mask=encoder_padding_mask_s2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x2 = self.dropout_module(x2)
ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2
x = x + x2
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论