Commit 1018d402 by xuchen

acc update, including additional CTC linear transform, bug fix, and so on

parent b984a889
...@@ -294,6 +294,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -294,6 +294,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1)) idx=$((idx + 1))
done done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir} ${data_dir}
--source-lang ${src_lang} --source-lang ${src_lang}
......
...@@ -3,7 +3,7 @@ encoder-type: pds ...@@ -3,7 +3,7 @@ encoder-type: pds
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
pds-layers: 2_6_6_4 pds-layers: 2_4_8_4
pds-ratios: 2_2_2_0 pds-ratios: 2_2_2_0
pds-fusion: False pds-fusion: False
pds-fusion-method: all_conv2 pds-fusion-method: all_conv2
......
warmup-updates: 1000
lr: 1e-4
subsampling-type: conv1d
subsampling-layers: 1
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
...@@ -12,7 +12,7 @@ if [ "$#" -eq 1 ]; then ...@@ -12,7 +12,7 @@ if [ "$#" -eq 1 ]; then
fi fi
sacrebleu=1 sacrebleu=1
ctc_infer=0 ctc_infer=1
n_average=10 n_average=10
beam_size=5 beam_size=5
infer_ctc_weight=0 infer_ctc_weight=0
......
...@@ -79,7 +79,7 @@ step_valid=0 ...@@ -79,7 +79,7 @@ step_valid=0
bleu_valid=0 bleu_valid=0
# Decoding Settings # Decoding Settings
batch_size=1 batch_size=0
sacrebleu=1 sacrebleu=1
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
ctc_infer=0 ctc_infer=0
...@@ -93,7 +93,7 @@ infer_debug=0 ...@@ -93,7 +93,7 @@ infer_debug=0
infer_score=0 infer_score=0
infer_tag= infer_tag=
infer_parameter= infer_parameter=
#infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy" #infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
# Parsing Options # Parsing Options
if [[ ${share_dict} -eq 1 ]]; then if [[ ${share_dict} -eq 1 ]]; then
...@@ -289,6 +289,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -289,6 +289,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1)) idx=$((idx + 1))
done done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir} ${data_dir}
--source-lang ${src_lang} --source-lang ${src_lang}
......
...@@ -554,6 +554,11 @@ def process_joint(args): ...@@ -554,6 +554,11 @@ def process_joint(args):
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for lang in languages: for lang in languages:
tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv" tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv"
gather_tsv = output_root / "train_all.tsv"
if os.path.exists(gather_tsv):
tsv_path = gather_tsv
df = load_df_from_tsv(tsv_path) df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]: for t in df["tgt_text"]:
f.write(t + "\n") f.write(t + "\n")
...@@ -566,6 +571,9 @@ def process_joint(args): ...@@ -566,6 +571,9 @@ def process_joint(args):
src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" ")) src_utt = " ".join(src_utt.split(" "))
f.write(src_utt + "\n") f.write(src_utt + "\n")
if os.path.exists(gather_tsv):
break
special_symbols = None special_symbols = None
if args.task == 'st': if args.task == 'st':
......
...@@ -391,11 +391,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -391,11 +391,13 @@ class CtcCriterion(FairseqCriterion):
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx) pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
target_lengths = pad_mask.sum(-1) target_lengths = pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out: input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
non_padding_mask = ~encoder_out["ctc_padding_mask"][0] if isinstance(ctc_logit, list):
else: input_lengths = ((~ctc_logit[1]).long().sum(-1)
non_padding_mask = ~encoder_out["encoder_padding_mask"][0] if ctc_logit[1] is not None
input_lengths = non_padding_mask.long().sum(-1) else input_lengths
)
ctc_logit = ctc_logit[0]
ctc_alignment_oracle["ctc"] = get_ctc_align( ctc_alignment_oracle["ctc"] = get_ctc_align(
ctc_logit, ctc_logit,
...@@ -416,11 +418,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -416,11 +418,13 @@ class CtcCriterion(FairseqCriterion):
xctc_logit = encoder_out["inter_xctc_logits"][-1] xctc_logit = encoder_out["inter_xctc_logits"][-1]
if xctc_logit is not None: if xctc_logit is not None:
if "ctc_padding_mask" in encoder_out: input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
non_padding_mask = ~encoder_out["ctc_padding_mask"][0] if isinstance(xctc_logit, list):
else: input_lengths = ((~xctc_logit[1]).long().sum(-1)
non_padding_mask = ~encoder_out["encoder_padding_mask"][0] if xctc_logit[1] is not None
input_lengths = non_padding_mask.long().sum(-1) else input_lengths
)
xctc_logit = xctc_logit[0]
tokens = self.get_ctc_target_text(sample) tokens = self.get_ctc_target_text(sample)
target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx) target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
......
...@@ -512,6 +512,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -512,6 +512,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="share the weight of all intermediate ctc modules", help="share the weight of all intermediate ctc modules",
) )
parser.add_argument( parser.add_argument(
"--no-inter-ctc-norm",
action="store_true",
help="do not using the layer norm between inter ctc and final norm",
)
parser.add_argument(
"--share-inter-ctc-norm", "--share-inter-ctc-norm",
action="store_true", action="store_true",
help="share the weight of layer norm between inter ctc and final norm", help="share the weight of layer norm between inter ctc and final norm",
...@@ -780,6 +785,19 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -780,6 +785,19 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=str, type=str,
help="position to insert the language tag, input or before prediction", help="position to insert the language tag, input or before prediction",
) )
parser.add_argument(
"--bil-ctc-pae-fusion",
default="serial",
choices=["parallel", "serial"],
type=str,
help="position to insert the language tag, input or before prediction",
)
parser.add_argument(
"--ctc-linear-transform",
action="store_true",
help="introduce additional linear transform in CTC",
)
pass pass
@classmethod @classmethod
...@@ -956,6 +974,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -956,6 +974,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
ctc_linear_transform = getattr(args, "ctc_linear_transform", False)
self.pae_ground_truth_ratio = getattr( self.pae_ground_truth_ratio = getattr(
args, "ctc_pae_ground_truth_ratio", 0 args, "ctc_pae_ground_truth_ratio", 0
) + getattr(args, "xctc_pae_ground_truth_ratio", 0) ) + getattr(args, "xctc_pae_ground_truth_ratio", 0)
...@@ -972,6 +991,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -972,6 +991,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False, need_layernorm=True if self.inter_ctc else False,
ctc_linear_transform=ctc_linear_transform
) )
if ( if (
...@@ -987,15 +1007,17 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -987,15 +1007,17 @@ class S2TTransformerEncoder(FairseqEncoder):
self.inter_ctc_layers = [] self.inter_ctc_layers = []
if args.inter_ctc_layers is not None: if args.inter_ctc_layers is not None:
self.share_inter_ctc_norm = args.share_inter_ctc_norm self.share_inter_ctc_norm = args.share_inter_ctc_norm
if self.share_inter_ctc_norm: self.no_inter_ctc_norm = getattr(args, "no_inter_ctc_norm", False)
logger.info( if not self.no_inter_ctc_norm:
"Share layer norm in intermediate CTC %s." % args.inter_ctc_layers if self.share_inter_ctc_norm:
) logger.info(
else: "Share layer norm in intermediate CTC %s." % args.inter_ctc_layers
logger.info( )
"Do not Share layer norm in intermediate CTC %s." else:
% args.inter_ctc_layers logger.info(
) "Do not Share layer norm in intermediate CTC %s."
% args.inter_ctc_layers
)
inter_ctc_layers = args.inter_ctc_layers.split(",") inter_ctc_layers = args.inter_ctc_layers.split(",")
inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "") inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "")
...@@ -1037,19 +1059,17 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1037,19 +1059,17 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
), ),
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary dictionary=task.source_dictionary,
ctc_linear_transform=ctc_linear_transform
) )
setattr(self, "inter_ctc%d" % layer_idx, inter_ctc) setattr(self, "inter_ctc%d" % layer_idx, inter_ctc)
# inter_layer_norm = LayerNorm(dim)
# setattr(
# self, "inter_layer_norm%d" % layer_idx, inter_layer_norm
# )
else: else:
self.ctc = CTC( self.ctc = CTC(
dim, dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary, dictionary=task.source_dictionary,
ctc_linear_transform=ctc_linear_transform
) )
if ( if (
getattr(args, "share_ctc_and_embed", False) getattr(args, "share_ctc_and_embed", False)
...@@ -1129,6 +1149,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1129,6 +1149,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_xctc else False, need_layernorm=True if self.inter_xctc else False,
dictionary=task.target_dictionary, dictionary=task.target_dictionary,
ctc_linear_transform=ctc_linear_transform
) )
if ( if (
...@@ -1192,15 +1213,16 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1192,15 +1213,16 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
self.share_inter_xctc_norm = args.share_inter_xctc_norm self.share_inter_xctc_norm = args.share_inter_xctc_norm
if self.share_inter_xctc_norm: if not self.no_inter_ctc_norm:
logger.info( if self.share_inter_xctc_norm:
"Share layer norm in intermediate XCTC %s." % inter_xctc_layers logger.info(
) "Share layer norm in intermediate XCTC %s." % inter_xctc_layers
else: )
logger.info( else:
"Do not Share layer norm in intermediate XCTC %s." logger.info(
% inter_xctc_layers "Do not Share layer norm in intermediate XCTC %s."
) % inter_xctc_layers
)
inter_xctc_layers = inter_xctc_layers.split(",") inter_xctc_layers = inter_xctc_layers.split(",")
for layer_idx in inter_xctc_layers: for layer_idx in inter_xctc_layers:
...@@ -1221,6 +1243,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1221,6 +1243,7 @@ class S2TTransformerEncoder(FairseqEncoder):
dim, dim,
dictionary_size=len(task.target_dictionary), dictionary_size=len(task.target_dictionary),
dropout=args.dropout, dropout=args.dropout,
ctc_linear_transform=ctc_linear_transform
) )
if ( if (
...@@ -1374,6 +1397,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1374,6 +1397,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.compression_stat = False self.compression_stat = False
self.enc_lang_tag_pos = getattr(args, "enc_lang_tag_pos", "input") self.enc_lang_tag_pos = getattr(args, "enc_lang_tag_pos", "input")
self.bil_ctc_pae_fusion = getattr(args, "bil_ctc_pae_fusion", "serial")
self.log_flag_dict = dict() self.log_flag_dict = dict()
# gather cosine similarity # gather cosine similarity
...@@ -1975,6 +1999,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1975,6 +1999,8 @@ class S2TTransformerEncoder(FairseqEncoder):
is_top=True, is_top=True,
) )
input_for_ctc = x
ctc_pae_output = None
# Inter CTC # Inter CTC
if layer_idx in self.inter_ctc_layers: if layer_idx in self.inter_ctc_layers:
if self.training and self.inter_ctc_drop_prob > 0: if self.training and self.inter_ctc_drop_prob > 0:
...@@ -1989,13 +2015,16 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1989,13 +2015,16 @@ class S2TTransformerEncoder(FairseqEncoder):
inter_ctc = getattr(self, "inter_ctc%d" % layer_idx) inter_ctc = getattr(self, "inter_ctc%d" % layer_idx)
pae = getattr(self, "pae%d" % layer_idx) pae = getattr(self, "pae%d" % layer_idx)
if self.share_inter_ctc_norm: if self.no_inter_ctc_norm:
layer_norm = self.layer_norm x_for_logit = input_for_ctc
else: else:
layer_norm = getattr(self, "ctc_norm%d" % layer_idx) if self.share_inter_ctc_norm:
norm_x = layer_norm(x) layer_norm = self.layer_norm
else:
layer_norm = getattr(self, "ctc_norm%d" % layer_idx)
x_for_logit = layer_norm(input_for_ctc)
logit, logit_list = inter_ctc.forward_with_lang_tag(norm_x, encoder_padding_mask, logit, logit_list = inter_ctc.forward_with_lang_tag(x_for_logit, encoder_padding_mask,
"CTC Layer %d" % layer_idx, "source", "CTC Layer %d" % layer_idx, "source",
self.embed_tokens(src_lang_idx) self.embed_tokens(src_lang_idx)
if src_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None) if src_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
...@@ -2035,12 +2064,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2035,12 +2064,12 @@ class S2TTransformerEncoder(FairseqEncoder):
# logit_list = [logit, None, ctc_force_emit] # logit_list = [logit, None, ctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x ctc_pae_input = input_for_ctc if self.pae_unnorm_input else x_for_logit
if pae.adapter_type != "none": if pae.adapter_type != "none":
x, encoder_padding_mask = pae( ctc_pae_output, encoder_padding_mask = pae(
[pae_input, logit], encoder_padding_mask, ctc_oracle, ctc_oracle_mask [ctc_pae_input, logit], encoder_padding_mask, ctc_oracle, ctc_oracle_mask
) )
self.show_debug(x, "x after pae") self.show_debug(ctc_pae_output, "x after pae")
inter_ctc_logits.append(logit_list) inter_ctc_logits.append(logit_list)
...@@ -2177,19 +2206,27 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2177,19 +2206,27 @@ class S2TTransformerEncoder(FairseqEncoder):
x = x.transpose(0, 1) x = x.transpose(0, 1)
# Inter XCTC # Inter XCTC
if self.bil_ctc_pae_fusion == "serial" and ctc_pae_output is not None:
input_for_xctc = ctc_pae_output
else:
input_for_xctc = input_for_ctc
xctc_pae_output = None
if layer_idx in self.inter_xctc_layers: if layer_idx in self.inter_xctc_layers:
if self.inter_xctc_drop_prob > 0: if self.inter_xctc_drop_prob > 0:
p = torch.rand(1).uniform_() p = torch.rand(1).uniform_()
if p < self.inter_xctc_drop_prob: if p < self.inter_xctc_drop_prob:
break break
if self.share_inter_xctc_norm: if self.no_inter_ctc_norm:
norm_x = self.layer_norm(x) x_for_logit = input_for_xctc
else: else:
norm = getattr(self, "xctc_norm%d" % layer_idx) if self.share_inter_xctc_norm:
norm_x = norm(x) layer_norm = self.layer_norm
else:
layer_norm = getattr(self, "xctc_norm%d" % layer_idx)
x_for_logit = layer_norm(input_for_xctc)
logit, logit_list = self.xctc.forward_with_lang_tag(norm_x, encoder_padding_mask, logit, logit_list = self.xctc.forward_with_lang_tag(x_for_logit, encoder_padding_mask,
"XCTC Layer %d" % layer_idx, "target", "XCTC Layer %d" % layer_idx, "target",
self.embed_tokens(tgt_lang_idx) self.embed_tokens(tgt_lang_idx)
if tgt_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None) if tgt_lang_idx is not None and self.enc_lang_tag_pos == "predict" else None)
...@@ -2228,10 +2265,14 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2228,10 +2265,14 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
# logit_list = [logit, None, xctc_force_emit] # logit_list = [logit, None, xctc_force_emit]
pae_input = x if self.pae_unnorm_input else norm_x if self.pae_unnorm_input:
xctc_pae_input = input_for_xctc
else:
xctc_pae_input = x_for_logit
if self.xctc_pae.adapter_type != "none": if self.xctc_pae.adapter_type != "none":
x, encoder_padding_mask = self.xctc_pae( xctc_pae_output, encoder_padding_mask = self.xctc_pae(
[pae_input, logit], [xctc_pae_input, logit],
encoder_padding_mask, encoder_padding_mask,
xctc_oracle, xctc_oracle,
xctc_oracle_mask, xctc_oracle_mask,
...@@ -2239,6 +2280,13 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -2239,6 +2280,13 @@ class S2TTransformerEncoder(FairseqEncoder):
inter_xctc_logits.append(logit_list) inter_xctc_logits.append(logit_list)
if self.bil_ctc_pae_fusion == "parallel" and ctc_pae_output is not None and xctc_pae_output is not None:
x = (ctc_pae_output + xctc_pae_output) / 2
elif ctc_pae_output is None and xctc_pae_output is not None:
x = xctc_pae_output
elif ctc_pae_output is not None and xctc_pae_output is None:
x = ctc_pae_output
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
......
...@@ -16,10 +16,15 @@ logger = logging.getLogger(__name__) ...@@ -16,10 +16,15 @@ logger = logging.getLogger(__name__)
class CTC(nn.Module): class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, def __init__(self, embed_dim, dictionary_size, dropout,
need_layernorm=False, dictionary=None): need_layernorm=False, dictionary=None, ctc_linear_transform=False):
super(CTC, self).__init__() super(CTC, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
if ctc_linear_transform is True:
self.ctc_linear_transform = nn.Linear(embed_dim, embed_dim)
else:
self.ctc_linear_transform = None
self.ctc_projection = nn.Linear(embed_dim, dictionary_size) self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
...@@ -53,6 +58,9 @@ class CTC(nn.Module): ...@@ -53,6 +58,9 @@ class CTC(nn.Module):
self.save_stream = None self.save_stream = None
def forward(self, x, padding=None, tag=None, is_top=False): def forward(self, x, padding=None, tag=None, is_top=False):
if self.ctc_linear_transform is not None:
x = self.ctc_linear_transform(x)
if self.need_layernorm: if self.need_layernorm:
x = self.LayerNorm(x) x = self.LayerNorm(x)
......
...@@ -556,7 +556,7 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -556,7 +556,7 @@ class SpeechToTextTask(LegacyFairseqTask):
if hasattr(generator, "symbols_to_strip_from_output"): if hasattr(generator, "symbols_to_strip_from_output"):
symbols_to_strip_from_output = generator.symbols_to_strip_from_output symbols_to_strip_from_output = generator.symbols_to_strip_from_output
else: else:
symbols_to_strip_from_output = generator.eos symbols_to_strip_from_output = [generator.eos]
s = self.tgt_dict.string( s = self.tgt_dict.string(
toks.int().cpu(), toks.int().cpu(),
remove_bpe, remove_bpe,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论