Commit b23817e0 by xuchen

implement the pyramid transformer

parent 7802e6f7
......@@ -7,4 +7,5 @@ from .berard import * # noqa
from .convtransformer import * # noqa
from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
from .pys2t_transformer import * # noqa
from .s2t_sate import * # noqa
#!/usr/bin/env python3
import logging
import math
import torch
from functools import reduce
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import S2TTransformerModel
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
PyramidTransformerEncoderLayer,
)
logger = logging.getLogger(__name__)
def lengths_to_padding_mask_with_maxlen(lens, max_length):
bsz = lens.size(0)
mask = torch.arange(max_length).to(lens.device).view(1, max_length)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_length)
return mask
class ReducedEmbed(nn.Module):
# Reduced embedding for Pyramid Transformer
def __init__(
self,
reduced_way: str,
embed_norm: bool,
in_channels: int,
out_channels: int,
kernel_sizes: int,
stride: int,
padding: int,
):
super().__init__()
self.stride = stride
self.reduced_way = reduced_way
if self.reduced_way == "conv":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding)
elif self.reduced_way == "glu":
self.conv = nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding)
self.glu = nn.GLU(dim=1)
else:
logger.error("Unsupported reduced way!")
self.embed_norm = embed_norm
if self.embed_norm:
# self.norm = LayerNorm(out_channels)
self.norm = LayerNorm(in_channels)
def forward(self, x, lengths):
seq_len, bsz, dim = x.size()
assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, seq_len) # bsz, seq_len
mask_pad = padding_mask.unsqueeze(2)
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
if self.embed_norm:
x = self.norm(x)
x = x.permute(1, 2, 0) # B * D * T
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
x = x.permute(2, 0, 1) # T * B * D
lengths = lengths / self.stride
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
mask_pad = padding_mask.unsqueeze(2)
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x, lengths, padding_mask
@register_model("pys2t_transformer")
class PYS2TTransformerModel(S2TTransformerModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser)
parser.add_argument(
"--pyramid-stages",
type=int,
help="the number of the stage",
)
parser.add_argument(
"--pyramid-layers",
type=str,
help="the number of the encoder layers",
)
parser.add_argument(
"--pyramid-sr-ratios",
type=str,
help="the ratio of the subsampling",
)
parser.add_argument(
"--pyramid-attn-sample-ratio",
type=str,
help="the ratio of the subsampling in the self attention module",
)
parser.add_argument(
"--pyramid-reduced-embed",
type=str,
choices=["glu", "conv"],
help="the reduced way of the embedding",
)
parser.add_argument(
"--pyramid-embed-norm",
action="store_true",
help="use layer norm in reduced embedding",
)
parser.add_argument(
"--pyramid-position-embed",
type=str,
help="use the position embedding or not",
)
parser.add_argument(
"--pyramid-embed-dims",
type=str,
help="the embedding dimension",
)
parser.add_argument(
"--pyramid-kernel-sizes",
type=str,
help="the kernel size of the reduced embedding",
)
parser.add_argument(
"--pyramid-ffn-ratios",
type=str,
help="the ratio of the ffn",
)
parser.add_argument(
"--pyramid-heads",
type=str,
help="the number of the attention heads",
)
parser.add_argument(
"--ctc-layer",
type=int,
help="the position of the ctc loss",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = PyS2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
return encoder
class PyS2TTransformerEncoder(FairseqEncoder):
"""Speech-to-text Pyramid Transformer encoder"""
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None)
self.padding_idx = 1
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.pyramid_stages = getattr(args, "pyramid_stages", 4)
self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")]
self.pyramid_sr_ratios = [int(n) for n in args.pyramid_sr_ratios.split("_")]
self.pyramid_attn_sample_ratios = [int(n) for n in args.pyramid_attn_sample_ratios.split("_")]
self.pyramid_embed_dims = [int(n) for n in args.pyramid_embed_dims.split("_")]
self.pyramid_position_embed = [int(n) for n in args.pyramid_position_embed.split("_")]
self.pyramid_kernel_sizes = [int(n) for n in args.pyramid_kernel_sizes.split("_")]
self.pyramid_ffn_ratios = [int(n) for n in args.pyramid_ffn_ratios.split("_")]
self.pyramid_heads = [int(n) for n in args.pyramid_heads.split("_")]
self.pyramid_reduced_embed = args.pyramid_reduced_embed
self.pyramid_embed_norm = args.pyramid_embed_norm
for i in range(self.pyramid_stages):
num_layers = self.pyramid_layers[i]
sr_ratio = self.pyramid_sr_ratios[i]
attn_sample_ratio = self.pyramid_attn_sample_ratios[i]
embed_dim = self.pyramid_embed_dims[i]
kernel_size = self.pyramid_kernel_sizes[i]
ffn_ratio = self.pyramid_ffn_ratios[i]
num_head = self.pyramid_heads[i]
use_pos_embed = self.pyramid_position_embed[i]
if i == 0:
self.embed_scale = math.sqrt(embed_dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
reduced_embed = ReducedEmbed(
self.pyramid_reduced_embed,
self.pyramid_embed_norm if i != 0 else False,
args.input_feat_per_channel * args.input_channels if i == 0 else self.pyramid_embed_dims[i-1],
embed_dim,
kernel_sizes=kernel_size,
stride=sr_ratio,
padding=kernel_size // 2,
)
if use_pos_embed:
pos_embed = PositionalEmbedding(
args.max_source_positions, embed_dim,
self.padding_idx, pos_emb_type=self.attn_type
)
else:
pos_embed = None
dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
block = nn.ModuleList([
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio)
for _ in range(num_layers)])
setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"dropout{i + 1}", dropout)
setattr(self, f"block{i + 1}", block)
if i == self.pyramid_stages - 1:
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and
(getattr(args, "ctc_weight", False) > 0))
if self.use_ctc:
self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers
self.inter_ctc = True if self.ctc_layer != args.encoder_layers - 1 else False
if task.source_dictionary == task.target_dictionary and getattr(args, "share_all_embeddings", False):
self.ctc_projection = nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
self.ctc_projection.weight = embed_tokens.weight
else:
embed_dim = self.pyramid_embed_dims[-1]
if self.inter_ctc:
ctc_layer = self.ctc_layer
for i in range(self.pyramid_stages):
ctc_layer -= self.pyramid_layers[i]
if ctc_layer <= 0:
embed_dim = self.pyramid_embed_dims[i]
self.ctc_layer_norm = LayerNorm(embed_dim)
self.ctc_projection = nn.Linear(embed_dim, len(task.source_dictionary), bias=False)
nn.init.normal_(
self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5
)
self.ctc_dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, src_tokens, src_lengths):
batch = src_tokens.size(0)
x = src_tokens.transpose(0, 1)
input_lengths = src_lengths
# padding to the multiply of 2
max_len = x.size(0)
length = reduce(lambda a, b: a*b, self.pyramid_sr_ratios)
padding_to_len = (length - max_len % length)
if padding_to_len > 0:
padding_for_pyramid = x.new_zeros((padding_to_len, batch, x.size(2)))
x = torch.cat([x, padding_for_pyramid], dim=0)
layer_idx = 0
ctc_logit = None
for i in range(self.pyramid_stages):
reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
dropout = getattr(self, f"dropout{i + 1}")
block = getattr(self, f"block{i + 1}")
if i == 0:
x = self.embed_scale * x
# reduced embed
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths)
# max_lens = int(x.size(0))
# encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, max_lens)
# add the position encoding and dropout
if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
if i == 0:
x = dropout(x)
positions = dropout(positions)
else:
positions = None
for layer in block:
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"ctc_logit": [ctc_logit if ctc_logit is not None else x],
"src_tokens": [],
"src_lengths": [],
}
def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!"
if isinstance(encoder_out, dict) and "encoder_out" in encoder_out:
encoder_state = encoder_out["encoder_out"][0]
else:
encoder_state = encoder_out
ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state))
return ctc_logit
def compute_ctc_prob(self, encoder_out, temperature=1.0):
assert self.use_ctc, "CTC is not available!"
ctc_logit = self.compute_ctc_logit(encoder_out) / temperature
return self.softmax(ctc_logit)
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]]
)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], # B x T
"src_lengths": [], # B x 1
}
@register_model_architecture(model_name="pys2t_transformer", arch_name="pys2t_transformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Pyramid
args.pyramid_stages = getattr(args, "pyramid_stages", None)
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", None)
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", None)
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", None)
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", None)
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", None)
args.pyramid_heads = getattr(args, "pyramid_heads", None)
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", None)
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
args.ctc_layer = getattr(args, "ctc_layer", -1)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_s")
def pys2t_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
args.pyramid_stages = getattr(args, "pyramid_stages", 4)
args.pyramid_layers = getattr(args, "pyramid_layers", "3_3_3_3")
args.pyramid_embed_dims = getattr(args, "pyramid_embed_dims", "64_128_256_512")
args.pyramid_kernel_sizes = getattr(args, "pyramid_kernel_sizes", "2_2_2_2")
args.pyramid_ffn_ratios = getattr(args, "pyramid_ffn_ratios", "4_4_4_4")
args.pyramid_attn_sample_ratios = getattr(args, "pyramid_attn_sample_ratios", "8_4_2_1")
args.pyramid_sr_ratios = getattr(args, "pyramid_sr_ratios", "2_2_2_2")
args.pyramid_heads = getattr(args, "pyramid_heads", "1_2_4_8")
args.pyramid_position_embed = getattr(args, "pyramid_position_embed", "1_1_1_1")
args.pyramid_reduced_embed = getattr(args, "pyramid_reduced_embed", "conv")
args.pyramid_embed_norm = getattr(args, "pyramid_embed_norm", False)
base_architecture(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_s_relative")
def pys2t_transformer_s_relative(args):
args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20
args.k_only = True
pys2t_transformer_s(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_xs")
def pys2t_transformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 3)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
pys2t_transformer_s(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_sp")
def pys2t_transformer_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
pys2t_transformer_s(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_m")
def pys2t_transformer_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_mp")
def pys2t_transformer_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
pys2t_transformer_m(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_l")
def pys2t_transformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("pys2t_transformer", "pys2t_transformer_lp")
def pys2t_transformer_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
pys2t_transformer_l(args)
......@@ -17,7 +17,9 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
S2TConformerEncoder,
S2TConformerModel
S2TConformerModel,
PYS2TTransformerModel,
PyS2TTransformerEncoder,
)
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import (
......@@ -46,6 +48,7 @@ class S2TSATEModel(S2TTransformerModel):
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TConformerModel.add_args(parser)
PYS2TTransformerModel.add_args(parser)
parser.add_argument(
"--text-encoder-layers",
......@@ -195,13 +198,16 @@ class Adapter(nn.Module):
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "none":
out = representation
else:
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
......@@ -262,6 +268,8 @@ class S2TSATEEncoder(FairseqEncoder):
self.acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "conformer":
self.acoustic_encoder = S2TConformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "pyramid":
self.acoustic_encoder = PyS2TTransformerEncoder(args, task, embed_tokens)
else:
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
......@@ -277,9 +285,9 @@ class S2TSATEEncoder(FairseqEncoder):
# )
acoustic_encoder_attention_type = args.encoder_attention_type
if acoustic_encoder_attention_type != "selfattn":
args.encoder_attention_type = "selfattn"
logger.info("Force self attention for text encoder.")
# if acoustic_encoder_attention_type != "selfattn":
# args.encoder_attention_type = "selfattn"
# logger.info("Force self attention for text encoder.")
# text encoder
self.text_encoder = TextEncoder(args, embed_tokens)
......@@ -378,6 +386,9 @@ def base_architecture(args):
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
......
......@@ -147,10 +147,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
"local",
],
help="transformer encoder self-attention layer type"
)
......
......@@ -29,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from .local_multihead_attention import LocalMultiheadAttention
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .reduced_multihead_attention import ReducedMultiheadAttention
from .rel_position_multihead_attention import RelPositionMultiheadAttention
from .relative_multihead_attention import RelativeMultiheadAttention
from .same_pad import SamePad
......@@ -41,6 +42,7 @@ from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock
from .conformer_layer import ConformerEncoderLayer
from .pyramid_layer import PyramidTransformerEncoderLayer
__all__ = [
"AdaptiveInput",
......@@ -74,6 +76,8 @@ __all__ = [
"LocalMultiheadAttention",
"MultiheadAttention",
"PositionalEmbedding",
"PyramidTransformerEncoderLayer",
"ReducedMultiheadAttention",
"RelPositionMultiheadAttention",
"RelativeMultiheadAttention",
"SamePad",
......
......@@ -325,7 +325,6 @@ class LocalMultiheadAttention(nn.Module):
multihead_mask_weight = None
gauss_bias = None
if self.multihead_gauss_mask_sigma is not None:
data_type = attn_weights.dtype
x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1).to(attn_weights.device)
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1).to(attn_weights.device)
dis_square = -(x1 - x2) ** 2 / 2.0
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
ReducedMultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention,
LocalMultiheadAttention,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
class PyramidTransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args, embed_dim, ffn_embed_dim, num_head, att_sample_ratio=1):
super().__init__()
self.args = args
self.embed_dim = embed_dim
self.encoder_ffn_embed_dim = ffn_embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(self.embed_dim, num_head, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu') or "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
self.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
self.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_self_attention(self, args, embed_dim, num_head, sample_ratio=1):
encoder_attention_heads = num_head
if self.attn_type == "selfattn":
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length = max(getattr(args, "max_encoder_relative_length", -1),
getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
elif self.attn_type == "local":
hard_mask_window = getattr(args, "hard_mask_window", 0)
gauss_mask_sigma = getattr(args, "gauss_mask_sigma", 0)
init_mask_weight = getattr(args, "init_mask_weight", 0)
return LocalMultiheadAttention(
embed_dim,
encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
hard_mask_window=hard_mask_window,
gauss_mask_sigma=gauss_mask_sigma,
init_mask_weight=init_mask_weight
)
elif self.attn_type == "reduced":
return ReducedMultiheadAttention(
embed_dim,
encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
sample_ratio=sample_ratio,
)
else:
print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, x,
encoder_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
pos_emb (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
pos_emb=pos_emb
)
else:
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.quant_noise = getattr(args, "quant_noise_pq", 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
self.cross_self_attention = getattr(args, "cross_self_attention", False)
self.attn_type = getattr(args, "decoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(
self.embed_dim,
args,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.activation_fn = utils.get_activation_fn(
activation=str(args.activation_fn)
if getattr(args, "activation_fn", None) is not None
else "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = self.build_fc1(
self.embed_dim,
args.decoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
args.decoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
if self.attn_type == "selfattn":
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
max_relative_length = max(getattr(args, "max_decoder_relative_length", -1), getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum decoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else:
print("The decoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not getattr(args, "cross_self_attention", False),
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def build_encoder_attention(self, embed_dim, args):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
pos_emb: Optional[Tensor] = None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
pos_emb=pos_emb
)
else:
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn
from torch.nn import Parameter
@with_incremental_state
class ReducedMultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
sample_ratio=1,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.sample_ratio = sample_ratio
if self.sample_ratio > 1:
self.sr = nn.Conv2d(embed_dim, embed_dim, kernel_size=sample_ratio, stride=sample_ratio)
self.norm = nn.LayerNorm(embed_dim)
self.reset_parameters()
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if (
self.sample_ratio == 1 and
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
q = self.q_proj(query)
if self.self_attention:
if self.sample_ratio > 1:
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim
query = self.norm(query_)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = ReducedMultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
elif key_padding_mask is not None:
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
if self.encoder_decoder_attention and input_buffer_k.size(
0
) == new_order.size(0):
break
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + "in_proj_weight"):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论