import unicodedata
import re
import jiwer
import jiwer.transforms as tr
import sys

ref_file = sys.argv[1]
hyp_file = sys.argv[2]


wer_standardize = tr.Compose(
    [
        tr.SubstituteRegexes({r"<<unk>>": r"@"}),
        tr.ToLowerCase(),
        tr.RemovePunctuation(),
        tr.ExpandCommonEnglishContractions(),
        tr.RemoveKaldiNonWords(),
        tr.RemoveWhiteSpace(replace_by_space=True),
        tr.ReduceToListOfListOfWords(),
    ]
)
cer_standardize = tr.Compose(
    [
        tr.SubstituteRegexes({r"<<unk>>": r"@"}),
        tr.ToLowerCase(),
        tr.RemovePunctuation(),
        tr.Strip(),
        tr.ReduceToListOfListOfChars(),
    ]
)

def process_text(text):
    # 将中文字符和英文字符间加空格
    text = re.sub(r'([\u4e00-\u9fa5])([a-zA-Z0-9])', r'\1 \2', text)
    text = re.sub(r'([a-zA-Z0-9])([\u4e00-\u9fa5])', r'\1 \2', text)

    # 将中文字符间加空格
    text = re.sub(r'([\u4e00-\u9fa5])', r'\1 ', text)

    # 去掉多余的空格
    text = re.sub(r'\s+', ' ', text).strip()

    return text

ref_lines = open(ref_file, "r").readlines()
hyp_lines = open(hyp_file, "r").readlines()
ref_lines = [process_text(line) for line in ref_lines]
hyp_lines = [process_text(line) for line in hyp_lines]

print(hyp_lines[:10])

wer = jiwer.wer(ref_lines, hyp_lines,
                truth_transform=wer_standardize,
                hypothesis_transform=wer_standardize,
                )
cer = jiwer.cer(ref_lines, hyp_lines,
                truth_transform=cer_standardize,
                hypothesis_transform=cer_standardize,
                )

print("WER: %.4f" % wer)
print("CER: %.4f" % cer)
