# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field from typing import Optional from fairseq import options, utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import ( FairseqLanguageModel, register_model, register_model_architecture, ) from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from omegaconf import II DEFAULT_MAX_TARGET_POSITIONS = 1024 @dataclass class TransformerLanguageModelConfig(FairseqDataclass): activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( default="relu", metadata={"help": "activation function to use"} ) dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) attention_dropout: float = field( default=0.0, metadata={"help": "dropout probability for attention weights"} ) activation_dropout: float = field( default=0.0, metadata={"help": "dropout probability after activation in FFN."} ) relu_dropout: float = field( default=0.0, metadata={"help": "dropout probability after activation in FFN."} ) decoder_embed_dim: int = field( default=512, metadata={"help": "decoder embedding dimension"} ) decoder_output_dim: int = field( default=512, metadata={"help": "decoder output dimension"} ) decoder_input_dim: int = field( default=512, metadata={"help": "decoder input dimension"} ) decoder_ffn_embed_dim: int = field( default=2048, metadata={"help": "decoder embedding dimension for FFN"} ) decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) decoder_attention_heads: int = field( default=8, metadata={"help": "num decoder attention heads"} ) decoder_normalize_before: bool = field( default=False, metadata={"help": "apply layernorm before each decoder block"} ) no_decoder_final_norm: bool = field( default=False, metadata={"help": "don't add an extra layernorm after the last decoder block"}, ) adaptive_softmax_cutoff: Optional[str] = field( default=None, metadata={ "help": "comma separated list of adaptive softmax cutoff points. " "Must be used with adaptive_loss criterion" }, ) adaptive_softmax_dropout: float = field( default=0, metadata={"help": "sets adaptive softmax dropout for the tail projections"}, ) adaptive_softmax_factor: float = field( default=4, metadata={"help": "adaptive input factor"} ) no_token_positional_embeddings: bool = field( default=False, metadata={ "help": "if set, disables positional embeddings (outside self attention)" }, ) share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} ) character_embeddings: bool = field( default=False, metadata={ "help": "if set, uses character embedding convolutions to produce token embeddings" }, ) character_filters: str = field( default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", metadata={"help": "size of character embeddings"}, ) character_embedding_dim: int = field( default=4, metadata={"help": "size of character embeddings"} ) char_embedder_highway_layers: int = field( default=2, metadata={"help": "number of highway layers for character token embeddder"}, ) adaptive_input: bool = field( default=False, metadata={"help": "if set, uses adaptive input"} ) adaptive_input_factor: float = field( default=4, metadata={"help": "adaptive input factor"} ) adaptive_input_cutoff: Optional[str] = field( default=None, metadata={"help": "comma separated list of adaptive input cutoff points."}, ) tie_adaptive_weights: bool = field( default=False, metadata={ "help": "if set, ties the weights of adaptive softmax and adaptive input" }, ) tie_adaptive_proj: bool = field( default=False, metadata={ "help": "if set, ties the projection weights of adaptive softmax and adaptive input" }, ) decoder_learned_pos: bool = field( default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) decoder_layerdrop: float = field( default=0.0, metadata={"help": "LayerDrop probability for decoder"} ) decoder_layers_to_keep: Optional[str] = field( default=None, metadata={ "help": "which layers to *keep* when pruning as a comma-separated list" }, ) layernorm_embedding: bool = field( default=False, metadata={"help": "add layernorm to embedding"} ) no_scale_embedding: bool = field( default=False, metadata={"help": "if True, dont scale embeddings"} ) checkpoint_activations: bool = field( default=False, metadata={"help": "checkpoint activations at each layer"} ) offload_activations: bool = field( default=False, metadata={"help": "move checkpointed activations to CPU after they are used."}, ) quant_noise_pq: float = field( default=0.0, metadata={"help": "iterative PQ quantization noise at training time"}, ) quant_noise_pq_block_size: int = field( default=8, metadata={"help": "block size of quantization noise at training time"}, ) # TODO common var add to parent quant_noise_scalar: float = field( default=0.0, metadata={ "help": "scalar quantization noise and scalar quantization at training time" }, ) add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") tpu: bool = II("common.tpu") @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) class TransformerLanguageModel(FairseqLanguageModel): @classmethod def hub_models(cls): def moses_fastbpe(path): return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} def spm(path): return {"path": path, "tokenizer": "space", "bpe": "sentencepiece"} return { "transformer_lm.gbw.adaptive_huge": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2", "transformer_lm.wiki103.adaptive": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2", "transformer_lm.wmt19.en": moses_fastbpe( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2" ), "transformer_lm.wmt19.de": moses_fastbpe( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2" ), "transformer_lm.wmt19.ru": moses_fastbpe( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2" ), "transformer_lm.wmt20.en": spm( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.en.tar.gz" ), "transformer_lm.wmt20.ta": spm( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.ta.tar.gz" ), "transformer_lm.wmt20.iu.news": spm( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.news.tar.gz" ), "transformer_lm.wmt20.iu.nh": spm( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.nh.tar.gz" ), } def __init__(self, decoder): super().__init__(decoder) @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_lm_architecture(args) if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) if getattr(args, "max_target_positions", None) is None: args.max_target_positions = getattr( args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS ) if args.character_embeddings: embed_tokens = CharacterTokenEmbedder( task.source_dictionary, eval(args.character_filters), args.character_embedding_dim, args.decoder_embed_dim, args.char_embedder_highway_layers, ) elif args.adaptive_input: embed_tokens = AdaptiveInput( len(task.source_dictionary), task.source_dictionary.pad(), args.decoder_input_dim, args.adaptive_input_factor, args.decoder_embed_dim, options.eval_str_list(args.adaptive_input_cutoff, type=int), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: embed_tokens = cls.build_embedding( args, task.source_dictionary, args.decoder_input_dim ) if args.tie_adaptive_weights: assert args.adaptive_input assert args.adaptive_input_factor == args.adaptive_softmax_factor assert ( args.adaptive_softmax_cutoff == args.adaptive_input_cutoff ), "{} != {}".format( args.adaptive_softmax_cutoff, args.adaptive_input_cutoff ) assert args.decoder_input_dim == args.decoder_output_dim decoder = TransformerDecoder( args, task.target_dictionary, embed_tokens, no_encoder_attn=True ) return cls(decoder) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) return embed_tokens def base_lm_architecture(args): # backward compatibility for older model checkpoints if hasattr(args, "no_tie_adaptive_proj"): # previous models defined --no-tie-adaptive-proj, so use the existence of # that option to determine if this is an "old" model checkpoint args.no_decoder_final_norm = True # old models always set this to True if args.no_tie_adaptive_proj is False: args.tie_adaptive_proj = True if hasattr(args, "decoder_final_norm"): args.no_decoder_final_norm = not args.decoder_final_norm args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.activation_fn = getattr(args, "activation_fn", "relu") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) args.add_bos_token = getattr(args, "add_bos_token", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.character_embeddings = getattr(args, "character_embeddings", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) # Model training is not stable without this args.decoder_normalize_before = True args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True @register_model_architecture("transformer_lm", "transformer_lm_big") def transformer_lm_big(args): args.decoder_layers = getattr(args, "decoder_layers", 12) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) base_lm_architecture(args) @register_model_architecture("transformer_lm", "transformer_lm_wiki103") @register_model_architecture("transformer_lm", "transformer_lm_baevski_wiki103") def transformer_lm_baevski_wiki103(args): args.decoder_layers = getattr(args, "decoder_layers", 16) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.dropout = getattr(args, "dropout", 0.3) args.adaptive_input = getattr(args, "adaptive_input", True) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", True) args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", "20000,60000") args.adaptive_softmax_cutoff = getattr( args, "adaptive_softmax_cutoff", "20000,60000" ) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0.2) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_dropout = getattr(args, "activation_dropout", 0.1) args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", True) args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", True) transformer_lm_big(args) @register_model_architecture("transformer_lm", "transformer_lm_gbw") @register_model_architecture("transformer_lm", "transformer_lm_baevski_gbw") def transformer_lm_baevski_gbw(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", True) transformer_lm_big(args) @register_model_architecture("transformer_lm", "transformer_lm_gpt") def transformer_lm_gpt(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072) args.decoder_layers = getattr(args, "decoder_layers", 12) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) @register_model_architecture("transformer_lm", "transformer_lm_gpt2_small") def transformer_lm_gpt2_small(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) args.decoder_layers = getattr(args, "decoder_layers", 24) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) @register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny") def transformer_lm_gpt2_tiny(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 64) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 64) args.decoder_layers = getattr(args, "decoder_layers", 2) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 1) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) @register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") def transformer_lm_gpt2_medium(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) args.decoder_layers = getattr(args, "decoder_layers", 36) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) @register_model_architecture("transformer_lm", "transformer_lm_gpt2_big") def transformer_lm_gpt2_big(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1600) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6400) args.decoder_layers = getattr(args, "decoder_layers", 48) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 25) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args)