#!/usr/bin/bash
set -e
device=(0)
model_root_dir=checkpoints
# set tag
model_dir_tag=(baseline_v2 baseline_v3)
model_dir=()
model_dir=$model_root_dir/$model_dir_tag
ensemble_dir=$model_root_dir/ensemble

checkpoint=fairseq.pt
data_dir=wmt19_zh2en
beam=12
lenpen=1.2
who=test

if [ -n "$ensemble" ]; then
	if [ ! -e "$model_dir/last$ensemble.ensemble.pt" ]; then
		PYTHONPATH=`pwd` python3 scripts/average_checkpoints.py --inputs $model_dir --output $model_dir/last$ensemble.ensemble.pt --num-epoch-checkpoints $ensemble
	fi
	checkpoint=last$ensemble.ensemble.pt
fi

output=$model_dir/translation.log

CUDA_VISIBLE_DEVICES=$device python3 generate_test.py \
data-bin/$data_dir \
--path $model_root_dir/big_v3_multistep4/$checkpoint:$model_root_dir/baseline_epoch10/$checkpoint:$model_root_dir/baseline_epoch20/$checkpoint:$model_root_dir/filter_base/$checkpoint:$model_root_dir/baseline_v3/$checkpoint:$model_root_dir/baseline_v3_seed2/$checkpoint:$model_root_dir/baseline_v2_seed2/$checkpoint:$model_root_dir/baseline/$checkpoint \
--gen-subset $who \
--output $ensemble_dir/hypo.$who.beam$beam.lenpen$lenpen \
--batch-size 8 \
--beam $beam \
--lenpen $lenpen \
--remove-bpe | tee $output

#python3 parse_translation_log.py -i $output --tgt_lang de

python3 rerank.py $ensemble_dir/hypo.$who.beam$beam.lenpen$lenpen $ensemble_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
#remove the intermediate output
rm $ensemble_dir/hypo.$who.beam$beam.lenpen$lenpen

if [ $data_dir == "wmt19_zh2en" ] && [ $who == "valid" ]; then
        perl $multi_bleu reference/wmt17-dev-ref < $ensemble_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
elif [ $data_dir == "wmt19_zh2en" ] && [ $who == "test" ]; then
        perl $multi_bleu reference/wmt17-test-ref < $ensemble_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
fi
