# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.models import (
    CompositeEncoder,
    FairseqDecoder,
    FairseqEncoder,
    FairseqEncoderDecoderModel,
    register_model,
    register_model_architecture,
)
from fairseq.modules import (
    DownsampledMultiHeadAttention,
    FairseqDropout,
    GradMultiply,
    LayerNorm,
    LearnedPositionalEmbedding,
    LinearizedConvolution,
)


logger = logging.getLogger(__name__)


@register_model("fconv_self_att")
class FConvModelSelfAtt(FairseqEncoderDecoderModel):
    @classmethod
    def hub_models(cls):
        return {
            "conv.stories.pretrained": {
                "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz",
                "checkpoint_file": "pretrained_checkpoint.pt",
                "tokenizer": "nltk",
            },
            "conv.stories": {
                "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz",
                "checkpoint_file": "fusion_checkpoint.pt",
                "tokenizer": "nltk",
                "pretrained": "True",
                "pretrained_checkpoint": "./pretrained_checkpoint.pt",
            },
            # Test set containing dictionaries
            "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2",
        }

    def __init__(self, encoder, decoder, pretrained_encoder=None):
        super().__init__(encoder, decoder)
        self.encoder.num_attention_layers = sum(
            layer is not None for layer in decoder.attention
        )
        self.pretrained_encoder = pretrained_encoder
        if self.pretrained_encoder is None:
            encoders = {"encoder": encoder}
        else:
            encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder}
        # for fusion model, CompositeEncoder contains both pretrained and training encoders
        # these are forwarded and then combined in the decoder
        self.encoder = CompositeEncoder(encoders)

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--dropout', type=float, metavar='D',
                            help='dropout probability')
        parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
                            help='encoder embedding dimension')
        parser.add_argument('--encoder-layers', type=str, metavar='EXPR',
                            help='encoder layers [(dim, kernel_size), ...]')
        parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension')
        parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
                            help='decoder layers [(dim, kernel_size), ...]')
        parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
                            help='decoder output embedding dimension')
        parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
                            help='decoder attention [True, ...]')
        parser.add_argument('--self-attention', type=str, metavar='EXPR',
                            help='decoder self-attention layers, ex: [True] + [False]*5')
        parser.add_argument('--multihead-attention-nheads', type=int,
                            help='Number of heads to use in attention')
        parser.add_argument('--multihead-self-attention-nheads', type=int,
                            help='Number of heads to use in self-attention')
        parser.add_argument('--encoder-attention', type=str, metavar='EXPR',
                            help='encoder attention [True, ...]')
        parser.add_argument('--encoder-attention-nheads', type=int,
                            help='Number of heads to use in encoder attention')
        parser.add_argument('--project-input', type=str, metavar='EXPR',
                            help='Use projections in self-attention [True, ...]')
        parser.add_argument('--gated-attention', type=str, metavar='EXPR',
                            help='Use GLU layers in self-attention projections [True, ...]')
        parser.add_argument('--downsample', type=str, metavar='EXPR',
                            help='Use downsampling in self-attention [True, ...]')
        parser.add_argument('--pretrained-checkpoint', metavar='DIR',
                            help='path to load checkpoint from pretrained model')
        parser.add_argument('--pretrained', type=str, metavar='EXPR',
                            help='use pretrained model when training [True, ...]')
        # fmt: on

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        trained_encoder, trained_decoder = None, None
        pretrained = eval(args.pretrained)
        if pretrained:
            logger.info("loading pretrained model")
            if not os.path.exists(args.pretrained_checkpoint):
                new_pretrained_checkpoint = os.path.join(
                    args.data, args.pretrained_checkpoint
                )
                if os.path.exists(new_pretrained_checkpoint):
                    args.pretrained_checkpoint = new_pretrained_checkpoint
            trained_model = checkpoint_utils.load_model_ensemble(
                filenames=[args.pretrained_checkpoint],
                task=task,
            )[0][0]
            trained_decoder = list(trained_model.children())[1]
            trained_encoder = list(trained_model.children())[0]

            # freeze pretrained model
            for param in trained_decoder.parameters():
                param.requires_grad = False
            for param in trained_encoder.parameters():
                param.requires_grad = False

        encoder = FConvEncoder(
            task.source_dictionary,
            embed_dim=args.encoder_embed_dim,
            convolutions=eval(args.encoder_layers),
            dropout=args.dropout,
            max_positions=args.max_source_positions,
            attention=eval(args.encoder_attention),
            attention_nheads=args.encoder_attention_nheads,
        )

        decoder = FConvDecoder(
            task.target_dictionary,
            embed_dim=args.decoder_embed_dim,
            convolutions=eval(args.decoder_layers),
            out_embed_dim=args.decoder_out_embed_dim,
            attention=eval(args.decoder_attention),
            dropout=args.dropout,
            max_positions=args.max_target_positions,
            selfattention=eval(args.self_attention),
            attention_nheads=args.multihead_attention_nheads,
            selfattention_nheads=args.multihead_self_attention_nheads,
            project_input=eval(args.project_input),
            gated_attention=eval(args.gated_attention),
            downsample=eval(args.downsample),
            pretrained=pretrained,
            trained_decoder=trained_decoder,
        )
        model = FConvModelSelfAtt(encoder, decoder, trained_encoder)

        return model

    @property
    def pretrained(self):
        return self.pretrained_encoder is not None


class FConvEncoder(FairseqEncoder):
    """Convolutional encoder"""

    def __init__(
        self,
        dictionary,
        embed_dim=512,
        max_positions=1024,
        convolutions=((512, 3),) * 20,
        dropout=0.1,
        attention=False,
        attention_nheads=1,
    ):
        super().__init__(dictionary)
        self.dropout_module = FairseqDropout(
            dropout, module_name=self.__class__.__name__
        )
        self.num_attention_layers = None

        num_embeddings = len(dictionary)
        self.padding_idx = dictionary.pad()
        self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
        self.embed_positions = PositionalEmbedding(
            max_positions,
            embed_dim,
            self.padding_idx,
        )

        def expand_bool_array(val):
            if isinstance(val, bool):
                # expand True into [True, True, ...] and do the same with False
                return [val] * len(convolutions)
            return val

        attention = expand_bool_array(attention)

        in_channels = convolutions[0][0]
        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
        self.projections = nn.ModuleList()
        self.convolutions = nn.ModuleList()
        self.attention = nn.ModuleList()
        self.attproj = nn.ModuleList()
        for i, (out_channels, kernel_size) in enumerate(convolutions):
            self.projections.append(
                Linear(in_channels, out_channels)
                if in_channels != out_channels
                else None
            )
            self.convolutions.append(
                ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout)
            )

            self.attention.append(
                SelfAttention(out_channels, embed_dim, attention_nheads)
                if attention[i]
                else None
            )
            in_channels = out_channels

        self.fc2 = Linear(in_channels, embed_dim)

    def forward(self, src_tokens, src_lengths):
        # embed tokens and positions
        x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
        x = self.dropout_module(x)
        input_embedding = x.transpose(0, 1)

        # project to size of convolution
        x = self.fc1(x)

        encoder_padding_mask = src_tokens.eq(self.padding_idx).t()  # -> T x B
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # temporal convolutions
        for proj, conv, attention in zip(
            self.projections, self.convolutions, self.attention
        ):
            residual = x if proj is None else proj(x)

            if encoder_padding_mask is not None:
                x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)

            x = self.dropout_module(x)
            padding_l = (conv.kernel_size[0] - 1) // 2
            padding_r = conv.kernel_size[0] // 2
            x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
            x = conv(x)
            x = F.glu(x, dim=2)
            if attention is not None:
                x = attention(x)
            x = (x + residual) * math.sqrt(0.5)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # project back to size of embedding
        x = self.fc2(x)

        if encoder_padding_mask is not None:
            encoder_padding_mask = encoder_padding_mask.t()  # -> B x T
            x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)

        # scale gradients (this only affects backward, not forward)
        x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))

        # add output to input embedding for attention
        y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5)

        return {
            "encoder_out": (x, y),
            "encoder_padding_mask": encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        encoder_out["encoder_out"] = tuple(
            eo.index_select(0, new_order) for eo in encoder_out["encoder_out"]
        )

        if encoder_out["encoder_padding_mask"] is not None:
            encoder_out["encoder_padding_mask"] = encoder_out[
                "encoder_padding_mask"
            ].index_select(0, new_order)

        if "pretrained" in encoder_out:
            encoder_out["pretrained"]["encoder_out"] = tuple(
                eo.index_select(0, new_order)
                for eo in encoder_out["pretrained"]["encoder_out"]
            )

        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return self.embed_positions.max_positions


@with_incremental_state
class FConvDecoder(FairseqDecoder):
    """Convolutional decoder"""

    def __init__(
        self,
        dictionary,
        embed_dim=512,
        out_embed_dim=256,
        max_positions=1024,
        convolutions=((512, 3),) * 8,
        attention=True,
        dropout=0.1,
        selfattention=False,
        attention_nheads=1,
        selfattention_nheads=1,
        project_input=False,
        gated_attention=False,
        downsample=False,
        pretrained=False,
        trained_decoder=None,
    ):
        super().__init__(dictionary)
        self.register_buffer("version", torch.Tensor([2]))
        self.pretrained = pretrained
        self.pretrained_decoder = trained_decoder
        self.dropout_module = FairseqDropout(
            dropout, module_name=self.__class__.__name__
        )
        self.need_attn = True
        in_channels = convolutions[0][0]

        def expand_bool_array(val):
            if isinstance(val, bool):
                # expand True into [True, True, ...] and do the same with False
                return [val] * len(convolutions)
            return val

        attention = expand_bool_array(attention)
        selfattention = expand_bool_array(selfattention)

        if not isinstance(attention, list) or len(attention) != len(convolutions):
            raise ValueError(
                "Attention is expected to be a list of booleans of "
                "length equal to the number of layers."
            )

        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()
        self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)

        self.embed_positions = PositionalEmbedding(
            max_positions,
            embed_dim,
            padding_idx,
        )

        self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
        self.projections = nn.ModuleList()
        self.convolutions = nn.ModuleList()
        self.attention = nn.ModuleList()
        self.selfattention = nn.ModuleList()
        self.attproj = nn.ModuleList()
        for i, (out_channels, kernel_size) in enumerate(convolutions):
            self.projections.append(
                Linear(in_channels, out_channels)
                if in_channels != out_channels
                else None
            )
            self.convolutions.append(
                LinearizedConv1d(
                    in_channels,
                    out_channels * 2,
                    kernel_size,
                    padding=(kernel_size - 1),
                    dropout=dropout,
                )
            )

            self.attention.append(
                DownsampledMultiHeadAttention(
                    out_channels,
                    embed_dim,
                    attention_nheads,
                    project_input=project_input,
                    gated=False,
                    downsample=False,
                )
                if attention[i]
                else None
            )

            self.attproj.append(
                Linear(out_channels, embed_dim, dropout=dropout)
                if attention[i]
                else None
            )
            self.selfattention.append(
                SelfAttention(
                    out_channels,
                    embed_dim,
                    selfattention_nheads,
                    project_input=project_input,
                    gated=gated_attention,
                    downsample=downsample,
                )
                if selfattention[i]
                else None
            )
            in_channels = out_channels

        self.fc2 = Linear(in_channels, out_embed_dim)
        self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)

        # model fusion
        if self.pretrained:
            # independent gates are learned from the concatenated input
            self.gate1 = nn.Sequential(
                Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid()
            )
            self.gate2 = nn.Sequential(
                Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid()
            )
            # pretrained and trained models are joined
            self.joining = nn.Sequential(
                Linear(out_embed_dim * 2, out_embed_dim * 2),
                LayerNorm(out_embed_dim * 2),
                nn.GLU(),
                Linear(out_embed_dim, out_embed_dim * 2),
                LayerNorm(out_embed_dim * 2),
                nn.GLU(),
                Linear(out_embed_dim, out_embed_dim),
                LayerNorm(out_embed_dim),
            )
            # pretrained model contains an output layer that is nhid -> vocab size
            # but the models are combined in their hidden state
            # the hook stores the output of the pretrained model forward
            self.pretrained_outputs = {}

            def save_output():
                def hook(a, b, output):
                    self.pretrained_outputs["out"] = output

                return hook

            self.pretrained_decoder.fc2.register_forward_hook(save_output())

    def forward(self, prev_output_tokens, encoder_out):
        trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None
        encoder_out = encoder_out["encoder"]["encoder_out"]

        encoder_a, encoder_b = self._split_encoder_out(encoder_out)

        # embed positions
        positions = self.embed_positions(prev_output_tokens)

        # embed tokens and positions
        x = self.embed_tokens(prev_output_tokens) + positions
        x = self.dropout_module(x)
        target_embedding = x.transpose(0, 1)

        # project to size of convolution
        x = self.fc1(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # temporal convolutions
        avg_attn_scores = None
        for proj, conv, attention, selfattention, attproj in zip(
            self.projections,
            self.convolutions,
            self.attention,
            self.selfattention,
            self.attproj,
        ):
            residual = x if proj is None else proj(x)

            x = self.dropout_module(x)
            x = conv(x)
            x = F.glu(x, dim=2)

            # attention
            if attention is not None:
                r = x
                x, attn_scores = attention(
                    attproj(x) + target_embedding, encoder_a, encoder_b
                )
                x = x + r
                if not self.training and self.need_attn:
                    if avg_attn_scores is None:
                        avg_attn_scores = attn_scores
                    else:
                        avg_attn_scores.add_(attn_scores)

            if selfattention is not None:
                x = selfattention(x)

            x = (x + residual) * math.sqrt(0.5)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        # project back to size of vocabulary
        x = self.fc2(x)
        x = self.dropout_module(x)
        if not self.pretrained:
            x = self.fc3(x)

        # fusion gating
        if self.pretrained:
            trained_x, _ = self.pretrained_decoder.forward(
                prev_output_tokens, trained_encoder_out
            )
            y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1)
            gate1 = self.gate1(y)
            gate2 = self.gate2(y)
            gated_x1 = gate1 * x
            gated_x2 = gate2 * self.pretrained_outputs["out"]
            fusion = torch.cat([gated_x1, gated_x2], dim=-1)
            fusion = self.joining(fusion)
            fusion_output = self.fc3(fusion)
            return fusion_output, avg_attn_scores
        else:
            return x, avg_attn_scores

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return self.embed_positions.max_positions

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn

    def _split_encoder_out(self, encoder_out):
        """Split and transpose encoder outputs."""
        # transpose only once to speed up attention layers
        encoder_a, encoder_b = encoder_out
        encoder_a = encoder_a.transpose(0, 1).contiguous()
        encoder_b = encoder_b.transpose(0, 1).contiguous()
        result = (encoder_a, encoder_b)
        return result


class SelfAttention(nn.Module):
    def __init__(
        self,
        out_channels,
        embed_dim,
        num_heads,
        project_input=False,
        gated=False,
        downsample=False,
    ):
        super().__init__()
        self.attention = DownsampledMultiHeadAttention(
            out_channels,
            embed_dim,
            num_heads,
            dropout=0,
            bias=True,
            project_input=project_input,
            gated=gated,
            downsample=downsample,
        )
        self.in_proj_q = Linear(out_channels, embed_dim)
        self.in_proj_k = Linear(out_channels, embed_dim)
        self.in_proj_v = Linear(out_channels, embed_dim)
        self.ln = LayerNorm(out_channels)

    def forward(self, x):
        residual = x
        query = self.in_proj_q(x)
        key = self.in_proj_k(x)
        value = self.in_proj_v(x)
        x, _ = self.attention(
            query, key, value, mask_future_timesteps=True, use_scalar_bias=True
        )
        return self.ln(x + residual)


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    m.weight.data.normal_(0, 0.1)
    return m


def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
    m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
    m.weight.data.normal_(0, 0.1)
    return m


def Linear(in_features, out_features, dropout=0.0):
    """Weight-normalized Linear layer (input: N x T x C)"""
    m = nn.Linear(in_features, out_features)
    m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
    m.bias.data.zero_()
    return m


def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
    """Weight-normalized Conv1d layer optimized for decoding"""
    m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
    std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
    m.weight.data.normal_(mean=0, std=std)
    m.bias.data.zero_()
    return m


def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
    """Weight-normalized Conv1d layer"""
    from fairseq.modules import ConvTBC

    m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
    std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
    m.weight.data.normal_(mean=0, std=std)
    m.bias.data.zero_()
    return m


@register_model_architecture("fconv_self_att", "fconv_self_att")
def base_architecture(args):
    args.dropout = getattr(args, "dropout", 0.1)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3")
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
    args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8")
    args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
    args.decoder_attention = getattr(args, "decoder_attention", "True")
    args.self_attention = getattr(args, "self_attention", "False")
    args.encoder_attention = getattr(args, "encoder_attention", "False")
    args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1)
    args.multihead_self_attention_nheads = getattr(
        args, "multihead_self_attention_nheads", 1
    )
    args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1)
    args.project_input = getattr(args, "project_input", "False")
    args.gated_attention = getattr(args, "gated_attention", "False")
    args.downsample = getattr(args, "downsample", "False")
    args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "")
    args.pretrained = getattr(args, "pretrained", "False")


@register_model_architecture("fconv_self_att", "fconv_self_att_wp")
def fconv_self_att_wp(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
    args.encoder_layers = getattr(
        args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1"
    )
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
    args.decoder_layers = getattr(
        args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1"
    )
    args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
    args.self_attention = getattr(args, "self_attention", "True")
    args.multihead_self_attention_nheads = getattr(
        args, "multihead_self_attention_nheads", 4
    )
    args.project_input = getattr(args, "project_input", "True")
    args.gated_attention = getattr(args, "gated_attention", "True")
    args.downsample = getattr(args, "downsample", "True")
    base_architecture(args)