# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""transformer (attention).

encoder: [Self-Attention, Feed-forward] x n
decoder: [Self-Attention, Source-Target-Attention, Feed-forward] x n
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy

# Dependency imports

from six.moves import xrange  # pylint: disable=redefined-builtin

from tensor2tensor.models import common_attention
from tensor2tensor.models import common_hparams
from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf


@registry.register_model
class TransformerMLRF(t2t_model.T2TModel):
    """Attention net.  See file docstring."""

    def model_fn_body(self, features):
        # Remove dropout if not training
        hparams = copy.copy(self._hparams)
        targets = features["targets"]
        inputs = features.get("inputs")
        target_space = features.get("target_space_id")

        inputs = common_layers.flatten4d3d(inputs)
        targets = common_layers.flatten4d3d(targets)

        (encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder(
            inputs, target_space, hparams))
        (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder(
            targets, hparams)

        def residual_fn(x, y):
            return common_layers.layer_norm(x + tf.nn.dropout(
                y, 1.0 - hparams.residual_dropout))

        # encoder_input = tf.squeeze(encoder_input, 2)
        # decoder_input = tf.squeeze(decoder_input, 2)
        encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
        decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout)
        encoder_output = transformer_encoder(encoder_input, residual_fn,
                                             encoder_attention_bias, hparams)

        decoder_output = transformer_decoder(
            decoder_input, encoder_output, residual_fn, decoder_self_attention_bias,
            encoder_attention_bias, hparams)
        decoder_output = tf.expand_dims(decoder_output, 2)

        return decoder_output


def transformer_prepare_encoder(inputs, target_space, hparams):
    """Prepare one shard of the model for the encoder.

    Args:
      inputs: a Tensor.
      target_space: a Tensor.
      hparams: run hyperparameters

    Returns:
      encoder_input: a Tensor, bottom of encoder stack
      encoder_self_attention_bias: a Tensor, containing large negative values
        to implement masked attention and possibly baises for diagonal
        alignments
      encoder_padding: a Tensor
    """
    # Flatten inputs.
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    encoder_self_attention_bias = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        target_space, 32, ishape_static[-1], name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    return (encoder_input, encoder_self_attention_bias, encoder_padding)


def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

    Args:
      targets: a Tensor.
      hparams: run hyperparameters

    Returns:
      decoder_input: a Tensor, bottom of decoder stack
      decoder_self_attention_bias: a Tensor, containing large negative values
      to implement masked attention and possibly baises for diagonal alignments
    """
    # use correct padding
    if hparams.fix_padding:
        # (B, Lt)
        decoder_padding = common_attention.embedding_to_padding(targets)
        targets = tf.Print(targets, [targets, tf.shape(targets)], message='target:', summarize=100)
        decoder_padding = tf.Print(decoder_padding, [decoder_padding,
                                                     tf.shape(decoder_padding)],
                                   summarize=1000, message="decoding_pad:")

        # (B, 1, 1, Lt)
        decoder_padding_exp = tf.expand_dims(tf.expand_dims(decoder_padding, 1), 1)
        length = tf.shape(targets)[1]
        # (Lt, Lt)
        history_padding = 1 - tf.matrix_band_part(tf.ones([length, length]), -1, 0)
        history_padding = tf.Print(history_padding, [history_padding, tf.shape(history_padding)],
                                   message="history pad:", summarize=1000)
        # (1, 1, Lt, Lt)
        history_padding = tf.expand_dims(tf.expand_dims(history_padding, 0), 0)
        # (B, 1, Lt, Lt)
        padding = decoder_padding_exp | tf.cast(history_padding, dtype=tf.bool)
        padding = tf.Print(padding, [padding, tf.shape(padding)], message='padding:', summarize=1000)
        decoder_self_attention_bias = tf.to_float(padding) * -1e9
    # not use correct padding
    else:
        decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))

    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)


def transformer_encoder(encoder_input,
                        residual_fn,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder"):
    """A stack of transformer layers.

    Args:
      encoder_input: a Tensor
      residual_fn: a function from (layer_input, layer_output) -> combined_output
      encoder_self_attention_bias: bias Tensor for self-attention
         (see common_attention.attention_bias())
      hparams: hyperparameters for model
      name: a string

    Returns:
      y: a Tensors
    """
    x = encoder_input
    # Summaries don't work in multi-problem setting yet.
    summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
    with tf.variable_scope(name):
        if hparams.fuse_encoder:
            layer_rep = [x] if hparams.fuse_word_embedding else []

        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                x = residual_fn(
                    x,
                    common_attention.multihead_attention(
                        x,
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        summaries=summaries,
                        name="encoder_self_attention"))
                x = residual_fn(x, transformer_ffn_layer(x, hparams))
            # save every output layer
            if hparams.fuse_encoder:
                layer_rep.append(x)
    # fuse different layers here
    if hparams.fuse_encoder:
        x = layer_fusion(layer_rep, name, hparams, residual_fn, is_encoder=True)
    return x


def layer_fusion(layer_rep, name, hparams, residual_fn, is_encoder=True):
    fused_layer_num = hparams.num_hidden_layers + 1 if hparams.fuse_word_embedding else hparams.fuse_word_embedding

    B = tf.shape(layer_rep[0])[0]
    L = tf.shape(layer_rep[0])[1]
    H = tf.shape(layer_rep[0])[2]
    N = len(layer_rep)

    # Summaries don't work in multi-problem setting yet.
    #summaries = "problems" not in hparams.values() or len(hparams.problems) == 1

    # (B,L,N*H)
    fusion = tf.concat(layer_rep, 2)
    # (B*L, N, H)
    x = tf.reshape(fusion, shape=[B * L, N, H])
    if hparams.use_layer_embedding:
        layer_emb = tf.get_variable(
            'layer_emb', [fused_layer_num, hparams.hidden_size],
            initializer=tf.random_normal_initializer(0.0, hparams.hidden_size ** -0.5))
        if hparams.scale_layer_embedding:
            layer_emb *= (hparams.hidden_size ** 0.5)
        x += tf.expand_dims(layer_emb, axis=0)

    if hparams.fuse_method == 'mhsa':
        x = mhsa_fusion(x, hparams, name, residual_fn, layer_rep)
    if hparams.fuse_method == 'avg':
        x = avg_fusion(x, hparams, name, residual_fn, layer_rep)

    return x

def avg_fusion(x, hparams, name, residual_fn, layer_rep):
    """
    using avg-pooling as fusion function, do not add any parameters
    :param x: (B*L, N, H)
    :param hparams:
    :param name:
    :param residual_fn:
    :param layer_rep:
    :return: (B, L, H)
    """
    B = tf.shape(layer_rep[0])[0]
    L = tf.shape(layer_rep[0])[1]
    H = tf.shape(layer_rep[0])[2]
    N = len(layer_rep)

    # (B*L, H)
    x = tf.reduce_mean(x, axis=1)
    x = tf.reshape(x, [B, L, H])
    x.set_shape(layer_rep[0].get_shape())
    return x

def mhsa_fusion(x, hparams, name, residual_fn, layer_rep):
    """
    using multi-head self attention as fusion function
    :param x: multiple layers representation, shape is (B*L, N, H), N is n_layer
    :param hparams:
    :param name:
    :param residual_fn:
    :param layer_rep: list, each item is the layer representations, only used to get shape
    :return: fused representation, shape is (B, L, H)
    """
    B = tf.shape(layer_rep[0])[0]
    L = tf.shape(layer_rep[0])[1]
    H = tf.shape(layer_rep[0])[2]
    N = len(layer_rep)

    with tf.variable_scope(name + '_mlrf'):
        res = tf.expand_dims(x[:, -1, :], axis=1)
        # (B*L, 1, H)
        x = residual_fn(
            res,
            common_attention.multihead_attention(
                res,
                x,
                #tf.zeros([B * L, 1, 1, N], dtype=tf.float32),
                None,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.fused_num_heads,
                hparams.fused_attention_dropout,
                summaries=False,
                name="mlrf_self_attention"))
        # reshape -> (B, L, H)
        x = tf.reshape(x, [B, L, H])
        x.set_shape(layer_rep[0].get_shape())
        # (B, L, H)
        x = residual_fn(x, transformer_ffn_layer(x, hparams, is_mlrf=True))
        return x

def transformer_decoder(decoder_input,
                        encoder_output,
                        residual_fn,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        name="decoder"):
    """A stack of transformer layers.

    Args:
      decoder_input: a Tensor
      encoder_output: a Tensor
      residual_fn: a function from (layer_input, layer_output) -> combined_output
      decoder_self_attention_bias: bias Tensor for self-attention
        (see common_attention.attention_bias())
      encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
        (see common_attention.attention_bias())
      hparams: hyperparameters for model
      name: a string

    Returns:
      y: a Tensors
    """
    x = decoder_input
    # Summaries don't work in multi-problem setting yet.
    summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
    with tf.variable_scope(name):
        if hparams.fuse_decoder:
            layer_rep = [x] if hparams.fuse_word_embedding else []
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                x = residual_fn(
                    x,
                    common_attention.multihead_attention(
                        x,
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        summaries=summaries,
                        name="decoder_self_attention"))
                x = residual_fn(
                    x,
                    common_attention.multihead_attention(
                        x,
                        encoder_output,
                        encoder_decoder_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        summaries=summaries,
                        name="encdec_attention"))
                x = residual_fn(x, transformer_ffn_layer(x, hparams))
            if hparams.fuse_decoder:
                layer_rep.append(x)
    if hparams.fuse_decoder:
        x = layer_fusion(layer_rep, name, hparams, residual_fn, is_encoder=False)

    return x


def transformer_ffn_layer(x, hparams, is_mlrf=False):
    """Feed-forward layer in the transformer.

    Args:
      x: a Tensor of shape [batch_size, length, hparams.hidden_size]
      hparams: hyperparmeters for model

    Returns:
      a Tensor of shape [batch_size, length, hparams.hidden_size]
    """
    if hparams.ffn_layer == "conv_hidden_relu":
        return common_layers.conv_hidden_relu(
            x,
            hparams.filter_size if not is_mlrf else hparams.fused_inner_hidden,
            hparams.hidden_size,
            dropout=hparams.relu_dropout if not is_mlrf else hparams.fused_relu_dropout)
    elif hparams.ffn_layer == "parameter_attention":
        return common_attention.parameter_attention(
            x,
            hparams.parameter_attention_key_channels or hparams.hidden_size,
            hparams.parameter_attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.filter_size,
            hparams.num_heads,
            hparams.attention_dropout)
    elif hparams.ffn_layer == "conv_hidden_relu_with_sepconv":
        return common_layers.conv_hidden_relu(
            x,
            hparams.filter_size,
            hparams.hidden_size,
            kernel_size=(3, 1),
            second_kernel_size=(31, 1),
            padding="LEFT",
            dropout=hparams.relu_dropout)
    else:
        assert hparams.ffn_layer == "none"
        return x


@registry.register_hparams
def transformer_mlrf_base():
    """Set of hyperparameters."""
    hparams = common_hparams.basic_params1()
    hparams.hidden_size = 512
    hparams.batch_size = 4096
    hparams.max_length = 256
    hparams.dropout = 0.0
    hparams.clip_grad_norm = 0.  # i.e. no gradient clipping
    hparams.optimizer_adam_epsilon = 1e-9
    hparams.learning_rate_decay_scheme = "noam"
    hparams.learning_rate = 0.1
    hparams.learning_rate_warmup_steps = 4000
    hparams.initializer_gain = 1.0
    hparams.num_hidden_layers = 6
    hparams.initializer = "uniform_unit_scaling"
    hparams.weight_decay = 0.0
    hparams.optimizer_adam_beta1 = 0.9
    hparams.optimizer_adam_beta2 = 0.98
    hparams.num_sampled_classes = 0
    hparams.label_smoothing = 0.1
    hparams.shared_embedding_and_softmax_weights = int(True)

    hparams.add_hparam("filter_size", 2048)  # Add new ones like this.
    # attention-related flags
    hparams.add_hparam("num_heads", 8)
    hparams.add_hparam("attention_key_channels", 0)
    hparams.add_hparam("attention_value_channels", 0)
    hparams.add_hparam("ffn_layer", "conv_hidden_relu")
    hparams.add_hparam("parameter_attention_key_channels", 0)
    hparams.add_hparam("parameter_attention_value_channels", 0)
    # All hyperparameters ending in "dropout" are automatically set to 0.0
    # when not in training mode.
    hparams.add_hparam("attention_dropout", 0.0)
    hparams.add_hparam("relu_dropout", 0.0)
    hparams.add_hparam("residual_dropout", 0.1)
    hparams.add_hparam("pos", "timing")  # timing, none
    hparams.add_hparam("nbr_decoder_problems", 1)

    # MLRF parameters
    hparams.add_hparam('fuse_method', 'mhsa')
    hparams.add_hparam('fuse_word_embedding', True)
    hparams.add_hparam('use_layer_embedding', True)
    hparams.add_hparam('fuse_encoder', True)
    hparams.add_hparam('fuse_decoder', True)
    hparams.add_hparam('fused_num_heads', 8)
    hparams.add_hparam('fused_inner_hidden', 2048)
    hparams.add_hparam('fused_relu_dropout', 0.0)
    hparams.add_hparam('fused_attention_dropout', 0.0)
    hparams.add_hparam('scale_layer_embedding', True)
    hparams.add_hparam('fix_padding', False)
    return hparams

@registry.register_hparams
def transformer_mlrf_base_noscale_fixpad():
    hparams = transformer_mlrf_base()
    hparams.scale_layer_embedding = False
    hparams.fix_padding = True
    return hparams

@registry.register_hparams
def transformer_mlrf_base_noscale_dropout1():
    hparams = transformer_mlrf_base()
    hparams.fused_relu_dropout = 0.1
    hparams.fused_attention_dropout = 0.1
    hparams.scale_layer_embedding = False
    return hparams

@registry.register_hparams
def transformer_mlrf_base_both_half_noscale():
    hparams = transformer_mlrf_base()
    hparams.fused_num_heads = 4
    hparams.fused_inner_hidden = 1024
    hparams.scale_layer_embedding = False
    return hparams

@registry.register_hparams
def transformer_mlrf_base_decoder_half_noscale():
    hparams = transformer_mlrf_base()
    hparams.fuse_encoder = False
    hparams.fused_num_heads = 4
    hparams.fused_inner_hidden = 1024
    hparams.scale_layer_embedding = False
    return hparams

@registry.register_hparams
def transformer_mlrf_base_encoder_half_noscale():
    hparams = transformer_mlrf_base()
    hparams.fuse_decoder = False
    hparams.fused_num_heads = 4
    hparams.fused_inner_hidden = 1024
    hparams.scale_layer_embedding = False
    return hparams



@registry.register_hparams
def transformer_mlrf_big():
    """HParams for transfomer big model on WMT."""
    hparams = transformer_mlrf_base()
    hparams.hidden_size = 1024
    hparams.filter_size = 4096
    hparams.num_heads = 16
    hparams.batching_mantissa_bits = 2
    hparams.residual_dropout = 0.3

    hparams.fused_num_heads = 16
    hparams.fused_inner_hidden = 4096
    return hparams


@registry.register_hparams
def transformer_mlrf_big_single_gpu():
    """HParams for transformer big model for single gpu."""
    hparams = transformer_mlrf_big()
    hparams.residual_dropout = 0.1
    hparams.learning_rate_warmup_steps = 16000
    hparams.optimizer_adam_beta2 = 0.998
    hparams.batching_mantissa_bits = 3
    return hparams


@registry.register_hparams
def transformer_mlrf_base_single_gpu():
    """HParams for transformer base model for single gpu."""
    hparams = transformer_mlrf_base()
    hparams.batch_size = 8192
    hparams.learning_rate_warmup_steps = 16000
    hparams.batching_mantissa_bits = 2
    return hparams


@registry.register_hparams
def transformer_mlrf_base_debug():
    hparams = transformer_mlrf_base()
    hparams.num_hidden_layers = 3
    hparams.fused_inner_hidden = 128
    hparams.hidden_size = 64
    hparams.filter_size = 128
    hparams.batch_size = 512
    hparams.fix_padding = True
    return hparams

@registry.register_hparams
def transformer_mlrf_base_decoder_noscale_nosmooth_dropout1():
    hparams = transformer_mlrf_base()
    hparams.fused_relu_dropout = 0.1
    hparams.fused_attention_dropout = 0.1
    hparams.relu_dropout = 0.1
    hparams.attention_dropout = 0.1
    hparams.scale_layer_embedding = False
    hparams.fuse_encoder = False
    hparams.label_smoothing = 0.0
    return hparams

@registry.register_hparams
def transformer_mlrf_base_decoder_dropout1():
    hparams = transformer_mlrf_base()
    hparams.fused_relu_dropout = 0.1
    hparams.fused_attention_dropout = 0.1
    hparams.relu_dropout = 0.1
    hparams.attention_dropout = 0.1
    hparams.scale_layer_embedding = False
    hparams.fuse_encoder = False
    #hparams.label_smoothing = 0.0
    return hparams


@registry.register_hparams
def transformer_mlrf_base_decoder_nosmooth():
    hparams = transformer_mlrf_base()
    hparams.scale_layer_embedding = False
    hparams.fuse_encoder = False
    hparams.label_smoothing = 0.0
    return hparams