Commit 922ef3d9 by xuchen

add the implementation of conformer

parent 0150d9ac
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
from .berard import * # noqa from .berard import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
#!/usr/bin/env python3
import logging
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import S2TTransformerModel, S2TTransformerEncoder
from fairseq.modules import (
ConformerEncoderLayer,
)
logger = logging.getLogger(__name__)
@register_model("s2t_conformer")
class S2TConformerModel(S2TTransformerModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TTransformerModel.add_args(parser)
parser.add_argument(
"--macaron-style",
default=False,
type=bool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
parser.add_argument(
"--zero-triu",
default=False,
type=bool,
help="If true, zero the uppper triangular part of attention matrix.",
)
# Relative positional encoding
parser.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
parser.add_argument(
"--use-cnn-module",
default=False,
type=bool,
help="Use convolution module or not",
)
parser.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TConformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
return encoder
class S2TConformerEncoder(S2TTransformerEncoder):
"""Speech-to-text Conformer encoder that consists of input subsampler and
Transformer encoder."""
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(args, task, embed_tokens)
self.transformer_layers = nn.ModuleList(
[ConformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
def forward(self, src_tokens, src_lengths):
x, input_lengths = self.subsample(src_tokens, src_lengths)
x = self.embed_scale * x
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
x = self.dropout_module(x)
positions = self.dropout_module(positions)
for layer in self.transformer_layers:
x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.layer_norm is not None:
x = self.layer_norm(x)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
@register_model_architecture(model_name="s2t_conformer", arch_name="s2t_conformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Conformer
args.macaron_style = getattr(args, "macaron_style", True)
args.macaron_style = getattr(args, "use_cnn_module", True)
args.macaron_style = getattr(args, "cnn_module_kernel", 31)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
@register_model_architecture("s2t_conformer", "s2t_conformer_s")
def s2t_conformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_xs")
def s2t_conformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.decoder_layers = getattr(args, "decoder_layers", 3)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_sp")
def s2t_conformer_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_m")
def s2t_conformer_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_mp")
def s2t_conformer_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_conformer_m(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_l")
def s2t_conformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_lp")
def s2t_conformer_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_conformer_l(args)
...@@ -307,8 +307,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -307,8 +307,9 @@ class S2TTransformerEncoder(FairseqEncoder):
[int(k) for k in args.conv_kernel_sizes.split(",")], [int(k) for k in args.conv_kernel_sizes.split(",")],
) )
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx args.max_source_positions, args.encoder_embed_dim, self.padding_idx, pos_emb_type=self.attn_type
) )
self.transformer_layers = nn.ModuleList( self.transformer_layers = nn.ModuleList(
...@@ -319,7 +320,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -319,7 +320,6 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
self.layer_norm = None self.layer_norm = None
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.use_ctc = ("ctc" in getattr(args, "criterion", False)) and \ self.use_ctc = ("ctc" in getattr(args, "criterion", False)) and \
(getattr(args, "ctc_weight", False) > 0) (getattr(args, "ctc_weight", False) > 0)
if self.use_ctc: if self.use_ctc:
...@@ -348,6 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -348,6 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.attn_type != "rel_selfattn": if self.attn_type != "rel_selfattn":
x += positions x += positions
x = self.dropout_module(x) x = self.dropout_module(x)
positions = self.dropout_module(positions)
for layer in self.transformer_layers: for layer in self.transformer_layers:
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
......
...@@ -8,6 +8,7 @@ from .adaptive_input import AdaptiveInput ...@@ -8,6 +8,7 @@ from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder from .character_token_embedder import CharacterTokenEmbedder
from .convolution import ConvolutionModule
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .cross_entropy import cross_entropy from .cross_entropy import cross_entropy
from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .downsampled_multihead_attention import DownsampledMultiHeadAttention
...@@ -36,12 +37,15 @@ from .transpose_last import TransposeLast ...@@ -36,12 +37,15 @@ from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .conformer_layer import ConformerEncoderLayer
__all__ = [ __all__ = [
"AdaptiveInput", "AdaptiveInput",
"AdaptiveSoftmax", "AdaptiveSoftmax",
"BeamableMM", "BeamableMM",
"CharacterTokenEmbedder", "CharacterTokenEmbedder",
"ConformerEncoderLayer",
"ConvolutionModule",
"ConvTBC", "ConvTBC",
"cross_entropy", "cross_entropy",
"DownsampledMultiHeadAttention", "DownsampledMultiHeadAttention",
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention, RelPositionMultiheadAttention, ConvolutionModule
# from .layer_norm import LayerNorm
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
class ConformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.args = args
self.embed_dim = args.encoder_embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu') or "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
if args.macaron_style:
self.macaron_fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.macaron_norm = LayerNorm(self.embed_dim)
self.ff_scale = 0.5
else:
self.macaron_fc1 = None
self.macaron_fc2 = None
self.macaron_norm = None
self.ff_scale = 1.0
if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim)
self.conv_module = ConvolutionModule(self.embed_dim, args.cnn_module_kernel, self.activation_fn)
self.final_norm(self.embed_dim)
else:
self.conv_norm = False
self.conv_module = None
self.final_norm = None
self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.ff_norm = LayerNorm(self.embed_dim)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_self_attention(self, embed_dim, args):
if self.attn_type == "selfattn":
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
else:
attn_func = MultiheadAttention
print("The attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, x,
encoder_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
positions (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# whether to use macaron style
if self.macaron_norm is not None:
residual = x
if self.normalize_before:
x = self.macaron_norm(x)
x = self.macaron_fc2(self.activation_dropout_module(self.activation_fn(self.macaron_fc1(x))))
x = residual + self.ff_scale * self.dropout_module(x)
if not self.normalize_before:
x = self.macaron_norm(x)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
pos_emb=pos_emb
)
else:
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + self.dropout_module(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
residual = x
if self.normalize_before:
x = self.ff_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.ff_norm(x)
if self.conv_module is not None:
x = self.norm_final(x)
return x
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ConvolutionModule definition."""
from torch import nn
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(self, x):
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x.transpose(1, 2)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import torch.nn as nn import torch.nn as nn
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding, RelPositionalEmbedding
def PositionalEmbedding( def PositionalEmbedding(
...@@ -14,8 +14,9 @@ def PositionalEmbedding( ...@@ -14,8 +14,9 @@ def PositionalEmbedding(
embedding_dim: int, embedding_dim: int,
padding_idx: int, padding_idx: int,
learned: bool = False, learned: bool = False,
pos_emb_type: str = None,
): ):
if learned: if learned or pos_emb_type == "learned":
# if padding_idx is specified then offset the embedding ids by # if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately # this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside # TODO: The right place for this offset would be inside
...@@ -26,6 +27,12 @@ def PositionalEmbedding( ...@@ -26,6 +27,12 @@ def PositionalEmbedding(
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None: if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
elif pos_emb_type is not None and pos_emb_type.startswith("debug"):
m = RelPositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
)
else: else:
m = SinusoidalPositionalEmbedding( m = SinusoidalPositionalEmbedding(
embedding_dim, embedding_dim,
......
...@@ -3,13 +3,10 @@ ...@@ -3,13 +3,10 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.multihead_attention import MultiheadAttention from fairseq.modules.multihead_attention import MultiheadAttention
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn from torch import Tensor, nn
...@@ -67,7 +64,6 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -67,7 +64,6 @@ class RelPositionMultiheadAttention(MultiheadAttention):
nn.init.xavier_normal_(self.pos_bias_u) nn.init.xavier_normal_(self.pos_bias_u)
nn.init.xavier_normal_(self.pos_bias_v) nn.init.xavier_normal_(self.pos_bias_v)
def forward( def forward(
self, self,
query, query,
...@@ -108,41 +104,6 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -108,41 +104,6 @@ class RelPositionMultiheadAttention(MultiheadAttention):
assert embed_dim == self.embed_dim assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
if (
False and
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None: if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state) saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state: if saved_state is not None and "prev_key" in saved_state:
...@@ -197,14 +158,19 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -197,14 +158,19 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# .view(tgt_len, bsz * self.num_heads, self.head_dim) # .view(tgt_len, bsz * self.num_heads, self.head_dim)
# .transpose(0, 1) # .transpose(0, 1)
# ) # )
# prepare q for RPE # (tgt_len, bsz, num_heads, head_dim)
q = q.contiguous().view(tgt_len, bsz, self.num_heads, self.head_dim) # prepare q for RPE # (bsz, tgt_len num_heads, head_dim)
q = q.contiguous().view(tgt_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
# k (bsz * num_heads, tgt_len, head_dim)
if k is not None: if k is not None:
k = ( k = (
k.contiguous() k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim) .view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1) .transpose(0, 1)
) )
# v (bsz * num_heads, tgt_len, head_dim)
if v is not None: if v is not None:
v = ( v = (
v.contiguous() v.contiguous()
...@@ -283,31 +249,32 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -283,31 +249,32 @@ class RelPositionMultiheadAttention(MultiheadAttention):
) )
pos_emb = pos_emb.transpose(0, 1) pos_emb = pos_emb.transpose(0, 1)
p_rep = self.linear_pos(pos_emb).view(bsz, -1, self.num_heads, self.head_dim) p = self.linear_pos(pos_emb).view(bsz, -1, self.num_heads, self.head_dim)
p_rep = p_rep.transpose(1, 2).contiguous().view(bsz * self.num_heads, -1, self.head_dim) # p (bsz * num_heads, tgt_len, head_dim)
p = p.transpose(1, 2).contiguous().view(bsz * self.num_heads, -1, self.head_dim)
# (batch * head, time1, d_k) # (batch * head, time1, d_k)
q_with_bias_u = ( q_with_bias_u = (
(q + self.pos_bias_u).contiguous() (q + self.pos_bias_u).transpose(1, 2)
.view(tgt_len, bsz * self.num_heads, self.head_dim) .contiguous()
.transpose(0, 1) .view(bsz * self.num_heads, tgt_len, self.head_dim)
) )
# (batch * head, time1, d_k) # (batch * head, time1, d_k)
q_with_bias_v = ( q_with_bias_v = (
(q + self.pos_bias_v).contiguous() (q + self.pos_bias_v).transpose(0, 1)
.view(tgt_len, bsz * self.num_heads, self.head_dim) .contiguous()
.transpose(0, 1) .view(bsz * self.num_heads, tgt_len, self.head_dim)
) )
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 # as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch * head, time1, time2) # (batch * head, time1, time2)
matrix_ac = torch.bmm(q_with_bias_u, k.transpose(1, 2)) matrix_ac = torch.matmul(q_with_bias_u, k.transpose(1, 2))
# compute matrix b and matrix d # compute matrix b and matrix d
# (batch * head, time1, time2) # (batch * head, time1, time2)
matrix_bd = torch.bmm(q_with_bias_v, p_rep.transpose(1, 2)) matrix_bd = torch.matmul(q_with_bias_v, p.transpose(1, 2))
def rel_shift(x, zero_triu=False): def rel_shift(x, zero_triu=False):
"""Compute relative positional encoding. """Compute relative positional encoding.
...@@ -315,36 +282,30 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -315,36 +282,30 @@ class RelPositionMultiheadAttention(MultiheadAttention):
Args: Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector. time1 means the length of query vector.
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns: Returns:
torch.Tensor: Output tensor. torch.Tensor: Output tensor.
""" """
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) zero_pad = torch.zeros((x.size()[0], x.size()[1], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1) x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) x_padded = x_padded.view(x.size()[0],
x = x_padded[:, :, 1:].view_as(x) x.size()[2] + 1, x.size()[1])
x = x_padded[:, 1:].view_as(x)
# zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
# x_padded = torch.cat([zero_pad, x], dim=-1)
#
# x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
# x = x_padded[:, :, 1:].view_as(x)[
# :, :, :, : x.size(-1) // 2 + 1
# ] # only keep the positions from 0 to time2
if zero_triu: if zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device) ones = torch.ones((x.size(1), x.size(2)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] x = x * torch.tril(ones, x.size(2) - x.size(1))[None, :, :]
return x return x
# matrix_bd = matrix_bd.contiguous().view(bsz, self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1)) matrix_bd = rel_shift(matrix_bd)
# matrix_bd = rel_shift(
# matrix_bd,
# ).contiguous().view(bsz * self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
attn_weights = (matrix_ac + matrix_bd) * self.scaling attn_weights = (matrix_ac + matrix_bd) * self.scaling
# attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
......
...@@ -103,3 +103,37 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -103,3 +103,37 @@ class SinusoidalPositionalEmbedding(nn.Module):
.view(bsz, seq_len, -1) .view(bsz, seq_len, -1)
.detach() .detach()
) )
class RelPositionalEmbedding(SinusoidalPositionalEmbedding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__(embedding_dim, padding_idx, init_size)
self.max_size = init_size
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep: Optional[Tensor] = None,
positions: Optional[Any] = None,
offset: int = 0
):
"""Compute positional encoding.
Args:
input (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert offset + input.size(1) < self.max_size
self.weights = self.weights.to(input.device)
pos_emb = self.weights[:, offset:offset + input.size(1)]
return pos_emb
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论