Commit cb2f2bcb by xuchen

2023.11

parent 51395037
......@@ -83,10 +83,12 @@ epoch_ensemble=0
best_ensemble=1
infer_debug=0
infer_score=0
infer_tag=
infer_parameter=
infer_tag=ee6
infer_parameters="--early-exit-count 6"
#infer_parameters="--early-exit-layer 12"
#infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
infer_parameter="--early-exit-count 6"
#infer_parameter="--early-exit-layer 12"
#infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
data_config=config.yaml
......@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd="${cmd}
--score-reference"
fi
if [[ -n ${infer_parameters} ]]; then
if [[ -n ${infer_parameter} ]]; then
cmd="${cmd}
${infer_parameters}"
${infer_parameter}"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
......
......@@ -37,7 +37,7 @@ from tqdm import tqdm
logger = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text"]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "tgt_lang"]
class AudioDataset(Dataset):
......@@ -398,6 +398,7 @@ def process(args):
if args.add_src and src_utt is not None:
manifest["src_text"].append(src_utt)
manifest["tgt_text"].append(tgt_utt)
manifest["tgt_lang"].append(tgt_lang)
if is_train_split:
if args.task == "st" and args.add_src and args.share:
......@@ -454,15 +455,17 @@ def process(args):
# if task == "st" and args.add_src and args.share:
if args.add_src and args.share:
for e in reader:
src_utt = dict(e)["src_text"]
if "src_text" in dict(e):
src_utt = dict(e)["src_text"]
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in punctuation_str:
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
train_text.append(src_utt)
tgt_utt = dict(e)["tgt_text"]
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in punctuation_str:
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
train_text.append(src_utt)
train_text.append(tgt_utt)
else:
tgt_text = [(dict(e))["tgt_text"] for e in reader]
......@@ -471,11 +474,16 @@ def process(args):
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
special_symbols = None
if args.add_syms:
special_symbols = [f'<lang:{lang}>' for lang in args.tgt_langs.split(",")]
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
special_symbols=special_symbols
)
# Generate config YAML
......@@ -491,18 +499,107 @@ def process(args):
cmvn_type=args.cmvn_type,
gcmvn_path=(output_root / "gcmvn.npz" if args.cmvn_type == "global" else None),
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if task == "asr" else False,
share_src_and_tgt=True if task == "asr" and not args.add_src else False,
prepend_tgt_lang_tag=(args.add_syms),
)
def process_joint(args):
cur_root = Path(args.data_root).absolute()
task = args.task
languages = args.languages.split(",")
assert all((cur_root / f"{lang}").is_dir() for lang in languages), \
"do not have downloaded data available for all languages"
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute()
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
if args.add_src:
if args.share:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
punctuation_str = string.punctuation
punctuation_str = punctuation_str.replace("'", "")
with NamedTemporaryFile(mode="w") as f:
for lang in languages:
tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv"
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
if args.add_src:
for src_utt in df["src_text"]:
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in punctuation_str:
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
f.write(src_utt + "\n")
special_symbols = None
if args.task == 'st':
special_symbols = [f'<lang:{lang.split("-")[1]}>' for lang in languages]
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
special_symbols=special_symbols
)
# Generate config YAML
yaml_filename = f"config.yaml"
if task == "st" and args.add_src and args.share:
yaml_filename = f"config_share.yaml"
gen_config_yaml(
output_root,
spm_filename_prefix + ".model",
yaml_filename=yaml_filename,
specaugment_policy="ld2",
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if task == "asr" else False,
prepend_tgt_lang_tag=(args.task == "st"),
)
# Make symbolic links to manifests
for lang in languages:
for split in args.splits.split(","):
src_path = cur_root / f"{lang}" / f"{task}" / f"{split}.tsv"
desc_path = output_root / f"{split}_{lang}.tsv"
if not desc_path.is_symlink():
shutil.copy(src_path, desc_path)
def main():
parser = argparse.ArgumentParser()
# general setting
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument("--task", type=str, default="st", choices=["asr", "st"])
parser.add_argument("--src-lang", type=str, required=True, help="source language")
parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--add-syms", action="store_true", help="")
parser.add_argument("--src-lang", type=str, help="source language")
parser.add_argument("--tgt-lang", type=str, help="target language")
parser.add_argument("--tgt-langs", type=str, help="target languages for multilingual training")
parser.add_argument("--languages", type=str, help="languages for multilingual training")
parser.add_argument(
"--splits", type=str, default="train,dev,test", help="dataset splits"
)
......@@ -569,7 +666,10 @@ def main():
args = parser.parse_args()
process(args)
if args.joint:
process_joint(args)
else:
process(args)
if __name__ == "__main__":
......
......@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "the weight for consistency regularization of mixup"},
)
cal_mixup_loss: bool = field(
mixup_no_hard_loss: bool = field(
default=False,
metadata={"help": "calculate the loss for the mixed samples"},
)
......@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
cal_mixup_loss=True,
mixup_no_hard_loss=False,
mixup_consistent_weight=0.0,
):
super().__init__(task)
......@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = float(label_smoothing)
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
self.cal_mixup_loss = cal_mixup_loss
self.mixup_no_hard_loss = mixup_no_hard_loss
self.mixup_consistent_weight = mixup_consistent_weight
def forward(self, model, sample, reduce=True):
......@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup_coef = net_output[1]["mixup"]["coef"][mixup_flag]
loss_coef = [mixup_coef, 1 - mixup_coef]
if self.cal_mixup_loss:
if not self.mixup_no_hard_loss:
for item_lprobs, item_target, item_coef in zip(mixup_lprobs, mixup_targets, loss_coef):
batch_size = item_target.size(0)
item_loss, item_nll_loss = label_smoothed_nll_loss(
......
......@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg: CtcCriterionConfig,
ctc_weight=0.0,
save_dir=None,
cal_mixup_loss=True,
mixup_no_hard_loss=False,
mixup_consistent_weight=0.0,
only_train_enc_prob=0.0,
get_oracle_when_only_train_enc=False
):
super().__init__(task, sentence_avg, label_smoothing,
report_accuracy=True,
cal_mixup_loss=cal_mixup_loss,
mixup_no_hard_loss=mixup_no_hard_loss,
mixup_consistent_weight=mixup_consistent_weight)
self.report_accuracy = True
self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir)
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir, mixup_no_hard_loss)
self.save_dir = save_dir
self.only_train_enc_prob = only_train_enc_prob
......
......@@ -358,7 +358,7 @@ class SpeechToTextDataset(FairseqDataset):
def check_tgt_lang_tag(self):
if self.data_cfg.prepend_tgt_lang_tag:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
......
......@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
"help": "training steps in each epoch"
}
)
sharded_data_load: bool = field(
default=False,
metadata={
"help": "Use sharded data for efficient data load"
},
)
@dataclass
......@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight for ctc probs for lm fusion"},
)
early_exit_count: int = field(
default=0,
metadata={"help": "early exit during decoding when n consecutive predictions are the same"},
)
early_exit_layer: int = field(
default=0,
metadata={"help": "early exit during decoding at layer n"},
)
# arguments for iterative refinement generator
iter_decode_eos_penalty: float = field(
......
......@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None
):
if hasattr(self, "ctc"):
import os
assert src_dict is not None
self.ctc.set_infer(
ctc_infer,
post_process,
src_dict,
path=path + ".ctc" if path is not None else None,
path=os.path.splitext(path)[0] + ".ctc" if path is not None else None,
)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
......
......@@ -250,7 +250,6 @@ class CTCDecoder(object):
logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from torchprofile import profile_macs
macs = profile_macs(self.model, [src_tokens, src_lengths])
gmacs = macs / 1e9
......@@ -269,20 +268,22 @@ class CTCDecoder(object):
inter_logits = encoder_outs.get("inter_xctc_logits", [])
if ctc_logit is None:
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
if len(inter_logits) > 0:
if len(inter_logits) == 0:
inter_logits = encoder_outs.get("inter_ctc_logits", [])
inter_logits_num = len(inter_logits)
encoder_padding_mask = encoder_outs["encoder_padding_mask"][0]
if self.ctc_inter_logit != 0:
assert inter_logits_num >= self.ctc_inter_logit
if inter_logits_num != 0:
assert self.ctc_inter_logit <= inter_logits_num
ctc_logit_item = inter_logits[-self.ctc_inter_logit]
if isinstance(ctc_logit_item, list):
ctc_logit = ctc_logit_item[0].transpose(0, 1)
if len(ctc_logit_item) >= 2:
encoder_padding_mask = ctc_logit_item[1]
else:
ctc_logit = ctc_logit_item.transpose(0, 1)
logit_length = (~encoder_padding_mask).long().sum(-1)
finalized = []
......@@ -318,7 +319,7 @@ class CTCDecoder(object):
else:
logit = inter_logits[i]
inter_logits_prob = utils.log_softmax(logits.transpose(0, 1), -1)
inter_logits_prob = utils.log_softmax(logit.transpose(0, 1), -1)
ctc_probs += inter_logits_prob
topk_prob, topk_index = ctc_probs.topk(1, dim=2)
......
......@@ -74,6 +74,19 @@ class CTC(nn.Module):
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
def predict(self, logits, padding):
input_lengths = (~padding).sum(-1)
logits = logits.transpose(0, 1).float().contiguous()
predicts = []
for logit, inp_l in zip(logits, input_lengths):
toks = logit[:inp_l].argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.dictionary.bos()]
# pred_units_arr = logit[:inp_l].argmax(dim=-1)
predicts.append(pred_units_arr)
return predicts
def infer(self, logits_or_probs, lengths, tag=None):
for lp, inp_l in zip(
logits_or_probs,
......
......@@ -130,6 +130,9 @@ class FairseqTask(object):
"""
return cls(cfg, **kwargs)
def sharded_data_load(self):
return getattr(self.cfg, "sharded_data_load", False)
def has_sharded_data(self, split):
return os.pathsep in getattr(self.cfg, "data", "")
......@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
"""
return cls(args, **kwargs)
def sharded_data_load(self):
return getattr(self.args, "sharded_data_load", False)
def has_sharded_data(self, split):
return os.pathsep in getattr(self.args, "data", "")
......
......@@ -521,16 +521,25 @@ class Trainer(object):
disable_iterator_cache=False,
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if self.task.sharded_data_load():
datasets = self.cfg.dataset.train_subset.split(",")
curr_dataset = datasets[(epoch - 1) % len(datasets)]
logger.info("sharded loading the training subset {}".format(curr_dataset))
else:
curr_dataset = self.cfg.dataset.train_subset
load_dataset = load_dataset or self.task.sharded_data_load()
disable_iterator_cache = disable_iterator_cache or self.task.sharded_data_load()
if load_dataset:
logger.info("loading train data for epoch {}".format(epoch))
self.task.load_dataset(
self.cfg.dataset.train_subset,
curr_dataset,
epoch=epoch,
combine=combine,
data_selector=data_selector,
)
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(self.cfg.dataset.train_subset),
dataset=self.task.dataset(curr_dataset),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
......
......@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
freeze_module_name = freeze_module_name.split(",")
for name in freeze_module_name:
freeze_module_params_by_name(module, name)
def distribution_soft_to_hard(logit_or_prob):
argmax_prob = torch.argmax(logit_or_prob, dim=-1, keepdim=True)
hard_distribution = (
(
argmax_prob
== torch.arange(logit_or_prob.size(-1), device=logit_or_prob.device).unsqueeze(
0
)
)
.to(logit_or_prob.dtype)
)
return hard_distribution
\ No newline at end of file
......@@ -108,8 +108,8 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict, translation_path)
if hasattr(model, "encoder") and hasattr(model.encoder, "set_flag"):
src_dict, tgt_dict, translation_path, cfg.generation.early_exit_count)
if hasattr(model, "encoder") and hasattr(model.encoder, "set_flag"):
model.encoder.set_flag(
cal_localness=cfg.generation.cal_localness,
localness_window=cfg.generation.localness_window,
......@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
mixup_infer=cfg.generation.mixup_infer,
gather_cos_sim=cfg.generation.gather_cos_sim,
gather_cos_sim_dis=cfg.generation.gather_cos_sim_dis,
early_exit_layer=cfg.generation.early_exit_layer,
)
if hasattr(model, "decoder") and hasattr(model.decoder, "set_flag"):
model.decoder.set_flag(
......@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
# Remove padding
if "src_tokens" in sample["net_input"]:
src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
)
if sample["net_input"]["src_tokens"].dtype in [torch.int32, torch.int64]:
src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], src_dict.pad()
)
elif "transcript" in sample:
src_tokens = utils.strip_pad(
sample["transcript"]["tokens"][i, :], src_dict.pad()
)
else:
src_tokens = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论