import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('--root_dir', required = True, help = 'a folder contains all feature informations used in reranking')
parser.add_argument('--tag', required = True, help = 'dataset name')
args = parser.parse_args()

root_dir = args.root_dir
tag = args.tag

feature_name = []
weight = []
weight_dic = {} #{feature1:weight1, feature2:weight2.............}
split_str = '=================================='

def get_feature_weight_dic(feature_names, weights):
  """ Return a dic {feature_name1:weight1, feature_name2:weight2, .....}

  Args:
    feature_names: a list of feature names
      e.g.['LS-LT', 'LT-LS' ....]
    weight: a list of weights match features
      e.g.[0.1,0.2.....]

  Returns:
    feature_weight_dic: a dic which key is 'feature-name' and value is weight
      e.g.{'LS-LT':0.1, 'LT-LS':0.2....}
  """

  feature_weight_dic = {}
  for (fn,w) in zip(feature_names, weights):
    feature_weight_dic[fn] = w
  return feature_weight_dic

def get_sten_list(baseline_path='LS-LT'):
  source_file = os.path.join(root_dir + baseline_path, tag + '.input.txt')
  target_file = os.path.join(root_dir + baseline_path, tag + '.output.txt')
  source_file = open(source_file, 'r', encoding='utf-8')
  target_file = open(target_file, 'r', encoding='utf-8')
  s_list = []
  t_list = []
  sten_count = 0
  while 1:
    line1 = source_file.readline().strip()
    line2 = target_file.readline().strip()
    if not line1:
      break
    sten_count += 1
    s_list.append(line1)
    t_list.append(line2)
  target_file.close()
  source_file.close()
  return s_list, t_list, sten_count

def create_feature_dir(root_dir, config_filename = 'feature_weight.config'):
  """Read the config file and return feature names and weights

    Args:
      root_dir:root dir used in rerank

    Returns:
      feature names ,weights
  """

  feature_names = []
  weights = []
  config_file = open(os.path.join(root_dir, config_filename), 'r', encoding='utf-8')
  while 1:
    line = config_file.readline().strip()
    if not line:
      break
    lines = line.split()
    feature_names.append(lines[0])
    weights.append(lines[1])
  config_file.close()
  return feature_names, weights

def get_score_dic(root_dir, feature_names):
  """For each feature, get every score for every sentence

    Args:
      root_dir: root dir used in rerank
      feature_names: a list of feature names
        e.g.['LS-LT', 'LT-LS' ....]

    Returns:
      score_dic: a dic which key is feature-name and value is a list of scores
        e.g.{'LS-LT':[1,2,3], 'LS-RT':[4,5,6]}
  """

  global tag
  tag1 = tag + '.'
  score_dic = {}
  for name in feature_names:
    file = os.path.join(root_dir,name)
    if not os.path.exists(file):#create feature folder
      os.makedirs(file)
      continue
    file = os.path.join(file, tag1+'score.txt')
    if not os.path.exists(file):
      continue
    file = open(file, 'r',encoding='utf-8')
    list = []
    while 1:
      line = file.readline().strip() #score per line
      if not line:
        break
      list.append(line)
    score_dic[name] = list
  return score_dic

def compute_score(score_dic, weight_dic, sten_count):
  """Compute final scores for all sentences

    Args:
      score_dic: a dic
        e.g. {'feature-name':[score1,score2,score3], 'feature-name2':[score4, score5,score6]...}
      weight_dic: a dic
        e.g. {'LS-LT':0.1, 'LT-LS':0.2, ....}
      sten_count: count of sentences

    Returns:
      final_scores: a list
        a list of scores for every sentence in n-best output file.
  """
  final_scores = [0 for _ in range(sten_count)]
  for i in range(sten_count):
    for w in weight_dic:
      if weight_dic[w] == '0' or w not in score_dic:
        continue
      final_scores[i] += float(weight_dic[w]) * float(score_dic[w][i])
  return final_scores

def get_best_score_ids(input_list, scores):
  """Return a list contains the best translation ids

    Args:
      input_list: a list of sentences
      scores: a list of scores for every sentence in n-best output file.

    Returns:
      final_indexs: a list of ids
        id means the line number of the best translation
  """

  final_indexs = []
  index = 0
  max_score = scores[0]
  for s in range(len(scores)):
    if scores[s] > max_score and input_list[s] != split_str:
      max_score = scores[s]
      index = s
    s_next = s + 1
    if input_list[s] == split_str:
      final_indexs.append(index)
      if s_next <len(scores):
        max_score = scores[s_next]
        index = s_next
  return final_indexs

def get_final_translation(output_list, indexs):
  linenu = 0
  final_translation = []
  for output in output_list:
    if linenu in indexs:
      final_translation.append(output)
    linenu += 1
  return final_translation

def write_into_result_file(tag,result_list):
  result_dir = os.path.join(root_dir, 'RESULT')
  output_file = os.path.join(result_dir, tag+'.rerank.output.txt')
  file = open(output_file,'w', encoding='utf-8')
  for l in result_list:
    file.write(l + '\n')
  file.close()

if __name__ == '__main__':
  feature_name, weight = create_feature_dir(root_dir)
  weight_dic = get_feature_weight_dic(feature_name, weight)
  s_list, t_list, sten_count = get_sten_list()
  score_dic = get_score_dic(root_dir, feature_name)
  scores = compute_score(score_dic, weight_dic, sten_count)
  scores_index = get_best_score_ids(s_list, scores)
  result_list = get_final_translation(t_list, scores_index)
  write_into_result_file(tag, result_list)
