Commit cb2f2bcb by xuchen

2023.11

parent 51395037
...@@ -83,10 +83,12 @@ epoch_ensemble=0 ...@@ -83,10 +83,12 @@ epoch_ensemble=0
best_ensemble=1 best_ensemble=1
infer_debug=0 infer_debug=0
infer_score=0 infer_score=0
infer_tag=
infer_parameter=
infer_tag=ee6 infer_tag=ee6
infer_parameters="--early-exit-count 6" infer_parameter="--early-exit-count 6"
#infer_parameters="--early-exit-layer 12" #infer_parameter="--early-exit-layer 12"
#infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy" #infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
data_config=config.yaml data_config=config.yaml
...@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -416,9 +418,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd="${cmd} cmd="${cmd}
--score-reference" --score-reference"
fi fi
if [[ -n ${infer_parameters} ]]; then if [[ -n ${infer_parameter} ]]; then
cmd="${cmd} cmd="${cmd}
${infer_parameters}" ${infer_parameter}"
fi fi
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
......
...@@ -37,7 +37,7 @@ from tqdm import tqdm ...@@ -37,7 +37,7 @@ from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text"] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "tgt_lang"]
class AudioDataset(Dataset): class AudioDataset(Dataset):
...@@ -398,6 +398,7 @@ def process(args): ...@@ -398,6 +398,7 @@ def process(args):
if args.add_src and src_utt is not None: if args.add_src and src_utt is not None:
manifest["src_text"].append(src_utt) manifest["src_text"].append(src_utt)
manifest["tgt_text"].append(tgt_utt) manifest["tgt_text"].append(tgt_utt)
manifest["tgt_lang"].append(tgt_lang)
if is_train_split: if is_train_split:
if args.task == "st" and args.add_src and args.share: if args.task == "st" and args.add_src and args.share:
...@@ -454,8 +455,8 @@ def process(args): ...@@ -454,8 +455,8 @@ def process(args):
# if task == "st" and args.add_src and args.share: # if task == "st" and args.add_src and args.share:
if args.add_src and args.share: if args.add_src and args.share:
for e in reader: for e in reader:
if "src_text" in dict(e):
src_utt = dict(e)["src_text"] src_utt = dict(e)["src_text"]
tgt_utt = dict(e)["tgt_text"]
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
if args.rm_punc_src: if args.rm_punc_src:
...@@ -463,6 +464,8 @@ def process(args): ...@@ -463,6 +464,8 @@ def process(args):
src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" ")) src_utt = " ".join(src_utt.split(" "))
train_text.append(src_utt) train_text.append(src_utt)
tgt_utt = dict(e)["tgt_text"]
train_text.append(tgt_utt) train_text.append(tgt_utt)
else: else:
tgt_text = [(dict(e))["tgt_text"] for e in reader] tgt_text = [(dict(e))["tgt_text"] for e in reader]
...@@ -471,11 +474,16 @@ def process(args): ...@@ -471,11 +474,16 @@ def process(args):
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in train_text: for t in train_text:
f.write(t + "\n") f.write(t + "\n")
special_symbols = None
if args.add_syms:
special_symbols = [f'<lang:{lang}>' for lang in args.tgt_langs.split(",")]
gen_vocab( gen_vocab(
Path(f.name), Path(f.name),
output_root / spm_filename_prefix, output_root / spm_filename_prefix,
args.vocab_type, args.vocab_type,
args.vocab_size, args.vocab_size,
special_symbols=special_symbols
) )
# Generate config YAML # Generate config YAML
...@@ -491,9 +499,94 @@ def process(args): ...@@ -491,9 +499,94 @@ def process(args):
cmvn_type=args.cmvn_type, cmvn_type=args.cmvn_type,
gcmvn_path=(output_root / "gcmvn.npz" if args.cmvn_type == "global" else None), gcmvn_path=(output_root / "gcmvn.npz" if args.cmvn_type == "global" else None),
asr_spm_filename=asr_spm_filename, asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if task == "asr" else False, share_src_and_tgt=True if task == "asr" and not args.add_src else False,
prepend_tgt_lang_tag=(args.add_syms),
)
def process_joint(args):
cur_root = Path(args.data_root).absolute()
task = args.task
languages = args.languages.split(",")
assert all((cur_root / f"{lang}").is_dir() for lang in languages), \
"do not have downloaded data available for all languages"
if args.output_root is None:
output_root = cur_root
else:
output_root = Path(args.output_root).absolute()
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
asr_spm_filename = None
if args.add_src:
if args.share:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
else:
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{task}_share"
asr_spm_filename = spm_filename_prefix + ".model"
else:
if args.st_spm_prefix is not None:
spm_filename_prefix = args.st_spm_prefix
assert args.asr_prefix is not None
asr_spm_filename = args.asr_prefix + ".model"
elif task == "asr":
if args.asr_prefix is not None:
spm_filename_prefix = args.asr_prefix
punctuation_str = string.punctuation
punctuation_str = punctuation_str.replace("'", "")
with NamedTemporaryFile(mode="w") as f:
for lang in languages:
tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv"
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
if args.add_src:
for src_utt in df["src_text"]:
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in punctuation_str:
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
f.write(src_utt + "\n")
special_symbols = None
if args.task == 'st':
special_symbols = [f'<lang:{lang.split("-")[1]}>' for lang in languages]
gen_vocab(
Path(f.name),
output_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
special_symbols=special_symbols
) )
# Generate config YAML
yaml_filename = f"config.yaml"
if task == "st" and args.add_src and args.share:
yaml_filename = f"config_share.yaml"
gen_config_yaml(
output_root,
spm_filename_prefix + ".model",
yaml_filename=yaml_filename,
specaugment_policy="ld2",
asr_spm_filename=asr_spm_filename,
share_src_and_tgt=True if task == "asr" else False,
prepend_tgt_lang_tag=(args.task == "st"),
)
# Make symbolic links to manifests
for lang in languages:
for split in args.splits.split(","):
src_path = cur_root / f"{lang}" / f"{task}" / f"{split}.tsv"
desc_path = output_root / f"{split}_{lang}.tsv"
if not desc_path.is_symlink():
shutil.copy(src_path, desc_path)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -501,8 +594,12 @@ def main(): ...@@ -501,8 +594,12 @@ def main():
parser.add_argument("--data-root", "-d", required=True, type=str) parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-root", "-o", default=None, type=str) parser.add_argument("--output-root", "-o", default=None, type=str)
parser.add_argument("--task", type=str, default="st", choices=["asr", "st"]) parser.add_argument("--task", type=str, default="st", choices=["asr", "st"])
parser.add_argument("--src-lang", type=str, required=True, help="source language") parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--add-syms", action="store_true", help="")
parser.add_argument("--src-lang", type=str, help="source language")
parser.add_argument("--tgt-lang", type=str, help="target language") parser.add_argument("--tgt-lang", type=str, help="target language")
parser.add_argument("--tgt-langs", type=str, help="target languages for multilingual training")
parser.add_argument("--languages", type=str, help="languages for multilingual training")
parser.add_argument( parser.add_argument(
"--splits", type=str, default="train,dev,test", help="dataset splits" "--splits", type=str, default="train,dev,test", help="dataset splits"
) )
...@@ -569,6 +666,9 @@ def main(): ...@@ -569,6 +666,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.joint:
process_joint(args)
else:
process(args) process(args)
......
...@@ -125,6 +125,18 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -125,6 +125,18 @@ class CtcCriterionConfig(FairseqDataclass):
default=0, default=0,
metadata={"help": "consistent regularization for inter CTC loss in mixup"}, metadata={"help": "consistent regularization for inter CTC loss in mixup"},
) )
xctc_mixup_consistent_weight: float = field(
default=0,
metadata={"help": "consistent regularization for XCTC loss in mixup"},
)
inter_xctc_mixup_consistent_weight: float = field(
default=0,
metadata={"help": "consistent regularization for Inter XCTC loss in mixup"},
)
ctc_mixup_consistent_hard_target: bool = field(
default=False,
metadata={"help": "use hard distribution during mixup consistent learning"},
)
wer_kenlm_model: Optional[str] = field( wer_kenlm_model: Optional[str] = field(
default=None, default=None,
...@@ -156,7 +168,11 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -156,7 +168,11 @@ class CtcCriterionConfig(FairseqDataclass):
@register_criterion("ctc", dataclass=CtcCriterionConfig) @register_criterion("ctc", dataclass=CtcCriterionConfig)
class CtcCriterion(FairseqCriterion): class CtcCriterion(FairseqCriterion):
def __init__( def __init__(
self, cfg: CtcCriterionConfig, task: FairseqTask, ctc_weight=1.0, save_dir=None self, cfg: CtcCriterionConfig,
task: FairseqTask,
ctc_weight=1.0,
save_dir=None,
mixup_no_hard_loss=False,
): ):
super().__init__(task) super().__init__(task)
...@@ -224,6 +240,10 @@ class CtcCriterion(FairseqCriterion): ...@@ -224,6 +240,10 @@ class CtcCriterion(FairseqCriterion):
self.ctc_mixup_consistent_weight = cfg.ctc_mixup_consistent_weight self.ctc_mixup_consistent_weight = cfg.ctc_mixup_consistent_weight
self.inter_ctc_mixup_consistent_weight = cfg.inter_ctc_mixup_consistent_weight self.inter_ctc_mixup_consistent_weight = cfg.inter_ctc_mixup_consistent_weight
self.xctc_mixup_consistent_weight = cfg.xctc_mixup_consistent_weight
self.inter_xctc_mixup_consistent_weight = cfg.inter_xctc_mixup_consistent_weight
self.mixup_no_hard_loss = mixup_no_hard_loss
self.ctc_mixup_consistent_hard_target = cfg.ctc_mixup_consistent_hard_target
self.all_ctc_weight = ( self.all_ctc_weight = (
self.ctc_weight self.ctc_weight
...@@ -441,6 +461,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -441,6 +461,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths, target_lengths,
loss_coef, loss_coef,
force_emit=None, force_emit=None,
loss_mask_flag=None
): ):
lprobs = model.get_normalized_probs( lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True [ctc_logit], log_probs=True
...@@ -470,6 +491,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -470,6 +491,9 @@ class CtcCriterion(FairseqCriterion):
input_lengths, input_lengths,
item_target_lengths, item_target_lengths,
) )
if loss_mask_flag is not None:
item_loss = item_loss * loss_mask_flag
loss += (item_loss * item_coef).sum() loss += (item_loss * item_coef).sum()
return loss, lprobs return loss, lprobs
...@@ -518,6 +542,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -518,6 +542,7 @@ class CtcCriterion(FairseqCriterion):
target_tokens != self.eos_idx target_tokens != self.eos_idx
) )
mixup_flag = None
if "mixup" in net_output and net_output["mixup"] is not None: if "mixup" in net_output and net_output["mixup"] is not None:
mixup_coef = net_output["mixup"]["coef"] mixup_coef = net_output["mixup"]["coef"]
mixup_idx1 = net_output["mixup"]["index1"] mixup_idx1 = net_output["mixup"]["index1"]
...@@ -532,12 +557,14 @@ class CtcCriterion(FairseqCriterion): ...@@ -532,12 +557,14 @@ class CtcCriterion(FairseqCriterion):
target_tokens = [target_tokens1, target_tokens2] target_tokens = [target_tokens1, target_tokens2]
target_lengths = [target_lengths1, target_lengths2] target_lengths = [target_lengths1, target_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef] loss_coef = [mixup_coef, 1 - mixup_coef]
if self.mixup_no_hard_loss:
mixup_flag = ~net_output["mixup"]["mixup_flag"]
else: else:
target_tokens = [target_tokens.masked_select(target_pad_mask)] target_tokens = [target_tokens.masked_select(target_pad_mask)]
target_lengths = [target_pad_mask.sum(-1)] target_lengths = [target_pad_mask.sum(-1)]
loss_coef = [1] loss_coef = [1]
return target_tokens, target_lengths, loss_coef return target_tokens, target_lengths, loss_coef, mixup_flag
def compute_ctc_loss(self, model, sample, net_output, logging_output): def compute_ctc_loss(self, model, sample, net_output, logging_output):
if "transcript" in sample: if "transcript" in sample:
...@@ -557,7 +584,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -557,7 +584,7 @@ class CtcCriterion(FairseqCriterion):
nfeatures = input_lengths.sum().item() nfeatures = input_lengths.sum().item()
logging_output["nfeatures"] = nfeatures logging_output["nfeatures"] = nfeatures
transcripts, transcript_lengths, loss_coef = self.get_targets_for_ctc_loss(tokens, net_output) transcripts, transcript_lengths, loss_coef, mixup_flag = self.get_targets_for_ctc_loss(tokens, net_output)
all_ctc_logits = dict() all_ctc_logits = dict()
self.ctc_names = [] self.ctc_names = []
...@@ -570,17 +597,17 @@ class CtcCriterion(FairseqCriterion): ...@@ -570,17 +597,17 @@ class CtcCriterion(FairseqCriterion):
if "inter_ctc_logits" in net_output: if "inter_ctc_logits" in net_output:
inter_ctc_num = len(net_output["inter_ctc_logits"]) inter_ctc_num = len(net_output["inter_ctc_logits"])
# calculate the inter CTC loss # calculate the Inter CTC loss
if self.inter_ctc_weight > 0 and inter_ctc_num > 0: if self.inter_ctc_weight > 0 and inter_ctc_num > 0:
logits = net_output["inter_ctc_logits"] logits = net_output["inter_ctc_logits"]
for i in range(inter_ctc_num): for i in range(inter_ctc_num):
inter_transcripts, inter_transcript_lengths, inter_loss_coef = transcripts, transcript_lengths, loss_coef inter_transcripts, inter_transcript_lengths, inter_loss_coef, inter_mixup_flag = transcripts, transcript_lengths, loss_coef, mixup_flag
if self.inter_ctc_mlo is not None: if self.inter_ctc_mlo is not None:
order = self.inter_ctc_mlo[i] order = self.inter_ctc_mlo[i]
tokens_key = "transcript%s" % order tokens_key = "transcript%s" % order
if sample.get(tokens_key, None): if sample.get(tokens_key, None):
inter_tokens = sample[tokens_key]["tokens"] inter_tokens = sample[tokens_key]["tokens"]
inter_transcripts, inter_transcript_lengths, inter_loss_coef = self.get_targets_for_ctc_loss(inter_tokens, net_output) inter_transcripts, inter_transcript_lengths, inter_loss_coef, inter_mixup_flag = self.get_targets_for_ctc_loss(inter_tokens, net_output)
logit = logits[i] logit = logits[i]
force_emit = None force_emit = None
...@@ -625,6 +652,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -625,6 +652,7 @@ class CtcCriterion(FairseqCriterion):
inter_transcript_lengths, inter_transcript_lengths,
inter_loss_coef, inter_loss_coef,
force_emit, force_emit,
inter_mixup_flag
) )
inter_ctc_loss += inter_loss inter_ctc_loss += inter_loss
lprobs = inter_lprobs lprobs = inter_lprobs
...@@ -641,7 +669,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -641,7 +669,6 @@ class CtcCriterion(FairseqCriterion):
): ):
use_ctc = True use_ctc = True
logit = net_output["ctc_logit"][0] logit = net_output["ctc_logit"][0]
# all_ctc_logits["ctc_logit"] = [ctc_logit, input_lengths]
force_emit = None force_emit = None
if type(logit) == list: if type(logit) == list:
...@@ -664,6 +691,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -664,6 +691,7 @@ class CtcCriterion(FairseqCriterion):
transcript_lengths, transcript_lengths,
loss_coef, loss_coef,
force_emit, force_emit,
mixup_flag
) )
if self.ctc_entropy_weight > 0: if self.ctc_entropy_weight > 0:
...@@ -687,7 +715,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -687,7 +715,7 @@ class CtcCriterion(FairseqCriterion):
if self.use_axctc: if self.use_axctc:
aligned_target_tokens = self.get_aligned_target_text(sample) aligned_target_tokens = self.get_aligned_target_text(sample)
target_tokens, target_lengths, loss_coef = self.get_targets_for_ctc_loss( target_tokens, target_lengths, loss_coef, target_mixup_flag = self.get_targets_for_ctc_loss(
aligned_target_tokens, net_output aligned_target_tokens, net_output
) )
...@@ -711,7 +739,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -711,7 +739,6 @@ class CtcCriterion(FairseqCriterion):
inter_axctc_logit = logit inter_axctc_logit = logit
inter_input_lengths = input_lengths inter_input_lengths = input_lengths
# all_ctc_logits["inter_axctc_logit%d" % i] = [inter_axctc_logit, inter_input_lengths]
inter_loss, target_inter_lprobs = self.get_ctc_loss( inter_loss, target_inter_lprobs = self.get_ctc_loss(
model, model,
inter_axctc_logit, inter_axctc_logit,
...@@ -720,6 +747,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -720,6 +747,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths, target_lengths,
loss_coef, loss_coef,
force_emit, force_emit,
target_mixup_flag
) )
inter_axctc_loss += inter_loss inter_axctc_loss += inter_loss
target_lprobs = target_inter_lprobs target_lprobs = target_inter_lprobs
...@@ -730,7 +758,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -730,7 +758,6 @@ class CtcCriterion(FairseqCriterion):
if self.axctc_weight > 0: if self.axctc_weight > 0:
assert "axctc_logit" in net_output assert "axctc_logit" in net_output
logit = net_output["axctc_logit"][0] logit = net_output["axctc_logit"][0]
# all_ctc_logits["axctc_logit"] = [axctc_logit, input_lengths]
force_emit = None force_emit = None
if type(logit) == list: if type(logit) == list:
...@@ -753,6 +780,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -753,6 +780,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths, target_lengths,
loss_coef, loss_coef,
force_emit, force_emit,
target_mixup_flag
) )
logging_output["axctc_loss"] = utils.item(axctc_loss.data) logging_output["axctc_loss"] = utils.item(axctc_loss.data)
...@@ -762,7 +790,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -762,7 +790,7 @@ class CtcCriterion(FairseqCriterion):
if self.use_xctc: if self.use_xctc:
ctc_target_tokens = self.get_ctc_target_text(sample) ctc_target_tokens = self.get_ctc_target_text(sample)
target_tokens, target_lengths, loss_coef = self.get_targets_for_ctc_loss( target_tokens, target_lengths, loss_coef, target_mixup_flag = self.get_targets_for_ctc_loss(
ctc_target_tokens, net_output ctc_target_tokens, net_output
) )
...@@ -787,7 +815,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -787,7 +815,6 @@ class CtcCriterion(FairseqCriterion):
inter_xctc_logit = logit inter_xctc_logit = logit
inter_input_lengths = input_lengths inter_input_lengths = input_lengths
# all_ctc_logits["inter_xctc_logit%d" % i] = [inter_xctc_logit, inter_input_lengths]
inter_loss, target_inter_lprobs = self.get_ctc_loss( inter_loss, target_inter_lprobs = self.get_ctc_loss(
model, model,
inter_xctc_logit, inter_xctc_logit,
...@@ -796,6 +823,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -796,6 +823,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths, target_lengths,
loss_coef, loss_coef,
force_emit, force_emit,
target_mixup_flag
) )
inter_xctc_loss += inter_loss inter_xctc_loss += inter_loss
target_lprobs = target_inter_lprobs target_lprobs = target_inter_lprobs
...@@ -819,7 +847,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -819,7 +847,6 @@ class CtcCriterion(FairseqCriterion):
force_emit = logit[2] force_emit = logit[2]
else: else:
xctc_logit = logit xctc_logit = logit
# all_ctc_logits["xctc_logit"] = [xctc_logit, input_lengths]
xctc_loss, target_lprobs = self.get_ctc_loss( xctc_loss, target_lprobs = self.get_ctc_loss(
model, model,
...@@ -829,6 +856,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -829,6 +856,7 @@ class CtcCriterion(FairseqCriterion):
target_lengths, target_lengths,
loss_coef, loss_coef,
force_emit, force_emit,
target_mixup_flag
) )
logging_output["xctc_loss"] = utils.item(xctc_loss.data) logging_output["xctc_loss"] = utils.item(xctc_loss.data)
...@@ -928,21 +956,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -928,21 +956,8 @@ class CtcCriterion(FairseqCriterion):
xctc_self_distill_loss * self.xctc_self_distill_weight xctc_self_distill_loss * self.xctc_self_distill_weight
) )
ctc_mixup_consistent_loss = 0 # calculate KD loss for interpolation augmentation
inter_ctc_mixup_consistent_loss = 0 def get_mixup_consistent_loss(ctc_logit, non_padding_mask, mixup_pos, mixup_real_idx1, mixup_real_idx2):
if use_ctc and mixup is True:
mixup_coef = net_output["mixup"]["coef"]
mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"]
mixup_pos = mixup_idx1 != mixup_idx2
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
def get_ctc_mixup_consistent_loss(ctc_logit, non_padding_mask):
mixup_consistent_loss = 0 mixup_consistent_loss = 0
mixup_real_logit = ctc_logit[:, mixup_pos, :] mixup_real_logit = ctc_logit[:, mixup_pos, :]
no_mixup_logit = ctc_logit[:, ~mixup_pos, :] no_mixup_logit = ctc_logit[:, ~mixup_pos, :]
...@@ -958,9 +973,16 @@ class CtcCriterion(FairseqCriterion): ...@@ -958,9 +973,16 @@ class CtcCriterion(FairseqCriterion):
for logit, pad, coef in zip( for logit, pad, coef in zip(
mixup_target_logit, mixup_target_pad_mask, loss_coef mixup_target_logit, mixup_target_pad_mask, loss_coef
): ):
if self.ctc_mixup_consistent_hard_target:
loss = F.kl_div(
F.log_softmax(mixup_real_logit, dim=-1, dtype=torch.float32),
utils.distribution_soft_to_hard(logit.detach()).to(torch.float32),
log_target=False,
reduction="none",
)
else:
loss = F.kl_div( loss = F.kl_div(
F.log_softmax(mixup_real_logit, dim=-1, dtype=torch.float32), F.log_softmax(mixup_real_logit, dim=-1, dtype=torch.float32),
# F.log_softmax(logit, dim=-1, dtype=torch.float32),
F.log_softmax(logit.detach(), dim=-1, dtype=torch.float32), F.log_softmax(logit.detach(), dim=-1, dtype=torch.float32),
log_target=True, log_target=True,
reduction="none", reduction="none",
...@@ -970,12 +992,33 @@ class CtcCriterion(FairseqCriterion): ...@@ -970,12 +992,33 @@ class CtcCriterion(FairseqCriterion):
).sum() ).sum()
return mixup_consistent_loss return mixup_consistent_loss
ctc_mixup_consistent_loss = 0
inter_ctc_mixup_consistent_loss = 0
xctc_mixup_consistent_loss = 0
inter_xctc_mixup_consistent_loss = 0
if use_ctc and mixup is True:
mixup_coef = net_output["mixup"]["coef"]
mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"]
mixup_pos = mixup_idx1 != mixup_idx2
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
if self.ctc_mixup_consistent_weight > 0: if self.ctc_mixup_consistent_weight > 0:
ctc_logit = net_output["ctc_logit"][0] ctc_logit = net_output["ctc_logit"][0]
ctc_mixup_consistent_loss = get_ctc_mixup_consistent_loss(ctc_logit, non_padding_mask) ctc_mixup_consistent_loss = get_mixup_consistent_loss(ctc_logit, non_padding_mask, mixup_pos, mixup_real_idx1, mixup_real_idx2)
logging_output["ctc_mixup_consistent_loss"] = utils.item( logging_output["ctc_mixup_consistent_loss"] = utils.item(
ctc_mixup_consistent_loss.data ctc_mixup_consistent_loss.data
) )
if self.xctc_mixup_consistent_weight > 0:
xctc_logit = net_output["xctc_logit"][0]
xctc_mixup_consistent_loss = get_mixup_consistent_loss(xctc_logit, non_padding_mask, mixup_pos, mixup_real_idx1, mixup_real_idx2)
logging_output["xctc_mixup_consistent_loss"] = utils.item(
xctc_mixup_consistent_loss.data
)
if self.inter_ctc_mixup_consistent_weight > 0: if self.inter_ctc_mixup_consistent_weight > 0:
if inter_ctc_num > 0: if inter_ctc_num > 0:
...@@ -989,12 +1032,40 @@ class CtcCriterion(FairseqCriterion): ...@@ -989,12 +1032,40 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_logit = logit inter_ctc_logit = logit
inter_non_padding_mask = non_padding_mask inter_non_padding_mask = non_padding_mask
inter_ctc_mixup_consistent_loss += get_ctc_mixup_consistent_loss(inter_ctc_logit, inter_non_padding_mask) inter_ctc_mixup_consistent_loss += get_mixup_consistent_loss(
inter_ctc_logit,
inter_non_padding_mask,
mixup_pos,
mixup_real_idx1,
mixup_real_idx2)
logging_output["inter_ctc_mixup_consistent_loss"] = utils.item( logging_output["inter_ctc_mixup_consistent_loss"] = utils.item(
inter_ctc_mixup_consistent_loss.data inter_ctc_mixup_consistent_loss.data
) )
if self.inter_xctc_mixup_consistent_weight > 0:
if inter_xctc_num > 0:
logits = net_output["inter_xctc_logits"]
for i in range(inter_xctc_num):
logit = logits[i]
if type(logit) == list:
inter_xctc_logit = logit[0]
inter_non_padding_mask = ~logit[1] if logit[1] is not None else non_padding_mask
else:
inter_xctc_logit = logit
inter_non_padding_mask = non_padding_mask
inter_xctc_mixup_consistent_loss += get_mixup_consistent_loss(
inter_xctc_logit,
inter_non_padding_mask,
mixup_pos,
mixup_real_idx1,
mixup_real_idx2)
logging_output["inter_xctc_mixup_consistent_loss"] = utils.item(
inter_xctc_mixup_consistent_loss.data
)
if len(ctc_entropy) != 0: if len(ctc_entropy) != 0:
ctc_entropy = sum(ctc_entropy) / len(ctc_entropy) ctc_entropy = sum(ctc_entropy) / len(ctc_entropy)
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data) logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
...@@ -1012,6 +1083,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -1012,6 +1083,8 @@ class CtcCriterion(FairseqCriterion):
+ self.ctc_entropy_weight * ctc_entropy + self.ctc_entropy_weight * ctc_entropy
+ self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss + self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss
+ self.inter_ctc_mixup_consistent_weight * inter_ctc_mixup_consistent_loss + self.inter_ctc_mixup_consistent_weight * inter_ctc_mixup_consistent_loss
+ self.xctc_mixup_consistent_weight * xctc_mixup_consistent_loss
+ self.inter_xctc_mixup_consistent_weight * inter_xctc_mixup_consistent_loss
) )
if loss != 0: if loss != 0:
...@@ -1137,6 +1210,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -1137,6 +1210,13 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_mixup_consistent_loss = utils.item( inter_ctc_mixup_consistent_loss = utils.item(
sum(log.get("inter_ctc_mixup_consistent_loss", 0) for log in logging_outputs) sum(log.get("inter_ctc_mixup_consistent_loss", 0) for log in logging_outputs)
) )
xctc_mixup_consistent_loss = utils.item(
sum(log.get("xctc_mixup_consistent_loss", 0) for log in logging_outputs)
)
inter_xctc_mixup_consistent_loss = utils.item(
sum(log.get("inter_xctc_mixup_consistent_loss", 0) for log in logging_outputs)
)
all_ctc_loss_sum = utils.item( all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs) sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
) )
...@@ -1245,6 +1325,20 @@ class CtcCriterion(FairseqCriterion): ...@@ -1245,6 +1325,20 @@ class CtcCriterion(FairseqCriterion):
sample_size, sample_size,
round=3, round=3,
) )
if xctc_mixup_consistent_loss > 0:
metrics.log_scalar(
"xctc_mixup_consistent_loss",
xctc_mixup_consistent_loss / sample_size / math.log(2),
sample_size,
round=3,
)
if inter_xctc_mixup_consistent_loss > 0:
metrics.log_scalar(
"inter_xctc_mixup_consistent_loss",
inter_xctc_mixup_consistent_loss / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_scalar("ntokens", ntokens) metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences) metrics.log_scalar("nsentences", nsentences)
......
...@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): ...@@ -25,7 +25,7 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "the weight for consistency regularization of mixup"}, metadata={"help": "the weight for consistency regularization of mixup"},
) )
cal_mixup_loss: bool = field( mixup_no_hard_loss: bool = field(
default=False, default=False,
metadata={"help": "calculate the loss for the mixed samples"}, metadata={"help": "calculate the loss for the mixed samples"},
) )
...@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -71,7 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing, label_smoothing,
ignore_prefix_size=0, ignore_prefix_size=0,
report_accuracy=False, report_accuracy=False,
cal_mixup_loss=True, mixup_no_hard_loss=False,
mixup_consistent_weight=0.0, mixup_consistent_weight=0.0,
): ):
super().__init__(task) super().__init__(task)
...@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -79,7 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = float(label_smoothing) self.eps = float(label_smoothing)
self.ignore_prefix_size = ignore_prefix_size self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy self.report_accuracy = report_accuracy
self.cal_mixup_loss = cal_mixup_loss self.mixup_no_hard_loss = mixup_no_hard_loss
self.mixup_consistent_weight = mixup_consistent_weight self.mixup_consistent_weight = mixup_consistent_weight
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
...@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -173,7 +173,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup_coef = net_output[1]["mixup"]["coef"][mixup_flag] mixup_coef = net_output[1]["mixup"]["coef"][mixup_flag]
loss_coef = [mixup_coef, 1 - mixup_coef] loss_coef = [mixup_coef, 1 - mixup_coef]
if self.cal_mixup_loss: if not self.mixup_no_hard_loss:
for item_lprobs, item_target, item_coef in zip(mixup_lprobs, mixup_targets, loss_coef): for item_lprobs, item_target, item_coef in zip(mixup_lprobs, mixup_targets, loss_coef):
batch_size = item_target.size(0) batch_size = item_target.size(0)
item_loss, item_nll_loss = label_smoothed_nll_loss( item_loss, item_nll_loss = label_smoothed_nll_loss(
......
...@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -30,19 +30,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg: CtcCriterionConfig, cfg: CtcCriterionConfig,
ctc_weight=0.0, ctc_weight=0.0,
save_dir=None, save_dir=None,
cal_mixup_loss=True, mixup_no_hard_loss=False,
mixup_consistent_weight=0.0, mixup_consistent_weight=0.0,
only_train_enc_prob=0.0, only_train_enc_prob=0.0,
get_oracle_when_only_train_enc=False get_oracle_when_only_train_enc=False
): ):
super().__init__(task, sentence_avg, label_smoothing, super().__init__(task, sentence_avg, label_smoothing,
report_accuracy=True, report_accuracy=True,
cal_mixup_loss=cal_mixup_loss, mixup_no_hard_loss=mixup_no_hard_loss,
mixup_consistent_weight=mixup_consistent_weight) mixup_consistent_weight=mixup_consistent_weight)
self.report_accuracy = True self.report_accuracy = True
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir) self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight, save_dir, mixup_no_hard_loss)
self.save_dir = save_dir self.save_dir = save_dir
self.only_train_enc_prob = only_train_enc_prob self.only_train_enc_prob = only_train_enc_prob
......
...@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass): ...@@ -198,6 +198,12 @@ class CommonConfig(FairseqDataclass):
"help": "training steps in each epoch" "help": "training steps in each epoch"
} }
) )
sharded_data_load: bool = field(
default=False,
metadata={
"help": "Use sharded data for efficient data load"
},
)
@dataclass @dataclass
...@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass): ...@@ -812,6 +818,14 @@ class GenerationConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "weight for ctc probs for lm fusion"}, metadata={"help": "weight for ctc probs for lm fusion"},
) )
early_exit_count: int = field(
default=0,
metadata={"help": "early exit during decoding when n consecutive predictions are the same"},
)
early_exit_layer: int = field(
default=0,
metadata={"help": "early exit during decoding at layer n"},
)
# arguments for iterative refinement generator # arguments for iterative refinement generator
iter_decode_eos_penalty: float = field( iter_decode_eos_penalty: float = field(
......
...@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1019,12 +1019,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None
): ):
if hasattr(self, "ctc"): if hasattr(self, "ctc"):
import os
assert src_dict is not None assert src_dict is not None
self.ctc.set_infer( self.ctc.set_infer(
ctc_infer, ctc_infer,
post_process, post_process,
src_dict, src_dict,
path=path + ".ctc" if path is not None else None, path=os.path.splitext(path)[0] + ".ctc" if path is not None else None,
) )
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"): def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
......
...@@ -250,7 +250,6 @@ class CTCDecoder(object): ...@@ -250,7 +250,6 @@ class CTCDecoder(object):
logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from torchprofile import profile_macs from torchprofile import profile_macs
macs = profile_macs(self.model, [src_tokens, src_lengths]) macs = profile_macs(self.model, [src_tokens, src_lengths])
gmacs = macs / 1e9 gmacs = macs / 1e9
...@@ -269,20 +268,22 @@ class CTCDecoder(object): ...@@ -269,20 +268,22 @@ class CTCDecoder(object):
inter_logits = encoder_outs.get("inter_xctc_logits", []) inter_logits = encoder_outs.get("inter_xctc_logits", [])
if ctc_logit is None: if ctc_logit is None:
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1) ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
if len(inter_logits) > 0: if len(inter_logits) == 0:
inter_logits = encoder_outs.get("inter_ctc_logits", []) inter_logits = encoder_outs.get("inter_ctc_logits", [])
inter_logits_num = len(inter_logits) inter_logits_num = len(inter_logits)
encoder_padding_mask = encoder_outs["encoder_padding_mask"][0] encoder_padding_mask = encoder_outs["encoder_padding_mask"][0]
if self.ctc_inter_logit != 0: if self.ctc_inter_logit != 0:
assert inter_logits_num >= self.ctc_inter_logit
if inter_logits_num != 0: if inter_logits_num != 0:
assert self.ctc_inter_logit <= inter_logits_num
ctc_logit_item = inter_logits[-self.ctc_inter_logit] ctc_logit_item = inter_logits[-self.ctc_inter_logit]
if isinstance(ctc_logit_item, list): if isinstance(ctc_logit_item, list):
ctc_logit = ctc_logit_item[0].transpose(0, 1) ctc_logit = ctc_logit_item[0].transpose(0, 1)
if len(ctc_logit_item) >= 2: if len(ctc_logit_item) >= 2:
encoder_padding_mask = ctc_logit_item[1] encoder_padding_mask = ctc_logit_item[1]
else:
ctc_logit = ctc_logit_item.transpose(0, 1)
logit_length = (~encoder_padding_mask).long().sum(-1) logit_length = (~encoder_padding_mask).long().sum(-1)
finalized = [] finalized = []
...@@ -318,7 +319,7 @@ class CTCDecoder(object): ...@@ -318,7 +319,7 @@ class CTCDecoder(object):
else: else:
logit = inter_logits[i] logit = inter_logits[i]
inter_logits_prob = utils.log_softmax(logits.transpose(0, 1), -1) inter_logits_prob = utils.log_softmax(logit.transpose(0, 1), -1)
ctc_probs += inter_logits_prob ctc_probs += inter_logits_prob
topk_prob, topk_index = ctc_probs.topk(1, dim=2) topk_prob, topk_index = ctc_probs.topk(1, dim=2)
......
...@@ -888,6 +888,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -888,6 +888,8 @@ class S2TTransformerEncoder(FairseqEncoder):
super().__init__(None) super().__init__(None)
dim = args.encoder_embed_dim dim = args.encoder_embed_dim
self.source_dictionary = task.source_dictionary
self.target_dictionary = task.target_dictionary
layer_num = args.encoder_layers layer_num = args.encoder_layers
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__ p=args.dropout, module_name=self.__class__.__name__
...@@ -1027,6 +1029,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1027,6 +1029,7 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
), ),
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary
) )
setattr(self, "inter_ctc%d" % layer_idx, inter_ctc) setattr(self, "inter_ctc%d" % layer_idx, inter_ctc)
# inter_layer_norm = LayerNorm(dim) # inter_layer_norm = LayerNorm(dim)
...@@ -1038,6 +1041,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1038,6 +1041,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dim, dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary,
) )
if ( if (
getattr(args, "share_ctc_and_embed", False) getattr(args, "share_ctc_and_embed", False)
...@@ -1116,6 +1120,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1116,6 +1120,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else len(task.target_dictionary), else len(task.target_dictionary),
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_xctc else False, need_layernorm=True if self.inter_xctc else False,
dictionary=task.target_dictionary,
) )
if ( if (
...@@ -1375,6 +1380,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1375,6 +1380,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self.mixup_infer = False self.mixup_infer = False
self.rep_dict = dict() self.rep_dict = dict()
self.early_exit_count = 0
self.early_exit_layer_record = []
self.early_exit_layer = 0
@staticmethod @staticmethod
def build_encoder_layer(args): def build_encoder_layer(args):
return S2TTransformerEncoderLayer(args) return S2TTransformerEncoderLayer(args)
...@@ -1400,6 +1409,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1400,6 +1409,10 @@ class S2TTransformerEncoder(FairseqEncoder):
layer, "dump" layer, "dump"
) else None ) else None
print("Early exit layer.", file=fstream)
if self.early_exit_count != 0:
print("\n".join([str(l) for l in self.early_exit_layer_record]), file=fstream)
if self.gather_cos_sim: if self.gather_cos_sim:
print( print(
"\nCosine similarity of distance %d" % self.gather_cos_sim_dis, "\nCosine similarity of distance %d" % self.gather_cos_sim_dis,
...@@ -1540,13 +1553,17 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1540,13 +1553,17 @@ class S2TTransformerEncoder(FairseqEncoder):
self.mixup_infer = kwargs.get("mixup_infer", False) self.mixup_infer = kwargs.get("mixup_infer", False)
self.gather_cos_sim = kwargs.get("gather_cos_sim", False) self.gather_cos_sim = kwargs.get("gather_cos_sim", False)
self.gather_cos_sim_dis = kwargs.get("gather_cos_sim_dis", 2) self.gather_cos_sim_dis = kwargs.get("gather_cos_sim_dis", 2)
self.early_exit_layer = kwargs.get("early_exit_layer", 0)
if self.early_exit_layer != 0:
logger.info("Using the logit in layer %d to infer." % self.early_exit_layer)
if self.mixup_infer: if self.mixup_infer:
self.mixup_keep_org = True self.mixup_keep_org = True
def set_ctc_infer( def set_ctc_infer(
self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None, early_exit_count=0
): ):
self.early_exit_count = early_exit_count
if hasattr(self, "ctc"): if hasattr(self, "ctc"):
assert src_dict is not None assert src_dict is not None
self.ctc.set_infer( self.ctc.set_infer(
...@@ -1711,6 +1728,27 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1711,6 +1728,27 @@ class S2TTransformerEncoder(FairseqEncoder):
org_x = x[:, ~flag, :].mean(0) org_x = x[:, ~flag, :].mean(0)
rep_dict[layer_idx].append(org_x) rep_dict[layer_idx].append(org_x)
def early_exit_or_not(self, history, new_logit, count):
history.append(new_logit)
length = len(history)
if count == 0 or length < count:
return False
else:
# for logit in history[length - count: length - 1]:
# if new_logit.size() != logit.size() or not (new_logit == logit).all():
# return False
# return True
hit = 0
for logit in history[: length - 1]:
if new_logit.size() == logit.size() and (new_logit == logit).all():
hit += 1
if hit >= count:
return True
else:
return False
def forward(self, src_tokens, src_lengths=None, **kwargs): def forward(self, src_tokens, src_lengths=None, **kwargs):
layer_idx = -1 layer_idx = -1
...@@ -1727,6 +1765,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1727,6 +1765,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# (B, T, D) -> (T, B, D) # (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1) x = src_tokens.transpose(0, 1)
input_lengths = src_lengths input_lengths = src_lengths
org_bsz = x.size(1)
if ( if (
(self.training or self.mixup_infer) (self.training or self.mixup_infer)
...@@ -1821,6 +1860,16 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1821,6 +1860,16 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_oracle_mask = None xctc_oracle_mask = None
xctc_force_emit = None xctc_force_emit = None
# Infer early exit
batch_idx_dict = dict()
inter_ctc_logits_history = dict()
final_ctc_logits = dict()
final_encoder_padding_mask = dict()
early_exit_layer = dict()
for i in range(x.size(1)):
inter_ctc_logits_history[i] = []
batch_idx_dict[i] = i
for layer in self.layers: for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
...@@ -1879,7 +1928,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1879,7 +1928,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# Inter CTC # Inter CTC
if layer_idx in self.inter_ctc_layers: if layer_idx in self.inter_ctc_layers:
if self.inter_ctc_drop_prob > 0: if self.training and self.inter_ctc_drop_prob > 0:
p = torch.rand(1).uniform_() p = torch.rand(1).uniform_()
if p < self.inter_ctc_drop_prob: if p < self.inter_ctc_drop_prob:
break break
...@@ -1945,6 +1994,44 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1945,6 +1994,44 @@ class S2TTransformerEncoder(FairseqEncoder):
inter_ctc_logits.append(inter_logit) inter_ctc_logits.append(inter_logit)
if not self.training and self.early_exit_layer == layer_idx:
ctc_logit = inter_logit[0]
break
if not self.training and self.early_exit_count != 0:
predicts = inter_ctc.predict(inter_logit[0], encoder_padding_mask)
if len(inter_ctc_logits) < self.early_exit_count:
for i in range(x.size(1)):
inter_ctc_logits_history[i].append(predicts[i])
else:
if org_bsz == 1:
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[0], predicts[0], self.early_exit_count)
if early_exit_flag:
ctc_logit = inter_logit[0]
self.early_exit_layer_record.append(layer_idx)
break
else:
idx = 0
keep_idx = []
new_batch_idx_dict = dict()
for i in range(x.size(1)):
real_idx = batch_idx_dict[i]
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[real_idx], predicts[i], self.early_exit_count)
if early_exit_flag:
final_ctc_logits[real_idx] = inter_logit[0][:, i, :]
final_encoder_padding_mask[real_idx] = encoder_padding_mask[i, :]
early_exit_layer[real_idx] = layer_idx
else:
keep_idx.append(i)
new_batch_idx_dict[idx] = real_idx
idx += 1
if idx == 0:
break
if idx < x.size(1):
batch_idx_dict = new_batch_idx_dict
x = x[:, keep_idx, :].contiguous()
encoder_padding_mask = encoder_padding_mask[keep_idx, :].contiguous()
if layer_idx in self.compression_layers: if layer_idx in self.compression_layers:
ctc_prob = utils.softmax(logit, dim=-1) # (T B C) ctc_prob = utils.softmax(logit, dim=-1) # (T B C)
blank_prob = ctc_prob[:, :, 0] blank_prob = ctc_prob[:, :, 0]
...@@ -2133,6 +2220,25 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2133,6 +2220,25 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
self.show_debug(x, "x after xctc") self.show_debug(x, "x after xctc")
if not self.training and self.early_exit_count != 0 and org_bsz != 1:
if layer_idx == len(self.layers) + 1:
for i in range(x.size(1)):
real_idx = batch_idx_dict[i]
final_ctc_logits[real_idx] = ctc_logit[:, i, :]
final_encoder_padding_mask[real_idx] = encoder_padding_mask[i, :]
early_exit_layer[real_idx] = layer_idx - 1
output_logits = []
output_encoder_padding_mask = []
output_layers = []
for i in range(len(final_ctc_logits)):
output_logits.append(final_ctc_logits[i])
output_encoder_padding_mask.append(final_encoder_padding_mask[i])
output_layers.append(early_exit_layer[i])
ctc_logit = torch.stack(output_logits, dim=0).transpose(0, 1)
encoder_padding_mask = torch.stack(output_encoder_padding_mask, dim=0)
self.early_exit_layer_record.extend(output_layers)
if ctc_force_emit is not None: if ctc_force_emit is not None:
ctc_logit = [ctc_logit, None, ctc_force_emit] ctc_logit = [ctc_logit, None, ctc_force_emit]
...@@ -2174,6 +2280,11 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2174,6 +2280,11 @@ class S2TTransformerEncoder(FairseqEncoder):
if len(encoder_out["xctc_logit"]) == 0 if len(encoder_out["xctc_logit"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["xctc_logit"]] else [x.index_select(1, new_order) for x in encoder_out["xctc_logit"]]
) )
new_inter_ctc_logits = (
[]
if len(encoder_out["inter_ctc_logits"]) == 0
else [[x[0].index_select(1, new_order)].extend(x[1:]) if isinstance(x, list) else x.index_select(1, new_order) for x in encoder_out["inter_ctc_logits"] if x is not None]
)
new_encoder_padding_mask = ( new_encoder_padding_mask = (
[] []
if len(encoder_out["encoder_padding_mask"]) == 0 if len(encoder_out["encoder_padding_mask"]) == 0
...@@ -2200,6 +2311,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2200,6 +2311,7 @@ class S2TTransformerEncoder(FairseqEncoder):
"encoder_out": new_encoder_out, # T x B x C "encoder_out": new_encoder_out, # T x B x C
"ctc_logit": new_ctc_logit, # T x B x C "ctc_logit": new_ctc_logit, # T x B x C
"xctc_logit": new_xctc_logit, "xctc_logit": new_xctc_logit,
"inter_ctc_logits": new_inter_ctc_logits,
"encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C "encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
......
...@@ -74,6 +74,19 @@ class CTC(nn.Module): ...@@ -74,6 +74,19 @@ class CTC(nn.Module):
def argmax(self, x): def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1) return torch.argmax(self.ctc_projection(x), dim=-1)
def predict(self, logits, padding):
input_lengths = (~padding).sum(-1)
logits = logits.transpose(0, 1).float().contiguous()
predicts = []
for logit, inp_l in zip(logits, input_lengths):
toks = logit[:inp_l].argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.dictionary.bos()]
# pred_units_arr = logit[:inp_l].argmax(dim=-1)
predicts.append(pred_units_arr)
return predicts
def infer(self, logits_or_probs, lengths, tag=None): def infer(self, logits_or_probs, lengths, tag=None):
for lp, inp_l in zip( for lp, inp_l in zip(
logits_or_probs, logits_or_probs,
......
...@@ -130,6 +130,9 @@ class FairseqTask(object): ...@@ -130,6 +130,9 @@ class FairseqTask(object):
""" """
return cls(cfg, **kwargs) return cls(cfg, **kwargs)
def sharded_data_load(self):
return getattr(self.cfg, "sharded_data_load", False)
def has_sharded_data(self, split): def has_sharded_data(self, split):
return os.pathsep in getattr(self.cfg, "data", "") return os.pathsep in getattr(self.cfg, "data", "")
...@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask): ...@@ -619,6 +622,9 @@ class LegacyFairseqTask(FairseqTask):
""" """
return cls(args, **kwargs) return cls(args, **kwargs)
def sharded_data_load(self):
return getattr(self.args, "sharded_data_load", False)
def has_sharded_data(self, split): def has_sharded_data(self, split):
return os.pathsep in getattr(self.args, "data", "") return os.pathsep in getattr(self.args, "data", "")
......
...@@ -521,16 +521,25 @@ class Trainer(object): ...@@ -521,16 +521,25 @@ class Trainer(object):
disable_iterator_cache=False, disable_iterator_cache=False,
): ):
"""Return an EpochBatchIterator over the training set for a given epoch.""" """Return an EpochBatchIterator over the training set for a given epoch."""
if self.task.sharded_data_load():
datasets = self.cfg.dataset.train_subset.split(",")
curr_dataset = datasets[(epoch - 1) % len(datasets)]
logger.info("sharded loading the training subset {}".format(curr_dataset))
else:
curr_dataset = self.cfg.dataset.train_subset
load_dataset = load_dataset or self.task.sharded_data_load()
disable_iterator_cache = disable_iterator_cache or self.task.sharded_data_load()
if load_dataset: if load_dataset:
logger.info("loading train data for epoch {}".format(epoch)) logger.info("loading train data for epoch {}".format(epoch))
self.task.load_dataset( self.task.load_dataset(
self.cfg.dataset.train_subset, curr_dataset,
epoch=epoch, epoch=epoch,
combine=combine, combine=combine,
data_selector=data_selector, data_selector=data_selector,
) )
batch_iterator = self.task.get_batch_iterator( batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(self.cfg.dataset.train_subset), dataset=self.task.dataset(curr_dataset),
max_tokens=self.cfg.dataset.max_tokens, max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size, max_sentences=self.cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
......
...@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name): ...@@ -754,3 +754,18 @@ def freeze_parameters(module, freeze_module_name):
freeze_module_name = freeze_module_name.split(",") freeze_module_name = freeze_module_name.split(",")
for name in freeze_module_name: for name in freeze_module_name:
freeze_module_params_by_name(module, name) freeze_module_params_by_name(module, name)
def distribution_soft_to_hard(logit_or_prob):
argmax_prob = torch.argmax(logit_or_prob, dim=-1, keepdim=True)
hard_distribution = (
(
argmax_prob
== torch.arange(logit_or_prob.size(-1), device=logit_or_prob.device).unsqueeze(
0
)
)
.to(logit_or_prob.dtype)
)
return hard_distribution
\ No newline at end of file
...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for model in models: for model in models:
if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"): if hasattr(model, "encoder") and hasattr(model.encoder, "set_ctc_infer"):
model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece", model.encoder.set_ctc_infer(cfg.generation.ctc_infer, "sentencepiece",
src_dict, tgt_dict, translation_path) src_dict, tgt_dict, translation_path, cfg.generation.early_exit_count)
if hasattr(model, "encoder") and hasattr(model.encoder, "set_flag"): if hasattr(model, "encoder") and hasattr(model.encoder, "set_flag"):
model.encoder.set_flag( model.encoder.set_flag(
cal_localness=cfg.generation.cal_localness, cal_localness=cfg.generation.cal_localness,
...@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -120,6 +120,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
mixup_infer=cfg.generation.mixup_infer, mixup_infer=cfg.generation.mixup_infer,
gather_cos_sim=cfg.generation.gather_cos_sim, gather_cos_sim=cfg.generation.gather_cos_sim,
gather_cos_sim_dis=cfg.generation.gather_cos_sim_dis, gather_cos_sim_dis=cfg.generation.gather_cos_sim_dis,
early_exit_layer=cfg.generation.early_exit_layer,
) )
if hasattr(model, "decoder") and hasattr(model.decoder, "set_flag"): if hasattr(model, "decoder") and hasattr(model.decoder, "set_flag"):
model.decoder.set_flag( model.decoder.set_flag(
...@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None): ...@@ -246,9 +247,15 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
# Remove padding # Remove padding
if "src_tokens" in sample["net_input"]: if "src_tokens" in sample["net_input"]:
if sample["net_input"]["src_tokens"].dtype in [torch.int32, torch.int64]:
src_tokens = utils.strip_pad( src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() sample["net_input"]["src_tokens"][i, :], src_dict.pad()
) )
elif "transcript" in sample:
src_tokens = utils.strip_pad(
sample["transcript"]["tokens"][i, :], src_dict.pad()
)
else: else:
src_tokens = None src_tokens = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论