prep_mustc_data.py 16.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import logging
import os
from pathlib import Path
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import string
16
import pickle
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

import numpy as np
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
    create_zip,
    extract_fbank_features,
    filter_manifest_df,
    gen_config_yaml,
    gen_vocab,
    get_zip_manifest,
    load_df_from_tsv,
    save_df_to_tsv,
    cal_gcmvn_stats,
)
from torch.utils.data import Dataset
from tqdm import tqdm


log = logging.getLogger(__name__)


MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]


class MUSTC(Dataset):
    """
    Create a Dataset for MuST-C. Each item is a tuple of the form:
    waveform, sample_rate, source utterance, target utterance, speaker_id,
    utterance_id
    """

xuchen committed
49 50
    SPLITS = ["dev", "tst-COMMON", "tst-HE", "train"]
    # SPLITS = ["train_debug", "dev"]
51 52
    LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]

53
    def __init__(self, root: str, lang: str, split: str, speed_perturb: bool = False) -> None:
54 55 56
        assert split in self.SPLITS and lang in self.LANGUAGES
        _root = Path(root) / f"en-{lang}" / "data" / split
        wav_root, txt_root = _root / "wav", _root / "txt"
57
        assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir(), (_root, wav_root, txt_root)
58 59 60 61 62 63 64
        # Load audio segments
        try:
            import yaml
        except ImportError:
            print("Please install PyYAML to load the MuST-C YAML files")
        with open(txt_root / f"{split}.yaml") as f:
            segments = yaml.load(f, Loader=yaml.BaseLoader)
65 66

        self.speed_perturb = [0.9, 1.0, 1.1] if speed_perturb and split.startswith("train") else None
67 68 69 70 71 72 73 74 75 76 77
        # Load source and target utterances
        for _lang in ["en", lang]:
            with open(txt_root / f"{split}.{_lang}") as f:
                utterances = [r.strip() for r in f]
            assert len(segments) == len(utterances)
            for i, u in enumerate(utterances):
                segments[i][_lang] = u
        # Gather info
        self.data = []
        for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
            wav_path = wav_root / wav_filename
xuchen committed
78 79 80 81
            try:
                sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
            except TypeError:
                sample_rate = torchaudio.info(wav_path.as_posix()).sample_rate
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
            seg_group = sorted(_seg_group, key=lambda x: x["offset"])
            for i, segment in enumerate(seg_group):
                offset = int(float(segment["offset"]) * sample_rate)
                n_frames = int(float(segment["duration"]) * sample_rate)
                _id = f"{wav_path.stem}_{i}"
                self.data.append(
                    (
                        wav_path.as_posix(),
                        offset,
                        n_frames,
                        sample_rate,
                        segment["en"],
                        segment[lang],
                        segment["speaker_id"],
                        _id,
                    )
                )

100 101 102 103 104 105
    def __getitem__(self, n: int):
        wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]

        items = []
        if self.speed_perturb is None:
            waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
106
            items.append([waveform, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
107 108 109
        else:
            for speed in self.speed_perturb:
                sp_utt_id = f"sp{speed}_" + utt_id
110
                sp_n_frames = n_frames / speed
111 112 113 114 115 116 117 118 119 120
                if speed == 1.0:
                    waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
                else:
                    waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
                    effects = [
                        ["speed", f"{speed}"],
                        ["rate", f"{sr}"]
                    ]
                    waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)

121
                items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
122 123 124
        return items

    def get_fast(self, n: int):
125
        wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
126 127 128

        items = []
        if self.speed_perturb is None:
129
            items.append([wav_path, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
130 131 132
        else:
            for speed in self.speed_perturb:
                sp_utt_id = f"sp{speed}_" + utt_id
133 134
                sp_n_frames = n_frames / speed
                items.append([wav_path, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
135 136 137 138 139 140 141 142 143 144 145 146 147
        return items

    def get_src_text(self):
        src_text = []
        for item in self.data:
            src_text.append(item[4])
        return src_text

    def get_tgt_text(self):
        tgt_text = []
        for item in self.data:
            tgt_text.append(item[5])
        return tgt_text
148 149 150 151 152 153 154 155 156 157 158 159

    def __len__(self) -> int:
        return len(self.data)


def process(args):
    root = Path(args.data_root).absolute()
    for lang in MUSTC.LANGUAGES:
        cur_root = root / f"en-{lang}"
        if not cur_root.is_dir():
            print(f"{cur_root.as_posix()} does not exist. Skipped.")
            continue
160 161 162 163 164
        if args.output_root is None:
            output_root = cur_root
        else:
            output_root = Path(args.output_root).absolute() / f"en-{lang}"

165
        # Extract features
xuchen committed
166 167 168 169
        if args.speed_perturb:
            feature_root = output_root / "fbank80_sp"
        else:
            feature_root = output_root / "fbank80"
170
        feature_root.mkdir(exist_ok=True)
xuchen committed
171 172 173 174 175
        if args.speed_perturb:
            zip_path = output_root / "fbank80_sp.zip"
        else:
            zip_path = output_root / "fbank80.zip"
        index = 0
176 177 178 179 180

        gen_feature_flag = False
        if not Path.exists(zip_path):
            gen_feature_flag = True

181
        if args.overwrite or gen_feature_flag:
182 183
            for split in MUSTC.SPLITS:
                print(f"Fetching split {split}...")
184 185
                dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
                is_train_split = split.startswith("train")
186
                print("Extracting log mel filter bank features...")
187
                if is_train_split and args.cmvn_type == "global":
188 189 190
                    print("And estimating cepstral mean and variance stats...")
                    gcmvn_feature_list = []

191 192
                for items in tqdm(dataset):
                    for item in items:
xuchen committed
193
                        index += 1
194
                        waveform, sr, _, _, _, _, utt_id = item
195

xuchen committed
196 197 198
                        if gen_feature_flag:
                            features_path = (feature_root / f"{utt_id}.npy").as_posix()
                            features = extract_fbank_features(waveform, sr, Path(features_path))
199

xuchen committed
200 201 202
                            if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
                                if len(gcmvn_feature_list) < args.gcmvn_max_num:
                                    gcmvn_feature_list.append(features)
203

xuchen committed
204
                    if is_train_split and args.size != -1 and index > args.size:
205 206 207
                        break

                if is_train_split and args.cmvn_type == "global":
208 209
                    # Estimate and save cmv
                    stats = cal_gcmvn_stats(gcmvn_feature_list)
210
                    with open(output_root / "gcmvn.npz", "wb") as f:
211 212
                        np.savez(f, mean=stats["mean"], std=stats["std"])

213 214 215
            # Pack features into ZIP
            print("ZIPing features...")
            create_zip(feature_root, zip_path)
216 217 218

        gen_manifest_flag = False
        for split in MUSTC.SPLITS:
219
            if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
220 221 222 223 224
                gen_manifest_flag = True
                break

        train_text = []
        if args.overwrite or gen_manifest_flag:
xuchen committed
225

226 227 228 229
            print("Fetching ZIP manifest...")
            zip_manifest = get_zip_manifest(zip_path)
            # Generate TSV manifest
            print("Generating manifest...")
xuchen committed
230
            for split in MUSTC.SPLITS:
231
                is_train_split = split.startswith("train")
xuchen committed
232 233 234
                manifest = {c: [] for c in MANIFEST_COLUMNS}
                if args.task == "st" and args.add_src:
                    manifest["src_text"] = []
235
                dataset = MUSTC(args.data_root, lang, split, args.speed_perturb)
xuchen committed
236 237 238
                for idx in range(len(dataset)):
                    items = dataset.get_fast(idx)
                    for item in items:
239
                        _, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id = item
xuchen committed
240 241
                        manifest["id"].append(utt_id)
                        manifest["audio"].append(zip_manifest[utt_id])
242
                        duration_ms = int(n_frames / sr * 1000)
xuchen committed
243 244 245 246 247 248 249 250 251 252
                        manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
                        if args.lowercase_src:
                            src_utt = src_utt.lower()
                        if args.rm_punc_src:
                            for w in string.punctuation:
                                src_utt = src_utt.replace(w, "")
                        manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
                        if args.task == "st" and args.add_src:
                            manifest["src_text"].append(src_utt)
                        manifest["speaker"].append(speaker_id)
253

xuchen committed
254 255 256 257 258 259
                    if is_train_split and args.size != -1 and len(manifest["id"]) > args.size:
                        break
                if is_train_split:
                    if args.task == "st" and args.add_src and args.share:
                        train_text.extend(manifest["src_text"])
                    train_text.extend(manifest["tgt_text"])
260 261
                df = pd.DataFrame.from_dict(manifest)
                df = filter_manifest_df(df, is_train_split=is_train_split)
262
                save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280

        # Generate vocab
        v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
        spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
        if args.task == "st" and args.add_src:
            if args.share:
                spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
                asr_spm_filename = spm_filename_prefix + ".model"
            else:
                asr_spm_filename = args.asr_prefix + ".model"
        else:
            asr_spm_filename = None

        if len(train_text) == 0:
            print("Loading the training text to build dictionary...")
            for split in MUSTC.SPLITS:
                if split.startswith("train"):
                    dataset = MUSTC(args.data_root, lang, split)
281 282 283
                    src_text = dataset.get_src_text()
                    tgt_text = dataset.get_tgt_text()
                    for src_utt, tgt_utt in zip(src_text, tgt_text):
284 285 286 287 288 289 290
                        if args.task == "st" and args.add_src and args.share:
                            if args.lowercase_src:
                                src_utt = src_utt.lower()
                            if args.rm_punc_src:
                                src_utt = src_utt.translate(None, string.punctuation)
                            train_text.append(src_utt)
                        train_text.append(tgt_utt)
291

292 293 294 295 296
        with NamedTemporaryFile(mode="w") as f:
            for t in train_text:
                f.write(t + "\n")
            gen_vocab(
                Path(f.name),
297
                output_root / spm_filename_prefix,
298 299 300 301 302 303 304 305 306
                args.vocab_type,
                args.vocab_size,
            )
        # Generate config YAML
        yaml_filename = f"config_{args.task}.yaml"
        if args.task == "st" and args.add_src and args.share:
            yaml_filename = f"config_{args.task}_share.yaml"

        gen_config_yaml(
307
            output_root,
308 309 310 311 312
            spm_filename_prefix + ".model",
            yaml_filename=yaml_filename,
            specaugment_policy="lb",
            cmvn_type=args.cmvn_type,
            gcmvn_path=(
313
                output_root / "gcmvn.npz" if args.cmvn_type == "global"
314 315
                else None
            ),
316 317
            asr_spm_filename=asr_spm_filename,
            share_src_and_tgt=True if args.task == "asr" else False
318
        )
319

320
        # Clean up
xuchen committed
321
        shutil.rmtree(feature_root)
322 323 324 325 326 327


def process_joint(args):
    cur_root = Path(args.data_root)
    assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \
        "do not have downloaded data available for all 8 languages"
328 329 330 331 332
    if args.output_root is None:
        output_root = cur_root
    else:
        output_root = Path(args.output_root).absolute()

333 334 335 336 337
    # Generate vocab
    vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
    spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
    with NamedTemporaryFile(mode="w") as f:
        for lang in MUSTC.LANGUAGES:
338
            tsv_path = output_root / f"en-{lang}" / f"train_{args.task}.tsv"
339 340 341 342 343 344 345 346
            df = load_df_from_tsv(tsv_path)
            for t in df["tgt_text"]:
                f.write(t + "\n")
        special_symbols = None
        if args.task == 'st':
            special_symbols = [f'<lang:{lang}>' for lang in MUSTC.LANGUAGES]
        gen_vocab(
            Path(f.name),
347
            output_root / spm_filename_prefix,
348 349 350 351 352 353
            args.vocab_type,
            args.vocab_size,
            special_symbols=special_symbols
        )
    # Generate config YAML
    gen_config_yaml(
354
        output_root,
355 356 357 358 359 360 361 362
        spm_filename_prefix + ".model",
        yaml_filename=f"config_{args.task}.yaml",
        specaugment_policy="ld",
        prepend_tgt_lang_tag=(args.task == "st"),
    )
    # Make symbolic links to manifests
    for lang in MUSTC.LANGUAGES:
        for split in MUSTC.SPLITS:
363 364
            src_path = output_root / f"en-{lang}" / f"{split}_{args.task}.tsv"
            desc_path = output_root / f"{split}_{lang}_{args.task}.tsv"
365 366 367 368 369 370 371
            if not desc_path.is_symlink():
                os.symlink(src_path, desc_path)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-root", "-d", required=True, type=str)
372
    parser.add_argument("--output-root", "-o", default=None, type=str)
373 374 375 376 377 378 379 380 381 382
    parser.add_argument(
        "--vocab-type",
        default="unigram",
        required=True,
        type=str,
        choices=["bpe", "unigram", "char"],
    ),
    parser.add_argument("--vocab-size", default=8000, type=int)
    parser.add_argument("--task", type=str, choices=["asr", "st"])
    parser.add_argument("--size", default=-1, type=int)
383 384
    parser.add_argument("--speed-perturb", action="store_true", default=False,
                        help="apply speed perturbation on wave file")
385
    parser.add_argument("--joint", action="store_true", help="")
386 387
    parser.add_argument("--share", action="store_true",
                        help="share the tokenizer and dictionary of the transcription and translation")
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
    parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
    parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
    parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
    parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
    parser.add_argument("--cmvn-type", default="utterance",
                        choices=["global", "utterance"],
                        help="The type of cepstral mean and variance normalization")
    parser.add_argument("--overwrite", action="store_true", help="overwrite the existing files")
    parser.add_argument("--gcmvn-max-num", default=150000, type=int,
                        help=(
                            "Maximum number of sentences to use to estimate"
                            "global mean and variance"
                            ))
    args = parser.parse_args()

    if args.joint:
        process_joint(args)
    else:
        process(args)


if __name__ == "__main__":
    main()