Commit 0b804492 by xuchen

fix the bug of the preprocessing of st

parent bec87003
......@@ -103,10 +103,11 @@ class MUSTC(Dataset):
items = []
if self.speed_perturb is None:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, utt_id])
items.append([waveform, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
sp_n_frames = n_frames / speed
if speed == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else:
......@@ -117,7 +118,7 @@ class MUSTC(Dataset):
]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_fast(self, n: int):
......@@ -125,11 +126,12 @@ class MUSTC(Dataset):
items = []
if self.speed_perturb is None:
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, utt_id])
items.append([wav_path, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
sp_n_frames = n_frames / speed
items.append([wav_path, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_src_text(self):
......@@ -170,19 +172,13 @@ def process(args):
zip_path = output_root / "fbank80_sp.zip"
else:
zip_path = output_root / "fbank80.zip"
frame_path = output_root / "frame.pkl"
frame_dict = {}
index = 0
gen_feature_flag = False
if not Path.exists(zip_path):
gen_feature_flag = True
gen_frame_flag = False
if not Path.exists(frame_path):
gen_frame_flag = True
if args.overwrite or gen_feature_flag or gen_frame_flag:
if args.overwrite or gen_feature_flag:
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
......@@ -195,9 +191,8 @@ def process(args):
for items in tqdm(dataset):
for item in items:
index += 1
waveform, sr, _, _, _, utt_id = item
waveform, sr, _, _, _, _, utt_id = item
frame_dict[utt_id] = waveform.size(1)
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
......@@ -215,9 +210,6 @@ def process(args):
with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"])
with open(frame_path, "wb") as f:
pickle.dump(frame_dict, f)
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
......@@ -230,9 +222,6 @@ def process(args):
train_text = []
if args.overwrite or gen_manifest_flag:
if len(frame_dict) == 0:
with open(frame_path, "rb") as f:
frame_dict = pickle.load(f)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
......@@ -247,10 +236,10 @@ def process(args):
for idx in range(len(dataset)):
items = dataset.get_fast(idx)
for item in items:
_, sr, src_utt, tgt_utt, speaker_id, utt_id = item
_, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id = item
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(frame_dict[utt_id] / sr * 1000)
duration_ms = int(n_frames / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
......
......@@ -98,10 +98,11 @@ class ST_Dataset(Dataset):
items = []
if self.speed_perturb is None:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, utt_id])
items.append([waveform, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
sp_n_frames = n_frames / speed
if speed == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else:
......@@ -112,7 +113,7 @@ class ST_Dataset(Dataset):
]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_fast(self, n: int):
......@@ -120,11 +121,12 @@ class ST_Dataset(Dataset):
items = []
if self.speed_perturb is None:
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, utt_id])
items.append([wav_path, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else:
for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, sp_utt_id])
sp_n_frames = n_frames / speed
items.append([wav_path, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items
def get_src_text(self):
......@@ -167,19 +169,13 @@ def process(args):
zip_path = output_root / "fbank80_sp.zip"
else:
zip_path = output_root / "fbank80.zip"
frame_path = output_root / "frame.pkl"
frame_dict = {}
index = 0
gen_feature_flag = False
if not Path.exists(zip_path):
gen_feature_flag = True
gen_frame_flag = False
if not Path.exists(frame_path):
gen_frame_flag = True
if args.overwrite or gen_feature_flag or gen_frame_flag:
if args.overwrite or gen_feature_flag:
for split in splits:
print(f"Fetching split {split}...")
dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb)
......@@ -194,7 +190,6 @@ def process(args):
index += 1
waveform, sr, _, _, _, utt_id = item
frame_dict[utt_id] = waveform.size(1)
if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path))
......@@ -212,9 +207,6 @@ def process(args):
with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"])
with open(frame_path, "wb") as f:
pickle.dump(frame_dict, f)
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
......@@ -227,10 +219,6 @@ def process(args):
train_text = []
if args.overwrite or gen_manifest_flag:
if len(frame_dict) == 0:
with open(frame_path, "rb") as f:
frame_dict = pickle.load(f)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
......@@ -244,10 +232,10 @@ def process(args):
for idx in range(len(dataset)):
items = dataset.get_fast(idx)
for item in items:
_, sr, src_utt, tgt_utt, speaker_id, utt_id = item
_, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id = item
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(frame_dict[utt_id] / sr * 1000)
duration_ms = int(n_frames / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src:
src_utt = src_utt.lower()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论