Commit e99d31c0 by xuchen

add the translation path to save the translation results

parent b9f02c3b
......@@ -43,8 +43,12 @@ def main(cfg: DictConfig):
cfg.common_eval.results_path,
"generate-{}.txt".format(cfg.dataset.gen_subset),
)
translation_path = os.path.join(
cfg.common_eval.results_path,
"translation-{}.txt".format(cfg.dataset.gen_subset)
)
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
return _main(cfg, h)
return _main(cfg, h, translation_path)
else:
return _main(cfg, sys.stdout)
......@@ -56,7 +60,7 @@ def get_symbols_to_strip_from_output(generator):
return {generator.eos}
def _main(cfg: DictConfig, output_file):
def _main(cfg: DictConfig, output_file, translation_path=None):
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
......@@ -179,6 +183,7 @@ def _main(cfg: DictConfig, output_file):
return x
scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
translation_list = []
num_sentences = 0
has_target = True
......@@ -233,7 +238,7 @@ def _main(cfg: DictConfig, output_file):
sample_id
)
else:
if src_dict is not None and src_tokens in [torch.int32, torch.int64]:
if src_dict is not None and src_tokens.dtype in [torch.int32, torch.int64]:
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
else:
src_str = ""
......@@ -347,6 +352,7 @@ def _main(cfg: DictConfig, output_file):
# Score only the top hypothesis
if has_target and j == 0:
translation_list.append([src_str, target_str, detok_hypo_str])
if align_dict is not None or cfg.common_eval.post_process is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tgt_dict.encode_line(
......@@ -393,6 +399,11 @@ def _main(cfg: DictConfig, output_file):
),
file=output_file,
)
if translation_path is not None:
with open(translation_path, "w", encoding="utf-8") as f:
for item in translation_list:
f.write("{}\n".format("\t".join(item)))
return scorer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论