Commit f0e3290f by xuchen

add the traditional relative multihead attention

parent a3e0f4c2
...@@ -182,6 +182,13 @@ def s2t_conformer_s(args): ...@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
base_architecture(args) base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_s_relative")
def s2t_conformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_xs") @register_model_architecture("s2t_conformer", "s2t_conformer_xs")
def s2t_conformer_xs(args): def s2t_conformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_layers = getattr(args, "encoder_layers", 6)
......
...@@ -357,6 +357,13 @@ def s2t_sate_s(args): ...@@ -357,6 +357,13 @@ def s2t_sate_s(args):
base_architecture(args) base_architecture(args)
@register_model_architecture("s2t_sate", "s2t_sate_s_relative")
def s2t_sate_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_sate_s(args)
@register_model_architecture("s2t_sate", "s2t_sate_xs") @register_model_architecture("s2t_sate", "s2t_sate_xs")
def s2t_sate_xs(args): def s2t_sate_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_layers = getattr(args, "encoder_layers", 6)
......
...@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="if True, dont scale embeddings", help="if True, dont scale embeddings",
) )
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument( parser.add_argument(
"--load-pretrained-encoder-from", "--load-pretrained-encoder-from",
type=str, type=str,
...@@ -518,6 +522,9 @@ def base_architecture(args): ...@@ -518,6 +522,9 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("s2t_transformer", "s2t_transformer_s") @register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args): def s2t_transformer_s(args):
...@@ -529,6 +536,13 @@ def s2t_transformer_s(args): ...@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
base_architecture(args) base_architecture(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_s_relative")
def s2t_transformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_transformer_s(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_xs") @register_model_architecture("s2t_transformer", "s2t_transformer_xs")
def s2t_transformer_xs(args): def s2t_transformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_layers = getattr(args, "encoder_layers", 6)
......
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
import math import math
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import checkpoint_utils, utils
from fairseq.distributed import fsdp_wrap from fairseq.distributed import fsdp_wrap
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
...@@ -35,6 +36,8 @@ from torch import Tensor ...@@ -35,6 +36,8 @@ from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@register_model("transformer") @register_model("transformer")
class TransformerModel(FairseqEncoderDecoderModel): class TransformerModel(FairseqEncoderDecoderModel):
...@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='block size of quantization noise at training time') help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time') help='scalar quantization noise and scalar quantization at training time')
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
# args for loading pre-trained models
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
if getattr(args, "offload_activations", False): if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
if not args.share_all_embeddings: if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8) encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8) decoder = fsdp_wrap(decoder, min_num_params=1e8)
...@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens) encoder = TransformerEncoder(args, src_dict, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
@classmethod @classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens): def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder( decoder = TransformerDecoder(
args, args,
tgt_dict, tgt_dict,
embed_tokens, embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False), no_encoder_attn=getattr(args, "no_cross_attention", False),
) )
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
# TorchScript doesn't support optional arguments with variable length (**kwargs). # TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes. # Current workaround is to add union of all arguments in child classes.
def forward( def forward(
...@@ -1073,6 +1134,15 @@ def base_architecture(args): ...@@ -1073,6 +1134,15 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args):
args.max_relative_length = 20
args.k_only = True
base_architecture(args)
@register_model_architecture("transformer", "transformer_iwslt_de_en") @register_model_architecture("transformer", "transformer_iwslt_de_en")
......
...@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply ...@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
from .gumbel_vector_quantizer import GumbelVectorQuantizer from .gumbel_vector_quantizer import GumbelVectorQuantizer
from .kmeans_vector_quantizer import KmeansVectorQuantizer from .kmeans_vector_quantizer import KmeansVectorQuantizer
from .layer_drop import LayerDropModuleList from .layer_drop import LayerDropModuleList
from .layer_history import CreateLayerHistory
from .layer_norm import Fp32LayerNorm, LayerNorm from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
...@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution ...@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding from .positional_embedding import PositionalEmbedding
from .rel_position_multihead_attention import RelPositionMultiheadAttention from .rel_position_multihead_attention import RelPositionMultiheadAttention
from .relative_multihead_attention import RelativeMultiheadAttention
from .same_pad import SamePad from .same_pad import SamePad
from .scalar_bias import ScalarBias from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
...@@ -47,6 +49,7 @@ __all__ = [ ...@@ -47,6 +49,7 @@ __all__ = [
"ConformerEncoderLayer", "ConformerEncoderLayer",
"ConvolutionModule", "ConvolutionModule",
"ConvTBC", "ConvTBC",
"CreateLayerHistory",
"cross_entropy", "cross_entropy",
"DownsampledMultiHeadAttention", "DownsampledMultiHeadAttention",
"DynamicConv1dTBC", "DynamicConv1dTBC",
...@@ -69,6 +72,7 @@ __all__ = [ ...@@ -69,6 +72,7 @@ __all__ = [
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"RelPositionMultiheadAttention", "RelPositionMultiheadAttention",
"RelativeMultiheadAttention",
"SamePad", "SamePad",
"ScalarBias", "ScalarBias",
"SinusoidalPositionalEmbedding", "SinusoidalPositionalEmbedding",
......
...@@ -8,7 +8,12 @@ from typing import Dict, List, Optional ...@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention, RelPositionMultiheadAttention from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention
)
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
...@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module): ...@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else: else:
print("The attention type %s is not supported!" % self.attn_type) print("The attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
...@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module): ...@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else: else:
print("The attention type %s is not supported!" % self.attn_type) print("The attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论