#!/usr/bin/bash
set -e

device=(0 1 3 4 5 6 7)

model_root_dir=checkpoints
# set tag
model_dir_tag=dense16
model_dir=$model_root_dir/$model_dir_tag

ensemble=
checkpoint=fairseq.pt
data_dir=wmt19_zh2en
beam=12
lenpens=(1.0 1.1 1.2 1.3 1.4 1.5 1.6)
who=test

n_device=${#device[@]}
n_lenpens=${#lenpens[@]}

if [ -n "$ensemble" ]; then
        if [ ! -e "$model_dir/last$ensemble.ensemble.pt" ]; then
                PYTHONPATH=`pwd` python3 scripts/average_checkpoints.py --inputs $model_dir --output $model_dir/last$ensemble.ensemble.pt --num-epoch-checkpoints $ensemble
        fi
        checkpoint=last$ensemble.ensemble.pt
fi


echo $n_device
echo $n_lenpens
# device is enough
if [ $n_device -ge $n_lenpens ]; then
        #echo "device is enough!"
        for ((i=0;i<${#lenpens[@]};i++));do
        {
                lenpen=${lenpens[$i]}
                dev=${device[$i]}
                output=$model_dir/translation.$lenpen.log
                cmd="python3 generate.py \
                data-bin/$data_dir \
                --path $model_dir/$checkpoint \
                --gen-subset $who \
                --output $model_dir/hypo.$who.beam$beam.lenpen$lenpen \
                --batch-size 32 \
                --beam $beam \
                --lenpen $lenpen \
                --remove-bpe"
                echo $cmd
                echo "run data=${who} beam=$beam lenpen=$lenpen dev=$dev"
                CUDA_VISIBLE_DEVICES=$dev $cmd | tee $output
        }&
        done
        #echo "[enough]==> wait it"
        wait
# device is poor
else
        #echo "device is poor"
        if [ $(($n_lenpens%$n_device)) -eq 0 ]; then
                n_group=$(($n_lenpens/$n_device))
        else
                n_group=$(($n_lenpens/$n_device+1))
        fi
        #echo "group=$n_group"
        for ((i=0;i<$n_group;i++));do
        {
                for ((j=0;j<$n_device;j++));do
                {
                        lenpen=${lenpens[$(($i*$n_device+$j))]}
                        dev=${device[$(($j))]}
                        output=$model_dir/translation.$lenpen.log
                        if [ -n "$lenpen" ]; then
                                echo "run data=${who} beam=$beam lenpen=$lenpen dev=$dev"
                                CUDA_VISIBLE_DEVICES=$dev $cmd | tee $output
                        fi
                } &
                done
                #echo "wait group=$i finish"
                wait
                #echo "group=$i finish"
        }
        done
fi

for ((i=0;i<${#lenpens[@]};i++));do
    lenpen=${lenpens[$i]}
    python3 rerank.py $model_dir/hypo.$who.beam$beam.lenpen$lenpen $model_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
    rm $model_dir/hypo.$who.beam$beam.lenpen$lenpen
done

echo 'multi bleu:'
for ((i=0;i<${#lenpens[@]};i++));do
    lenpen=${lenpens[$i]}
    echo beam: $beam, lenpen: $lenpen
    if [ $data_dir == "wmt19_zh2en" ] && [ $who == "valid" ]; then
        perl $multi_bleu reference/wmt17-dev-ref < $model_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
    elif [ $data_dir == "wmt19_zh2en" ] && [ $who == "test" ]; then
        perl $multi_bleu reference/wmt17-test-ref < $model_dir/hypo.$who.beam$beam.lenpen$lenpen.decodes
    fi
done

