Commit 67d8695f by xuchen

add target ctc

parent d4255246
arch: s2t_ctc arch: s2t_sate
encoder-type: pds share-decoder-input-output-embed: True
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 4_5_6
pds-ratios: 2_2_2
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 0.0015 lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc ctc-weight: 0.3
post-process: sentencepiece target-ctc-weight: 0.2
target-ctc-layers: 3,6
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-layers: 15 encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
macaron-style: True acoustic-encoder: transformer
use-cnn-module: True adapter: league
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
...@@ -13,15 +13,15 @@ encoder-type: pds ...@@ -13,15 +13,15 @@ encoder-type: pds
encoder-embed-dim: 240 encoder-embed-dim: 240
pds-stages: 3 pds-stages: 3
#ctc-layer: 15 #ctc-layer: 15
pds-layers: 4_5_6 pds-layers: 5_5_5
pds-ratios: 2_2_2 pds-ratios: 2_2_2
pds-fusion: False pds-fusion: True
pds-fusion-method: all_conv pds-fusion-method: all_conv
pds-embed-dims: 120_168_240 pds-embed-dims: 120_168_240
pds-ds-method: conv pds-ds-method: conv
pds-embed-norm: True pds-embed-norm: True
pds-position-embed: 1_1_1 pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3 pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4 pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4 pds-attn-heads: 4_4_4
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type: rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 10
pds-layers: 3_2_2_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
#load-pretrained-encoder-from:
...@@ -8,7 +8,7 @@ with open(in_file, "r", encoding="utf-8") as f: ...@@ -8,7 +8,7 @@ with open(in_file, "r", encoding="utf-8") as f:
for line in f.readlines(): for line in f.readlines():
line = line.strip().lower() line = line.strip().lower()
for w in string.punctuation: for w in string.punctuation:
line = line.replace(w, "") if w != "'":
line = line.replace(" ", "") line = line.replace(w, "")
line = line.replace(" ", " ")
print(line) print(line)
...@@ -44,10 +44,10 @@ lcrm=1 ...@@ -44,10 +44,10 @@ lcrm=1
tokenizer=1 tokenizer=1
use_specific_dict=1 use_specific_dict=1
specific_prefix=asr5k_st10k specific_prefix=unified
specific_dir=${root_dir}/data/iwslt2022/st_lcrm_asr specific_dir=${root_dir}/data/wmt20/vocab
src_vocab_prefix=spm_unigram5000_asr src_vocab_prefix=spm_en
tgt_vocab_prefix=spm_unigram10000_st tgt_vocab_prefix=spm_zh
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt data_dir=${root_dir}/data/${dataset}/mt
...@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir} mkdir -p ${data_dir}
fi fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then if [[ ${use_specific_dict} -eq 0 ]]; then
...@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--tgt-vocab-type ${tgt_vocab_type} --tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size} --src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}" --tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
else else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir} cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir} cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
for split in ${train_subset} ${valid_subset} ${trans_subset}; do --data-root ${org_data_dir}
{ --output-root ${data_dir}
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then --splits ${train_subset},${valid_subset},${trans_subset}
text_dir=${org_data_dir}/data/${split}/txt --src-lang ${src_lang}
else --tgt-lang ${tgt_lang}
text_dir=${org_data_dir}/data/${split} --src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi fi
src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}" cmd="$cmd
--lowercase-src
--rm-punc-src"
fi fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m" echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd} [[ $eval -eq 1 ]] && eval ${cmd}
fi
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
}&
done
wait
cmd="python ${code_dir}/fairseq_cli/preprocess.py cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang} --source-lang ${src_lang} --target-lang ${tgt_lang}
......
...@@ -296,6 +296,9 @@ def process(args): ...@@ -296,6 +296,9 @@ def process(args):
gen_manifest_flag = True gen_manifest_flag = True
break break
punctuation_str = string.punctuation
punctuation_str.replace("'", "")
train_text = [] train_text = []
if args.overwrite or gen_manifest_flag: if args.overwrite or gen_manifest_flag:
if not use_raw: if not use_raw:
...@@ -340,7 +343,7 @@ def process(args): ...@@ -340,7 +343,7 @@ def process(args):
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
if args.rm_punc_src: if args.rm_punc_src:
for w in string.punctuation: for w in punctuation_str:
src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" ")) src_utt = " ".join(src_utt.split(" "))
else: else:
...@@ -414,7 +417,7 @@ def process(args): ...@@ -414,7 +417,7 @@ def process(args):
if args.lowercase_src: if args.lowercase_src:
src_utt = src_utt.lower() src_utt = src_utt.lower()
if args.rm_punc_src: if args.rm_punc_src:
for w in string.punctuation: for w in punctuation_str:
src_utt = src_utt.replace(w, "") src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" ")) src_utt = " ".join(src_utt.split(" "))
train_text.append(src_utt) train_text.append(src_utt)
......
...@@ -11,6 +11,7 @@ from pathlib import Path ...@@ -11,6 +11,7 @@ from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Tuple from typing import Tuple
import string import string
import sentencepiece as spm
from examples.speech_to_text.data_utils import ( from examples.speech_to_text.data_utils import (
gen_vocab, gen_vocab,
...@@ -62,6 +63,8 @@ def process(args): ...@@ -62,6 +63,8 @@ def process(args):
splits = args.splits.split(",") splits = args.splits.split(",")
src_train_text = [] src_train_text = []
tgt_train_text = [] tgt_train_text = []
manifest = {c: [] for c in MANIFEST_COLUMNS}
sent_num =[0]
lang = f"{src_lang}-{tgt_lang}" lang = f"{src_lang}-{tgt_lang}"
cur_root = Path(args.data_root).absolute() / lang cur_root = Path(args.data_root).absolute() / lang
...@@ -70,20 +73,22 @@ def process(args): ...@@ -70,20 +73,22 @@ def process(args):
else: else:
output_root = Path(args.output_root).absolute() output_root = Path(args.output_root).absolute()
punctuation_str = string.punctuation
punctuation_str = punctuation_str.replace("'", "")
# Generate TSV manifest # Generate TSV manifest
print("Generating manifest...") print("Generating manifest...")
for split in splits: for split in splits:
is_train_split = split.startswith("train") is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = MTDataset(args.data_root, src_lang, tgt_lang, split, args.tokenizer) dataset = MTDataset(args.data_root, src_lang, tgt_lang, split, args.tokenizer)
for src_text, tgt_text in tqdm(dataset): for src_text, tgt_text in tqdm(dataset):
if args.lowercase_src: if args.lowercase_src:
src_text = src_text.lower() src_text = src_text.lower()
if args.rm_punc_src: if args.rm_punc_src:
for w in string.punctuation: for w in punctuation_str:
src_text = src_text.replace(w, "") src_text = src_text.replace(w, "")
src_text = src_text.replace(" ", "") src_text = src_text.replace(" ", " ")
manifest["src_text"].append(src_text) manifest["src_text"].append(src_text)
manifest["tgt_text"].append(tgt_text) manifest["tgt_text"].append(tgt_text)
...@@ -94,34 +99,50 @@ def process(args): ...@@ -94,34 +99,50 @@ def process(args):
if is_train_split: if is_train_split:
src_train_text.extend(manifest["src_text"]) src_train_text.extend(manifest["src_text"])
tgt_train_text.extend(manifest["tgt_text"]) tgt_train_text.extend(manifest["tgt_text"])
sent_num.append(len(manifest["src_text"]))
# Generate vocab and yaml # Generate vocab and yaml
print("Generating vocabulary...")
tgt_v_size_str = "" if args.tgt_vocab_type == "char" else str(args.tgt_vocab_size) tgt_v_size_str = "" if args.tgt_vocab_type == "char" else str(args.tgt_vocab_size)
tgt_spm_filename_prefix = f"spm_{args.tgt_vocab_type}{tgt_v_size_str}" tgt_spm_filename_prefix = f"spm_{args.tgt_vocab_type}{tgt_v_size_str}"
if args.share: if args.share:
tgt_train_text.extend(src_train_text) if args.tgt_vocab_prefix is not None:
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_share" tgt_spm_filename_prefix = args.tgt_vocab_prefix
else:
tgt_train_text.extend(src_train_text)
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_share"
src_spm_filename_prefix = tgt_spm_filename_prefix src_spm_filename_prefix = tgt_spm_filename_prefix
else: else:
src_v_size_str = "" if args.src_vocab_type == "char" else str(args.src_vocab_size) if args.tgt_vocab_prefix is not None:
src_spm_filename_prefix = f"spm_{args.src_vocab_type}{src_v_size_str}" tgt_spm_filename_prefix = args.tgt_vocab_prefix
else:
src_spm_filename_prefix = src_spm_filename_prefix + "_" + src_lang tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_" + tgt_lang
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_" + tgt_lang
if args.src_vocab_prefix is not None:
with NamedTemporaryFile(mode="w") as f: src_spm_filename_prefix = args.src_vocab_prefix
for t in tgt_train_text: else:
f.write(t + "\n") src_v_size_str = "" if args.src_vocab_type == "char" else str(args.src_vocab_size)
gen_vocab( src_spm_filename_prefix = f"spm_{args.src_vocab_type}{src_v_size_str}"
Path(f.name), src_spm_filename_prefix = src_spm_filename_prefix + "_" + src_lang
output_root / tgt_spm_filename_prefix,
args.tgt_vocab_type, src_spm_model = (output_root / (src_spm_filename_prefix + ".model")).as_posix()
args.tgt_vocab_size, tgt_spm_model = (output_root / (tgt_spm_filename_prefix + ".model")).as_posix()
normalization_rule_name="identity" if tgt_lang == "zh" else None
) if not os.path.exists(tgt_spm_model):
with NamedTemporaryFile(mode="w") as f:
if not args.share: for t in tgt_train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / tgt_spm_filename_prefix,
args.tgt_vocab_type,
args.tgt_vocab_size,
normalization_rule_name="identity" if tgt_lang == "zh" else None
)
if not args.share and not os.path.exists(src_spm_model):
with NamedTemporaryFile(mode="w") as f: with NamedTemporaryFile(mode="w") as f:
for t in src_train_text: for t in src_train_text:
f.write(t + "\n") f.write(t + "\n")
...@@ -133,6 +154,38 @@ def process(args): ...@@ -133,6 +154,38 @@ def process(args):
normalization_rule_name="identity" if tgt_lang == "zh" else None normalization_rule_name="identity" if tgt_lang == "zh" else None
) )
# Generate sentencepiece
print("Applying sentencepiece...")
tgt_sp = spm.SentencePieceProcessor()
tgt_sp.Load(tgt_spm_model)
if args.share:
src_sp = tgt_sp
else:
src_sp = spm.SentencePieceProcessor()
src_sp.Load(src_spm_model)
index = 0
for split in splits:
src_text = manifest["src_text"][sent_num[index]: sent_num[index + 1]]
tgt_text = manifest["tgt_text"][sent_num[index]: sent_num[index + 1]]
index += 1
src_spm_name = (output_root / "data" / (split + "." + src_lang)).as_posix()
tgt_spm_name = (output_root / "data" / (split + "." + tgt_lang)).as_posix()
with open(src_spm_name, 'w') as f:
for sentence in src_text:
pieces = src_sp.EncodeAsPieces(sentence)
result = " ".join(pieces)
f.write(result + "\n")
with open(tgt_spm_name, 'w') as f:
for sentence in tgt_text:
pieces = tgt_sp.EncodeAsPieces(sentence)
result = " ".join(pieces)
f.write(result + "\n")
# Generate config YAML # Generate config YAML
yaml_filename = f"config.yaml" yaml_filename = f"config.yaml"
if args.share: if args.share:
...@@ -162,19 +215,19 @@ def main(): ...@@ -162,19 +215,19 @@ def main():
parser.add_argument( parser.add_argument(
"--src-vocab-type", "--src-vocab-type",
default="unigram", default="unigram",
required=True,
type=str, type=str,
choices=["bpe", "unigram", "char"], choices=["bpe", "unigram", "char"],
) )
parser.add_argument( parser.add_argument(
"--tgt-vocab-type", "--tgt-vocab-type",
default="unigram", default="unigram",
required=True,
type=str, type=str,
choices=["bpe", "unigram", "char"], choices=["bpe", "unigram", "char"],
) )
parser.add_argument("--src-vocab-size", default=10000, type=int) parser.add_argument("--src-vocab-size", default=10000, type=int)
parser.add_argument("--tgt-vocab-size", default=10000, type=int) parser.add_argument("--tgt-vocab-size", default=10000, type=int)
parser.add_argument("--src-vocab-prefix", default=None, type=str, help="prefix of the specific source vocabulary")
parser.add_argument("--tgt-vocab-prefix", default=None, type=str, help="prefix of the specific target vocabulary")
parser.add_argument("--size", default=-1, type=int) parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--splits", default="train,dev,test", type=str) parser.add_argument("--splits", default="train,dev,test", type=str)
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text") parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
......
...@@ -9,6 +9,8 @@ from argparse import Namespace ...@@ -9,6 +9,8 @@ from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from omegaconf import II from omegaconf import II
from typing import Optional from typing import Optional
import numpy as np
import logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -19,6 +21,7 @@ from fairseq.data.data_utils import post_process ...@@ -19,6 +21,7 @@ from fairseq.data.data_utils import post_process
from fairseq.tasks import FairseqTask from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round from fairseq.logging.meters import safe_round
logger = logging.getLogger(__name__)
@dataclass @dataclass
class CtcCriterionConfig(FairseqDataclass): class CtcCriterionConfig(FairseqDataclass):
...@@ -31,8 +34,8 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -31,8 +34,8 @@ class CtcCriterionConfig(FairseqDataclass):
default="sentencepiece", default="sentencepiece",
metadata={ metadata={
"help": "how to post process predictions into words. can be letter, " "help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. " "wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options" "See fairseq.data.data_utils.post_process() for full list of options"
}, },
) )
ctc_entropy: float = field( ctc_entropy: float = field(
...@@ -43,6 +46,10 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -43,6 +46,10 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "weight of intermedia CTC loss"}, metadata={"help": "weight of intermedia CTC loss"},
) )
target_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of intermedia CTC loss for target sentence"},
)
ctc_self_distill_weight: float = field( ctc_self_distill_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of the self distillation CTC loss"}, metadata={"help": "weight of the self distillation CTC loss"},
...@@ -116,10 +123,12 @@ class CtcCriterion(FairseqCriterion): ...@@ -116,10 +123,12 @@ class CtcCriterion(FairseqCriterion):
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.intermedia_ctc_weight = cfg.intermedia_ctc_weight self.intermedia_ctc_weight = cfg.intermedia_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy self.ctc_entropy = cfg.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + self.ctc_self_distill_weight + self.ctc_entropy self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + self.target_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy
if self.all_ctc_weight > 0: if self.all_ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary." assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True) self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
...@@ -145,7 +154,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -145,7 +154,7 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask = ~net_output["ctc_padding_mask"][0] non_padding_mask = ~net_output["ctc_padding_mask"][0]
else: else:
non_padding_mask = ~net_output["encoder_padding_mask"][0] non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1) ctc_input_lengths = input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (transcript["tokens"] != self.pad_idx) & ( pad_mask = (transcript["tokens"] != self.pad_idx) & (
transcript["tokens"] != self.eos_idx transcript["tokens"] != self.eos_idx
...@@ -215,6 +224,43 @@ class CtcCriterion(FairseqCriterion): ...@@ -215,6 +224,43 @@ class CtcCriterion(FairseqCriterion):
if lprobs is None: if lprobs is None:
lprobs = inter_lprobs lprobs = inter_lprobs
target_ctc_num = 0
target_ctc_loss = 0
if "target_ctc_logits" in net_output:
target_ctc_num = len(net_output["target_ctc_logits"])
# calculate the target CTC loss
if self.target_ctc_weight > 0 and target_ctc_num > 0:
target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
targets_flat = target.masked_select(pad_mask)
target_length = pad_mask.sum(-1)
for i in range(target_ctc_num):
out = net_output["target_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
ctc_input_lengths,
target_length,
)
target_ctc_loss += loss
target_ctc_loss /= target_ctc_num
logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data)
# calculate the self distillation CTC loss # calculate the self distillation CTC loss
ctc_self_distill_loss = 0 ctc_self_distill_loss = 0
ctc_self_distill_num = 0 ctc_self_distill_num = 0
...@@ -247,6 +293,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -247,6 +293,7 @@ class CtcCriterion(FairseqCriterion):
loss = \ loss = \
self.ctc_weight * ctc_loss + \ self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \ self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \ self.ctc_self_distill_weight * ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy self.ctc_entropy * ctc_entropy
...@@ -264,9 +311,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -264,9 +311,9 @@ class CtcCriterion(FairseqCriterion):
w_len = 0 w_len = 0
wv_errs = 0 wv_errs = 0
for lp, t, inp_l in zip( for lp, t, inp_l in zip(
lprobs_t, lprobs_t,
sample["transcript"]["tokens"] if "transcript" in sample else sample["target"], sample["transcript"]["tokens"] if "transcript" in sample else sample["target"],
input_lengths, input_lengths,
): ):
lp = lp[:inp_l].unsqueeze(0) lp = lp[:inp_l].unsqueeze(0)
...@@ -283,7 +330,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -283,7 +330,7 @@ class CtcCriterion(FairseqCriterion):
decoded = decoded[0] decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & ( p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos() t != self.task.target_dictionary.eos()
) )
targ = t[p] targ = t[p]
targ_units = self.task.target_dictionary.string(targ) targ_units = self.task.target_dictionary.string(targ)
...@@ -332,6 +379,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -332,6 +379,9 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_loss_sum = utils.item( inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs) sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
) )
target_ctc_loss_sum = utils.item(
sum(log.get("target_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item( ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs) sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
) )
...@@ -346,6 +396,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -346,6 +396,9 @@ class CtcCriterion(FairseqCriterion):
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
) )
if np.isnan(all_ctc_loss_sum) or np.isinf(all_ctc_loss_sum) or all_ctc_loss_sum < 0:
logger.error("Illegal loss %f!" % all_ctc_loss_sum)
if all_ctc_loss_sum > 0: if all_ctc_loss_sum > 0:
if "loss" not in logging_outputs[0]: if "loss" not in logging_outputs[0]:
metrics.log_scalar( metrics.log_scalar(
...@@ -383,6 +436,14 @@ class CtcCriterion(FairseqCriterion): ...@@ -383,6 +436,14 @@ class CtcCriterion(FairseqCriterion):
sample_size, sample_size,
round=3, round=3,
) )
if target_ctc_loss_sum > 0:
metrics.log_scalar(
"target_ctc_loss",
target_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_self_distill_loss_sum > 0: if ctc_self_distill_loss_sum > 0:
metrics.log_scalar( metrics.log_scalar(
"ctc_self_distill_loss", "ctc_self_distill_loss",
...@@ -404,8 +465,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -404,8 +465,8 @@ class CtcCriterion(FairseqCriterion):
metrics.log_scalar("_c_total", c_total) metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors) metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) # wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors) # metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs) w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total) metrics.log_scalar("_w_total", w_total)
...@@ -427,14 +488,14 @@ class CtcCriterion(FairseqCriterion): ...@@ -427,14 +488,14 @@ class CtcCriterion(FairseqCriterion):
if meters["_w_total"].sum > 0 if meters["_w_total"].sum > 0
else float("nan"), else float("nan"),
) )
metrics.log_derived( # metrics.log_derived(
"raw_wer", # "raw_wer",
lambda meters: safe_round( # lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 # meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
) # )
if meters["_w_total"].sum > 0 # if meters["_w_total"].sum > 0
else float("nan"), # else float("nan"),
) # )
@staticmethod @staticmethod
def logging_outputs_can_be_summed() -> bool: def logging_outputs_can_be_summed() -> bool:
......
...@@ -17,6 +17,7 @@ class CTC(nn.Module): ...@@ -17,6 +17,7 @@ class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False): def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False):
super(CTC, self).__init__() super(CTC, self).__init__()
self.embed_dim = embed_dim
self.ctc_projection = nn.Linear(embed_dim, dictionary_size, bias=False) self.ctc_projection = nn.Linear(embed_dim, dictionary_size, bias=False)
nn.init.normal_( nn.init.normal_(
......
...@@ -232,6 +232,7 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -232,6 +232,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"rope", "rope",
"abs", "abs",
"transfer", "transfer",
"reduced_rel_pos",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -579,6 +580,12 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -579,6 +580,12 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=float, type=float,
help="probability of dropping the followed layers", help="probability of dropping the followed layers",
) )
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
pass pass
@classmethod @classmethod
...@@ -626,10 +633,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -626,10 +633,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")] self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")] self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")] self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
if self.attn_type == "reduced": self.pds_attn_ds_ratios = [int(n) for n in args.pds_attn_ds_ratios.split("_")]
self.pds_attn_ds_ratios = [int(n) for n in args.pds_attn_ds_ratios.split("_")]
else:
self.pds_attn_ds_ratios = None
self.pds_conv_strides = [int(n) for n in args.pds_conv_strides.split("_")] self.pds_conv_strides = [int(n) for n in args.pds_conv_strides.split("_")]
self.pds_attn_strides = [int(n) for n in args.pds_attn_strides.split("_")] self.pds_attn_strides = [int(n) for n in args.pds_attn_strides.split("_")]
...@@ -674,7 +678,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -674,7 +678,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ffn_ratio = self.pds_ffn_ratios[i] ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i] num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1 attn_ds_ratio = self.pds_attn_ds_ratios[i] # if self.attn_type == "reduced" else -1
conv_stride = self.pds_conv_strides[i] conv_stride = self.pds_conv_strides[i]
attn_stride = self.pds_attn_strides[i] attn_stride = self.pds_attn_strides[i]
if conv_stride != 1 or attn_stride != 1: if conv_stride != 1 or attn_stride != 1:
...@@ -712,7 +716,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -712,7 +716,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# position encoding # position encoding
if use_pos_embed: if use_pos_embed:
if self.attn_type == "rel_pos": if self.attn_type in ["rel_pos", "reduced_rel_pos"]:
pos_embed = RelPositionalEncoding( pos_embed = RelPositionalEncoding(
args.max_source_positions, embed_dim args.max_source_positions, embed_dim
) )
...@@ -850,7 +854,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -850,7 +854,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if ctc_layer <= 0: if ctc_layer <= 0:
embed_dim = self.pds_embed_dims[i] embed_dim = self.pds_embed_dims[i]
break break
if inter_ctc_module is None: if inter_ctc_module is None or embed_dim != inter_ctc_module.embed_dim:
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
...@@ -866,6 +870,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -866,6 +870,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else: else:
self.layer_norm = None self.layer_norm = None
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
self.gather_cos_sim = getattr(args, "gather_cos_sim", False) self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.dis = 2 self.dis = 2
self.cos_sim = dict() self.cos_sim = dict()
...@@ -933,7 +938,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -933,7 +938,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout # add the position encoding and dropout
if pos_embed: if pos_embed:
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]: if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn", "reduced_rel_pos"]:
positions = pos_embed(x) positions = pos_embed(x)
elif self.attn_type == "rope": elif self.attn_type == "rope":
...@@ -981,7 +986,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -981,7 +986,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
logit = ctc(x.clone()) logit = ctc(x.clone())
intermedia_ctc_logits.append([logit, encoder_padding_mask]) intermedia_ctc_logits.append([logit, encoder_padding_mask])
prob = utils.softmax(logit, dim=-1) prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask)
if self.fusion_stages_num != 0: if self.fusion_stages_num != 0:
...@@ -1131,9 +1136,9 @@ def base_architecture(args): ...@@ -1131,9 +1136,9 @@ def base_architecture(args):
args.pds_position_embed = getattr(args, "pds_position_embed", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None) args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1") args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1") args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
......
...@@ -118,6 +118,7 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -118,6 +118,7 @@ class S2TCTCModel(FairseqEncoderModel):
"rope", "rope",
"abs", "abs",
"transfer", "transfer",
"reduced_rel_pos",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -739,9 +740,9 @@ def base_architecture(args): ...@@ -739,9 +740,9 @@ def base_architecture(args):
args.pds_position_embed = getattr(args, "pds_position_embed", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None) args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1") args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1") args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
......
...@@ -4,7 +4,7 @@ import math ...@@ -4,7 +4,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import checkpoint_utils from fairseq import checkpoint_utils, utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
register_model, register_model,
...@@ -16,7 +16,7 @@ from fairseq.models.speech_to_text import ( ...@@ -16,7 +16,7 @@ from fairseq.models.speech_to_text import (
PDSS2TTransformerModel, PDSS2TTransformerModel,
PDSS2TTransformerEncoder, PDSS2TTransformerEncoder,
) )
from fairseq.models.speech_to_text.modules import CTCCompressStrategy, Adapter from fairseq.models.speech_to_text.modules import Adapter, CTC
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
...@@ -88,6 +88,12 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -88,6 +88,12 @@ class S2TSATEModel(S2TTransformerModel):
help="the architecture of the acoustic encoder", help="the architecture of the acoustic encoder",
) )
parser.add_argument( parser.add_argument(
"--target-ctc-layers",
default=None,
type=str,
help="ctc layers for target sentence",
)
parser.add_argument(
"--load-pretrained-acoustic-encoder-from", "--load-pretrained-acoustic-encoder-from",
type=str, type=str,
metavar="STR", metavar="STR",
...@@ -138,113 +144,15 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -138,113 +144,15 @@ class S2TSATEModel(S2TTransformerModel):
return encoder return encoder
# class Adapter(nn.Module):
# def __init__(self, args, dictionary, embed_tokens):
# super().__init__()
#
# embed_dim = args.encoder_embed_dim
#
# self.adapter_type = args.adapter
# if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
# self.linear_adapter = nn.Sequential(
# nn.Linear(embed_dim, embed_dim),
# LayerNorm(args.encoder_embed_dim),
# nn.ReLU(),
# )
# elif self.adapter_type == "linear2":
# self.linear_adapter = nn.Sequential(
# nn.Linear(embed_dim, embed_dim),
# )
#
# if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
# if embed_tokens is None:
# num_embeddings = len(dictionary)
# self.embed_adapter = Embedding(num_embeddings, embed_dim, dictionary.pad())
# else:
# self.embed_adapter = embed_tokens
#
# if self.adapter_type == "gated_league":
# self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
# elif self.adapter_type == "gated_league2":
# self.gate_linear1 = nn.Linear(embed_dim, embed_dim)
# self.gate_linear2 = nn.Linear(embed_dim, embed_dim)
#
# if self.adapter_type == "shrink":
# self.ctc_compress_method = getattr(CTCCompressStrategy, args.ctc_compress_strategy)
#
# def forward(self, x, padding):
#
# representation, distribution = x
# batch, seq_len, embed_dim = representation.size()
# org_distribution = distribution
# if distribution is not None:
# distribution = distribution.view(-1, distribution.size(-1))
# lengths = (~padding).long().sum(-1)
#
# if self.adapter_type == "linear":
# out = self.linear_adapter(representation)
#
# elif self.adapter_type == "context":
# out = torch.mm(
# distribution, self.embed_adapter.weight.float()
# ).view(batch, seq_len, -1).type_as(representation)
#
# elif self.adapter_type == "league":
# linear_out = self.linear_adapter(representation)
# soft_out = torch.mm(
# distribution, self.embed_adapter.weight.float()
# ).view(batch, seq_len, -1).type_as(linear_out)
# out = linear_out + soft_out
#
# elif self.adapter_type == "gated_league":
# linear_out = self.linear_adapter(representation)
# soft_out = torch.mm(
# distribution, self.embed_adapter.weight.float()
# ).view(batch, seq_len, -1).type_as(linear_out)
# coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
# out = coef * linear_out + (1 - coef) * soft_out
#
# elif self.adapter_type == "none":
# out = representation
#
# elif self.adapter_type == "shrink":
# from itertools import groupby
#
# with torch.no_grad():
# batch_predicted = []
# prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
# for b in range(prob_ctc.shape[0]):
# predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist()
# batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
#
# new_lengths = [len(p) for p in batch_predicted]
# weights_matrix = self.ctc_compress_method(prob_ctc, batch_predicted, new_lengths,
# prob_ctc.dtype, prob_ctc.device)
#
# # x is T x B x C -> B x C x T; weights_matrix is B x T x T'
# data_type = representation.dtype
# representation = representation.permute(1, 2, 0).float()
# compressed_output = representation.bmm(weights_matrix).type_as(data_type) # B x C x T'
# out = compressed_output.permute(2, 0, 1)
#
# out_lengths = lengths.new(new_lengths)
# padding = lengths_to_padding_mask(out_lengths)
#
# else:
# out = None
# logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
#
# return out, padding
class TextEncoder(FairseqEncoder): class TextEncoder(FairseqEncoder):
def __init__(self, args, dictionary): def __init__(self, args, dictionary, embed_tokens=None):
super().__init__(None) super().__init__(None)
self.embed_tokens = None
embed_dim = args.encoder_embed_dim embed_dim = args.encoder_embed_dim
layer_num = args.text_encoder_layers
self.layer_num = layer_num
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
if args.no_scale_embedding: if args.no_scale_embedding:
self.embed_scale = 1.0 self.embed_scale = 1.0
...@@ -259,13 +167,44 @@ class TextEncoder(FairseqEncoder): ...@@ -259,13 +167,44 @@ class TextEncoder(FairseqEncoder):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.text_encoder_layers)] [TransformerEncoderLayer(args) for _ in range(layer_num)]
) )
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim) self.layer_norm = LayerNorm(args.encoder_embed_dim)
else: else:
self.layer_norm = None self.layer_norm = None
self.intermedia_ctc_layers = []
self.target_ctc_layers = getattr(args, "target_ctc_layers", None)
if self.target_ctc_layers is not None:
intermedia_ctc_layers = self.target_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
assert layer_idx <= layer_num, (layer_idx, layer_num)
if layer_idx <= 0:
layer_idx += layer_num
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
self.ctc = CTC(embed_dim,
dictionary_size=len(dictionary),
dropout=args.dropout)
if embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
dictionary, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
def forward(self, x, encoder_padding_mask=None, history=None): def forward(self, x, encoder_padding_mask=None, history=None):
x = self.embed_scale * x x = self.embed_scale * x
...@@ -273,10 +212,28 @@ class TextEncoder(FairseqEncoder): ...@@ -273,10 +212,28 @@ class TextEncoder(FairseqEncoder):
x = positions + x x = positions + x
x = self.dropout_module(x) x = self.dropout_module(x)
target_ctc_logits = []
layer_idx = 0
for layer in self.layers: for layer in self.layers:
layer_idx += 1
if history is not None: if history is not None:
x = history.pop() x = history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
if layer_idx != self.layer_num and layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
target_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if history is not None: if history is not None:
history.push(x) history.push(x)
...@@ -286,7 +243,11 @@ class TextEncoder(FairseqEncoder): ...@@ -286,7 +243,11 @@ class TextEncoder(FairseqEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
return x if layer_idx in self.intermedia_ctc_layers:
logit = self.ctc(x)
target_ctc_logits.append(logit)
return x, target_ctc_logits
class S2TSATEEncoder(FairseqEncoder): class S2TSATEEncoder(FairseqEncoder):
...@@ -327,7 +288,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -327,7 +288,7 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = args.text_attention_type args.encoder_attention_type = args.text_attention_type
# text encoder # text encoder
self.text_encoder = TextEncoder(args, task.source_dictionary) self.text_encoder = TextEncoder(args, task.source_dictionary, embed_tokens)
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
...@@ -367,12 +328,13 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -367,12 +328,13 @@ class S2TSATEEncoder(FairseqEncoder):
self.history.push(x) self.history.push(x)
x = self.text_encoder(x, encoder_padding_mask, self.history) x, target_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C "ctc_logit": [ctc_logit], # T x B x C
"intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C "intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C
"target_ctc_logits": target_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T "ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
...@@ -490,15 +452,23 @@ def base_architecture(args): ...@@ -490,15 +452,23 @@ def base_architecture(args):
args.pds_position_embed = getattr(args, "pds_position_embed", None) args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None) args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None) args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout) args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False) args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
@register_model_architecture("s2t_sate", "s2t_sate_s") @register_model_architecture("s2t_sate", "s2t_sate_s")
def s2t_sate_s(args): def s2t_sate_s(args):
......
...@@ -395,6 +395,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -395,6 +395,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=float, type=float,
help="probability of dropping the followed layers", help="probability of dropping the followed layers",
) )
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
pass pass
@classmethod @classmethod
...@@ -585,6 +591,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -585,6 +591,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.adapter = Adapter(dim, args.intermedia_adapter, self.adapter = Adapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy) task.source_dictionary, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
@staticmethod @staticmethod
def pooling_ratio(): def pooling_ratio():
...@@ -683,7 +690,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -683,7 +690,7 @@ class S2TTransformerEncoder(FairseqEncoder):
intermedia_ctc_logits.append(logit) intermedia_ctc_logits.append(logit)
# prob = self.ctc.softmax(norm_x) # prob = self.ctc.softmax(norm_x)
prob = utils.softmax(logit, dim=-1) prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
# gather cosine similarity # gather cosine similarity
......
...@@ -54,6 +54,7 @@ from .positional_encoding import ( ...@@ -54,6 +54,7 @@ from .positional_encoding import (
from .espnet_multihead_attention import ( from .espnet_multihead_attention import (
ESPNETMultiHeadedAttention, ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention, RelPositionMultiHeadedAttention,
ReducedRelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention, LegacyRelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention, RotaryPositionMultiHeadedAttention,
) )
...@@ -113,10 +114,11 @@ __all__ = [ ...@@ -113,10 +114,11 @@ __all__ = [
"unfold1d", "unfold1d",
"ESPNETMultiHeadedAttention", "ESPNETMultiHeadedAttention",
"PositionalEmbedding", "PositionalEmbedding",
"RelPositionMultiHeadedAttention",
"PositionalEncoding", "PositionalEncoding",
"LegacyRelPositionalEncoding", "LegacyRelPositionalEncoding",
"RelPositionalEncoding", "RelPositionalEncoding",
"RelPositionMultiHeadedAttention",
"ReducedRelPositionMultiHeadedAttention",
"LegacyRelPositionMultiHeadedAttention", "LegacyRelPositionMultiHeadedAttention",
"RotaryPositionalEmbedding", "RotaryPositionalEmbedding",
"RotaryPositionMultiHeadedAttention", "RotaryPositionMultiHeadedAttention",
......
...@@ -1347,7 +1347,7 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -1347,7 +1347,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
# Assert # Assert
assert not (group_size > 1 and kernel_size is not None), "Local grouped attention not implemented" assert not (group_size > 1 and kernel_size is not None), "Local grouped attention not implemented"
assert not (group_size > 1 and stride > 1 is not None), "Strided grouped attention not implemented" assert not (group_size > 1 and stride > 1), "Strided grouped attention not implemented"
assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings" assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings"
# Pre Norm # Pre Norm
......
...@@ -14,6 +14,7 @@ from fairseq.modules.rotary_positional_embedding import ( ...@@ -14,6 +14,7 @@ from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding, RotaryPositionalEmbedding,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
from .layer_norm import LayerNorm
class ESPNETMultiHeadedAttention(nn.Module): class ESPNETMultiHeadedAttention(nn.Module):
...@@ -72,6 +73,7 @@ class ESPNETMultiHeadedAttention(nn.Module): ...@@ -72,6 +73,7 @@ class ESPNETMultiHeadedAttention(nn.Module):
if mask is not None: if mask is not None:
scores = scores.masked_fill( scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool), mask.unsqueeze(1).unsqueeze(2).to(bool),
# -1e8 if scores.dtype == torch.float32 else -1e4
float("-inf"), # (batch, head, time1, time2) float("-inf"), # (batch, head, time1, time2)
) )
self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2) self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2)
...@@ -195,6 +197,131 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): ...@@ -195,6 +197,131 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
return scores, None return scores, None
class ReducedRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False,
sample_ratio=1,
reduced_method="conv",
reduced_q=False,
):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
super().__init__(n_feat, n_head, dropout, zero_triu)
self.sample_ratio = sample_ratio
self.reduced_method = reduced_method
self.reduced_q = reduced_q
if reduced_q:
assert self.reduced_method == 'group', "only support grouped method for query reduction"
if self.sample_ratio > 1:
if reduced_method == "conv":
self.sr = nn.Conv1d(n_feat, n_feat,
kernel_size=sample_ratio,
stride=sample_ratio,
)
self.norm = LayerNorm(n_feat)
elif reduced_method == "pool":
self.linear = nn.Linear(n_feat, n_feat)
self.norm = LayerNorm(n_feat)
self.act = nn.GELU()
elif reduced_method == "group":
pass
def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor 2T-1 X B(1) X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
"""
# (bsz, seq_len, dim)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
pos_emb = pos_emb.transpose(0, 1)
tgt_len = query.size(1)
query_ = query
if self.sample_ratio > 1:
assert tgt_len % self.sample_ratio == 0, \
("sample ratio %d is mismatched with length %d" % (self.sample_ratio, tgt_len))
if self.reduced_method == "conv":
query_ = query.transpose(1, 2) # bsz, dim, seq_len
query_ = self.sr(query_).transpose(1, 2) # bsz, seq_len, dim
query_ = self.norm(query_)
elif self.reduced_method == "pool":
query_ = query.transpose(1, 2) # bsz, dim, seq_len
pool_length = int(tgt_len / self.sample_ratio)
query_ = nn.functional.adaptive_max_pool1d(query_, pool_length).transpose(1, 2)
query_ = self.act(self.norm(query_))
key = value = query_
if key_padding_mask is not None:
key_padding_mask = key_padding_mask[:, ::self.sample_ratio]
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
# q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention): class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version). """Multi-Head Attention layer with relative position encoding (old version).
......
...@@ -12,6 +12,7 @@ from fairseq.modules import ( ...@@ -12,6 +12,7 @@ from fairseq.modules import (
ConvolutionModule, ConvolutionModule,
ESPNETMultiHeadedAttention, ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention, RelPositionMultiHeadedAttention,
ReducedRelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention, LegacyRelPositionMultiHeadedAttention,
LocalMultiheadAttention, LocalMultiheadAttention,
ReducedMultiheadAttention, ReducedMultiheadAttention,
...@@ -91,6 +92,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -91,6 +92,7 @@ class PDSTransformerEncoderLayer(nn.Module):
self.macaron_norm = None self.macaron_norm = None
self.ffn_scale = 1.0 self.ffn_scale = 1.0
self.conv_stride = conv_stride
if args.use_cnn_module: if args.use_cnn_module:
self.conv_norm = LayerNorm(embed_dim) self.conv_norm = LayerNorm(embed_dim)
self.conv_module = ConvolutionModule( self.conv_module = ConvolutionModule(
...@@ -104,7 +106,6 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -104,7 +106,6 @@ class PDSTransformerEncoderLayer(nn.Module):
self.final_norm = LayerNorm(expand_embed_dim) self.final_norm = LayerNorm(expand_embed_dim)
# Convolution Residual # Convolution Residual
self.conv_stride = conv_stride
self.conv_res = nn.Sequential( self.conv_res = nn.Sequential(
Permute3D(1, 2, 0), Permute3D(1, 2, 0),
nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride), nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride),
...@@ -173,6 +174,15 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -173,6 +174,15 @@ class PDSTransformerEncoderLayer(nn.Module):
attention_heads, attention_heads,
dropout=dropout, dropout=dropout,
) )
elif self.attn_type == "reduced_rel_pos":
return ReducedRelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
sample_ratio=sample_ratio,
reduced_method=getattr(args, "attention_reduced_method", "conv"),
reduced_q=getattr(args, "attention_reduced_q", False)
)
elif self.attn_type == "rel_pos_legacy": elif self.attn_type == "rel_pos_legacy":
return LegacyRelPositionMultiHeadedAttention( return LegacyRelPositionMultiHeadedAttention(
embed_dim, embed_dim,
...@@ -284,7 +294,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -284,7 +294,7 @@ class PDSTransformerEncoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]: if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn", "reduced_rel_pos"]:
assert pos_emb is not None, "Positions is necessary for RPE!" assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论