Commit f0e3290f by xuchen

add the traditional relative multihead attention

parent a3e0f4c2
......@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_s_relative")
def s2t_conformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_xs")
def s2t_conformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
......
......@@ -357,6 +357,13 @@ def s2t_sate_s(args):
base_architecture(args)
@register_model_architecture("s2t_sate", "s2t_sate_s_relative")
def s2t_sate_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_sate_s(args)
@register_model_architecture("s2t_sate", "s2t_sate_xs")
def s2t_sate_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
......
......@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
......@@ -518,6 +522,9 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args):
......@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
base_architecture(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_s_relative")
def s2t_transformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_transformer_s(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_xs")
def s2t_transformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
......
......@@ -5,10 +5,11 @@
import math
from typing import Any, Dict, List, Optional, Tuple
import logging
import torch
import torch.nn as nn
from fairseq import utils
from fairseq import checkpoint_utils, utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
FairseqEncoder,
......@@ -35,6 +36,8 @@ from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@register_model("transformer")
class TransformerModel(FairseqEncoderDecoderModel):
......@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
# args for loading pre-trained models
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
# fmt: on
@classmethod
......@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8)
......@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
encoder = TransformerEncoder(args, src_dict, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(
decoder = TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
def forward(
......@@ -1073,6 +1134,15 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args):
args.max_relative_length = 20
args.k_only = True
base_architecture(args)
@register_model_architecture("transformer", "transformer_iwslt_de_en")
......
......@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
from .gumbel_vector_quantizer import GumbelVectorQuantizer
from .kmeans_vector_quantizer import KmeansVectorQuantizer
from .layer_drop import LayerDropModuleList
from .layer_history import CreateLayerHistory
from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
......@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .rel_position_multihead_attention import RelPositionMultiheadAttention
from .relative_multihead_attention import RelativeMultiheadAttention
from .same_pad import SamePad
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
......@@ -47,6 +49,7 @@ __all__ = [
"ConformerEncoderLayer",
"ConvolutionModule",
"ConvTBC",
"CreateLayerHistory",
"cross_entropy",
"DownsampledMultiHeadAttention",
"DynamicConv1dTBC",
......@@ -69,6 +72,7 @@ __all__ = [
"MultiheadAttention",
"PositionalEmbedding",
"RelPositionMultiheadAttention",
"RelativeMultiheadAttention",
"SamePad",
"ScalarBias",
"SinusoidalPositionalEmbedding",
......
import torch
import torch.nn as nn
from fairseq.models.transformer import LayerNorm
import queue
import numpy as np
def CreateLayerHistory(args, is_encoder):
history_type = args.encoder_history_type if is_encoder else args.decoder_history_type
if history_type is None:
return None
elif history_type == "residual":
return ResidualLayerHistory(args, is_encoder)
elif history_type == "dense":
return DenseLayerHistory(args, is_encoder)
elif history_type == "learnable_dense":
return LearnableDenseLayerHistory(args, is_encoder)
elif history_type == "learnable_dense_mask":
return LearnableDenseMaskLayerHistory(args, is_encoder)
elif history_type == "learnable_dense_nonorm":
return LearnableDenseNoNormLayerHistory(args, is_encoder)
elif history_type == "gru":
return GruLayerHistory(args, is_encoder)
else:
raise ValueError
class BaseLayerHistory(nn.Module):
def __init__(self, args, is_encoder):
super(BaseLayerHistory, self).__init__()
self.is_encoder = is_encoder
self.normalize_before = args.encoder_normalize_before if is_encoder else args.decoder_normalize_before
# the first layer (aka. embedding layer) does not have layer normalization
layers = args.encoder_layers if is_encoder else args.decoder_layers
dim = args.encoder_embed_dim if is_encoder else args.decoder_embed_dim
self.layer_norms = nn.ModuleList(LayerNorm(dim) for _ in range(layers))
def add(self, layer):
raise NotImplemented
def pop(self):
raise NotImplemented
def clean(self):
raise NotImplemented
class ResidualLayerHistory(BaseLayerHistory):
"""
x_n = x_{n-1} + y_{n-1}
"""
def __init__(self, args, is_encoder):
super(ResidualLayerHistory, self).__init__(args, is_encoder)
self.count = 0
self.x = None
self.y = None
def add(self, layer):
if self.x is None:
self.x = layer
self.count += 1
return
self.count += 1
if self.normalize_before:
self.y = self.layer_norms[self.count - 2](layer)
else:
self.y = layer
def pop(self):
assert self.x is not None
if self.y is None:
return self.x
ret = self.x + self.y
if not self.normalize_before:
ret = self.layer_norms[self.count - 2](ret)
self.x = ret
return ret
def clean(self):
self.x = None
self.y = None
self.count = 0
class DenseLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(DenseLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.individuals = None # store past individual value, used for windows_size > 0
self.integration_type = getattr(args, 'encoder_integration_type', 'avg') if is_encoder else \
getattr(args, 'decoder_integration_type', 'avg')
# windows = 1 means not use residual connection
self.windows_size = getattr(args, 'encoder_windows_size', -1) if is_encoder else \
getattr(args, 'decoder_windows_size', -1)
if self.windows_size > 0:
assert self.windows_size <= (args.encoder_layers + 1) if is_encoder else (args.decoder_layers + 1)
self.individuals = queue.Queue(self.windows_size)
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
if self.individuals is not None:
self.individuals.put(layer)
return
# following layer
if self.normalize_before:
layer = self.layer_norms[self.count - 2](layer)
self.sum = self.sum + layer
if self.windows_size != -1 and self.count > self.windows_size:
self.sum = self.sum - self.individuals.get()
if self.individuals is not None:
self.individuals.put(layer)
def pop(self):
assert self.sum is not None
if self.integration_type == 'sum':
ret = self.sum
else:
if self.windows_size == -1:
ret = self.sum / self.count
else:
ret = self.sum / min(self.count, self.windows_size)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
if self.individuals is not None:
self.individuals.queue.clear()
class LearnableDenseLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
def extra_repr(self):
return 'n_layers={layer_num}, '.format(**self.__dict__)
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
self.layers.append(layer)
return
# following layer
if self.normalize_before:
layer = self.layer_norms[self.count - 2](layer)
self.layers.append(layer)
def pop(self):
assert len(self.layers) > 0
ret = (torch.stack(self.layers, 0) * self.weight[self.count - 1, : self.count].view(-1, 1, 1, 1)).sum(0)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
self.layers = []
def get_loss(self):
return (0.5 * (self.weight.sum(1) - 1.0) ** 2).mean()
class LearnableDenseMaskLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseMaskLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
if is_encoder:
self.weight_mask = np.loadtxt("encoder_mask.txt", dtype=float, delimiter=' ')
else:
self.weight_mask = np.loadtxt("decoder_mask.txt", dtype=float, delimiter=' ')
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
self.layers.append(layer)
return
# following layer
if self.normalize_before:
layer = self.layer_norms[self.count - 2](layer)
self.layers.append(layer)
def pop(self):
assert len(self.layers) > 0
ret = (torch.stack(self.layers, 0) * self.weight[self.count - 1, : self.count].view(-1, 1, 1, 1)).sum(0)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
self.layers = []
def get_loss(self):
return (0.5 * (self.weight.sum(1) - 1.0) ** 2).mean()
class LearnableDenseNoNormLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseNoNormLayerHistory, self).__init__(args, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
self.layers = []
self.layer_norms = None
def add(self, layer):
self.count += 1
# first layer
if self.sum is None:
self.sum = layer
self.layers.append(layer)
return
self.layers.append(layer)
def pop(self):
assert len(self.layers) > 0
ret = (torch.stack(self.layers, 0) * self.weight[self.count - 1, : self.count].view(-1, 1, 1, 1)).sum(0)
if self.count == 1 or self.normalize_before:
return ret
return self.layer_norms[self.count - 2](ret)
def clean(self):
self.sum = None
self.count = 0
self.layers = []
class GruLayerHistory(BaseLayerHistory):
"""
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(GruLayerHistory, self).__init__(args, is_encoder)
self.count = 0
self.gru = nn.GRUCell(args.encoder_embed_dim, args.encoder_embed_dim)
self.gru_cells = []
self.layer_norms = nn.ModuleList(LayerNorm(args.encoder_embed_dim) for _ in range(args.decoder_layers + 1))
self.decoder_layers = args.decoder_layers
def compute_gru(self, layer_output):
if len(self.gru_cells) == 0:
self.gru_cells.append(layer_output)
return self.layer_norms[self.count](layer_output)
self.count += 1
prev_h = self.gru_cells[-1]
L, B, H = layer_output.size()
layer_output = torch.reshape(layer_output, (-1, H))
prev_h = torch.reshape(prev_h, (-1, H))
h = self.gru(layer_output, prev_h).view(L, B, H)
self.gru_cells.append(h)
if self.count != self.decoder_layers:
return self.layer_norms[self.count](h)
else:
return None
def clean(self):
self.gru_cells = []
self.count = 0
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules.multihead_attention import MultiheadAttention
from torch import Tensor, nn
from torch.nn import Parameter
class RelativeMultiheadAttention(MultiheadAttention):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
max_relative_length=-1,
k_only=True
):
super().__init__(
embed_dim,
num_heads,
kdim,
vdim,
dropout,
bias,
add_bias_kv,
add_zero_attn,
self_attention,
encoder_decoder_attention,
q_noise,
qn_block_size,
)
self.max_relative_length = max_relative_length
self.k_only = k_only
self.relative_position_keys = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
if not self.k_only:
self.relative_position_values = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
nn.init.xavier_uniform_(self.relative_position_keys)
if not self.k_only:
nn.init.xavier_uniform_(self.relative_position_values)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
relative_positions_matrix = self._generate_relative_positions_matrix(
src_len, self.max_relative_length, incremental_state
)
if self.k_only:
relation_keys = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_keys)
else:
relation_keys = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_values)
attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
# key only mode
if self.k_only:
attn = torch.bmm(attn_probs, v)
# original implementation
else:
attn = self._relative_attention_inner(attn_probs, v, relation_values, transpose=False)
# attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
def _generate_relative_positions_matrix(self, length, max_relative_length, incremental_state):
if not incremental_state:
# training process
range_vec = torch.arange(length)
range_mat = range_vec.repeat(length, 1)
distance_mat = range_mat - range_mat.transpose(0, 1)
else:
distance_mat = torch.range(-length + 1, 0).view(1, -1)
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_length, max_relative_length)
# position difference.
final_mat = distance_mat_clipped + max_relative_length
return final_mat
def _relative_attention_inner(self, x, y, z, transpose=True):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
Args:
x: Tensor with shape [batch_size*heads, length, length or depth].
y: Tensor with shap e [batch_size*heads, length, depth].
z: Tensor with shape [length, length, depth].
transpose: Whether to tranpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
A Tensor with shape [batch_size*heads, length, length or depth].
wq: this function actually does 'X(Y+Z)', where Z is vector,
but factor above formular as: 'XY + XZ'
"""
batch_size_mul_head = x.size()[0]
length = z.size()[0]
# print(batch_size_mul_head, length)
# xy_matmul is [batch_size*heads, length, length or depth]
if transpose:
y = y.transpose(1, 2)
xy_matmul = torch.bmm(x, y)
# x_t is [length, batch_size * heads, length or depth]
x_t = x.transpose(0, 1)
# x_tz_matmul is [length, batch_size * heads, length or depth]
if transpose:
z = z.transpose(1, 2)
x_tz_matmul = torch.bmm(x_t, z).transpose(0, 1).view(batch_size_mul_head, length, -1)
attn = xy_matmul + x_tz_matmul
return attn
......@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention, RelPositionMultiheadAttention
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
......@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else:
print("The attention type %s is not supported!" % self.attn_type)
exit(1)
......@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else:
print("The attention type %s is not supported!" % self.attn_type)
exit(1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论