#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import logging import os from pathlib import Path import shutil from itertools import groupby from tempfile import NamedTemporaryFile import string import csv import numpy as np import pandas as pd import torchaudio from examples.speech_to_text.data_utils import ( create_zip, extract_fbank_features, filter_manifest_df, gen_config_yaml, gen_vocab, get_zip_manifest, load_df_from_tsv, save_df_to_tsv, cal_gcmvn_stats, ) from torch.utils.data import Dataset from tqdm import tqdm log = logging.getLogger(__name__) MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] class MUSTC(Dataset): """ Create a Dataset for MuST-C. Each item is a tuple of the form: waveform, sample_rate, source utterance, target utterance, speaker_id, utterance_id """ SPLITS = ["dev", "tst-COMMON", "train"] LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"] def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False, tokenizer: bool = False) -> None: assert split in self.SPLITS and lang in self.LANGUAGES _root = Path(root) / f"en-{lang}" / "data" / split wav_root, txt_root = _root / "wav", _root / "txt" if tokenizer: txt_root = _root / "txt.tok" assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root) # Load audio segments try: import yaml except ImportError: print("Please install PyYAML to load the MuST-C YAML files") with open(txt_root / f"{split}.yaml") as f: segments = yaml.load(f, Loader=yaml.BaseLoader) self.speed_perturb = [0.9, 1.0, 1.1] if speed_perturb and split.startswith("train") else None # Load source and target utterances for _lang in ["en", lang]: with open(txt_root / f"{split}.{_lang}") as f: utterances = [r.strip() for r in f] assert len(segments) == len(utterances) for i, u in enumerate(utterances): segments[i][_lang] = u # Gather info self.data = [] for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_path = wav_root / wav_filename sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate) _id = f"{split}_{wav_path.stem}_{i}" self.data.append( ( wav_path.as_posix(), offset, n_frames, sample_rate, segment["en"], segment[lang], segment["speaker_id"], _id, ) ) def __getitem__(self, n: int): wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] items = [] if self.speed_perturb is None: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) items.append([waveform, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id]) else: for speed in self.speed_perturb: sp_utt_id = f"sp{speed}_" + utt_id sp_n_frames = n_frames / speed if speed == 1.0: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) else: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) effects = [ ["speed", f"{speed}"], ["rate", f"{sr}"] ] waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects) items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id]) return items def get_wav(self, n: int, speed_perturb=1.0): wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] if self.speed_perturb is None or speed_perturb == 1.0: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) else: waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) effects = [ ["speed", f"{speed_perturb}"], ["rate", f"{sr}"] ] waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects) return waveform def get_fast(self, n: int): wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] items = [] if self.speed_perturb is None: items.append([wav_path, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id]) else: for speed in self.speed_perturb: sp_utt_id = f"sp{speed}_" + utt_id sp_n_frames = n_frames / speed items.append([wav_path, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id]) return items def get_src_text(self): src_text = [] for item in self.data: src_text.append(item[4]) return src_text def get_tgt_text(self): tgt_text = [] for item in self.data: tgt_text.append(item[5]) return tgt_text def __len__(self) -> int: return len(self.data) def process(args): root = Path(args.data_root).absolute() for lang in MUSTC.LANGUAGES: cur_root = root / f"en-{lang}" if not cur_root.is_dir(): print(f"{cur_root.as_posix()} does not exist. Skipped.") continue if args.output_root is None: output_root = cur_root else: output_root = Path(args.output_root).absolute() / f"en-{lang}" if args.speed_perturb: zip_path = output_root / "fbank80_sp.zip" else: zip_path = output_root / "fbank80.zip" index = 0 # Extract features if args.overwrite or not Path.exists(zip_path): if args.speed_perturb: feature_root = output_root / "fbank80_sp" else: feature_root = output_root / "fbank80" feature_root.mkdir(exist_ok=True) for split in MUSTC.SPLITS: print(f"Fetching split {split}...") dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb, args.tokenizer) is_train_split = split.startswith("train") print("Extracting log mel filter bank features...") if is_train_split and args.cmvn_type == "global": print("And estimating cepstral mean and variance stats...") gcmvn_feature_list = [] for idx in tqdm(range(len(dataset))): items = dataset.get_fast(idx) for item in items: index += 1 wav_path, sr, _, _, _, _, utt_id = item features_path = (feature_root / f"{utt_id}.npy").as_posix() if not os.path.exists(features_path): sp = 1.0 if dataset.speed_perturb is not None: sp = float(utt_id.split("_")[0].replace("sp", "")) waveform = dataset.get_wav(idx, sp) if waveform.shape[1] == 0: continue features = extract_fbank_features(waveform, sr, Path(features_path)) if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"): if len(gcmvn_feature_list) < args.gcmvn_max_num: gcmvn_feature_list.append(features) if is_train_split and args.size != -1 and index > args.size: break if is_train_split and args.cmvn_type == "global": # Estimate and save cmv stats = cal_gcmvn_stats(gcmvn_feature_list) with open(output_root / "gcmvn.npz", "wb") as f: np.savez(f, mean=stats["mean"], std=stats["std"]) # Pack features into ZIP print("ZIPing features...") create_zip(feature_root, zip_path) # # Clean up # shutil.rmtree(feature_root) gen_manifest_flag = False for split in MUSTC.SPLITS: if not Path.exists(output_root / f"{split}_{args.task}.tsv"): gen_manifest_flag = True break train_text = [] if args.overwrite or gen_manifest_flag: print("Fetching ZIP manifest...") zip_manifest = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") for split in MUSTC.SPLITS: is_train_split = split.startswith("train") manifest = {c: [] for c in MANIFEST_COLUMNS} if args.task == "st" and args.add_src: manifest["src_text"] = [] dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer) for idx in range(len(dataset)): items = dataset.get_fast(idx) for item in items: _, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id = item manifest["id"].append(utt_id) manifest["audio"].append(zip_manifest[utt_id]) duration_ms = int(n_frames / sr * 1000) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) if args.lowercase_src: src_utt = src_utt.lower() if args.rm_punc_src: for w in string.punctuation: src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(" ", "") manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) if args.task == "st" and args.add_src: manifest["src_text"].append(src_utt) manifest["speaker"].append(speaker_id) if is_train_split and args.size != -1 and len(manifest["id"]) > args.size: break if is_train_split: if args.task == "st" and args.add_src and args.share: train_text.extend(manifest["src_text"]) train_text.extend(manifest["tgt_text"]) df = pd.DataFrame.from_dict(manifest) df = filter_manifest_df(df, is_train_split=is_train_split) save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv") # Generate vocab v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" asr_spm_filename = None gen_vocab_flag = True if args.task == "st" and args.add_src: if args.share: if args.st_spm_prefix is not None: gen_vocab_flag = False spm_filename_prefix = args.st_spm_prefix else: spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share" asr_spm_filename = spm_filename_prefix + ".model" else: if args.st_spm_prefix is not None: gen_vocab_flag = False spm_filename_prefix = args.st_spm_prefix assert args.asr_prefix is not None asr_spm_filename = args.asr_prefix + ".model" elif args.task == "asr": if args.asr_prefix is not None: gen_vocab_flag = False spm_filename_prefix = args.asr_prefix if gen_vocab_flag: if len(train_text) == 0: print("Loading the training text to build dictionary...") for split in MUSTC.SPLITS: if split.startswith("train"): csv_path = output_root / f"{split}_{args.task}.tsv" with open(csv_path) as f: reader = csv.DictReader( f, delimiter="\t", quotechar=None, doublequote=False, lineterminator="\n", quoting=csv.QUOTE_NONE, ) if args.task == "st" and args.add_src and args.share: for e in reader: src_utt = dict(e)["src_text"] if args.lowercase_src: src_utt = src_utt.lower() if args.rm_punc_src: for w in string.punctuation: src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(" ", "") train_text.append(src_utt) tgt_text = [dict(e)["tgt_text"] for e in reader] train_text.extend(tgt_text) with NamedTemporaryFile(mode="w") as f: for t in train_text: f.write(t + "\n") gen_vocab( Path(f.name), output_root / spm_filename_prefix, args.vocab_type, args.vocab_size, ) # Generate config YAML yaml_filename = f"config_{args.task}.yaml" if args.task == "st" and args.add_src and args.share: yaml_filename = f"config_{args.task}_share.yaml" gen_config_yaml( output_root, spm_filename_prefix + ".model", yaml_filename=yaml_filename, specaugment_policy="lb", cmvn_type=args.cmvn_type, gcmvn_path=( output_root / "gcmvn.npz" if args.cmvn_type == "global" else None ), asr_spm_filename=asr_spm_filename, share_src_and_tgt=True if args.task == "asr" else False ) def process_joint(args): cur_root = Path(args.data_root) assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \ "do not have downloaded data available for all 8 languages" if args.output_root is None: output_root = cur_root else: output_root = Path(args.output_root).absolute() # Generate vocab vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" with NamedTemporaryFile(mode="w") as f: for lang in MUSTC.LANGUAGES: tsv_path = output_root / f"en-{lang}" / f"train_{args.task}.tsv" df = load_df_from_tsv(tsv_path) for t in df["tgt_text"]: f.write(t + "\n") special_symbols = None if args.task == 'st': special_symbols = [f'<lang:{lang}>' for lang in MUSTC.LANGUAGES] gen_vocab( Path(f.name), output_root / spm_filename_prefix, args.vocab_type, args.vocab_size, special_symbols=special_symbols ) # Generate config YAML gen_config_yaml( output_root, spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", specaugment_policy="ld", prepend_tgt_lang_tag=(args.task == "st"), ) # Make symbolic links to manifests for lang in MUSTC.LANGUAGES: for split in MUSTC.SPLITS: src_path = output_root / f"en-{lang}" / f"{split}_{args.task}.tsv" desc_path = output_root / f"{split}_{lang}_{args.task}.tsv" if not desc_path.is_symlink(): os.symlink(src_path, desc_path) def main(): parser = argparse.ArgumentParser() parser.add_argument("--data-root", "-d", required=True, type=str) parser.add_argument("--output-root", "-o", default=None, type=str) parser.add_argument( "--vocab-type", default="unigram", required=True, type=str, choices=["bpe", "unigram", "char"], ), parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--size", default=-1, type=int) parser.add_argument("--speed-perturb", action="store_true", default=False, help="apply speed perturbation on wave file") parser.add_argument("--joint", action="store_true", help="") parser.add_argument("--share", action="store_true", help="share the tokenizer and dictionary of the transcription and translation") parser.add_argument("--add-src", action="store_true", help="add the src text for st task") parser.add_argument("--asr-prefix", type=str, default=None, help="prefix of the asr dict") parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text") parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt") parser.add_argument("--cmvn-type", default="utterance", choices=["global", "utterance"], help="The type of cepstral mean and variance normalization") parser.add_argument("--overwrite", action="store_true", help="overwrite the existing files") parser.add_argument("--gcmvn-max-num", default=150000, type=int, help=( "Maximum number of sentences to use to estimate" "global mean and variance" )) args = parser.parse_args() if args.joint: process_joint(args) else: process(args) if __name__ == "__main__": main()