Commit e99d31c0 by xuchen

add the translation path to save the translation results

parent b9f02c3b
...@@ -43,8 +43,12 @@ def main(cfg: DictConfig): ...@@ -43,8 +43,12 @@ def main(cfg: DictConfig):
cfg.common_eval.results_path, cfg.common_eval.results_path,
"generate-{}.txt".format(cfg.dataset.gen_subset), "generate-{}.txt".format(cfg.dataset.gen_subset),
) )
translation_path = os.path.join(
cfg.common_eval.results_path,
"translation-{}.txt".format(cfg.dataset.gen_subset)
)
with open(output_path, "w", buffering=1, encoding="utf-8") as h: with open(output_path, "w", buffering=1, encoding="utf-8") as h:
return _main(cfg, h) return _main(cfg, h, translation_path)
else: else:
return _main(cfg, sys.stdout) return _main(cfg, sys.stdout)
...@@ -56,7 +60,7 @@ def get_symbols_to_strip_from_output(generator): ...@@ -56,7 +60,7 @@ def get_symbols_to_strip_from_output(generator):
return {generator.eos} return {generator.eos}
def _main(cfg: DictConfig, output_file): def _main(cfg: DictConfig, output_file, translation_path=None):
logging.basicConfig( logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
...@@ -179,6 +183,7 @@ def _main(cfg: DictConfig, output_file): ...@@ -179,6 +183,7 @@ def _main(cfg: DictConfig, output_file):
return x return x
scorer = scoring.build_scorer(cfg.scoring, tgt_dict) scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
translation_list = []
num_sentences = 0 num_sentences = 0
has_target = True has_target = True
...@@ -233,7 +238,7 @@ def _main(cfg: DictConfig, output_file): ...@@ -233,7 +238,7 @@ def _main(cfg: DictConfig, output_file):
sample_id sample_id
) )
else: else:
if src_dict is not None and src_tokens in [torch.int32, torch.int64]: if src_dict is not None and src_tokens.dtype in [torch.int32, torch.int64]:
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
else: else:
src_str = "" src_str = ""
...@@ -347,6 +352,7 @@ def _main(cfg: DictConfig, output_file): ...@@ -347,6 +352,7 @@ def _main(cfg: DictConfig, output_file):
# Score only the top hypothesis # Score only the top hypothesis
if has_target and j == 0: if has_target and j == 0:
translation_list.append([src_str, target_str, detok_hypo_str])
if align_dict is not None or cfg.common_eval.post_process is not None: if align_dict is not None or cfg.common_eval.post_process is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE # Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tgt_dict.encode_line( target_tokens = tgt_dict.encode_line(
...@@ -393,6 +399,11 @@ def _main(cfg: DictConfig, output_file): ...@@ -393,6 +399,11 @@ def _main(cfg: DictConfig, output_file):
), ),
file=output_file, file=output_file,
) )
if translation_path is not None:
with open(translation_path, "w", encoding="utf-8") as f:
for item in translation_list:
f.write("{}\n".format("\t".join(item)))
return scorer return scorer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论