Commit cb2f2bcb by xuchen

2023.11

parent 51395037
...@@ -83,10 +83,12 @@ epoch_ensemble=0 ...@@ -83,10 +83,12 @@ epoch_ensemble=0
best_ensemble=1 best_ensemble=1
infer_debug=0 infer_debug=0
infer_score=0 infer_score=0
infer_tag=
infer_parameter=
infer_tag=ee6 infer_tag=ee6
infer_parameters="--early-exit-count 6" infer_parameter="--early-exit-count 6"
#infer_parameters="--early-exit-layer 12" #infer_parameter="--early-exit-layer 12"
#infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy" #infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
data_config=config.yaml data_config=config.yaml
...@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd="${cmd} cmd="${cmd}
--score-reference" --score-reference"
fi fi
if [[ -n ${infer_parameters} ]]; then if [[ -n ${infer_parameter} ]]; then
cmd="${cmd} cmd="${cmd}
${infer_parameters}" ${infer_parameter}"
fi fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
......
...@@ -37,7 +37,7 @@ from tqdm import tqdm ...@@ -37,7 +37,7 @@ from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text"] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "tgt_lang"]
class AudioDataset(Dataset): class AudioDataset(Dataset):
...@@ -398,6 +398,7 @@ def process(args): ...@@ -398,6 +398,7 @@ def process(args):
if args.add_src and src_utt is not None: if args.add_src and src_utt is not None:
manifest["src_text"].append(src_utt) manifest["src_text"].append(src_utt)
manifest["tgt_text"].append(tgt_utt) manifest["tgt_text"].append(tgt_utt)
manifest["tgt_lang"].append(tgt_lang)
if is_train_split: if is_train_split:
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
...@@ -454,8 +455,8 @@ def process(args): ...@@ -454,8 +455,8 @@ def process(args):
# if task == "st" and args.add_src and args.share: # if task == "st" and args.add_src and args.share:
if args.add_src and args.share: if args.add_src and args.share:
for e in reader: for e in reader:
if "src_text" in dict(e):
src_utt = dict(e)["src_text"] src_utt = dict(e)["src_text"]
tgt_utt = dict(e)["tgt_text"]
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
if args.rm_punc_src: if args.rm_punc_src:
...@@ -463,6 +464,8 @@ def process(args): ...@@ -463,6 +464,8 @@ def process(args):
src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" ")) src_utt = " ".join(src_utt.split(" "))
train_text.append(src_utt) train_text.append(src_utt)
tgt_utt = dict(e)["tgt_text"]
train_text.append(tgt_utt) train_text.append(tgt_utt)
else: else:
tgt_text = [(dict(e))["tgt_text"] for e in reader] tgt_text = [(dict(e))["tgt_text"] for e in reader]
...@@ -471,11 +474,16 @@ def process(args): ...@@ -471,11 +474,16 @@ def process(args):
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + "\n") f.write(t + "\n")
special_symbols = None
if args.add_syms:
special_symbols = [f'<lang:{lang}>' for lang in args.tgt_langs.split(",")]
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
output_root / spm_filename_prefix, output_root / spm_filename_prefix,
args.vocab_type, args.vocab_type,
args.vocab_size, args.vocab_size,
special_symbols=special_symbols
) )
# Generate config YAML # Generate config YAML
...@@ -491,9 +499,94 @@ def process(args): ...@@ -491,9 +499,94 @@ def process(args):
cmvn_type=args.cmvn_type, cmvn_type=args.cmvn_type,
gcmvn_path=(output_root / "gcmvn.npz" if args.cmvn_type == "global" else None), gcmvn_path=(output_root / "gcmvn.npz" if args.cmvn_type == "global" else None),
asr_spm_filename=asr_spm_filename, asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if task == "asr" else False, share_src_and_tgt=True if task == "asr" and not args.add_src else False,
prepend_tgt_lang_tag=(args.add_syms),
)
def process_joint(args):
cur_root = Path(args.data_root).absolute()
task = args.task
languages = args.languages.split(",")
assert all((cur_root / f"{lang}").is_dir() for lang in languages), \
"do not have downloaded data available for all languages"
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute()
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
if args.add_src:
if args.share:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
punctuation_str = string.punctuation
punctuation_str = punctuation_str.replace("'", "")
with NamedTemporaryFile(mode="w") as f:
for lang in languages:
tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv"
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
if args.add_src:
for src_utt in df["src_text"]:
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in punctuation_str:
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
f.write(src_utt + "\n")
special_symbols = None
if args.task == 'st':
special_symbols = [f'<lang:{lang.split("-")[1]}>' for lang in languages]
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
special_symbols=special_symbols
) )
# Generate config YAML
yaml_filename = f"config.yaml"
if task == "st" and args.add_src and args.share:
yaml_filename = f"config_share.yaml"
gen_config_yaml(
output_root,
spm_filename_prefix + ".model",
yaml_filename=yaml_filename,
specaugment_policy="ld2",
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if task == "asr" else False,
prepend_tgt_lang_tag=(args.task == "st"),
)
# Make symbolic links to manifests
for lang in languages:
for split in args.splits.split(","):
src_path = cur_root / f"{lang}" / f"{task}" / f"{split}.tsv"
desc_path = output_root / f"{split}_{lang}.tsv"
if not desc_path.is_symlink():
shutil.copy(src_path, desc_path)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -501,8 +594,12 @@ def main(): ...@@ -501,8 +594,12 @@ def main():
parser.add_argument("--data-root", "-d", required=True, type=str) parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str) parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument("--task", type=str, default="st", choices=["asr", "st"]) parser.add_argument("--task", type=str, default="st", choices=["asr", "st"])
parser.add_argument("--src-lang", type=str, required=True, help="source language") parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--add-syms", action="store_true", help="")
parser.add_argument("--src-lang", type=str, help="source language")
parser.add_argument("--tgt-lang", type=str, help="target language") parser.add_argument("--tgt-lang", type=str, help="target language")
parser.add_argument("--tgt-langs", type=str, help="target languages for multilingual training")
parser.add_argument("--languages", type=str, help="languages for multilingual training")
parser.add_argument( parser.add_argument(
"--splits", type=str, default="train,dev,test", help="dataset splits" "--splits", type=str, default="train,dev,test", help="dataset splits"
) )
...@@ -569,6 +666,9 @@ def main(): ...@@ -569,6 +666,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.joint:
process_joint(args)
else:
process(args) process(args)
......
...@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): ...@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "the weight for consistency regularization of mixup"}, metadata={"help": "the weight for consistency regularization of mixup"},
) )
cal_mixup_loss: bool = field( mixup_no_hard_loss: bool = field(
default=False, default=False,
metadata={"help": "calculate the loss for the mixed samples"}, metadata={"help": "calculate the loss for the mixed samples"},
) )
...@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing, label_smoothing,
ignore_prefix_size=0, ignore_prefix_size=0,
report_accuracy=False, report_accuracy=False,
cal_mixup_loss=True, mixup_no_hard_loss=False,
mixup_consistent_weight=0.0, mixup_consistent_weight=0.0,
): ):
super().__init__(task) super().__init__(task)
...@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = float(label_smoothing) self.eps = float(label_smoothing)
self.ignore_prefix_size = ignore_prefix_size self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy self.report_accuracy = report_accuracy
self.cal_mixup_loss = cal_mixup_loss self.mixup_no_hard_loss = mixup_no_hard_loss
self.mixup_consistent_weight = mixup_consistent_weight self.mixup_consistent_weight = mixup_consistent_weight
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
...@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup_coef = net_output[1]["mixup"]["coef"][mixup_flag] mixup_coef = net_output[1]["mixup"]["coef"][mixup_flag]
loss_coef = [mixup_coef, 1 - mixup_coef] loss_coef = [mixup_coef, 1 - mixup_coef]
if self.cal_mixup_loss: if not self.mixup_no_hard_loss:
for item_lprobs, item_target, item_coef in zip(mixup_lprobs, mixup_targets, loss_coef): for item_lprobs, item_target, item_coef in zip(mixup_lprobs, mixup_targets, loss_coef):
batch_size = item_target.size(0) batch_size = item_target.size(0)
item_loss, item_nll_loss = label_smoothed_nll_loss( item_loss, item_nll_loss = label_smoothed_nll_loss(
......
...@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg: CtcCriterionConfig, cfg: CtcCriterionConfig,
ctc_weight=0.0, ctc_weight=0.0,
save_dir=None, save_dir=None,
cal_mixup_loss=True, mixup_no_hard_loss=False,
mixup_consistent_weight=0.0, mixup_consistent_weight=0.0,
only_train_enc_prob=0.0, only_train_enc_prob=0.0,
get_oracle_when_only_train_enc=False get_oracle_when_only_train_enc=False
): ):
super().__init__(task, sentence_avg, label_smoothing, super().__init__(task, sentence_avg, label_smoothing,
report_accuracy=True, report_accuracy=True,
cal_mixup_loss=cal_mixup_loss, mixup_no_hard_loss=mixup_no_hard_loss,
mixup_consistent_weight=mixup_consistent_weight) mixup_consistent_weight=mixup_consistent_weight)
self.report_accuracy = True self.report_accuracy = True
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir) self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir, mixup_no_hard_loss)
self.save_dir = save_dir self.save_dir = save_dir
self.only_train_enc_prob = only_train_enc_prob self.only_train_enc_prob = only_train_enc_prob
......
...@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass): ...@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
"help": "training steps in each epoch" "help": "training steps in each epoch"
} }
) )
sharded_data_load: bool = field(
default=False,
metadata={
"help": "Use sharded data for efficient data load"
},
)
@dataclass @dataclass
...@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass): ...@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "weight for ctc probs for lm fusion"}, metadata={"help": "weight for ctc probs for lm fusion"},
) )
early_exit_count: int = field(
default=0,
metadata={"help": "early exit during decoding when n consecutive predictions are the same"},
)
early_exit_layer: int = field(
default=0,
metadata={"help": "early exit during decoding at layer n"},
)
# arguments for iterative refinement generator # arguments for iterative refinement generator
iter_decode_eos_penalty: float = field( iter_decode_eos_penalty: float = field(
......
...@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None
): ):
if hasattr(self, "ctc"): if hasattr(self, "ctc"):
import os
assert src_dict is not None assert src_dict is not None
self.ctc.set_infer( self.ctc.set_infer(
ctc_infer, ctc_infer,
post_process, post_process,
src_dict, src_dict,
path=path + ".ctc" if path is not None else None, path=os.path.splitext(path)[0] + ".ctc" if path is not None else None,
) )
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"): def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
......
...@@ -250,7 +250,6 @@ class CTCDecoder(object): ...@@ -250,7 +250,6 @@ class CTCDecoder(object):
logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from torchprofile import profile_macs from torchprofile import profile_macs
macs = profile_macs(self.model, [src_tokens, src_lengths]) macs = profile_macs(self.model, [src_tokens, src_lengths])
gmacs = macs / 1e9 gmacs = macs / 1e9
...@@ -269,20 +268,22 @@ class CTCDecoder(object): ...@@ -269,20 +268,22 @@ class CTCDecoder(object):
inter_logits = encoder_outs.get("inter_xctc_logits", []) inter_logits = encoder_outs.get("inter_xctc_logits", [])
if ctc_logit is None: if ctc_logit is None:
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1) ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
if len(inter_logits) > 0: if len(inter_logits) == 0:
inter_logits = encoder_outs.get("inter_ctc_logits", []) inter_logits = encoder_outs.get("inter_ctc_logits", [])
inter_logits_num = len(inter_logits) inter_logits_num = len(inter_logits)
encoder_padding_mask = encoder_outs["encoder_padding_mask"][0] encoder_padding_mask = encoder_outs["encoder_padding_mask"][0]
if self.ctc_inter_logit != 0: if self.ctc_inter_logit != 0:
assert inter_logits_num >= self.ctc_inter_logit
if inter_logits_num != 0: if inter_logits_num != 0:
assert self.ctc_inter_logit <= inter_logits_num
ctc_logit_item = inter_logits[-self.ctc_inter_logit] ctc_logit_item = inter_logits[-self.ctc_inter_logit]
if isinstance(ctc_logit_item, list): if isinstance(ctc_logit_item, list):
ctc_logit = ctc_logit_item[0].transpose(0, 1) ctc_logit = ctc_logit_item[0].transpose(0, 1)
if len(ctc_logit_item) >= 2: if len(ctc_logit_item) >= 2:
encoder_padding_mask = ctc_logit_item[1] encoder_padding_mask = ctc_logit_item[1]
else:
ctc_logit = ctc_logit_item.transpose(0, 1)
logit_length = (~encoder_padding_mask).long().sum(-1) logit_length = (~encoder_padding_mask).long().sum(-1)
finalized = [] finalized = []
...@@ -318,7 +319,7 @@ class CTCDecoder(object): ...@@ -318,7 +319,7 @@ class CTCDecoder(object):
else: else:
logit = inter_logits[i] logit = inter_logits[i]
inter_logits_prob = utils.log_softmax(logits.transpose(0, 1), -1) inter_logits_prob = utils.log_softmax(logit.transpose(0, 1), -1)
ctc_probs += inter_logits_prob ctc_probs += inter_logits_prob
topk_prob, topk_index = ctc_probs.topk(1, dim=2) topk_prob, topk_index = ctc_probs.topk(1, dim=2)
......
...@@ -74,6 +74,19 @@ class CTC(nn.Module): ...@@ -74,6 +74,19 @@ class CTC(nn.Module):
def argmax(self, x): def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1) return torch.argmax(self.ctc_projection(x), dim=-1)
def predict(self, logits, padding):
input_lengths = (~padding).sum(-1)
logits = logits.transpose(0, 1).float().contiguous()
predicts = []
for logit, inp_l in zip(logits, input_lengths):
toks = logit[:inp_l].argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.dictionary.bos()]
# pred_units_arr = logit[:inp_l].argmax(dim=-1)
predicts.append(pred_units_arr)
return predicts
def infer(self, logits_or_probs, lengths, tag=None): def infer(self, logits_or_probs, lengths, tag=None):
for lp, inp_l in zip( for lp, inp_l in zip(
logits_or_probs, logits_or_probs,
......
...@@ -130,6 +130,9 @@ class FairseqTask(object): ...@@ -130,6 +130,9 @@ class FairseqTask(object):
""" """
return cls(cfg, **kwargs) return cls(cfg, **kwargs)
def sharded_data_load(self):
return getattr(self.cfg, "sharded_data_load", False)
def has_sharded_data(self, split): def has_sharded_data(self, split):
return os.pathsep in getattr(self.cfg, "data", "") return os.pathsep in getattr(self.cfg, "data", "")
...@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask): ...@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
""" """
return cls(args, **kwargs) return cls(args, **kwargs)
def sharded_data_load(self):
return getattr(self.args, "sharded_data_load", False)
def has_sharded_data(self, split): def has_sharded_data(self, split):
return os.pathsep in getattr(self.args, "data", "") return os.pathsep in getattr(self.args, "data", "")
......
...@@ -521,16 +521,25 @@ class Trainer(object): ...@@ -521,16 +521,25 @@ class Trainer(object):
disable_iterator_cache=False, disable_iterator_cache=False,
): ):
"""Return an EpochBatchIterator over the training set for a given epoch.""" """Return an EpochBatchIterator over the training set for a given epoch."""
if self.task.sharded_data_load():
datasets = self.cfg.dataset.train_subset.split(",")
curr_dataset = datasets[(epoch - 1) % len(datasets)]
logger.info("sharded loading the training subset {}".format(curr_dataset))
else:
curr_dataset = self.cfg.dataset.train_subset
load_dataset = load_dataset or self.task.sharded_data_load()
disable_iterator_cache = disable_iterator_cache or self.task.sharded_data_load()
if load_dataset: if load_dataset:
logger.info("loading train data for epoch {}".format(epoch)) logger.info("loading train data for epoch {}".format(epoch))
self.task.load_dataset( self.task.load_dataset(
self.cfg.dataset.train_subset, curr_dataset,
epoch=epoch, epoch=epoch,
combine=combine, combine=combine,
data_selector=data_selector, data_selector=data_selector,
) )
batch_iterator = self.task.get_batch_iterator( batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(self.cfg.dataset.train_subset), dataset=self.task.dataset(curr_dataset),
max_tokens=self.cfg.dataset.max_tokens, max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size, max_sentences=self.cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
......
...@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name): ...@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
freeze_module_name = freeze_module_name.split(",") freeze_module_name = freeze_module_name.split(",")
for name in freeze_module_name: for name in freeze_module_name:
freeze_module_params_by_name(module, name) freeze_module_params_by_name(module, name)
def distribution_soft_to_hard(logit_or_prob):
argmax_prob = torch.argmax(logit_or_prob, dim=-1, keepdim=True)
hard_distribution = (
(
argmax_prob
== torch.arange(logit_or_prob.size(-1), device=logit_or_prob.device).unsqueeze(
0
)
)
.to(logit_or_prob.dtype)
)
return hard_distribution
\ No newline at end of file
...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for model in models: for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"): if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece", model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict, translation_path) src_dict, tgt_dict, translation_path, cfg.generation.early_exit_count)
if hasattr(model, "encoder") and hasattr(model.encoder, "set_flag"): if hasattr(model, "encoder") and hasattr(model.encoder, "set_flag"):
model.encoder.set_flag( model.encoder.set_flag(
cal_localness=cfg.generation.cal_localness, cal_localness=cfg.generation.cal_localness,
...@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
mixup_infer=cfg.generation.mixup_infer, mixup_infer=cfg.generation.mixup_infer,
gather_cos_sim=cfg.generation.gather_cos_sim, gather_cos_sim=cfg.generation.gather_cos_sim,
gather_cos_sim_dis=cfg.generation.gather_cos_sim_dis, gather_cos_sim_dis=cfg.generation.gather_cos_sim_dis,
early_exit_layer=cfg.generation.early_exit_layer,
) )
if hasattr(model, "decoder") and hasattr(model.decoder, "set_flag"): if hasattr(model, "decoder") and hasattr(model.decoder, "set_flag"):
model.decoder.set_flag( model.decoder.set_flag(
...@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
# Remove padding # Remove padding
if "src_tokens" in sample["net_input"]: if "src_tokens" in sample["net_input"]:
if sample["net_input"]["src_tokens"].dtype in [torch.int32, torch.int64]:
src_tokens = utils.strip_pad( src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() sample["net_input"]["src_tokens"][i, :], src_dict.pad()
) )
elif "transcript" in sample:
src_tokens = utils.strip_pad(
sample["transcript"]["tokens"][i, :], src_dict.pad()
)
else: else:
src_tokens = None src_tokens = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论