Commit 71eae3fd by xuchen

modify the preprocessing of the s2t

parent a8105353
......@@ -50,10 +50,12 @@ class MUSTC(Dataset):
# SPLITS = ["train_debug", "dev"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False, tokenizer: bool = False) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = Path(root) / f"en-{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt"
if tokenizer:
txt_root = _root / "txt.tok"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root)
# Load audio segments
try:
......@@ -162,26 +164,23 @@ def process(args):
else:
output_root = Path(args.output_root).absolute() / f"en-{lang}"
# Extract features
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.speed_perturb:
zip_path = output_root / "fbank80_sp.zip"
else:
zip_path = output_root / "fbank80.zip"
index = 0
gen_feature_flag = False
if not Path.exists(zip_path):
gen_feature_flag = True
# Extract features
if args.overwrite or not Path.exists(zip_path):
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.overwrite or gen_feature_flag:
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb, args.tokenizer)
is_train_split = split.startswith("train")
print("Extracting log mel filter bank features...")
if is_train_split and args.cmvn_type == "global":
......@@ -193,13 +192,12 @@ def process(args):
index += 1
waveform, sr, _, _, _, _, utt_id = item
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if is_train_split and args.size != -1 and index > args.size:
break
......@@ -214,6 +212,9 @@ def process(args):
print("ZIPing features...")
create_zip(feature_root, zip_path)
# # Clean up
# shutil.rmtree(feature_root)
gen_manifest_flag = False
for split in MUSTC.SPLITS:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
......@@ -232,7 +233,7 @@ def process(args):
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb)
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
for idx in range(len(dataset)):
items = dataset.get_fast(idx)
for item in items:
......@@ -262,23 +263,11 @@ def process(args):
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in MUSTC.SPLITS:
if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split)
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
......@@ -292,6 +281,18 @@ def process(args):
train_text.append(src_utt)
train_text.append(tgt_utt)
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
......@@ -320,9 +321,6 @@ def process(args):
share_src_and_tgt=True if args.task == "asr" else False
)
# Clean up
shutil.rmtree(feature_root)
def process_joint(args):
cur_root = Path(args.data_root)
......@@ -392,6 +390,7 @@ def main():
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt")
parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization")
......
......@@ -11,9 +11,7 @@ from pathlib import Path
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import string
import pickle
import numpy as np
import pandas as pd
......@@ -74,11 +72,11 @@ class ST_Dataset(Dataset):
sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
except TypeError:
sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{wav_path.stem}_{i}"
_id = f"{split}_{wav_path.stem}_{i}"
self.data.append(
(
wav_path.as_posix(),
......@@ -87,7 +85,7 @@ class ST_Dataset(Dataset):
sample_rate,
segment[src_lang],
segment[tgt_lang],
segment["speaker_id"],
segment["speaker_id"] if "speaker_id" in segment else "spk1",
_id,
)
)
......@@ -188,7 +186,7 @@ def process(args):
for items in tqdm(dataset):
for item in items:
index += 1
waveform, sr, _, _, _, utt_id = item
waveform, sr, _, _, _, _, utt_id = item
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix()
......@@ -259,41 +257,55 @@ def process(args):
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
elif args.task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.st_spm_prefix is None:
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in splits:
if split.startswith("train"):
dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split)
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
if args.task == "st" and args.add_src and args.share:
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
src_utt = src_utt.translate(None, string.punctuation)
train_text.append(src_utt)
train_text.append(tgt_utt)
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in splits:
if split.startswith("train"):
dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split)
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
if args.task == "st" and args.add_src and args.share:
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
src_utt = src_utt.translate(None, string.punctuation)
train_text.append(src_utt)
train_text.append(tgt_utt)
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
yaml_filename = f"config_{args.task}.yaml"
if args.task == "st" and args.add_src and args.share:
......@@ -324,7 +336,6 @@ def main():
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
......@@ -339,7 +350,8 @@ def main():
parser.add_argument("--share", action="store_true",
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--asr-prefix", type=str, default=None, help="prefix of the asr dict")
parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--cmvn-type", default="utterance",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论