Commit d88a22ef by xuchen

acc update

parent 478c694b
...@@ -85,8 +85,8 @@ dec_model=checkpoint_best.pt ...@@ -85,8 +85,8 @@ dec_model=checkpoint_best.pt
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
infer_score=0 infer_score=1
infer_parameters= infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
# Parsing Options # Parsing Options
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
...@@ -370,6 +370,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -370,6 +370,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--beam ${beam_size} --beam ${beam_size}
--lenpen ${len_penalty} --lenpen ${len_penalty}
--batch-size 1
--scoring wer --scoring wer
--wer-tokenizer 13a --wer-tokenizer 13a
--wer-lowercase --wer-lowercase
......
...@@ -21,12 +21,12 @@ sacrebleu=1 ...@@ -21,12 +21,12 @@ sacrebleu=1
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=40000 max_tokens=50000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
--stage 3 --stage 2
--stop_stage 3 --stop_stage 2
--src_lang ${src_lang} --src_lang ${src_lang}
--tgt_lang ${tgt_lang} --tgt_lang ${tgt_lang}
--share_dict ${share_dict} --share_dict ${share_dict}
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
# Processing MuST-C Datasets # Processing MuST-C Datasets
# Copyright 2021 Natural Language Processing Laboratory # Copyright 2021 Chen Xu (xuchennlp@outlook.com)
# Xu Chen (xuchenneu@163.com)
# Set bash to 'debug' mode, it will exit on : # Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
...@@ -16,22 +15,21 @@ eval=1 ...@@ -16,22 +15,21 @@ eval=1
time=$(date "+%m%d_%H%M") time=$(date "+%m%d_%H%M")
stage=1 stage=1
stop_stage=4 stop_stage=2
######## hardware ######## ######## Hardware ########
# devices # Devices
device=(0) device=(0)
gpu_num=1 gpu_num=1
update_freq=1 update_freq=1
hdfs_get=0
root_dir=/opt/tiger
data_root_dir=/mnt/bn/nas-xc-1
code_dir=${root_dir}/s2t
pwd_dir=$PWD pwd_dir=$PWD
root_dir=${ST_ROOT}
data_root_dir=${root_dir}
code_dir=${root_dir}/S2T
# dataset # Dataset
src_lang=en src_lang=en
tgt_lang=de tgt_lang=de
dataset=must_c dataset=must_c
...@@ -65,18 +63,16 @@ valid_subset=dev ...@@ -65,18 +63,16 @@ valid_subset=dev
trans_subset=tst-COMMON trans_subset=tst-COMMON
test_subset=valid,test test_subset=valid,test
# exp # Exp
sub_tag=
exp_prefix=$(date "+%m%d") exp_prefix=$(date "+%m%d")
# exp_subfix=${ARNOLD_JOB_ID}_${ARNOLD_TASK_ID}_${ARNOLD_TRIAL_ID}
extra_tag= extra_tag=
extra_parameter= extra_parameter=
exp_tag=baseline exp_tag=baseline
exp_name= exp_name=
# config # Training Settings
train_config=small train_config=small
# training setting
fp16=1 fp16=1
max_tokens=8192 max_tokens=8192
step_valid=0 step_valid=0
...@@ -88,7 +84,10 @@ dec_model=checkpoint_best.pt ...@@ -88,7 +84,10 @@ dec_model=checkpoint_best.pt
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
infer_score=1
infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
# Parsing Options
. ./local/parse_options.sh || exit 1; . ./local/parse_options.sh || exit 1;
if [[ ${use_specific_dict} -eq 1 ]]; then if [[ ${use_specific_dict} -eq 1 ]]; then
...@@ -127,19 +126,6 @@ if [[ ${tokenizer} -eq 1 ]]; then ...@@ -127,19 +126,6 @@ if [[ ${tokenizer} -eq 1 ]]; then
exp_prefix=${exp_prefix}_tok exp_prefix=${exp_prefix}_tok
fi fi
# setup nccl envs
export NCCL_IB_DISABLE=0
export NCCL_IB_HCA=$ARNOLD_RDMA_DEVICE:1
export NCCL_IB_GID_INDEX=3
export NCCL_SOCKET_IFNAME=eth0
HOSTS=$ARNOLD_WORKER_HOSTS
HOST=(${HOSTS//,/ })
HOST_SPLIT=(${HOST//:/ })
PORT=${HOST_SPLIT[1]}
INIT_METHOD="tcp://${ARNOLD_WORKER_0_HOST}:${ARNOLD_WORKER_0_PORT}"
DIST_RANK=$((ARNOLD_ID * ARNOLD_WORKER_GPU))
export PATH=$PATH:${code_dir}/scripts export PATH=$PATH:${code_dir}/scripts
. ./local/parse_options.sh || exit 1; . ./local/parse_options.sh || exit 1;
...@@ -153,20 +139,27 @@ if [[ -z ${exp_name} ]]; then ...@@ -153,20 +139,27 @@ if [[ -z ${exp_name} ]]; then
exp_name=${exp_name}_${exp_subfix} exp_name=${exp_name}_${exp_subfix}
fi fi
fi fi
model_dir=${code_dir}/checkpoints/${data_model_subfix}/${exp_name}
echo "stage: $stage" ckpt_dir=${root_dir}/checkpoints/
echo "stop_stage: $stop_stage" model_dir=${root_dir}/checkpoints/${data_model_subfix}/${sub_tag}/${exp_name}
# Start
cd ${code_dir} cd ${code_dir}
echo "Start Stage: $stage"
echo "Stop Stage: $stop_stage"
if [[ `pip list | grep fairseq | wc -l` -eq 0 ]]; then
echo "Default Stage: env configure"
pip3 install -e ${code_dir}
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download" echo "Stage -1: Data Download"
# pass
fi fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself. ### Task dependent. You have to make data the following preparation part by yourself.
echo "stage 0: MT Data Preparation" echo "Stage 0: Data Preparation"
if [[ ! -e ${data_dir} ]]; then if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir} mkdir -p ${data_dir}
fi fi
...@@ -230,32 +223,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -230,32 +223,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
fi fi
echo "stage 1: env configure"
if [[ `pip list | grep fairseq | wc -l` -eq 0 ]]; then
pip3 install -e ${code_dir} -i https://bytedpypi.byted.org/simple --no-build-isolation --default-timeout=10000
fi
if [[ -d /mnt/bn/nas-xc-1/checkpoints && ! -d ${code_dir}/checkpoints ]]; then
ln -s /mnt/bn/nas-xc-1/checkpoints ${code_dir}
fi
# if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [ ${hdfs_get} -eq 1 ]; then
ln_data_dir=`echo ${data_dir} | sed -e "s#${data_root_dir}#${code_dir}#"`
echo ${ln_data_dir}
mkdir -p ${ln_data_dir}
ln -s ${data_dir}/../* ${ln_data_dir}
rm -r ${ln_data_dir}
hdfs_path=`echo ${data_dir} | sed -e "s#${data_root_dir}#hdfs://haruna/home/byte_arnold_lq_mlnlc/user/xuchen/#"`
hdfs dfs -get ${hdfs_path} ${ln_data_dir}
data_dir=${ln_data_dir}
fi
# fi
data_dir=${data_dir}/data-bin data_dir=${data_dir}/data-bin
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "Stage 1: Network Training"
echo "stage 2: MT Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1; [[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
...@@ -265,6 +235,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -265,6 +235,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
fi fi
export CUDA_VISIBLE_DEVICES=${device}
fi fi
echo -e "data=${data_dir} model=${model_dir}" echo -e "data=${data_dir} model=${model_dir}"
...@@ -369,18 +340,13 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -369,18 +340,13 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "${time} | ${data_dir} | ${exp_name} | ${model_dir} " >> $log echo "${time} | ${data_dir} | ${exp_name} | ${model_dir} " >> $log
tail -n 50 ${log} > tmp.log tail -n 50 ${log} > tmp.log
mv tmp.log $log mv tmp.log $log
# export CUDA_VISIBLE_DEVICES=${device}
log=${model_dir}/train.log log=${model_dir}/train.log
cmd="${cmd} 2>&1 | tee -a ${log}" cmd="${cmd} 2>&1 | tee -a ${log}"
#cmd="nohup ${cmd} >> ${log} 2>&1 &" #cmd="nohup ${cmd} >> ${log} 2>&1 &"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
# tensorboard # tensorboard
if [[ -z ${ARNOLD_TENSORBOARD_CURRENT_PORT} ]]; then port=6666
port=6666
else
port=${ARNOLD_TENSORBOARD_CURRENT_PORT}
fi
tensorboard --logdir ${model_dir} --port ${port} --bind_all & tensorboard --logdir ${model_dir} --port ${port} --bind_all &
echo "${cmd}" > ${model_dir}/cmd echo "${cmd}" > ${model_dir}/cmd
...@@ -390,8 +356,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -390,8 +356,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 3: MT Decoding" echo "Stage 2: Decoding"
if [[ ${n_average} -ne 1 ]]; then if [[ ${n_average} -ne 1 ]]; then
# Average models # Average models
dec_model=avg_${n_average}_checkpoint.pt dec_model=avg_${n_average}_checkpoint.pt
...@@ -411,12 +377,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -411,12 +377,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then if [[ ${gpu_num} -eq 0 ]]; then
device="" device=""
else else
source ./local/utils.sh source ./local/utils.sh
device=$(get_devices $gpu_num 0) device=$(get_devices $gpu_num 0)
fi fi
export CUDA_VISIBLE_DEVICES=${device}
fi fi
# export CUDA_VISIBLE_DEVICES=${device}
suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens} suffix=beam${beam_size}_alpha${len_penalty}_tokens${max_tokens}
if [[ ${n_average} -ne 1 ]]; then if [[ ${n_average} -ne 1 ]]; then
...@@ -427,6 +393,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -427,6 +393,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
else else
suffix=${suffix}_multibleu suffix=${suffix}_multibleu
fi fi
if [[ ${infer_score} -eq 1 ]]; then
suffix=${suffix}_score
fi
result_file=${model_dir}/decode_result_${suffix} result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file} [[ -f ${result_file} ]] && rm ${result_file}
...@@ -442,6 +411,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -442,6 +411,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--results-path ${model_dir} --results-path ${model_dir}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--beam ${beam_size} --beam ${beam_size}
--batch-size 1
--lenpen ${len_penalty} --lenpen ${len_penalty}
--post-process sentencepiece" --post-process sentencepiece"
...@@ -462,27 +432,34 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -462,27 +432,34 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--target-lang ${tgt_lang}" --target-lang ${tgt_lang}"
fi fi
fi fi
if [[ ${infer_score} -eq 1 ]]; then
cmd="${cmd}
--score-reference"
fi
if [[ -n ${infer_parameters} ]]; then
cmd="${cmd}
${infer_parameters}"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
cd ${code_dir} cd ${code_dir}
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
eval $cmd eval $cmd
echo "" >> ${result_file}
tail -n 2 ${model_dir}/generate-${subset}.txt >> ${result_file} tail -n 2 ${model_dir}/generate-${subset}.txt >> ${result_file}
mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt mv ${model_dir}/generate-${subset}.txt ${model_dir}/generate-${subset}-${suffix}.txt
mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt mv ${model_dir}/translation-${subset}.txt ${model_dir}/translation-${subset}-${suffix}.txt
cd ${pwd_dir} cd ${pwd_dir}
if [[ -f ${model_dir}/enc_dump ]]; then
mv ${model_dir}/enc_dump ${model_dir}/dump-${subset}-enc-${suffix}
fi
if [[ -f ${model_dir}/dec_dump ]]; then
mv ${model_dir}/dec_dump ${model_dir}/dump-${subset}-dec-${suffix}
fi
fi fi
done done
echo echo
cat ${result_file} cat ${result_file}
fi fi
# if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# cd ${fairseq_dir}
# echo "Stage 4: Upload model and log"
# echo "Path: hdfs://haruna/home/byte_arnold_lq_mlnlc/user/xuchen/s2t/checkpoints/${data_model_subfix}/${exp_name}"
# hdfs dfs -mkdir -p hdfs://haruna/home/byte_arnold_lq_mlnlc/user/xuchen/s2t/checkpoints/${data_model_subfix}
# hdfs dfs -put -f ${model_dir} hdfs://haruna/home/byte_arnold_lq_mlnlc/user/xuchen/s2t/checkpoints/${data_model_subfix}
# fi
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
gpu_num=1 gpu_num=1
data_dir= data_dir=
test_subset=(dev tst-COMMON) test_subset=(tst-COMMON)
exp_name= exp_name=
if [ "$#" -eq 1 ]; then if [ "$#" -eq 1 ]; then
...@@ -11,16 +11,16 @@ if [ "$#" -eq 1 ]; then ...@@ -11,16 +11,16 @@ if [ "$#" -eq 1 ]; then
fi fi
sacrebleu=1 sacrebleu=1
ctc_infer=1 ctc_infer=0
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
max_tokens=80000 max_tokens=50000
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
cmd="./run.sh cmd="./run.sh
--stage 3 --stage 2
--stop_stage 3 --stop_stage 2
--gpu_num ${gpu_num} --gpu_num ${gpu_num}
--exp_name ${exp_name} --exp_name ${exp_name}
--sacrebleu ${sacrebleu} --sacrebleu ${sacrebleu}
......
...@@ -24,7 +24,7 @@ gpu_num=8 ...@@ -24,7 +24,7 @@ gpu_num=8
update_freq=1 update_freq=1
pwd_dir=$PWD pwd_dir=$PWD
root_dir=${pwd_dir}/../../../../ root_dir=${ST_ROOT}
data_root_dir=${root_dir} data_root_dir=${root_dir}
code_dir=${root_dir}/S2T code_dir=${root_dir}/S2T
...@@ -85,8 +85,9 @@ ctc_infer=0 ...@@ -85,8 +85,9 @@ ctc_infer=0
n_average=10 n_average=10
beam_size=5 beam_size=5
len_penalty=1.0 len_penalty=1.0
infer_score=0 infer_score=1
infer_parameters= infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
# Parsing Options # Parsing Options
if [[ ${share_dict} -eq 1 ]]; then if [[ ${share_dict} -eq 1 ]]; then
...@@ -136,8 +137,8 @@ if [[ -z ${exp_name} ]]; then ...@@ -136,8 +137,8 @@ if [[ -z ${exp_name} ]]; then
fi fi
fi fi
ckpt_dir=${code_dir}/checkpoints/ ckpt_dir=${root_dir}/checkpoints/
model_dir=${code_dir}/checkpoints/${data_model_subfix}/${sub_tag}/${exp_name} model_dir=${root_dir}/checkpoints/${data_model_subfix}/${sub_tag}/${exp_name}
# Start # Start
cd ${code_dir} cd ${code_dir}
...@@ -427,6 +428,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -427,6 +428,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--path ${model_dir}/${dec_model} --path ${model_dir}/${dec_model}
--results-path ${model_dir} --results-path ${model_dir}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--batch-size 1
--beam ${beam_size} --beam ${beam_size}
--skip-invalid-size-inputs-valid-test --skip-invalid-size-inputs-valid-test
--lenpen ${len_penalty}" --lenpen ${len_penalty}"
......
...@@ -109,8 +109,8 @@ _code_to_dtype = { ...@@ -109,8 +109,8 @@ _code_to_dtype = {
3: np.int16, 3: np.int16,
4: np.int32, 4: np.int32,
5: np.int64, 5: np.int64,
6: np.float, 6: np.float32,
7: np.double, 7: np.float64,
8: np.uint16, 8: np.uint16,
9: np.uint32, 9: np.uint32,
10: np.uint64, 10: np.uint64,
...@@ -316,8 +316,8 @@ class IndexedDatasetBuilder: ...@@ -316,8 +316,8 @@ class IndexedDatasetBuilder:
np.int16: 2, np.int16: 2,
np.int32: 4, np.int32: 4,
np.int64: 8, np.int64: 8,
np.float: 4, np.float32: 4,
np.double: 8, np.float64: 8,
} }
def __init__(self, out_file, dtype=np.int32): def __init__(self, out_file, dtype=np.int32):
......
...@@ -563,6 +563,17 @@ class TransformerEncoder(FairseqEncoder): ...@@ -563,6 +563,17 @@ class TransformerEncoder(FairseqEncoder):
x = self.quant_noise(x) x = self.quant_noise(x)
return x, embed return x, embed
def set_flag(self, **kwargs):
for layer in self.layers:
if hasattr(layer, "set_flag"):
layer.set_flag(**kwargs)
def dump(self, fstream, info=""):
for i, layer in enumerate(self.layers):
layer.dump(fstream, "%s Layer %d" % (info, i)) if hasattr(
layer, "dump"
) else None
def forward( def forward(
self, self,
src_tokens, src_tokens,
......
...@@ -113,9 +113,9 @@ class MultiheadAttention(nn.Module): ...@@ -113,9 +113,9 @@ class MultiheadAttention(nn.Module):
if kwargs.get("cal_localness", False) and not self.encoder_decoder_attention: if kwargs.get("cal_localness", False) and not self.encoder_decoder_attention:
self.cal_localness = True self.cal_localness = True
self.localness_window = kwargs.get("localness_window", 0.1) self.localness_window = kwargs.get("localness_window", 0.1)
if kwargs.get("cal_entropy", False) and self.encoder_decoder_attention: if kwargs.get("cal_entropy", False): # and self.encoder_decoder_attention:
self.cal_entropy = True self.cal_entropy = True
if kwargs.get("cal_topk_cross_attn_weights", False) and self.encoder_decoder_attention: if kwargs.get("cal_topk_cross_attn_weights", False):
self.cal_topk = True self.cal_topk = True
self.weights_topk = kwargs.get("topk_cross_attn_weights", 1) self.weights_topk = kwargs.get("topk_cross_attn_weights", 1)
if kwargs.get("cal_monotonic_cross_attn_weights", False) and self.encoder_decoder_attention: if kwargs.get("cal_monotonic_cross_attn_weights", False) and self.encoder_decoder_attention:
...@@ -123,7 +123,7 @@ class MultiheadAttention(nn.Module): ...@@ -123,7 +123,7 @@ class MultiheadAttention(nn.Module):
def dump(self, fstream, info): def dump(self, fstream, info):
if self.cal_localness: if self.cal_localness:
print("%s window size: %f localness: %.2f" % (info, self.localness_window, self.localness), file=fstream) print("%s window size: %.2f localness: %.4f" % (info, self.localness_window, self.localness), file=fstream)
if self.cal_entropy: if self.cal_entropy:
print("%s Entropy: %.2f" % (info, self.entropy), file=fstream) print("%s Entropy: %.2f" % (info, self.entropy), file=fstream)
...@@ -423,36 +423,55 @@ class MultiheadAttention(nn.Module): ...@@ -423,36 +423,55 @@ class MultiheadAttention(nn.Module):
# average attention weights over heads # average attention weights over heads
attn_weights = attn_weights.mean(dim=0) attn_weights = attn_weights.mean(dim=0)
self.cal_localness_func(attn_weights_float, bsz, src_len, tgt_len) self.cal_localness_func(attn_weights_float, bsz, src_len, tgt_len, key_padding_mask)
self.cal_entropy_func(attn_weights_float, bsz, src_len, tgt_len) self.cal_entropy_func(attn_weights_float, bsz, src_len, tgt_len)
self.cal_topk_func(attn_weights_float, bsz, src_len, tgt_len) self.cal_topk_func(attn_weights_float, bsz, src_len, tgt_len)
self.cal_monotonic_func(attn_weights_float, bsz, src_len, tgt_len) self.cal_monotonic_func(attn_weights_float, bsz, src_len, tgt_len)
return attn, attn_weights return attn, attn_weights
def cal_localness_func(self, attn_weights_float, bsz, src_len, tgt_len): def cal_localness_func(self, attn_weights_float, bsz, src_len, tgt_len, key_padding_mask):
if not self.training and self.cal_localness: if not self.training and self.cal_localness:
weights = attn_weights_float.view( weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0).mean(dim=0) ).transpose(1, 0).mean(0).cpu()
localness = 0 localness = 0
item_localness = 0
window = int(src_len * self.localness_window) window = int(src_len * self.localness_window)
for i in range(window, src_len - window): # print(src_len)
# print(window)
# for i in range(window, src_len - window):
# item_localness = 0
# for j in range(-window, window + 1):
# # if j == 0:
# # continue
# item_localness += weights[:, :, i, i + j]
# localness += item_localness
for i in range(bsz):
sum_num = 0
item_localness = 0 item_localness = 0
for j in range(-window, window + 1): # print(weights[i, :, :])
# if j == 0: for j in range(window, src_len - window):
# continue if key_padding_mask is not None and key_padding_mask[i, j] == True:
item_localness += weights[:, i, i + j] continue
localness += item_localness
localness = localness / (src_len - 2 * window) unit_localness = 0
localness *= 100 for k in range(-window, window + 1):
unit_localness += weights[i, j, j + k]
# print(j)
# print(unit_localness)
item_localness += unit_localness
sum_num += 1
# exit()
if sum_num > 0:
localness += item_localness / sum_num
localness = localness / bsz
if self.localness_num == 0: if self.localness_num == 0:
self.localness = localness.mean() self.localness = localness.mean()
else: else:
self.localness = (self.localness * self.localness_num + localness.mean()) / (self.localness_num + 1) self.localness = (self.localness * self.localness_num + localness.mean()) / (self.localness_num + 1)
# print(self.localness)
self.localness_num += 1 self.localness_num += 1
def cal_entropy_func(self, attn_weights_float, bsz, src_len, tgt_len): def cal_entropy_func(self, attn_weights_float, bsz, src_len, tgt_len):
......
...@@ -5,8 +5,11 @@ ...@@ -5,8 +5,11 @@
import unicodedata import unicodedata
import sacrebleu as sb
from fairseq.dataclass import ChoiceEnum from fairseq.dataclass import ChoiceEnum
SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2
class EvaluationTokenizer(object): class EvaluationTokenizer(object):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers """A generic evaluation-time tokenizer, which leverages built-in tokenizers
...@@ -24,7 +27,12 @@ class EvaluationTokenizer(object): ...@@ -24,7 +27,12 @@ class EvaluationTokenizer(object):
SPACE = chr(32) SPACE = chr(32)
SPACE_ESCAPE = chr(9601) SPACE_ESCAPE = chr(9601)
ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) _ALL_TOKENIZER_TYPES = (
sb.BLEU.TOKENIZERS
if SACREBLEU_V2_ABOVE
else ["none", "13a", "intl", "zh", "ja-mecab"]
)
ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES)
def __init__( def __init__(
self, self,
...@@ -33,13 +41,16 @@ class EvaluationTokenizer(object): ...@@ -33,13 +41,16 @@ class EvaluationTokenizer(object):
punctuation_removal: bool = False, punctuation_removal: bool = False,
character_tokenization: bool = False, character_tokenization: bool = False,
): ):
from sacrebleu.tokenizers import TOKENIZERS assert (
tokenizer_type in self._ALL_TOKENIZER_TYPES
assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" ), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}"
self.lowercase = lowercase self.lowercase = lowercase
self.punctuation_removal = punctuation_removal self.punctuation_removal = punctuation_removal
self.character_tokenization = character_tokenization self.character_tokenization = character_tokenization
self.tokenizer = TOKENIZERS[tokenizer_type] if SACREBLEU_V2_ABOVE:
self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer
else:
self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]()
@classmethod @classmethod
def remove_punctuation(cls, sent: str): def remove_punctuation(cls, sent: str):
...@@ -51,7 +62,7 @@ class EvaluationTokenizer(object): ...@@ -51,7 +62,7 @@ class EvaluationTokenizer(object):
) )
def tokenize(self, sent: str): def tokenize(self, sent: str):
tokenized = self.tokenizer()(sent) tokenized = self.tokenizer(sent)
if self.punctuation_removal: if self.punctuation_removal:
tokenized = self.remove_punctuation(tokenized) tokenized = self.remove_punctuation(tokenized)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论