import torch
from torch import nn
import torch.nn.functional as F
from fairseq.models.transformer import Linear, Embedding, LayerNorm
from fairseq.modules.multihead_attention import MultiheadAttention

def CreateLayerStack(args, is_encoder, is_inner):
    print(args)
    dropout = args.stack_dropout
    stack_type=""
    if is_inner:
        if is_encoder:
            stack_type = args.encoder_inner_stack_type
        else:
            stack_type = args.decoder_inner_stack_type
    else:
        if is_encoder:
            stack_type = args.encoder_outer_stack_type
        else:
            stack_type = args.decoder_outer_stack_type
    #stack_type = args.encoder_stack_type if is_encoder else args.decoder_stack_type
    block_num = args.encoder_layers if is_encoder else args.decoder_layers
    sublayer_num_in_block = 2 if is_encoder else 3
    layer_norm_owned_type = args.encoder_layer_normal_owned_type if is_encoder else args.decoder_layer_normal_owned_type
    layer_norm_dim = args.encoder_embed_dim if is_encoder else args.decoder_embed_dim

    if stack_type == "residual":
        return ResLayerStack(block_num,
                             sublayer_num_in_block,
                             layer_norm_dim,
                             ln_owned_type=layer_norm_owned_type,
                             dropout=dropout)
    elif stack_type == "None":
        return NoneLayerStack(block_num,
                             sublayer_num_in_block,
                             layer_norm_dim,
                             ln_owned_type=layer_norm_owned_type,
                             dropout=dropout)
    elif stack_type == "gru":
        gru_owned_type = args.encoder_gru_owned_type if is_encoder else args.decoder_gru_owned_type
        return GRULayerStack(block_num,
                             sublayer_num_in_block,
                             layer_norm_dim,
                             args.stack_hidden_size,
                             args.stack_hidden_size,
                             ln_owned_type=layer_norm_owned_type,
                             gru_owned_type=gru_owned_type,
                             dropout=dropout)
    elif stack_type == "aan":
        return AANLayerStack(block_num,
                             sublayer_num_in_block,
                             layer_norm_dim,
                             args.stack_hidden_size,
                             args.stack_inner_hidden_size,
                             use_layer_embed=args.add_layer_embed,
                             ln_owned_type=layer_norm_owned_type,
                             relu_dropout=args.stack_relu_dropout,
                             dropout=dropout)
    elif stack_type == "mhsa":
        return MHSALayerStack(block_num,
                              sublayer_num_in_block,
                              layer_norm_dim,
                              args.stack_hidden_size,
                              args.stack_inner_hidden_size,
                              args.stack_head_num,
                              use_layer_embed=args.add_layer_embed,
                              ln_owned_type=layer_norm_owned_type,
                              relu_dropout=args.stack_relu_dropout,
                              dropout=dropout,
                              att_dropout=args.stack_att_dropout)
    else:
        raise ValueError('unsupported stack type = {}'.format(stack_type))


class LayerStackBase(nn.Module):

    def __init__(self,
                 block_num,
                 sublayer_num_in_block,
                 ln_dim,
                 ln_owned_type="every_sublayer",
                 dropout=0.0):
        super(LayerStackBase, self).__init__()
        self.stack = []
        self.dropout = dropout
        ln_num = self._get_module_num(ln_owned_type, block_num, sublayer_num_in_block)
        if isinstance(ln_dim, int):
            self.layer_normals = nn.ModuleList(LayerNorm(ln_dim) for _ in range(ln_num))
        if isinstance(ln_dim, list):
            assert len(ln_dim) == ln_num, "stack depth={}, but len(dim)={}".format(ln_num, len(ln_dim))
            self.layer_normals = nn.ModuleList(LayerNorm(d) for _, d in zip(range(ln_num), ln_dim))
        self.pop_index = 0
        self.ln_owned_type = ln_owned_type
        self.block_num = block_num
        self.sublayer_num_in_block = sublayer_num_in_block
        print('stack depth:{}'.format(len(self.layer_normals)))

    def push(self, layer, use_dropout=True):
        if use_dropout:
            layer = F.dropout(layer, self.dropout, self.training)
            #print('do drooput:{}'.format(layer))
        self.stack.append(layer)

    def pop(self, reduce=True):
        raise NotImplementedError

    def is_empty(self):
        return True if len(self.stack) == 0 else False

    def get_top(self):
        assert not self.is_empty(), "stack is empty!"
        return self.stack[-1]

    def clear(self):
        del self.stack
        self.stack = []
        self.pop_index =0


    def _pop_post_process(self, x, reduce):
        ln = self._pick(self.pop_index, self.ln_owned_type, self.layer_normals)
        x = ln(x)
        if reduce:
            self.push(x, use_dropout=False)
        self.pop_index += 1
        return x

    def _pick(self, idx, owned_type, module_list):
        if owned_type == "every_sublayer":
            return module_list[idx]
        if owned_type == "every_block":
            return module_list[idx / self.sublayer_num_in_block]
        if owned_type == "every_same_sublayer":
            return module_list[idx % self.block_num]
        if owned_type == "every_side":
            return module_list[0]

    def _get_module_num(self, owned_type, block_num, sublayer_num_in_block):
        if owned_type == "every_sublayer":
            return sublayer_num_in_block
        elif owned_type == "every_block":
            return 1
        elif owned_type == "every_same_sublayer":
            return sublayer_num_in_block
        elif owned_type == "every_side":
            return 1
        else:
            raise ValueError("unknown owned type = {}".format(owned_type))

class ResLayerStack(LayerStackBase):
    """
        be equal to original Transformer
    """
    def __init__(self,
                 block_num,
                 sublayer_num_in_block,
                 ln_dim,
                 ln_owned_type="every_sublayer",
                 dropout=0.0):
        super(ResLayerStack, self).__init__(block_num,
                                            sublayer_num_in_block,
                                            ln_dim,
                                            ln_owned_type=ln_owned_type,
                                            dropout=dropout)
        print(self.layer_normals)
    def pop(self, reduce=True):
        assert not self.is_empty(), "stack is empty!"
        if len(self.stack) == 1:
            return self.stack[0]
        # (L, B, H)
        cur_output = self.stack[-1]
        cur_input = self.stack[-2]

        # \hat_y = LN(x + dropout(y))
        x = cur_input + cur_output
        return self._pop_post_process(x, reduce)

class DenseLayerStack(LayerStackBase):
    """
        Dense connections
    """
    def __init__(self, dropout=0.0):
        super(DenseLayerStack, self).__init__(dropout)


    def pop(self, reduce=True):
        assert not self.is_empty(), "stack is empty!"
        if len(self.stack) == 1:
            return self.stack[0]
        # (B, L, H)
        cur_output = self.stack[-1]
        cur_input = self.stack[-2]

        # \hat_y = [x, dropout(y)]
        ret = torch.cat([cur_input, cur_output], dim=-1)

        if reduce:
            self.push(ret)
        return ret


class GRULayerStack(LayerStackBase):

    def __init__(self,
                 block_num,
                 sublayer_num_in_block,
                 ln_dim,
                 input_size,
                 hidden_size,
                 ln_owned_type="every_sublayer",
                 gru_owned_type="every_side",
                 dropout=0.0):
        super(GRULayerStack, self).__init__(block_num, sublayer_num_in_block, ln_dim,
                                            ln_owned_type=ln_owned_type,
                                            dropout=dropout)
        gru_num = self._get_module_num(gru_owned_type, block_num, sublayer_num_in_block)
        self.grus = nn.ModuleList(nn.GRUCell(input_size, hidden_size) for _ in range(gru_num))
        self.gru_owned_type = gru_owned_type
        #self.gru = nn.GRUCell(input_size, hidden_size)

    def pop(self, reduce=True):
        assert not self.is_empty(), "stack is empty!"
        if len(self.stack) == 1:
            return self.stack[0]
            #x = self.stack[0]
            #prev_h = x.new_zeros(x.size())
        else:
            # (L, B, H)
            x = self.stack[-1]
            prev_h = self.stack[-2]

        L, B, H = x.size()
        # (L*B, H) -> (L, B, H)
        gru = self._pick(self.pop_index, self.gru_owned_type, self.grus)
        h = gru(x.view(-1, H), prev_h.view(-1, H)).view(L, B, H)
        return self._pop_post_process(h, reduce)


class LSTMLayerStack(LayerStackBase):

    def __init__(self, input_size, hidden_size, depth, dim, dropout=0.0):
        # note: we add the depth because when depth=1, we use LN. It is different from residual stack.
        super(LSTMLayerStack, self).__init__(depth+1, dim, dropout=dropout)
        #self.gru = nn.GRUCell(input_size, hidden_size)
        self.lstm = nn.LSTMCell(input_size, hidden_size)

    def pop(self, reduce=True):
        assert not self.is_empty(), "stack is empty!"
        if len(self.stack) == 1:
            x = self.stack[0]
            prev_h = x.new_zeros(x.size())
        else:
            # (L, B, H)
            x = self.stack[-1]
            prev_h = self.stack[-2]

        L, B, H = x.size()
        # (L*B, H) -> (L, B, H)
        h = self.gru(x.view(-1, H), prev_h.view(-1, H)).view(L, B, H)
        return self.post_process(h, reduce)

class FnnLayer(nn.Module):

    def __init__(self, input_size, inner_hidden_size, output_size=None, dropout=0.0):
        super(FnnLayer, self).__init__()
        self.output_size = output_size or input_size
        self.f1 = Linear(input_size, inner_hidden_size)
        #self.dropout = F.dropout(self.f1, dropout, self.training)
        self.f2 = Linear(inner_hidden_size, self.output_size)
        self.dropout = dropout

    def forward(self, x):
        x = F.relu(self.f1(x))
        return self.f2(F.dropout(x, self.dropout, self.training))

class AANLayerStack(LayerStackBase):

    def __init__(self, block_num, sublayer_num_in_block, ln_dim,
                 hidden_size, inner_hidden_size,
                 use_layer_embed=True,
                 ln_owned_type="every_sublayer",
                 dropout=0.0, relu_dropout=0.0):
        super(AANLayerStack, self).__init__(block_num, sublayer_num_in_block, ln_dim,
                                            ln_owned_type=ln_owned_type, dropout=dropout)
        self.sum = None
        self.top = None
        self.count = 0
        self.hidden_size = hidden_size
        self.ffn = FnnLayer(hidden_size, inner_hidden_size, dropout=relu_dropout)
        self.gate_linear = Linear(2*hidden_size, 2*hidden_size)
        self.layer_embed = None
        if use_layer_embed:
            self.layer_embed = Embedding(block_num * sublayer_num_in_block + 1, hidden_size, padding_idx=None)

    def push(self, x, use_dropout=True):
        if self.layer_embed is not None:
            idx = x.new_zeros(x.size()[:-1]).fill_(self.count).long()
            x = x + self.layer_embed(idx)
        if use_dropout:
            x = F.dropout(x, self.dropout, self.training)
        if self.count == 0:
            self.sum = x
        else:
            self.sum = self.sum + x
        self.top = x
        self.count += 1

    def pop(self, reduce=False):
        if self.count == 1:
            return self.sum
        g = self.ffn(self.sum / self.count)
        linear = self.gate_linear(torch.cat([self.top, g], dim=-1))
        i, f = torch.split(linear, self.hidden_size, dim=-1)
        ret = F.sigmoid(i) * self.top + F.sigmoid(f) * g
        return self._pop_post_process(ret, reduce)

    def clear(self):
        super(AANLayerStack, self).clear()
        self.top = None
        self.sum = None
        self.count = 0


class MHSALayerStack(LayerStackBase):

    def __init__(self, block_num, sublayer_num_in_block, ln_dim,
                 hidden_size, inner_hidden_size, head_num,
                 use_layer_embed=True,
                 ln_owned_type="every_sublayer",
                 dropout=0.0, relu_dropout=0.0, att_dropout=0.0):

        super(MHSALayerStack, self).__init__(block_num, sublayer_num_in_block, ln_dim,
                                            ln_owned_type=ln_owned_type, dropout=dropout)
        self.query = None
        self.key = None
        self.count = 0
        self.hidden_size = hidden_size
        self.mhsa = MultiheadAttention(hidden_size, head_num, att_dropout)
        self.layer_embed = None
        if use_layer_embed:
            self.layer_embed = Embedding(block_num * sublayer_num_in_block + 1, hidden_size, padding_idx=None)
        self.cache = {}
        self.shape = None

    # x=(T, B, H)
    def push(self, x, use_dropout=True):
        if self.layer_embed is not None:
            idx = x.new_zeros(x.size()[:-1]).fill_(self.count).long()
            x = x + self.layer_embed(idx)
        if use_dropout:
            x = F.dropout(x, self.dropout, self.training)
        self.shape = x.size()

        x = x.view(1, self.shape[0]*self.shape[1], self.shape[2])
        self.query = x
        if self.count == 0:
            self.key = x
        else:
            self.key = torch.cat([self.key, x], 0)
        self.count += 1

    def pop(self, reduce=False):
        if self.count == 1:
            return self.query.view(self.shape)

        # (1, T*B, H)
        ret, _ = self.mhsa(self.query, self.key, self.key, self.cache, need_weights=False)
        return self._pop_post_process(ret.view(self.shape), reduce)

    def clear(self):
        super(MHSALayerStack, self).clear()
        self.query = None
        self.key = None
        self.count = 0
        del self.cache
        self.cache = {}

class NoneLayerStack(LayerStackBase):
    """
        be equal to original Transformer
    """
    def __init__(self,
                 block_num,
                 sublayer_num_in_block,
                 ln_dim,
                 ln_owned_type="every_sublayer",
                 dropout=0.0):
        super(NoneLayerStack, self).__init__(block_num,
                                            sublayer_num_in_block,
                                            ln_dim,
                                            ln_owned_type=ln_owned_type,
                                            dropout=dropout)
        self.layer_normals = None

    def pop(self, reduce=False):
        assert not self.is_empty(), "stack is empty!"
        x = self.stack[-1]
        return x


