Commit f0e3290f by xuchen

add the traditional relative multihead attention

parent a3e0f4c2
...@@ -182,6 +182,13 @@ def s2t_conformer_s(args): ...@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
base_architecture(args) base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_s_relative")
def s2t_conformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_xs") @register_model_architecture("s2t_conformer", "s2t_conformer_xs")
def s2t_conformer_xs(args): def s2t_conformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_layers = getattr(args, "encoder_layers", 6)
......
...@@ -357,6 +357,13 @@ def s2t_sate_s(args): ...@@ -357,6 +357,13 @@ def s2t_sate_s(args):
base_architecture(args) base_architecture(args)
@register_model_architecture("s2t_sate", "s2t_sate_s_relative")
def s2t_sate_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_sate_s(args)
@register_model_architecture("s2t_sate", "s2t_sate_xs") @register_model_architecture("s2t_sate", "s2t_sate_xs")
def s2t_sate_xs(args): def s2t_sate_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_layers = getattr(args, "encoder_layers", 6)
......
...@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="if True, dont scale embeddings", help="if True, dont scale embeddings",
) )
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument( parser.add_argument(
"--load-pretrained-encoder-from", "--load-pretrained-encoder-from",
type=str, type=str,
...@@ -518,6 +522,9 @@ def base_architecture(args): ...@@ -518,6 +522,9 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("s2t_transformer", "s2t_transformer_s") @register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args): def s2t_transformer_s(args):
...@@ -529,6 +536,13 @@ def s2t_transformer_s(args): ...@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
base_architecture(args) base_architecture(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_s_relative")
def s2t_transformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_transformer_s(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_xs") @register_model_architecture("s2t_transformer", "s2t_transformer_xs")
def s2t_transformer_xs(args): def s2t_transformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_layers = getattr(args, "encoder_layers", 6)
......
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
import math import math
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import checkpoint_utils, utils
from fairseq.distributed import fsdp_wrap from fairseq.distributed import fsdp_wrap
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
...@@ -35,6 +36,8 @@ from torch import Tensor ...@@ -35,6 +36,8 @@ from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@register_model("transformer") @register_model("transformer")
class TransformerModel(FairseqEncoderDecoderModel): class TransformerModel(FairseqEncoderDecoderModel):
...@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='block size of quantization noise at training time') help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time') help='scalar quantization noise and scalar quantization at training time')
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
# args for loading pre-trained models
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
if getattr(args, "offload_activations", False): if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
if not args.share_all_embeddings: if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8) encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8) decoder = fsdp_wrap(decoder, min_num_params=1e8)
...@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens) encoder = TransformerEncoder(args, src_dict, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
@classmethod @classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens): def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder( decoder = TransformerDecoder(
args, args,
tgt_dict, tgt_dict,
embed_tokens, embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False), no_encoder_attn=getattr(args, "no_cross_attention", False),
) )
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
# TorchScript doesn't support optional arguments with variable length (**kwargs). # TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes. # Current workaround is to add union of all arguments in child classes.
def forward( def forward(
...@@ -1073,6 +1134,15 @@ def base_architecture(args): ...@@ -1073,6 +1134,15 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args):
args.max_relative_length = 20
args.k_only = True
base_architecture(args)
@register_model_architecture("transformer", "transformer_iwslt_de_en") @register_model_architecture("transformer", "transformer_iwslt_de_en")
......
...@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply ...@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
from .gumbel_vector_quantizer import GumbelVectorQuantizer from .gumbel_vector_quantizer import GumbelVectorQuantizer
from .kmeans_vector_quantizer import KmeansVectorQuantizer from .kmeans_vector_quantizer import KmeansVectorQuantizer
from .layer_drop import LayerDropModuleList from .layer_drop import LayerDropModuleList
from .layer_history import CreateLayerHistory
from .layer_norm import Fp32LayerNorm, LayerNorm from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
...@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution ...@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding from .positional_embedding import PositionalEmbedding
from .rel_position_multihead_attention import RelPositionMultiheadAttention from .rel_position_multihead_attention import RelPositionMultiheadAttention
from .relative_multihead_attention import RelativeMultiheadAttention
from .same_pad import SamePad from .same_pad import SamePad
from .scalar_bias import ScalarBias from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
...@@ -47,6 +49,7 @@ __all__ = [ ...@@ -47,6 +49,7 @@ __all__ = [
"ConformerEncoderLayer", "ConformerEncoderLayer",
"ConvolutionModule", "ConvolutionModule",
"ConvTBC", "ConvTBC",
"CreateLayerHistory",
"cross_entropy", "cross_entropy",
"DownsampledMultiHeadAttention", "DownsampledMultiHeadAttention",
"DynamicConv1dTBC", "DynamicConv1dTBC",
...@@ -69,6 +72,7 @@ __all__ = [ ...@@ -69,6 +72,7 @@ __all__ = [
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"RelPositionMultiheadAttention", "RelPositionMultiheadAttention",
"RelativeMultiheadAttention",
"SamePad", "SamePad",
"ScalarBias", "ScalarBias",
"SinusoidalPositionalEmbedding", "SinusoidalPositionalEmbedding",
......
import torch
import torch.nn as nn
from fairseq.models.transformer import LayerNorm
import queue
import numpy as np
def CreateLayerHistory(args, is_encoder):
history_type = args.encoder_history_type if is_encoder else args.decoder_history_type
if history_type is None:
return None
elif history_type == "residual":
return ResidualLayerHistory(args, is_encoder)
elif history_type == "dense":
return DenseLayerHistory(args, is_encoder)
elif history_type == "learnable_dense":
return LearnableDenseLayerHistory(args, is_encoder)
elif history_type == "learnable_dense_mask":
return LearnableDenseMaskLayerHistory(args, is_encoder)
elif history_type == "learnable_dense_nonorm":
return LearnableDenseNoNormLayerHistory(args, is_encoder)
elif history_type == "gru":
return GruLayerHistory(args, is_encoder)
else:
raise ValueError
class BaseLayerHistory(nn.Module):
def __init__(self, args, is_encoder):
super(BaseLayerHistory, self).__init__()
self.is_encoder = is_encoder
self.normalize_before = args.encoder_normalize_before if is_encoder else args.decoder_normalize_before
# the first layer (aka. embedding layer) does not have layer normalization
layers = args.encoder_layers if is_encoder else args.decoder_layers
dim = args.encoder_embed_dim if is_encoder else args.decoder_embed_dim
self.layer_norms = nn.ModuleList(LayerNorm(dim) for _ in range(layers))
def add(self, layer):
raise NotImplemented
def pop(self):
raise NotImplemented
def clean(self):
raise NotImplemented
class ResidualLayerHistory(BaseLayerHistory):
"""
x_n = x_{n-1} + y_{n-1}
"""
def __init__(self, args, is_encoder):
super(ResidualLayerHistory, self).__init__(args, is_encoder)
self.count = 0
self.x = None
self.y = None
def add(self, layer):
if self.x is None:
self.x = layer
self.count += 1
return
self.count += 1
if self.normalize_before:
self.y = self.layer_norms[self.count - 2](layer)
else:
self.y = layer
def pop(self):
assert self.x is not None
if self.y is None:
return self.x
ret = self.x + self.y
if not self.normalize_before:
ret = self.layer_norms[self.count - 2](ret)
self.x = ret
return ret
def clean(self):
self.x = None
self.y = None
self.count = 0
class DenseLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(DenseLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.individuals = None # store past individual value, used for windows_size > 0
self.integration_type = getattr(args, 'encoder_integration_type', 'avg') if is_encoder else \
getattr(args, 'decoder_integration_type', 'avg')
# windows = 1 means not use residual connection
self.windows_size = getattr(args, 'encoder_windows_size', -1) if is_encoder else \
getattr(args, 'decoder_windows_size', -1)
if self.windows_size > 0:
assert self.windows_size <= (args.encoder_layers + 1) if is_encoder else (args.decoder_layers + 1)
self.individuals = queue.Queue(self.windows_size)
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
if self.individuals is not None:
self.individuals.put(layer)
return
# following layer
if self.normalize_before:
layer = self.layer_norms[self.count - 2](layer)
self.sum = self.sum + layer
if self.windows_size != -1 and self.count > self.windows_size:
self.sum = self.sum - self.individuals.get()
if self.individuals is not None:
self.individuals.put(layer)
def pop(self):
assert self.sum is not None
if self.integration_type == 'sum':
ret = self.sum
else:
if self.windows_size == -1:
ret = self.sum / self.count
else:
ret = self.sum / min(self.count, self.windows_size)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
if self.individuals is not None:
self.individuals.queue.clear()
class LearnableDenseLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
def extra_repr(self):
return 'n_layers={layer_num}, '.format(**self.__dict__)
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
self.layers.append(layer)
return
# following layer
if self.normalize_before:
layer = self.layer_norms[self.count - 2](layer)
self.layers.append(layer)
def pop(self):
assert len(self.layers) > 0
ret = (torch.stack(self.layers, 0) * self.weight[self.count - 1, : self.count].view(-1, 1, 1, 1)).sum(0)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
self.layers = []
def get_loss(self):
return (0.5 * (self.weight.sum(1) - 1.0) ** 2).mean()
class LearnableDenseMaskLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseMaskLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
if is_encoder:
self.weight_mask = np.loadtxt("encoder_mask.txt", dtype=float, delimiter=' ')
else:
self.weight_mask = np.loadtxt("decoder_mask.txt", dtype=float, delimiter=' ')
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
self.layers.append(layer)
return
# following layer
if self.normalize_before:
layer = self.layer_norms[self.count - 2](layer)
self.layers.append(layer)
def pop(self):
assert len(self.layers) > 0
ret = (torch.stack(self.layers, 0) * self.weight[self.count - 1, : self.count].view(-1, 1, 1, 1)).sum(0)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
self.layers = []
def get_loss(self):
return (0.5 * (self.weight.sum(1) - 1.0) ** 2).mean()
class LearnableDenseNoNormLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseNoNormLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
self.layers = []
self.layer_norms = None
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
self.layers.append(layer)
return
self.layers.append(layer)
def pop(self):
assert len(self.layers) > 0
ret = (torch.stack(self.layers, 0) * self.weight[self.count - 1, : self.count].view(-1, 1, 1, 1)).sum(0)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
self.layers = []
class GruLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(GruLayerHistory, self).__init__(args, is_encoder)
self.count = 0
self.gru = nn.GRUCell(args.encoder_embed_dim, args.encoder_embed_dim)
self.gru_cells = []
self.layer_norms = nn.ModuleList(LayerNorm(args.encoder_embed_dim) for _ in range(args.decoder_layers + 1))
self.decoder_layers = args.decoder_layers
def compute_gru(self, layer_output):
if len(self.gru_cells) == 0:
self.gru_cells.append(layer_output)
return self.layer_norms[self.count](layer_output)
self.count += 1
prev_h = self.gru_cells[-1]
L, B, H = layer_output.size()
layer_output = torch.reshape(layer_output, (-1, H))
prev_h = torch.reshape(prev_h, (-1, H))
h = self.gru(layer_output, prev_h).view(L, B, H)
self.gru_cells.append(h)
if self.count != self.decoder_layers:
return self.layer_norms[self.count](h)
else:
return None
def clean(self):
self.gru_cells = []
self.count = 0
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules.multihead_attention import MultiheadAttention
from torch import Tensor, nn
from torch.nn import Parameter
class RelativeMultiheadAttention(MultiheadAttention):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
max_relative_length=-1,
k_only=True
):
super().__init__(
embed_dim,
num_heads,
kdim,
vdim,
dropout,
bias,
add_bias_kv,
add_zero_attn,
self_attention,
encoder_decoder_attention,
q_noise,
qn_block_size,
)
self.max_relative_length = max_relative_length
self.k_only = k_only
self.relative_position_keys = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
if not self.k_only:
self.relative_position_values = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
nn.init.xavier_uniform_(self.relative_position_keys)
if not self.k_only:
nn.init.xavier_uniform_(self.relative_position_values)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
relative_positions_matrix = self._generate_relative_positions_matrix(
src_len, self.max_relative_length, incremental_state
)
if self.k_only:
relation_keys = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_keys)
else:
relation_keys = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_values)
attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
# key only mode
if self.k_only:
attn = torch.bmm(attn_probs, v)
# original implementation
else:
attn = self._relative_attention_inner(attn_probs, v, relation_values, transpose=False)
# attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
def _generate_relative_positions_matrix(self, length, max_relative_length, incremental_state):
if not incremental_state:
# training process
range_vec = torch.arange(length)
range_mat = range_vec.repeat(length, 1)
distance_mat = range_mat - range_mat.transpose(0, 1)
else:
distance_mat = torch.range(-length + 1, 0).view(1, -1)
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_length, max_relative_length)
# position difference.
final_mat = distance_mat_clipped + max_relative_length
return final_mat
def _relative_attention_inner(self, x, y, z, transpose=True):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
Args:
x: Tensor with shape [batch_size*heads, length, length or depth].
y: Tensor with shap e [batch_size*heads, length, depth].
z: Tensor with shape [length, length, depth].
transpose: Whether to tranpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
A Tensor with shape [batch_size*heads, length, length or depth].
wq: this function actually does 'X(Y+Z)', where Z is vector,
but factor above formular as: 'XY + XZ'
"""
batch_size_mul_head = x.size()[0]
length = z.size()[0]
# print(batch_size_mul_head, length)
# xy_matmul is [batch_size*heads, length, length or depth]
if transpose:
y = y.transpose(1, 2)
xy_matmul = torch.bmm(x, y)
# x_t is [length, batch_size * heads, length or depth]
x_t = x.transpose(0, 1)
# x_tz_matmul is [length, batch_size * heads, length or depth]
if transpose:
z = z.transpose(1, 2)
x_tz_matmul = torch.bmm(x_t, z).transpose(0, 1).view(batch_size_mul_head, length, -1)
attn = xy_matmul + x_tz_matmul
return attn
...@@ -8,7 +8,12 @@ from typing import Dict, List, Optional ...@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention, RelPositionMultiheadAttention from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention
)
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
...@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module): ...@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else: else:
print("The attention type %s is not supported!" % self.attn_type) print("The attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
...@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module): ...@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else: else:
print("The attention type %s is not supported!" % self.attn_type) print("The attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论