Commit 0ab28954 by xuchen

add the mt recipe, which support the same pipeline like asr and st.

parent 5cf7717c
...@@ -134,3 +134,4 @@ experimental/* ...@@ -134,3 +134,4 @@ experimental/*
# Weights and Biases logs # Weights and Biases logs
wandb/ wandb/
/examples/translation/iwslt14.tokenized.de-en/
train-subset: train
valid-subset: dev
max-epoch: 50
max-update: 100000
num-workers: 8
patience: 10
no-progress-bar: True
log-interval: 100
seed: 1
report-accuracy: True
#load-params:
#load-pretrained-encoder-from:
arch: transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
#attention-dropout: 0.1
#activation-dropout: 0.1
#! /bin/bash
gpu_num=1
test_subset=(tst-COMMON)
exp_name=
if [ "$#" -eq 1 ]; then
exp_name=$1
fi
n_average=10
beam_size=5
max_tokens=40000
cmd="./run.sh
--stage 2
--stop_stage 2
--gpu_num ${gpu_num}
--exp_name ${exp_name}
--test_subset ${test_subset}
--n_average ${n_average}
--beam_size ${beam_size}
--max_tokens ${max_tokens}
"
echo $cmd
eval $cmd
gpu_num=1
while :
do
all_devices=$(seq 0 `gpustat | sed '1,2d' | wc -l`);
count=0
for dev in ${all_devices[@]}
do
line=`expr $dev + 2`
use=`gpustat -p | head -n $line | tail -1 | cut -d '|' -f4 | wc -w`
if [[ $use -eq 0 ]]; then
device[$count]=$dev
count=`expr $count + 1`
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
sleep 60s
else
echo "Run $cmd"
eval $cmd
sleep 10s
exit
fi
done
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
MAIN_ROOT=$PWD/../../..
KALDI_ROOT=$MAIN_ROOT/tools/kaldi
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
export LC_ALL=C
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/src/lib
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/tools/chainer_ctc/ext/warp-ctc/build
. "${MAIN_ROOT}"/tools/activate_python.sh && . "${MAIN_ROOT}"/tools/extra_path.sh
export PATH=$MAIN_ROOT/utils:$MAIN_ROOT/espnet/bin:$PATH
export OMP_NUM_THREADS=1
# check extra module installation
if ! which tokenizer.perl > /dev/null; then
echo "Error: it seems that moses is not installed." >&2
echo "Error: please install moses as follows." >&2
echo "Error: cd ${MAIN_ROOT}/tools && make moses.done" >&2
return 1
fi
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
get_devices(){
gpu_num=$1
use_cpu=$2
device=()
while :
do
record=`mktemp -t temp.record.XXXXXX`
gpustat > $record
all_devices=$(seq 0 `cat $record | sed '1,2d' | wc -l`);
count=0
for dev in ${all_devices[@]}
do
line=`expr $dev + 2`
use=`cat $record | head -n $line | tail -1 | cut -d '|' -f3 | cut -d '/' -f1`
if [[ $use -lt 10 ]]; then
device[$count]=$dev
count=`expr $count + 1`
if [[ $count -eq $gpu_num ]]; then
break
fi
fi
done
if [[ ${#device[@]} -lt $gpu_num ]]; then
if [[ $use_cpu -eq 1 ]]; then
device=(-1)
else
sleep 60s
fi
else
break
fi
done
echo ${device[*]} | sed 's/ /,/g'
return $?
}
#! /bin/bash
# Processing MuST-C Datasets
# Copyright 2021 Natural Language Processing Laboratory
# Xu Chen (xuchenneu@163.com)
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
#set -u
set -o pipefail
export PYTHONIOENCODING=UTF-8
eval=1
time=$(date "+%m%d_%H%M")
stage=0
stop_stage=0
######## hardware ########
# devices
device=()
gpu_num=8
update_freq=1
root_dir=~/st/fairseq
pwd_dir=$PWD
# dataset
src_lang=en
tgt_lang=de
lang=${src_lang}-${tgt_lang}
dataset=mustc
task=translation_with_tokenizer
vocab_type=unigram
vocab_size=10000
share_dict=1
org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}/mt/${lang}
test_subset=(tst-COMMON)
# exp
extra_tag=
extra_parameter=
exp_tag=baseline
exp_name=
# config
train_config=st_train_ctc.yaml
# training setting
fp16=1
max_tokens=40000
step_valid=0
bleu_valid=0
# decoding setting
n_average=10
beam_size=5
. ./local/parse_options.sh || exit 1;
if [[ $step_valid -eq 1 ]]; then
validate_interval=10000
save_interval=10000
no_epoch_checkpoints=1
save_interval_updates=5000
keep_interval_updates=3
else
validate_interval=1
keep_last_epochs=10
fi
if [[ ${share_dict} -eq 1 ]]; then
data_config=config_share.yaml
else
data_config=config.yaml
fi
# full path
train_config=$pwd_dir/conf/${train_config}
if [[ -z ${exp_name} ]]; then
exp_name=$(basename ${train_config%.*})_${exp_tag}
if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag}
fi
fi
model_dir=$root_dir/../checkpoints/$dataset/mt/${exp_name}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
# pass
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself.
echo "stage 0: MT Data Preparation"
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
cmd="python ${root_dir}/examples/speech_to_text/prep_mt_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--splits train,dev,tst-COMMON,tst-HE
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: MT Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then
device=()
else
source ./local/utils.sh
device=$(get_devices $gpu_num 0)
fi
fi
echo -e "dev=${device} data=${data_dir} model=${model_dir}"
if [[ ! -d ${model_dir} ]]; then
mkdir -p ${model_dir}
else
echo "${model_dir} exists."
fi
cp ${BASH_SOURCE[0]} ${model_dir}
cp ${PWD}/train.sh ${model_dir}
cp ${train_config} ${model_dir}
cmd="python3 -u ${root_dir}/fairseq_cli/train.py
${data_dir}
--source-lang ${src_lang}
--target-lang ${tgt_lang}
--config-yaml ${data_config}
--train-config ${train_config}
--task ${task}
--max-tokens ${max_tokens}
--update-freq ${update_freq}
--log-interval 100
--save-dir ${model_dir}
--tensorboard-logdir ${model_dir}"
if [[ -n ${extra_parameter} ]]; then
cmd="${cmd}
${extra_parameter}"
fi
if [[ ${gpu_num} -gt 0 ]]; then
cmd="${cmd}
--distributed-world-size $gpu_num
--ddp-backend no_c10d"
fi
if [[ $fp16 -eq 1 ]]; then
cmd="${cmd}
--fp16"
fi
if [[ $bleu_valid -eq 1 ]]; then
cmd="$cmd
--eval-bleu
--eval-bleu-args '{\"beam\": 1}'
--eval-tokenized-bleu
--eval-bleu-remove-bpe
--best-checkpoint-metric bleu
--maximize-best-checkpoint-metric"
fi
if [[ -n $no_epoch_checkpoints && $no_epoch_checkpoints -eq 1 ]]; then
cmd="$cmd
--no-epoch-checkpoints"
fi
if [[ -n $validate_interval ]]; then
cmd="${cmd}
--validate-interval $validate_interval "
fi
if [[ -n $save_interval ]]; then
cmd="${cmd}
--save-interval $save_interval "
fi
if [[ -n $keep_last_epochs ]]; then
cmd="${cmd}
--keep-last-epochs $keep_last_epochs "
fi
if [[ -n $save_interval_updates ]]; then
cmd="${cmd}
--save-interval-updates $save_interval_updates"
if [[ -n $keep_interval_updates ]]; then
cmd="${cmd}
--keep-interval-updates $keep_interval_updates"
fi
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
# save info
log=./history.log
echo "${time} | ${device} | ${data_dir} | ${model_dir} " >> $log
cat $log | tail -n 50 > tmp.log
mv tmp.log $log
export CUDA_VISIBLE_DEVICES=${device}
cmd="nohup ${cmd} >> ${model_dir}/train.log 2>&1 &"
if [[ $eval -eq 1 ]]; then
eval $cmd
sleep 2s
tail -n `wc -l ${model_dir}/train.log | awk '{print $1+1}'` -f ${model_dir}/train.log
fi
fi
wait
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: MT Decoding"
if [[ ${n_average} -ne 1 ]]; then
# Average models
dec_model=avg_${n_average}_checkpoint.pt
cmd="python ${root_dir}/scripts/average_checkpoints.py
--inputs ${model_dir}
--num-epoch-checkpoints ${n_average}
--output ${model_dir}/${dec_model}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd
else
dec_model=checkpoint_best.pt
fi
if [[ -z ${device} || ${#device[@]} -eq 0 ]]; then
if [[ ${gpu_num} -eq 0 ]]; then
device=()
else
source ./local/utils.sh
device=$(get_devices $gpu_num 0)
fi
fi
export CUDA_VISIBLE_DEVICES=${device}
#tmp_file=$(mktemp ${model_dir}/tmp-XXXXX)
#trap 'rm -rf ${tmp_file}' EXIT
result_file=${model_dir}/decode_result
[[ -f ${result_file} ]] && rm ${result_file}
for subset in ${test_subset[@]}; do
subset=${subset}_st
cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir}/$lang
--config-yaml ${data_config}
--gen-subset ${subset}
--task ${task}
--path ${model_dir}/${dec_model}
--results-path ${model_dir}
--max-tokens ${max_tokens}
--beam ${beam_size}
--scoring sacrebleu"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
if [[ $eval -eq 1 ]]; then
eval $cmd
tail -n 1 ${model_dir}/generate-${subset}.txt >> ${result_file}
fi
done
cat ${result_file}
fi
#! /bin/bash
# training the model
gpu_num=1
update_freq=1
max_tokens=4096
extra_tag=
extra_parameter=
#extra_tag="${extra_tag}"
#extra_parameter="${extra_parameter} "
exp_tag=
train_config=train.yaml
cmd="./run.sh
--stage 1
--stop_stage 1
--gpu_num ${gpu_num}
--update_freq ${update_freq}
--train_config ${train_config}
--max_tokens ${max_tokens}
"
if [[ -n ${exp_tag} ]]; then
cmd="$cmd --exp_tag ${exp_tag}"
fi
if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}"
fi
if [[ -n ${extra_parameter} ]]; then
cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi
echo $cmd
eval $cmd
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Tuple
import pandas as pd
from examples.speech_to_text.data_utils import (
gen_vocab,
save_df_to_tsv,
)
from torch.utils.data import Dataset
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["src_text", "tgt_text"]
class MTData(Dataset):
"""
Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
def __init__(self, root: str, src_lang, tgt_lang: str, split: str) -> None:
_root = Path(root) / f"{src_lang}-{tgt_lang}" / "data" / split
txt_root = _root / "txt"
assert _root.is_dir() and txt_root.is_dir()
# Load source and target text
self.data = []
for _lang in [src_lang, tgt_lang]:
with open(txt_root / f"{split}.{_lang}") as f:
texts = [r.strip() for r in f]
self.data.append(texts)
self.data = list(zip(self.data[0], self.data[1]))
def __getitem__(self, n: int) -> Tuple[str, str]:
src_text, tgt_text = self.data[n]
return src_text, tgt_text
def __len__(self) -> int:
return len(self.data)
def process(args):
src_lang = args.src_lang
tgt_lang = args.tgt_lang
splits = args.splits.split(",")
src_train_text = []
tgt_train_text = []
lang = f"{src_lang}-{tgt_lang}"
cur_root = Path(args.data_root).absolute() / lang
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute()
# Generate TSV manifest
print("Generating manifest...")
for split in splits:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = MTData(args.data_root, src_lang, tgt_lang, split)
for src_text, tgt_text in tqdm(dataset):
manifest["src_text"].append(src_text)
manifest["tgt_text"].append(tgt_text)
if is_train_split and args.size != -1 and len(manifest["src_text"]) > args.size:
break
if is_train_split:
src_train_text.extend(manifest["src_text"])
tgt_train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
save_df_to_tsv(df, output_root / f"{split}.tsv")
# Generate vocab and yaml
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}"
if args.share:
tgt_train_text.extend(src_train_text)
src_spm_filename_prefix = spm_filename_prefix + "_share"
tgt_spm_filename_prefix = src_spm_filename_prefix
else:
src_spm_filename_prefix = spm_filename_prefix + "_" + src_lang
tgt_spm_filename_prefix = spm_filename_prefix + "_" + tgt_lang
with NamedTemporaryFile(mode="w") as f:
for t in tgt_train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / tgt_spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
if not args.share:
with NamedTemporaryFile(mode="w") as f:
for t in src_train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / src_spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
yaml_filename = f"config.yaml"
if args.share:
yaml_filename = f"config_share.yaml"
conf = {}
conf["src_vocab_filename"] = src_spm_filename_prefix + ".txt"
conf["tgt_vocab_filename"] = tgt_spm_filename_prefix + ".txt"
conf["src_bpe_tokenizer"] = {
"bpe": "sentencepiece",
"sentencepiece_model": (output_root / (src_spm_filename_prefix + ".model")).as_posix(),
}
conf["tgt_bpe_tokenizer"] = {
"bpe": "sentencepiece",
"sentencepiece_model": (output_root / (tgt_spm_filename_prefix + ".model")).as_posix(),
}
import yaml
with open(os.path.join(output_root, yaml_filename), "w") as f:
yaml.dump(conf, f)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=10000, type=int)
parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--splits", default="train,dev,test", type=str)
parser.add_argument("--src-lang", required=True, type=str)
parser.add_argument("--tgt-lang", required=True, type=str)
parser.add_argument("--share", action="store_true", help="share the source and target vocabulary")
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import itertools
import json
import logging
import os
import os.path as op
from typing import Optional
from argparse import Namespace
from omegaconf import II
import csv
from typing import Dict, List, Optional, Tuple
import numpy as np
from fairseq import metrics, utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
LanguagePairDataset,
PrependTokenDataset,
StripTokenDataset,
TruncateDataset,
data_utils,
encoders,
indexed_dataset,
)
from fairseq.data.fairseq_dataset import FairseqDataset
from functools import lru_cache
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, TranslationConfig
EVAL_BLEU_ORDER = 4
logger = logging.getLogger(__name__)
class ListTextDataset(FairseqDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, text, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = text
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(dictionary)
self.size = len(self.tokens_list)
def read_data(self, dictionary):
for line in self.lines:
tokens = dictionary.encode_line(
line,
add_if_not_exist=False,
append_eos=self.append_eos,
reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError("index out of range")
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
def load_langpair_dataset(
data_path,
split,
src,
src_dict,
tgt,
tgt_dict,
combine,
dataset_impl,
upsample_primary,
left_pad_source,
left_pad_target,
max_source_positions,
max_target_positions,
prepend_bos=False,
load_alignments=False,
truncate_source=False,
append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
pre_tokenizer=None,
src_bpe_tokenizer=None,
tgt_bpe_tokenizer=None,
):
def tokenize_text(text: str, pre_tokenizer, bpe_tokenizer):
if pre_tokenizer is not None:
text = pre_tokenizer.encode(text)
if bpe_tokenizer is not None:
text = bpe_tokenizer.encode(text)
return text
src_datasets = []
tgt_datasets = []
root = data_path
_splits = split.split(",")
for split in _splits:
src_texts = []
tgt_texts = []
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
for e in reader:
item = dict(e)
src_text = item.get("src_text", "")
tgt_text = item.get("tgt_text", "")
src_tokenized = tokenize_text(src_text, pre_tokenizer, src_bpe_tokenizer)
tgt_tokenized = tokenize_text(tgt_text, pre_tokenizer, tgt_bpe_tokenizer)
src_texts.append(src_tokenized)
tgt_texts.append(tgt_tokenized)
assert len(src_texts) > 0 and len(src_texts) == len(tgt_texts), (len(src_texts), len(tgt_texts))
src_dataset = ListTextDataset(src_texts, dictionary=src_dict)
if truncate_source:
src_dataset = AppendTokenDataset(
TruncateDataset(
StripTokenDataset(src_dataset, src_dict.eos()),
max_source_positions - 1,
),
src_dict.eos(),
)
src_datasets.append(src_dataset)
tgt_dataset = ListTextDataset(tgt_texts, dictionary=tgt_dict)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{} {} examples".format(
data_path, split, src, tgt, len(src_datasets[-1])
)
)
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
if prepend_bos:
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
eos = None
if append_source_id:
src_dataset = AppendTokenDataset(
src_dataset, src_dict.index("[{}]".format(src))
)
if tgt_dataset is not None:
tgt_dataset = AppendTokenDataset(
tgt_dataset, tgt_dict.index("[{}]".format(tgt))
)
eos = tgt_dict.index("[{}]".format(tgt))
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(
align_path, None, dataset_impl
)
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguagePairDataset(
src_dataset,
src_dataset.sizes,
src_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
align_dataset=align_dataset,
eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
)
class TransDataConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path):
try:
import yaml
except ImportError:
print("Please install PyYAML to load YAML files for " "S2T data config")
self.config = {}
if op.isfile(yaml_path):
try:
with open(yaml_path) as f:
self.config = yaml.load(f, Loader=yaml.FullLoader)
except Exception as e:
logger.info(f"Failed to load config from {yaml_path}: {e}")
else:
logger.error(f"Cannot find {yaml_path}")
@property
def src_vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("src_vocab_filename", None)
@property
def tgt_vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("tgt_vocab_filename", None)
@property
def share_src_and_tgt(self) -> bool:
return self.config.get("share_src_and_tgt", False)
@property
def shuffle(self) -> bool:
"""Shuffle dataset samples before batching"""
return self.config.get("shuffle", False)
@property
def pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("pre_tokenizer", {"tokenizer": None})
@property
def src_bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_bpe_tokenizer", {"bpe": None})
@property
def tgt_bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("tgt_bpe_tokenizer", None)
@property
def prepend_tgt_lang_tag(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_tgt_lang_tag", False)
@dataclass
class TranslationWithTokenizerConfig(TranslationConfig):
config_yaml: Optional[str] = field(
default="config.yaml",
metadata={
"help": "Configuration YAML filename (under manifest root)"
},
)
@register_task("translation_with_tokenizer", dataclass=TranslationWithTokenizerConfig)
class TranslationWithTokenizerTask(TranslationTask):
"""
Translate from one (source) language to another (target) language.
Args:
src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
"""
cfg: TranslationWithTokenizerConfig
def __init__(self, cfg: TranslationWithTokenizerConfig, src_dict, tgt_dict, data_cfg,
pre_tokenizer, src_bpe_tokenizer, tgt_bpe_tokenizer):
super().__init__(cfg, src_dict, tgt_dict)
self.data_cfg = data_cfg
self.pre_tokenizer = pre_tokenizer
self.src_bpe_tokenizer = src_bpe_tokenizer
self.tgt_bpe_tokenizer = tgt_bpe_tokenizer
@classmethod
def setup_task(cls, cfg: TranslationWithTokenizerConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (argparse.Namespace): parsed command-line arguments
"""
data_cfg = TransDataConfig(op.join(cfg.data, cfg.config_yaml))
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
# find language pair automatically
if cfg.source_lang is None or cfg.target_lang is None:
raise Exception(
"Could not infer language pair, please provide it explicitly"
)
# load
src_dict_path = op.join(cfg.data, data_cfg.src_vocab_filename)
if not op.isfile(src_dict_path):
raise FileNotFoundError(f"source Dict not found: {src_dict_path}")
src_dict = cls.load_dictionary(src_dict_path)
tgt_dict_path = op.join(cfg.data, data_cfg.tgt_vocab_filename)
if not op.isfile(tgt_dict_path):
raise FileNotFoundError(f"target Dict not found: {src_dict_path}")
tgt_dict = cls.load_dictionary(src_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
logger.info("[{}] source dictionary: {} types".format(cfg.source_lang, len(src_dict)))
logger.info("[{}] target dictionary: {} types".format(cfg.target_lang, len(tgt_dict)))
logger.info(f"pre-tokenizer: {data_cfg.pre_tokenizer}")
pre_tokenizer = encoders.build_tokenizer(Namespace(**data_cfg.pre_tokenizer))
logger.info(f"source tokenizer: {data_cfg.src_bpe_tokenizer}")
src_bpe_tokenizer = encoders.build_bpe(Namespace(**data_cfg.src_bpe_tokenizer))
logger.info(f"target tokenizer: {data_cfg.tgt_bpe_tokenizer}")
tgt_bpe_tokenizer = encoders.build_bpe(Namespace(**data_cfg.tgt_bpe_tokenizer))
return cls(cfg, src_dict, tgt_dict, data_cfg, pre_tokenizer, src_bpe_tokenizer, tgt_bpe_tokenizer)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
if split != self.cfg.train_subset:
# if not training data set, use the first shard for valid and test
paths = paths[:1]
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.cfg.source_lang, self.cfg.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
split,
src,
self.src_dict,
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.cfg.dataset_impl,
upsample_primary=self.cfg.upsample_primary,
left_pad_source=self.cfg.left_pad_source,
left_pad_target=self.cfg.left_pad_target,
max_source_positions=self.cfg.max_source_positions,
max_target_positions=self.cfg.max_target_positions,
load_alignments=self.cfg.load_alignments,
truncate_source=self.cfg.truncate_source,
num_buckets=self.cfg.num_batch_buckets,
shuffle=(split != "test"),
pad_to_multiple=self.cfg.required_seq_len_multiple,
pre_tokenizer=self.pre_tokenizer,
src_bpe_tokenizer=self.src_bpe_tokenizer,
tgt_bpe_tokenizer=self.tgt_bpe_tokenizer,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
return LanguagePairDataset(
src_tokens,
src_lengths,
self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints,
)
def build_model(self, cfg):
model = super().build_model(cfg)
if self.cfg.eval_bleu:
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
)
gen_args = json.loads(self.cfg.eval_bleu_args)
self.sequence_generator = self.build_generator(
[model], Namespace(**gen_args)
)
return model
def build_tokenizer(self, args):
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer))
def build_src_bpe(self):
logger.info(f"source tokenizer: {self.data_cfg.src_bpe_tokenizer}")
return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer))
def build_src_bpe(self, args):
logger.info(f"src tokenizer: {self.data_cfg.src_bpe_tokenizer}")
return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer))
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.src_dict
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.tgt_dict
def _inference_with_bleu(self, generator, sample, model):
import sacrebleu
def decode(toks, escape_unk=False):
s = self.tgt_dict.string(
toks.int().cpu(),
self.cfg.eval_bleu_remove_bpe,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]["tokens"]))
refs.append(
decode(
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
escape_unk=True, # don't count <unk> as matches to the hypo
)
)
if self.cfg.eval_bleu_print_samples:
logger.info("example hypothesis: " + hyps[0])
logger.info("example reference: " + refs[0])
if self.cfg.eval_tokenized_bleu:
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
else:
return sacrebleu.corpus_bleu(hyps, [refs])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论