Commit 1018d402 by xuchen

acc update, including additional CTC linear transform, bug fix, and so on

parent b984a889
...@@ -294,6 +294,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -294,6 +294,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1)) idx=$((idx + 1))
done done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir} ${data_dir}
--source-lang ${src_lang} --source-lang ${src_lang}
......
...@@ -3,7 +3,7 @@ encoder-type: pds ...@@ -3,7 +3,7 @@ encoder-type: pds
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
pds-layers: 2_6_6_4 pds-layers: 2_4_8_4
pds-ratios: 2_2_2_0 pds-ratios: 2_2_2_0
pds-fusion: False pds-fusion: False
pds-fusion-method: all_conv2 pds-fusion-method: all_conv2
......
warmup-updates: 1000
lr: 1e-4
subsampling-type: conv1d
subsampling-layers: 1
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
...@@ -12,7 +12,7 @@ if [ "$#" -eq 1 ]; then ...@@ -12,7 +12,7 @@ if [ "$#" -eq 1 ]; then
fi fi
sacrebleu=1 sacrebleu=1
ctc_infer=0 ctc_infer=1
n_average=10 n_average=10
beam_size=5 beam_size=5
infer_ctc_weight=0 infer_ctc_weight=0
......
...@@ -79,7 +79,7 @@ step_valid=0 ...@@ -79,7 +79,7 @@ step_valid=0
bleu_valid=0 bleu_valid=0
# Decoding Settings # Decoding Settings
batch_size=1 batch_size=0
sacrebleu=1 sacrebleu=1
dec_model=checkpoint_best.pt dec_model=checkpoint_best.pt
ctc_infer=0 ctc_infer=0
...@@ -93,7 +93,7 @@ infer_debug=0 ...@@ -93,7 +93,7 @@ infer_debug=0
infer_score=0 infer_score=0
infer_tag= infer_tag=
infer_parameter= infer_parameter=
#infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy" #infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
# Parsing Options # Parsing Options
if [[ ${share_dict} -eq 1 ]]; then if [[ ${share_dict} -eq 1 ]]; then
...@@ -289,6 +289,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -289,6 +289,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1)) idx=$((idx + 1))
done done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir} ${data_dir}
--source-lang ${src_lang} --source-lang ${src_lang}
......
...@@ -554,6 +554,11 @@ def process_joint(args): ...@@ -554,6 +554,11 @@ def process_joint(args):
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for lang in languages: for lang in languages:
tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv" tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv"
gather_tsv = output_root / "train_all.tsv"
if os.path.exists(gather_tsv):
tsv_path = gather_tsv
df = load_df_from_tsv(tsv_path) df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]: for t in df["tgt_text"]:
f.write(t + "\n") f.write(t + "\n")
...@@ -567,6 +572,9 @@ def process_joint(args): ...@@ -567,6 +572,9 @@ def process_joint(args):
src_utt = " ".join(src_utt.split(" ")) src_utt = " ".join(src_utt.split(" "))
f.write(src_utt + "\n") f.write(src_utt + "\n")
if os.path.exists(gather_tsv):
break
special_symbols = None special_symbols = None
if args.task == 'st': if args.task == 'st':
special_symbols = [f'<lang:{lang.split("-")[0]}>' for lang in languages] special_symbols = [f'<lang:{lang.split("-")[0]}>' for lang in languages]
......
...@@ -391,11 +391,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -391,11 +391,13 @@ class CtcCriterion(FairseqCriterion):
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx) pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
target_lengths = pad_mask.sum(-1) target_lengths = pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out: input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
non_padding_mask = ~encoder_out["ctc_padding_mask"][0] if isinstance(ctc_logit, list):
else: input_lengths = ((~ctc_logit[1]).long().sum(-1)
non_padding_mask = ~encoder_out["encoder_padding_mask"][0] if ctc_logit[1] is not None
input_lengths = non_padding_mask.long().sum(-1) else input_lengths
)
ctc_logit = ctc_logit[0]
ctc_alignment_oracle["ctc"] = get_ctc_align( ctc_alignment_oracle["ctc"] = get_ctc_align(
ctc_logit, ctc_logit,
...@@ -416,11 +418,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -416,11 +418,13 @@ class CtcCriterion(FairseqCriterion):
xctc_logit = encoder_out["inter_xctc_logits"][-1] xctc_logit = encoder_out["inter_xctc_logits"][-1]
if xctc_logit is not None: if xctc_logit is not None:
if "ctc_padding_mask" in encoder_out: input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
non_padding_mask = ~encoder_out["ctc_padding_mask"][0] if isinstance(xctc_logit, list):
else: input_lengths = ((~xctc_logit[1]).long().sum(-1)
non_padding_mask = ~encoder_out["encoder_padding_mask"][0] if xctc_logit[1] is not None
input_lengths = non_padding_mask.long().sum(-1) else input_lengths
)
xctc_logit = xctc_logit[0]
tokens = self.get_ctc_target_text(sample) tokens = self.get_ctc_target_text(sample)
target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx) target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
......
...@@ -16,10 +16,15 @@ logger = logging.getLogger(__name__) ...@@ -16,10 +16,15 @@ logger = logging.getLogger(__name__)
class CTC(nn.Module): class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, def __init__(self, embed_dim, dictionary_size, dropout,
need_layernorm=False, dictionary=None): need_layernorm=False, dictionary=None, ctc_linear_transform=False):
super(CTC, self).__init__() super(CTC, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
if ctc_linear_transform is True:
self.ctc_linear_transform = nn.Linear(embed_dim, embed_dim)
else:
self.ctc_linear_transform = None
self.ctc_projection = nn.Linear(embed_dim, dictionary_size) self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
...@@ -53,6 +58,9 @@ class CTC(nn.Module): ...@@ -53,6 +58,9 @@ class CTC(nn.Module):
self.save_stream = None self.save_stream = None
def forward(self, x, padding=None, tag=None, is_top=False): def forward(self, x, padding=None, tag=None, is_top=False):
if self.ctc_linear_transform is not None:
x = self.ctc_linear_transform(x)
if self.need_layernorm: if self.need_layernorm:
x = self.LayerNorm(x) x = self.LayerNorm(x)
......
...@@ -556,7 +556,7 @@ class SpeechToTextTask(LegacyFairseqTask): ...@@ -556,7 +556,7 @@ class SpeechToTextTask(LegacyFairseqTask):
if hasattr(generator, "symbols_to_strip_from_output"): if hasattr(generator, "symbols_to_strip_from_output"):
symbols_to_strip_from_output = generator.symbols_to_strip_from_output symbols_to_strip_from_output = generator.symbols_to_strip_from_output
else: else:
symbols_to_strip_from_output = generator.eos symbols_to_strip_from_output = [generator.eos]
s = self.tgt_dict.string( s = self.tgt_dict.string(
toks.int().cpu(), toks.int().cpu(),
remove_bpe, remove_bpe,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论