Commit c02e267c by xuchen

modify the preprocessing of s2t

parent bb6ce82c
......@@ -213,6 +213,7 @@ def process(args):
# Extract features
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in CoVoST.SPLITS:
print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
......@@ -270,24 +271,28 @@ def process(args):
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}"
asr_spm_filename = None
gen_vocab_flag = True
if args.task == "st" and args.add_src:
if args.share:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr":
if args.asr_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.asr_prefix
if args.st_spm_prefix is None:
if gen_vocab_flag:
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
......
......@@ -13,7 +13,7 @@ from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import string
import pickle
import csv
import numpy as np
import pandas as pd
......@@ -263,45 +263,71 @@ def process(args):
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in MUSTC.SPLITS:
if split.startswith("train"):
dataset = MUSTC(args.data_root, lang, split, args.speed_perturb, args.tokenizer)
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
if args.task == "st" and args.add_src and args.share:
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
train_text.append(src_utt)
train_text.append(tgt_utt)
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
gen_vocab_flag = True
if args.task == "st" and args.add_src:
if args.share:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
else:
asr_spm_filename = None
elif args.task == "asr":
if args.asr_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.asr_prefix
if gen_vocab_flag:
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in MUSTC.SPLITS:
if split.startswith("train"):
csv_path = output_root / f"{split}_{args.task}.tsv"
with open(csv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
if args.task == "st" and args.add_src and args.share:
for e in reader:
src_utt = dict(e)["src_text"]
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
train_text.append(src_utt)
tgt_text = [dict(e)["tgt_text"] for e in reader]
train_text.extend(tgt_text)
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
yaml_filename = f"config_{args.task}.yaml"
if args.task == "st" and args.add_src and args.share:
......@@ -388,6 +414,7 @@ def main():
help="share the tokenizer and dictionary of the transcription and translation")
parser.add_argument("--add-src", action="store_true", help="add the src text for st task")
parser.add_argument("--asr-prefix", type=str, help="prefix of the asr dict")
parser.add_argument("--st-spm-prefix", type=str, default=None, help="prefix of the existing st dict")
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--tokenizer", action="store_true", help="use tokenizer txt")
......
......@@ -12,6 +12,7 @@ import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
import string
import csv
import numpy as np
import pandas as pd
......@@ -159,11 +160,6 @@ def process(args):
# Extract features
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
if args.speed_perturb:
zip_path = output_root / "fbank80_sp.zip"
else:
zip_path = output_root / "fbank80.zip"
......@@ -174,6 +170,12 @@ def process(args):
gen_feature_flag = True
if args.overwrite or gen_feature_flag:
if args.speed_perturb:
feature_root = output_root / "fbank80_sp"
else:
feature_root = output_root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in splits:
print(f"Fetching split {split}...")
dataset = ST_Dataset(root.as_posix(), src_lang, tgt_lang, split, args.speed_perturb)
......@@ -209,6 +211,9 @@ def process(args):
print("ZIPing features...")
create_zip(feature_root, zip_path)
# Clean up
shutil.rmtree(feature_root)
gen_manifest_flag = False
for split in splits:
if not Path.exists(output_root / f"{split}_{args.task}.tsv"):
......@@ -226,6 +231,7 @@ def process(args):
manifest = {c: [] for c in MANIFEST_COLUMNS}
if args.task == "st" and args.add_src:
manifest["src_text"] = []
dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split, args.speed_perturb)
for idx in range(len(dataset)):
items = dataset.get_fast(idx)
......@@ -251,50 +257,65 @@ def process(args):
if args.task == "st" and args.add_src and args.share:
train_text.extend(manifest["src_text"])
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, output_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
gen_vocab_flag = True
if args.task == "st" and args.add_src:
if args.share:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif args.task == "asr":
if args.asr_prefix is not None:
gen_vocab_flag = False
spm_filename_prefix = args.asr_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
if args.st_spm_prefix is None:
if gen_vocab_flag:
if len(train_text) == 0:
print("Loading the training text to build dictionary...")
for split in splits:
for split in args.SPLITS:
if split.startswith("train"):
dataset = ST_Dataset(args.data_root, src_lang, tgt_lang, split)
src_text = dataset.get_src_text()
tgt_text = dataset.get_tgt_text()
for src_utt, tgt_utt in zip(src_text, tgt_text):
if args.task == "st" and args.add_src and args.share:
csv_path = output_root / f"{split}_{args.task}.tsv"
with open(csv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
if args.task == "st" and args.add_src and args.share:
for e in reader:
src_utt = dict(e)["src_text"]
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
src_utt = src_utt.translate(None, string.punctuation)
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
train_text.append(src_utt)
train_text.append(tgt_utt)
tgt_text = [dict(e)["tgt_text"] for e in reader]
train_text.extend(tgt_text)
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
......@@ -325,9 +346,6 @@ def process(args):
share_src_and_tgt=True if args.task == "asr" else False
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论