Commit 0d041293 by libei

add relative_transformer model to support rpr fast decode

parent d2cb256f
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.relative_multihead_attention import RelativeMultiheadAttention
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
register_model, register_model_architecture,
)
@register_model('relative_transformer')
class RelativeTransformerModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim)
encoder = RelativeTransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = RelativeTransformerDecoder(args, tgt_dict, decoder_embed_tokens)
return RelativeTransformerModel(encoder, decoder)
class RelativeTransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__(dictionary)
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
RelativeTransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
emb_token = self.embed_tokens(src_tokens)
x = self.embed_scale * emb_token
pos_emb = self.embed_positions(src_tokens)
x += pos_emb
x = F.dropout(x, p=self.dropout, training=self.training)
#print("encoder_input:",x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if self.normalize:
x = self.layer_norm(x)
# print('enc_output:', x.size())
#print('enc_output:', x)
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict:
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class RelativeTransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=False, final_norm=True):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
RelativeTransformerDecoderLayer(args)
for i in range(args.decoder_layers)
])
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
# print('raw position:',positions.size())
if incremental_state is not None:
is_first_step = prev_output_tokens.size(1) == 1
prev_output_tokens = prev_output_tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if is_first_step:
x = positions.expand_as(x)
else:
x += positions
else:
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
x[:, 0, :].fill_(0)
x += positions
# print('dec-final-emb:', x.size())
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# decoder layers
for layer in self.layers:
x, attn = layer(
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
incremental_state,
)
if self.normalize:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
# x[:, 0] = float(-10)
# print(x[:5,:10])
return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class RelativeTransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = RelativeMultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
args.max_relative_length,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
#print("rpr_out:",x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
return x
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
class RelativeTransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(self, args):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = RelativeMultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
args.max_relative_length,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.decoder_normalize_before
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
mask_future_timesteps=True,
incremental_state=incremental_state,
need_weights=False,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
residual = x
x = self.maybe_layer_norm(2, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(2, x, after=True)
return x, attn
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
return m
@register_model_architecture('relative_transformer', 'relative_transformer')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
args.max_relative_length = getattr(args, 'max_relative_length', 16)
@register_model_architecture('relative_transformer', 'relative_transformer_wmt_en_de')
def relative_transformer_wmt_en_de(args):
base_architecture(args)
@register_model_architecture('relative_transformer', 'relative_transformer_t2t_wmt_en_de')
def relative_transformer_t2t_wmt_en_de(args):
args.encoder_normalize_before = True
args.decoder_normalize_before = True
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
args.max_relative_length = 16
base_architecture(args)
@register_model_architecture('relative_transformer', 'relative_transformer_toy')
def relative_transformer_toy(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 64)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 256)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 64)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 256)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
relative_transformer_t2t_wmt_en_de(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@register_model_architecture('relative_transformer', 'relative_transformer_vaswani_wmt_en_de_big')
def relative_transformer_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args)
@register_model_architecture('relative_transformer', 'relative_transformer_vaswani_wmt_en_fr_big')
def relative_transformer_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
relative_transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('relative_transformer', 'relative_transformer_wmt_en_de_big')
def relative_transformer_wmt_en_de_big(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
relative_transformer_vaswani_wmt_en_de_big(args)
# default parameters used in tensor2tensor implementation
@register_model_architecture('relative_transformer', 'relative_transformer_wmt_en_de_big_t2t')
def relative_transformer_wmt_en_de_big_t2t(args):
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args.encoder_normalize_before = getattr(args, 'decoder_normalize_before', True)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
relative_transformer_vaswani_wmt_en_de_big(args)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from fairseq import utils
class RelativeMultiheadAttention(nn.Module):
"""Multi-headed attention use relative position information to enchance the attention.
See "Self-Attention with Relative Position Representations" for more details.
"""
def __init__(self, embed_dim, num_heads, max_relative_length, dropout=0., bias=True):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.max_relative_length = max_relative_length
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim
self.scaling = self.head_dim ** -0.5
self._mask = None
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.relative_position_keys = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
self.relative_position_values = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight)
nn.init.xavier_normal_(self.relative_position_keys)
nn.init.xavier_normal_(self.relative_position_values)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
key = value = None
else:
saved_state = None
if qkv_same:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif kv_same:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
# this will allow us to concat it with previous value and get
# just get the previous value
k = v = q.new(0)
else:
k, v = self.in_proj_kv(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if not incremental_state:
assert q.size() == k.size() == v.size()
length = k.size()[1]
relative_positions_matrix = self._generate_relative_positions_matrix(
length, self.max_relative_length, incremental_state
)
#print("relative_positions_matrix : {}".format(relative_positions_matrix))
relation_keys = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_values)
relative_attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True)
assert list(relative_attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
relative_attn_weights += self.buffered_mask(relative_attn_weights).unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
relative_attn_weights = relative_attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
relative_attn_weights = relative_attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
).type_as(relative_attn_weights) # FP16 support: cast to float and back
relative_attn_weights = relative_attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
relative_attn_weights = F.softmax(relative_attn_weights.float(), dim=-1).type_as(relative_attn_weights)
relative_attn_weights = F.dropout(relative_attn_weights, p=self.dropout, training=self.training)
attn = self._relative_attention_inner(relative_attn_weights, v, relation_values, transpose=False)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
#print("attn : {}".format(attn))
if need_weights:
# average attention weights over heads
relative_attn_weights = relative_attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
relative_attn_weights = relative_attn_weights.sum(dim=1) / self.num_heads
else:
relative_attn_weights = None
return attn, relative_attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
def in_proj_v(self, value):
return self._in_proj(value, start=2 * self.embed_dim)
def _in_proj(self, input, start=None, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
if end is not None:
weight = weight[:end, :]
if bias is not None:
bias = bias[:end]
if start is not None:
weight = weight[start:, :]
if bias is not None:
bias = bias[start:]
return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask is None:
self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
def _generate_relative_positions_matrix(self, length, max_relative_length, incremental_state):
if not incremental_state:
# training process
range_vec = torch.arange(length)
range_mat = range_vec.repeat(length, 1)
distance_mat = range_mat - range_mat.transpose(0, 1)
else:
distance_mat = torch.range(-length + 1, 0).view(1, -1)
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_length, max_relative_length)
# position difference.
final_mat = distance_mat_clipped + max_relative_length
return final_mat
def _relative_attention_inner(self, x, y, z, transpose=True):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
Args:
x: Tensor with shape [batch_size*heads, length, length or depth].
y: Tensor with shape [batch_size*heads, length, depth].
z: Tensor with shape [length, length, depth].
transpose: Whether to tranpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
A Tensor with shape [batch_size*heads, length, length or depth].
wq: this function actually does 'X(Y+Z)', where Z is vector,
but factor above formular as: 'XY + XZ'
"""
batch_size_mul_head = x.size()[0]
length = z.size()[0]
#print(batch_size_mul_head, length)
# xy_matmul is [batch_size*heads, length, length or depth]
if transpose:
y = y.transpose(1, 2)
xy_matmul = torch.bmm(x, y)
# x_t is [length, batch_size * heads, length or depth]
x_t = x.transpose(0, 1)
# x_tz_matmul is [length, batch_size * heads, length or depth]
if transpose:
z = z.transpose(1, 2)
x_tz_matmul = torch.bmm(x_t, z).transpose(0, 1).view(batch_size_mul_head, length, -1)
#assert xy_matmul.size() == x_tz_matmul.size()
return xy_matmul + x_tz_matmul
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论