Commit e59c8eb4 by xuchen

bug fix and small optimization

parent 919a3e57
......@@ -4,7 +4,7 @@ gpu_num=1
data_tag=st
test_subset=(tst-COMMON_en-de tst-COMMON_en-fr tst-COMMON_en-es tst-COMMON_en-it tst-COMMON_en-nl tst-COMMON_en-pt tst-COMMON_en-ro tst-COMMON_en-ru)
#test_subset=(tst-COMMON_en-de)
test_subset=(tst-COMMON_en-de)
#test_subset=(test_en-fr_1k)
exp_name=
......@@ -16,7 +16,7 @@ sacrebleu=1
ctc_infer=1
n_average=10
beam_size=5
infer_ctc_weight=0.1
infer_ctc_weight=0
len_penalty=1.0
max_tokens=20000
batch_size=1
......
......@@ -37,7 +37,7 @@ from tqdm import tqdm
logger = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "tgt_lang"]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "src_lang", "tgt_lang"]
class AudioDataset(Dataset):
......@@ -114,26 +114,33 @@ class AudioDataset(Dataset):
if 0 < self.size < total_length:
utterances = utterances[: self.size]
for idx, u in enumerate(utterances):
segments[idx][_lang] = u
# split = split.replace("_gen", "")
# Gather info
self.data = dict()
if self.mode == "easy":
real_idx = 0
for idx, v in segments.items():
audio_name = f"{split}_{v['audio']}"
v["audio"] = (wav_root / v["audio"].strip()).as_posix() + ".wav"
audio_name = os.path.splitext(audio_name)[0]
full_audio_path = (wav_root / v["audio"].strip()).as_posix()
if not os.path.exists(full_audio_path):
full_audio_path += ".wav"
if not os.path.exists(full_audio_path):
logger.warning("No audio: {}".format(full_audio_path))
continue
v["audio"] = full_audio_path
if self.speed_perturb is not None:
for perturb in self.speed_perturb:
sp_item = copy.deepcopy(v)
sp_item["perturb"] = perturb
sp_item["id"] = f"{audio_name}_sp{perturb}"
sp_item["idx"] = f"{audio_name}_sp{perturb}"
self.data[real_idx] = sp_item
real_idx += 1
else:
v["id"] = audio_name
v["idx"] = audio_name
self.data[real_idx] = v
real_idx += 1
if 0 < self.size <= real_idx:
......@@ -163,12 +170,12 @@ class AudioDataset(Dataset):
if self.speed_perturb is not None:
for perturb in self.speed_perturb:
sp_item = copy.deepcopy(item)
sp_item["id"] = f"{_id}_sp{perturb}"
sp_item["idx"] = f"{_id}_sp{perturb}"
sp_item["perturb"] = perturb
self.data[idx] = sp_item
idx += 1
else:
item["id"] = _id
item["idx"] = _id
self.data[idx] = item
idx += 1
if 0 < self.size <= idx:
......@@ -280,22 +287,22 @@ def process(args):
for idx in tqdm(range(len(dataset))):
item = dataset[idx]
utt_id = item["id"]
utt_id = item["idx"]
features_path = (feature_root / f"{utt_id}.npy").as_posix()
if os.path.exists(features_path):
continue
try:
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
if waveform.shape[1] == 0:
continue
try:
features = extract_fbank_features(
waveform, sample_rate, Path(features_path)
)
except AssertionError:
logger.warning("Extract file %s failed." % utt_id)
except RuntimeError:
logger.warning("Get info of audio file %s failed." % utt_id)
if (
split == "train"
......@@ -354,10 +361,14 @@ def process(args):
)
if args.task == "st" and args.add_src and dataset.have_src_utt:
manifest["src_text"] = []
for idx in tqdm(range(len(dataset))):
item = dataset[idx]
try:
_, sample_rate, n_frames = dataset.get(idx, need_waveform=False)
utt_id = item["id"]
except RuntimeError:
logger.warning("Get info of audio file %s failed." % item["idx"])
utt_id = item["idx"]
if use_raw:
audio_path = item["audio"]
......@@ -398,6 +409,7 @@ def process(args):
if args.add_src and src_utt is not None:
manifest["src_text"].append(src_utt)
manifest["tgt_text"].append(tgt_utt)
manifest["src_lang"].append(src_lang)
manifest["tgt_lang"].append(tgt_lang)
if is_train_split:
......@@ -557,7 +569,9 @@ def process_joint(args):
special_symbols = None
if args.task == 'st':
special_symbols = [f'<lang:{lang.split("-")[1]}>' for lang in languages]
special_symbols = [f'<lang:{lang.split("-")[0]}>' for lang in languages]
special_symbols.extend([f'<lang:{lang.split("-")[1]}>' for lang in languages])
special_symbols = list(set(special_symbols))
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
......@@ -585,7 +599,8 @@ def process_joint(args):
for split in args.splits.split(","):
src_path = cur_root / f"{lang}" / f"{task}" / f"{split}.tsv"
desc_path = output_root / f"{split}_{lang}.tsv"
if not desc_path.is_symlink():
if not os.path.exists(desc_path) and os.path.exists(src_path):
# if not desc_path.is_symlink():
shutil.copy(src_path, desc_path)
def main():
......@@ -635,7 +650,6 @@ def main():
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["word", "bpe", "unigram", "char"],
),
......
......@@ -177,7 +177,6 @@ def get_features_from_npy_or_audio(path):
def get_features_or_waveform_from_uncompressed_zip(
path, byte_offset, byte_size, need_waveform=False
):
assert path.endswith(".zip")
data = read_from_uncompressed_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
......@@ -341,6 +340,13 @@ class SpeechToTextDataset(FairseqDataset):
]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
if self.data_cfg.prepend_src_lang_tag_to_enc:
assert self.src_langs is not None and self.src_dict is not None
src_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.src_langs)
]
assert all(t in self.src_dict for t in src_lang_tags)
def tokenize_text(self, text: str, is_src=False):
if self.pre_tokenizer is not None:
text = self.pre_tokenizer.encode(text)
......@@ -571,7 +577,7 @@ class SpeechToTextDatasetCreator(object):
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text"
KEY_ALIGNED_TGT_TEXT = "aligned_tgt_text"
KEY_CTC_TGT_TEXT = "ctc_tgt_text"
KEY_CTC_TGT_TEXT = "xctc_text"
# optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
......
......@@ -438,6 +438,8 @@ class MultiheadAttention(nn.Module):
localness = 0
window = int(src_len * self.localness_window)
if window == 0:
return
for i in range(window, src_len - window):
item_localness = 0
for j in range(-window, window + 1):
......
......@@ -14,11 +14,12 @@ from fairseq.optim import FairseqOptimizer, register_optimizer
from omegaconf import II, DictConfig
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
has_deepspeed_cpu_adam = True
except ImportError:
has_deepspeed_cpu_adam = False
has_deepspeed_cpu_adam = False
# try:
# from deepspeed.ops.op_builder import CPUAdamBuilder
# has_deepspeed_cpu_adam = True
# except ImportError:
# has_deepspeed_cpu_adam = False
@dataclass
......
......@@ -537,12 +537,10 @@ class SpeechToTextTask(LegacyFairseqTask):
if bpe_tokenizer is None:
bpe_tokenizer = self.data_cfg.bpe_tokenizer
logger.info(f"tokenizer: {bpe_tokenizer}")
if bpe_tokenizer is None:
return None
return encoders.build_bpe(Namespace(**bpe_tokenizer))
# def build_src_bpe(self, args):
# logger.info(f"src tokenizer: {self.data_cfg.src_bpe_tokenizer}")
# return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer))
def get_interactive_tokens_and_lengths(self, lines, encode_fn):
n_frames = [get_features_or_waveform(p).shape[0] for p in lines]
return lines, n_frames
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论