import os.path as op from typing import BinaryIO, Optional, Tuple, Union import numpy as np import torch import torchaudio def get_waveform( path_or_fp: Union[str, BinaryIO], normalization=True, offset=None, size=None ) -> Tuple[np.ndarray, int]: """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. Args: path_or_fp (str or BinaryIO): the path or file-like object normalization (bool): Normalize values to [-1, 1] (Default: True) """ if isinstance(path_or_fp, str): ext = op.splitext(op.basename(path_or_fp))[1] if ext not in {".flac", ".wav"}: raise ValueError(f"Unsupported audio format: {ext}") if offset is not None and size is not None: waveform, sample_rate = torchaudio.load(path_or_fp, frame_offset=offset, num_frames=size) else: waveform, sample_rate = torchaudio.load(path_or_fp) waveform = waveform.squeeze().numpy() if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers return waveform, sample_rate def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: """Get mel-filter bank features via PyKaldi.""" try: from kaldi.feat.mel import MelBanksOptions from kaldi.feat.fbank import FbankOptions, Fbank from kaldi.feat.window import FrameExtractionOptions from kaldi.matrix import Vector mel_opts = MelBanksOptions() mel_opts.num_bins = n_bins frame_opts = FrameExtractionOptions() frame_opts.samp_freq = sample_rate opts = FbankOptions() opts.mel_opts = mel_opts opts.frame_opts = frame_opts fbank = Fbank(opts=opts) features = fbank.compute(Vector(waveform), 1.0).numpy() return features except ImportError: return None def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: """Get mel-filter bank features via TorchAudio.""" try: import torchaudio.compliance.kaldi as ta_kaldi import torchaudio.sox_effects as ta_sox if not isinstance(waveform, torch.Tensor): waveform = torch.from_numpy(waveform) if len(waveform.shape) == 1: # Mono channel: D -> 1 x D waveform = waveform.unsqueeze(0) else: # Merge multiple channels to one: D x C -> 1 x D waveform, _ = ta_sox.apply_effects_tensor(waveform.T, sample_rate, [['channels', '1']]) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) return features.numpy() except ImportError: return None def get_fbank( path_or_fp: Union[str, BinaryIO], n_bins=80, offset=None, size=None, ) -> np.ndarray: """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi (faster CPP implementation) to TorchAudio (Python implementation). Note that Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the waveform should not be normalized.""" sound, sample_rate = get_waveform(path_or_fp, normalization=False, offset=offset, size=size) features = _get_kaldi_fbank(sound, sample_rate, n_bins) if features is None: features = _get_torchaudio_fbank(sound, sample_rate, n_bins) if features is None: raise ImportError( "Please install pyKaldi or torchaudio to enable " "online filterbank feature extraction" ) return features def get_fbank_with_perturb(waveform, sample_rate=16000, n_bins=80): import random speed = random.choice([0.9, 1.0, 1.1]) if speed != 1.0: waveform = torch.from_numpy(waveform).float().unsqueeze(0) waveform, _ = torchaudio.sox_effects.apply_effects_tensor( waveform, sample_rate, [['speed', str(speed)], ['rate', str(sample_rate)]]) waveform = waveform.squeeze() features = _get_kaldi_fbank(waveform, sample_rate) if features is None: features = _get_torchaudio_fbank(waveform, sample_rate, n_bins) return features