Commit 0b804492 by xuchen

fix the bug of the preprocessing of st

parent bec87003
...@@ -103,10 +103,11 @@ class MUSTC(Dataset): ...@@ -103,10 +103,11 @@ class MUSTC(Dataset):
items = [] items = []
if self.speed_perturb is None: if self.speed_perturb is None:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, utt_id]) items.append([waveform, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else: else:
for speed in self.speed_perturb: for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id sp_utt_id = f"sp{speed}_" + utt_id
sp_n_frames = n_frames / speed
if speed == 1.0: if speed == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else: else:
...@@ -117,7 +118,7 @@ class MUSTC(Dataset): ...@@ -117,7 +118,7 @@ class MUSTC(Dataset):
] ]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects) waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, sp_utt_id]) items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items return items
def get_fast(self, n: int): def get_fast(self, n: int):
...@@ -125,11 +126,12 @@ class MUSTC(Dataset): ...@@ -125,11 +126,12 @@ class MUSTC(Dataset):
items = [] items = []
if self.speed_perturb is None: if self.speed_perturb is None:
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, utt_id]) items.append([wav_path, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else: else:
for speed in self.speed_perturb: for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id sp_utt_id = f"sp{speed}_" + utt_id
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, sp_utt_id]) sp_n_frames = n_frames / speed
items.append([wav_path, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items return items
def get_src_text(self): def get_src_text(self):
...@@ -170,19 +172,13 @@ def process(args): ...@@ -170,19 +172,13 @@ def process(args):
zip_path = output_root / "fbank80_sp.zip" zip_path = output_root / "fbank80_sp.zip"
else: else:
zip_path = output_root / "fbank80.zip" zip_path = output_root / "fbank80.zip"
frame_path = output_root / "frame.pkl"
frame_dict = {}
index = 0 index = 0
gen_feature_flag = False gen_feature_flag = False
if not Path.exists(zip_path): if not Path.exists(zip_path):
gen_feature_flag = True gen_feature_flag = True
gen_frame_flag = False if args.overwrite or gen_feature_flag:
if not Path.exists(frame_path):
gen_frame_flag = True
if args.overwrite or gen_feature_flag or gen_frame_flag:
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb) dataset = MUSTC(root.as_posix(), lang, split, args.speed_perturb)
...@@ -195,9 +191,8 @@ def process(args): ...@@ -195,9 +191,8 @@ def process(args):
for items in tqdm(dataset): for items in tqdm(dataset):
for item in items: for item in items:
index += 1 index += 1
waveform, sr, _, _, _, utt_id = item waveform, sr, _, _, _, _, utt_id = item
frame_dict[utt_id] = waveform.size(1)
if gen_feature_flag: if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path)) features = extract_fbank_features(waveform, sr, Path(features_path))
...@@ -215,12 +210,9 @@ def process(args): ...@@ -215,12 +210,9 @@ def process(args):
with open(output_root / "gcmvn.npz", "wb") as f: with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"]) np.savez(f, mean=stats["mean"], std=stats["std"])
with open(frame_path, "wb") as f: # Pack features into ZIP
pickle.dump(frame_dict, f) print("ZIPing features...")
create_zip(feature_root, zip_path)
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
gen_manifest_flag = False gen_manifest_flag = False
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
...@@ -230,9 +222,6 @@ def process(args): ...@@ -230,9 +222,6 @@ def process(args):
train_text = [] train_text = []
if args.overwrite or gen_manifest_flag: if args.overwrite or gen_manifest_flag:
if len(frame_dict) == 0:
with open(frame_path, "rb") as f:
frame_dict = pickle.load(f)
print("Fetching ZIP manifest...") print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path) zip_manifest = get_zip_manifest(zip_path)
...@@ -247,10 +236,10 @@ def process(args): ...@@ -247,10 +236,10 @@ def process(args):
for idx in range(len(dataset)): for idx in range(len(dataset)):
items = dataset.get_fast(idx) items = dataset.get_fast(idx)
for item in items: for item in items:
_, sr, src_utt, tgt_utt, speaker_id, utt_id = item _, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id = item
manifest["id"].append(utt_id) manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id]) manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(frame_dict[utt_id] / sr * 1000) duration_ms = int(n_frames / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
......
...@@ -98,10 +98,11 @@ class ST_Dataset(Dataset): ...@@ -98,10 +98,11 @@ class ST_Dataset(Dataset):
items = [] items = []
if self.speed_perturb is None: if self.speed_perturb is None:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, utt_id]) items.append([waveform, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else: else:
for speed in self.speed_perturb: for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id sp_utt_id = f"sp{speed}_" + utt_id
sp_n_frames = n_frames / speed
if speed == 1.0: if speed == 1.0:
waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames) waveform, _ = torchaudio.load(wav_path, frame_offset=offset, num_frames=n_frames)
else: else:
...@@ -112,7 +113,7 @@ class ST_Dataset(Dataset): ...@@ -112,7 +113,7 @@ class ST_Dataset(Dataset):
] ]
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects) waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sr, effects)
items.append([waveform, sr, src_utt, tgt_utt, spk_id, sp_utt_id]) items.append([waveform, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items return items
def get_fast(self, n: int): def get_fast(self, n: int):
...@@ -120,11 +121,12 @@ class ST_Dataset(Dataset): ...@@ -120,11 +121,12 @@ class ST_Dataset(Dataset):
items = [] items = []
if self.speed_perturb is None: if self.speed_perturb is None:
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, utt_id]) items.append([wav_path, sr, n_frames, src_utt, tgt_utt, spk_id, utt_id])
else: else:
for speed in self.speed_perturb: for speed in self.speed_perturb:
sp_utt_id = f"sp{speed}_" + utt_id sp_utt_id = f"sp{speed}_" + utt_id
items.append([wav_path, sr, src_utt, tgt_utt, spk_id, sp_utt_id]) sp_n_frames = n_frames / speed
items.append([wav_path, sr, sp_n_frames, src_utt, tgt_utt, spk_id, sp_utt_id])
return items return items
def get_src_text(self): def get_src_text(self):
...@@ -167,19 +169,13 @@ def process(args): ...@@ -167,19 +169,13 @@ def process(args):
zip_path = output_root / "fbank80_sp.zip" zip_path = output_root / "fbank80_sp.zip"
else: else:
zip_path = output_root / "fbank80.zip" zip_path = output_root / "fbank80.zip"
frame_path = output_root / "frame.pkl"
frame_dict = {}
index = 0 index = 0
gen_feature_flag = False gen_feature_flag = False
if not Path.exists(zip_path): if not Path.exists(zip_path):
gen_feature_flag = True gen_feature_flag = True
gen_frame_flag = False if args.overwrite or gen_feature_flag:
if not Path.exists(frame_path):
gen_frame_flag = True
if args.overwrite or gen_feature_flag or gen_frame_flag:
for split in splits: for split in splits:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb) dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb)
...@@ -194,7 +190,6 @@ def process(args): ...@@ -194,7 +190,6 @@ def process(args):
index += 1 index += 1
waveform, sr, _, _, _, utt_id = item waveform, sr, _, _, _, utt_id = item
frame_dict[utt_id] = waveform.size(1)
if gen_feature_flag: if gen_feature_flag:
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
features = extract_fbank_features(waveform, sr, Path(features_path)) features = extract_fbank_features(waveform, sr, Path(features_path))
...@@ -212,12 +207,9 @@ def process(args): ...@@ -212,12 +207,9 @@ def process(args):
with open(output_root / "gcmvn.npz", "wb") as f: with open(output_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"]) np.savez(f, mean=stats["mean"], std=stats["std"])
with open(frame_path, "wb") as f: # Pack features into ZIP
pickle.dump(frame_dict, f) print("ZIPing features...")
create_zip(feature_root, zip_path)
# Pack features into ZIP
print("ZIPing features...")
create_zip(feature_root, zip_path)
gen_manifest_flag = False gen_manifest_flag = False
for split in splits: for split in splits:
...@@ -227,10 +219,6 @@ def process(args): ...@@ -227,10 +219,6 @@ def process(args):
train_text = [] train_text = []
if args.overwrite or gen_manifest_flag: if args.overwrite or gen_manifest_flag:
if len(frame_dict) == 0:
with open(frame_path, "rb") as f:
frame_dict = pickle.load(f)
print("Fetching ZIP manifest...") print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path) zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest # Generate TSV manifest
...@@ -244,10 +232,10 @@ def process(args): ...@@ -244,10 +232,10 @@ def process(args):
for idx in range(len(dataset)): for idx in range(len(dataset)):
items = dataset.get_fast(idx) items = dataset.get_fast(idx)
for item in items: for item in items:
_, sr, src_utt, tgt_utt, speaker_id, utt_id = item _, sr, n_frames, src_utt, tgt_utt, speaker_id, utt_id = item
manifest["id"].append(utt_id) manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id]) manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(frame_dict[utt_id] / sr * 1000) duration_ms = int(n_frames / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论