Commit f7a9b1a0 by xuchen

for BIL-CTC

parent 6ca43d9e
...@@ -414,7 +414,8 @@ def process(args): ...@@ -414,7 +414,8 @@ def process(args):
asr_spm_filename = None asr_spm_filename = None
gen_vocab_flag = True gen_vocab_flag = True
if task == "st" and args.add_src: # if task == "st" and args.add_src:
if args.add_src:
if args.share: if args.share:
if args.st_spm_prefix is not None: if args.st_spm_prefix is not None:
gen_vocab_flag = False gen_vocab_flag = False
...@@ -450,7 +451,8 @@ def process(args): ...@@ -450,7 +451,8 @@ def process(args):
quoting=csv.QUOTE_NONE, quoting=csv.QUOTE_NONE,
) )
if task == "st" and args.add_src and args.share: # if task == "st" and args.add_src and args.share:
if args.add_src and args.share:
for e in reader: for e in reader:
src_utt = dict(e)["src_text"] src_utt = dict(e)["src_text"]
tgt_utt = dict(e)["tgt_text"] tgt_utt = dict(e)["tgt_text"]
......
...@@ -38,7 +38,7 @@ SPLITS = [ ...@@ -38,7 +38,7 @@ SPLITS = [
"test-other", "test-other",
] ]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker", "src_text"]
def process(args): def process(args):
...@@ -169,8 +169,11 @@ def process(args): ...@@ -169,8 +169,11 @@ def process(args):
lineterminator="\n", lineterminator="\n",
quoting=csv.QUOTE_NONE, quoting=csv.QUOTE_NONE,
) )
tgt_text = [(dict(e))["tgt_text"] for e in reader] for e in reader:
train_text.extend(tgt_text) e = dict(e)
train_text.append(e["tgt_text"])
if "src_text" in e:
train_text.append(e["src_text"])
for t in train_text: for t in train_text:
f.write(t + "\n") f.write(t + "\n")
gen_vocab( gen_vocab(
......
...@@ -421,9 +421,9 @@ class TextualEncoder(FairseqEncoder): ...@@ -421,9 +421,9 @@ class TextualEncoder(FairseqEncoder):
self.pae_gt_decay = False self.pae_gt_decay = False
decay_params = getattr(args, "xctc_pae_ground_truth_ratio_decay", None) decay_params = getattr(args, "xctc_pae_ground_truth_ratio_decay", None)
self.gt_ratio = self.pae_ground_truth_ratio self.gt_ratio = self.pae_ground_truth_ratio
if self.pae_ground_truth_ratio != 0:
self.pae_adaptive_gt = getattr(args, "xctc_pae_ground_truth_ratio_adaptive", False) self.pae_adaptive_gt = getattr(args, "xctc_pae_ground_truth_ratio_adaptive", False)
self.pae_gt_only_mistake = getattr(args, "xctc_pae_ground_truth_only_mistake", False) self.pae_gt_only_mistake = getattr(args, "xctc_pae_ground_truth_only_mistake", False)
if self.pae_ground_truth_ratio != 0:
if decay_params is not None and len(decay_params.split(":")) == 3: if decay_params is not None and len(decay_params.split(":")) == 3:
self.pae_gt_decay = True self.pae_gt_decay = True
params = [float(item) for item in decay_params.split(":")] params = [float(item) for item in decay_params.split(":")]
......
...@@ -1121,6 +1121,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1121,6 +1121,12 @@ class S2TTransformerEncoder(FairseqEncoder):
self.xctc.ctc_projection.weight = embed_tokens.weight self.xctc.ctc_projection.weight = embed_tokens.weight
self.inter_xctc_layers = [] self.inter_xctc_layers = []
self.pae_adaptive_gt = getattr(
args, "xctc_pae_ground_truth_ratio_adaptive", False
)
self.pae_gt_only_mistake = getattr(
args, "xctc_pae_ground_truth_only_mistake", False
)
inter_xctc_layers = getattr(args, "inter_xctc_layers", None) inter_xctc_layers = getattr(args, "inter_xctc_layers", None)
if ( if (
getattr(args, "disable_xctc", False) is False getattr(args, "disable_xctc", False) is False
...@@ -1138,13 +1144,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1138,13 +1144,8 @@ class S2TTransformerEncoder(FairseqEncoder):
self.pae_unnorm_input = getattr(args, "pae_unnorm_input", False) self.pae_unnorm_input = getattr(args, "pae_unnorm_input", False)
self.pae_gt_decay = False self.pae_gt_decay = False
decay_params = getattr(args, "xctc_pae_ground_truth_ratio_decay", None) decay_params = getattr(args, "xctc_pae_ground_truth_ratio_decay", None)
if self.xctc_pae_ground_truth_ratio != 0: if self.xctc_pae_ground_truth_ratio != 0:
self.pae_adaptive_gt = getattr(
args, "xctc_pae_ground_truth_ratio_adaptive", False
)
self.pae_gt_only_mistake = getattr(
args, "xctc_pae_ground_truth_only_mistake", False
)
if decay_params is not None and len(decay_params.split(":")) == 3: if decay_params is not None and len(decay_params.split(":")) == 3:
self.pae_gt_decay = True self.pae_gt_decay = True
params = [float(item) for item in decay_params.split(":")] params = [float(item) for item in decay_params.split(":")]
...@@ -1806,9 +1807,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1806,9 +1807,9 @@ class S2TTransformerEncoder(FairseqEncoder):
xctc_logit = None xctc_logit = None
inter_xctc_logits = [] inter_xctc_logits = []
# CTC alignment # CTC alignment
oracle = None ctc_oracle = None
oracle_mask = None ctc_oracle_mask = None
force_emit = None ctc_force_emit = None
xctc_oracle = None xctc_oracle = None
xctc_oracle_mask = None xctc_oracle_mask = None
xctc_force_emit = None xctc_force_emit = None
...@@ -1899,20 +1900,39 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1899,20 +1900,39 @@ class S2TTransformerEncoder(FairseqEncoder):
ctc_alignment_oracle is not None ctc_alignment_oracle is not None
and ctc_alignment_oracle.get("ctc", None) is not None and ctc_alignment_oracle.get("ctc", None) is not None
): ):
if oracle is None: if ctc_oracle is None:
oracle, best_aligns_pad = ctc_alignment_oracle["ctc"] (
oracle_mask = ( ctc_oracle,
torch.rand(oracle.size(), device=oracle.device) best_aligns_pad,
< self.ctc_pae_ground_truth_ratio mistake_flag,
).bool() mistake_ratio,
force_emit = best_aligns_pad.masked_fill(~oracle_mask, -1) ) = ctc_alignment_oracle["ctc"]
if self.pae_adaptive_gt:
prob = (
self.ctc_pae_ground_truth_ratio
* mistake_ratio.unsqueeze(-1)
)
else: else:
inter_logit = [logit, None, force_emit] prob = self.ctc_pae_ground_truth_ratio
ctc_oracle_mask = (
torch.rand(
ctc_oracle.size(), device=ctc_oracle.device
)
< prob
).bool()
if self.pae_gt_only_mistake:
ctc_oracle_mask.masked_fill_(~mistake_flag, False)
ctc_force_emit = best_aligns_pad.masked_fill(
~ctc_oracle_mask, -1
)
inter_logit = [logit, None, ctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x pae_input = x if self.pae_unnorm_input else norm_x
if pae.adapter_type != "none": if pae.adapter_type != "none":
x, encoder_padding_mask = pae( x, encoder_padding_mask = pae(
[pae_input, logit], encoder_padding_mask, oracle, oracle_mask [pae_input, logit], encoder_padding_mask, ctc_oracle, ctc_oracle_mask
) )
self.show_debug(x, "x after pae") self.show_debug(x, "x after pae")
...@@ -2106,8 +2126,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2106,8 +2126,8 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
self.show_debug(x, "x after xctc") self.show_debug(x, "x after xctc")
if force_emit is not None: if ctc_force_emit is not None:
ctc_logit = [ctc_logit, None, force_emit] ctc_logit = [ctc_logit, None, ctc_force_emit]
if xctc_force_emit is not None: if xctc_force_emit is not None:
xctc_logit = [xctc_logit, None, xctc_force_emit] xctc_logit = [xctc_logit, None, xctc_force_emit]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论