Commit f7a9b1a0 by xuchen

for BIL-CTC

parent 6ca43d9e
......@@ -414,7 +414,8 @@ def process(args):
asr_spm_filename = None
gen_vocab_flag = True
if task == "st" and args.add_src:
# if task == "st" and args.add_src:
if args.add_src:
if args.share:
if args.st_spm_prefix is not None:
gen_vocab_flag = False
......@@ -450,7 +451,8 @@ def process(args):
quoting=csv.QUOTE_NONE,
)
if task == "st" and args.add_src and args.share:
# if task == "st" and args.add_src and args.share:
if args.add_src and args.share:
for e in reader:
src_utt = dict(e)["src_text"]
tgt_utt = dict(e)["tgt_text"]
......
......@@ -38,7 +38,7 @@ SPLITS = [
"test-other",
]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker", "src_text"]
def process(args):
......@@ -169,8 +169,11 @@ def process(args):
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
tgt_text = [(dict(e))["tgt_text"] for e in reader]
train_text.extend(tgt_text)
for e in reader:
e = dict(e)
train_text.append(e["tgt_text"])
if "src_text" in e:
train_text.append(e["src_text"])
for t in train_text:
f.write(t + "\n")
gen_vocab(
......
......@@ -421,9 +421,9 @@ class TextualEncoder(FairseqEncoder):
self.pae_gt_decay = False
decay_params = getattr(args, "xctc_pae_ground_truth_ratio_decay", None)
self.gt_ratio = self.pae_ground_truth_ratio
if self.pae_ground_truth_ratio != 0:
self.pae_adaptive_gt = getattr(args, "xctc_pae_ground_truth_ratio_adaptive", False)
self.pae_gt_only_mistake = getattr(args, "xctc_pae_ground_truth_only_mistake", False)
if self.pae_ground_truth_ratio != 0:
if decay_params is not None and len(decay_params.split(":")) == 3:
self.pae_gt_decay = True
params = [float(item) for item in decay_params.split(":")]
......
......@@ -1121,6 +1121,12 @@ class S2TTransformerEncoder(FairseqEncoder):
self.xctc.ctc_projection.weight = embed_tokens.weight
self.inter_xctc_layers = []
self.pae_adaptive_gt = getattr(
args, "xctc_pae_ground_truth_ratio_adaptive", False
)
self.pae_gt_only_mistake = getattr(
args, "xctc_pae_ground_truth_only_mistake", False
)
inter_xctc_layers = getattr(args, "inter_xctc_layers", None)
if (
getattr(args, "disable_xctc", False) is False
......@@ -1138,13 +1144,8 @@ class S2TTransformerEncoder(FairseqEncoder):
self.pae_unnorm_input = getattr(args, "pae_unnorm_input", False)
self.pae_gt_decay = False
decay_params = getattr(args, "xctc_pae_ground_truth_ratio_decay", None)
if self.xctc_pae_ground_truth_ratio != 0:
self.pae_adaptive_gt = getattr(
args, "xctc_pae_ground_truth_ratio_adaptive", False
)
self.pae_gt_only_mistake = getattr(
args, "xctc_pae_ground_truth_only_mistake", False
)
if decay_params is not None and len(decay_params.split(":")) == 3:
self.pae_gt_decay = True
params = [float(item) for item in decay_params.split(":")]
......@@ -1806,9 +1807,9 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_logit = None
inter_xctc_logits = []
# CTC alignment
oracle = None
oracle_mask = None
force_emit = None
ctc_oracle = None
ctc_oracle_mask = None
ctc_force_emit = None
xctc_oracle = None
xctc_oracle_mask = None
xctc_force_emit = None
......@@ -1899,20 +1900,39 @@ class S2TTransformerEncoder(FairseqEncoder):
ctc_alignment_oracle is not None
and ctc_alignment_oracle.get("ctc", None) is not None
):
if oracle is None:
oracle, best_aligns_pad = ctc_alignment_oracle["ctc"]
oracle_mask = (
torch.rand(oracle.size(), device=oracle.device)
< self.ctc_pae_ground_truth_ratio
).bool()
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1)
if ctc_oracle is None:
(
ctc_oracle,
best_aligns_pad,
mistake_flag,
mistake_ratio,
) = ctc_alignment_oracle["ctc"]
if self.pae_adaptive_gt:
prob = (
self.ctc_pae_ground_truth_ratio
* mistake_ratio.unsqueeze(-1)
)
else:
inter_logit = [logit, None, force_emit]
prob = self.ctc_pae_ground_truth_ratio
ctc_oracle_mask = (
torch.rand(
ctc_oracle.size(), device=ctc_oracle.device
)
< prob
).bool()
if self.pae_gt_only_mistake:
ctc_oracle_mask.masked_fill_(~mistake_flag, False)
ctc_force_emit = best_aligns_pad.masked_fill(
~ctc_oracle_mask, -1
)
inter_logit = [logit, None, ctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x
if pae.adapter_type != "none":
x, encoder_padding_mask = pae(
[pae_input, logit], encoder_padding_mask, oracle, oracle_mask
[pae_input, logit], encoder_padding_mask, ctc_oracle, ctc_oracle_mask
)
self.show_debug(x, "x after pae")
......@@ -2106,8 +2126,8 @@ class S2TTransformerEncoder(FairseqEncoder):
)
self.show_debug(x, "x after xctc")
if force_emit is not None:
ctc_logit = [ctc_logit, None, force_emit]
if ctc_force_emit is not None:
ctc_logit = [ctc_logit, None, ctc_force_emit]
if xctc_force_emit is not None:
xctc_logit = [xctc_logit, None, xctc_force_emit]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论