import tensorflow as tf
from tensor2tensor.models import common_layers
import numpy as np

def CreateLayerHistory(hparams, is_encoder, name=None):
    history_type = hparams.encoder_history_type if is_encoder else hparams.decoder_history_type
    if history_type is None:
        return None
    elif history_type == "dense":
        return DenseLayerHistory(hparams, is_encoder)
    elif history_type == "learnable_dense":
        return LearnableDenseLayerHistory(hparams, is_encoder, name=name)
    elif history_type == "SqueezeExcitation":
        return SqueezeExcitationLayerHistory(hparams, is_encoder)
    else:
        raise ValueError


class BaseLayerHistory(object):

    def __init__(self, hparams, is_encoder):
        super(BaseLayerHistory, self).__init__()
        self.hparams = hparams
        self.is_encoder = is_encoder
        # the first layer (aka. embedding layer) does not have layer normalization
        self.layer_num = hparams.encoder_layers if is_encoder else hparams.decoder_layers
        self.dim = hparams.hidden_size
        self.layer_norms = [common_layers.layer_norm for _ in range(self.layer_num)]
        self.normalize_before = hparams.normalize_before

    def add(self, layer):
        raise NotImplemented

    def pop(self):
        raise NotImplemented

    def clean(self):
        raise NotImplemented


class DenseLayerHistory(BaseLayerHistory):
    """
    x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
    """
    def __init__(self, hparams, is_encoder):
        super(DenseLayerHistory, self).__init__(hparams, is_encoder)
        self.sum = None
        self.count = 0

    def add(self, layer):
        self.count += 1

        # first layer
        if self.sum is None:
            self.sum = layer
            return

        if self.normalize_before:
            layer = self.layer_norms[self.count - 2](layer,name="layer_norm%d"%(self.count-2))
        # following layer
        self.sum = self.sum + layer

    def pop(self):
        assert self.sum is not None
        ret = self.sum / self.count
        if self.count ==1 or self.normalize_before:
            return ret

        return self.layer_norms[self.count - 2](ret, name="layer_norm%d"%(self.count-2))


    def clean(self):
        self.sum = None
        self.count = 0


class LearnableDenseLayerHistory(BaseLayerHistory):
    """
    x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
    """
    def __init__(self, hparams, is_encoder, name):
        super(LearnableDenseLayerHistory, self).__init__(hparams, is_encoder)
        self.sum = None
        self.count = 0
        lower_triangle = tf.matrix_band_part(tf.ones([self.layer_num + 1, self.layer_num + 1]), -1, 0)
        with tf.variable_scope("%s/layer_history"%name):
            self.weight = tf.get_variable("layer_weight", initializer=lower_triangle)
            scale = tf.reduce_sum(self.weight, axis=1, keep_dims=True)
            self.weight = self.weight / scale

        self.layers = []

    def add(self, layer):
        self.count += 1

        # first layer
        if self.sum is None:
            self.sum = layer
            self.layers.append(layer)
            return

        # following layer
        if self.normalize_before:
            layer = self.layer_norms[self.count - 2](layer ,name="layer_norm%d"%(self.count-2))

        self.layers.append(layer)

    def pop(self):
        assert len(self.layers) > 0
        ret = tf.reduce_sum(tf.stack(self.layers, axis=0) * tf.reshape(self.weight[self.count -1 ,: self.count], shape=[-1, 1, 1, 1]), axis=0)
        if self.count == 1 or self.normalize_before:
            return ret
        return self.layer_norms[self.count - 2](ret ,name="layer_norm%d"%(self.count-2))

    def clean(self):
        self.sum = None
        self.count = 0
        self.layers = []

    def get_loss(self):
        return (0.5 * (self.weight.sum(1) - 1.0) ** 2).mean()


class SqueezeExcitationLayerHistory(BaseLayerHistory):
    """
    x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
    """
    def __init__(self, hparams, is_encoder):
        super(SqueezeExcitationLayerHistory, self).__init__(hparams, is_encoder)
        self.sum = None
        self.count = 0
        self.fc1_list = [common_layers.Linear for _ in range(self.layer_num)]
        self.fc2_list = [common_layers.Linear for _ in range(self.layer_num)]
        self.layers = []

    def add(self, layer):
        self.count += 1

        # first layer
        if self.sum is None:
            self.sum = layer
            self.layers.append(layer)
            return

        # following layer

        if self.normalize_before:
            layer = self.layer_norms[self.count - 2](layer,name="layer_norm%d"%(self.count-2))

        self.sum = self.sum + layer
        self.layers.append(layer)

    def pop(self):
        assert len(self.layers) > 0
        ret = self.sum / self.count
        if self.count ==1:
            return ret
        ret = self.fc1_list[self.count - 2](ret, int(self.dim/self.hparams.se_scale),activation=tf.nn.relu,name="se_dense1_layer%d"%(self.count-2))
        ret = self.fc2_list[self.count - 2](ret, self.dim, name="se_dense2_layer%d"%(self.count-2))
        scale = tf.nn.sigmoid(ret)
        ret = self.layers[-1] * scale

        if self.normalize_before:
            return ret

        return self.layer_norms[self.count - 2](ret + self.layers[-1],name="layer_norm%d"%(self.count-2))

    def clean(self):
        self.sum = None
        self.count = 0
        self.layers = []
