#usage: model_dir model_num out_dir

import argparse
import os
import re
import tensorflow as tf
import numpy as np
import six
tf.logging.set_verbosity(tf.logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument('-model_dir', required=True, type=str, help='saved models path')
parser.add_argument('-model_num', required=True, type=int, help='ensembled model numbers, we use the last models')
parser.add_argument('-out_dir', required=True, type=str, help='output ensembled model path, do not set same as model_dir')

args = parser.parse_args()
assert os.path.exists(args.model_dir), 'check model dir!'
assert args.out_dir != args.model_dir, 'do not set model_dir == output_dir'

root_dir, dir_names, file_names = list(os.walk(args.model_dir))[0]
index_list = []
for file in file_names:
    match = re.findall(r'model\.ckpt-(\d+)\.index', file)
    if len(match) == 1:
        index_list += match

# we sort all files by descending order, so the recent file is sorted in the front
index_list = [int(i) for i in index_list]
index_list = sorted(index_list, reverse=True)
print('total find %d model index'%len(index_list))
print(index_list)

model_num = args.model_num
if args.model_num > len(index_list):
    print('warning: you set model_num=%d, however only %d files are detected. so reset model_num=%d'%(args.model_num, len(index_list), len(index_list)))
    model_num = len(index_list)

# get ensembled model index
index_list = index_list[:model_num]
print('using following index-model')
print(index_list)

if not os.path.exists(args.out_dir):
    os.mkdir(args.out_dir)

"""
    extract model parameters
"""
tf.logging.info("Reading variables and averaging checkpoints:")

checkpoints = [os.path.join(args.model_dir, 'model.ckpt-{}'.format(index))  for index in index_list]
for c in checkpoints:
    tf.logging.info("%s ", c)
var_list = tf.contrib.framework.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
    if not name.startswith("global_step"):
        var_values[name] = np.zeros(shape)
for checkpoint in checkpoints:
    reader = tf.contrib.framework.load_checkpoint(checkpoint)
    for name in var_values:
        tensor = reader.get_tensor(name)
        var_dtypes[name] = tensor.dtype
        var_values[name] += tensor
    tf.logging.info("Read from checkpoint %s", checkpoint)
for name in var_values:  # Average.
    var_values[name] /= len(checkpoints)

tf_vars = [
    tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
    for v in var_values
]
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step = tf.Variable(
    0, name="global_step", trainable=False, dtype=tf.int64)
saver = tf.train.Saver(tf.all_variables())

# Build a model consisting only of variables, set them to the average values.
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                           six.iteritems(var_values)):
        sess.run(assign_op, {p: value})
    # Use the built saver to save the averaged checkpoint.
    saver.save(sess, os.path.join(args.out_dir, 'ensemble_%d'%model_num) , global_step=global_step)

tf.logging.info("Averaged checkpoints saved in %s", args.out_dir)
