Commit d11db556 by libei

support dense relative transformer decode and model convert

parent 73e8d792
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules.layer_history import CreateLayerHistory
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
RelativeMultiheadAttention,
)
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
register_model, register_model_architecture,
)
import numpy
import fairseq.utils as util
@register_model('dense_relative_transformer')
class DenseTransformerModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
### dense layer parameters
parser.add_argument('--encoder-history-type',
help='encoder layer history type')
parser.add_argument('--decoder-history-type',
help='decoder layer history type')
parser.add_argument('--encoder-integration-type', choices=['avg', 'sum'],
help='encoder layer integration type')
parser.add_argument('--decoder-integration-type', choices=['avg', 'sum'],
help='decoder layer integration type')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim)
encoder = DenseRelativeTransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = DenseRelativeTransformerDecoder(args, tgt_dict, decoder_embed_tokens)
return DenseTransformerModel(encoder, decoder)
class DenseRelativeTransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__(dictionary)
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
self.history = CreateLayerHistory(args, is_encoder=True)
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths):
if self.history is not None:
self.history.clean()
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
# print("emb_word:", x[0])
pos_emb = self.embed_positions(src_tokens)
# print("pos_emb:", pos_emb[0])
x += pos_emb
x = F.dropout(x, p=self.dropout, training=self.training)
# print("encoder_input:", x[0])
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.history is not None:
self.history.add(x)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
# encoder layers
for layer_id, layer in enumerate(self.layers):
if self.history is not None:
x = self.history.pop()
x = layer(x, encoder_padding_mask)
if self.history is not None:
self.history.add(x)
if self.history is not None:
x = self.history.pop()
if self.normalize:
x = self.layer_norm(x)
# print('enc_output:', x.size())
# print('enc_output:', x[0][0][:10])
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict:
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class DenseRelativeTransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=False, final_norm=True):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args)
for i in range(args.decoder_layers)
])
self.history = CreateLayerHistory(args, is_encoder=False)
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if self.history is not None:
self.history.clean()
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if incremental_state is not None:
is_first_step = prev_output_tokens.size(1) == 1
prev_output_tokens = prev_output_tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if is_first_step:
x = positions.expand_as(x)
else:
x += positions
else:
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
x[:, 0, :].fill_(0)
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.history is not None:
self.history.add(x)
# decoder layers
for layer_id, layer in enumerate(self.layers):
if self.history is not None:
x = self.history.pop()
x, attn = layer(
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
incremental_state,
)
if self.history is not None:
self.history.add(x)
if self.history is not None:
x = self.history.pop()
if self.normalize:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = RelativeMultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
args.max_relative_length,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
return x
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(self, args):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = RelativeMultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
args.max_relative_length,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.decoder_normalize_before
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
mask_future_timesteps=True,
incremental_state=incremental_state,
need_weights=False,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
residual = x
x = self.maybe_layer_norm(2, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(2, x, after=True)
return x, attn
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
self.named_parameters()
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
return m
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
args.encoder_history_type = getattr(args, 'encoder_history_type', 'dense')
args.decoder_history_type = getattr(args, 'decoder_history_type', 'dense')
args.encoder_integration_type = getattr(args, 'encoder_integration_type', 'avg')
args.decoder_integration_type = getattr(args, 'decoder_integration_type', 'avg')
args.max_relative_length = 20 = getattr(args, 'max_relative_length', 20)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_wmt_en_de')
def dense_relative_transformer_wmt_en_de(args):
args.encoder_history_type = getattr(args, 'encoder_history_type', 'learnable_dense')
args.decoder_history_type = getattr(args, 'decoder_history_type', 'learnable_dense')
args.max_relative_length = 20
args.encoder_layers = 6
base_architecture(args)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_t2t_wmt_en_de')
def dense_relative_transformer_t2t_wmt_en_de(args):
args.encoder_normalize_before = True
args.decoder_normalize_before = True
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout= getattr(args, 'relu_dropout', 0.1)
args.encoder_history_type = getattr(args, 'encoder_history_type', 'learnable_dense')
args.decoder_history_type = getattr(args, 'decoder_history_type', 'learnable_dense')
args.max_relative_length = 20
args.encoder_layers = 30
base_architecture(args)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_iwslt_de_en')
def dense_relative_transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 3)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 3)
base_architecture(args)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_toy')
def dense_relative_transformer_toy(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 64)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 64)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 3)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 64)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 64)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 3)
base_architecture(args)
# parameters used in the "Learning Deep Transformer Models for Machine Translation" paper (Qiang Wang, et al, 2019)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_vaswani_wmt_en_de_big')
def dense_relative_transformer_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_vaswani_wmt_en_fr_big')
def dense_relative_transformer_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
dense_relative_transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_wmt_en_de_big')
def dense_relative_transformer_wmt_en_de_big(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
dense_relative_transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('dense_relative_transformer', 'dense_relative_transformer_t2t_wmt_en_de_big')
def dense_relative_transformer_t2t_wmt_en_de_big(args):
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', True)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
dense_relative_transformer_vaswani_wmt_en_de_big(args)
......@@ -15,7 +15,7 @@ from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .relative_multihead_attention import RelativeMultiheadAttention
__all__ = [
'AdaptiveSoftmax',
'BeamableMM',
......@@ -27,4 +27,6 @@ __all__ = [
'MultiheadAttention',
'ScalarBias',
'SinusoidalPositionalEmbedding',
'RelativeMultiheadAttention'
''
]
......@@ -472,6 +472,8 @@ def convert_settings(settings):
args['max_relative_length'] = int(settings['max_relative_length'])
args['arch'] = 'relative_transformer'
if use_relative_position_representation and use_dense:
args['arch'] = 'relative_dense_transformer'
return argparse.Namespace(**args)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论