from __future__ import print_function
# a common method
from io import open as open
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
from tensorflow.python import pywrap_tensorflow
import re
import tensorflow as tf
import math
from collections import OrderedDict
import time
import struct
import argparse
import sys

try:
    UNICODE_EXISTS = bool(type(unicode))
except NameError:
    unicode = str


parser = argparse.ArgumentParser()
parser.add_argument('-name', required=True, help="eg: transformer, transformer_dla")
parser.add_argument('-param_set', required=True, help="eg: transformer_base_before or transformer_base_after to decide the normalize type")
parser.add_argument('-model', required=True, help="trained model prefix, also include dir, e.g. ../data/model-100")
parser.add_argument('-src_vocab', required=True, help="source vocabulary path")
parser.add_argument('-tgt_vocab', required=True, help="target vocabulary path")
parser.add_argument('-head_num', required=True, type=int,
                    help="head number in MultiHeadAttention, the value can not be inferred from model file")
parser.add_argument('-binary', default=True, action='store_true', help="outputed model is saved in binary format?")
parser.add_argument('-max_seq_length', type=int, default=200, help="used for position embedding")
parser.add_argument('-tgt_space_id', type=int, default=4,
                    help="this is `EN_BPE_TOK`, which is used target-space-id for Problem `wmt_zhen_tokens_32k`")
parser.add_argument('-shard_num', type=int, default=16, help="shard number in e.g. embedding or softmax weight")

args = parser.parse_args()

model_name = args.name
print(model_name)
print(type(model_name))
param_set = args.param_set

norm_type = None

if "before" in param_set:
  norm_type = "layer_prepostprocess"
elif "after" in param_set:
  norm_type = "layer_postprocess"

#file_name = '../output/model.ckpt-1'
file_name = args.model
#src_vocab_name = '../data/180w/tokens.vocab.zh.32768'
src_vocab_name = args.src_vocab
#tgt_vocab_name = '../data/180w/tokens.vocab.en.32768'
tgt_vocab_name = args.tgt_vocab

#shard = 16
shard = args.shard_num
# used for position embedding
#max_seq_length = 200
max_seq_length = args.max_seq_length
# this is `EN_BPE_TOK`, which is used target-space-id for Problem `wmt_zhen_tokens_32k`
#target_space_id = 4
target_space_id = args.tgt_space_id
#binary = True
binary = args.binary

# these are setting that we can infer from checkpoint
src_vocab_size = None
tgt_vocab_size = None
emb_size = None
hidden_size = None
src_layer_num = None
tgt_layer_num = None
activation_function = None
use_relative_position_representation = None
max_relative_length = -1


def find_useful_param(model_file):
    global src_vocab_size
    global tgt_vocab_size
    global emb_size
    global hidden_size
    global src_layer_num
    global tgt_layer_num
    global activation_function
    global use_relative_position_representation
    global max_relative_length

    trainable_param = dict()
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(model_file)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in sorted(var_to_shape_map):
            # this is our real trainable parameters
            print(key)
            if key.startswith(model_name):
                # print("tensor_name: {} - shape: {}".format(key, var_to_shape_map[key]))
                trainable_param[key] = var_to_shape_map[key]

                # capture source vocab size and emb_size
                match = re.match(model_name+'/symbol_modality_(\d+)_(\d+)/input_emb', key)
                if match and (src_vocab_size is None and emb_size is None):
                    src_vocab_size = int(match.group(1))
                    emb_size = int(match.group(2))

                # capture target vocab size
                match = re.match(model_name+'/symbol_modality_(\d+)_(\d+)/target_emb', key)
                if match and tgt_vocab_size is None:
                    tgt_vocab_size = int(match.group(1))

                # capture encoder layer number
                match = re.match(model_name+'/body/encoder/layer_(\d+)/', key)
                if match:
                    if src_layer_num is None or int(match.group(1)) > src_layer_num:
                        src_layer_num = int(match.group(1))

                # capture decoder layer number
                match = re.match(model_name+'/body/decoder/layer_(\d+)/', key)
                if match:
                    if tgt_layer_num is None or int(match.group(1)) > tgt_layer_num:
                        tgt_layer_num = int(match.group(1))

                pattern = re.compile('/conv_hidden_([^/]+)/')
                match = re.search(pattern, key)
                if match and activation_function is None:
                    activation_function = match.group(1)

                pattern = re.compile('dot_product_attention_relative')
                match = re.search(pattern, key)
                if match and use_relative_position_representation is None:
                    use_relative_position_representation = 1
                    max_relative_length = (var_to_shape_map[key][0] - 1) // 2

    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
    assert len(trainable_param) > 0, "not found any trainable parameters"

    # add 1 due to index from 0
    src_layer_num += 1
    tgt_layer_num += 1
    if activation_function is None:
        activation_function = 'relu'
    if use_relative_position_representation is None:
        use_relative_position_representation = False
    if use_relative_position_representation is None:
        use_relative_position_representation = 0

    print('find src-vocab:{} tgt-vocab:{} emb-size:{} src-layer:{} tgt-layer:{} activation:{} relative:{}'.
          format(src_vocab_size, tgt_vocab_size, emb_size, src_layer_num, tgt_layer_num,
                 activation_function, use_relative_position_representation))
    return reader, trainable_param





def load_param():
    print('loadding params ...')
    param_dict = OrderedDict()

    def _get_tensor(name, shard=1):
        tensor = None
        if shard == 1:
            full_name = name
            tensor = reader.get_tensor(full_name)
        if shard > 1:
            tensor_list = []
            for i in range(shard):
                full_name = '{}_{}'.format(name, i)
                tensor_list += [reader.get_tensor(full_name)]
            tensor = tf.concat(tensor_list, axis=0)
        print('{} shape is: {}'.format(full_name, tensor.shape))
        return tensor

    # source embedding, shape is src_vocab * emb_size
    src_emb_tensor = _get_tensor(model_name + '/symbol_modality_{}_{}/input_emb/weights'.format(src_vocab_size, emb_size), shard=shard)
    param_dict['src_emb'] = src_emb_tensor

    # space embedding, we only use target_space_id, so the shape is H
    space_emb_tensor = _get_tensor(model_name + '/body/target_space_embedding/kernel')
    param_dict['space_emb'] = tf.gather(space_emb_tensor, tf.constant(target_space_id))
    print('target space id emb shape:', param_dict['space_emb'].shape)

    # position embedding, shape is max-seq-length * emb_size
    # note that this tensor actually has no need training, it is a fixed tensor, we precompute it for speedup
    # return L * H
    def get_pos_emb(seq_len=max_seq_length, emb_size=512, min_timescale=1.0, max_timescale=1.0e4):
        length = seq_len
        channels = emb_size
        position = tf.to_float(tf.range(length))
        num_timescales = channels // 2
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            (tf.to_float(num_timescales) - 1))
        inv_timescales = min_timescale * tf.exp(
            tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
        # wq: shape is L * H/2
        scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
        # wq: shape is L * H (note that if H is odd, actually current signal shape is H-1)
        signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
        # wq: if H is odd, then we need to add 1 at the end of signal, to make the shape as L * H
        signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
        return signal

    # note, actually pos_emb is same in source and target.
    # However, in practice, we write the pos_emb in both source and target sides
    # This is convinient for us to init encoder and decoder independently
    param_dict['src_pos_emb'] = get_pos_emb(emb_size=emb_size)

    # encoder
    for layer_id in range(src_layer_num):
        # step 1. qvk_trans_w shape is H * 3H (because query, key, value is same); qvk_trans_b shape is 3H
        qvk_trans_w = _get_tensor(model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/kernel' % layer_id)
        qvk_trans_w = tf.squeeze(qvk_trans_w, axis=[0, 1])
        qvk_trans_b = _get_tensor(model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/bias' % layer_id)
        param_dict['enc_%d_qvk_trans_w' % layer_id] = qvk_trans_w
        param_dict['enc_%d_qvk_trans_b' % layer_id] = qvk_trans_b

        # step 2. output_trans_w shape is H * H; output_trans_b shape is H
        output_tarns_w = _get_tensor(
            model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/output_transform_single/kernel' % layer_id)
        output_tarns_w = tf.squeeze(output_tarns_w, axis=[0, 1])
        output_trans_b = _get_tensor(
            model_name + '/body/encoder/layer_%d/self_attention/multihead_attention/output_transform_single/bias' % layer_id)
        param_dict['enc_%d_out_trans_w' % layer_id] = output_tarns_w
        param_dict['enc_%d_out_trans_b' % layer_id] = output_trans_b

        # step 7. optional. used for relative position representation
        if use_relative_position_representation is not None and use_relative_position_representation:
            rp_key = _get_tensor(
                'body/encoder/layer_%d/encoder_self_attention/dot_product_attention_relative/relative_positions_keys/embeddings' % layer_id)
            rp_value = _get_tensor(
                'body/encoder/layer_%d/encoder_self_attention/dot_product_attention_relative/relative_positions_values/embeddings' % layer_id)
            param_dict['enc_%d_rp_key' % layer_id] = rp_key
            param_dict['enc_%d_rp_value' % layer_id] = rp_value


        # step 3. layer normalization
        layer_normal_bias = _get_tensor(model_name + '/body/encoder/layer_%d/self_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_bias' )
        layer_normal_scale = _get_tensor(model_name + '/body/encoder/layer_%d/self_attention/'% layer_id + norm_type + '/layer_norm/layer_norm_scale')
        param_dict['enc_%d_ln_scale_1' % layer_id] = layer_normal_scale
        param_dict['enc_%d_ln_bias_1' % layer_id] = layer_normal_bias

        # step 4. fnn_layer_1, w shape is H * X, b shape is X
        fnn_layer_1_w = _get_tensor(model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv1_single/kernel' % (layer_id, activation_function))
        fnn_layer_1_w = tf.squeeze(fnn_layer_1_w, axis=[0, 1])
        fnn_layer_1_b = _get_tensor(model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv1_single/bias' % (layer_id, activation_function))
        param_dict['enc_%d_fnn_1_w' % layer_id] = fnn_layer_1_w
        param_dict['enc_%d_fnn_1_b' % layer_id] = fnn_layer_1_b

        # step 5, fnn_layer_2, w shape is X * H, b shape is H
        fnn_layer_2_w = _get_tensor(model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv2_single/kernel' % (layer_id, activation_function))
        fnn_layer_2_w = tf.squeeze(fnn_layer_2_w, axis=[0, 1])
        fnn_layer_2_b = _get_tensor(model_name + '/body/encoder/layer_%d/ffn/conv_hidden_%s/conv2_single/bias' % (layer_id, activation_function))
        param_dict['enc_%d_fnn_2_w' % layer_id] = fnn_layer_2_w
        param_dict['enc_%d_fnn_2_b' % layer_id] = fnn_layer_2_b

        # step 6. layer normalization 2
        layer_normal_bias_2 = _get_tensor(model_name + '/body/encoder/layer_%d/ffn/' % layer_id + norm_type +'/layer_norm/layer_norm_bias')
        layer_normal_scale_2 = _get_tensor(model_name + '/body/encoder/layer_%d/ffn/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
        param_dict['enc_%d_ln_scale_2' % layer_id] = layer_normal_scale_2
        param_dict['enc_%d_ln_bias_2' % layer_id] = layer_normal_bias_2

    # decoder
    # target embedding, shape is tgt_vocab * emb_size
    tgt_emb_tensor = _get_tensor(model_name + '/symbol_modality_{}_{}/target_emb/weights'.format(tgt_vocab_size, emb_size),
                                 shard=shard)
    param_dict['tgt_emb'] = tgt_emb_tensor
    param_dict['tgt_pos_emb'] = get_pos_emb(emb_size=emb_size)

    for layer_id in range(tgt_layer_num):
        # step 1. decoder_self_attention
        # qvk_trans_w shape is H * 3H (because query, key, value is same); qvk_trans_b shape is 3H
        qvk_trans_w = _get_tensor(model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/kernel' % layer_id)
        qvk_trans_w = tf.squeeze(qvk_trans_w, axis=[0, 1])
        qvk_trans_b = _get_tensor(model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/qkv_transform_single/bias' % layer_id)
        param_dict['dec_%d_decatt_qvk_trans_w' % layer_id] = qvk_trans_w
        param_dict['dec_%d_decatt_qvk_trans_b' % layer_id] = qvk_trans_b

        # output_trans_w shape is H * H; output_trans_b shape is H
        output_tarns_w = _get_tensor(
            model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/output_transform_single/kernel' % layer_id)
        output_tarns_w = tf.squeeze(output_tarns_w, axis=[0, 1])
        output_trans_b = _get_tensor(
            model_name + '/body/decoder/layer_%d/self_attention/multihead_attention/output_transform_single/bias' % layer_id)
        param_dict['dec_%d_decatt_out_trans_w' % layer_id] = output_tarns_w
        param_dict['dec_%d_decatt_out_trans_b' % layer_id] = output_trans_b

        # optional. used for relative position representation
        if use_relative_position_representation:
            rp_key = _get_tensor(
                'body/decoder/layer_%d/decoder_self_attention/dot_product_attention_relative/relative_positions_keys/embeddings' % layer_id)
            rp_value = _get_tensor(
                'body/decoder/layer_%d/decoder_self_attention/dot_product_attention_relative/relative_positions_values/embeddings' % layer_id)
            param_dict['dec_%d_rp_key' % layer_id] = rp_key
            param_dict['dec_%d_rp_value' % layer_id] = rp_value

        # layer normalization
        layer_normal_bias = _get_tensor(model_name + '/body/decoder/layer_%d/self_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_bias')
        layer_normal_scale = _get_tensor(model_name + '/body/decoder/layer_%d/self_attention/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
        param_dict['dec_%d_decatt_ln_scale_1' % layer_id] = layer_normal_scale
        param_dict['dec_%d_decatt_ln_bias_1' % layer_id] = layer_normal_bias

        # step 2. encoder-decoder attention
        # shape is H * H
        encdec_q_trans_w = _get_tensor(model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/q_transform_single/kernel' % layer_id)
        encdec_q_trans_w = tf.squeeze(encdec_q_trans_w, axis=[0, 1])
        encdec_q_trans_b = _get_tensor(model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/q_transform_single/bias' % layer_id)
        param_dict['dec_%d_encdecatt_q_trans_w' % layer_id] = encdec_q_trans_w
        param_dict['dec_%d_encdecatt_q_trans_b' % layer_id] = encdec_q_trans_b

        # shape is H * 2H
        encdec_kv_trans_w = _get_tensor(model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/kv_transform_single/kernel' % layer_id)
        encdec_kv_trans_w = tf.squeeze(encdec_kv_trans_w, axis=[0, 1])
        encdec_kv_trans_b = _get_tensor(model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/kv_transform_single/bias' % layer_id)
        param_dict['dec_%d_encdecatt_kv_trans_w' % layer_id] = encdec_kv_trans_w
        param_dict['dec_%d_encdecatt_kv_trans_b' % layer_id] = encdec_kv_trans_b

        # shape is H * H
        encdec_out_trans_w = _get_tensor(
            model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/output_transform_single/kernel' % layer_id)
        encdec_out_trans_w = tf.squeeze(encdec_out_trans_w, axis=[0, 1])
        encdec_out_trans_b = _get_tensor(
            model_name + '/body/decoder/layer_%d/encdec_attention/multihead_attention/output_transform_single/bias' % layer_id)
        param_dict['dec_%d_encdecatt_out_trans_w' % layer_id] = encdec_out_trans_w
        param_dict['dec_%d_encdecatt_out_trans_b' % layer_id] = encdec_out_trans_b

        encdec_layer_normal_bias = _get_tensor(model_name + '/body/decoder/layer_%d/encdec_attention/'% layer_id + norm_type +'/layer_norm/layer_norm_bias' )
        encdec_layer_normal_scale = _get_tensor(model_name + '/body/decoder/layer_%d/encdec_attention/'% layer_id + norm_type +'/layer_norm/layer_norm_scale' )
        param_dict['dec_%d_encdecatt_ln_scale_1' % layer_id] = encdec_layer_normal_scale
        param_dict['dec_%d_encdecatt_ln_bias_1' % layer_id] = encdec_layer_normal_bias

        # step 3. fnn layer
        # fnn_layer_1, w shape is H * X, b shape is X
        fnn_layer_1_w = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv1_single/kernel' % (layer_id, activation_function))
        fnn_layer_1_w = tf.squeeze(fnn_layer_1_w, axis=[0, 1])
        fnn_layer_1_b = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv1_single/bias' % (layer_id, activation_function))
        param_dict['dec_%d_fnn_1_w' % layer_id] = fnn_layer_1_w
        param_dict['dec_%d_fnn_1_b' % layer_id] = fnn_layer_1_b

        # fnn_layer_2, w shape is X * H, b shape is H
        fnn_layer_2_w = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv2_single/kernel' % (layer_id, activation_function))
        fnn_layer_2_w = tf.squeeze(fnn_layer_2_w, axis=[0, 1])
        fnn_layer_2_b = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/conv_hidden_%s/conv2_single/bias' % (layer_id, activation_function))
        param_dict['dec_%d_fnn_2_w' % layer_id] = fnn_layer_2_w
        param_dict['dec_%d_fnn_2_b' % layer_id] = fnn_layer_2_b

        encdec_layer_normal_bias_2 = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/'% layer_id + norm_type + '/layer_norm/layer_norm_bias' )
        encdec_layer_normal_scale_2 = _get_tensor(model_name + '/body/decoder/layer_%d/ffn/' % layer_id + norm_type + '/layer_norm/layer_norm_scale')
        param_dict['dec_%d_encdecatt_ln_scale_2' % layer_id] = encdec_layer_normal_scale_2
        param_dict['dec_%d_encdecatt_ln_bias_2' % layer_id] = encdec_layer_normal_bias_2



    softmax_w_tensor = _get_tensor(model_name + '/symbol_modality_{}_{}/softmax/weights'.format(tgt_vocab_size, emb_size),
                                 shard=shard)
    # note: we transpose the matrix, from (V,H) to (H,V)
    param_dict['softmax-w'] = tf.transpose(softmax_w_tensor, perm=[1,0])
    return param_dict

def _get_param_numpy(name, sess):
    assert name in param_dict, "unknown param nam:{} in dict".format(name)
    value = param_dict[name]
    if type(value) is tf.Tensor:
        value = sess.run(value)
    return value

def write_model():
    print('write model ...')
    mode = 'wb' if binary else 'w'
    ofile = open('online.model', mode)

    def _write_tensor(name, tensor, binary=False):
        assert type(tensor) is not tf.Tensor, "can not write tf.Tensor, you should pass a numpy.ndarray"
        if len(tensor.shape) == 1:
            num = tensor.shape[0]
            if not binary:
                for i in range(num):
                    print('%.5f ' % tensor[i], file=ofile, end='')
                print('\n' + name, file=ofile)
            else:
                for i in range(num):
                    ofile.write(struct.pack("f", tensor[i]))
        if len(tensor.shape) == 2:
            row, col = tensor.shape
            if not binary:
                for i in range(row):
                    for j in range(col):
                        print('%.5f ' % tensor[i, j], file=ofile, end='')
                    print('', file=ofile)
                print(name, file=ofile)
            else:
                for i in range(row):
                    for j in range(col):
                        ofile.write(struct.pack("f", tensor[i, j]))


    sess = tf.Session()

    # for k,v in param_dict.items():
    # print('{} - {}'.format(k,type(v)))

    start_time = time.time()
    # write model setting: src_layer, tgt_layer, hidden, inner_hidden, head_num, src_vocab, tgt_vocab
    ofile.write(struct.pack("i", src_layer_num))
    ofile.write(struct.pack("i", tgt_layer_num))
    ofile.write(struct.pack("i", emb_size))
    ofile.write(struct.pack("i", _get_param_numpy('enc_0_fnn_1_b', sess).shape[0]))
    ofile.write(struct.pack("i", args.head_num))
    ofile.write(struct.pack("i", src_vocab_size))
    ofile.write(struct.pack("i", tgt_vocab_size))
    # note: we limit the activation name as 16 chars
    ofile.write(struct.pack("16s", activation_function.encode('ascii')))
    ofile.write(struct.pack("i", use_relative_position_representation))
    ofile.write(struct.pack("i", max_relative_length))

    for k in param_dict.keys():
        print('dump {} ...'.format(k))
        param = _get_param_numpy(k, sess)
        _write_tensor(k, param, binary=binary)
    print('convert model using %f seconds' % (time.time() - start_time))
    ofile.close()

def write_vocab():
    print('write vocab file ...')
    vocab_file = open('online.vocab','w')
    def _write_to_file(vocab, online_vocab):
        with open(vocab, 'r', encoding='utf-8') as f:
            line_num = 0
            for line in f:
                line_num += 1
                if line_num == 1:
                    #print(unicode(u'{} {}'.format(line_num-1, u'<PAD>')) ,file=online_vocab)
                    print(u'{} {}'.format(line_num - 1, u'<PAD>'), file=online_vocab)
                    continue
                if line_num == 2:
                    #print(unicode(u'{} {}'.format(line_num-1, u'<EOS>')), file=online_vocab)
                    print(u'{} {}'.format(line_num - 1, u'<EOS>'), file=online_vocab)
                    continue
                token = line.strip()[1:-1]
                #print(unicode(u'{} {}'.format(line_num-1, token)), file=online_vocab)
                print(u'{} {}'.format(line_num - 1, token), file=online_vocab)
    _write_to_file(src_vocab_name, vocab_file)
    print(u"================================", file=vocab_file)
    _write_to_file(tgt_vocab_name, vocab_file)
    vocab_file.close()


dump_model = True
show_part = False
if dump_model:
    reader, trainable_param = find_useful_param(file_name)
    param_dict = load_param()
    print(param_dict.keys())
    write_model()
    write_vocab()
    print("done!")

if show_part:
    reader, trainable_param = find_useful_param(file_name)
    param_dict = load_param()
    with tf.Session() as sess:
        #print(_get_param_numpy('src_emb', sess)[:10,:10])
        #print(_get_param_numpy('space_emb', sess)[:10])
        print(_get_param_numpy('enc_0_rp_key', sess)[:10])
        print('*'*10)
        print(_get_param_numpy('dec_0_rp_key', sess)[:10])
        print('*'*10)
        print(_get_param_numpy('softmax-w', sess).shape)

"""
    print(_get_param_numpy('pos_emb', sess)[:10, :10])
    print(_get_param_numpy('enc_0_qvk_trans_w', sess)[:10, :10])
    print(_get_param_numpy('enc_0_qvk_trans_b', sess)[:10])
    print(_get_param_numpy('enc_0_out_trans_w', sess)[:10, :10])
    print(_get_param_numpy('enc_0_out_trans_b', sess)[:10])
    print(_get_param_numpy('enc_0_ln_scale_1', sess)[:10])
    print(_get_param_numpy('enc_0_ln_bias_1', sess)[:10])
    print(_get_param_numpy('enc_0_fnn_1_w', sess)[:10,:10])
    print(_get_param_numpy('enc_0_fnn_1_b', sess)[:10])
    print(_get_param_numpy('enc_0_fnn_2_w', sess)[:10,:10])
    print(_get_param_numpy('enc_0_fnn_2_b', sess)[:10])
    print(_get_param_numpy('enc_0_ln_scale_2', sess)[:10])
    print(_get_param_numpy('enc_0_ln_bias_2', sess)[:10])
"""

