# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ RoBERTa: A Robustly Optimized BERT Pretraining Approach. """ import logging import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder from fairseq.models import register_model, register_model_architecture from fairseq.models.roberta import ( roberta_base_architecture, roberta_prenorm_architecture, RobertaEncoder, RobertaModel, ) from fairseq.modules import LayerNorm try: from fairseq.model_parallel.megatron.mpu import ( copy_to_model_parallel_region, gather_from_model_parallel_region, ColumnParallelLinear, VocabParallelEmbedding, ) has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False logger = logging.getLogger(__name__) @register_model("model_parallel_roberta") class ModelParallelRobertaModel(RobertaModel): def __init__(self, args, encoder): super().__init__(args, encoder) self.classification_heads = nn.ModuleDict() @staticmethod def add_args(parser): RobertaModel.add_args(parser) parser.add_argument( "--no-final-layer-norm", action="store_true", help=( "don't add final layernorm (only applicable when " "--encoder-normalize-before=True" ), ) @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) if not hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample if getattr(args, "untie_weights_roberta", False): raise NotImplementedError( "--untie-weights-roberta is not supported in model parallel mode" ) encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) return cls(args, encoder) def forward( self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs ): if classification_head_name is not None: features_only = True x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) if classification_head_name is not None: x = self.classification_heads[classification_head_name](x) return x, extra def register_classification_head( self, name, num_classes=None, inner_dim=None, **kwargs ): """Register a classification head.""" if name in self.classification_heads: prev_num_classes = self.classification_heads[name].out_proj.out_features prev_inner_dim = self.classification_heads[name].dense.out_features if num_classes != prev_num_classes or inner_dim != prev_inner_dim: logger.warning( 're-registering head "{}" with num_classes {} (prev: {}) ' "and inner_dim {} (prev: {})".format( name, num_classes, prev_num_classes, inner_dim, prev_inner_dim ) ) self.classification_heads[name] = ModelParallelRobertaClassificationHead( self.args.encoder_embed_dim, inner_dim or self.args.encoder_embed_dim, num_classes, self.args.pooler_activation_fn, self.args.pooler_dropout, ) class ModelParallelRobertaLMHead(nn.Module): """Head for masked language modeling.""" def __init__(self, embed_dim, output_dim, activation_fn, weight=None): super().__init__() self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True) self.activation_fn = utils.get_activation_fn(activation_fn) self.layer_norm = LayerNorm(embed_dim) if weight is None: weight = nn.Linear(embed_dim, output_dim, bias=False).weight self.weight = weight self.bias = nn.Parameter(torch.zeros(output_dim)) def forward(self, features, masked_tokens=None, **kwargs): # Only project the unmasked tokens while training, # saves both memory and computation if masked_tokens is not None: features = features[masked_tokens, :] x = self.dense(features) x = self.activation_fn(x) x = self.layer_norm(x) x = copy_to_model_parallel_region(x) # project back to size of vocabulary with bias x = F.linear(x, self.weight) x = gather_from_model_parallel_region(x).contiguous() x = x + self.bias return x class ModelParallelRobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__( self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout ): super().__init__() self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) self.activation_fn = utils.get_activation_fn(activation_fn) self.dropout = nn.Dropout(p=pooler_dropout) self.out_proj = nn.Linear(inner_dim, num_classes) def forward(self, features, **kwargs): x = features[:, 0, :] # take <s> token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = self.activation_fn(x) x = self.dropout(x) x = self.out_proj(x) return x class ModelParallelRobertaEncoder(RobertaEncoder): """RoBERTa encoder.""" def __init__(self, args, dictionary): super().__init__(args, dictionary) assert not self.args.untie_weights_roberta def build_embedding(self, vocab_size, embedding_dim, padding_idx): return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) def build_encoder(self, args, dictionary, embed_tokens): return ModelParallelTransformerEncoder(args, dictionary, embed_tokens) def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta") def base_architecture(args): args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False) # model parallel RoBERTa defaults to "Pre-LN" formulation roberta_prenorm_architecture(args) # earlier versions of model parallel RoBERTa removed the final layer norm @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1") def model_parallel_roberta_v1_architecture(args): args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True) base_architecture(args) @register_model_architecture( "model_parallel_roberta", "model_parallel_roberta_postnorm" ) def model_parallel_roberta_postnorm_architecture(args): # the original BERT/RoBERTa uses the "Post-LN" formulation roberta_base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base") def model_parallel_roberta_base_architecture(args): base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large") def model_parallel_roberta_large_architecture(args): args.encoder_layers = getattr(args, "encoder_layers", 24) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) base_architecture(args)