Commit 47e0f6e0 by xuchen

add the multibranch S2T architecture.

I also find some bugs in the dual architecture.
parent 793f553a
...@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa ...@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
from .s2t_dual import * # noqa from .s2t_dual import * # noqa
from .s2t_ctc import * from .s2t_ctc import *
from .s2t_multibranch import *
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
import logging import logging
import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import ( ...@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
TransformerS2Encoder, TransformerS2Encoder,
TransformerS2Decoder, TransformerS2Decoder,
) )
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
TransformerS2DecoderLayer,
TransformerS2EncoderLayer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder): ...@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
and len(encoder_out["encoder_padding_mask"]) > 0 and len(encoder_out["encoder_padding_mask"]) > 0
) )
else None, else None,
encoder_out_s2=encoder_out["encoder_out_s2"][0], encoder_out_s2=encoder_out["s2_encoder_out"][0],
encoder_padding_mask_s2=encoder_out["encoder_padding_mask_s2"][0], encoder_padding_mask_s2=encoder_out["s2_encoder_padding_mask"][0],
incremental_state=incremental_state, incremental_state=incremental_state,
self_attn_mask=self_attn_mask, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
......
...@@ -61,6 +61,7 @@ from .espnet_multihead_attention import ( ...@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
) )
from .convolution import ConvolutionModule from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .s2t_transformer_s2_layer import S2TTransformerS2EncoderLayer
from .pds_layer import PDSTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
__all__ = [ __all__ = [
...@@ -70,6 +71,7 @@ __all__ = [ ...@@ -70,6 +71,7 @@ __all__ = [
"BeamableMM", "BeamableMM",
"CharacterTokenEmbedder", "CharacterTokenEmbedder",
"S2TTransformerEncoderLayer", "S2TTransformerEncoderLayer",
"S2TTransformerS2EncoderLayer",
"ConvolutionModule", "ConvolutionModule",
"ConvTBC", "ConvTBC",
"cross_entropy", "cross_entropy",
......
...@@ -79,6 +79,8 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -79,6 +79,8 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se: if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16) self.se_attn = SEAttention(self.embed_dim, 16)
self.s2_norm = LayerNorm(self.embed_dim)
self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention( self.s2_attn = MultiheadAttention(
self.embed_dim, self.embed_dim,
args.encoder_attention_heads, args.encoder_attention_heads,
...@@ -87,26 +89,28 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -87,26 +89,28 @@ class TransformerS2EncoderLayer(nn.Module):
dropout=args.attention_dropout, dropout=args.attention_dropout,
self_attention=False, self_attention=False,
) )
self.s1_ratio = args.encoder_s1_ratio
self.s2_ratio = args.encoder_s2_ratio
self.drop_net = args.encoder_drop_net self.encoder_collaboration_mode = args.encoder_collaboration_mode
self.drop_net_prob = args.encoder_drop_net_prob self.league_s1_ratio = args.encoder_league_s1_ratio
self.drop_net_mix = args.encoder_drop_net_mix self.league_s2_ratio = args.encoder_league_s2_ratio
self.league_drop_net = args.encoder_league_drop_net
self.league_drop_net_prob = args.encoder_league_drop_net_prob
self.league_drop_net_mix = args.encoder_league_drop_net_mix
def get_ratio(self): def get_ratio(self):
if self.drop_net: if self.league_drop_net:
frand = float(uniform(0, 1)) frand = float(uniform(0, 1))
if self.drop_net_mix and self.training: if self.drop_net_mix and self.training:
return [frand, 1 - frand] return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training: if frand < self.league_drop_net_prob and self.training:
return [1, 0] return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training: elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1] return [0, 1]
else: else:
return [0.5, 0.5] return [0.5, 0.5]
else: else:
return [self.s1_ratio, self.s2_ratio] return [self.league_s1_ratio, self.league_s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise( return quant_noise(
...@@ -186,8 +190,8 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -186,8 +190,8 @@ class TransformerS2EncoderLayer(nn.Module):
def forward(self, x, def forward(self, x,
encoder_padding_mask: Optional[Tensor], encoder_padding_mask: Optional[Tensor],
x2 = None, s2 = None,
x2_encoder_padding_mask = None, s2_encoder_padding_mask = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None): pos_emb: Optional[Tensor] = None):
""" """
...@@ -219,6 +223,7 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -219,6 +223,7 @@ class TransformerS2EncoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
attn_x = x
if self.attn_type == "rel_selfattn": if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!" assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn( x, _ = self.self_attn(
...@@ -240,20 +245,34 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -240,20 +245,34 @@ class TransformerS2EncoderLayer(nn.Module):
attn_mask=attn_mask, attn_mask=attn_mask,
) )
x = self.dropout_module(x) x = self.dropout_module(x)
if s2 is None or self.encoder_collaboration_mode != "parallel":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if x2 is not None: if s2 is not None:
x2, _ = self.s2_attn( s2 = self.s2_norm(s2)
if self.encoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_norm(x)
x, _ = self.s2_attn(
query=x, query=x,
key=x2, key=s2,
value=x2, value=s2,
key_padding_mask=x2_encoder_padding_mask) key_padding_mask=s2_encoder_padding_mask)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.encoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=attn_x,
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask)
x2 = self.dropout_module(x2) x2 = self.dropout_module(x2)
ratio = self.get_ratio() ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1] x = x * ratio[0] + x2 * ratio[1]
x = self.residual_connection(x, residual) x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
...@@ -341,11 +360,12 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -341,11 +360,12 @@ class TransformerS2DecoderLayer(nn.Module):
self.s2_attn = MultiheadAttention( self.s2_attn = MultiheadAttention(
self.embed_dim, self.embed_dim,
args.decoder_attention_heads, args.decoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim), kdim=getattr(args, "encoder_s2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim), vdim=getattr(args, "encoder_s2_dim", self.embed_dim),
dropout=args.attention_dropout, dropout=args.attention_dropout,
encoder_decoder_attention=True, encoder_decoder_attention=True,
) )
self.s2_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = self.build_fc1( self.fc1 = self.build_fc1(
self.embed_dim, self.embed_dim,
...@@ -365,26 +385,27 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -365,26 +385,27 @@ class TransformerS2DecoderLayer(nn.Module):
self.onnx_trace = False self.onnx_trace = False
self.s1_ratio = args.encoder_s1_ratio self.decoder_collaboration_mode = args.decoder_collaboration_mode
self.s2_ratio = args.encoder_s2_ratio self.league_s1_ratio = args.decoder_league_s1_ratio
self.league_s2_ratio = args.decoder_league_s2_ratio
self.drop_net = args.encoder_drop_net self.league_drop_net = args.decoder_league_drop_net
self.drop_net_prob = args.encoder_drop_net_prob self.league_drop_net_prob = args.decoder_league_drop_net_prob
self.drop_net_mix = args.encoder_drop_net_mix self.league_drop_net_mix = args.decoder_league_drop_net_mix
def get_ratio(self): def get_ratio(self):
if self.drop_net: if self.league_drop_net:
frand = float(uniform(0, 1)) frand = float(uniform(0, 1))
if self.drop_net_mix and self.training: if self.drop_net_mix and self.training:
return [frand, 1 - frand] return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training: if frand < self.league_drop_net_prob and self.training:
return [1, 0] return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training: elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1] return [0, 1]
else: else:
return [0.5, 0.5] return [0.5, 0.5]
else: else:
return [self.s1_ratio, self.s2_ratio] return [self.league_s1_ratio, self.league_s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
...@@ -551,6 +572,8 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -551,6 +572,8 @@ class TransformerS2DecoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.encoder_attn_layer_norm(x) x = self.encoder_attn_layer_norm(x)
cross_attn_x = x
if prev_attn_state is not None: if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2] prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = { saved_state: Dict[str, Optional[Tensor]] = {
...@@ -575,9 +598,16 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -575,9 +598,16 @@ class TransformerS2DecoderLayer(nn.Module):
need_head_weights=need_head_weights, need_head_weights=need_head_weights,
) )
x = self.dropout_module(x) x = self.dropout_module(x)
if encoder_out_s2 is None or self.decoder_collaboration_mode != "parallel":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if encoder_out_s2 is not None: if encoder_out_s2 is not None:
x2, _ = self.s2_attn( if self.decoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_layer_norm(x)
x, _ = self.s2_attn(
query=x, query=x,
key=encoder_out_s2, key=encoder_out_s2,
value=encoder_out_s2, value=encoder_out_s2,
...@@ -587,10 +617,23 @@ class TransformerS2DecoderLayer(nn.Module): ...@@ -587,10 +617,23 @@ class TransformerS2DecoderLayer(nn.Module):
need_weights=need_attn or (not self.training and self.need_attn), need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights, need_head_weights=need_head_weights,
) )
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.decoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=cross_attn_x,
key=encoder_out_s2,
value=encoder_out_s2,
key_padding_mask=encoder_padding_mask_s2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x2 = self.dropout_module(x2) x2 = self.dropout_module(x2)
ratios = self.get_ratio() ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2 x = ratios[0] * x + ratios[1] * x2
x = x + x2
x = self.residual_connection(x, residual) x = self.residual_connection(x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.encoder_attn_layer_norm(x) x = self.encoder_attn_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论