Commit f0e3290f by xuchen

add the traditional relative multihead attention

parent a3e0f4c2
......@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
base_architecture(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_s_relative")
def s2t_conformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_conformer_s(args)
@register_model_architecture("s2t_conformer", "s2t_conformer_xs")
def s2t_conformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
......
......@@ -357,6 +357,13 @@ def s2t_sate_s(args):
base_architecture(args)
@register_model_architecture("s2t_sate", "s2t_sate_s_relative")
def s2t_sate_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_sate_s(args)
@register_model_architecture("s2t_sate", "s2t_sate_xs")
def s2t_sate_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
......
......@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
......@@ -518,6 +522,9 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args):
......@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
base_architecture(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_s_relative")
def s2t_transformer_s_relative(args):
args.max_relative_length = 20
args.k_only = True
s2t_transformer_s(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_xs")
def s2t_transformer_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
......
......@@ -5,10 +5,11 @@
import math
from typing import Any, Dict, List, Optional, Tuple
import logging
import torch
import torch.nn as nn
from fairseq import utils
from fairseq import checkpoint_utils, utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
FairseqEncoder,
......@@ -35,6 +36,8 @@ from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@register_model("transformer")
class TransformerModel(FairseqEncoderDecoderModel):
......@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
# args for loading pre-trained models
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
# fmt: on
@classmethod
......@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8)
......@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
encoder = TransformerEncoder(args, src_dict, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(
decoder = TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
def forward(
......@@ -1073,6 +1134,15 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args):
args.max_relative_length = 20
args.k_only = True
base_architecture(args)
@register_model_architecture("transformer", "transformer_iwslt_de_en")
......
......@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
from .gumbel_vector_quantizer import GumbelVectorQuantizer
from .kmeans_vector_quantizer import KmeansVectorQuantizer
from .layer_drop import LayerDropModuleList
from .layer_history import CreateLayerHistory
from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
......@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .rel_position_multihead_attention import RelPositionMultiheadAttention
from .relative_multihead_attention import RelativeMultiheadAttention
from .same_pad import SamePad
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
......@@ -47,6 +49,7 @@ __all__ = [
"ConformerEncoderLayer",
"ConvolutionModule",
"ConvTBC",
"CreateLayerHistory",
"cross_entropy",
"DownsampledMultiHeadAttention",
"DynamicConv1dTBC",
......@@ -69,6 +72,7 @@ __all__ = [
"MultiheadAttention",
"PositionalEmbedding",
"RelPositionMultiheadAttention",
"RelativeMultiheadAttention",
"SamePad",
"ScalarBias",
"SinusoidalPositionalEmbedding",
......
......@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention, RelPositionMultiheadAttention
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
......@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else:
print("The attention type %s is not supported!" % self.attn_type)
exit(1)
......@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
else:
print("The attention type %s is not supported!" % self.attn_type)
exit(1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论