#!/usr/bin/env python
# encoding: utf-8
"""
@author: Wang Qiang
@contact: wangqiangneu@gmail.com
@file: fast_new_to_fairseq.py.py
@time: 2018/11/15 14:16
@desc: Convert online model to fairseq
"""

import argparse
import struct
import time
import numpy as np
import math
from collections import OrderedDict
import numpy
STR_BINARY_LENGTH = 64
SPLIT_FLAG_BETWEEN_SETTING_KEY_AND_VALUE = "-"
SPLIT_FLAG_BETWEEN_SETTING_AND_WEIGHT = "|||||"
POSITIONAL_EMBEDDING_LENGTH = 1024


def parse_args():
    parser = argparse.ArgumentParser('fast converter from online Transformer model to fairseq')
    parser.add_argument('--model-input', required=True, help='online model path')
    parser.add_argument('--vocab-input', required=True, help='online vocab path')
    parser.add_argument('--model-output', required=True, help='fairseq checkpoint path')
    parser.add_argument('--vocab-output', required=True, help='fairseq dict path')
    return parser.parse_args()


def read_str(data, offset, size=STR_BINARY_LENGTH):
    value = struct.unpack('%ds'%size, data[offset: offset+size])
    # note: it is necessary to strip \x00 in the bytes array. Or the str can not be equal
    # bytes array to str. note: we use 'errors=ignore' to prevent the exception
    value = value[0].strip(b'\x00').decode(errors='ignore')
    offset += size
    return value, offset


def read_int(data, offset, n=1):
    value = struct.unpack('%di'%n, data[offset: offset+4*n])
    offset += 4*n
    if n == 1:
        value = value[0]
    return value, offset


def read_float(data, offset, n=1):
    value = struct.unpack('%df'%n, data[offset: offset+4*n])
    offset += 4*n
    if n == 1:
        value = value[0]
    return value, offset


def load_setting(data, offset=0):
    setting = OrderedDict()
    find_split_flag = False
    while True:
        key, offset = read_str(data, offset)
        if key == SPLIT_FLAG_BETWEEN_SETTING_AND_WEIGHT:
            find_split_flag = True
            break
        tag, offset = read_str(data, offset, 1)
        assert tag == SPLIT_FLAG_BETWEEN_SETTING_KEY_AND_VALUE
        value, offset = read_str(data, offset)
        setting[key] = value
    assert find_split_flag, offset
    return setting, offset


def read_tensor(data, offset):
    name, offset = read_str(data, offset)
    order, offset = read_int(data, offset)
    dims, offset = read_int(data, offset, n=order)
    data_type, offset = read_str(data, offset)

    def get_count(shape):
        if isinstance(shape, int):
            shape = [shape]
        c = 1
        for dim in shape:
            c *= dim
        return c
    weights, offset = read_float(data, offset, n=get_count(dims))
    return (name, order, dims, data_type, weights), offset


def load_param(data, offset):
    params = OrderedDict()

    while offset < len(data):
        tensor, offset = read_tensor(data, offset)
        params[tensor[0]] = tensor[1:]
        #print(tensor[0])
        #print('offset={} len(data)={}'.format(offset, len(data)))
    #print(params)
    return params

def convert_setting(settings):
    print(settings)
    args = {}
    args['arch'] = 'transformer'
    args['encoder_layers'] = int(settings['src_layer_num'])
    args['decoder_layers'] = int(settings['tgt_layer_num'])
    args['encoder_embed_dim'] = int(settings['embedding_size'])
    args['decoder_embed_dim'] = int(settings['embedding_size'])
    args['encoder_ffn_embed_dim'] = int(settings['inner_hidden_size'])
    args['decoder_ffn_embed_dim'] = int(settings['inner_hidden_size'])
    args['encoder_attention_heads'] = int(settings['head_num'])
    args['decoder_attention_heads'] = int(settings['head_num'])
    #args[''] = settings['src_max_length']
    #args[''] = settings['tgt_max_length']

    # const params
    # although both t2t and fairseq use sin-pos-emb, but the implement is different
    # so we should set pos-emb in fairseq as learned
    args['encoder_learned_pos'] = True
    args['decoder_learned_pos'] = True
    args['encoder_normal_type'] = 'layer'
    args['decoder_normal_type'] = 'layer'
    args['encoder_normalize_before'] = False
    args['decoder_normalize_before'] = False
    args['left_pad_source'] = True
    args['left_pad_target'] = False
    args['share_all_embeddings'] = False
    args['share_decoder_input_output_embed'] = False
    # dropout doesn't matter in inference
    args['dropout'] = 0.1
    #args['dropout'] = 0.3
    args['attention_dropout'] = 0.0
    args['relu_dropout'] = 0.0
    return argparse.Namespace(**args)

def get_tensor(name, params, transpose=True):
    order, dims, data_type, weights = params[name]
    if isinstance(dims, int):
        dims = [dims]
    #print('order={} dims={}'.format(order, str(dims)))
    tensor = torch.from_numpy(numpy.array(weights)).float()
    tensor = tensor.view(*dims)

    if transpose and tensor.dim() > 1:
        tensor = tensor.t()

    if tensor.dim() == 2 and tensor.size(1) == 1:
        tensor = tensor.squeeze(1)
    print('%s size=%s' % (name, str(tensor.size())))
    return tensor

def convert_param(settings, params):
    model = {}
    prefix = 'transformer_encoder'
    embed = get_tensor("%s/word-emb" % prefix, params, transpose=False)
    # note: t2t has space-emb, so the final emb = scale*word-emb + space-emb
    # to avoid using space-emb, we can change the word-emb as (word-emb + space-emb/scale)
    space_emb = get_tensor("%s/space-emb" % prefix, params, transpose=False)
    """
    # <PAD> <EOS> <UNK> <Lua> -> <Lua> <PAD> <EOS> <UNK>
    pad_row = embed[0, :].unsqueeze(0)
    eos_row = embed[1, :].unsqueeze(0)
    unk_row = embed[2, :].unsqueeze(0)
    lua_row = embed[3, :].unsqueeze(0)
    model['encoder.embed_tokens.weight'] = torch.cat([lua_row, pad_row, eos_row, unk_row, embed[4:,:]], dim=0)
    """
    # <PAD> <EOS> <UNK> -> <Lua> <PAD> <EOS> <UNK>
    lua_row = torch.zeros(1, embed.size(1))
    embed = embed + space_emb / (embed.size(1) ** 0.5)
    model['encoder.embed_tokens.weight'] = torch.cat([lua_row, embed], dim=0)
    
    pos_emb = get_tensor("%s/positional-emb" % prefix, params, transpose=False)
    #print('enc-pos:', pos_emb)
    #model['encoder.embed_positions._float_tensor'] = torch.Tensor([])
    # in fairseq, pos-index is from 2
    pos_emb = torch.cat([torch.zeros(2, pos_emb.size(1)), pos_emb[:-2,:]], dim=0)
    model['encoder.embed_positions.weight'] = pos_emb
    for layer_id in range(int(settings["src_layer_num"])):
        p1 = 'encoder.layers.%d.self_attn' %layer_id
        p2 = '%s/%d/sa' %(prefix, layer_id+1)
        model['%s.in_proj_weight'%p1] = get_tensor('%s/qkv-trans-w'%p2, params)
        model['%s.in_proj_bias' % p1] = get_tensor('%s/qkv-trans-b' % p2, params)
        model['%s.out_proj.weight' % p1] = get_tensor('%s/out-trans-w' % p2, params)
        model['%s.out_proj.bias' % p1] = get_tensor('%s/out-trans-b' % p2, params)

        p1 = 'encoder.layers.%d' %layer_id
        p2 = '%s/%d/fnn' %(prefix, layer_id+1)
        model['%s.fc1.weight'%p1] = get_tensor('%s/weight1'%p2, params)
        model['%s.fc1.bias' % p1] = get_tensor('%s/bias1' % p2, params)
        model['%s.fc2.weight' % p1] = get_tensor('%s/weight2' % p2, params)
        model['%s.fc2.bias' % p1] = get_tensor('%s/bias2' % p2, params)

        p1 = 'encoder.layers.%d.layer_norms' %layer_id
        p2 = '%s/%d/sa-ln' %(prefix, layer_id+1)
        model['%s.0.weight'%p1] = get_tensor('%s/scale'%p2, params)
        model['%s.0.bias' % p1] = get_tensor('%s/bias' % p2, params)
        p2 = '%s/%d/fnn-ln' %(prefix, layer_id+1)
        model['%s.1.weight' % p1] = get_tensor('%s/scale' % p2, params)
        model['%s.1.bias' % p1] = get_tensor('%s/bias' % p2, params)

    prefix = 'transformer_decoder'
    embed = get_tensor('%s/word-emb' %prefix, params, transpose=False)
    """
    # <PAD> <EOS> <UNK> <Lua> -> <Lua> <PAD> <EOS> <UNK>
    pad_row = embed[0, :].unsqueeze(0)
    eos_row = embed[1, :].unsqueeze(0)
    unk_row = embed[2, :].unsqueeze(0)
    lua_row = embed[3, :].unsqueeze(0)
    model['decoder.embed_tokens.weight'] = torch.cat([lua_row, pad_row, eos_row, unk_row, embed[4:, :]], dim=0)
    """
    lua_row = torch.zeros(1, embed.size(1))
    model['decoder.embed_tokens.weight'] = torch.cat([lua_row, embed], dim=0)

    #model['decoder.embed_positions._float_tensor'] = torch.Tensor([])
    # in fairseq, pos-index is from 2
    pos_emb = get_tensor("%s/positional-emb" % prefix, params, transpose=False)
    pos_emb = torch.cat([torch.zeros(2, pos_emb.size(1)), pos_emb[:-2,:]], dim=0)
    model['decoder.embed_positions.weight'] = pos_emb 
    for layer_id in range(int(settings["tgt_layer_num"])):
        p1 = 'decoder.layers.%d.self_attn' %layer_id
        p2 = '%s/%d/sa' %(prefix, layer_id+1)
        model['%s.in_proj_weight'%p1] = get_tensor('%s/qkv-trans-w'%p2, params)
        model['%s.in_proj_bias' % p1] = get_tensor('%s/qkv-trans-b' % p2, params)
        model['%s.out_proj.weight' % p1] = get_tensor('%s/out-trans-w' % p2, params)
        model['%s.out_proj.bias' % p1] = get_tensor('%s/out-trans-b' % p2, params)

        p1 = 'decoder.layers.%d.encoder_attn' % layer_id
        p2 = '%s/%d/eda' % (prefix, layer_id + 1)
        q_w = get_tensor('%s/q-trans-w' % p2, params)
        q_b = get_tensor('%s/q-trans-b' % p2, params)
        kv_w = get_tensor('%s/kv-trans-w' % p2, params)
        kv_b = get_tensor('%s/kv-trans-b' % p2, params)
        #print('q_w={} kv_w={}'.format(str(q_w.size()), str(kv_w.size())))
        #print('q_b={} kv_b={}'.format(str(q_b.size()), str(kv_b.size())))

        model['%s.in_proj_weight' % p1] = torch.cat([q_w, kv_w], dim=0)
        model['%s.in_proj_bias' % p1] = torch.cat([q_b, kv_b], dim=0)
        model['%s.out_proj.weight' % p1] = get_tensor('%s/out-trans-w' % p2, params)
        model['%s.out_proj.bias' % p1] = get_tensor('%s/out-trans-b' % p2, params)

        p1 = 'decoder.layers.%d' % layer_id
        p2 = '%s/%d/fnn' % (prefix, layer_id + 1)
        model['%s.fc1.weight' % p1] = get_tensor('%s/weight1' % p2, params)
        model['%s.fc1.bias' % p1] = get_tensor('%s/bias1' % p2, params)
        model['%s.fc2.weight' % p1] = get_tensor('%s/weight2' % p2, params)
        model['%s.fc2.bias' % p1] = get_tensor('%s/bias2' % p2, params)

        p1 = 'decoder.layers.%d.layer_norms' % layer_id
        p2 = '%s/%d/sa-ln' % (prefix, layer_id + 1)
        model['%s.0.weight' % p1] = get_tensor('%s/scale' % p2, params)
        model['%s.0.bias' % p1] = get_tensor('%s/bias' % p2, params)
        p2 = '%s/%d/eda-ln' % (prefix, layer_id + 1)
        model['%s.1.weight' % p1] = get_tensor('%s/scale' % p2, params)
        model['%s.1.bias' % p1] = get_tensor('%s/bias' % p2, params)
        p2 = '%s/%d/fnn-ln' % (prefix, layer_id + 1)
        model['%s.2.weight' % p1] = get_tensor('%s/scale' % p2, params)
        model['%s.2.bias' % p1] = get_tensor('%s/bias' % p2, params)

    softmax_w_tensor = get_tensor('output-w', params)
    """
    pad_row = softmax_w_tensor[0, :].unsqueeze(0)
    eos_row = softmax_w_tensor[1, :].unsqueeze(0)
    unk_row = softmax_w_tensor[2, :].unsqueeze(0)
    lua_row = softmax_w_tensor[3, :].unsqueeze(0)
    model['decoder.embed_out'] = torch.cat([lua_row, pad_row, eos_row, unk_row, softmax_w_tensor[4:,:]], dim=0)
    """
    #lua_row = torch.zeros(1, softmax_w_tensor.size(1)).fill_(float('-inf'))
    lua_row = torch.zeros(1, softmax_w_tensor.size(1))
    model['decoder.embed_out'] = torch.cat([lua_row, softmax_w_tensor], dim=0)
    return model

def convert_dict(vocab_input, vocab_output):
    specials = ['<PAD>', '<UNK>', '<EOS>', '<LUA>']
    with open(vocab_input, 'r') as infile, \
            open(vocab_output+'.src', 'w') as src_ofile, \
            open(vocab_output+'.tgt', 'w') as tgt_ofile:
        is_src = True
        for line in infile:
            # explicitly strip '\n', by default, strip() will delete more characters, which may be the token
            line = line.strip('\n')
            if line.startswith('====='):
                assert is_src
                is_src = False
                continue

            # explicitly split by ' ', or it may be wrong to split due to some special symbols
            items = line.split(' ')
            assert len(items) == 2
            idx, token = items
            if token in specials:
                continue
            print('%s %d' %(token, 1), file=src_ofile if is_src else tgt_ofile)


def main(args):
    start = time.time()

    data = open(args.model_input, 'rb').read()

    settings, offset = load_setting(data, offset=0)
    for k,v in settings.items():
        print('{}={}'.format(k,v))

    params = load_param(data, offset)

    ck_state = {}
    ck_state['args'] = convert_setting(settings)
    ck_state['model'] = convert_param(settings, params)
    torch.save(ck_state, args.model_output)

    convert_dict(args.vocab_input, args.vocab_output)

    elapsed = time.time() - start
    print('use %.2f seconds'%elapsed)

if __name__ == '__main__':
    import torch
    #ck = torch.load("example/fairseq/transformer.last5.ensemble.pt")
    #print('args type:', type(ck['args']))
    args = parse_args()
    main(args)
