Commit e59c8eb4 by xuchen

bug fix and small optimization

parent 919a3e57
...@@ -4,7 +4,7 @@ gpu_num=1 ...@@ -4,7 +4,7 @@ gpu_num=1
data_tag=st data_tag=st
test_subset=(tst-COMMON_en-de tst-COMMON_en-fr tst-COMMON_en-es tst-COMMON_en-it tst-COMMON_en-nl tst-COMMON_en-pt tst-COMMON_en-ro tst-COMMON_en-ru) test_subset=(tst-COMMON_en-de tst-COMMON_en-fr tst-COMMON_en-es tst-COMMON_en-it tst-COMMON_en-nl tst-COMMON_en-pt tst-COMMON_en-ro tst-COMMON_en-ru)
#test_subset=(tst-COMMON_en-de) test_subset=(tst-COMMON_en-de)
#test_subset=(test_en-fr_1k) #test_subset=(test_en-fr_1k)
exp_name= exp_name=
...@@ -16,7 +16,7 @@ sacrebleu=1 ...@@ -16,7 +16,7 @@ sacrebleu=1
ctc_infer=1 ctc_infer=1
n_average=10 n_average=10
beam_size=5 beam_size=5
infer_ctc_weight=0.1 infer_ctc_weight=0
len_penalty=1.0 len_penalty=1.0
max_tokens=20000 max_tokens=20000
batch_size=1 batch_size=1
......
...@@ -37,7 +37,7 @@ from tqdm import tqdm ...@@ -37,7 +37,7 @@ from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "tgt_lang"] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "src_lang", "tgt_lang"]
class AudioDataset(Dataset): class AudioDataset(Dataset):
...@@ -114,26 +114,33 @@ class AudioDataset(Dataset): ...@@ -114,26 +114,33 @@ class AudioDataset(Dataset):
if 0 < self.size < total_length: if 0 < self.size < total_length:
utterances = utterances[: self.size] utterances = utterances[: self.size]
for idx, u in enumerate(utterances): for idx, u in enumerate(utterances):
segments[idx][_lang] = u segments[idx][_lang] = u
# split = split.replace("_gen", "")
# Gather info # Gather info
self.data = dict() self.data = dict()
if self.mode == "easy": if self.mode == "easy":
real_idx = 0 real_idx = 0
for idx, v in segments.items(): for idx, v in segments.items():
audio_name = f"{split}_{v['audio']}" audio_name = f"{split}_{v['audio']}"
v["audio"] = (wav_root / v["audio"].strip()).as_posix() + ".wav" audio_name = os.path.splitext(audio_name)[0]
full_audio_path = (wav_root / v["audio"].strip()).as_posix()
if not os.path.exists(full_audio_path):
full_audio_path += ".wav"
if not os.path.exists(full_audio_path):
logger.warning("No audio: {}".format(full_audio_path))
continue
v["audio"] = full_audio_path
if self.speed_perturb is not None: if self.speed_perturb is not None:
for perturb in self.speed_perturb: for perturb in self.speed_perturb:
sp_item = copy.deepcopy(v) sp_item = copy.deepcopy(v)
sp_item["perturb"] = perturb sp_item["perturb"] = perturb
sp_item["id"] = f"{audio_name}_sp{perturb}" sp_item["idx"] = f"{audio_name}_sp{perturb}"
self.data[real_idx] = sp_item self.data[real_idx] = sp_item
real_idx += 1 real_idx += 1
else: else:
v["id"] = audio_name v["idx"] = audio_name
self.data[real_idx] = v self.data[real_idx] = v
real_idx += 1 real_idx += 1
if 0 < self.size <= real_idx: if 0 < self.size <= real_idx:
...@@ -163,12 +170,12 @@ class AudioDataset(Dataset): ...@@ -163,12 +170,12 @@ class AudioDataset(Dataset):
if self.speed_perturb is not None: if self.speed_perturb is not None:
for perturb in self.speed_perturb: for perturb in self.speed_perturb:
sp_item = copy.deepcopy(item) sp_item = copy.deepcopy(item)
sp_item["id"] = f"{_id}_sp{perturb}" sp_item["idx"] = f"{_id}_sp{perturb}"
sp_item["perturb"] = perturb sp_item["perturb"] = perturb
self.data[idx] = sp_item self.data[idx] = sp_item
idx += 1 idx += 1
else: else:
item["id"] = _id item["idx"] = _id
self.data[idx] = item self.data[idx] = item
idx += 1 idx += 1
if 0 < self.size <= idx: if 0 < self.size <= idx:
...@@ -280,22 +287,22 @@ def process(args): ...@@ -280,22 +287,22 @@ def process(args):
for idx in tqdm(range(len(dataset))): for idx in tqdm(range(len(dataset))):
item = dataset[idx] item = dataset[idx]
utt_id = item["id"] utt_id = item["idx"]
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
if os.path.exists(features_path): if os.path.exists(features_path):
continue continue
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
if waveform.shape[1] == 0:
continue
try: try:
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
if waveform.shape[1] == 0:
continue
features = extract_fbank_features( features = extract_fbank_features(
waveform, sample_rate, Path(features_path) waveform, sample_rate, Path(features_path)
) )
except AssertionError: except RuntimeError:
logger.warning("Extract file %s failed." % utt_id) logger.warning("Get info of audio file %s failed." % utt_id)
if ( if (
split == "train" split == "train"
...@@ -354,10 +361,14 @@ def process(args): ...@@ -354,10 +361,14 @@ def process(args):
) )
if args.task == "st" and args.add_src and dataset.have_src_utt: if args.task == "st" and args.add_src and dataset.have_src_utt:
manifest["src_text"] = [] manifest["src_text"] = []
for idx in tqdm(range(len(dataset))): for idx in tqdm(range(len(dataset))):
item = dataset[idx] item = dataset[idx]
_, sample_rate, n_frames = dataset.get(idx, need_waveform=False) try:
utt_id = item["id"] _, sample_rate, n_frames = dataset.get(idx, need_waveform=False)
except RuntimeError:
logger.warning("Get info of audio file %s failed." % item["idx"])
utt_id = item["idx"]
if use_raw: if use_raw:
audio_path = item["audio"] audio_path = item["audio"]
...@@ -398,6 +409,7 @@ def process(args): ...@@ -398,6 +409,7 @@ def process(args):
if args.add_src and src_utt is not None: if args.add_src and src_utt is not None:
manifest["src_text"].append(src_utt) manifest["src_text"].append(src_utt)
manifest["tgt_text"].append(tgt_utt) manifest["tgt_text"].append(tgt_utt)
manifest["src_lang"].append(src_lang)
manifest["tgt_lang"].append(tgt_lang) manifest["tgt_lang"].append(tgt_lang)
if is_train_split: if is_train_split:
...@@ -557,7 +569,9 @@ def process_joint(args): ...@@ -557,7 +569,9 @@ def process_joint(args):
special_symbols = None special_symbols = None
if args.task == 'st': if args.task == 'st':
special_symbols = [f'<lang:{lang.split("-")[1]}>' for lang in languages] special_symbols = [f'<lang:{lang.split("-")[0]}>' for lang in languages]
special_symbols.extend([f'<lang:{lang.split("-")[1]}>' for lang in languages])
special_symbols = list(set(special_symbols))
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
output_root / spm_filename_prefix, output_root / spm_filename_prefix,
...@@ -585,7 +599,8 @@ def process_joint(args): ...@@ -585,7 +599,8 @@ def process_joint(args):
for split in args.splits.split(","): for split in args.splits.split(","):
src_path = cur_root / f"{lang}" / f"{task}" / f"{split}.tsv" src_path = cur_root / f"{lang}" / f"{task}" / f"{split}.tsv"
desc_path = output_root / f"{split}_{lang}.tsv" desc_path = output_root / f"{split}_{lang}.tsv"
if not desc_path.is_symlink(): if not os.path.exists(desc_path) and os.path.exists(src_path):
# if not desc_path.is_symlink():
shutil.copy(src_path, desc_path) shutil.copy(src_path, desc_path)
def main(): def main():
...@@ -635,7 +650,6 @@ def main(): ...@@ -635,7 +650,6 @@ def main():
parser.add_argument( parser.add_argument(
"--vocab-type", "--vocab-type",
default="unigram", default="unigram",
required=True,
type=str, type=str,
choices=["word", "bpe", "unigram", "char"], choices=["word", "bpe", "unigram", "char"],
), ),
......
...@@ -177,7 +177,6 @@ def get_features_from_npy_or_audio(path): ...@@ -177,7 +177,6 @@ def get_features_from_npy_or_audio(path):
def get_features_or_waveform_from_uncompressed_zip( def get_features_or_waveform_from_uncompressed_zip(
path, byte_offset, byte_size, need_waveform=False path, byte_offset, byte_size, need_waveform=False
): ):
assert path.endswith(".zip")
data = read_from_uncompressed_zip(path, byte_offset, byte_size) data = read_from_uncompressed_zip(path, byte_offset, byte_size)
f = io.BytesIO(data) f = io.BytesIO(data)
if is_npy_data(data): if is_npy_data(data):
...@@ -341,6 +340,13 @@ class SpeechToTextDataset(FairseqDataset): ...@@ -341,6 +340,13 @@ class SpeechToTextDataset(FairseqDataset):
] ]
assert all(t in self.tgt_dict for t in tgt_lang_tags) assert all(t in self.tgt_dict for t in tgt_lang_tags)
if self.data_cfg.prepend_src_lang_tag_to_enc:
assert self.src_langs is not None and self.src_dict is not None
src_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.src_langs)
]
assert all(t in self.src_dict for t in src_lang_tags)
def tokenize_text(self, text: str, is_src=False): def tokenize_text(self, text: str, is_src=False):
if self.pre_tokenizer is not None: if self.pre_tokenizer is not None:
text = self.pre_tokenizer.encode(text) text = self.pre_tokenizer.encode(text)
...@@ -571,7 +577,7 @@ class SpeechToTextDatasetCreator(object): ...@@ -571,7 +577,7 @@ class SpeechToTextDatasetCreator(object):
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text" KEY_TGT_TEXT = "tgt_text"
KEY_ALIGNED_TGT_TEXT = "aligned_tgt_text" KEY_ALIGNED_TGT_TEXT = "aligned_tgt_text"
KEY_CTC_TGT_TEXT = "ctc_tgt_text" KEY_CTC_TGT_TEXT = "xctc_text"
# optional columns # optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
......
...@@ -438,6 +438,8 @@ class MultiheadAttention(nn.Module): ...@@ -438,6 +438,8 @@ class MultiheadAttention(nn.Module):
localness = 0 localness = 0
window = int(src_len * self.localness_window) window = int(src_len * self.localness_window)
if window == 0:
return
for i in range(window, src_len - window): for i in range(window, src_len - window):
item_localness = 0 item_localness = 0
for j in range(-window, window + 1): for j in range(-window, window + 1):
......
...@@ -14,11 +14,12 @@ from fairseq.optim import FairseqOptimizer, register_optimizer ...@@ -14,11 +14,12 @@ from fairseq.optim import FairseqOptimizer, register_optimizer
from omegaconf import II, DictConfig from omegaconf import II, DictConfig
try: has_deepspeed_cpu_adam = False
from deepspeed.ops.op_builder import CPUAdamBuilder # try:
has_deepspeed_cpu_adam = True # from deepspeed.ops.op_builder import CPUAdamBuilder
except ImportError: # has_deepspeed_cpu_adam = True
has_deepspeed_cpu_adam = False # except ImportError:
# has_deepspeed_cpu_adam = False
@dataclass @dataclass
......
...@@ -537,12 +537,10 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -537,12 +537,10 @@ class SpeechToTextTask(LegacyFairseqTask):
if bpe_tokenizer is None: if bpe_tokenizer is None:
bpe_tokenizer = self.data_cfg.bpe_tokenizer bpe_tokenizer = self.data_cfg.bpe_tokenizer
logger.info(f"tokenizer: {bpe_tokenizer}") logger.info(f"tokenizer: {bpe_tokenizer}")
if bpe_tokenizer is None:
return None
return encoders.build_bpe(Namespace(**bpe_tokenizer)) return encoders.build_bpe(Namespace(**bpe_tokenizer))
# def build_src_bpe(self, args):
# logger.info(f"src tokenizer: {self.data_cfg.src_bpe_tokenizer}")
# return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer))
def get_interactive_tokens_and_lengths(self, lines, encode_fn): def get_interactive_tokens_and_lengths(self, lines, encode_fn):
n_frames = [get_features_or_waveform(p).shape[0] for p in lines] n_frames = [get_features_or_waveform(p).shape[0] for p in lines]
return lines, n_frames return lines, n_frames
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论