import argparse
import struct
import time
import numpy as np
import math
from collections import OrderedDict

STR_BINARY_LENGTH = 64
SPLIT_FLAG_BETWEEN_SETTING_KEY_AND_VALUE = "-"
SPLIT_FLAG_BETWEEN_SETTING_AND_WEIGHT = "|||||"
POSITIONAL_EMBEDDING_LENGTH = 1024

# position embedding, shape is max-seq-length * emb_size
# note that this tensor actually has no need training, it is a fixed tensor, we precompute it for speedup
# return L * H
def get_pos_emb(seq_len, emb_size, min_timescale=1.0, max_timescale=1.0e4):
    length = seq_len
    channels = emb_size
    position = np.arange(length)
    #position = tf.to_float(tf.range(length))
    num_timescales = channels // 2
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (num_timescales - 1))
    inv_timescales = min_timescale * np.exp(
        np.arange(num_timescales) * -log_timescale_increment)
    # wq: shape is L * H/2
    scaled_time = np.expand_dims(position,1) * np.expand_dims(inv_timescales, 0)
    # wq: shape is L * H (note that if H is odd, actually current signal shape is H-1)
    signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
    # wq: if H is odd, then we need to add 1 at the end of signal, to make the shape as L * H
    signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
    return signal.flatten().tolist()

def parse_args():
    parser = argparse.ArgumentParser('fast online Transformer model converter')
    parser.add_argument('-i', '--input', required=True, help='old.txt online model path')
    parser.add_argument('-o', '--output', required=True, help='new online model path')
    return parser.parse_args()

def load_setting(data, offset=0):
    setting = OrderedDict()
    setting['src_max_length'] = POSITIONAL_EMBEDDING_LENGTH
    setting['tgt_max_length'] = POSITIONAL_EMBEDDING_LENGTH

    size = 7 * 4
    # 7 integers:
    # src_layer_num; tgt_layer_num; emb_size; inner_hidden_size;
    # head_num; src_vocab_size; tgt_vocab_size
    tuples = struct.unpack('7i', data[offset:offset+size])
    offset += size
    for name, value in zip(['src_layer_num', 'tgt_layer_num', 'embedding_size',
                            'inner_hidden_size', 'head_num', 'src_vocab_size',
                            'tgt_vocab_size'], tuples):
        #print('{}={}'.format(name, value))
        setting[name] = value
    #print(setting)

    # optional (activation, use_relative_position, max_relative_length
    # these are default value
    setting['activation'] = 'relu'
    setting['use_rpr'] = 0
    setting['max_relative_length'] = 16

    size = 16
    # note: it is necessary to strip \x00 in the bytes array. Or the str can not be equal
    activation = struct.unpack('16s', data[offset:offset + size])[0].strip(b'\x00')
    # bytes array to str. note: we use 'errors=ignore' to prevent the exception
    activation = activation.decode(errors='ignore')
    # no 'activation' means this is an old online model
    if activation not in ['relu', 'swish']:
        print('[WARN] not found activation ...')
        print('existing activation=%s', activation)
        return setting, offset
    else:
        setting['activation'] = activation
        offset += size

    size = 2 * 4
    tuples = struct.unpack('2i', data[offset:offset + size])
    use_rpr = tuples[0]
    max_relative_position = tuples[1]
    if use_rpr not in [0, 1]:
        print('[WARN] not found rpr ...')
        print('existing use_rpr=%d max_relative_position=%d' %(use_rpr, max_relative_position))
        return setting, offset
    else:
        setting['use_rpr'] = use_rpr
        setting['max_relative_length'] = max_relative_position
        offset += size

    return setting, offset

def load_weight_(data, offset, shape):
    def get_count(shape):
        c = 1
        for dim in shape:
            c *= dim
        return c
    c = get_count(shape)
    size = c * 4
    assert offset + size <= len(data), "beyond model bound"
    tuples = struct.unpack('%df'%c, data[offset:offset+size])
    offset += size
    return (shape, tuples), offset

def load_sa_param(prefix, params, data, offset, hidden_size, use_rpr, max_relative_length):
    params[prefix + 'qkv-trans-w'], offset = load_weight_(data, offset, [hidden_size, 3 * hidden_size])
    params[prefix + 'qkv-trans-b'], offset = load_weight_(data, offset, [1, 3*hidden_size])
    params[prefix + 'out-trans-w'], offset = load_weight_(data, offset, [hidden_size, hidden_size])
    params[prefix + 'out-trans-b'], offset = load_weight_(data, offset, [1, hidden_size])
    if use_rpr:
        params[prefix + 'rp-key'], offset = load_weight_(data, offset, [2 * max_relative_length + 1, hidden_size])
        params[prefix + 'rp-value'], offset = load_weight_(data, offset, [2 * max_relative_length + 1, hidden_size])
    return params, offset

def load_ln_param(prefix, params, data, offset, hidden_size):
    params[prefix + 'scale'], offset = load_weight_(data, offset, [1, hidden_size])
    params[prefix + 'bias'], offset = load_weight_(data, offset, [1, hidden_size])
    return params, offset

def load_fnn_param(prefix, params, data, offset, hidden_size, inner_hidden_size):
    params[prefix + 'weight1'], offset = load_weight_(data, offset, [hidden_size, inner_hidden_size])
    params[prefix + 'bias1'], offset = load_weight_(data, offset, [1, inner_hidden_size])
    params[prefix + 'weight2'], offset = load_weight_(data, offset, [inner_hidden_size, hidden_size])
    params[prefix + 'bias2'], offset = load_weight_(data, offset, [1, hidden_size])
    return params, offset

def load_eda_param(prefix, params, data, offset, hidden_size):
    params[prefix + 'q-trans-w'], offset = load_weight_(data, offset, [hidden_size, hidden_size])
    params[prefix + 'q-trans-b'], offset = load_weight_(data, offset, [1, hidden_size])
    params[prefix + 'kv-trans-w'], offset = load_weight_(data, offset, [hidden_size, 2 * hidden_size])
    params[prefix + 'kv-trans-b'], offset = load_weight_(data, offset, [1, 2*hidden_size])
    params[prefix + 'out-trans-w'], offset = load_weight_(data, offset, [hidden_size, hidden_size])
    params[prefix + 'out-trans-b'], offset = load_weight_(data, offset, [1, hidden_size])
    return params, offset

def load_param(data, offset, setting):
    params = OrderedDict()
    emb_size = setting['embedding_size']
    hidden_size = emb_size
    src_vocab_size = setting['src_vocab_size']
    src_layer_num = setting['src_layer_num']
    use_rpr = setting['use_rpr']
    max_relative_length = setting['max_relative_length']
    inner_hidden_size = setting['inner_hidden_size']

    # src word embedding & target space embedding & position embedding
    prefix = 'transformer_encoder/'
    params[prefix+'word-emb'], offset = load_weight_(data, offset, [src_vocab_size, emb_size])
    params[prefix+'space-emb'], offset = load_weight_(data, offset, [1, emb_size])
    #params[prefix+'positional-emb'], offset = load_weight_(data, offset, [200, emb_size])
    # recompute positional embedding. we use a longer length, e.g. 1024 rather than 200
    pos_emb_shape = (POSITIONAL_EMBEDDING_LENGTH, emb_size)
    params[prefix + 'positional-emb'] = (pos_emb_shape, get_pos_emb(POSITIONAL_EMBEDDING_LENGTH, emb_size))
    # skip it
    offset += 200 * emb_size * 4
    """    
    my_pos_emb = get_pos_emb(200, emb_size)
    for i in range(10):
        for j in range(10):
            print("%.4f " %params[prefix + 'positional-emb'][1][i*512+j], end='')
        print("")
    print("")
    for i in range(10):
        for j in range(10):
            print("%.4f " %my_pos_emb[i * 512 + j], end='')
        print("")
    print("")
    exit(-1)
    """

    # encoder layers
    for layer_id in range(src_layer_num):
        prefix = '%s/%d/%s/'%('transformer_encoder', layer_id+1, 'sa')
        params, offset = load_sa_param(prefix, params, data, offset, hidden_size, use_rpr, max_relative_length)
        prefix = prefix.replace('sa', 'sa-ln')
        params, offset = load_ln_param(prefix, params, data, offset, hidden_size)
        prefix = prefix.replace('sa-ln', 'fnn')
        params, offset = load_fnn_param(prefix, params, data, offset, hidden_size, inner_hidden_size)
        prefix = prefix.replace('fnn', 'fnn-ln')
        params, offset = load_ln_param(prefix, params, data, offset, hidden_size)

    # tgt embedding
    tgt_vocab_size = setting['tgt_vocab_size']
    tgt_layer_num = setting['tgt_layer_num']
    prefix = 'transformer_decoder/'
    params[prefix + 'word-emb'], offset = load_weight_(data, offset, [tgt_vocab_size, emb_size])
    #params[prefix + 'positional-emb'], offset = load_weight_(data, offset, [200, emb_size])
    params[prefix + 'positional-emb'] = (pos_emb_shape, get_pos_emb(POSITIONAL_EMBEDDING_LENGTH, emb_size))
    # skip it
    offset += 200 * emb_size * 4

    # decoder layers
    for layer_id in range(tgt_layer_num):
        prefix = '%s/%d/%s/' % ('transformer_decoder', layer_id + 1, 'sa')
        params, offset = load_sa_param(prefix, params, data, offset, hidden_size, use_rpr, max_relative_length)
        prefix = prefix.replace('sa', 'sa-ln')
        params, offset = load_ln_param(prefix, params, data, offset, hidden_size)
        prefix = prefix.replace('sa-ln', 'eda')
        params, offset = load_eda_param(prefix, params, data, offset, hidden_size)
        prefix = prefix.replace('eda', 'eda-ln')
        params, offset = load_ln_param(prefix, params, data, offset, hidden_size)
        prefix = prefix.replace('eda-ln', 'fnn')
        params, offset = load_fnn_param(prefix, params, data, offset, hidden_size, inner_hidden_size)
        prefix = prefix.replace('fnn', 'fnn-ln')
        params, offset = load_ln_param(prefix, params, data, offset, hidden_size)

    params['output-w'], offset = load_weight_(data, offset, [hidden_size, tgt_vocab_size])
    return params

def write_setting(f, settings):
    # all string are stored by 64 bytes
    for k,v in settings.items():
        f.write(struct.pack('%ds'%STR_BINARY_LENGTH, k.encode('ascii')))
        # this is a split flag between key and value. help us to check the model illegal
        f.write(struct.pack('1s', SPLIT_FLAG_BETWEEN_SETTING_KEY_AND_VALUE.encode('ascii')))
        # all values are wrote as string
        f.write(struct.pack('%ds' % STR_BINARY_LENGTH, str(v).encode('ascii')))
    f.write(struct.pack('%ds'%STR_BINARY_LENGTH, SPLIT_FLAG_BETWEEN_SETTING_AND_WEIGHT.encode('ascii')))


def write_param(f, params):
    is_first = True
    for k,v in params.items():
        shape = v[0]
        weight = v[1]
        # param name
        f.write(struct.pack('%ds' % STR_BINARY_LENGTH, k.encode('ascii')))
        # order
        f.write(struct.pack('i', len(shape)))
        # dimensions
        for i in range(len(shape)):
            f.write(struct.pack('i', shape[i]))
        # data type, default is float
        f.write(struct.pack('%ds'%STR_BINARY_LENGTH, 'X_FLOAT'.encode('ascii')))
        # value
        # write all floats together, this is faster than write the value singly
        f.write(struct.pack('%df' % len(weight), *weight))
        # we dump the first 10 floats in the first tensor to check the correction
        if is_first:
            for i in range(min(10, len(weight))):
                print('[%d] %.4f' % (i, weight[i]))
            is_first = False
        print('write %s over ...' % k)

def main(args):
    start = time.time()

    data = open(args.input, 'rb').read()

    settings, offset = load_setting(data, offset=0)
    for k,v in settings.items():
        print('{}={}'.format(k,v))
    params = load_param(data, offset, settings)

    with open(args.output, 'wb') as fout:
        write_setting(fout, settings)
        write_param(fout, params)

    elapsed = time.time() - start
    print('use %.2f seconds'%elapsed)

if __name__ == '__main__':
    args = parse_args()
    main(args)