Commit 21734086 by xuchen

fix the bugs

parent e1d3d2ed
...@@ -185,7 +185,7 @@ class AudioDataset(Dataset): ...@@ -185,7 +185,7 @@ class AudioDataset(Dataset):
if need_waveform: if need_waveform:
offset = item.get('offset', False) offset = item.get('offset', False)
if offset: if offset is not False:
waveform, sample_rate = torchaudio.load(audio, waveform, sample_rate = torchaudio.load(audio,
frame_offset=offset, frame_offset=offset,
num_frames=item["n_frames"]) num_frames=item["n_frames"])
...@@ -272,7 +272,11 @@ def process(args): ...@@ -272,7 +272,11 @@ def process(args):
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True) waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
if waveform.shape[1] == 0: if waveform.shape[1] == 0:
continue continue
features = extract_fbank_features(waveform, sample_rate, Path(features_path))
try:
features = extract_fbank_features(waveform, sample_rate, Path(features_path))
except AssertionError:
logger.warning("Extract file %s failed." % utt_id)
if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"): if split == 'train' and args.cmvn_type == "global" and not utt_id.startswith("sp"):
if len(gcmvn_feature_list) < args.gcmvn_max_num: if len(gcmvn_feature_list) < args.gcmvn_max_num:
...@@ -326,16 +330,21 @@ def process(args): ...@@ -326,16 +330,21 @@ def process(args):
_, sample_rate, n_frames = dataset.get(idx, need_waveform=False) _, sample_rate, n_frames = dataset.get(idx, need_waveform=False)
utt_id = item["id"] utt_id = item["id"]
manifest["id"].append(utt_id)
if use_raw: if use_raw:
audio_path = item["audio"] audio_path = item["audio"]
# add offset and frames info # add offset and frames info
if item.get("offset", False): if item.get("offset", False) is not False:
audio_path = f"{audio_path}:{item['offset']}:{n_frames}" audio_path = f"{audio_path}:{item['offset']}:{n_frames}"
manifest["audio"].append(audio_path) manifest["audio"].append(audio_path)
else: else:
manifest["audio"].append(zip_manifest[utt_id]) if utt_id in zip_manifest:
manifest["audio"].append(zip_manifest[utt_id])
else:
logger.warning("%s is not in the zip" % utt_id)
continue
manifest["id"].append(utt_id)
duration_ms = int(n_frames / sample_rate * 1000) duration_ms = int(n_frames / sample_rate * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论