Commit 67d8695f by xuchen

add target ctc

parent d4255246
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 4_5_6
pds-ratios: 2_2_2
pds-fusion: False
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.0015
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
ctc-weight: 0.3
target-ctc-weight: 0.2
target-ctc-layers: 3,6
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-layers: 15
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
text-encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 15
encoder-activation-fn: swish
encoder-attention-type: rel_pos
acoustic-encoder: transformer
adapter: league
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -13,15 +13,15 @@ encoder-type: pds
encoder-embed-dim: 240
pds-stages: 3
#ctc-layer: 15
pds-layers: 4_5_6
pds-layers: 5_5_5
pds-ratios: 2_2_2
pds-fusion: False
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 120_168_240
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1
pds-kernel-sizes: 3_3_3
pds-kernel-sizes: 5_5_5
pds-ffn-ratios: 4_4_4
pds-attn-heads: 4_4_4
......
arch: s2t_ctc
encoder-type: pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type: rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim: 512
pds-stages: 4
#ctc-layer: 15
encoder-layers: 10
pds-layers: 3_2_2_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 0.002
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
dropout: 0.1
activation-fn: relu
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
#load-pretrained-encoder-from:
......@@ -8,7 +8,7 @@ with open(in_file, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip().lower()
for w in string.punctuation:
line = line.replace(w, "")
line = line.replace(" ", "")
if w != "'":
line = line.replace(w, "")
line = line.replace(" ", " ")
print(line)
......@@ -44,10 +44,10 @@ lcrm=1
tokenizer=1
use_specific_dict=1
specific_prefix=asr5k_st10k
specific_dir=${root_dir}/data/iwslt2022/st_lcrm_asr
src_vocab_prefix=spm_unigram5000_asr
tgt_vocab_prefix=spm_unigram10000_st
specific_prefix=unified
specific_dir=${root_dir}/data/wmt20/vocab
src_vocab_prefix=spm_en
tgt_vocab_prefix=spm_zh
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt
......@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [[ ! -e ${data_dir} ]]; then
mkdir -p ${data_dir}
fi
if [[ ! -e ${data_dir}/data ]]; then
mkdir -p ${data_dir}/data
fi
if [[ ! -f ${data_dir}/${src_vocab_prefix}.txt || ! -f ${data_dir}/${tgt_vocab_prefix}.txt ]]; then
if [[ ${use_specific_dict} -eq 0 ]]; then
......@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--tgt-vocab-type ${tgt_vocab_type}
--src-vocab-size ${src_vocab_size}
--tgt-vocab-size ${tgt_vocab_size}"
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
else
cp -r ${specific_dir}/${src_vocab_prefix}.* ${data_dir}
cp ${specific_dir}/${tgt_vocab_prefix}.* ${data_dir}
fi
fi
mkdir -p ${data_dir}/data
for split in ${train_subset} ${valid_subset} ${trans_subset}; do
{
if [[ -d ${org_data_dir}/data/${split}/txt ]]; then
text_dir=${org_data_dir}/data/${split}/txt
else
text_dir=${org_data_dir}/data/${split}
cmd="python ${code_dir}/examples/speech_to_text/prep_mt_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--splits ${train_subset},${valid_subset},${trans_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--src-vocab-prefix ${src_vocab_prefix}
--tgt-vocab-prefix ${tgt_vocab_prefix}"
fi
if [[ $share_dict -eq 1 ]]; then
cmd="$cmd
--share"
fi
src_text=${text_dir}/${split}.${src_lang}
tgt_text=${text_dir}/${split}.${tgt_lang}
cmd="cat ${src_text}"
if [[ ${lcrm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${src_text}"
cmd="$cmd
--lowercase-src
--rm-punc-src"
fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
cmd="spm_encode
--model ${data_dir}/${tgt_vocab_prefix}.model
--output_format=piece
< ${tgt_text}
> ${data_dir}/data/${split}.${tgt_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
}&
done
wait
fi
cmd="python ${code_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang}
......
......@@ -296,6 +296,9 @@ def process(args):
gen_manifest_flag = True
break
punctuation_str = string.punctuation
punctuation_str.replace("'", "")
train_text = []
if args.overwrite or gen_manifest_flag:
if not use_raw:
......@@ -340,7 +343,7 @@ def process(args):
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
for w in punctuation_str:
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
else:
......@@ -414,7 +417,7 @@ def process(args):
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
for w in string.punctuation:
for w in punctuation_str:
src_utt = src_utt.replace(w, "")
src_utt = " ".join(src_utt.split(" "))
train_text.append(src_utt)
......
......@@ -11,6 +11,7 @@ from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Tuple
import string
import sentencepiece as spm
from examples.speech_to_text.data_utils import (
gen_vocab,
......@@ -62,6 +63,8 @@ def process(args):
splits = args.splits.split(",")
src_train_text = []
tgt_train_text = []
manifest = {c: [] for c in MANIFEST_COLUMNS}
sent_num =[0]
lang = f"{src_lang}-{tgt_lang}"
cur_root = Path(args.data_root).absolute() / lang
......@@ -70,20 +73,22 @@ def process(args):
else:
output_root = Path(args.output_root).absolute()
punctuation_str = string.punctuation
punctuation_str = punctuation_str.replace("'", "")
# Generate TSV manifest
print("Generating manifest...")
for split in splits:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = MTDataset(args.data_root, src_lang, tgt_lang, split, args.tokenizer)
for src_text, tgt_text in tqdm(dataset):
if args.lowercase_src:
src_text = src_text.lower()
if args.rm_punc_src:
for w in string.punctuation:
for w in punctuation_str:
src_text = src_text.replace(w, "")
src_text = src_text.replace(" ", "")
src_text = src_text.replace(" ", " ")
manifest["src_text"].append(src_text)
manifest["tgt_text"].append(tgt_text)
......@@ -94,34 +99,50 @@ def process(args):
if is_train_split:
src_train_text.extend(manifest["src_text"])
tgt_train_text.extend(manifest["tgt_text"])
sent_num.append(len(manifest["src_text"]))
# Generate vocab and yaml
print("Generating vocabulary...")
tgt_v_size_str = "" if args.tgt_vocab_type == "char" else str(args.tgt_vocab_size)
tgt_spm_filename_prefix = f"spm_{args.tgt_vocab_type}{tgt_v_size_str}"
if args.share:
tgt_train_text.extend(src_train_text)
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_share"
if args.tgt_vocab_prefix is not None:
tgt_spm_filename_prefix = args.tgt_vocab_prefix
else:
tgt_train_text.extend(src_train_text)
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_share"
src_spm_filename_prefix = tgt_spm_filename_prefix
else:
src_v_size_str = "" if args.src_vocab_type == "char" else str(args.src_vocab_size)
src_spm_filename_prefix = f"spm_{args.src_vocab_type}{src_v_size_str}"
src_spm_filename_prefix = src_spm_filename_prefix + "_" + src_lang
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_" + tgt_lang
with NamedTemporaryFile(mode="w") as f:
for t in tgt_train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / tgt_spm_filename_prefix,
args.tgt_vocab_type,
args.tgt_vocab_size,
normalization_rule_name="identity" if tgt_lang == "zh" else None
)
if not args.share:
if args.tgt_vocab_prefix is not None:
tgt_spm_filename_prefix = args.tgt_vocab_prefix
else:
tgt_spm_filename_prefix = tgt_spm_filename_prefix + "_" + tgt_lang
if args.src_vocab_prefix is not None:
src_spm_filename_prefix = args.src_vocab_prefix
else:
src_v_size_str = "" if args.src_vocab_type == "char" else str(args.src_vocab_size)
src_spm_filename_prefix = f"spm_{args.src_vocab_type}{src_v_size_str}"
src_spm_filename_prefix = src_spm_filename_prefix + "_" + src_lang
src_spm_model = (output_root / (src_spm_filename_prefix + ".model")).as_posix()
tgt_spm_model = (output_root / (tgt_spm_filename_prefix + ".model")).as_posix()
if not os.path.exists(tgt_spm_model):
with NamedTemporaryFile(mode="w") as f:
for t in tgt_train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
output_root / tgt_spm_filename_prefix,
args.tgt_vocab_type,
args.tgt_vocab_size,
normalization_rule_name="identity" if tgt_lang == "zh" else None
)
if not args.share and not os.path.exists(src_spm_model):
with NamedTemporaryFile(mode="w") as f:
for t in src_train_text:
f.write(t + "\n")
......@@ -133,6 +154,38 @@ def process(args):
normalization_rule_name="identity" if tgt_lang == "zh" else None
)
# Generate sentencepiece
print("Applying sentencepiece...")
tgt_sp = spm.SentencePieceProcessor()
tgt_sp.Load(tgt_spm_model)
if args.share:
src_sp = tgt_sp
else:
src_sp = spm.SentencePieceProcessor()
src_sp.Load(src_spm_model)
index = 0
for split in splits:
src_text = manifest["src_text"][sent_num[index]: sent_num[index + 1]]
tgt_text = manifest["tgt_text"][sent_num[index]: sent_num[index + 1]]
index += 1
src_spm_name = (output_root / "data" / (split + "." + src_lang)).as_posix()
tgt_spm_name = (output_root / "data" / (split + "." + tgt_lang)).as_posix()
with open(src_spm_name, 'w') as f:
for sentence in src_text:
pieces = src_sp.EncodeAsPieces(sentence)
result = " ".join(pieces)
f.write(result + "\n")
with open(tgt_spm_name, 'w') as f:
for sentence in tgt_text:
pieces = tgt_sp.EncodeAsPieces(sentence)
result = " ".join(pieces)
f.write(result + "\n")
# Generate config YAML
yaml_filename = f"config.yaml"
if args.share:
......@@ -162,19 +215,19 @@ def main():
parser.add_argument(
"--src-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
)
parser.add_argument(
"--tgt-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
)
parser.add_argument("--src-vocab-size", default=10000, type=int)
parser.add_argument("--tgt-vocab-size", default=10000, type=int)
parser.add_argument("--src-vocab-prefix", default=None, type=str, help="prefix of the specific source vocabulary")
parser.add_argument("--tgt-vocab-prefix", default=None, type=str, help="prefix of the specific target vocabulary")
parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--splits", default="train,dev,test", type=str)
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
......
......@@ -9,6 +9,8 @@ from argparse import Namespace
from dataclasses import dataclass, field
from omegaconf import II
from typing import Optional
import numpy as np
import logging
import torch
import torch.nn.functional as F
......@@ -19,6 +21,7 @@ from fairseq.data.data_utils import post_process
from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round
logger = logging.getLogger(__name__)
@dataclass
class CtcCriterionConfig(FairseqDataclass):
......@@ -31,8 +34,8 @@ class CtcCriterionConfig(FairseqDataclass):
default="sentencepiece",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
ctc_entropy: float = field(
......@@ -43,6 +46,10 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight of intermedia CTC loss"},
)
target_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of intermedia CTC loss for target sentence"},
)
ctc_self_distill_weight: float = field(
default=0.0,
metadata={"help": "weight of the self distillation CTC loss"},
......@@ -116,10 +123,12 @@ class CtcCriterion(FairseqCriterion):
self.ctc_weight = ctc_weight
self.intermedia_ctc_weight = cfg.intermedia_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + self.ctc_self_distill_weight + self.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + self.target_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy
if self.all_ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
......@@ -145,7 +154,7 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask = ~net_output["ctc_padding_mask"][0]
else:
non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
ctc_input_lengths = input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (transcript["tokens"] != self.pad_idx) & (
transcript["tokens"] != self.eos_idx
......@@ -215,6 +224,43 @@ class CtcCriterion(FairseqCriterion):
if lprobs is None:
lprobs = inter_lprobs
target_ctc_num = 0
target_ctc_loss = 0
if "target_ctc_logits" in net_output:
target_ctc_num = len(net_output["target_ctc_logits"])
# calculate the target CTC loss
if self.target_ctc_weight > 0 and target_ctc_num > 0:
target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
targets_flat = target.masked_select(pad_mask)
target_length = pad_mask.sum(-1)
for i in range(target_ctc_num):
out = net_output["target_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
ctc_input_lengths,
target_length,
)
target_ctc_loss += loss
target_ctc_loss /= target_ctc_num
logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data)
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
......@@ -247,6 +293,7 @@ class CtcCriterion(FairseqCriterion):
loss = \
self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy
......@@ -264,9 +311,9 @@ class CtcCriterion(FairseqCriterion):
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["transcript"]["tokens"] if "transcript" in sample else sample["target"],
input_lengths,
lprobs_t,
sample["transcript"]["tokens"] if "transcript" in sample else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
......@@ -283,7 +330,7 @@ class CtcCriterion(FairseqCriterion):
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
......@@ -332,6 +379,9 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
)
target_ctc_loss_sum = utils.item(
sum(log.get("target_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
)
......@@ -346,6 +396,9 @@ class CtcCriterion(FairseqCriterion):
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
if np.isnan(all_ctc_loss_sum) or np.isinf(all_ctc_loss_sum) or all_ctc_loss_sum < 0:
logger.error("Illegal loss %f!" % all_ctc_loss_sum)
if all_ctc_loss_sum > 0:
if "loss" not in logging_outputs[0]:
metrics.log_scalar(
......@@ -383,6 +436,14 @@ class CtcCriterion(FairseqCriterion):
sample_size,
round=3,
)
if target_ctc_loss_sum > 0:
metrics.log_scalar(
"target_ctc_loss",
target_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"ctc_self_distill_loss",
......@@ -404,8 +465,8 @@ class CtcCriterion(FairseqCriterion):
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
# wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
# metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
......@@ -427,14 +488,14 @@ class CtcCriterion(FairseqCriterion):
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
# metrics.log_derived(
# "raw_wer",
# lambda meters: safe_round(
# meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
# )
# if meters["_w_total"].sum > 0
# else float("nan"),
# )
@staticmethod
def logging_outputs_can_be_summed() -> bool:
......
......@@ -17,6 +17,7 @@ class CTC(nn.Module):
def __init__(self, embed_dim, dictionary_size, dropout, need_layernorm=False):
super(CTC, self).__init__()
self.embed_dim = embed_dim
self.ctc_projection = nn.Linear(embed_dim, dictionary_size, bias=False)
nn.init.normal_(
......
......@@ -232,6 +232,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"rope",
"abs",
"transfer",
"reduced_rel_pos",
],
help="transformer encoder self-attention layer type"
)
......@@ -579,6 +580,12 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
pass
@classmethod
......@@ -626,10 +633,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
if self.attn_type == "reduced":
self.pds_attn_ds_ratios = [int(n) for n in args.pds_attn_ds_ratios.split("_")]
else:
self.pds_attn_ds_ratios = None
self.pds_attn_ds_ratios = [int(n) for n in args.pds_attn_ds_ratios.split("_")]
self.pds_conv_strides = [int(n) for n in args.pds_conv_strides.split("_")]
self.pds_attn_strides = [int(n) for n in args.pds_attn_strides.split("_")]
......@@ -674,7 +678,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1
attn_ds_ratio = self.pds_attn_ds_ratios[i] # if self.attn_type == "reduced" else -1
conv_stride = self.pds_conv_strides[i]
attn_stride = self.pds_attn_strides[i]
if conv_stride != 1 or attn_stride != 1:
......@@ -712,7 +716,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# position encoding
if use_pos_embed:
if self.attn_type == "rel_pos":
if self.attn_type in ["rel_pos", "reduced_rel_pos"]:
pos_embed = RelPositionalEncoding(
args.max_source_positions, embed_dim
)
......@@ -850,7 +854,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if ctc_layer <= 0:
embed_dim = self.pds_embed_dims[i]
break
if inter_ctc_module is None:
if inter_ctc_module is None or embed_dim != inter_ctc_module.embed_dim:
self.ctc = CTC(embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
......@@ -866,6 +870,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else:
self.layer_norm = None
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
self.dis = 2
self.cos_sim = dict()
......@@ -933,7 +938,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout
if pos_embed:
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn", "reduced_rel_pos"]:
positions = pos_embed(x)
elif self.attn_type == "rope":
......@@ -981,7 +986,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
logit = ctc(x.clone())
intermedia_ctc_logits.append([logit, encoder_padding_mask])
prob = utils.softmax(logit, dim=-1)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask)
if self.fusion_stages_num != 0:
......@@ -1131,9 +1136,9 @@ def base_architecture(args):
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
......
......@@ -118,6 +118,7 @@ class S2TCTCModel(FairseqEncoderModel):
"rope",
"abs",
"transfer",
"reduced_rel_pos",
],
help="transformer encoder self-attention layer type"
)
......@@ -739,9 +740,9 @@ def base_architecture(args):
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
......
......@@ -395,6 +395,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
pass
@classmethod
......@@ -585,6 +591,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.adapter = Adapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
@staticmethod
def pooling_ratio():
......@@ -683,7 +690,7 @@ class S2TTransformerEncoder(FairseqEncoder):
intermedia_ctc_logits.append(logit)
# prob = self.ctc.softmax(norm_x)
prob = utils.softmax(logit, dim=-1)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
# gather cosine similarity
......
......@@ -54,6 +54,7 @@ from .positional_encoding import (
from .espnet_multihead_attention import (
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
ReducedRelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
......@@ -113,10 +114,11 @@ __all__ = [
"unfold1d",
"ESPNETMultiHeadedAttention",
"PositionalEmbedding",
"RelPositionMultiHeadedAttention",
"PositionalEncoding",
"LegacyRelPositionalEncoding",
"RelPositionalEncoding",
"RelPositionMultiHeadedAttention",
"ReducedRelPositionMultiHeadedAttention",
"LegacyRelPositionMultiHeadedAttention",
"RotaryPositionalEmbedding",
"RotaryPositionMultiHeadedAttention",
......
......@@ -1347,7 +1347,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
# Assert
assert not (group_size > 1 and kernel_size is not None), "Local grouped attention not implemented"
assert not (group_size > 1 and stride > 1 is not None), "Strided grouped attention not implemented"
assert not (group_size > 1 and stride > 1), "Strided grouped attention not implemented"
assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings"
# Pre Norm
......
......@@ -14,6 +14,7 @@ from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
)
from .layer_norm import LayerNorm
class ESPNETMultiHeadedAttention(nn.Module):
......@@ -72,6 +73,7 @@ class ESPNETMultiHeadedAttention(nn.Module):
if mask is not None:
scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool),
# -1e8 if scores.dtype == torch.float32 else -1e4
float("-inf"), # (batch, head, time1, time2)
)
self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2)
......@@ -195,6 +197,131 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
return scores, None
class ReducedRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False,
sample_ratio=1,
reduced_method="conv",
reduced_q=False,
):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
super().__init__(n_feat, n_head, dropout, zero_triu)
self.sample_ratio = sample_ratio
self.reduced_method = reduced_method
self.reduced_q = reduced_q
if reduced_q:
assert self.reduced_method == 'group', "only support grouped method for query reduction"
if self.sample_ratio > 1:
if reduced_method == "conv":
self.sr = nn.Conv1d(n_feat, n_feat,
kernel_size=sample_ratio,
stride=sample_ratio,
)
self.norm = LayerNorm(n_feat)
elif reduced_method == "pool":
self.linear = nn.Linear(n_feat, n_feat)
self.norm = LayerNorm(n_feat)
self.act = nn.GELU()
elif reduced_method == "group":
pass
def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor 2T-1 X B(1) X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
"""
# (bsz, seq_len, dim)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
pos_emb = pos_emb.transpose(0, 1)
tgt_len = query.size(1)
query_ = query
if self.sample_ratio > 1:
assert tgt_len % self.sample_ratio == 0, \
("sample ratio %d is mismatched with length %d" % (self.sample_ratio, tgt_len))
if self.reduced_method == "conv":
query_ = query.transpose(1, 2) # bsz, dim, seq_len
query_ = self.sr(query_).transpose(1, 2) # bsz, seq_len, dim
query_ = self.norm(query_)
elif self.reduced_method == "pool":
query_ = query.transpose(1, 2) # bsz, dim, seq_len
pool_length = int(tgt_len / self.sample_ratio)
query_ = nn.functional.adaptive_max_pool1d(query_, pool_length).transpose(1, 2)
query_ = self.act(self.norm(query_))
key = value = query_
if key_padding_mask is not None:
key_padding_mask = key_padding_mask[:, ::self.sample_ratio]
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
# q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
......
......@@ -12,6 +12,7 @@ from fairseq.modules import (
ConvolutionModule,
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
ReducedRelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
LocalMultiheadAttention,
ReducedMultiheadAttention,
......@@ -91,6 +92,7 @@ class PDSTransformerEncoderLayer(nn.Module):
self.macaron_norm = None
self.ffn_scale = 1.0
self.conv_stride = conv_stride
if args.use_cnn_module:
self.conv_norm = LayerNorm(embed_dim)
self.conv_module = ConvolutionModule(
......@@ -104,7 +106,6 @@ class PDSTransformerEncoderLayer(nn.Module):
self.final_norm = LayerNorm(expand_embed_dim)
# Convolution Residual
self.conv_stride = conv_stride
self.conv_res = nn.Sequential(
Permute3D(1, 2, 0),
nn.Conv1d(embed_dim, expand_embed_dim, kernel_size=1, stride=conv_stride),
......@@ -173,6 +174,15 @@ class PDSTransformerEncoderLayer(nn.Module):
attention_heads,
dropout=dropout,
)
elif self.attn_type == "reduced_rel_pos":
return ReducedRelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
sample_ratio=sample_ratio,
reduced_method=getattr(args, "attention_reduced_method", "conv"),
reduced_q=getattr(args, "attention_reduced_q", False)
)
elif self.attn_type == "rel_pos_legacy":
return LegacyRelPositionMultiHeadedAttention(
embed_dim,
......@@ -284,7 +294,7 @@ class PDSTransformerEncoderLayer(nn.Module):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn", "reduced_rel_pos"]:
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
query=x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论