# -*- coding: utf-8 -*-
import codecs, re, sys
import argparse

parser = argparse.ArgumentParser(description='Invert translation result to xml format official eval tools require.')
parser.add_argument('--src_testfile_path', required=True, help="the path of source test file.")
parser.add_argument('--refs_testfile_path', required=True, help="the path of refs file, use puntch , split ref files.")
parser.add_argument('--tst_testfile_path', required=True, help="the path of translation of source file.")
parser.add_argument('--output_path', required=True, help="the path of output.")
parser.add_argument('--srclang', required=True)
parser.add_argument('--tgtlang', required=True)
args = parser.parse_args()

src_testfile_path=args.src_testfile_path
refs_testfile_path=args.refs_testfile_path
tst_testfile_path=args.tst_testfile_path
output_path=args.output_path
srclang=args.srclang
tgtlang=args.tgtlang

def get_system_info(organization, system_identify, system_description_info):
    system_label = []
    system_label.append("<system site=\"" + organization + "\""+ " " + "sysid=\"" + system_identify + "\">")
    for info in system_description_info:
        system_label.append(info.strip())
    system_label.append("</system>")
    return system_label

def get_firstLine_info():
    return ["<?xml version=\"1.0\" encoding=\"UTF-8\"?>"]

def get_secondLine_info(setclass, setid, srclang, tgtlang):
    return ["<" + setclass + " setid=\"" + setid + "\"" + " " + "srclang=\"" + srclang + "\"" + " " + "trglang=\"" + tgtlang +"\">"]

def get_tstTail_info():
    return ["</tstset>"]

def get_refTail_info():
    return ["</refset>"]

def XMLformat(line):
    special_char = ["&", ">", "<", "\"", r"'"]
    repalce_char = [r"&amp;", r"&lt;", r"&gt;", r"&quot;", r"&apos;"]
    count = 0
    new_line = ""
    while count < len(line):
        c = line[count]

        cc_spec = 0
        flag = False
        while cc_spec < len(special_char):
            if c == special_char[cc_spec]:
               flag = True
               break
            cc_spec += 1

        if flag:
            line = line[:count] + repalce_char[cc_spec] + line[(count + 1):]

        count += 1
    return line

def handle_src_file(src_file, srclang, tgtlang, out_path):
    # 该函数返回结果
    src_xml_content = []

    # 开头
    src_xml_content.append(get_firstLine_info()[0])
    src_xml_content.append(get_secondLine_info(setclass="srcset", setid= srclang + "_" + tgtlang + "_news_trans", srclang=srclang, tgtlang=tgtlang)[0])

    # 正文
    src_xml_content.append("<DOC docid=" + "\"news\">")

    count = 0
    src_file_handle = codecs.open(src_file, "r", "utf_8_sig")
    for src_line in src_file_handle.read().strip().split("\n"):
        count += 1
        src_xml_content.append("<seg id=" + "\"" + str(count) + "\">" + XMLformat(src_line.strip()) + "</seg>")
    src_xml_content.append("</DOC>")
    src_file_handle.close()

    # 结尾
    src_xml_content.append("</srcset>")

    # 输出
    src_file_out = codecs.open(out_path +  "src.txt.xml", "w", "utf_8_sig")
    for src_line in src_xml_content:
        src_file_out.write(src_line + "\n")
    src_file_out.close()
    return src_xml_content

def get_src_content(src_file, src_lang, tgt_lang, out_path, occasion):

    src_xml_content = []
    if occasion == True:
        src_xml_content = handle_src_file(src_file, src_lang, tgt_lang, out_path)
    else:
        src_file_handle = codecs.open(src_file, "r", "utf_8_sig")
        for line in src_file_handle.read().strip().split("\n"):
            src_xml_content.append(line)
        src_file_handle.close()

    return src_xml_content[2:][:-1]

def main(src_file, refs_file, tst_file, out_path ,src_lang="zh", tgt_lang="en", occasion=True):
    """
    :param src_file:
    :param refs_file: list
    :param tst_file:
    :param out_path:
    :param src_lang:
    :param tgt_lang:
    :param occasion:
    :return:
    """
    # 将src_file转为特定的xml格式 或 读取src_file文件内容
    src_xml_content = get_src_content(src_file, src_lang, tgt_lang, out_path, occasion)

    # 比较src_xml_content内容， 将refs_file转为特定格式
    refs_xml_content = []
    transorg = 1  # 1表示第一个ref的数据，一次类推
    for ref_file in refs_file:

        ref_site = "\"transorg" + str(transorg) + "\""
        ref_file_handle = codecs.open(ref_file, "r", "utf_8_sig")
        ref_lines = ref_file_handle.read().strip().split("\n")

        ref_line_count, count = 0, 0
        ref_line = ""
        while count < len(src_xml_content):
            # src xml line
            src_line = src_xml_content[count]

            # ref raw line
            if ref_line_count < len(ref_lines):
                ref_line = ref_lines[ref_line_count]

            # 依据src_xml_file文件情况
            if "<DOC" in src_line:
                # <DOC docid="news" site="transorg1">
                refs_xml_content.append(src_line.strip()[:-1] + " " + "site=" + ref_site + ">")

            elif "<p>" in src_line or "</p>" in src_line:
                # <p> </p> 单独成一行
                refs_xml_content.append(src_line)

            elif "</DOC>" in src_line:
                # </DOC > 单独成一行
                refs_xml_content.append(src_line)

            elif "<seg" in src_line:

                rs = re.match(r"<seg id=\"(.*?)\">", src_line) # 提取<seg id="402">中间的数字
                if rs:
                    # rs.group(0) = <seg id=\"(.*?)\">
                    line = rs.group(0) + XMLformat(ref_line.strip()) + "</seg>"
                    refs_xml_content.append(line)
                    ref_line_count += 1
            count += 1

        transorg += 1
        ref_file_handle.close()

    # tail
    refs_out_contents = get_firstLine_info() + get_secondLine_info("refset",  src_lang + "_" + tgt_lang + "_news" , src_lang, tgt_lang) + \
                     refs_xml_content + get_refTail_info()


    # 比较src_xml_file,将翻译结果转为xml格式，注意目前只支持一个翻译结果
    tst_file_handle = codecs.open(tst_file, "r", "utf_8_sig")
    tst_lines = tst_file_handle.read().strip().split("\n")

    transed_xml_content = []
    count, tst_line_no = 0, 0
    transed_line = ""
    while count < len(src_xml_content):
        src_line = src_xml_content[count]

        if tst_line_no < len(tst_lines):
            transed_line = tst_lines[tst_line_no]

        if "<DOC" in src_line:
            # <DOC docid="文档名称" sysid="系统标识">
            transed_xml_content.append(src_line[:-1] + " sysid=\"" + src_lang + "_" + tgt_lang + "_trans" + "\">")

        elif "<p>" in src_line or "</p>" in src_line:
            transed_xml_content.append(src_line)

        elif "</DOC" in src_line:
            transed_xml_content.append(src_line)

        elif "<seg" in src_line:
            rs = re.match(r"<seg id=\"(.*?)\">", src_line)
            if rs:
                line = rs.group(0) + XMLformat(transed_line.strip()) + "</seg>"
                transed_xml_content.append(line)
                tst_line_no += 1
        count += 1
    tst_file_handle.close()

    # <system site="单位名称" sysid="系统标识"> tail
    transed_out_contents = get_firstLine_info() + get_secondLine_info("tstset", src_lang + "_" + tgt_lang + "_news", src_lang, tgt_lang) + \
        get_system_info("Niu",  src_lang + "_" + tgt_lang + "_trans", ["系统描述信息"])  + transed_xml_content + get_tstTail_info()


    ## 保存
    ref_out_file = codecs.open(out_path + "ref.txt.xml", "w", "utf_8_sig")
    max_size = len(refs_out_contents)
    count = 0
    while count < max_size:
        out_line = refs_out_contents[count]
        ref_out_file.write(out_line)
        if count != (max_size - 1):
            ref_out_file.write("\n")
        count += 1

    transed_out_file = codecs.open(out_path + "tst.txt.xml", "w", "utf_8_sig")
    max_size = len(transed_out_contents)
    count = 0
    while count < max_size:
        out_line = transed_out_contents[count]
        transed_out_file.write(out_line)
        if count != (max_size - 1):
            transed_out_file.write("\n")
        count += 1

    ref_out_file.close()
    transed_out_file.close()

# main("./eval/ensemble/mt12-wb/input.token", ["./eval/ensemble/mt12-wb/ref0","./eval/ensemble/mt12-wb/ref1", "./eval/ensemble/mt12-wb/ref2", "./eval/ensemble/mt12-wb/ref3"], "./eval/ensemble\mt12-wb/mt12-wb.ensemble", "./eval/ensemble\mt12-wb/")
print ("refs: " + refs_testfile_path)
refs_list = []
for line in refs_testfile_path.strip().split():
    refs_list.append(line.strip())
if len(refs_list) == 1 and len(refs_testfile_path.strip().split(",")) > 1:
    refs_list = []
    for line in refs_testfile_path.strip().split(","):
        refs_list.append(line.strip())

main(src_testfile_path, refs_list, tst_testfile_path, output_path, srclang, tgtlang)
print("xml format file has created.")
