Commit 28d33ad8 by xuchen

modify the preprocessing of s2t

parent c02e267c
......@@ -29,19 +29,27 @@ pwd_dir=$PWD
# dataset
src_lang=swa
lang=${src_lang}
tgt_lang=en
lang=${src_lang}-${tgt_lang}
dataset=lower
task=speech_to_text
vocab_type=unigram
vocab_size=10000
speed_perturb=0
vocab_size=1000
speed_perturb=1
lcrm=1
org_data_dir=/media/data/${dataset}
data_dir=~/st/data/${dataset}
use_specific_dict=0
specific_prefix=valid
specific_dir=/home/xuchen/st/data/mustc/st_lcrm/en-de
asr_vocab_prefix=spm_unigram10000_st_share
org_data_dir=~/st/data/${dataset}/asr
data_dir=~/st/data/${dataset}/asr
test_subset=test
# exp
exp_prefix=${time}
extra_tag=
extra_parameter=
exp_tag=baseline
......@@ -49,7 +57,7 @@ exp_name=
# config
train_config=train_ctc.yaml
data_config=config.yaml
data_config=config_asr.yaml
# training setting
fp16=1
......@@ -62,6 +70,15 @@ beam_size=5
if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp
exp_prefix=${exp_prefix}_sp
fi
if [[ ${lcrm} -eq 1 ]]; then
data_dir=${data_dir}_lcrm
exp_prefix=${exp_prefix}_lcrm
fi
if [[ ${use_specific_dict} -eq 1 ]]; then
data_dir=${data_dir}_${specific_prefix}
exp_prefix=${exp_prefix}_${specific_prefix}
fi
. ./local/parse_options.sh || exit 1;
......@@ -69,13 +86,10 @@ fi
# full path
train_config=$pwd_dir/conf/${train_config}
if [[ -z ${exp_name} ]]; then
exp_name=$(basename ${train_config%.*})_${exp_tag}
exp_name=${exp_prefix}_$(basename ${train_config%.*})_${exp_tag}
if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag}
fi
if [[ ${speed_perturb} -eq 1 ]]; then
exp_name=sp_${exp_name}
fi
fi
model_dir=$root_dir/../checkpoints/$dataset/asr/${exp_name}
......@@ -87,26 +101,42 @@ fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself.
### But you can utilize Kaldi recipes in most cases
echo "stage 0: Data Preparation"
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
echo "stage 0: ASR Data Preparation"
if [[ ! -e ${data_dir}/${lang} ]]; then
mkdir -p ${data_dir}/${lang}
fi
source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_librispeech_data.py
cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--task asr
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
if [[ ${use_specific_dict} -eq 1 ]]; then
cp -r ${specific_dir}/${asr_vocab_prefix}.* ${data_dir}/${lang}
cmd="$cmd
--asr-prefix ${asr_vocab_prefix}"
fi
if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd
--speed-perturb"
fi
if [[ ${lcrm} -eq 1 ]]; then
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd
[[ $eval -eq 1 ]] && eval ${cmd}
deactivate
fi
data_dir=${data_dir}/${lang}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: ASR Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
......@@ -242,7 +272,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
test_subset=(${test_subset//,/ })
for subset in ${test_subset[@]}; do
subset=${subset}
subset=${subset}_asr
cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir}
--config-yaml ${data_config}
......@@ -252,7 +282,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--results-path ${model_dir}
--max-tokens ${max_tokens}
--beam ${beam_size}
--scoring wer"
--scoring wer
--wer-tokenizer 13a
--wer-lowercase
--wer-remove-punct
"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
if [[ $eval -eq 1 ]]; then
......
......@@ -58,7 +58,6 @@ exp_name=
# config
train_config=train_ctc.yaml
data_config=config_asr.yaml
data_config=config_st_share.yaml
# training setting
fp16=1
......
......@@ -137,6 +137,7 @@ class CoVoST(Dataset):
if self.no_translation:
print("No target translation.")
df = cv_tsv[["path", "sentence", "client_id"]]
df = df.set_index(["path"], drop=False)
else:
covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language
......@@ -165,26 +166,26 @@ class CoVoST(Dataset):
self.data = []
for e in data:
try:
# path = self.root / "clips" / e["path"]
# _ = torchaudio.info(path.as_posix())
path = self.root / "wav" / e["path"]
_ = torchaudio.info(path.as_posix())
self.data.append(e)
except RuntimeError:
pass
def __getitem__(
self, n: int
) -> Tuple[Path, int, int, str, str, Optional[str], str, str]:
) -> Tuple[Path, int, int, str, str, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
tuple: ``(wav_path, sample_rate, n_frames, sentence, translation, speaker_id,
sample_id)``
"""
data = self.data[n]
path = self.root / "clips" / data["path"]
path = self.root / "wav" / data["path"]
info = torchaudio.info(path)
sample_rate = info.sample_rate
n_frames = info.num_frames
......@@ -235,9 +236,9 @@ def process(args):
# Generate TSV manifest
print("Generating manifest...")
train_text = []
task = f"asr_{args.src_lang}"
if args.tgt_lang is not None:
task = f"st_{args.src_lang}_{args.tgt_lang}"
task = args.task
# if args.tgt_lang is not None:
# task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
......@@ -255,7 +256,7 @@ def process(args):
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
if args.tgt_lang is not None:
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train")
......
......@@ -46,7 +46,7 @@ class MUSTC(Dataset):
utterance_id
"""
SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
SPLITS = ["dev", "tst-COMMON", "train"]
# SPLITS = ["train_debug", "dev"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
......@@ -123,6 +123,20 @@ class MUSTC(Dataset):
items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_wav(self, n: int, speed_perturb=1.0):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
if self.speed_perturb is None or speed_perturb == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
effects = [
["speed", f"{speed_perturb}"],
["rate", f"{sr}"]
]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
return waveform
def get_fast(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
......@@ -187,13 +201,21 @@ def process(args):
print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = []
for items in tqdm(dataset):
for idx in tqdm(range(len(dataset))):
items = dataset.get_fast(idx)
for item in items:
index += 1
waveform, sr, _, _, _, _, utt_id = item
wav_path, sr, _, _, _, _, utt_id = item
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
if not os.path.exists(features_path):
sp = 1.0
if dataset.speed_perturb is not None:
sp = float(utt_id.split("_")[0].replace("sp", ""))
waveform = dataset.get_wav(idx, sp)
if waveform.shape[1] == 0:
continue
features = extract_fbank_features(waveform, sr, Path(features_path))
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论