Commit 28d33ad8 by xuchen

modify the preprocessing of s2t

parent c02e267c
...@@ -29,19 +29,27 @@ pwd_dir=$PWD ...@@ -29,19 +29,27 @@ pwd_dir=$PWD
# dataset # dataset
src_lang=swa src_lang=swa
lang=${src_lang} tgt_lang=en
lang=${src_lang}-${tgt_lang}
dataset=lower dataset=lower
task=speech_to_text task=speech_to_text
vocab_type=unigram vocab_type=unigram
vocab_size=10000 vocab_size=1000
speed_perturb=0 speed_perturb=1
lcrm=1
org_data_dir=/media/data/${dataset} use_specific_dict=0
data_dir=~/st/data/${dataset} specific_prefix=valid
specific_dir=/home/xuchen/st/data/mustc/st_lcrm/en-de
asr_vocab_prefix=spm_unigram10000_st_share
org_data_dir=~/st/data/${dataset}/asr
data_dir=~/st/data/${dataset}/asr
test_subset=test test_subset=test
# exp # exp
exp_prefix=${time}
extra_tag= extra_tag=
extra_parameter= extra_parameter=
exp_tag=baseline exp_tag=baseline
...@@ -49,7 +57,7 @@ exp_name= ...@@ -49,7 +57,7 @@ exp_name=
# config # config
train_config=train_ctc.yaml train_config=train_ctc.yaml
data_config=config.yaml data_config=config_asr.yaml
# training setting # training setting
fp16=1 fp16=1
...@@ -62,6 +70,15 @@ beam_size=5 ...@@ -62,6 +70,15 @@ beam_size=5
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
data_dir=${data_dir}_sp data_dir=${data_dir}_sp
exp_prefix=${exp_prefix}_sp
fi
if [[ ${lcrm} -eq 1 ]]; then
data_dir=${data_dir}_lcrm
exp_prefix=${exp_prefix}_lcrm
fi
if [[ ${use_specific_dict} -eq 1 ]]; then
data_dir=${data_dir}_${specific_prefix}
exp_prefix=${exp_prefix}_${specific_prefix}
fi fi
. ./local/parse_options.sh || exit 1; . ./local/parse_options.sh || exit 1;
...@@ -69,13 +86,10 @@ fi ...@@ -69,13 +86,10 @@ fi
# full path # full path
train_config=$pwd_dir/conf/${train_config} train_config=$pwd_dir/conf/${train_config}
if [[ -z ${exp_name} ]]; then if [[ -z ${exp_name} ]]; then
exp_name=$(basename ${train_config%.*})_${exp_tag} exp_name=${exp_prefix}_$(basename ${train_config%.*})_${exp_tag}
if [[ -n ${extra_tag} ]]; then if [[ -n ${extra_tag} ]]; then
exp_name=${exp_name}_${extra_tag} exp_name=${exp_name}_${extra_tag}
fi fi
if [[ ${speed_perturb} -eq 1 ]]; then
exp_name=sp_${exp_name}
fi
fi fi
model_dir=$root_dir/../checkpoints/$dataset/asr/${exp_name} model_dir=$root_dir/../checkpoints/$dataset/asr/${exp_name}
...@@ -87,26 +101,42 @@ fi ...@@ -87,26 +101,42 @@ fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
### Task dependent. You have to make data the following preparation part by yourself. ### Task dependent. You have to make data the following preparation part by yourself.
### But you can utilize Kaldi recipes in most cases ### But you can utilize Kaldi recipes in most cases
echo "stage 0: Data Preparation" echo "stage 0: ASR Data Preparation"
if [[ ! -e ${data_dir}/${lang} ]]; then
if [[ ! -e ${data_dir} ]]; then mkdir -p ${data_dir}/${lang}
mkdir -p ${data_dir}
fi fi
source ~/tools/audio/bin/activate source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_librispeech_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
--output-root ${data_dir} --output-root ${data_dir}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--task asr
--vocab-type ${vocab_type} --vocab-type ${vocab_type}
--vocab-size ${vocab_size}" --vocab-size ${vocab_size}"
if [[ ${use_specific_dict} -eq 1 ]]; then
cp -r ${specific_dir}/${asr_vocab_prefix}.* ${data_dir}/${lang}
cmd="$cmd
--asr-prefix ${asr_vocab_prefix}"
fi
if [[ ${speed_perturb} -eq 1 ]]; then if [[ ${speed_perturb} -eq 1 ]]; then
cmd="$cmd cmd="$cmd
--speed-perturb" --speed-perturb"
fi fi
if [[ ${lcrm} -eq 1 ]]; then
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd [[ $eval -eq 1 ]] && eval ${cmd}
deactivate
fi fi
data_dir=${data_dir}/${lang}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: ASR Network Training" echo "stage 1: ASR Network Training"
[[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1; [[ ! -d ${data_dir} ]] && echo "The data dir ${data_dir} is not existing!" && exit 1;
...@@ -242,7 +272,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -242,7 +272,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
test_subset=(${test_subset//,/ }) test_subset=(${test_subset//,/ })
for subset in ${test_subset[@]}; do for subset in ${test_subset[@]}; do
subset=${subset} subset=${subset}_asr
cmd="python ${root_dir}/fairseq_cli/generate.py cmd="python ${root_dir}/fairseq_cli/generate.py
${data_dir} ${data_dir}
--config-yaml ${data_config} --config-yaml ${data_config}
...@@ -252,7 +282,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -252,7 +282,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--results-path ${model_dir} --results-path ${model_dir}
--max-tokens ${max_tokens} --max-tokens ${max_tokens}
--beam ${beam_size} --beam ${beam_size}
--scoring wer" --scoring wer
--wer-tokenizer 13a
--wer-lowercase
--wer-remove-punct
"
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
if [[ $eval -eq 1 ]]; then if [[ $eval -eq 1 ]]; then
......
...@@ -58,7 +58,6 @@ exp_name= ...@@ -58,7 +58,6 @@ exp_name=
# config # config
train_config=train_ctc.yaml train_config=train_ctc.yaml
data_config=config_asr.yaml data_config=config_asr.yaml
data_config=config_st_share.yaml
# training setting # training setting
fp16=1 fp16=1
......
...@@ -137,6 +137,7 @@ class CoVoST(Dataset): ...@@ -137,6 +137,7 @@ class CoVoST(Dataset):
if self.no_translation: if self.no_translation:
print("No target translation.") print("No target translation.")
df = cv_tsv[["path", "sentence", "client_id"]] df = cv_tsv[["path", "sentence", "client_id"]]
df = df.set_index(["path"], drop=False)
else: else:
covost_url = self.COVOST_URL_TEMPLATE.format( covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language src_lang=source_language, tgt_lang=target_language
...@@ -165,26 +166,26 @@ class CoVoST(Dataset): ...@@ -165,26 +166,26 @@ class CoVoST(Dataset):
self.data = [] self.data = []
for e in data: for e in data:
try: try:
# path = self.root / "clips" / e["path"] path = self.root / "wav" / e["path"]
# _ = torchaudio.info(path.as_posix()) _ = torchaudio.info(path.as_posix())
self.data.append(e) self.data.append(e)
except RuntimeError: except RuntimeError:
pass pass
def __getitem__( def __getitem__(
self, n: int self, n: int
) -> Tuple[Path, int, int, str, str, Optional[str], str, str]: ) -> Tuple[Path, int, int, str, str, str, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
Args: Args:
n (int): The index of the sample to be loaded n (int): The index of the sample to be loaded
Returns: Returns:
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id, tuple: ``(wav_path, sample_rate, n_frames, sentence, translation, speaker_id,
sample_id)`` sample_id)``
""" """
data = self.data[n] data = self.data[n]
path = self.root / "clips" / data["path"] path = self.root / "wav" / data["path"]
info = torchaudio.info(path) info = torchaudio.info(path)
sample_rate = info.sample_rate sample_rate = info.sample_rate
n_frames = info.num_frames n_frames = info.num_frames
...@@ -235,9 +236,9 @@ def process(args): ...@@ -235,9 +236,9 @@ def process(args):
# Generate TSV manifest # Generate TSV manifest
print("Generating manifest...") print("Generating manifest...")
train_text = [] train_text = []
task = f"asr_{args.src_lang}" task = args.task
if args.tgt_lang is not None: # if args.tgt_lang is not None:
task = f"st_{args.src_lang}_{args.tgt_lang}" # task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS: for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
...@@ -255,7 +256,7 @@ def process(args): ...@@ -255,7 +256,7 @@ def process(args):
src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "") src_utt = src_utt.replace(" ", "")
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt) manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
if args.tgt_lang is not None: if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt) manifest["src_text"].append(src_utt)
manifest["speaker"].append(speaker_id) manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
......
...@@ -46,7 +46,7 @@ class MUSTC(Dataset): ...@@ -46,7 +46,7 @@ class MUSTC(Dataset):
utterance_id utterance_id
""" """
SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"] SPLITS = ["dev", "tst-COMMON", "train"]
# SPLITS = ["train_debug", "dev"] # SPLITS = ["train_debug", "dev"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"] LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
...@@ -123,6 +123,20 @@ class MUSTC(Dataset): ...@@ -123,6 +123,20 @@ class MUSTC(Dataset):
items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id]) items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items return items
def get_wav(self, n: int, speed_perturb=1.0):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
if self.speed_perturb is None or speed_perturb == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
effects = [
["speed", f"{speed_perturb}"],
["rate", f"{sr}"]
]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
return waveform
def get_fast(self, n: int): def get_fast(self, n: int):
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
...@@ -187,13 +201,21 @@ def process(args): ...@@ -187,13 +201,21 @@ def process(args):
print("And estimating cepstral mean and variance stats...") print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = [] gcmvn_feature_list = []
for items in tqdm(dataset): for idx in tqdm(range(len(dataset))):
items = dataset.get_fast(idx)
for item in items: for item in items:
index += 1 index += 1
waveform, sr, _, _, _, _, utt_id = item wav_path, sr, _, _, _, _, utt_id = item
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path)) if not os.path.exists(features_path):
sp = 1.0
if dataset.speed_perturb is not None:
sp = float(utt_id.split("_")[0].replace("sp", ""))
waveform = dataset.get_wav(idx, sp)
if waveform.shape[1] == 0:
continue
features = extract_fbank_features(waveform, sr, Path(features_path))
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"): if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num: if len(gcmvn_feature_list) < args.gcmvn_max_num:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论