import argparse
import os
import random
import sys
import six
import tensorflow as tf


parser = argparse.ArgumentParser()
parser.add_argument('--src_trainfile_path',required=True,help='the path of source training file')
parser.add_argument('--tgt_trainfile_path',required=True,help='the path of target training file')
parser.add_argument('--src_devfile_path',required=True,help='the path of source dev file')
parser.add_argument('--tgt_devfile_paths', nargs='+',required=True,help="the paths of target dev files")
parser.add_argument('--final_traindata_path',required=True,help="the paths of final training data")
parser.add_argument('--src_dic_path',required=True,help='the path of source dic file')
parser.add_argument('--tgt_dic_path',required=True,help='the path of target dic file')
args = parser.parse_args()

src_trainfile_path=args.src_trainfile_path
tgt_trainfile_path=args.tgt_trainfile_path
src_devfile_path=args.src_devfile_path
tgt_devfile_paths=args.tgt_devfile_paths
src_dic_path=args.src_dic_path
tgt_dic_path=args.tgt_dic_path
final_traindata_path=args.final_traindata_path

srcdicfilepath=''#source dictionary path during training
tgtdicfilepath=''#target dictionary path during training
srcdic={}
tgtdic={}
trainsrcnumfilepath=''#number list file corresponding to source training file
traintgtnumfilepath=''#number list file corresponding to target training file
devsrcnumfilepath=''#number list file corresponding to source dev file
devtgtnumfilespath=[]#number list files corresponding to target dev files

UNSHUFFLED_SUFFIX = "-unshuffled"

def validate_flags():
  """
  check the args are right or not
  """
  if not os.path.exists(src_trainfile_path):
    raise ValueError("source training file:%s is not existed!"%(src_trainfile_path))
  if not os.path.exists(tgt_trainfile_path):
    raise ValueError("target training file:%s is not existed!"%(tgt_trainfile_path))
  if not os.path.exists(src_devfile_path):
    raise ValueError("source dev file:%s is not existed!"%(src_devfile_path))
  for t in tgt_devfile_paths:
    if not os.path.exists(t):
      raise ValueError("target dev file:%s is not existed!"%(t))

def create_dir_and_filenames():
  """
  Create output dir and generate filenames will be used
  """
  if not os.path.exists(final_traindata_path):
    os.mkdir(final_traindata_path)
  global srcdicfilepath
  global tgtdicfilepath
  global trainsrcnumfilepath
  global traintgtnumfilepath
  global devsrcnumfilepath
  srcdicfilepath=os.path.join(final_traindata_path,'tokens.vocab.zh.32768')
  tgtdicfilepath=os.path.join(final_traindata_path,'tokens.vocab.en.32768')
  trainsrcnumfilepath=os.path.join(final_traindata_path,'trainsrc.num')
  traintgtnumfilepath=os.path.join(final_traindata_path,'traintgt.num')
  devsrcnumfilepath=os.path.join(final_traindata_path,'devsrc.num')
  for i in range(0,len(tgt_devfile_paths)):
    devtgtnumfilespath.append(os.path.join(final_traindata_path,'devtgt.num'+str(i)))


def read_dic(inputfile):
    dict={}
    index=2
    with open(inputfile,'r',encoding="utf-8") as fr:
        dict['PAD']=0
        dict['EOS']=1
        for line in fr.readlines()[2:]:
            line=line.strip().replace('\n','')
            print(line)
            dict[line[1:-1]]=index
            index+=1
    return dict

def create_dic(trainfilepath,outputfile,dicsize):
  """
  Produce the dictionary used in training
  
  Args:
        trainfilepath: the path of train file path
        outputfile:the path of dictionary
        dicsize:the size of dictionary
  Return:
    dic={word1,linenumber1,word2,linenumber2......}
  """
  dic={}
  file=open(trainfilepath,'r')
  fileoutput=open(outputfile,'w')
  alreadyreadline=0
  while 1:
    line=file.readline().strip()
    if not line:
      break
    alreadyreadline=alreadyreadline+1
    if int(alreadyreadline)%50000 == 0:
      print('reading traindata..line=',alreadyreadline)
    lines=line.split()
    for i in range(0,len(lines)):
      if lines[i] not in dic:
        dic[lines[i]]=0
      else:
        dic[lines[i]]=dic[lines[i]]+1
  finallist=sorted(dic.items(),key=lambda item:(item[1],item[0]),reverse=True)
  listdic={}
  diclinenu=3
  fileoutput.write('\''+'\''+'\n')
  fileoutput.write('\''+'\''+'\n')
  fileoutput.write('\''+'<UNK>'+'\''+'\n')
  dicsize-=2
  for i in finallist:
    dicsize=dicsize-1
    if dicsize ==0:
      break
    fileoutput.write('\''+i[0]+'\'\n')
    listdic[i[0]]=diclinenu
    diclinenu+=1
  fileoutput.close()
  file.close()
  return listdic

def create_numfile(workname,dic,inputfile,outputfile):
  """
  create number file
  Args:
  workname:type of file
  inputfile:the file need to tranform to number file
  outputfile:the number file
  """
  nu=0
  inputfile=open(inputfile,'r')
  outputfile=open(outputfile,'w')
  while 1:
    line1=inputfile.readline().strip()
    if not line1:
      break
    nu+=1
    if nu%10000==0:
      print('creating numlist of',workname,'line=',nu)
    words=line1.split()
    for wordindex in range(0,len(words)):
      if words[wordindex] in dic:
        index=dic[words[wordindex]]
      else:
        index=2 #UNK
      outputfile.write(str(index)+' ')
    outputfile.write('1\n')
  outputfile.close()
  inputfile.close()

def get_outputfile_names(maxshard=1,train=True):
  """
  Args:
    maxshard: 
      the count of train file of dev file
    train:
      the file if training file or not
  Return:
      a list of file names for training or dev
  """
  basenames=[]
  basename=os.path.join(final_traindata_path,'wmt_zhen_tokens_32k-unshuffled-train-') if train else os.path.join(final_traindata_path,'wmt_zhen_tokens_32k-unshuffled-dev-')
  for i in range(0,maxshard):
    newname=basename+'%.5d'%i
    newname+='-'
    newname=newname+'%.5d'%maxshard
    basenames.append(newname)
  return basenames

def to_example(dictionary):
  """
  build tf.Example from (string -> int/float/str list) dictionary.

  """
  features = {}
  for (k, v) in six.iteritems(dictionary):
    if not v:
      raise ValueError("Empty generated field: %s", str((k, v)))
    if isinstance(v[0], six.integer_types):
      features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
    elif isinstance(v[0], float):
      features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
    elif isinstance(v[0], six.string_types):
      if not six.PY2:  # Convert in python 3.
        v = [bytes(x, "utf-8") for x in v]
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
    elif isinstance(v[0], bytes):
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
    else:
      raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
                       (k, str(v[0]), str(type(v[0]))))
  return tf.train.Example(features=tf.train.Features(feature=features))

def generate_trainfile(generator, output_filenames, max_cases=None):
  """Generate cases from a generator and save as TFRecord files.

  Generated cases are transformed to tf.Example protos and saved as TFRecords
  in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.

  Args:
    generator: a generator for training yielding (string -> int/float/str list) dictionaries.
    output_filenames: List of output training file paths.
    max_cases: maximum number of cases to get from the generator;
      if None (default), we use the generator until StopIteration is raised.
  """
  num_shards = len(output_filenames)
  writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
  counter, shard = 0, 0
  for case in generator:
    if counter > 0 and counter % 100000 == 0:
      print('Generating case=',counter)
    counter += 1
    if max_cases and counter > max_cases:
      break
    sequence_example = to_example(case)
    writers[shard].write(sequence_example.SerializeToString())
    shard = (shard + 1) % num_shards
  print('all training case generate over!,lines=',counter)
  for writer in writers:
    writer.close()
def generate_devfile(generators, output_filenames, max_cases=None):
  """Generate cases from a generator and save as TFRecord files.

  Generated cases are transformed to tf.Example protos and saved as TFRecords
  in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.

  Args:
    generators: a generator set for training yielding (string -> int/float/str list) dictionaries.
    output_filenames: List of output dev file paths.
    max_cases: maximum number of cases to get from the generator;
      if None (default), we use the generator until StopIteration is raised.
  """
  num_shards = len(output_filenames)
  writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
  counter, shard = 0, 0
  for generator in generators:
    counter, shard=0, 0
    for case in generator:
      if counter > 0 and counter % 100000 == 0:
        print('Generating case=',counter)
      counter += 1
      if max_cases and counter > max_cases:
        break
      sequence_example = to_example(case)
      writers[shard].write(sequence_example.SerializeToString())
      shard = (shard + 1) % num_shards
    print('dev case generate over!,lines=',counter)
  for writer in writers:
    writer.close()

def creategenerator(src,tgt):
  return create_case(src,tgt)

def create_case(src,tgt):
  """create a list of a dic type. [{"inputs": source-line1, "targets": target-line1},
                                   {"inputs": source-line2, "targets": target-line2},.....]
  Args:
    src: path to the source number file.
    tgt: path to the target number file.
  Yields:
    A dictionary {"inputs": source-line, "targets": target-line} where
    the lines are integer lists converted from tokens in the file lines.
  """

  file1=open(src,'r')
  file2=open(tgt,'r')
  while 1:
    line1=file1.readline()
    line2=file2.readline()
    if not line1:
      break
    srcints=[]
    tgtints=[]
    for l1 in line1.split():
      srcints.append(int(l1))
    for l2 in line2.split():
      tgtints.append(int(l2))
    yield {"inputs": srcints, "targets": tgtints}
def read_records(filename):
  reader = tf.python_io.tf_record_iterator(filename)
  records = []
  for record in reader:
    records.append(record)
    if len(records) % 100000 == 0:
      tf.logging.info("read: %d", len(records))
  return records

def write_records(records, out_filename):
  writer = tf.python_io.TFRecordWriter(out_filename)
  for count, record in enumerate(records):
    writer.write(record)
    if count > 0 and count % 100000 == 0:
      tf.logging.info("write: %d", count)
  writer.close()

def shuffle_dataset(filenames):
  print("Shuffling data...")
  for fname in filenames:
    records = read_records(fname)
    random.shuffle(records)
    out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
    write_records(records, out_fname)
    tf.gfile.Remove(fname)
  print('Shuffling data finished!')


create_dir_and_filenames()
validate_flags()
srcdic=read_dic(src_dic_path)
tgtdic=read_dic(tgt_dic_path)
create_numfile('sourcedev',srcdic,src_devfile_path,devsrcnumfilepath)
for i in range(0,len(tgt_devfile_paths)):
  create_numfile('targetdev'+str(i),tgtdic,tgt_devfile_paths[i],devtgtnumfilespath[i])
print('generating training data..')
devset=[]#dev data number file
for i in range(0,len(tgt_devfile_paths)):
  print('generating dev data'+str(i))
  devset.append(creategenerator(devsrcnumfilepath,devtgtnumfilespath[i]))
generate_devfile(devset,get_outputfile_names(1,False))
shuffle_dataset(get_outputfile_names(1,False))
