#!/usr/bin/bash
set -e

model_root_dir=checkpoints
# set tag
model_dir_tag=baseline_v2
model_dir=$model_root_dir/$model_dir_tag

ensemble=
checkpoint=fairseq.pt
data_dir=wmt19_zh2en
beam=11
lenpen=1.4
who=test

if [ -n "$ensemble" ]; then
	if [ ! -e "$model_dir/last$ensemble.ensemble.pt" ]; then
		PYTHONPATH=`pwd` python3 scripts/average_checkpoints.py --inputs $model_dir --output $model_dir/last$ensemble.ensemble.pt --num-epoch-checkpoints $ensemble
	fi
	checkpoint=last$ensemble.ensemble.pt
fi

output=$model_dir/translation.log

python3 generate.py \
data-bin/$data_dir \
--path $model_dir/$checkpoint \
--gen-subset $who \
--output $model_dir/hypo.$who.beam$beam.lenpen$lenpen \
--batch-size 32 \
--beam $beam \
--lenpen $lenpen \
--remove-bpe | tee $output

#python3 parse_translation_log.py -i $output --tgt_lang de

python3 rerank.py $model_dir/hypo.$who.beam$beam.lenpen$lenpen $model_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
#remove the intermediate output
rm $model_dir/hypo.$who.beam$beam.lenpen$lenpen

if [ $data_dir == "wmt19_zh2en" ] && [ $who == "valid" ]; then
        perl $multi_bleu reference/wmt17-dev-ref < $model_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
elif [ $data_dir == "wmt19_zh2en" ] && [ $who == "test" ]; then
        perl $multi_bleu reference/wmt17-test-ref < $model_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
fi
