Commit c02e267c by xuchen

modify the preprocessing of s2t

parent bb6ce82c
...@@ -213,6 +213,7 @@ def process(args): ...@@ -213,6 +213,7 @@ def process(args):
# Extract features # Extract features
feature_root = output_root / "fbank80" feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True) feature_root.mkdir(exist_ok=True)
for split in CoVoST.SPLITS: for split in CoVoST.SPLITS:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
...@@ -270,24 +271,28 @@ def process(args): ...@@ -270,24 +271,28 @@ def process(args):
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}" spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}"
asr_spm_filename = None asr_spm_filename = None
gen_vocab_flag = True
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
if args.share: if args.share:
if args.st_spm_prefix is not None: if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix spm_filename_prefix = args.st_spm_prefix
else: else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share" spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model" asr_spm_filename = spm_filename_prefix + ".model"
else: else:
if args.st_spm_prefix is not None: if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model" asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr": elif args.task == "asr":
if args.asr_prefix is not None: if args.asr_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.asr_prefix spm_filename_prefix = args.asr_prefix
if args.st_spm_prefix is None: if gen_vocab_flag:
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + "\n") f.write(t + "\n")
......
...@@ -13,7 +13,7 @@ from itertools import groupby ...@@ -13,7 +13,7 @@ from itertools import groupby
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Tuple from typing import Tuple
import string import string
import pickle import csv
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -263,15 +263,51 @@ def process(args): ...@@ -263,15 +263,51 @@ def process(args):
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv") save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
gen_vocab_flag = True
if args.task == "st" and args.add_src:
if args.share:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr":
if args.asr_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.asr_prefix
if gen_vocab_flag:
if len(train_text) == 0: if len(train_text) == 0:
print("Loading the training text to build dictionary...") print("Loading the training text to build dictionary...")
for split in MUSTC.SPLITS: for split in MUSTC.SPLITS:
if split.startswith("train"): if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer) csv_path = output_root / f"{split}_{args.task}.tsv"
src_text = dataset.get_src_text() with open(csv_path) as f:
tgt_text = dataset.get_tgt_text() reader = csv.DictReader(
for src_utt, tgt_utt in zip(src_text, tgt_text): f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
for e in reader:
src_utt = dict(e)["src_text"]
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
if args.rm_punc_src: if args.rm_punc_src:
...@@ -279,19 +315,8 @@ def process(args): ...@@ -279,19 +315,8 @@ def process(args):
src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "") src_utt = src_utt.replace(" ", "")
train_text.append(src_utt) train_text.append(src_utt)
train_text.append(tgt_utt) tgt_text = [dict(e)["tgt_text"] for e in reader]
train_text.extend(tgt_text)
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
...@@ -302,6 +327,7 @@ def process(args): ...@@ -302,6 +327,7 @@ def process(args):
args.vocab_type, args.vocab_type,
args.vocab_size, args.vocab_size,
) )
# Generate config YAML # Generate config YAML
yaml_filename = f"config_{args.task}.yaml" yaml_filename = f"config_{args.task}.yaml"
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
...@@ -388,6 +414,7 @@ def main(): ...@@ -388,6 +414,7 @@ def main():
help="share the tokenizer and dictionary of the transcription and translation") help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task") parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict") parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text") parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt") parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt")
......
...@@ -12,6 +12,7 @@ import shutil ...@@ -12,6 +12,7 @@ import shutil
from itertools import groupby from itertools import groupby
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import string import string
import csv
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -159,11 +160,6 @@ def process(args): ...@@ -159,11 +160,6 @@ def process(args):
# Extract features # Extract features
if args.speed_perturb: if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.speed_perturb:
zip_path = output_root / "fbank80_sp.zip" zip_path = output_root / "fbank80_sp.zip"
else: else:
zip_path = output_root / "fbank80.zip" zip_path = output_root / "fbank80.zip"
...@@ -174,6 +170,12 @@ def process(args): ...@@ -174,6 +170,12 @@ def process(args):
gen_feature_flag = True gen_feature_flag = True
if args.overwrite or gen_feature_flag: if args.overwrite or gen_feature_flag:
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in splits: for split in splits:
print(f"Fetching split {split}...") print(f"Fetching split {split}...")
dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb) dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb)
...@@ -209,6 +211,9 @@ def process(args): ...@@ -209,6 +211,9 @@ def process(args):
print("ZIPing features...") print("ZIPing features...")
create_zip(feature_root, zip_path) create_zip(feature_root, zip_path)
# Clean up
shutil.rmtree(feature_root)
gen_manifest_flag = False gen_manifest_flag = False
for split in splits: for split in splits:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"): if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
...@@ -226,6 +231,7 @@ def process(args): ...@@ -226,6 +231,7 @@ def process(args):
manifest = {c: [] for c in MANIFEST_COLUMNS} manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
manifest["src_text"] = [] manifest["src_text"] = []
dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split, args.speed_perturb) dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split, args.speed_perturb)
for idx in range(len(dataset)): for idx in range(len(dataset)):
items = dataset.get_fast(idx) items = dataset.get_fast(idx)
...@@ -251,50 +257,65 @@ def process(args): ...@@ -251,50 +257,65 @@ def process(args):
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"]) train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"]) train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest) df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split) df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv") save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab # Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None asr_spm_filename = None
gen_vocab_flag = True
if args.task == "st" and args.add_src: if args.task == "st" and args.add_src:
if args.share: if args.share:
if args.st_spm_prefix is not None: if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix spm_filename_prefix = args.st_spm_prefix
else: else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share" spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model" asr_spm_filename = spm_filename_prefix + ".model"
else: else:
if args.st_spm_prefix is not None: if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
assert args.asr_prefix is not None assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model" asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr": elif args.task == "asr":
if args.asr_prefix is not None: if args.asr_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.asr_prefix spm_filename_prefix = args.asr_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.st_spm_prefix is None: if gen_vocab_flag:
if len(train_text) == 0: if len(train_text) == 0:
print("Loading the training text to build dictionary...") print("Loading the training text to build dictionary...")
for split in splits:
for split in args.SPLITS:
if split.startswith("train"): if split.startswith("train"):
dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split) csv_path = output_root / f"{split}_{args.task}.tsv"
src_text = dataset.get_src_text() with open(csv_path) as f:
tgt_text = dataset.get_tgt_text() reader = csv.DictReader(
for src_utt, tgt_utt in zip(src_text, tgt_text): f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
for e in reader:
src_utt = dict(e)["src_text"]
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
if args.rm_punc_src: if args.rm_punc_src:
src_utt = src_utt.translate(None, string.punctuation) for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
train_text.append(src_utt) train_text.append(src_utt)
train_text.append(tgt_utt) tgt_text = [dict(e)["tgt_text"] for e in reader]
train_text.extend(tgt_text)
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
...@@ -325,9 +346,6 @@ def process(args): ...@@ -325,9 +346,6 @@ def process(args):
share_src_and_tgt=True if args.task == "asr" else False share_src_and_tgt=True if args.task == "asr" else False
) )
# Clean up
shutil.rmtree(feature_root)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论