Commit 47e0f6e0 by xuchen

add the multibranch S2T architecture.

I also find some bugs in the dual architecture.
parent 793f553a
......@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa
from .s2t_dual import * # noqa
from .s2t_ctc import *
from .s2t_multibranch import *
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch
......@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
TransformerS2Encoder,
TransformerS2Decoder,
)
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
TransformerS2DecoderLayer,
TransformerS2EncoderLayer,
)
logger = logging.getLogger(__name__)
......
import logging
from typing import Dict, List, Optional, Tuple
import torch.nn as nn
from torch import Tensor
from fairseq import checkpoint_utils, utils
from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
PDSS2TTransformerModel,
PDSS2TTransformerEncoder,
S2TSATEModel,
)
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerS2EncoderLayer,
TransformerS2EncoderLayer,
)
from fairseq.modules.speech_to_text import Adapter
from fairseq.models.transformer_s2 import (
Embedding,
TransformerS2Decoder,
)
logger = logging.getLogger(__name__)
@register_model("s2t_multibranch")
class S2TMultiBranchModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser)
PDSS2TTransformerModel.add_specific_args(parser)
S2TSATEModel.add_specific_args(parser)
S2TMultiBranchModel.add_specific_args(parser)
@staticmethod
def add_specific_args(parser):
# multibranch
parser.add_argument(
"--junior-acoustic-encoder",
default="transformer",
choices=["transformer", "pds", "sate", "wav2vec"],
type=str,
help="the architecture of the junior acoustic encoder",
)
parser.add_argument(
"--senior-acoustic-encoder",
default="transformer",
choices=["transformer", "pds", "sate", "wav2vec"],
type=str,
help="the architecture of the senior acoustic ASR encoder",
)
parser.add_argument(
"--textual-encoder",
default="transformer",
type=str,
help="the architecture of the MT encoder",
)
parser.add_argument(
"--textual-encoder-dim",
type=int,
help="the dimension of the textual encoder",
)
parser.add_argument(
"--junior-acoustic-encoder-layers",
default=6,
type=int,
help="the layers of the senior acoustic encoder",
)
parser.add_argument(
"--senior-acoustic-encoder-layers",
default=6,
type=int,
help="the layers of the senior acoustic encoder",
)
parser.add_argument(
"--textual-encoder-layers",
default=6,
type=int,
help="the layers of the textual encoder",
)
# collaboration
parser.add_argument(
"--collaboration-direction",
default="none",
type=str,
help="direction of collaboration",
)
parser.add_argument(
"--collaboration-step",
default="1:1",
type=str,
help="collaboration step in two encoders",
)
parser.add_argument(
"--encoder-collaboration-mode",
default="serial",
type=str,
help="how to calculate attention during league in encoder",
)
parser.add_argument(
"--decoder-collaboration-mode",
default="serial",
type=str,
help="how to calculate attention during league in encoder",
)
# league
parser.add_argument(
"--encoder-league-s1-ratio",
default=0.5,
type=float,
help="league ratio of the s1 representation",
)
parser.add_argument(
"--encoder-league-s2-ratio",
default=0.5,
type=float,
help="league ratio of the s2 representation",
)
parser.add_argument(
"--encoder-league-drop-net",
action="store_true",
help="drop one input during league",
)
parser.add_argument(
"--encoder-league-drop-net-prob",
default=0.0,
type=float,
help="probability of dropping one representations",
)
parser.add_argument(
"--encoder-league-drop-net-mix",
action="store_true",
help="mix the two input with any probability",
)
parser.add_argument(
"--decoder-league-s1-ratio",
default=0.5,
type=float,
help="league ratio of the s1 representation",
)
parser.add_argument(
"--decoder-league-s2-ratio",
default=0.5,
type=float,
help="league ratio of the s2 representation",
)
parser.add_argument(
"--decoder-league-drop-net",
action="store_true",
help="drop one input during league",
)
parser.add_argument(
"--decoder-league-drop-net-prob",
default=0.0,
type=float,
help="probability of dropping one representations",
)
parser.add_argument(
"--decoder-league-drop-net-mix",
action="store_true",
help="mix the two input with any probability",
)
parser.add_argument(
"--load-pretrained-junior-acoustic-encoder-from",
type=str,
metavar="STR",
help="model to take junior acoustic encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-senior-acoustic-encoder-from",
type=str,
metavar="STR",
help="model to take senior acoustic encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-textual-encoder-from",
type=str,
metavar="STR",
help="model to take textual encoder weights from (for initialization)",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TMultiBranchEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
if getattr(args, "load_pretrained_junior_encoder_from", None):
encoder.junior_acoustic_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.asr_encoder, checkpoint=args.load_pretrained_junior_encoder_from, strict=False
)
logger.info(
f"loaded pretrained junior acoustic encoder from: "
f"{args.load_pretrained_junior_encoder_from}"
)
if getattr(args, "load_pretrained_senior_encoder_from", None):
encoder.senior_acoustic_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.asr_encoder, checkpoint=args.load_pretrained_senior_encoder_from, strict=False
)
logger.info(
f"loaded pretrained senior acoustic encoder from: "
f"{args.load_pretrained_senior_encoder_from}"
)
if getattr(args, "load_pretrained_textual_encoder_from", None):
encoder.textual_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.mt_encoder, checkpoint=args.load_pretrained_textual_encoder_from, strict=False
)
logger.info(
f"loaded pretrained textual encoder from: "
f"{args.load_pretrained_textual_encoder_from}"
)
return encoder
@classmethod
def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerS2Decoder(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim
)
encoder = cls.build_encoder(args, task, encoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
return cls(encoder, decoder)
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
return decoder_out
class S2TMultiBranchEncoder(FairseqEncoder):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None)
self.padding_idx = 1
setattr(args, "encoder_layers", args.junior_acoustic_encoder_layers)
junior_encoder_type = args.junior_acoustic_encoder
if junior_encoder_type == "transformer":
self.junior_acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens)
elif junior_encoder_type == "pds":
self.junior_acoustic_encoder = PDSS2TTransformerEncoder(args, task, embed_tokens)
else:
logger.error("Unsupported junior acoustic architecture: %s." % junior_encoder_type)
self.senior_acoustic_attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.senior_acoustic_attn_type == "rel_pos":
self.senior_acoustic_embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.senior_acoustic_attn_type in ["rel_selfattn", "rel_pos_legacy"]:
self.senior_acoustic_embed_positions = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
else: # Use absolute positional embedding
self.senior_acoustic_embed_positions = None
setattr(args, "collaboration_mode", args.encoder_collaboration_mode)
self.senior_acoustic_encoder_layer_num = args.senior_acoustic_encoder_layers
self.senior_acoustic_encoder_layers = nn.ModuleList(
[S2TTransformerS2EncoderLayer(args) for _ in range(self.senior_acoustic_encoder_layer_num)])
# adapter
self.adapter_temperature = args.adapter_temperature
strategy = {
"embed_norm": getattr(args, "adapter_embed_norm", False),
"out_norm": getattr(args, "adapter_out_norm", False),
"ctc_compress_strategy": getattr(args, "ctc_compress_strategy", None),
"distribution_cutoff": getattr(args, "adapter_distribution_cutoff", None),
"drop_prob": getattr(args, "adapter_drop_prob", 0),
}
self.adapter = Adapter(args.encoder_embed_dim,
args.adapter,
len(task.source_dictionary),
strategy=strategy)
assert not (args.share_adapter_and_ctc and args.share_adapter_and_embed), "Can not be True at the same time"
if args.share_adapter_and_ctc and hasattr(self.adapter, "embed_adapter"):
self.adapter.embed_adapter.weight = self.acoustic_encoder.ctc.ctc_projection.weight
if args.share_adapter_and_embed and hasattr(self.adapter, "embed_adapter"):
self.adapter.embed_adapter.weight = embed_tokens.weight
# textual encoder
self.textual_embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
attn_type = args.encoder_attention_type
setattr(args, "encoder_attention_type", "selfattn")
self.textual_encoder_layer_num = args.textual_encoder_layers
self.textual_encoder_layers = nn.ModuleList(
[TransformerS2EncoderLayer(args) for _ in range(self.textual_encoder_layer_num)])
setattr(args, "encoder_attention_type", attn_type)
# collaboration
collaboration_step = args.collaboration_step
if len(collaboration_step.split(":")) == 2:
self.collaboration_step = [int(s) for s in collaboration_step.split(":")]
else:
self.collaboration_step = [1, 1]
self.collaboration_direction = args.collaboration_direction
self.acoustic_norm = LayerNorm(args.encoder_embed_dim)
self.textual_norm = LayerNorm(args.encoder_embed_dim)
def forward(self, src_tokens, src_lengths=None, **kwargs):
junior_acoustic_encoder_out = self.junior_acoustic_encoder(src_tokens, src_lengths, **kwargs)
acoustic_x = junior_acoustic_encoder_out["encoder_out"][0]
acoustic_encoder_padding_mask = junior_acoustic_encoder_out["encoder_padding_mask"][0]
if "ctc_logit" in junior_acoustic_encoder_out and len(junior_acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = junior_acoustic_encoder_out["ctc_logit"][0]
else:
ctc_logit = None
x = (acoustic_x, ctc_logit)
adapter_x, adapter_encoder_padding_mask = self.adapter(x, acoustic_encoder_padding_mask)
textual_x = adapter_x + self.textual_embed_positions(adapter_encoder_padding_mask).transpose(0, 1)
# textual_x = self.dropout_module(textual_x)
textual_encoder_padding_mask = adapter_encoder_padding_mask
senior_acoustic_encoder_idx = -1
textual_encoder_idx = -1
while True:
if self.collaboration_direction == "acoustic":
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
)
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x, s2_encoder_padding_mask=textual_encoder_padding_mask
)
elif self.collaboration_direction == "textual":
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
)
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
xs=acoustic_x, s2_encoder_padding_mask=acoustic_encoder_padding_mask
)
elif self.collaboration_direction == "both":
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
s2=textual_x, s2_encoder_padding_mask=textual_encoder_padding_mask
)
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
s2=acoustic_x, s2_encoder_padding_mask=acoustic_encoder_padding_mask
)
elif self.collaboration_direction == "none":
for _ in range(self.collaboration_step[0]):
senior_acoustic_encoder_idx += 1
acoustic_x = self.senior_acoustic_encoder_layers[senior_acoustic_encoder_idx](
acoustic_x, encoder_padding_mask=acoustic_encoder_padding_mask,
)
for _ in range(self.collaboration_step[1]):
textual_encoder_idx += 1
textual_x = self.textual_encoder_layers[textual_encoder_idx](
textual_x, encoder_padding_mask=textual_encoder_padding_mask,
)
if senior_acoustic_encoder_idx == self.senior_acoustic_encoder_layer_num - 1 and \
textual_encoder_idx == self.textual_encoder_layer_num - 1:
break
acoustic_x = self.acoustic_norm(acoustic_x)
textual_x = self.acoustic_norm(textual_x)
junior_acoustic_encoder_out["encoder_out"] = [acoustic_x]
junior_acoustic_encoder_out["encoder_padding_mask"] = [acoustic_encoder_padding_mask]
junior_acoustic_encoder_out["s2_encoder_out"] = [textual_x]
junior_acoustic_encoder_out["s2_encoder_padding_mask"] = [textual_encoder_padding_mask]
return junior_acoustic_encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["s2_encoder_out"]) == 0:
new_s2_encoder_out = []
else:
new_s2_encoder_out = [encoder_out["s2_encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["s2_encoder_padding_mask"]) == 0:
new_s2_encoder_padding_mask = []
else:
new_s2_encoder_padding_mask = [
encoder_out["s2_encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
src_lengths = []
else:
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"s2_encoder_out": new_s2_encoder_out, # T x B x C
"s2_encoder_padding_mask": new_s2_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"src_lengths": src_lengths, # B x 1
}
@register_model_architecture(model_name="s2t_multibranch", arch_name="s2t_multibranch")
def base_architecture(args):
# Convolutional subsampler
args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.encoder_no_scale_embedding = getattr(args, "encoder_no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
# Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
args.cnn_module_norm = getattr(args, "cnn_module_norm", "batch_norm")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# local modeling
args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved CTC
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.target_sae_adapter = getattr(args, "target_sae_adapter", args.sae_adapter)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.share_target_sae_and_ctc = getattr(args, "share_target_sae_and_ctc", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
args.sae_gumbel = getattr(args, "sae_gumbel", False)
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False)
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
args.pds_layers = getattr(args, "pds_layers", None)
args.pds_ratios = getattr(args, "pds_ratios", None)
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", False)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
args.pds_attn_strides = getattr(args, "pds_attn_strides", None)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# dual
args.junior_acoustic_encoder = getattr(args, "junior_acoustic_encoder", "transformer")
args.senior_acoustic_encoder = getattr(args, "senior_acoustic_encoder", "transformer")
args.textual_encoder = getattr(args, "textual_encoder", "transformer")
args.textual_encoder_dim = getattr(args, "textual_encoder", args.encoder_embed_dim)
args.junior_acoustic_encoder_layers = getattr(args, "junior_acoustic_encoder_layers", 6)
args.senior_acoustic_encoder_layers = getattr(args, "senior_acoustic_encoder_layers", 6)
args.textual_encoder_layers = getattr(args, "textual_encoder_layers", 6)
args.collaboration_direction = getattr(args, "collaboration_direction", "none")
args.collaboration_step = getattr(args, "collaboration_step", "1:1")
args.encoder_collaboration_mode = getattr(args, "encoder_collaboration_mode", "serial")
args.decoder_collaboration_mode = getattr(args, "decoder_collaboration_mode", "serial")
args.encoder_league_s1_ratio = getattr(args, "encoder_league_s1_ratio", 0.5)
args.encoder_league_s2_ratio = getattr(args, "encoder_league_s2_ratio", 0.5)
args.encoder_league_drop_net = getattr(args, "encoder_league_drop_net", False)
args.encoder_league_drop_net_prob = getattr(args, "encoder_league_drop_net_prob", 0.0)
args.encoder_league_drop_net_mix = getattr(args, "encoder_league_drop_net_mix", False)
args.decoder_league_s1_ratio = getattr(args, "decoder_league_s1_ratio", 0.5)
args.decoder_league_s2_ratio = getattr(args, "decoder_league_s2_ratio", 0.5)
args.decoder_league_drop_net = getattr(args, "decoder_league_drop_net", False)
args.decoder_league_drop_net_prob = getattr(args, "decoder_league_drop_net_prob", 0.0)
args.decoder_league_drop_net_mix = getattr(args, "decoder_league_drop_net_mix", False)
# args.encoder_asr_ratio = getattr(args, "encoder_asr_ratio", 1.0)
# args.encoder_mt_ratio = getattr(args, "encoder_mt_ratio", 1.0)
# args.encoder_drop_net = getattr(args, "encoder_drop_net", False)
# args.encoder_drop_net_prob = getattr(args, "encoder_drop_net_prob", 1.0)
# args.encoder_drop_net_mix = getattr(args, "encoder_drop_net_mix", False)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_s")
def s2t_multibranch_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_s_relative")
def s2t_multibranch_s_relative(args):
args.max_encoder_relative_length = 100
args.k_only = True
s2t_multibranch_s(args)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_xs")
def s2t_multibranch_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
s2t_multibranch_s(args)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_sp")
def s2t_multibranch_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_multibranch_s(args)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_m")
def s2t_multibranch_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_mp")
def s2t_multibranch_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_multibranch_m(args)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_l")
def s2t_multibranch_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("s2t_multibranch", "s2t_multibranch_lp")
def s2t_multibranch_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_multibranch_l(args)
......@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
encoder_out_s2=encoder_out["encoder_out_s2"][0],
encoder_padding_mask_s2=encoder_out["encoder_padding_mask_s2"][0],
encoder_out_s2=encoder_out["s2_encoder_out"][0],
encoder_padding_mask_s2=encoder_out["s2_encoder_padding_mask"][0],
incremental_state=incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
......
......@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
)
from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .s2t_transformer_s2_layer import S2TTransformerS2EncoderLayer
from .pds_layer import PDSTransformerEncoderLayer
__all__ = [
......@@ -70,6 +71,7 @@ __all__ = [
"BeamableMM",
"CharacterTokenEmbedder",
"S2TTransformerEncoderLayer",
"S2TTransformerS2EncoderLayer",
"ConvolutionModule",
"ConvTBC",
"cross_entropy",
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import torch.nn as nn
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention,
ConvolutionModule,
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from torch import Tensor
from fairseq.modules.activations import get_activation_class
class FeedForwardModule(torch.nn.Module):
"""Positionwise feed forward layer used in conformer"""
def __init__(
self,
input_feat,
hidden_units,
dropout1,
dropout2,
activation_fn="relu",
bias=True,
):
"""
Args:
input_feat: Input feature dimension
hidden_units: Hidden unit dimension
dropout1: dropout value for layer1
dropout2: dropout value for layer2
activation_fn: Name of activation function
bias: If linear layers should have bias
"""
super(FeedForwardModule, self).__init__()
self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias)
self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias)
self.dropout1 = torch.nn.Dropout(dropout1)
self.dropout2 = torch.nn.Dropout(dropout2)
self.activation = get_activation_class(activation_fn)
def forward(self, x):
"""
Args:
x: Input Tensor of shape T X B X C
Returns:
Tensor of shape T X B X C
"""
x = self.w_1(x)
x = self.activation(x)
x = self.dropout1(x)
x = self.w_2(x)
return self.dropout2(x)
class S2TTransformerS2EncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.args = args
embed_dim = args.encoder_embed_dim
ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout
self.embed_dim = embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(args, self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.normalize_before = args.encoder_normalize_before
activation = getattr(args, 'encoder_activation_fn', 'relu')
if args.macaron_style:
self.macaron_ffn = FeedForwardModule(
embed_dim,
ffn_dim,
dropout,
dropout,
activation
)
self.macaron_norm = LayerNorm(embed_dim)
self.ffn_scale = 0.5
else:
self.macaron_ffn = None
self.macaron_norm = None
self.ffn_scale = 1.0
if args.use_cnn_module:
self.conv_norm = LayerNorm(embed_dim)
self.conv_module = ConvolutionModule(
self.embed_dim,
self.embed_dim,
depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish'),
norm_type=args.cnn_module_norm
)
self.final_norm = LayerNorm(embed_dim)
else:
self.conv_norm = None
self.conv_module = None
self.final_norm = None
self.ffn = FeedForwardModule(
embed_dim,
ffn_dim,
dropout,
dropout,
activation
)
self.ffn_norm = LayerNorm(self.embed_dim)
self.s2_norm = LayerNorm(self.embed_dim)
self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
kdim=getattr(args, "s2_encoder_embed_dim", self.embed_dim),
vdim=getattr(args, "s2_encoder_embed_dim", self.embed_dim),
dropout=args.attention_dropout,
self_attention=False,
)
def build_self_attention(self, args, embed_dim):
attention_heads = args.encoder_attention_heads
dropout = args.dropout
if self.attn_type == "selfattn":
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
max_relative_length = max(getattr(args, "max_encoder_relative_length", -1),
getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
attention_heads,
dropout=dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
elif self.attn_type == "rel_pos":
return RelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
elif self.attn_type == "rel_pos_legacy":
return LegacyRelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
elif self.attn_type == "rope":
return RotaryPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
precision=args.fp16
)
elif self.attn_type == "abs":
return ESPNETMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
else:
attn_func = MultiheadAttention
print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
attention_heads,
dropout=dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, x,
encoder_padding_mask: Optional[Tensor],
s2 = None,
s2_encoder_padding_mask = None,
attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
pos_emb (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# whether to use macaron style
if self.macaron_norm is not None:
residual = x
if self.normalize_before:
x = self.macaron_norm(x)
x = self.macaron_ffn(x)
x = residual + self.ffn_scale * x
if not self.normalize_before:
x = self.macaron_norm(x)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
pos_emb=pos_emb
)
else:
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if s2 is not None:
residual = x
x = self.s2_attn_norm(x)
s2 = self.s2_norm(s2)
x, _ = self.self_attn(
query=x,
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask,
need_weights=False,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
# convolution module
if self.conv_module is not None:
residual = x
x = x.transpose(0, 1)
if self.normalize_before:
x = self.conv_norm(x)
x = self.conv_module(x)
x = x.transpose(0, 1)
x = residual + x
if not self.normalize_before:
x = self.conv_norm(x)
residual = x
if self.normalize_before:
x = self.ffn_norm(x)
x = self.ffn(x)
x = self.residual_connection(self.ffn_scale * x, residual)
if not self.normalize_before:
x = self.ffn_norm(x)
if self.conv_module is not None:
x = self.final_norm(x)
return x
......@@ -79,6 +79,8 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16)
self.s2_norm = LayerNorm(self.embed_dim)
self.s2_attn_norm = LayerNorm(self.embed_dim)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
......@@ -87,26 +89,28 @@ class TransformerS2EncoderLayer(nn.Module):
dropout=args.attention_dropout,
self_attention=False,
)
self.s1_ratio = args.encoder_s1_ratio
self.s2_ratio = args.encoder_s2_ratio
self.encoder_collaboration_mode = args.encoder_collaboration_mode
self.league_s1_ratio = args.encoder_league_s1_ratio
self.league_s2_ratio = args.encoder_league_s2_ratio
self.drop_net = args.encoder_drop_net
self.drop_net_prob = args.encoder_drop_net_prob
self.drop_net_mix = args.encoder_drop_net_mix
self.league_drop_net = args.encoder_league_drop_net
self.league_drop_net_prob = args.encoder_league_drop_net_prob
self.league_drop_net_mix = args.encoder_league_drop_net_mix
def get_ratio(self):
if self.drop_net:
if self.league_drop_net:
frand = float(uniform(0, 1))
if self.drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training:
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training:
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.s1_ratio, self.s2_ratio]
return [self.league_s1_ratio, self.league_s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
......@@ -186,8 +190,8 @@ class TransformerS2EncoderLayer(nn.Module):
def forward(self, x,
encoder_padding_mask: Optional[Tensor],
x2 = None,
x2_encoder_padding_mask = None,
s2 = None,
s2_encoder_padding_mask = None,
attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None):
"""
......@@ -219,6 +223,7 @@ class TransformerS2EncoderLayer(nn.Module):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
attn_x = x
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
......@@ -240,20 +245,34 @@ class TransformerS2EncoderLayer(nn.Module):
attn_mask=attn_mask,
)
x = self.dropout_module(x)
if x2 is not None:
x2, _ = self.s2_attn(
query=x,
key=x2,
value=x2,
key_padding_mask=x2_encoder_padding_mask)
x2 = self.dropout_module(x2)
ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1]
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if s2 is None or self.encoder_collaboration_mode != "parallel":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if s2 is not None:
s2 = self.s2_norm(s2)
if self.encoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_norm(x)
x, _ = self.s2_attn(
query=x,
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.encoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=attn_x,
key=s2,
value=s2,
key_padding_mask=s2_encoder_padding_mask)
x2 = self.dropout_module(x2)
ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1]
x = self.residual_connection(x, residual)
residual = x
if self.normalize_before:
......@@ -341,11 +360,12 @@ class TransformerS2DecoderLayer(nn.Module):
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim),
kdim=getattr(args, "encoder_s2_dim", self.embed_dim),
vdim=getattr(args, "encoder_s2_dim", self.embed_dim),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.s2_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = self.build_fc1(
self.embed_dim,
......@@ -365,26 +385,27 @@ class TransformerS2DecoderLayer(nn.Module):
self.onnx_trace = False
self.s1_ratio = args.encoder_s1_ratio
self.s2_ratio = args.encoder_s2_ratio
self.decoder_collaboration_mode = args.decoder_collaboration_mode
self.league_s1_ratio = args.decoder_league_s1_ratio
self.league_s2_ratio = args.decoder_league_s2_ratio
self.drop_net = args.encoder_drop_net
self.drop_net_prob = args.encoder_drop_net_prob
self.drop_net_mix = args.encoder_drop_net_mix
self.league_drop_net = args.decoder_league_drop_net
self.league_drop_net_prob = args.decoder_league_drop_net_prob
self.league_drop_net_mix = args.decoder_league_drop_net_mix
def get_ratio(self):
if self.drop_net:
if self.league_drop_net:
frand = float(uniform(0, 1))
if self.drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training:
if frand < self.league_drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training:
elif frand > 1 - self.league_drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.s1_ratio, self.s2_ratio]
return [self.league_s1_ratio, self.league_s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
......@@ -551,6 +572,8 @@ class TransformerS2DecoderLayer(nn.Module):
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
cross_attn_x = x
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
......@@ -575,25 +598,45 @@ class TransformerS2DecoderLayer(nn.Module):
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
if encoder_out_s2 is None or self.decoder_collaboration_mode != "parallel":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if encoder_out_s2 is not None:
x2, _ = self.s2_attn(
query=x,
key=encoder_out_s2,
value=encoder_out_s2,
key_padding_mask=encoder_padding_mask_s2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x2 = self.dropout_module(x2)
ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if self.decoder_collaboration_mode == "serial":
residual = x
x = self.s2_attn_layer_norm(x)
x, _ = self.s2_attn(
query=x,
key=encoder_out_s2,
value=encoder_out_s2,
key_padding_mask=encoder_padding_mask_s2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.decoder_collaboration_mode == "parallel":
x2, _ = self.s2_attn(
query=cross_attn_x,
key=encoder_out_s2,
value=encoder_out_s2,
key_padding_mask=encoder_padding_mask_s2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x2 = self.dropout_module(x2)
ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2
x = x + x2
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论