Commit 1018d402 by xuchen

acc update, including additional CTC linear transform, bug fix, and so on

parent b984a889
......@@ -294,6 +294,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1))
done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir}
--source-lang ${src_lang}
......
......@@ -3,7 +3,7 @@ encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
pds-layers: 2_6_6_4
pds-layers: 2_4_8_4
pds-ratios: 2_2_2_0
pds-fusion: False
pds-fusion-method: all_conv2
......
warmup-updates: 1000
lr: 1e-4
subsampling-type: conv1d
subsampling-layers: 1
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......@@ -12,7 +12,7 @@ if [ "$#" -eq 1 ]; then
fi
sacrebleu=1
ctc_infer=0
ctc_infer=1
n_average=10
beam_size=5
infer_ctc_weight=0
......
......@@ -79,7 +79,7 @@ step_valid=0
bleu_valid=0
# Decoding Settings
batch_size=1
batch_size=0
sacrebleu=1
dec_model=checkpoint_best.pt
ctc_infer=0
......@@ -93,7 +93,7 @@ infer_debug=0
infer_score=0
infer_tag=
infer_parameter=
#infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
#infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
# Parsing Options
if [[ ${share_dict} -eq 1 ]]; then
......@@ -289,6 +289,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1))
done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir}
--source-lang ${src_lang}
......
......@@ -554,6 +554,11 @@ def process_joint(args):
with NamedTemporaryFile(mode="w") as f:
for lang in languages:
tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv"
gather_tsv = output_root / "train_all.tsv"
if os.path.exists(gather_tsv):
tsv_path = gather_tsv
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
......@@ -566,6 +571,9 @@ def process_joint(args):
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
f.write(src_utt + "\n")
if os.path.exists(gather_tsv):
break
special_symbols = None
if args.task == 'st':
......
......@@ -391,11 +391,13 @@ class CtcCriterion(FairseqCriterion):
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
target_lengths = pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
else:
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
if isinstance(ctc_logit, list):
input_lengths = ((~ctc_logit[1]).long().sum(-1)
if ctc_logit[1] is not None
else input_lengths
)
ctc_logit = ctc_logit[0]
ctc_alignment_oracle["ctc"] = get_ctc_align(
ctc_logit,
......@@ -416,11 +418,13 @@ class CtcCriterion(FairseqCriterion):
xctc_logit = encoder_out["inter_xctc_logits"][-1]
if xctc_logit is not None:
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
else:
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
if isinstance(xctc_logit, list):
input_lengths = ((~xctc_logit[1]).long().sum(-1)
if xctc_logit[1] is not None
else input_lengths
)
xctc_logit = xctc_logit[0]
tokens = self.get_ctc_target_text(sample)
target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
......
......@@ -512,6 +512,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="share the weight of all intermediate ctc modules",
)
parser.add_argument(
"--no-inter-ctc-norm",
action="store_true",
help="do not using the layer norm between inter ctc and final norm",
)
parser.add_argument(
"--share-inter-ctc-norm",
action="store_true",
help="share the weight of layer norm between inter ctc and final norm",
......@@ -780,6 +785,19 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=str,
help="position to insert the language tag, input or before prediction",
)
parser.add_argument(
"--bil-ctc-pae-fusion",
default="serial",
choices=["parallel", "serial"],
type=str,
help="position to insert the language tag, input or before prediction",
)
parser.add_argument(
"--ctc-linear-transform",
action="store_true",
help="introduce additional linear transform in CTC",
)
pass
@classmethod
......@@ -956,6 +974,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else:
self.history = None
ctc_linear_transform = getattr(args, "ctc_linear_transform", False)
self.pae_ground_truth_ratio = getattr(
args, "ctc_pae_ground_truth_ratio", 0
) + getattr(args, "xctc_pae_ground_truth_ratio", 0)
......@@ -972,6 +991,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False,
ctc_linear_transform=ctc_linear_transform
)
if (
......@@ -987,15 +1007,17 @@ class S2TTransformerEncoder(FairseqEncoder):
self.inter_ctc_layers = []
if args.inter_ctc_layers is not None:
self.share_inter_ctc_norm = args.share_inter_ctc_norm
if self.share_inter_ctc_norm:
logger.info(
"Share layer norm in intermediate CTC %s." % args.inter_ctc_layers
)
else:
logger.info(
"Do not Share layer norm in intermediate CTC %s."
% args.inter_ctc_layers
)
self.no_inter_ctc_norm = getattr(args, "no_inter_ctc_norm", False)
if not self.no_inter_ctc_norm:
if self.share_inter_ctc_norm:
logger.info(
"Share layer norm in intermediate CTC %s." % args.inter_ctc_layers
)
else:
logger.info(
"Do not Share layer norm in intermediate CTC %s."
% args.inter_ctc_layers
)
inter_ctc_layers = args.inter_ctc_layers.split(",")
inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "")
......@@ -1037,19 +1059,17 @@ class S2TTransformerEncoder(FairseqEncoder):
)
),
dropout=args.dropout,
dictionary=task.source_dictionary
dictionary=task.source_dictionary,
ctc_linear_transform=ctc_linear_transform
)
setattr(self, "inter_ctc%d" % layer_idx, inter_ctc)
# inter_layer_norm = LayerNorm(dim)
# setattr(
# self, "inter_layer_norm%d" % layer_idx, inter_layer_norm
# )
else:
self.ctc = CTC(
dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
dictionary=task.source_dictionary,
ctc_linear_transform=ctc_linear_transform
)
if (
getattr(args, "share_ctc_and_embed", False)
......@@ -1129,6 +1149,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout,
need_layernorm=True if self.inter_xctc else False,
dictionary=task.target_dictionary,
ctc_linear_transform=ctc_linear_transform
)
if (
......@@ -1192,15 +1213,16 @@ class S2TTransformerEncoder(FairseqEncoder):
)
self.share_inter_xctc_norm = args.share_inter_xctc_norm
if self.share_inter_xctc_norm:
logger.info(
"Share layer norm in intermediate XCTC %s." % inter_xctc_layers
)
else:
logger.info(
"Do not Share layer norm in intermediate XCTC %s."
% inter_xctc_layers
)
if not self.no_inter_ctc_norm:
if self.share_inter_xctc_norm:
logger.info(
"Share layer norm in intermediate XCTC %s." % inter_xctc_layers
)
else:
logger.info(
"Do not Share layer norm in intermediate XCTC %s."
% inter_xctc_layers
)
inter_xctc_layers = inter_xctc_layers.split(",")
for layer_idx in inter_xctc_layers:
......@@ -1221,6 +1243,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dim,
dictionary_size=len(task.target_dictionary),
dropout=args.dropout,
ctc_linear_transform=ctc_linear_transform
)
if (
......@@ -1374,6 +1397,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.compression_stat = False
self.enc_lang_tag_pos = getattr(args, "enc_lang_tag_pos", "input")
self.bil_ctc_pae_fusion = getattr(args, "bil_ctc_pae_fusion", "serial")
self.log_flag_dict = dict()
# gather cosine similarity
......@@ -1975,6 +1999,8 @@ class S2TTransformerEncoder(FairseqEncoder):
is_top=True,
)
input_for_ctc = x
ctc_pae_output = None
# Inter CTC
if layer_idx in self.inter_ctc_layers:
if self.training and self.inter_ctc_drop_prob > 0:
......@@ -1989,13 +2015,16 @@ class S2TTransformerEncoder(FairseqEncoder):
inter_ctc = getattr(self, "inter_ctc%d" % layer_idx)
pae = getattr(self, "pae%d" % layer_idx)
if self.share_inter_ctc_norm:
layer_norm = self.layer_norm
if self.no_inter_ctc_norm:
x_for_logit = input_for_ctc
else:
layer_norm = getattr(self, "ctc_norm%d" % layer_idx)
norm_x = layer_norm(x)
if self.share_inter_ctc_norm:
layer_norm = self.layer_norm
else:
layer_norm = getattr(self, "ctc_norm%d" % layer_idx)
x_for_logit = layer_norm(input_for_ctc)
logit, logit_list = inter_ctc.forward_with_lang_tag(norm_x, encoder_padding_mask,
logit, logit_list = inter_ctc.forward_with_lang_tag(x_for_logit, encoder_padding_mask,
"CTC Layer %d" % layer_idx, "source",
self.embed_tokens(src_lang_idx)
if src_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
......@@ -2035,12 +2064,12 @@ class S2TTransformerEncoder(FairseqEncoder):
# logit_list = [logit, None, ctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x
ctc_pae_input = input_for_ctc if self.pae_unnorm_input else x_for_logit
if pae.adapter_type != "none":
x, encoder_padding_mask = pae(
[pae_input, logit], encoder_padding_mask, ctc_oracle, ctc_oracle_mask
ctc_pae_output, encoder_padding_mask = pae(
[ctc_pae_input, logit], encoder_padding_mask, ctc_oracle, ctc_oracle_mask
)
self.show_debug(x, "x after pae")
self.show_debug(ctc_pae_output, "x after pae")
inter_ctc_logits.append(logit_list)
......@@ -2177,19 +2206,27 @@ class S2TTransformerEncoder(FairseqEncoder):
x = x.transpose(0, 1)
# Inter XCTC
if self.bil_ctc_pae_fusion == "serial" and ctc_pae_output is not None:
input_for_xctc = ctc_pae_output
else:
input_for_xctc = input_for_ctc
xctc_pae_output = None
if layer_idx in self.inter_xctc_layers:
if self.inter_xctc_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.inter_xctc_drop_prob:
break
if self.share_inter_xctc_norm:
norm_x = self.layer_norm(x)
if self.no_inter_ctc_norm:
x_for_logit = input_for_xctc
else:
norm = getattr(self, "xctc_norm%d" % layer_idx)
norm_x = norm(x)
if self.share_inter_xctc_norm:
layer_norm = self.layer_norm
else:
layer_norm = getattr(self, "xctc_norm%d" % layer_idx)
x_for_logit = layer_norm(input_for_xctc)
logit, logit_list = self.xctc.forward_with_lang_tag(norm_x, encoder_padding_mask,
logit, logit_list = self.xctc.forward_with_lang_tag(x_for_logit, encoder_padding_mask,
"XCTC Layer %d" % layer_idx, "target",
self.embed_tokens(tgt_lang_idx)
if tgt_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
......@@ -2228,10 +2265,14 @@ class S2TTransformerEncoder(FairseqEncoder):
)
# logit_list = [logit, None, xctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x
if self.pae_unnorm_input:
xctc_pae_input = input_for_xctc
else:
xctc_pae_input = x_for_logit
if self.xctc_pae.adapter_type != "none":
x, encoder_padding_mask = self.xctc_pae(
[pae_input, logit],
xctc_pae_output, encoder_padding_mask = self.xctc_pae(
[xctc_pae_input, logit],
encoder_padding_mask,
xctc_oracle,
xctc_oracle_mask,
......@@ -2239,6 +2280,13 @@ class S2TTransformerEncoder(FairseqEncoder):
inter_xctc_logits.append(logit_list)
if self.bil_ctc_pae_fusion == "parallel" and ctc_pae_output is not None and xctc_pae_output is not None:
x = (ctc_pae_output + xctc_pae_output) / 2
elif ctc_pae_output is None and xctc_pae_output is not None:
x = xctc_pae_output
elif ctc_pae_output is not None and xctc_pae_output is None:
x = ctc_pae_output
if self.history is not None:
self.history.push(x)
......
......@@ -16,10 +16,15 @@ logger = logging.getLogger(__name__)
class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout,
need_layernorm=False, dictionary=None):
need_layernorm=False, dictionary=None, ctc_linear_transform=False):
super(CTC, self).__init__()
self.embed_dim = embed_dim
if ctc_linear_transform is True:
self.ctc_linear_transform = nn.Linear(embed_dim, embed_dim)
else:
self.ctc_linear_transform = None
self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
......@@ -53,6 +58,9 @@ class CTC(nn.Module):
self.save_stream = None
def forward(self, x, padding=None, tag=None, is_top=False):
if self.ctc_linear_transform is not None:
x = self.ctc_linear_transform(x)
if self.need_layernorm:
x = self.LayerNorm(x)
......
......@@ -556,7 +556,7 @@ class SpeechToTextTask(LegacyFairseqTask):
if hasattr(generator, "symbols_to_strip_from_output"):
symbols_to_strip_from_output = generator.symbols_to_strip_from_output
else:
symbols_to_strip_from_output = generator.eos
symbols_to_strip_from_output = [generator.eos]
s = self.tgt_dict.string(
toks.int().cpu(),
remove_bpe,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论