from __future__ import print_function
# a common method
from io import open as open
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
from tensorflow.python import pywrap_tensorflow
import re
import tensorflow as tf
import math
from collections import OrderedDict
import argparse
import struct
import time
import numpy as np
import math
import torch
import os


def find_useful_param(model_file):
  global src_max_length
  global tgt_max_length
  global src_vocab_size
  global tgt_vocab_size
  global emb_size
  global hidden_size
  global src_layer_num
  global tgt_layer_num
  global activation_function
  global use_relative_position_representation
  global max_relative_length
  global normalize_before
  global use_dense
  global norm_type

  trainable_param = dict()
  try:
    reader = pywrap_tensorflow.NewCheckpointReader(model_file)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in sorted(var_to_shape_map):
      # this is our real trainable parameters
      if key.startswith(model_name):
        # print("tensor_name: {} - shape: {}".format(key, var_to_shape_map[key]))
        trainable_param[key] = var_to_shape_map[key]

        # capture source vocab size and emb_size
        match = re.match(model_name+'/symbol_modality_(\d+)_(\d+)/input_emb', key)
        if match and (src_vocab_size is None and emb_size is None):
          src_vocab_size = int(match.group(1))
          emb_size = int(match.group(2))

        # capture target vocab size
        match = re.match(model_name+'/symbol_modality_(\d+)_(\d+)/target_emb', key)
        if match and tgt_vocab_size is None:
          tgt_vocab_size = int(match.group(1))

        # capture encoder layer number
        match = re.match(model_name+'/body/encoder/layer_(\d+)/', key)
        if match:
          if src_layer_num is None or int(match.group(1)) > src_layer_num:
            src_layer_num = int(match.group(1))

        # capture decoder layer number
        match = re.match(model_name+'/body/decoder/layer_(\d+)/', key)
        if match:
          if tgt_layer_num is None or int(match.group(1)) > tgt_layer_num:
            tgt_layer_num = int(match.group(1))

        pattern = re.compile('/conv_hidden_([^/]+)/')
        match = re.search(pattern, key)
        if match and activation_function is None:
          activation_function = match.group(1)

        pattern = re.compile('dot_product_attention_relative')
        match = re.search(pattern, key)
        if match and use_relative_position_representation is None:
          use_relative_position_representation = 1
          max_relative_length = (var_to_shape_map[key][0] - 1) // 2

        pattern = re.compile('/layer_prepostprocess/')
        match = re.search(pattern, key)
        if match and normalize_before is False:
          normalize_before = True

        pattern = re.compile('/layer_history/layer_weight')
        match = re.search(pattern, key)
        if match and use_dense is False:
          use_dense = True

  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
  assert len(trainable_param) > 0, "not found any trainable parameters"

  # add 1 due to index from 0
  src_layer_num += 1
  tgt_layer_num += 1
  if activation_function is None:
    activation_function = 'relu'
  if use_relative_position_representation is None:
    use_relative_position_representation = False
  if use_relative_position_representation is None:
    use_relative_position_representation = 0
  if normalize_before is True:
      norm_type = "layer_prepostprocess"
  elif normalize_before is False:
      norm_type = "layer_postprocess"

  print('find src-vocab:{} tgt-vocab:{} emb-size:{} src-layer:{} tgt-layer:{} activation:{} relative:{} normalize_before:{} use_dense:{}'.
        format(src_vocab_size, tgt_vocab_size, emb_size, src_layer_num, tgt_layer_num,
               activation_function, use_relative_position_representation, normalize_before, use_dense))



  return reader, trainable_param


def load_param():
  print('loadding params ...')
  param_dict = OrderedDict()

  def _get_tensor(name, shard=1):
    tensor = None
    if shard == 1:
      full_name = name
      tensor = reader.get_tensor(full_name)
    if shard > 1:
      tensor_list = []
      for i in range(shard):
        full_name = '{}_{}'.format(name, i)
        tensor_list += [reader.get_tensor(full_name)]
      tensor = tf.concat(tensor_list, axis=0)
    # print('{} shape is: {}'.format(full_name, tensor.shape))
    return tensor

  # source embedding, shape is src_vocab * emb_size
  src_emb_tensor = _get_tensor(model_name + '/symbol_modality_{}_{}/input_emb/weights'.format(src_vocab_size, emb_size), shard=shard)
  param_dict['src_emb'] = src_emb_tensor

  # space embedding, we only use target_space_id, so the shape is H
  space_emb_tensor = _get_tensor(model_name + '/body/target_space_embedding/kernel')
  param_dict['space_emb'] = tf.reshape(tf.gather(space_emb_tensor, tf.constant(target_space_id)), [1, -1])
  #print('target space id emb shape:', param_dict['space_emb'].shape)

  # position embedding, shape is max-seq-length * emb_size
  # note that this tensor actually has no need training, it is a fixed tensor, we precompute it for speedup
  # return L * H
  def get_pos_emb(seq_len=src_max_length, emb_size=512, min_timescale=1.0, max_timescale=1.0e4):
    length = seq_len
    channels = emb_size
    position = tf.to_float(tf.range(length))
    num_timescales = channels // 2
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (tf.to_float(num_timescales) - 1))
    inv_timescales = min_timescale * tf.exp(
      tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
    # wq: shape is L * H/2
    scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
    # wq: shape is L * H (note that if H is odd, actually current signal shape is H-1)
    signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
    # wq: if H is odd, then we need to add 1 at the end of signal, to make the shape as L * H
    signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
    return signal

  # note, actually pos_emb is same in source and target.
  # However, in practice, we write the pos_emb in both source and target sides
  # This is convinient for us to init encoder and decoder independently
  param_dict['src_pos_emb'] = get_pos_emb(src_max_length, emb_size=emb_size)

  # encoder
  for layer_id in range(src_layer_num):
    # step 1. qvk_trans_w shape is H * 3H (because query, key, value is same); qvk_trans_b shape is 3H
    qvk_trans_w = _get_tensor(
        model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/kernel' % layer_id)
    qvk_trans_w = tf.squeeze(qvk_trans_w, axis=[0, 1])
    qvk_trans_b = _get_tensor(
        model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/bias' % layer_id)
    param_dict['enc_%d_qvk_trans_w' % layer_id] = qvk_trans_w
    param_dict['enc_%d_qvk_trans_b' % layer_id] = tf.reshape(qvk_trans_b, [1, -1])

    # step 2. output_trans_w shape is H * H; output_trans_b shape is H
    output_tarns_w = _get_tensor(
        model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/output_transform_single/kernel' % layer_id)
    output_tarns_w = tf.squeeze(output_tarns_w, axis=[0, 1])
    output_trans_b = _get_tensor(
        model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/output_transform_single/bias' % layer_id)
    param_dict['enc_%d_out_trans_w' % layer_id] = output_tarns_w
    param_dict['enc_%d_out_trans_b' % layer_id] = tf.reshape(output_trans_b, [1, -1])


    # step 3. layer normalization
    layer_normal_bias = _get_tensor(
        model_name + '/body/encoder/layer_%d/self_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_bias')
    layer_normal_scale = _get_tensor(
        model_name + '/body/encoder/layer_%d/self_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
    param_dict['enc_%d_ln_scale_1' % layer_id] = tf.reshape(layer_normal_scale, [1, -1])
    param_dict['enc_%d_ln_bias_1' % layer_id] = tf.reshape(layer_normal_bias, [1, -1])

    # step 4. fnn_layer_1, w shape is H * X, b shape is X
    fnn_layer_1_w = _get_tensor(
        model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv1_single/kernel' % (layer_id, activation_function))
    fnn_layer_1_w = tf.squeeze(fnn_layer_1_w, axis=[0, 1])
    fnn_layer_1_b = _get_tensor(
        model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv1_single/bias' % (layer_id, activation_function))
    param_dict['enc_%d_fnn_1_w' % layer_id] = fnn_layer_1_w
    param_dict['enc_%d_fnn_1_b' % layer_id] = tf.reshape(fnn_layer_1_b, [1, -1])

    # step 5, fnn_layer_2, w shape is X * H, b shape is H
    fnn_layer_2_w = _get_tensor(
        model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv2_single/kernel' % (layer_id, activation_function))
    fnn_layer_2_w = tf.squeeze(fnn_layer_2_w, axis=[0, 1])
    fnn_layer_2_b = _get_tensor(
        model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv2_single/bias' % (layer_id, activation_function))
    param_dict['enc_%d_fnn_2_w' % layer_id] = fnn_layer_2_w
    param_dict['enc_%d_fnn_2_b' % layer_id] = tf.reshape(fnn_layer_2_b, [1, -1])

    # step 6. layer normalization 2
    layer_normal_bias_2 = _get_tensor(
        model_name + '/body/encoder/layer_%d/ffn/' % layer_id + norm_type + '/layer_norm/layer_norm_bias')
    layer_normal_scale_2 = _get_tensor(
        model_name + '/body/encoder/layer_%d/ffn/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
    param_dict['enc_%d_ln_scale_2' % layer_id] = tf.reshape(layer_normal_scale_2, [1, -1])
    param_dict['enc_%d_ln_bias_2' % layer_id] = tf.reshape(layer_normal_bias_2, [1, -1])

    if use_dense:
      # step 9. layer normalization dense
      enc_layer_normal_bias_dense = _get_tensor(model_name + '/body/encoder/layer_%d/layer_history/layer_norm%d/layer_norm_bias' % (layer_id, layer_id))
      enc_layer_normal_scale_dense = _get_tensor(model_name + '/body/encoder/layer_%d/layer_history/layer_norm%d/layer_norm_scale' % (layer_id, layer_id))
      param_dict['enc_%d_ln_scale_dense' % layer_id] = tf.reshape(enc_layer_normal_scale_dense, [1, -1])
      param_dict['enc_%d_ln_bias_dense' % layer_id] = tf.reshape(enc_layer_normal_bias_dense, [1, -1])

  # step 7. layer normalization top
  if normalize_before:
    layer_normal_bias_top = _get_tensor(model_name + '/body/encoder/layer_prepostprocess/layer_norm/layer_norm_bias')
    layer_normal_scale_top = _get_tensor(model_name + '/body/encoder/layer_prepostprocess/layer_norm/layer_norm_scale')
    param_dict['enc_ln_bias_top'] = tf.reshape(layer_normal_bias_top, [1, -1])
    param_dict['enc_ln_scale_top'] = tf.reshape(layer_normal_scale_top, [1, -1])

  # step 8. dense transformer weight matrix
  if use_dense:
    layer_weight = _get_tensor(model_name + '/body/encoder/layer_history/layer_weight')
    param_dict['enc_layer_weight'] = layer_weight

  # decoder
  # target embedding, shape is tgt_vocab * emb_size
  tgt_emb_tensor = _get_tensor(model_name + '/symbol_modality_{}_{}/target_emb/weights'.format(tgt_vocab_size, emb_size),
                               shard=shard)
  param_dict['tgt_emb'] = tgt_emb_tensor
  param_dict['tgt_pos_emb'] = get_pos_emb(tgt_max_length, emb_size=emb_size)

  for layer_id in range(tgt_layer_num):
    # step 1. decoder_self_attention
    # qvk_trans_w shape is H * 3H (because query, key, value is same); qvk_trans_b shape is 3H
    qvk_trans_w = _get_tensor(
      model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/kernel' % layer_id)
    qvk_trans_w = tf.squeeze(qvk_trans_w, axis=[0, 1])
    qvk_trans_b = _get_tensor(
      model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/bias' % layer_id)
    param_dict['dec_%d_decatt_qvk_trans_w' % layer_id] = qvk_trans_w
    param_dict['dec_%d_decatt_qvk_trans_b' % layer_id] = tf.reshape(qvk_trans_b, [1, -1])

    # output_trans_w shape is H * H; output_trans_b shape is H
    output_tarns_w = _get_tensor(
      model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/output_transform_single/kernel' % layer_id)
    output_tarns_w = tf.squeeze(output_tarns_w, axis=[0, 1])
    output_trans_b = _get_tensor(
      model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/output_transform_single/bias' % layer_id)
    param_dict['dec_%d_decatt_out_trans_w' % layer_id] = output_tarns_w
    param_dict['dec_%d_decatt_out_trans_b' % layer_id] = tf.reshape(output_trans_b, [1, -1])


    # layer normalization
    layer_normal_bias = _get_tensor(
      model_name + '/body/decoder/layer_%d/self_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_bias')
    layer_normal_scale = _get_tensor(
      model_name + '/body/decoder/layer_%d/self_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
    param_dict['dec_%d_decatt_ln_scale_1' % layer_id] = tf.reshape(layer_normal_scale, [1, -1])
    param_dict['dec_%d_decatt_ln_bias_1' % layer_id] = tf.reshape(layer_normal_bias, [1, -1])

    # step 2. encoder-decoder attention
    # shape is H * H
    encdec_q_trans_w = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/q_transform_single/kernel' % layer_id)
    encdec_q_trans_w = tf.squeeze(encdec_q_trans_w, axis=[0, 1])
    encdec_q_trans_b = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/q_transform_single/bias' % layer_id)
    param_dict['dec_%d_encdecatt_q_trans_w' % layer_id] = encdec_q_trans_w
    param_dict['dec_%d_encdecatt_q_trans_b' % layer_id] = tf.reshape(encdec_q_trans_b, [1, -1])

    # shape is H * 2H
    encdec_kv_trans_w = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/kv_transform_single/kernel' % layer_id)
    encdec_kv_trans_w = tf.squeeze(encdec_kv_trans_w, axis=[0, 1])
    encdec_kv_trans_b = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/kv_transform_single/bias' % layer_id)
    param_dict['dec_%d_encdecatt_kv_trans_w' % layer_id] = encdec_kv_trans_w
    param_dict['dec_%d_encdecatt_kv_trans_b' % layer_id] = tf.reshape(encdec_kv_trans_b, [1, -1])

    # shape is H * H
    encdec_out_trans_w = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/output_transform_single/kernel' % layer_id)
    encdec_out_trans_w = tf.squeeze(encdec_out_trans_w, axis=[0, 1])
    encdec_out_trans_b = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/output_transform_single/bias' % layer_id)
    param_dict['dec_%d_encdecatt_out_trans_w' % layer_id] = encdec_out_trans_w
    param_dict['dec_%d_encdecatt_out_trans_b' % layer_id] = tf.reshape(encdec_out_trans_b, [1, -1])

    encdec_layer_normal_bias = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_bias')
    encdec_layer_normal_scale = _get_tensor(
      model_name + '/body/decoder/layer_%d/encdec_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
    param_dict['dec_%d_encdecatt_ln_scale_1' % layer_id] = tf.reshape(encdec_layer_normal_scale, [1, -1])
    param_dict['dec_%d_encdecatt_ln_bias_1' % layer_id] = tf.reshape(encdec_layer_normal_bias, [1, -1])

    # step 3. fnn layer
    # fnn_layer_1, w shape is H * X, b shape is X
    fnn_layer_1_w = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv1_single/kernel' % (
    layer_id, activation_function))
    fnn_layer_1_w = tf.squeeze(fnn_layer_1_w, axis=[0, 1])
    fnn_layer_1_b = _get_tensor(
      model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv1_single/bias' % (layer_id, activation_function))
    param_dict['dec_%d_fnn_1_w' % layer_id] = fnn_layer_1_w
    param_dict['dec_%d_fnn_1_b' % layer_id] = tf.reshape(fnn_layer_1_b, [1, -1])

    # fnn_layer_2, w shape is X * H, b shape is H
    fnn_layer_2_w = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv2_single/kernel' % (
    layer_id, activation_function))
    fnn_layer_2_w = tf.squeeze(fnn_layer_2_w, axis=[0, 1])
    fnn_layer_2_b = _get_tensor(
      model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv2_single/bias' % (layer_id, activation_function))
    param_dict['dec_%d_fnn_2_w' % layer_id] = fnn_layer_2_w
    param_dict['dec_%d_fnn_2_b' % layer_id] = tf.reshape(fnn_layer_2_b, [1, -1])

    encdec_layer_normal_bias_2 = _get_tensor(
      model_name + '/body/decoder/layer_%d/ffn/' % layer_id + norm_type + '/layer_norm/layer_norm_bias')
    encdec_layer_normal_scale_2 = _get_tensor(
      model_name + '/body/decoder/layer_%d/ffn/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
    param_dict['dec_%d_encdecatt_ln_scale_2' % layer_id] = tf.reshape(encdec_layer_normal_scale_2, [1, -1])
    param_dict['dec_%d_encdecatt_ln_bias_2' % layer_id] = tf.reshape(encdec_layer_normal_bias_2, [1, -1])

    if use_dense:
      # step 9. layer normalization dense
      dec_layer_normal_bias_dense = _get_tensor(model_name + '/body/decoder/layer_%d/layer_history/layer_norm%d/layer_norm_bias' % (layer_id, layer_id))
      dec_layer_normal_scale_dense = _get_tensor(model_name + '/body/decoder/layer_%d/layer_history/layer_norm%d/layer_norm_scale' % (layer_id, layer_id))
      param_dict['dec_%d_ln_scale_dense' % layer_id] = tf.reshape(dec_layer_normal_scale_dense, [1, -1])
      param_dict['dec_%d_ln_bias_dense' % layer_id] = tf.reshape(dec_layer_normal_bias_dense, [1, -1])

  # step 7. layer normalization top
  if normalize_before:
    layer_normal_bias_top = _get_tensor(model_name + '/body/decoder/layer_prepostprocess/layer_norm/layer_norm_bias')
    layer_normal_scale_top = _get_tensor(model_name + '/body/decoder/layer_prepostprocess/layer_norm/layer_norm_scale')
    param_dict['dec_ln_bias_top'] = tf.reshape(layer_normal_bias_top, [1, -1])
    param_dict['dec_ln_scale_top'] = tf.reshape(layer_normal_scale_top, [1, -1])

  # step 8. dense transformer weight matrix
  if use_dense:
    layer_weight = _get_tensor(model_name + '/body/decoder/layer_history/layer_weight')
    param_dict['dec_layer_weight'] = layer_weight

  softmax_w_tensor = _get_tensor(model_name + '/symbol_modality_{}_{}/softmax/weights'.format(tgt_vocab_size, emb_size),
                                 shard=shard)
  # note: we transpose the matrix, from (V,H) to (H,V)
  param_dict['softmax_w'] = softmax_w_tensor
  return param_dict

def _get_param_numpy(name, sess, transpose=False):
  assert name in param_dict, "unknown param nam:{} in dict".format(name)
  value = param_dict[name]

  if type(value) is tf.Tensor:
    value = sess.run(value)
  value = torch.from_numpy(value).float()

  if transpose and value.dim() > 1:
    value = value.t()

  if value.dim() == 2 and value.size(1) == 1:
    value = value.squeeze(1)
  return value


def write_vocab():
  print('write vocab file ...')
  vocab_file = open('online.vocab', 'w')

  def _write_to_file(vocab, online_vocab):
    with open(vocab, 'r', encoding='utf-8') as f:
      line_num = 0
      for line in f:
        line_num += 1
        if line_num == 1:
          # print(unicode(u'{} {}'.format(line_num-1, u'<PAD>')) ,file=online_vocab)
          print(u'{} {}'.format(line_num - 1, u'<PAD>'), file=online_vocab)
          continue
        if line_num == 2:
          # print(unicode(u'{} {}'.format(line_num-1, u'<EOS>')), file=online_vocab)
          print(u'{} {}'.format(line_num - 1, u'<EOS>'), file=online_vocab)
          continue
        token = line.strip()[1:-1]
        # print(unicode(u'{} {}'.format(line_num-1, token)), file=online_vocab)
        print(u'{} {}'.format(line_num - 1, token), file=online_vocab)

  _write_to_file(src_vocab_name, vocab_file)
  print(u"================================", file=vocab_file)
  _write_to_file(tgt_vocab_name, vocab_file)
  vocab_file.close()

def convert_settings(settings):
  args = {}
  args['arch'] = 'transformer'
  args['encoder_layers'] = int(settings['src_layer_num'])
  args['decoder_layers'] = int(settings['tgt_layer_num'])
  args['encoder_embed_dim'] = int(settings['embedding_size'])
  args['decoder_embed_dim'] = int(settings['embedding_size'])
  args['encoder_ffn_embed_dim'] = int(settings['inner_hidden_size'])
  args['decoder_ffn_embed_dim'] = int(settings['inner_hidden_size'])
  args['encoder_attention_heads'] = int(settings['head_num'])
  args['decoder_attention_heads'] = int(settings['head_num'])

  if use_dense:
    args['arch'] = 'dense_transformer'
    args['encoder_history_type'] = 'learnable_dense'
    args['decoder_history_type'] = 'learnable_dense'
    args['encoder_integration_type'] = 'avg'
    args['decoder_integration_type'] = 'avg'

  # const params
  args['encoder_learned_pos'] = True
  args['decoder_learned_pos'] = True
  args['encoder_normal_type'] = 'layer'
  args['decoder_normal_type'] = 'layer'
  args['encoder_normalize_before'] = False
  args['decoder_normalize_before'] = False
  args['left_pad_source'] = True
  args['left_pad_target'] = False
  args['share_all_embeddings'] = False
  args['share_decoder_input_output_embed'] = False
  args['dropout'] = 0.1
  args['attention_dropout'] = 0.0
  args['relu_dropout'] = 0.0

  if normalize_before:
    args['encoder_normalize_before'] = True
    args['decoder_normalize_before'] = True
    args['attention_dropout'] = 0.1
    args['relu_dropout'] = 0.1

  return argparse.Namespace(**args)


def creat_settings():

  settings = OrderedDict()
  settings["src_max_length"] = str(src_max_length)
  settings["tgt_max_length"] = str(tgt_max_length)
  settings["src_layer_num"] = str(src_layer_num)
  settings["tgt_layer_num"] = str(tgt_layer_num)
  settings["embedding_size"] = str(emb_size)
  settings["inner_hidden_size"] = str(_get_param_numpy('enc_0_fnn_1_b', sess).shape[1])
  settings["head_num"] = str(args.head_num)
  settings["src_vocab_size"] = str(src_vocab_size)
  settings["tgt_vocab_size"] = str(tgt_vocab_size)
  settings["activation"] = str(activation_function)
  settings["use_rpr"] = str(use_relative_position_representation)
  settings["max_relative_length"] = str(max_relative_length)
  settings["normalize_before"] = str(normalize_before)
  settings["use_dense"] = str(use_dense)

  return settings

def convert_param():
  global src_layer_num
  global tgt_layer_num
  print("src_layer_num : %d" % src_layer_num)
  print("tgt_layer_num : %d" % tgt_layer_num)
  model = {}
  embed = _get_param_numpy("src_emb", sess)
  # note: t2t has space-emb, so the final emb = scale*word-emb + space-emb
  # to avoid using space-emb, we can change the word-emb as (word-emb + space-emb/scale)
  space_emb = _get_param_numpy("space_emb", sess)
  """
  # <PAD> <EOS> <UNK> <Lua> -> <Lua> <PAD> <EOS> <UNK>
  pad_row = embed[0, :].unsqueeze(0)
  eos_row = embed[1, :].unsqueeze(0)
  unk_row = embed[2, :].unsqueeze(0)
  lua_row = embed[3, :].unsqueeze(0)
  model['encoder.embed_tokens.weight'] = torch.cat([lua_row, pad_row, eos_row, unk_row, embed[4:,:]], dim=0)
  """
  # <PAD> <EOS> <UNK> -> <Lua> <PAD> <EOS> <UNK>
  lua_row = torch.zeros(1, embed.size(1))
  embed = embed + space_emb / (embed.size(1) ** 0.5)
  model['encoder.embed_tokens.weight'] = torch.cat([lua_row, embed], dim=0)

  pos_emb = _get_param_numpy("src_pos_emb", sess)
  # print('enc-pos:', pos_emb)
  # model['encoder.embed_positions._float_tensor'] = torch.Tensor([])
  # in fairseq, pos-index is from 2
  pos_emb = torch.cat([torch.zeros(2, pos_emb.size(1)), pos_emb[:-2, :]], dim=0)
  model['encoder.embed_positions.weight'] = pos_emb
  for layer_id in range(int(src_layer_num)):
    p1 = 'encoder.layers.%d.self_attn' % layer_id
    p2 = layer_id
    model['%s.in_proj_weight' % p1] = _get_param_numpy('enc_%d_qvk_trans_w' % p2, sess, transpose=True)
    model['%s.in_proj_bias' % p1] = _get_param_numpy('enc_%d_qvk_trans_b' % p2, sess, transpose=True)
    model['%s.out_proj.weight' % p1] = _get_param_numpy('enc_%d_out_trans_w' % p2, sess, transpose=True)
    model['%s.out_proj.bias' % p1] = _get_param_numpy('enc_%d_out_trans_b' % p2, sess, transpose=True)

    p1 = 'encoder.layers.%d' % layer_id

    model['%s.fc1.weight' % p1] = _get_param_numpy('enc_%d_fnn_1_w' % p2, sess, transpose=True)
    model['%s.fc1.bias' % p1] = _get_param_numpy('enc_%d_fnn_1_b' % p2, sess, transpose=True)
    model['%s.fc2.weight' % p1] = _get_param_numpy('enc_%d_fnn_2_w' % p2, sess, transpose=True)
    model['%s.fc2.bias' % p1] = _get_param_numpy('enc_%d_fnn_2_b' % p2, sess, transpose=True)

    p1 = 'encoder.layers.%d.layer_norms' % layer_id

    model['%s.0.weight' % p1] = _get_param_numpy('enc_%d_ln_scale_1' % p2, sess, transpose=True)
    model['%s.0.bias' % p1] = _get_param_numpy('enc_%d_ln_bias_1' % p2, sess, transpose=True)

    model['%s.1.weight' % p1] = _get_param_numpy('enc_%d_ln_scale_2' % p2, sess, transpose=True)
    model['%s.1.bias' % p1] = _get_param_numpy('enc_%d_ln_bias_2' % p2, sess, transpose=True)

    if use_dense:
      p1 = 'encoder.history.layer_norms.%d' % layer_id
      model['%s.weight' % p1] = _get_param_numpy('enc_%d_ln_scale_dense' % p2, sess, transpose=True)
      model['%s.bias' % p1] = _get_param_numpy('enc_%d_ln_bias_dense' % p2, sess, transpose=True)

  if normalize_before:
    p1 = 'encoder.layer_norm'
    model['%s.weight' % p1] = _get_param_numpy('enc_ln_scale_top', sess, transpose=True)
    model['%s.bias' % p1] = _get_param_numpy('enc_ln_bias_top', sess, transpose=True)

  if use_dense:
    enc_layer_weight = _get_param_numpy('enc_layer_weight', sess)
    scale = torch.sum(enc_layer_weight,dim=1,keepdim=True)
    enc_layer_weight = enc_layer_weight / scale
    model['encoder.history.weight'] = enc_layer_weight

  embed = _get_param_numpy('tgt_emb', sess)
  """
  # <PAD> <EOS> <UNK> <Lua> -> <Lua> <PAD> <EOS> <UNK>
  pad_row = embed[0, :].unsqueeze(0)
  eos_row = embed[1, :].unsqueeze(0)
  unk_row = embed[2, :].unsqueeze(0)
  lua_row = embed[3, :].unsqueeze(0)
  model['decoder.embed_tokens.weight'] = torch.cat([lua_row, pad_row, eos_row, unk_row, embed[4:, :]], dim=0)
  """
  lua_row = torch.zeros(1, embed.size(1))
  model['decoder.embed_tokens.weight'] = torch.cat([lua_row, embed], dim=0)

  # model['decoder.embed_positions._float_tensor'] = torch.Tensor([])
  # in fairseq, pos-index is from 2
  pos_emb = _get_param_numpy("tgt_pos_emb", sess)
  pos_emb = torch.cat([torch.zeros(2, pos_emb.size(1)), pos_emb[:-2, :]], dim=0)
  model['decoder.embed_positions.weight'] = pos_emb
  for layer_id in range(int(tgt_layer_num)):
    p1 = 'decoder.layers.%d.self_attn' % layer_id
    p2 = layer_id
    model['%s.in_proj_weight' % p1] = _get_param_numpy('dec_%d_decatt_qvk_trans_w' % p2, sess, transpose=True)
    model['%s.in_proj_bias' % p1] = _get_param_numpy('dec_%d_decatt_qvk_trans_b' % p2, sess, transpose=True)
    model['%s.out_proj.weight' % p1] = _get_param_numpy('dec_%d_decatt_out_trans_w' % p2, sess, transpose=True)
    model['%s.out_proj.bias' % p1] = _get_param_numpy('dec_%d_decatt_out_trans_b' % p2, sess, transpose=True)

    p1 = 'decoder.layers.%d.encoder_attn' % layer_id

    q_w = _get_param_numpy('dec_%d_encdecatt_q_trans_w' % p2, sess, transpose=True)
    q_b = _get_param_numpy('dec_%d_encdecatt_q_trans_b' % p2, sess, transpose=True)
    kv_w = _get_param_numpy('dec_%d_encdecatt_kv_trans_w' % p2, sess, transpose=True)
    kv_b = _get_param_numpy('dec_%d_encdecatt_kv_trans_b' % p2, sess, transpose=True)


    model['%s.in_proj_weight' % p1] = torch.cat([q_w, kv_w], dim=0)
    model['%s.in_proj_bias' % p1] = torch.cat([q_b, kv_b], dim=0)
    model['%s.out_proj.weight' % p1] = _get_param_numpy('dec_%d_encdecatt_out_trans_w' % p2, sess, transpose=True)
    model['%s.out_proj.bias' % p1] = _get_param_numpy('dec_%d_encdecatt_out_trans_b' % p2, sess, transpose=True)

    p1 = 'decoder.layers.%d' % layer_id

    model['%s.fc1.weight' % p1] = _get_param_numpy('dec_%d_fnn_1_w' % p2, sess, transpose=True)
    model['%s.fc1.bias' % p1] = _get_param_numpy('dec_%d_fnn_1_b' % p2, sess, transpose=True)
    model['%s.fc2.weight' % p1] = _get_param_numpy('dec_%d_fnn_2_w' % p2, sess, transpose=True)
    model['%s.fc2.bias' % p1] = _get_param_numpy('dec_%d_fnn_2_b' % p2, sess, transpose=True)

    p1 = 'decoder.layers.%d.layer_norms' % layer_id

    model['%s.0.weight' % p1] = _get_param_numpy('dec_%d_decatt_ln_scale_1' % p2, sess, transpose=True)
    model['%s.0.bias' % p1] = _get_param_numpy('dec_%d_decatt_ln_bias_1' % p2, sess, transpose=True)

    model['%s.1.weight' % p1] = _get_param_numpy('dec_%d_encdecatt_ln_scale_1' % p2, sess, transpose=True)
    model['%s.1.bias' % p1] = _get_param_numpy('dec_%d_encdecatt_ln_bias_1' % p2, sess, transpose=True)

    model['%s.2.weight' % p1] = _get_param_numpy('dec_%d_encdecatt_ln_scale_2' % p2, sess, transpose=True)
    model['%s.2.bias' % p1] = _get_param_numpy('dec_%d_encdecatt_ln_bias_2' % p2, sess, transpose=True)

    if use_dense:
      p1 = 'decoder.history.layer_norms.%d' % layer_id
      model['%s.weight' % p1] = _get_param_numpy('dec_%d_ln_scale_dense' % layer_id, sess, transpose=True)
      model['%s.bias' % p1] = _get_param_numpy('dec_%d_ln_bias_dense' % layer_id, sess, transpose=True)

  if normalize_before:
    p1 = 'decoder.layer_norm'
    model['%s.weight' % p1] = _get_param_numpy('dec_ln_scale_top', sess, transpose=True)
    model['%s.bias' % p1] = _get_param_numpy('dec_ln_bias_top', sess, transpose=True)

  if use_dense:
    dec_layer_weight = _get_param_numpy('dec_layer_weight', sess)
    scale = torch.sum(dec_layer_weight, dim=1, keepdim=True)
    dec_layer_weight = dec_layer_weight / scale
    model['decoder.history.weight'] = dec_layer_weight

  softmax_w_tensor = _get_param_numpy('softmax_w', sess)
  """
  pad_row = softmax_w_tensor[0, :].unsqueeze(0)
  eos_row = softmax_w_tensor[1, :].unsqueeze(0)
  unk_row = softmax_w_tensor[2, :].unsqueeze(0)
  lua_row = softmax_w_tensor[3, :].unsqueeze(0)
  model['decoder.embed_out'] = torch.cat([lua_row, pad_row, eos_row, unk_row, softmax_w_tensor[4:,:]], dim=0)
  """
  # lua_row = torch.zeros(1, softmax_w_tensor.size(1)).fill_(float('-inf'))
  lua_row = torch.zeros(1, softmax_w_tensor.size(1))
  model['decoder.embed_out'] = torch.cat([lua_row, softmax_w_tensor], dim=0)
  return model

def write_vocab():
  print('write vocab file ...')
  vocab_file = open('online.vocab', 'w')

  def _write_to_file(vocab, online_vocab):
    with open(vocab, 'r', encoding='utf-8') as f:
      line_num = 0
      for line in f:
        line_num += 1
        if line_num == 1:
          # print(unicode(u'{} {}'.format(line_num-1, u'<PAD>')) ,file=online_vocab)
          print(u'{} {}'.format(line_num - 1, u'<PAD>'), file=online_vocab)
          continue
        if line_num == 2:
          # print(unicode(u'{} {}'.format(line_num-1, u'<EOS>')), file=online_vocab)
          print(u'{} {}'.format(line_num - 1, u'<EOS>'), file=online_vocab)
          continue
        token = line.strip()[1:-1]
        # print(unicode(u'{} {}'.format(line_num-1, token)), file=online_vocab)
        print(u'{} {}'.format(line_num - 1, token), file=online_vocab)

  _write_to_file(src_vocab_name, vocab_file)
  print(u"================================", file=vocab_file)
  _write_to_file(tgt_vocab_name, vocab_file)
  vocab_file.close()


def convert_dict(vocab_input, vocab_output):
  specials = ['<PAD>', '<UNK>', '<EOS>', '<LUA>']
  with open(vocab_input, 'r') as infile, \
    open(vocab_output + '.src', 'w') as src_ofile, \
    open(vocab_output + '.tgt', 'w') as tgt_ofile:
    is_src = True
    for line in infile:
      # explicitly strip '\n', by default, strip() will delete more characters, which may be the token
      line = line.strip('\n')
      if line.startswith('====='):
        assert is_src
        is_src = False
        continue

      # explicitly split by ' ', or it may be wrong to split due to some special symbols
      items = line.split(' ')
      assert len(items) == 2
      idx, token = items
      if token in specials:
        continue
      print('%s %d' % (token, 1), file=src_ofile if is_src else tgt_ofile)

if __name__ == '__main__':

  parser = argparse.ArgumentParser()
  parser.add_argument('-name', required=True, help="eg: transformer, transformer_dla")
  parser.add_argument('-model', required=True, help="trained model prefix, also include dir, e.g. ../data/model-100")
  parser.add_argument('-src_vocab', required=True, help="source vocabulary path")
  parser.add_argument('-tgt_vocab', required=True, help="target vocabulary path")
  parser.add_argument('-head_num', required=True, type=int,
                      help="head number in MultiHeadAttention, the value can not be inferred from model file")
  parser.add_argument('-tgt_space_id', type=int, default=0,
                      help="this is `EN_BPE_TOK`, which is used target-space-id for Problem `wmt_zhen_tokens_32k`")
  parser.add_argument('-shard_num', type=int, default=16, help="shard number in e.g. embedding or softmax weight")
  parser.add_argument('-model_output', required=True, help='fairseq checkpoint path')
  parser.add_argument('-vocab_output', required=True, help='fairseq dict path')

  args = parser.parse_args()

  model_name = args.name
  # file_name = '../output/model.ckpt-1'
  file_name = args.model
  # src_vocab_name = '../data/180w/tokens.vocab.zh.32768'
  src_vocab_name = args.src_vocab
  # tgt_vocab_name = '../data/180w/tokens.vocab.en.32768'
  tgt_vocab_name = args.tgt_vocab

  # shard = 16
  shard = args.shard_num
  # used for position embedding
  # this is `EN_BPE_TOK`, which is used target-space-id for Problem `wmt_zhen_tokens_32k`
  # target_space_id = 0 in t2t-1.6.5
  target_space_id = args.tgt_space_id

  src_max_length = 1024
  tgt_max_length = 1024
  src_vocab_size = None
  tgt_vocab_size = None
  emb_size = None
  hidden_size = None
  src_layer_num = None
  tgt_layer_num = None
  activation_function = None
  use_relative_position_representation = None
  max_relative_length = -1
  normalize_before = False
  use_dense = False
  norm_type = None
  start = time.time()

  sess = tf.Session()
  reader, trainable_param = find_useful_param(file_name)

  param_dict = load_param()
  settings = creat_settings()
  #print the setting items
  for k, v in settings.items():
    print('{}={}'.format(k, v))

  for k, v in param_dict.items():
    print('{}:shape={}'.format(k,v.shape))

  ck_state = {}
  ck_state['args'] = convert_settings(settings)

  ck_state['model'] = convert_param()
  print(ck_state['args'])

  write_vocab()
  convert_dict('online.vocab', args.vocab_output)
  torch.save(ck_state, args.model_output)

  elapsed = time.time() - start
  print('use %.2f seconds' % elapsed)