#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import logging import os from pathlib import Path import shutil from itertools import groupby from tempfile import NamedTemporaryFile import string import csv import numpy as np import pandas as pd import torchaudio from examples.speech_to_text.data_utils import ( create_zip, extract_fbank_features, filter_manifest_df, gen_config_yaml, gen_vocab, get_zip_manifest, load_df_from_tsv, save_df_to_tsv, cal_gcmvn_stats, ) from torch.utils.data import Dataset from tqdm import tqdm log = logging.getLogger(__name__) MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] class ST_Dataset(Dataset): """ Create a Dataset for MuST-C. Each item is a tuple of the form: waveform, sample_rate, source utterance, target utterance, speaker_id, utterance_id """ def __init__(self, root: str, src_lang, tgt_lang: str, split: str, speed_perturb: bool = False, tokenizer: bool = False) -> None: _root = Path(root) / f"{src_lang}-{tgt_lang}" / split wav_root, txt_root = _root / "wav", _root / "txt" if tokenizer: txt_root = _root / "txt.tok" assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root) # Load audio segments try: import yaml except ImportError: print("Please install PyYAML to load the MuST-C YAML files") with open(txt_root / f"{split}.yaml") as f: segments = yaml.load(f, Loader=yaml.BaseLoader) self.speed_perturb = [0.9, 1.0, 1.1] if speed_perturb and split.startswith("train") else None # Load source and target utterances for _lang in [src_lang, tgt_lang]: with open(txt_root / f"{split}.{_lang}") as f: utterances = [r.strip() for r in f] assert len(segments) == len(utterances) for i, u in enumerate(utterances): segments[i][_lang] = u # Gather info self.data = [] for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_path = wav_root / wav_filename try: sample_rate = torchaudio.info(wav_path.as_posix())[0].rate except TypeError: sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate) _id = f"{split}_{wav_path.stem}_{i}" self.data.append( ( wav_path.as_posix(), offset, n_frames, sample_rate, segment[src_lang], segment[tgt_lang], segment["speaker_id"] if "speaker_id" in segment else "spk1", _id, ) ) def __getitem__(self, n: int): wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] items = [] if self.speed_perturb is None: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) items.append([waveform, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id]) else: for speed in self.speed_perturb: sp_utt_id = f"sp{speed}_" + utt_id sp_n_frames = n_frames / speed if speed == 1.0: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) else: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) effects = [ ["speed", f"{speed}"], ["rate", f"{sr}"] ] waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects) items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id]) return items def get_wav(self, n: int, speed_perturb=1.0): wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] if self.speed_perturb is None or speed_perturb == 1.0: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) else: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) effects = [ ["speed", f"{speed_perturb}"], ["rate", f"{sr}"] ] waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects) return waveform def get_fast(self, n: int): wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] items = [] if self.speed_perturb is None: items.append([wav_path, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id]) else: for speed in self.speed_perturb: sp_utt_id = f"sp{speed}_" + utt_id sp_n_frames = n_frames / speed items.append([wav_path, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id]) return items def get_src_text(self): src_text = [] for item in self.data: src_text.append(item[4]) return src_text def get_tgt_text(self): tgt_text = [] for item in self.data: tgt_text.append(item[5]) return tgt_text def __len__(self) -> int: return len(self.data) def process(args): root = Path(args.data_root).absolute() splits = args.splits.split(",") src_lang = args.src_lang tgt_lang = args.tgt_lang cur_root = root / f"{src_lang}-{tgt_lang}" if not cur_root.is_dir(): print(f"{cur_root.as_posix()} does not exist. Skipped.") return if args.output_root is None: output_root = cur_root else: output_root = Path(args.output_root).absolute() / f"{src_lang}-{tgt_lang}" # Extract features if args.speed_perturb: zip_path = output_root / "fbank80_sp.zip" else: zip_path = output_root / "fbank80.zip" index = 0 gen_feature_flag = False if not Path.exists(zip_path): gen_feature_flag = True if args.overwrite or gen_feature_flag: if args.speed_perturb: feature_root = output_root / "fbank80_sp" else: feature_root = output_root / "fbank80" feature_root.mkdir(exist_ok=True) for split in splits: print(f"Fetching split {split}...") dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb, args.tokenizer) is_train_split = split.startswith("train") print("Extracting log mel filter bank features...") if is_train_split and args.cmvn_type == "global": print("And estimating cepstral mean and variance stats...") gcmvn_feature_list = [] for idx in tqdm(range(len(dataset))): items = dataset.get_fast(idx) for item in items: index += 1 wav_path, sr, _, _, _, _, utt_id = item features_path = (feature_root / f"{utt_id}.npy").as_posix() if not os.path.exists(features_path): sp = 1.0 if dataset.speed_perturb is not None: sp = float(utt_id.split("_")[0].replace("sp", "")) waveform = dataset.get_wav(idx, sp) if waveform.shape[1] == 0: continue features = extract_fbank_features(waveform, sr, Path(features_path)) if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"): if len(gcmvn_feature_list) < args.gcmvn_max_num: gcmvn_feature_list.append(features) if is_train_split and args.size != -1 and index > args.size: break if is_train_split and args.cmvn_type == "global": # Estimate and save cmv stats = cal_gcmvn_stats(gcmvn_feature_list) with open(output_root / "gcmvn.npz", "wb") as f: np.savez(f, mean=stats["mean"], std=stats["std"]) # Pack features into ZIP print("ZIPing features...") create_zip(feature_root, zip_path) # Clean up shutil.rmtree(feature_root) gen_manifest_flag = False for split in splits: if not Path.exists(output_root / f"{split}_{args.task}.tsv"): gen_manifest_flag = True break train_text = [] if args.overwrite or gen_manifest_flag: print("Fetching ZIP manifest...") zip_manifest = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") for split in splits: is_train_split = split.startswith("train") manifest = {c: [] for c in MANIFEST_COLUMNS} if args.task == "st" and args.add_src: manifest["src_text"] = [] dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split, args.speed_perturb, args.tokenizer) for idx in range(len(dataset)): items = dataset.get_fast(idx) for item in items: _, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id = item manifest["id"].append(utt_id) manifest["audio"].append(zip_manifest[utt_id]) duration_ms = int(n_frames / sr * 1000) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) if args.lowercase_src: src_utt = src_utt.lower() if args.rm_punc_src: for w in string.punctuation: src_utt = src_utt.replace(w, "") manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) if args.task == "st" and args.add_src: manifest["src_text"].append(src_utt) manifest["speaker"].append(speaker_id) if is_train_split and args.size != -1 and len(manifest["id"]) > args.size: break if is_train_split: if args.task == "st" and args.add_src and args.share: train_text.extend(manifest["src_text"]) train_text.extend(manifest["tgt_text"]) df = pd.DataFrame.from_dict(manifest) df = filter_manifest_df(df, is_train_split=is_train_split) save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv") # Generate vocab v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" asr_spm_filename = None gen_vocab_flag = True if args.task == "st" and args.add_src: if args.share: if args.st_spm_prefix is not None: gen_vocab_flag = False spm_filename_prefix = args.st_spm_prefix else: spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share" asr_spm_filename = spm_filename_prefix + ".model" else: if args.st_spm_prefix is not None: gen_vocab_flag = False spm_filename_prefix = args.st_spm_prefix assert args.asr_prefix is not None asr_spm_filename = args.asr_prefix + ".model" elif args.task == "asr": if args.asr_prefix is not None: gen_vocab_flag = False spm_filename_prefix = args.asr_prefix if gen_vocab_flag: if len(train_text) == 0: print("Loading the training text to build dictionary...") for split in args.SPLITS: if split.startswith("train"): csv_path = output_root / f"{split}_{args.task}.tsv" with open(csv_path) as f: reader = csv.DictReader( f, delimiter="\t", quotechar=None, doublequote=False, lineterminator="\n", quoting=csv.QUOTE_NONE, ) if args.task == "st" and args.add_src and args.share: for e in reader: src_utt = dict(e)["src_text"] if args.lowercase_src: src_utt = src_utt.lower() if args.rm_punc_src: for w in string.punctuation: src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(" ", "") train_text.append(src_utt) tgt_text = [dict(e)["tgt_text"] for e in reader] train_text.extend(tgt_text) with NamedTemporaryFile(mode="w") as f: for t in train_text: f.write(t + "\n") gen_vocab( Path(f.name), output_root / spm_filename_prefix, args.vocab_type, args.vocab_size, ) # Generate config YAML yaml_filename = f"config_{args.task}.yaml" if args.task == "st" and args.add_src and args.share: yaml_filename = f"config_{args.task}_share.yaml" gen_config_yaml( output_root, spm_filename_prefix + ".model", yaml_filename=yaml_filename, specaugment_policy="lb", cmvn_type=args.cmvn_type, gcmvn_path=( output_root / "gcmvn.npz" if args.cmvn_type == "global" else None ), asr_spm_filename=asr_spm_filename, share_src_and_tgt=True if args.task == "asr" else False ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--data-root", "-d", required=True, type=str) parser.add_argument("--output-root", "-o", default=None, type=str) parser.add_argument( "--vocab-type", default="unigram", required=True, type=str, choices=["bpe", "unigram", "char"], ), parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, default="st", choices=["asr", "st"]) parser.add_argument("--src-lang", type=str, required=True, help="source language") parser.add_argument("--tgt-lang", type=str, required=True, help="target language") parser.add_argument("--splits", type=str, default="train,dev,test", help="dataset splits") parser.add_argument("--speed-perturb", action="store_true", default=False, help="apply speed perturbation on wave file") parser.add_argument("--share", action="store_true", help="share the tokenizer and dictionary of the transcription and translation") parser.add_argument("--add-src", action="store_true", help="add the src text for st task") parser.add_argument("--asr-prefix", type=str, default=None, help="prefix of the asr dict") parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text") parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt") parser.add_argument("--cmvn-type", default="utterance", choices=["global", "utterance"], help="The type of cepstral mean and variance normalization") parser.add_argument("--overwrite", action="store_true", help="overwrite the existing files") parser.add_argument("--gcmvn-max-num", default=150000, type=int, help=( "Maximum number of sentences to use to estimate" "global mean and variance" )) args = parser.parse_args() process(args) if __name__ == "__main__": main()