Commit 1018d402 by xuchen

acc update, including additional CTC linear transform, bug fix, and so on

parent b984a889
......@@ -294,6 +294,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1))
done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir}
--source-lang ${src_lang}
......
......@@ -3,7 +3,7 @@ encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
pds-layers: 2_6_6_4
pds-layers: 2_4_8_4
pds-ratios: 2_2_2_0
pds-fusion: False
pds-fusion-method: all_conv2
......
warmup-updates: 1000
lr: 1e-4
subsampling-type: conv1d
subsampling-layers: 1
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
......@@ -12,7 +12,7 @@ if [ "$#" -eq 1 ]; then
fi
sacrebleu=1
ctc_infer=0
ctc_infer=1
n_average=10
beam_size=5
infer_ctc_weight=0
......
......@@ -79,7 +79,7 @@ step_valid=0
bleu_valid=0
# Decoding Settings
batch_size=1
batch_size=0
sacrebleu=1
dec_model=checkpoint_best.pt
ctc_infer=0
......@@ -93,7 +93,7 @@ infer_debug=0
infer_score=0
infer_tag=
infer_parameter=
#infer_parameters="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
#infer_parameter="--cal-monotonic-cross-attn-weights --cal-localness --localness-window 0.1 --cal-topk-cross-attn-weights --topk-cross-attn-weights 15 --cal-entropy"
# Parsing Options
if [[ ${share_dict} -eq 1 ]]; then
......@@ -289,6 +289,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
idx=$((idx + 1))
done
#cmd="python3 -u -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${code_dir}/fairseq_cli/train.py
cmd="python3 -u ${code_dir}/fairseq_cli/train.py
${data_dir}
--source-lang ${src_lang}
......
......@@ -554,6 +554,11 @@ def process_joint(args):
with NamedTemporaryFile(mode="w") as f:
for lang in languages:
tsv_path = cur_root / f"{lang}" / f"{args.task}" / f"train.tsv"
gather_tsv = output_root / "train_all.tsv"
if os.path.exists(gather_tsv):
tsv_path = gather_tsv
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
......@@ -566,6 +571,9 @@ def process_joint(args):
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
f.write(src_utt + "\n")
if os.path.exists(gather_tsv):
break
special_symbols = None
if args.task == 'st':
......
......@@ -391,11 +391,13 @@ class CtcCriterion(FairseqCriterion):
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
target_lengths = pad_mask.sum(-1)
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
else:
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
if isinstance(ctc_logit, list):
input_lengths = ((~ctc_logit[1]).long().sum(-1)
if ctc_logit[1] is not None
else input_lengths
)
ctc_logit = ctc_logit[0]
ctc_alignment_oracle["ctc"] = get_ctc_align(
ctc_logit,
......@@ -416,11 +418,13 @@ class CtcCriterion(FairseqCriterion):
xctc_logit = encoder_out["inter_xctc_logits"][-1]
if xctc_logit is not None:
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
else:
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
input_lengths = (~encoder_out["encoder_padding_mask"][0]).long().sum(-1)
if isinstance(xctc_logit, list):
input_lengths = ((~xctc_logit[1]).long().sum(-1)
if xctc_logit[1] is not None
else input_lengths
)
xctc_logit = xctc_logit[0]
tokens = self.get_ctc_target_text(sample)
target_pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
......
......@@ -16,10 +16,15 @@ logger = logging.getLogger(__name__)
class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout,
need_layernorm=False, dictionary=None):
need_layernorm=False, dictionary=None, ctc_linear_transform=False):
super(CTC, self).__init__()
self.embed_dim = embed_dim
if ctc_linear_transform is True:
self.ctc_linear_transform = nn.Linear(embed_dim, embed_dim)
else:
self.ctc_linear_transform = None
self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
......@@ -53,6 +58,9 @@ class CTC(nn.Module):
self.save_stream = None
def forward(self, x, padding=None, tag=None, is_top=False):
if self.ctc_linear_transform is not None:
x = self.ctc_linear_transform(x)
if self.need_layernorm:
x = self.LayerNorm(x)
......
......@@ -556,7 +556,7 @@ class SpeechToTextTask(LegacyFairseqTask):
if hasattr(generator, "symbols_to_strip_from_output"):
symbols_to_strip_from_output = generator.symbols_to_strip_from_output
else:
symbols_to_strip_from_output = generator.eos
symbols_to_strip_from_output = [generator.eos]
s = self.tgt_dict.string(
toks.int().cpu(),
remove_bpe,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论