# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import List, Tuple

import torch
import torch.nn.functional as F
from fairseq.data import Dictionary
from torch import nn


CHAR_PAD_IDX = 0
CHAR_EOS_IDX = 257


logger = logging.getLogger(__name__)


class CharacterTokenEmbedder(torch.nn.Module):
    def __init__(
        self,
        vocab: Dictionary,
        filters: List[Tuple[int, int]],
        char_embed_dim: int,
        word_embed_dim: int,
        highway_layers: int,
        max_char_len: int = 50,
        char_inputs: bool = False,
    ):
        super(CharacterTokenEmbedder, self).__init__()

        self.onnx_trace = False
        self.embedding_dim = word_embed_dim
        self.max_char_len = max_char_len
        self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0)
        self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim))
        self.eos_idx, self.unk_idx = 0, 1
        self.char_inputs = char_inputs

        self.convolutions = nn.ModuleList()
        for width, out_c in filters:
            self.convolutions.append(
                nn.Conv1d(char_embed_dim, out_c, kernel_size=width)
            )

        last_dim = sum(f[1] for f in filters)

        self.highway = Highway(last_dim, highway_layers) if highway_layers > 0 else None

        self.projection = nn.Linear(last_dim, word_embed_dim)

        assert (
            vocab is not None or char_inputs
        ), "vocab must be set if not using char inputs"
        self.vocab = None
        if vocab is not None:
            self.set_vocab(vocab, max_char_len)

        self.reset_parameters()

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def set_vocab(self, vocab, max_char_len):
        word_to_char = torch.LongTensor(len(vocab), max_char_len)

        truncated = 0
        for i in range(len(vocab)):
            if i < vocab.nspecial:
                char_idxs = [0] * max_char_len
            else:
                chars = vocab[i].encode()
                # +1 for padding
                char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars))
            if len(char_idxs) > max_char_len:
                truncated += 1
                char_idxs = char_idxs[:max_char_len]
            word_to_char[i] = torch.LongTensor(char_idxs)

        if truncated > 0:
            logger.info(
                "truncated {} words longer than {} characters".format(
                    truncated, max_char_len
                )
            )

        self.vocab = vocab
        self.word_to_char = word_to_char

    @property
    def padding_idx(self):
        return Dictionary().pad() if self.vocab is None else self.vocab.pad()

    def reset_parameters(self):
        nn.init.xavier_normal_(self.char_embeddings.weight)
        nn.init.xavier_normal_(self.symbol_embeddings)
        nn.init.xavier_uniform_(self.projection.weight)

        nn.init.constant_(
            self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0
        )
        nn.init.constant_(self.projection.bias, 0.0)

    def forward(
        self,
        input: torch.Tensor,
    ):
        if self.char_inputs:
            chars = input.view(-1, self.max_char_len)
            pads = chars[:, 0].eq(CHAR_PAD_IDX)
            eos = chars[:, 0].eq(CHAR_EOS_IDX)
            if eos.any():
                if self.onnx_trace:
                    chars = torch.where(eos.unsqueeze(1), chars.new_zeros(1), chars)
                else:
                    chars[eos] = 0

            unk = None
        else:
            flat_words = input.view(-1)
            chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as(
                input
            )
            pads = flat_words.eq(self.vocab.pad())
            eos = flat_words.eq(self.vocab.eos())
            unk = flat_words.eq(self.vocab.unk())

        word_embs = self._convolve(chars)
        if self.onnx_trace:
            if pads.any():
                word_embs = torch.where(
                    pads.unsqueeze(1), word_embs.new_zeros(1), word_embs
                )
            if eos.any():
                word_embs = torch.where(
                    eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs
                )
            if unk is not None and unk.any():
                word_embs = torch.where(
                    unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs
                )
        else:
            if pads.any():
                word_embs[pads] = 0
            if eos.any():
                word_embs[eos] = self.symbol_embeddings[self.eos_idx]
            if unk is not None and unk.any():
                word_embs[unk] = self.symbol_embeddings[self.unk_idx]

        return word_embs.view(input.size()[:2] + (-1,))

    def _convolve(
        self,
        char_idxs: torch.Tensor,
    ):
        char_embs = self.char_embeddings(char_idxs)
        char_embs = char_embs.transpose(1, 2)  # BTC -> BCT

        conv_result = []

        for conv in self.convolutions:
            x = conv(char_embs)
            x, _ = torch.max(x, -1)
            x = F.relu(x)
            conv_result.append(x)

        x = torch.cat(conv_result, dim=-1)

        if self.highway is not None:
            x = self.highway(x)
        x = self.projection(x)

        return x


class Highway(torch.nn.Module):
    """
    A `Highway layer <https://arxiv.org/abs/1505.00387>`_.
    Adopted from the AllenNLP implementation.
    """

    def __init__(self, input_dim: int, num_layers: int = 1):
        super(Highway, self).__init__()
        self.input_dim = input_dim
        self.layers = nn.ModuleList(
            [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]
        )
        self.activation = nn.ReLU()

        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            # As per comment in AllenNLP:
            # We should bias the highway layer to just carry its input forward.  We do that by
            # setting the bias on `B(x)` to be positive, because that means `g` will be biased to
            # be high, so we will carry the input forward.  The bias on `B(x)` is the second half
            # of the bias vector in each Linear layer.
            nn.init.constant_(layer.bias[self.input_dim :], 1)

            nn.init.constant_(layer.bias[: self.input_dim], 0)
            nn.init.xavier_normal_(layer.weight)

    def forward(self, x: torch.Tensor):
        for layer in self.layers:
            projection = layer(x)
            proj_x, gate = projection.chunk(2, dim=-1)
            proj_x = self.activation(proj_x)
            gate = torch.sigmoid(gate)
            x = gate * x + (gate.new_tensor([1]) - gate) * proj_x
        return x