Commit ebd1be88 by xuchen

fix the implementation of the relative position encoding in conformer, optimize…

fix the implementation of the relative position encoding in conformer, optimize the code of ctc loss
parent 0c7e71c7
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -13,7 +13,7 @@ label_smoothing: 0.1 ...@@ -13,7 +13,7 @@ label_smoothing: 0.1
subsampling-type: conv1d subsampling-type: conv1d
subsmapling-layers: 2 subsmapling-layers: 2
subsampling-filter: 1024 subsampling-filter: 512
subsampling-kernel: 5 subsampling-kernel: 5
subsampling-stride: 2 subsampling-stride: 2
subsampling-norm: none subsampling-norm: none
...@@ -32,3 +32,9 @@ decoder-ffn-embed-dim: 2048 ...@@ -32,3 +32,9 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1 attention-dropout: 0.1
activation-dropout: 0.1 activation-dropout: 0.1
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
encoder-attention-type: rel_selfattn encoder-attention-type: rel_pos
#encoder-attention-type: relative #encoder-attention-type: relative
#max-encoder-relative-length: 100 #max-encoder-relative-length: 100
...@@ -4,14 +4,20 @@ clip-norm: 10.0 ...@@ -4,14 +4,20 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 0.0015
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
post-process: sentencepiece post-process: sentencepiece
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 704 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 176 encoder-embed-dim: 176
...@@ -22,4 +28,5 @@ encoder-attention-heads: 4 ...@@ -22,4 +28,5 @@ encoder-attention-heads: 4
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-attention-type: rel_selfattn encoder-activation-fn: swish
\ No newline at end of file encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt ...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
\ No newline at end of file
...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt ...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
ctc-weight: 0.2
intermedia-ctc-layers: 6,9 intermedia-ctc-layers: 6,9
intermedia-adapter: league intermedia-adapter: league
intermedia-ctc-weight: 0.15 intermedia-ctc-weight: 0.1
ctc-self-distill-weight: 1 ctc-self-distill-weight: 0
\ No newline at end of file post-process: sentencepiece
\ No newline at end of file
...@@ -5,21 +5,28 @@ lr-scheduler: inverse_sqrt ...@@ -5,21 +5,28 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
zero_infinity: True zero_infinity: True
post-process: sentencepiece post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
encoder-attention-heads: 4 encoder-attention-heads: 4
attention-dropout: 0.1 #load-pretrained-encoder-from:
activation-dropout: 0.1 \ No newline at end of file
\ No newline at end of file
...@@ -6,7 +6,6 @@ gpu_num=8 ...@@ -6,7 +6,6 @@ gpu_num=8
update_freq=1 update_freq=1
max_tokens=40000 max_tokens=40000
extra_tag= extra_tag=
extra_parameter= extra_parameter=
#extra_tag="${extra_tag}" #extra_tag="${extra_tag}"
...@@ -14,12 +13,12 @@ extra_parameter= ...@@ -14,12 +13,12 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base) config_list=(base ctc)
#config_list=(ctc) config_list=(purectc)
#config_list=(base conformer) #config_list=(base conformer)
#config_list=(pds_base_16) #config_list=(pds_base_16)
config_list=(pds_base_16 conformer rpr) #config_list=(pds_base_16 conformer rpr)
# exp full name # exp full name
exp_name= exp_name=
......
...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt ...@@ -6,13 +6,19 @@ lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
ctc-weight: 0.2
intermedia-ctc-layers: 6,9 intermedia-ctc-layers: 6,9
intermedia-adapter: league intermedia-adapter: league
intermedia-ctc-weight: 0.15 intermedia-ctc-weight: 0.1
ctc-self-distill-weight: 1 ctc-self-distill-weight: 0
\ No newline at end of file post-process: sentencepiece
\ No newline at end of file
...@@ -20,6 +20,8 @@ from fairseq.tasks import FairseqTask ...@@ -20,6 +20,8 @@ from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round from fairseq.logging.meters import safe_round
@dataclass @dataclass
class CtcCriterionConfig(FairseqDataclass): class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field( zero_infinity: bool = field(
...@@ -35,6 +37,19 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -35,6 +37,19 @@ class CtcCriterionConfig(FairseqDataclass):
"See fairseq.data.data_utils.post_process() for full list of options" "See fairseq.data.data_utils.post_process() for full list of options"
}, },
) )
ctc_entropy: float = field(
default=0.0,
metadata={"help": "weight of CTC entropy"},
)
intermedia_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of intermedia CTC loss"},
)
ctc_self_distill_weight: float = field(
default=0.0,
metadata={"help": "weight of the self distillation CTC loss"},
)
wer_kenlm_model: Optional[str] = field( wer_kenlm_model: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
...@@ -64,12 +79,8 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -64,12 +79,8 @@ class CtcCriterionConfig(FairseqDataclass):
@register_criterion("ctc", dataclass=CtcCriterionConfig) @register_criterion("ctc", dataclass=CtcCriterionConfig)
class CtcCriterion(FairseqCriterion): class CtcCriterion(FairseqCriterion):
def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask, ctc_weight=1.0):
super().__init__(task) super().__init__(task)
self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
if cfg.wer_args is not None: if cfg.wer_args is not None:
( (
...@@ -99,48 +110,151 @@ class CtcCriterion(FairseqCriterion): ...@@ -99,48 +110,151 @@ class CtcCriterion(FairseqCriterion):
else: else:
self.w2l_decoder = None self.w2l_decoder = None
self.zero_infinity = cfg.zero_infinity self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
self.sentence_avg = cfg.sentence_avg self.sentence_avg = cfg.sentence_avg
self.ctc_weight = ctc_weight
self.intermedia_ctc_weight = cfg.intermedia_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + self.ctc_self_distill_weight + self.ctc_entropy
if self.all_ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"]) net_output = model(**sample["net_input"])
lprobs = model.get_normalized_probs(
net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder
ntokens = sample["ntokens"]
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
loss, logging_output = self.compute_ctc_loss(model, sample, net_output, logging_output)
return loss, sample_size, logging_output
def compute_ctc_loss(self, model, sample, net_output, logging_output):
transcript = sample["transcript"]
if "ctc_padding_mask" in net_output:
non_padding_mask = ~net_output["ctc_padding_mask"][0]
else:
non_padding_mask = ~net_output["encoder_padding_mask"][0] non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (sample["target"] != self.pad_idx) & ( pad_mask = (transcript["tokens"] != self.pad_idx) & (
sample["target"] != self.eos_idx transcript["tokens"] != self.eos_idx
) )
targets_flat = sample["target"].masked_select(pad_mask) targets_flat = transcript["tokens"].masked_select(pad_mask)
target_lengths = pad_mask.sum(-1) transcript_lengths = pad_mask.sum(-1)
ctc_loss = 0
ctc_entropy = 0
lprobs = None
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
ctc_logit = net_output["ctc_logit"][0]
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss( ctc_loss = self.ctc_loss(
lprobs, lprobs,
targets_flat, targets_flat,
input_lengths, input_lengths,
target_lengths, transcript_lengths,
blank=self.blank_idx, )
reduction="sum", if self.ctc_entropy > 0:
zero_infinity=self.zero_infinity, from torch.distributions import Categorical
# ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
# ctc_logit = ctc_logit / ctc_logit.sum(dim=-1, keepdim=True)
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:100]
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
# ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
intermedia_ctc_num = 0
intermedia_ctc_loss = 0
if "intermedia_ctc_logits" in net_output:
intermedia_ctc_num = len(net_output["intermedia_ctc_logits"])
# calculate the intermedia CTC loss
if self.intermedia_ctc_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = net_output["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
input_lengths,
transcript_lengths,
) )
intermedia_ctc_loss += loss
intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
if lprobs is None:
lprobs = inter_lprobs
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.ctc_weight > 0 and self.ctc_self_distill_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = net_output["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
if inter_ctc_logit.size() != ctc_logit.size():
continue
ntokens = ( ctc_self_distill_num += 1
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item() loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1),
F.softmax(ctc_logit, dim=-1),
reduction="none",
) )
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
loss = loss.sum()
ctc_self_distill_loss += loss
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens ctc_self_distill_loss /= ctc_self_distill_num
logging_output = { logging_output["ctc_self_distill_loss"] = utils.item(ctc_self_distill_loss.data)
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
if not model.training: loss = \
self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy
logging_output["all_ctc_loss"] = utils.item(loss.data)
if not model.training and self.ctc_weight > 0:
import editdistance import editdistance
with torch.no_grad(): with torch.no_grad():
...@@ -153,9 +267,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -153,9 +267,7 @@ class CtcCriterion(FairseqCriterion):
wv_errs = 0 wv_errs = 0
for lp, t, inp_l in zip( for lp, t, inp_l in zip(
lprobs_t, lprobs_t,
sample["target_label"] sample["transcript"]["tokens"] if "transcript" in sample else sample["target"],
if "target_label" in sample
else sample["target"],
input_lengths, input_lengths,
): ):
lp = lp[:inp_l].unsqueeze(0) lp = lp[:inp_l].unsqueeze(0)
...@@ -207,13 +319,28 @@ class CtcCriterion(FairseqCriterion): ...@@ -207,13 +319,28 @@ class CtcCriterion(FairseqCriterion):
logging_output["c_errors"] = c_err logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len logging_output["c_total"] = c_len
return loss, sample_size, logging_output return loss, logging_output
@staticmethod @staticmethod
def reduce_metrics(logging_outputs) -> None: def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
ctc_entropy_sum = utils.item(
sum(log.get("ctc_entropy", 0) for log in logging_outputs)
)
inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
)
all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
)
# loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item( nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs) sum(log.get("nsentences", 0) for log in logging_outputs)
...@@ -221,15 +348,56 @@ class CtcCriterion(FairseqCriterion): ...@@ -221,15 +348,56 @@ class CtcCriterion(FairseqCriterion):
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
) )
if all_ctc_loss_sum > 0:
if "loss" not in logging_outputs[0]:
metrics.log_scalar( metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3 "loss",
all_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
) )
else:
if all_ctc_loss_sum != ctc_loss_sum:
metrics.log_scalar(
"all_ctc_loss",
all_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_loss_sum > 0:
metrics.log_scalar(
"ctc_loss",
ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_entropy_sum > 0:
metrics.log_scalar(
"ctc_entropy",
ctc_entropy_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if inter_ctc_loss_sum > 0:
metrics.log_scalar(
"intermedia_ctc_loss",
inter_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"ctc_self_distill_loss",
ctc_self_distill_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_scalar("ntokens", ntokens) metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences) metrics.log_scalar("nsentences", nsentences)
if sample_size != ntokens: if sample_size != ntokens:
metrics.log_scalar( metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 "nll_loss", ctc_loss_sum / ntokens / math.log(2), ntokens, round=3
) )
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
......
...@@ -13,41 +13,28 @@ from fairseq.data.data_utils import post_process ...@@ -13,41 +13,28 @@ from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round from fairseq.logging.meters import safe_round
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from .ctc import CtcCriterion, CtcCriterionConfig
@register_criterion("label_smoothed_cross_entropy_with_ctc") @register_criterion("label_smoothed_cross_entropy_with_ctc")
class LabelSmoothedCrossEntropyCriterionWithCTC( class LabelSmoothedCrossEntropyCriterionWithCTC(
LabelSmoothedCrossEntropyCriterion LabelSmoothedCrossEntropyCriterion
): ):
def __init__(self, task, sentence_avg, label_smoothing, post_process="letter", def __init__(self, task, label_smoothing,
ctc_weight=0.0, intermedia_ctc_weight=0.0, ctc_self_distill_weight=0.0): sentence_avg,
cfg: CtcCriterionConfig,
ctc_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing) super().__init__(task, sentence_avg, label_smoothing)
self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.report_accuracy = True self.report_accuracy = True
self.ctc_weight = ctc_weight
assert 0 <= ctc_weight self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight)
self.top_ctc_weight = ctc_weight
self.intermedia_ctc_weight = intermedia_ctc_weight
self.ctc_self_distill_weight = ctc_self_distill_weight
self.ctc_weight = ctc_weight + intermedia_ctc_weight + ctc_self_distill_weight
if self.ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.post_process = post_process
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser) LabelSmoothedCrossEntropyCriterion.add_args(parser)
parser.add_argument( CtcCriterion.add_args(parser)
"--zero-infinity",
default=True,
type=bool,
help="zero inf loss when source length <= target length",
)
parser.add_argument( parser.add_argument(
"--ctc-weight", "--ctc-weight",
default=0.0, default=0.0,
...@@ -55,33 +42,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -55,33 +42,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metavar="D", metavar="D",
help="weight of CTC loss", help="weight of CTC loss",
) )
parser.add_argument(
"--intermedia-ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of intermedia CTC loss",
)
parser.add_argument(
"--ctc-self-distill",
action="store_true",
help="use self distillation for intermedia CTC loss",
)
parser.add_argument(
"--ctc-self-distill-weight",
default=0.0,
type=float,
metavar="D",
help="weight of the self distillation CTC loss",
)
parser.add_argument(
"--post-process",
default="letter",
type=str,
help="how to post process predictions into words. can be letter, "
"word-piece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options",
)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -114,173 +74,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -114,173 +74,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output["n_correct"] = utils.item(n_correct.data) logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data) logging_output["total"] = utils.item(total.data)
if self.ctc_weight > 0: if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.compute_ctc_loss(model, sample, encoder_out, logging_output) ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.top_ctc_weight - self.intermedia_ctc_weight) * loss + ctc_loss loss = (1 - self.ctc_weight) * loss + ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output return loss, sample_size, logging_output
def compute_ctc_loss(self, model, sample, encoder_out, logging_output):
transcript = sample["transcript"]
if "ctc_padding_mask" in encoder_out:
non_padding_mask = ~encoder_out["ctc_padding_mask"][0]
else:
non_padding_mask = ~encoder_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (transcript["tokens"] != self.pad_idx) & (
transcript["tokens"] != self.eos_idx
)
targets_flat = transcript["tokens"].masked_select(pad_mask)
transcript_lengths = pad_mask.sum(-1)
ctc_loss = 0
lprobs = None
if self.top_ctc_weight > 0 and "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) > 0:
ctc_logit = encoder_out["ctc_logit"][0]
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
ctc_loss = self.ctc_loss(
lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
intermedia_ctc_num = 0
intermedia_ctc_loss = 0
if "intermedia_ctc_logits" in encoder_out:
intermedia_ctc_num = len(encoder_out["intermedia_ctc_logits"])
# calculate the intermedia CTC loss
if self.intermedia_ctc_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = encoder_out["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
intermedia_ctc_loss += loss
intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
if lprobs is None:
lprobs = inter_lprobs
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.top_ctc_weight > 0 and self.ctc_self_distill_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = encoder_out["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
if inter_ctc_logit.size() != ctc_logit.size():
continue
ctc_self_distill_num += 1
loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1),
F.softmax(ctc_logit, dim=-1),
reduction="none",
)
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
loss = loss.sum()
ctc_self_distill_loss += loss
ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(ctc_self_distill_loss.data)
loss = \
self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss
if self.intermedia_ctc_weight > 0 or self.ctc_self_distill_weight > 0:
logging_output["all_ctc_loss"] = utils.item(loss.data)
if not model.training and self.ctc_weight > 0:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["transcript"]["tokens"] if "transcript" in sample else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return loss, logging_output
@staticmethod @staticmethod
def reduce_metrics(logging_outputs) -> None: def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
...@@ -291,19 +91,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -291,19 +91,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item( nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs) sum(log.get("nll_loss", 0) for log in logging_outputs)
) )
ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
)
all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -319,38 +106,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -319,38 +106,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar( metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
) )
if ctc_loss_sum > 0:
metrics.log_scalar(
"ctc_loss",
ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if inter_ctc_loss_sum > 0:
metrics.log_scalar(
"intermedia_ctc_loss",
inter_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"ctc_self_distill_loss",
ctc_self_distill_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if all_ctc_loss_sum > 0:
metrics.log_scalar(
"all_ctc_loss",
all_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]:
CtcCriterion.reduce_metrics(logging_outputs)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0: if total > 0:
...@@ -368,44 +128,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -368,44 +128,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
else float("nan"), else float("nan"),
) )
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"cer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
@staticmethod @staticmethod
def logging_outputs_can_be_summed() -> bool: def logging_outputs_can_be_summed() -> bool:
""" """
......
...@@ -85,34 +85,45 @@ class InterAdapter(nn.Module): ...@@ -85,34 +85,45 @@ class InterAdapter(nn.Module):
if self.adapter_type == "shrink": if self.adapter_type == "shrink":
self.ctc_compress = getattr(CTCCompressStrategy, strategy) self.ctc_compress = getattr(CTCCompressStrategy, strategy)
logger.info("CTC Compress Strategy: %s" % strategy)
elif self.adapter_type == "league":
self.distribution_cutoff = strategy
if self.distribution_cutoff != -1:
logger.info("Distribution cutoff: %d" % int(strategy))
def forward(self, x, padding): def forward(self, x, padding):
representation, distribution = x representation, distribution = x
dim1, dim2, dim = representation.size() dim1, dim2, dim = representation.size()
org_distribution = distribution org_distribution = distribution
if distribution is not None:
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear": if self.adapter_type == "linear":
out = self.linear_adapter(representation) out = self.linear_adapter(representation)
elif self.adapter_type == "context": elif self.adapter_type == "context":
distribution = distribution.view(-1, distribution.size(-1))
out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.distribution_cutoff != -1:
cutoff = min(int(self.distribution_cutoff), distribution.size(-1) - 1)
threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
distribution = torch.where(distribution > threshold, distribution, torch.zeros_like(distribution))
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = linear_out + soft_out out = linear_out + soft_out
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league": elif self.adapter_type == "inter_league":
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1) soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = representation + soft_out out = representation + soft_out
......
...@@ -15,73 +15,23 @@ from fairseq.models import ( ...@@ -15,73 +15,23 @@ from fairseq.models import (
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text.modules import InterAdapter, CTC from fairseq.models.speech_to_text.modules import InterAdapter, CTC
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
RelPositionalEncoding,
S2TTransformerEncoderLayer, S2TTransformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
) )
from fairseq.modules.speech_to_text import (
subsampling
)
from torch import Tensor from torch import Tensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
kernel_sizes: List[int] = (3, 3),
):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
inner_x = []
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
inner_x.append(x)
_, _, out_seq_len = x.size()
# x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
out_inner_x = []
for x in inner_x:
out_inner_x.append(x.transpose(1, 2).transpose(0, 1).contiguous())
return out_inner_x, self.get_out_seq_lens_tensor(src_lengths)
@register_model("s2t_ctc") @register_model("s2t_ctc")
class S2TCTCModel(FairseqEncoderModel): class S2TCTCModel(FairseqEncoderModel):
...@@ -91,18 +41,43 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -91,18 +41,43 @@ class S2TCTCModel(FairseqEncoderModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# input # subsampling
parser.add_argument( parser.add_argument(
"--conv-kernel-sizes", "--subsampling-type",
type=str, type=str,
metavar="N", help="subsampling type, like conv1d and conv2d",
help="kernel sizes of Conv1d subsampling layers",
) )
parser.add_argument( parser.add_argument(
"--conv-channels", "--subsampling-layers",
type=int, type=int,
metavar="N", help="subsampling layers",
help="# of channels in Conv1d subsampling layers", )
parser.add_argument(
"--subsampling-filter",
type=int,
help="subsampling filter",
)
parser.add_argument(
"--subsampling-kernel",
type=int,
help="subsampling kernel",
)
parser.add_argument(
"--subsampling-stride",
type=int,
help="subsampling stride",
)
parser.add_argument(
"--subsampling-norm",
type=str,
default="none",
help="subsampling normalization type",
)
parser.add_argument(
"--subsampling-activation",
type=str,
default="none",
help="subsampling activation function type",
) )
# Transformer # Transformer
parser.add_argument( parser.add_argument(
...@@ -153,6 +128,9 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -153,6 +128,9 @@ class S2TCTCModel(FairseqEncoderModel):
"reduced", "reduced",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"rel_pos",
"rope",
"abs"
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -320,6 +298,13 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -320,6 +298,13 @@ class S2TCTCModel(FairseqEncoderModel):
# Conformer setting # Conformer setting
parser.add_argument( parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style", "--macaron-style",
default=False, default=False,
type=bool, type=bool,
...@@ -390,6 +375,12 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -390,6 +375,12 @@ class S2TCTCModel(FairseqEncoderModel):
type=str, type=str,
help="type of intermedia adapter", help="type of intermedia adapter",
) )
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
type=int,
help="cutoff of the distribution",
)
pass pass
@classmethod @classmethod
...@@ -422,13 +413,15 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -422,13 +413,15 @@ class S2TCTCModel(FairseqEncoderModel):
def get_normalized_probs( def get_normalized_probs(
self, self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], net_output,
log_probs: bool, log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None, sample: Optional[Dict[str, Tensor]] = None,
): ):
# net_output['encoder_out'] is a (T, B, D) tensor # net_output['encoder_out'] is a (T, B, D) tensor
if isinstance(net_output, list):
logits = net_output[0]
else:
logits = net_output["ctc_logit"][0] logits = net_output["ctc_logit"][0]
# logits = logits.transpose(0, 1)
if log_probs: if log_probs:
return utils.log_softmax(logits.float(), dim=-1) return utils.log_softmax(logits.float(), dim=-1)
else: else:
...@@ -461,16 +454,19 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -461,16 +454,19 @@ class S2TCTCEncoder(FairseqEncoder):
self.embed_scale = 1.0 self.embed_scale = 1.0
self.padding_idx = 1 self.padding_idx = 1
self.subsample = Conv1dSubsampler( self.subsample = subsampling(args)
args.input_feat_per_channel * args.input_channels,
args.conv_channels,
dim,
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos":
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_source_positions, dim, self.padding_idx args.max_source_positions, args.encoder_embed_dim, self.padding_idx
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
...@@ -513,6 +509,8 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -513,6 +509,8 @@ class S2TCTCEncoder(FairseqEncoder):
strategy = None strategy = None
if args.intermedia_adapter == "shrink": if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None) strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", -1)
self.adapter = InterAdapter(dim, args.intermedia_adapter, self.adapter = InterAdapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy) task.source_dictionary, strategy=strategy)
...@@ -547,26 +545,28 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -547,26 +545,28 @@ class S2TCTCEncoder(FairseqEncoder):
# down-sampling # down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
# (B, T, D) -> (T, B, D)
if type(x) == list: x = x.transpose(0, 1)
inner_x = x
# gather cosine similarity
if self.gather_cos_sim:
for x in inner_x:
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
x = inner_x[-1]
# embedding scaling # embedding scaling
x = self.embed_scale * x x = self.embed_scale * x
# padding and position embedding # padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type == "rel_pos":
positions = self.embed_positions(x)
elif self.attn_type == "rope":
positions = None
else:
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions x += positions
positions = None
x = self.dropout_module(x) x = self.dropout_module(x)
positions = self.dropout_module(positions) # positions = self.dropout_module(positions)
# add emb into history # add emb into history
if self.history is not None: if self.history is not None:
...@@ -723,16 +723,7 @@ class CTCDecoder(object): ...@@ -723,16 +723,7 @@ class CTCDecoder(object):
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1) ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
beam_results, beam_scores, timesteps, out_lens = self.ctc_decoder.decode(F.softmax(ctc_logit, -1), src_lengths) beam_results, beam_scores, timesteps, out_lens = self.ctc_decoder.decode(F.softmax(ctc_logit, -1), src_lengths)
# beam_results = beam_results[:, :, :out_lens.max()]
# for beam_idx in range(beam_size):
# top_beam_tokens = beam_results[:, beam_idx, :]
# top_beam_len = out_lens[:, beam_idx]
# mask = torch.arange(0, top_beam_tokens.size(1)).type_as(top_beam_len). \
# repeat(top_beam_len.size(0), 1).lt(top_beam_len.unsqueeze(1))
# top_beam_tokens[~mask] = self.pad
finalized = [] finalized = []
for idx in range(bsz): for idx in range(bsz):
hypos = [] hypos = []
for beam_idx in range(beam_size): for beam_idx in range(beam_size):
...@@ -752,8 +743,13 @@ class CTCDecoder(object): ...@@ -752,8 +743,13 @@ class CTCDecoder(object):
@register_model_architecture(model_name="s2t_ctc", arch_name="s2t_ctc") @register_model_architecture(model_name="s2t_ctc", arch_name="s2t_ctc")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.conv_channels = getattr(args, "conv_channels", 1024) args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# Transformer # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
......
...@@ -19,6 +19,7 @@ from fairseq.modules import ( ...@@ -19,6 +19,7 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
RelPositionalEncoding,
S2TTransformerEncoderLayer, S2TTransformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
) )
...@@ -132,6 +133,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -132,6 +133,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"reduced", "reduced",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"rel_pos",
"rope",
"abs"
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -299,6 +303,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -299,6 +303,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
# Conformer setting # Conformer setting
parser.add_argument( parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style", "--macaron-style",
default=False, default=False,
type=bool, type=bool,
...@@ -369,6 +380,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -369,6 +380,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=str, type=str,
help="type of intermedia adapter", help="type of intermedia adapter",
) )
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
type=int,
help="cutoff of the distribution",
)
pass pass
@classmethod @classmethod
...@@ -477,8 +494,16 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -477,8 +494,16 @@ class S2TTransformerEncoder(FairseqEncoder):
self.subsample = subsampling(args) self.subsample = subsampling(args)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos":
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_source_positions, dim, self.padding_idx args.max_source_positions, args.encoder_embed_dim, self.padding_idx
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
...@@ -540,6 +565,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -540,6 +565,8 @@ class S2TTransformerEncoder(FairseqEncoder):
strategy = None strategy = None
if args.intermedia_adapter == "shrink": if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None) strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", -1)
self.adapter = InterAdapter(dim, args.intermedia_adapter, self.adapter = InterAdapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy) task.source_dictionary, strategy=strategy)
...@@ -586,11 +613,20 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -586,11 +613,20 @@ class S2TTransformerEncoder(FairseqEncoder):
# padding and position embedding # padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type == "rel_pos":
positions = self.embed_positions(x)
elif self.attn_type == "rope":
positions = None
else:
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions x += positions
positions = None
x = self.dropout_module(x) x = self.dropout_module(x)
positions = self.dropout_module(positions) # positions = self.dropout_module(positions)
# add emb into history # add emb into history
if self.history is not None: if self.history is not None:
...@@ -742,8 +778,13 @@ class TransformerDecoderScriptable(TransformerDecoder): ...@@ -742,8 +778,13 @@ class TransformerDecoderScriptable(TransformerDecoder):
@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") @register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.conv_channels = getattr(args, "conv_channels", 1024) args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# Transformer # Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
...@@ -791,6 +832,7 @@ def base_architecture(args): ...@@ -791,6 +832,7 @@ def base_architecture(args):
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
# Conformer # Conformer
args.encoder_activation_fn = getattr(args, "encoder_activation_fn", "relu")
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
......
...@@ -10,7 +10,6 @@ from .adaptive_input import AdaptiveInput ...@@ -10,7 +10,6 @@ from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder from .character_token_embedder import CharacterTokenEmbedder
from .convolution import ConvolutionModule
from .downsample_convolution import DownSampleConvolutionModule from .downsample_convolution import DownSampleConvolutionModule
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .cross_entropy import cross_entropy from .cross_entropy import cross_entropy
...@@ -30,6 +29,7 @@ from .learned_positional_embedding import LearnedPositionalEmbedding ...@@ -30,6 +29,7 @@ from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
from .local_multihead_attention import LocalMultiheadAttention from .local_multihead_attention import LocalMultiheadAttention
from .location_attention import LocationAttention
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding from .positional_embedding import PositionalEmbedding
from .reduced_multihead_attention import ReducedMultiheadAttention from .reduced_multihead_attention import ReducedMultiheadAttention
...@@ -44,6 +44,16 @@ from .transpose_last import TransposeLast ...@@ -44,6 +44,16 @@ from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .espnet_multihead_attention import (
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import (
RelPositionalEncoding,
)
from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .pds_layer import PDSTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
...@@ -77,6 +87,7 @@ __all__ = [ ...@@ -77,6 +87,7 @@ __all__ = [
"LightweightConv", "LightweightConv",
"LinearizedConvolution", "LinearizedConvolution",
"LocalMultiheadAttention", "LocalMultiheadAttention",
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"PDSTransformerEncoderLayer", "PDSTransformerEncoderLayer",
...@@ -96,4 +107,10 @@ __all__ = [ ...@@ -96,4 +107,10 @@ __all__ = [
"TransposeLast", "TransposeLast",
"VGGBlock", "VGGBlock",
"unfold1d", "unfold1d",
"ESPNETMultiheadedAttention",
"PositionalEmbedding",
"RelPositionMultiHeadedAttention",
"RelPositionalEncoding",
"RotaryPositionalEmbedding",
"RotaryPositionMultiHeadedAttention",
] ]
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from typing import Callable
def get_activation_fn(activation: str) -> Callable:
""" Returns the activation function corresponding to `activation` """
from fairseq.modules import gelu, gelu_accurate
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def get_activation_class(activation: str, dim=None): def get_activation_class(activation: str, dim=None):
......
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
###############################################################################
# Multi-Head Attention Layers
###############################################################################
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention Layer
Args:
dim_model: model feature dimension
num_heads: number of attention heads
References:
Attention Is All You Need, Vaswani et al.
https://arxiv.org/abs/1706.03762
"""
def __init__(self, dim_model, num_heads):
super(MultiHeadAttention, self).__init__()
# Attention Params
self.num_heads = num_heads # H
self.dim_model = dim_model # D
self.dim_head = dim_model // num_heads # d
# Linear Layers
self.query_layer = nn.Linear(self.dim_model, self.dim_model)
self.key_layer = nn.Linear(self.dim_model, self.dim_model)
self.value_layer = nn.Linear(self.dim_model, self.dim_model)
self.output_layer = nn.Linear(self.dim_model, self.dim_model)
def forward(self, query, key, value, mask=None):
"""Scaled Dot-Product Multi-Head Attention
Args:
query: Query of shape (B, T, D)
key: Key of shape (B, T, D)
value: Value of shape (B, T, D)
mask: Optional position mask of shape (1 or B, 1 or H, 1 or T, 1 or T)
Return:
out: Attention output of shape (B, T, D)
att_w: Attention weights of shape (B, H, T, T)
"""
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Reshape and Transpose (B, T, D) -> (B, H, T, d)
query = query.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
key = key.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
value = value.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Att scores (B, H, T, T)
att_scores = query.matmul(key.transpose(2, 3)) / key.shape[-1] ** 0.5
# Apply mask
if mask is not None:
att_scores += (mask * -1e9)
# Att weights (B, H, T, T)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, H, T, d) -> (B, T, D)
out = out.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach()
def pad(self, query, key, value, mask, chunk_size):
# Compute Overflows
overflow_Q = query.size(1) % chunk_size
overflow_KV = key.size(1) % chunk_size
padding_Q = chunk_size - overflow_Q if overflow_Q else 0
padding_KV = chunk_size - overflow_KV if overflow_KV else 0
batch_size, seq_len_KV, _ = key.size()
# Input Padding (B, T, D) -> (B, T + P, D)
query = F.pad(query, (0, 0, 0, padding_Q), value=0)
key = F.pad(key, (0, 0, 0, padding_KV), value=0)
value = F.pad(value, (0, 0, 0, padding_KV), value=0)
# Update Padding Mask
if mask is not None:
# (B, 1, 1, T) -> (B, 1, 1, T + P)
if mask.size(2) == 1:
mask = F.pad(mask, pad=(0, padding_KV), value=1)
# (B, 1, T, T) -> (B, 1, T + P, T + P)
else:
mask = F.pad(mask, pad=(0, padding_Q, 0, padding_KV), value=1)
elif padding_KV:
# None -> (B, 1, 1, T + P)
mask = F.pad(query.new_zeros(batch_size, 1, 1, seq_len_KV), pad=(0, padding_KV), value=1)
return query, key, value, mask, padding_Q
class GroupedMultiHeadAttention(MultiHeadAttention):
"""Grouped Multi-Head Attention Layer
Grouped multi-head attention reduces attention complexity from out(T2·D) to out(T2·D/G)
by grouping neighbouring time elements along the feature dimension before applying
scaled dot-product attention.
Args:
dim_model: model feature dimension
num_heads: number of attention heads
group_size: attention group size
"""
def __init__(self, dim_model, num_heads, group_size):
super(GroupedMultiHeadAttention, self).__init__(dim_model, num_heads)
# Attention Params
self.group_size = group_size # G
self.dim_head = (self.group_size * dim_model) // self.num_heads # d
def forward(self, query, key, value, mask=None):
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Chunk Padding
query, key, value, mask, padding = self.pad(query, key, value, mask, chunk_size=self.group_size)
# Reshape and Transpose (B, T, D) -> (B, H, T//G, d)
query = query.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
key = key.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
value = value.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Att scores (B, H, T//G, T//G)
att_scores = query.matmul(key.transpose(2, 3)) / key.shape[-1] ** 0.5
# Apply mask
if mask is not None:
# Slice Mask (B, 1, T, T) -> (B, 1, T//G, T//G)
mask = mask[:, :, ::self.group_size, ::self.group_size]
# Apply mask
att_scores += (mask * -1e9)
# Att weights (B, H, T//G, T//G)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T//G, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, H, T//G, d) -> (B, T, D)
out = out.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Slice Padding
out = out[:, :out.size(1) - padding]
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach()
class LocalMultiHeadAttention(MultiHeadAttention):
"""Local Multi-Head Attention Layer
Local multi-head attention restricts the attended positions to a local neighborhood
around the query position. This is achieved by segmenting the hidden sequence into
non overlapping blocks of size key and performing scaled dot-product attention in
parallel for each of these blocks.
Args:
dim_model: model feature dimension
num_heads: number of attention heads
kernel_size: attention kernel size / window
References:
Image Transformer, Parmar et al.
https://arxiv.org/abs/1802.05751
"""
def __init__(self, dim_model, num_heads, kernel_size):
super(LocalMultiHeadAttention, self).__init__(dim_model, num_heads)
# Attention Params
self.kernel_size = kernel_size # key
def forward(self, query, key, value, mask=None):
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Chunk Padding
query, key, value, mask, padding = self.pad(query, key, value, mask, chunk_size=self.kernel_size)
# Reshape and Transpose (B, T, D) -> (B, T//key, H, key, d)
query = query.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
key = key.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
value = value.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Att scores (B, T//key, H, key, key)
att_scores = query.matmul(key.transpose(3, 4)) / key.shape[-1] ** 0.5
# Apply mask
if mask is not None:
# Slice mask (B, 1, T, T) -> (B, T//key, 1, key, key)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size: (m + 1) * self.kernel_size,
m * self.kernel_size: (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Apply mask
att_scores = att_scores.float() - mask.float() * 1e9
# Att weights (B, T//key, H, key, key)
att_w = att_scores.softmax(dim=-1)
# Att output (B, T//key, H, key, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, T//key, H, key, d) -> (B, T, D)
out = out.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
out = out[:, :out.size(1) - padding]
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach()
class StridedMultiHeadAttention(MultiHeadAttention):
"""Strided Multi-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding
the attention query before aplying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence
context to perform downsampling.
Args:
dim_model: model feature dimension
num_heads: number of attention heads
stride: query stride
"""
def __init__(self, dim_model, num_heads, stride):
super(StridedMultiHeadAttention, self).__init__(dim_model, num_heads)
# Attention Params
self.stride = stride # S
def forward(self, query, key, value, mask=None):
# Query Subsampling (B, T, D) -> (B, T//S, D)
query = query[:, ::self.stride]
# Mask Subsampling (B, 1, T, T) -> (B, 1, T//S, T)
if mask is not None:
mask = mask[:, :, ::self.stride]
# Multi-Head Attention
return super(StridedMultiHeadAttention, self).forward(query, key, value, mask)
class StridedLocalMultiHeadAttention(MultiHeadAttention):
"""Strided Local Multi-Head Attention Layer
Args:
dim_model: model feature dimension
num_heads: number of attention heads
kernel_size: attention kernel size / window
stride: query stride
"""
def __init__(self, dim_model, num_heads, kernel_size, stride):
super(StridedLocalMultiHeadAttention, self).__init__(dim_model, num_heads)
# Assert
assert kernel_size % stride == 0, "Attention kernel size has to be a multiple of attention stride"
# Attention Params
self.kernel_size = kernel_size # key
self.stride = stride # S
def forward(self, query, key, value, mask=None):
# Batch size B
batch_size = query.size(0)
# Query Subsampling (B, T, D) -> (B, T//S, D)
query = query[:, ::self.stride]
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Chunk Padding
query, key, value, mask, padding = self.pad(query, key, value, mask, chunk_size=self.kernel_size)
# Reshape and Transpose (B, T//S, D) -> (B, T//key, H, key//S, d)
query = query.reshape(batch_size, -1, self.kernel_size // self.stride, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, T, D) -> (B, T//key, H, key, d)
key = key.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
value = value.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Att scores (B, T//key, H, key//S, key)
att_scores = query.matmul(key.transpose(3, 4)) / key.shape[-1] ** 0.5
# Apply mask
if mask is not None:
# Slice mask (B, 1, T, T) -> (B, T//key, 1, key, key)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size: (m + 1) * self.kernel_size,
m * self.kernel_size: (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Subsample mask (B, T//key, 1, key, key) -> (B, T//key, 1, key//S, key)
mask = mask[:, :, :, ::self.stride]
# Apply mask
att_scores = att_scores.float() - mask.float() * 1e9
# Att weights (B, T//key, H, key//S, key)
att_w = att_scores.softmax(dim=-1)
# Att output (B, T//key, H, key//S, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, T//key, H, key//S, d) -> (B, T//S, D)
out = out.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
out = out[:, :(out.size(1) - padding - 1) // self.stride + 1]
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach()
class MultiHeadLinearAttention(MultiHeadAttention):
"""Multi-Head Linear Attention
Args:
dim_model: model feature dimension
num_heads: number of attention heads
References:
Efficient Attention: Attention with Linear Complexities, Shen et al.
https://arxiv.org/abs/1812.01243
Efficient conformer-based speech recognition with linear attention, Li et al.
https://arxiv.org/abs/2104.06865
"""
def __init__(self, dim_model, num_heads):
super(MultiHeadLinearAttention, self).__init__(dim_model, num_heads)
def forward(self, query, key, value):
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Reshape and Transpose (B, T, D) -> (B, N, T, d)
query = query.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
key = key.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
value = value.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Global Context Vector (B, N, d, d)
KV = (key / key.shape[-1] ** (1.0 / 4.0)).softmax(dim=-2).transpose(2, 3).matmul(value)
# Attention Output (B, N, T, d)
out = (query / query.shape[-1] ** (1.0 / 4.0)).softmax(dim=-1).matmul(KV)
# Transpose and Reshape (B, N, T, d) -> (B, T, D)
out = out.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
out = self.output_layer(out)
return out, KV.detach()
###############################################################################
# Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings
###############################################################################
class RelPosMultiHeadSelfAttention(MultiHeadAttention):
"""Multi-Head Self-Attention Layer with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
max_pos_encoding: maximum relative distance between elements
References:
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, Dai et al.
https://arxiv.org/abs/1901.02860
"""
def __init__(self, dim_model, num_heads, causal, max_pos_encoding):
super(RelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads)
# Position Embedding Layer
self.pos_layer = nn.Linear(self.dim_model, self.dim_model)
self.causal = causal
# Global content and positional bias
self.u = nn.Parameter(torch.Tensor(self.dim_model)) # Content bias
self.v = nn.Parameter(torch.Tensor(self.dim_model)) # Pos bias
torch.nn.init.xavier_uniform_(self.u.reshape(self.num_heads, self.dim_head)) # glorot uniform
torch.nn.init.xavier_uniform_(self.v.reshape(self.num_heads, self.dim_head)) # glorot uniform
# Relative Sinusoidal Positional Encodings
self.rel_pos_enc = RelativeSinusoidalPositionalEncoding(max_pos_encoding, self.dim_model, self.causal)
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, H, T, Th + 2*T-1) for full context and (B, H, T, Th + T) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, H, T, Th + T)
References:
causal context:
Music Transformer, Huang et al.
https://arxiv.org/abs/1809.04281
full context:
Attention Augmented Convolutional Networks, Bello et al.
https://arxiv.org/abs/1904.09925
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T, Th + T)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T, 1 + Th + T)
att_scores = F.pad(att_scores, pad=(1, 0), value=0)
# Flatten (B, H, T + TTh + TT)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# Start Padding (B, H, Th + T + TTh + TT)
att_scores = F.pad(att_scores, pad=(seq_length2 - seq_length1, 0), value=0)
# Reshape (B, H, 1 + T, Th + T)
att_scores = att_scores.reshape(batch_size, num_heads, 1 + seq_length1, seq_length2)
# Slice (B, H, T, Th + T)
att_scores = att_scores[:, :, 1:]
# Full Context
else:
# Att Scores (B, H, T, Th + 2*T-1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T, Th + 2*T)
att_scores = F.pad(att_scores, pad=(0, 1), value=0)
# Flatten (B, H, TTh + 2*TT)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# End Padding (B, H, TTh + 2*TT + Th + T - 1)
att_scores = F.pad(att_scores, pad=(0, seq_length2 - seq_length1), value=0)
# Reshape (B, H, T + 1, Th + 2*T-1)
att_scores = att_scores.reshape(batch_size, num_heads, 1 + seq_length1, seq_length2)
# Slice (B, H, T, Th + T)
att_scores = att_scores[:, :, :seq_length1, seq_length1 - 1:]
return att_scores
def forward(self, query, key, value, mask=None, hidden=None):
"""Scaled Dot-Product Self-Attention with relative sinusoidal position encodings
Args:
query: Query of shape (B, T, D)
key: Key of shape (B, T, D)
value: Value of shape (B, T, D)
mask: Optional position mask of shape (1 or B, 1 or H, 1 or T, 1 or T)
hidden: Optional Key and Value hidden states for decoding
Return:
out: Attention output of shape (B, T, D)
att_w: Attention weights of shape (B, H, T, Th + T)
hidden: Key and value hidden states
"""
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Hidden State Provided
if hidden:
key = torch.cat([hidden["key"], key], dim=1)
value = torch.cat([hidden["value"], value], dim=1)
# Update Hidden State
hidden = {"key": key.detach(), "value": value.detach()}
# Add Bias
Qu = query + self.u
Qv = query + self.v
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, query.size(1), key.size(1) - query.size(1)))
# Reshape and Transpose (B, T, D) -> (B, H, T, d)
Qu = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + T, D) -> (B, H, Th + T, d)
key = key.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
value = value.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + 2*T-1, D) -> (B, H, Th + 2*T-1, d) / (B, Th + T, D) -> (B, H, Th + T, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, H, T, Th + T)
att_scores_K = Qu.matmul(key.transpose(2, 3))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / key.shape[-1] ** 0.5
# Apply mask
if mask is not None:
att_scores += (mask * -1e9)
# Att weights (B, H, T, Th + T)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, H, T, d) -> (B, T, D)
out = out.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach(), hidden
class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Grouped Multi-Head Self-Attention Layer with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
max_pos_encoding: maximum relative distance between elements
group_size: attention group size
"""
def __init__(self, dim_model, num_heads, causal, max_pos_encoding, group_size):
super(GroupedRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, max_pos_encoding)
# Attention Params
self.group_size = group_size # G
self.dim_head = (self.group_size * dim_model) // self.num_heads # d
# Grouped Relative Sinusoidal Positional Encodings
self.rel_pos_enc = GroupedRelativeSinusoidalPositionalEncoding(max_pos_encoding, self.dim_model,
self.group_size, self.causal)
def forward(self, query, key, value, mask=None, hidden=None):
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Hidden State Provided
if hidden:
Kh = torch.cat([hidden["key"], key], dim=1)
Vh = torch.cat([hidden["value"], value], dim=1)
key = torch.cat([hidden["key"][:, hidden["key"].size(1) % self.group_size:], key], dim=1)
value = torch.cat([hidden["value"][:, hidden["value"].size(1) % self.group_size:], value], dim=1)
# Update Hidden State
hidden = {"key": Kh.detach(), "value": Vh.detach()}
else:
# Update Hidden State
hidden = {"key": key.detach(), "value": value.detach()}
# Chunk Padding
query, key, value, mask, padding = self.pad(query, key, value, mask, chunk_size=self.group_size)
# Add Bias
Qu = query + self.u
Qv = query + self.v
# Relative Positional Embeddings (B, Th + 2*T-G, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, query.size(1), key.size(1) - query.size(1)))
# Reshape and Transpose (B, T, D) -> (B, H, T//G, d)
Qu = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + T, D) -> (B, H, Th//G + T//G, d)
key = key.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
value = value.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + 2*T-G, D) -> (B, H, Th//G + 2*T//G-1, d) / (B, Th + T, D) -> (B, H, Th//G + T//G, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, H, T//G, Th//G + T//G)
att_scores_K = Qu.matmul(key.transpose(2, 3))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / key.shape[-1] ** 0.5
# Apply mask
if mask is not None:
# Slice Mask (B, 1, T, T) -> (B, 1, T//G, T//G)
mask = mask[:, :, ::self.group_size, ::self.group_size]
# Apply mask
att_scores += (mask * -1e9)
# Att weights (B, H, T//G, Th//G + T//G)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T//G, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, H, T//G, d) -> (B, T, D)
out = out.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Slice Padding
out = out[:, :out.size(1) - padding]
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach(), hidden
class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Local Multi-Head Self-Attention with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
kernel_size: attention kernel size / window
References:
Music Transformer, Huang et al.
https://arxiv.org/abs/1809.04281
"""
def __init__(self, dim_model, num_heads, causal, kernel_size):
super(LocalRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, kernel_size)
# Attention Params
self.kernel_size = kernel_size # key
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, N, T, 2 * key - 1) for full context and (B, H, T, key) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, T//key, H, key, key)
References:
Causal context:
Music Transformer, Huang et al.
https://arxiv.org/abs/1809.04281
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T, key)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//key, H, key, key)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size, self.kernel_size)
# Column Padding (B, T//key, H, key, 1 + key)
att_scores = F.pad(att_scores, pad=(1, 0), value=0)
# Reshape (B, T//key, H, 1 + key, key)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size + 1, self.kernel_size)
# Slice (B, T//key, H, key, key)
att_scores = att_scores[:, :, :, 1:]
# Full Context
else:
# Att Scores (B, H, T, 2 * key - 1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//key, H, key, 2 * key - 1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size, seq_length2)
# Column Padding (B, T//key, H, key, 2 * key)
att_scores = F.pad(att_scores, pad=(0, 1), value=0)
# Flatten (B, T//key, H, key * 2 * key)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, 2 * self.kernel_size ** 2)
# End Padding (B, T//key, H, key * 2 * key + key - 1)
att_scores = F.pad(att_scores, pad=(0, self.kernel_size - 1), value=0)
# Reshape (B, T//key, H, key + 1, 2 * key - 1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size + 1, seq_length2)
# Slice (B, T//key, H, key, key)
att_scores = att_scores[:, :, :, :self.kernel_size, self.kernel_size - 1:]
return att_scores
def forward(self, query, key, value, mask=None, hidden=None):
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Chunk Padding
query, key, value, mask, padding = self.pad(query, key, value, mask, chunk_size=self.kernel_size)
# Add Bias
Qu = query + self.u
Qv = query + self.v
# Relative Positional Embeddings (B, 2*key-1, D) / (B, key, D)
E = self.pos_layer(self.rel_pos_enc(batch_size))
# Reshape and Transpose (B, T, D) -> (B, H, T, d)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, T, D) -> (B, T//key, H, key, d)
Qu = Qu.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
key = key.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
value = value.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, 2*key-1, D) -> (B, H, 2*key-1, d) / (B, key, D) -> (B, H, key, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, T//key, H, key, key)
att_scores_K = Qu.matmul(key.transpose(3, 4))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / key.shape[-1] ** 0.5
# Mask scores
if mask is not None:
# Diagonal Mask (B, 1, T, T) -> (B, T//key, 1, key, key)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size: (m + 1) * self.kernel_size,
m * self.kernel_size: (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Apply Mask
att_scores = att_scores.float() - mask.float() * 1e9
# Attention weights (B, T//key, H, key, key)
att_w = att_scores.softmax(dim=-1)
# Attention output (B, T//key, H, key, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, T//key, H, key, d) -> (B, T, D)
out = out.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
out = out[:, :out.size(1) - padding]
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach(), hidden
class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Strided Multi-Head Self-Attention with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
max_pos_encoding: maximum relative distance between elements
stride: query stride
"""
def __init__(self, dim_model, num_heads, causal, max_pos_encoding, stride):
super(StridedRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, max_pos_encoding)
# Attention Params
self.stride = stride # S
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, H, T//S, Th + 2 * T - 1) for full context and (B, H, T//S, Th + T) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, H, T//S,Th + T)
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T // S, Th + T)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T // S, Th + T + S)
att_scores = F.pad(att_scores, pad=(1, self.stride - 1), value=0)
# Flatten (B, H, TTh//S + TT//S + T)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# Start Padding (B, H, TTh//S + TT//S + T + Th)
att_scores = F.pad(att_scores, pad=(seq_length2 - self.stride * seq_length1, 0), value=0)
# Reshape (B, H, 1 + T // S, Th + T)
att_scores = att_scores.reshape(batch_size, num_heads, seq_length1 + 1, seq_length2)
# Slice (B, H, T // S, Th + T)
att_scores = att_scores[:, :, 1:]
# Full Context
else:
# Att Scores (B, H, T // S, Th + 2*T-1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T // S, Th + 2*T-1 + S)
att_scores = F.pad(att_scores, pad=(0, self.stride), value=0)
# Flatten (B, H, TTh//S + 2*TT//S - T//S + T)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# End Padding (B, H, TTh//S + 2*TT//S - T//S + Th + 2T-1)
att_scores = F.pad(att_scores, pad=(0, seq_length2 - seq_length1 * self.stride), value=0)
# Reshape (B, H, T//S + 1, Th + 2*T-1)
att_scores = att_scores.reshape(batch_size, num_heads, seq_length1 + 1, seq_length2)
# Slice (B, H, T // S, Th + T)
att_scores = att_scores[:, :, :seq_length1, seq_length1 * self.stride - 1:]
return att_scores
def forward(self, query, key, value, mask=None, hidden=None):
# Batch size B
batch_size = query.size(0)
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Hidden State Provided
if hidden:
key = torch.cat([hidden["key"], key], dim=1)
value = torch.cat([hidden["value"], value], dim=1)
# Update Hidden State
hidden = {"key": key.detach(), "value": value.detach()}
# Chunk Padding
query, key, value, mask, _ = self.pad(query, key, value, mask, chunk_size=self.stride)
# Query Subsampling (B, T, D) -> (B, T//S, D)
query = query[:, ::self.stride]
# Add Bias
Qu = query + self.u
Qv = query + self.v
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, self.stride * query.size(1), key.size(1) - self.stride * query.size(1)))
# Reshape and Transpose (B, T//S, D) -> (B, H, T//S, d)
Qu = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + T, D) -> (B, H, Th + T, d)
key = key.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
value = value.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + 2*T-1, D) -> (B, H, Th + 2*T-1, d) / (B, Th + T, D) -> (B, H, Th + T, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, H, T//S, Th + T)
att_scores_K = Qu.matmul(key.transpose(2, 3))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / key.shape[-1] ** 0.5
# Apply mask
if mask is not None:
# Mask Subsampling (B, 1, T, T) -> (B, 1, T//S, T)
if mask is not None:
mask = mask[:, :, ::self.stride]
# Apply mask
att_scores += (mask * -1e9)
# Att weights (B, H, T//S, Th + T)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T//S, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, H, T//S, d) -> (B, T//S, D)
out = out.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach(), hidden
class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Strided Local Multi-Head Self-Attention with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
kernel_size: attention kernel size / window
stride: query stride
"""
def __init__(self, dim_model, num_heads, causal, kernel_size, stride):
super(StridedLocalRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, kernel_size)
# Assert
assert kernel_size % stride == 0, "Attention kernel size has to be a multiple of attention stride"
# Attention Params
self.kernel_size = kernel_size # key
self.stride = stride # S
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, H, T//S, 2 * key - 1) for full context and (B, H, T//S, key) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, T//key, H, key//S, key)
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T//S, key)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//key, H, key//S, key)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size // self.stride,
self.kernel_size)
# Column Padding (B, T//key, H, key//S, key + S)
att_scores = F.pad(att_scores, pad=(1, self.stride - 1), value=0)
# Reshape (B, T//key, H, 1 + key//S, key)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size // self.stride + 1,
self.kernel_size)
# Slice (B, T//key, H, key//S, key)
att_scores = att_scores[:, :, :, 1:]
# Full Context
else:
# Att Scores (B, H, T//S, 2*key-1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//key, H, key//S, 2*key-1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size // self.stride,
seq_length2)
# Column Padding (B, T//key, H, key//S, 2*key-1 + S)
att_scores = F.pad(att_scores, pad=(0, self.stride), value=0)
# Flatten (B, T//key, H, 2KK//S - key//S + key)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads,
self.kernel_size // self.stride * (2 * self.kernel_size - 1 + self.stride))
# End Padding (B, T//key, H, 2KK//S - key//S + 2K-1)
att_scores = F.pad(att_scores, pad=(0, self.kernel_size - 1), value=0)
# Reshape (B, T//key, H, key//S + 1, 2*key-1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size // self.stride + 1,
seq_length2)
# Slice (B, T//key, H, key//S, key)
att_scores = att_scores[:, :, :, :self.kernel_size // self.stride, self.kernel_size - 1:]
return att_scores
def forward(self, query, key, value, mask=None, hidden=None):
# Batch size B
batch_size = query.size(0)
# Chunk Padding
query, key, value, mask, padding = self.pad(query, key, value, mask, chunk_size=self.kernel_size)
# Query Subsampling (B, T, D) -> (B, T//S, D)
query = query[:, ::self.stride]
# Linear Layers
query = self.query_layer(query)
key = self.key_layer(key)
value = self.value_layer(value)
# Add Bias
Qu = query + self.u
Qv = query + self.v
# Relative Positional Embeddings (B, 2*key-1, D) / (B, key, D)
E = self.pos_layer(self.rel_pos_enc(batch_size))
# Reshape and Transpose (B, T//S, D) -> (B, H, T//S, d)
Qv = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, T//S, D) -> (B, T//key, H, key//S, d)
Qu = Qv.reshape(batch_size, -1, self.kernel_size // self.stride, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, T, D) -> (B, T//key, H, key, d)
key = key.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
value = value.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, 2*key-1, D) -> (B, H, 2*key-1, d) / (B, key, D) -> (B, H, key, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, T//key, H, key//S, key)
att_scores_K = Qu.matmul(key.transpose(3, 4))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / key.shape[-1] ** 0.5
# Mask scores
if mask is not None:
# Diagonal Mask (B, 1, T, T) -> (B, T//key, 1, key, key)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size: (m + 1) * self.kernel_size,
m * self.kernel_size: (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Stride Mask (B, T//key, 1, key, key) -> (B, T//key, 1, key//S, key)
mask = mask[:, :, :, ::self.stride]
# Apply Mask
att_scores = att_scores.float() - mask.float() * 1e9
# Attention weights (B, T//key, H, key//S, key)
att_w = att_scores.softmax(dim=-1)
# Attention output (B, T//key, H, key//S, d)
out = att_w.matmul(value)
# Transpose and Reshape (B, T//key, H, key//S, d) -> (B, T//S, D)
out = out.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
out = out[:, :(self.stride * out.size(1) - padding - 1) // self.stride + 1]
# Output linear layer
out = self.output_layer(out)
return out, att_w.detach(), hidden
###############################################################################
# Positional Encodings
###############################################################################
class SinusoidalPositionalEncoding(nn.Module):
"""
Sinusoidal Positional Encoding
Reference: "Attention Is All You Need" by Vaswani et al.
https://arxiv.org/abs/1706.03762
"""
def __init__(self, max_len, dim_model):
super(SinusoidalPositionalEncoding, self).__init__()
pos_encoding = torch.zeros(max_len, dim_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
i = torch.arange(0, dim_model // 2, dtype=torch.float).unsqueeze(0)
angles = pos / 10000 ** (2 * i / dim_model)
pos_encoding[:, 0::2] = angles.sin()
pos_encoding[:, 1::2] = angles.cos()
pos_encoding = pos_encoding.unsqueeze(0)
self.register_buffer('pos_encoding', pos_encoding, persistent=False)
def forward(self, batch_size=1, seq_len=None):
# (B, T, D)
if seq_len is not None:
P = self.pos_encoding[:, :seq_len]
# (B, Tmax, D)
else:
P = self.pos_encoding
return P.repeat(batch_size, 1, 1)
class RelativeSinusoidalPositionalEncoding(nn.Module):
"""
Relative Sinusoidal Positional Encoding
Positional encoding for left context (sin) and right context (cos)
Total context = 2 * max_len - 1
"""
def __init__(self, max_len, dim_model, causal=False):
super(RelativeSinusoidalPositionalEncoding, self).__init__()
# PE
pos_encoding = torch.zeros(2 * max_len - 1, dim_model)
# Positions (max_len - 1, ..., max_len - 1)
pos_left = torch.arange(start=max_len - 1, end=0, step=-1, dtype=torch.float)
pos_right = torch.arange(start=0, end=-max_len, step=-1, dtype=torch.float)
pos = torch.cat([pos_left, pos_right], dim=0).unsqueeze(1)
# Angles
angles = pos / 10000 ** (2 * torch.arange(0, dim_model // 2, dtype=torch.float).unsqueeze(0) / dim_model)
# Rel Sinusoidal PE
pos_encoding[:, 0::2] = angles.sin()
pos_encoding[:, 1::2] = angles.cos()
pos_encoding = pos_encoding.unsqueeze(0)
self.register_buffer('pos_encoding', pos_encoding, persistent=False)
self.max_len = max_len
self.causal = causal
def forward(self, batch_size=1, seq_len=None, hidden_len=0):
# Causal Context
if self.causal:
# (B, Th + T, D)
if seq_len is not None:
R = self.pos_encoding[:, self.max_len - seq_len - hidden_len: self.max_len]
# (B, Tmax, D)
else:
R = self.pos_encoding[:, :self.max_len]
# Full Context
else:
# (B, Th + 2*T-1, D)
if seq_len is not None:
R = self.pos_encoding[:, self.max_len - seq_len - hidden_len: self.max_len - 1 + seq_len]
# (B, 2*Tmax-1, D)
else:
R = self.pos_encoding
return R.repeat(batch_size, 1, 1)
class GroupedRelativeSinusoidalPositionalEncoding(nn.Module):
"""
Relative Sinusoidal Positional Encoding for grouped multi-head attention
Positional encoding for left context (sin) and right context (cos)
Total context = 2 * max_len - group_size
"""
def __init__(self, max_len, dim_model, group_size=1, causal=False):
super(GroupedRelativeSinusoidalPositionalEncoding, self).__init__()
# PE
pos_encoding = torch.zeros(2 * max_len - group_size % 2, dim_model)
# Positions (max_len - 1, ..., max_len - 1)
pos_left = torch.arange(start=max_len - 1, end=group_size % 2 - 1, step=-1, dtype=torch.float)
pos_right = torch.arange(start=0, end=-max_len, step=-1, dtype=torch.float)
pos = torch.cat([pos_left, pos_right], dim=0).unsqueeze(1)
# Angles
angles = pos / 10000 ** (2 * torch.arange(0, dim_model // 2, dtype=torch.float).unsqueeze(0) / dim_model)
# Rel Sinusoidal PE
pos_encoding[:, 0::2] = angles.sin()
pos_encoding[:, 1::2] = angles.cos()
pos_encoding = pos_encoding.unsqueeze(0)
self.register_buffer('pos_encoding', pos_encoding, persistent=False)
self.max_len = max_len
self.causal = causal
self.group_size = group_size
def forward(self, batch_size=1, seq_len=None, hidden_len=0):
# Causal Context
if self.causal:
# (B, Th + T, D)
if seq_len is not None:
R = self.pos_encoding[:, self.max_len - seq_len - hidden_len: self.max_len]
# (B, Tmax, D)
else:
R = self.pos_encoding[:, :self.max_len]
else:
# (B, Th + 2*T-G, D)
if seq_len is not None:
R = self.pos_encoding[:,
self.max_len - seq_len + self.group_size // 2 - hidden_len: self.max_len - self.group_size % 2 + seq_len - self.group_size // 2]
# (B, 2*Tmax-G, D)
else:
R = self.pos_encoding
return R.repeat(batch_size, 1, 1)
class MultiHeadSelfAttentionModule(nn.Module):
"""Multi-Head Self-Attention Module
Args:
dim_model: model feature dimension
num_heads: number of attention heads
Pdrop: residual dropout probability
max_pos_encoding: maximum position
relative_pos_enc: whether to use relative postion embedding
causal: True for causal attention with masked future context
group_size: Attention group size
kernel_size: Attention kernel size
stride: Query stride
linear_att: whether to use multi-head linear self-attention
"""
def __init__(self,
dim_model,
num_heads,
Pdrop,
max_pos_encoding,
relative_pos_enc,
causal,
group_size,
kernel_size,
stride,
linear_att):
super(MultiHeadSelfAttentionModule, self).__init__()
# Assert
assert not (group_size > 1 and kernel_size is not None), "Local grouped attention not implemented"
assert not (group_size > 1 and stride > 1 is not None), "Strided grouped attention not implemented"
assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings"
# Pre Norm
self.norm = nn.LayerNorm(dim_model, eps=1e-6)
# Multi-Head Linear Attention
if linear_att:
self.mhsa = MultiHeadLinearAttention(dim_model, num_heads)
# Grouped Multi-Head Self-Attention
elif group_size > 1:
if relative_pos_enc:
self.mhsa = GroupedRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, max_pos_encoding,
group_size)
else:
self.mhsa = GroupedMultiHeadAttention(dim_model, num_heads, group_size)
# Local Multi-Head Self-Attention
elif kernel_size is not None and stride == 1:
if relative_pos_enc:
self.mhsa = LocalRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, kernel_size)
else:
self.mhsa = LocalMultiHeadAttention(dim_model, num_heads, kernel_size)
# Strided Multi-Head Self-Attention
elif kernel_size is None and stride > 1:
if relative_pos_enc:
self.mhsa = StridedRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, max_pos_encoding, stride)
else:
self.mhsa = StridedMultiHeadAttention(dim_model, num_heads, stride)
# Strided Local Multi-Head Self-Attention
elif stride > 1 and kernel_size is not None:
if relative_pos_enc:
self.mhsa = StridedLocalRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, kernel_size, stride)
else:
self.mhsa = StridedLocalMultiHeadAttention(dim_model, num_heads, kernel_size, stride)
# Multi-Head Self-Attention
else:
if relative_pos_enc:
self.mhsa = RelPosMultiHeadSelfAttention(dim_model, num_heads, causal, max_pos_encoding)
else:
self.mhsa = MultiHeadAttention(dim_model, num_heads)
# Dropout
self.dropout = nn.Dropout(Pdrop)
# Module Params
self.rel_pos_enc = relative_pos_enc
self.linear_att = linear_att
def forward(self, x, mask=None, hidden=None):
# Pre Norm
x = self.norm(x)
# Multi-Head Self-Attention
if self.linear_att:
x, attention = self.mhsa(x, x, x)
elif self.rel_pos_enc:
x, attention, hidden = self.mhsa(x, x, x, mask, hidden)
else:
x, attention = self.mhsa(x, x, x, mask)
# Dropout
x = self.dropout(x)
return x, attention, hidden
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: di.wu@mobvoi.com (DI WU)
"""ConvolutionModule definition."""
from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from fairseq.modules.layer_norm import LayerNorm
from fairseq.modules.activations import get_activation_class from fairseq.modules.activations import get_activation_class
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.""" """Convolution block used in the conformer block"""
def __init__(self,
channels: int,
kernel_size: int = 15,
norm: str = "batch_norm",
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d( def __init__(
self,
embed_dim,
channels, channels,
depthwise_kernel_size,
dropout,
activation_fn="swish",
bias=False,
export=False,
):
"""
Args:
embed_dim: Embedding dimension
channels: Number of channels in depthwise conv layers
depthwise_kernel_size: Depthwise conv layer kernel size
dropout: dropout value
activation_fn: Activation function to use after depthwise convolution kernel
bias: If bias should be added to conv layers
export: If layernorm should be exported to jit
"""
super(ConvolutionModule, self).__init__()
assert (
depthwise_kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
self.pointwise_conv1 = torch.nn.Conv1d(
embed_dim,
2 * channels, 2 * channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.glu = torch.nn.GLU(dim=1)
# kernel_size should be an odd number for none causal convolution self.depthwise_conv = torch.nn.Conv1d(
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels, channels,
channels, channels,
kernel_size, depthwise_kernel_size,
stride=1, stride=1,
padding=padding, padding=(depthwise_kernel_size - 1) // 2,
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.batch_norm = nn.BatchNorm1d(channels)
assert norm in ['batch_norm', 'layer_norm'] self.activation = get_activation_class(activation_fn)
if norm == "batch_norm": self.pointwise_conv2 = torch.nn.Conv1d(
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels, channels,
embed_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.activation = get_activation_class("swish") self.dropout = torch.nn.Dropout(dropout)
def forward( def forward(self, x):
self, """
x: torch.Tensor,
mask_pad: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args: Args:
x (torch.Tensor): Input tensor (#batch, time, channels). x: Input of shape B X T X C
mask_pad (torch.Tensor): used for batch padding
Returns: Returns:
torch.Tensor: Output tensor (#batch, time, channels). Tensor of shape B X T X C
""" """
# exchange the temporal dimension and the feature dimension # exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) x = x.transpose(1, 2)
zero_mask_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
# mask batch padding
if mask_pad is not None:
x.masked_fill_(zero_mask_pad, 0.0)
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, time) x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, time) x = self.glu(x) # (batch, channel, dim)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
if self.use_layer_norm: x = self.batch_norm(x)
x = x.transpose(1, 2) x = self.activation(x)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if zero_mask_pad is not None:
x.masked_fill_(zero_mask_pad, 0.0)
x = self.pointwise_conv2(x)
x = self.dropout(x)
return x.transpose(1, 2) return x.transpose(1, 2)
#
# class ConvolutionModule(nn.Module):
# """ConvolutionModule in Conformer model."""
# def __init__(self,
# channels: int,
# kernel_size: int = 15,
# norm: str = "batch_norm",
# bias: bool = True):
# """Construct an ConvolutionModule object.
# Args:
# channels (int): The number of channels of conv layers.
# kernel_size (int): Kernel size of conv layers.
# causal (int): Whether use causal convolution or not
# """
# super().__init__()
#
# self.pointwise_conv1 = nn.Conv1d(
# channels,
# 2 * channels,
# kernel_size=1,
# stride=1,
# padding=0,
# bias=bias,
# )
#
# # kernel_size should be an odd number for none causal convolution
# assert (kernel_size - 1) % 2 == 0
# padding = (kernel_size - 1) // 2
#
# self.depthwise_conv = nn.Conv1d(
# channels,
# channels,
# kernel_size,
# stride=1,
# padding=padding,
# groups=channels,
# bias=bias,
# )
#
# assert norm in ['batch_norm', 'layer_norm']
# if norm == "batch_norm":
# self.use_layer_norm = False
# self.norm = nn.BatchNorm1d(channels)
# else:
# self.use_layer_norm = True
# self.norm = LayerNorm(channels)
#
# self.pointwise_conv2 = nn.Conv1d(
# channels,
# channels,
# kernel_size=1,
# stride=1,
# padding=0,
# bias=bias,
# )
# self.activation = get_activation_class("swish")
#
# def forward(
# self,
# x: torch.Tensor,
# mask_pad: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """Compute convolution module.
# Args:
# x (torch.Tensor): Input tensor (#batch, time, channels).
# mask_pad (torch.Tensor): used for batch padding
# Returns:
# torch.Tensor: Output tensor (#batch, time, channels).
# """
# # exchange the temporal dimension and the feature dimension
# x = x.transpose(1, 2)
#
# # zero_mask_pad = mask_pad.unsqueeze(1).repeat(1, x.size(1), 1)
# # # mask batch padding
# # if mask_pad is not None:
# # x.masked_fill_(zero_mask_pad, 0.0)
#
# # GLU mechanism
# x = self.pointwise_conv1(x) # (batch, 2*channel, time)
# x = nn.functional.glu(x, dim=1) # (batch, channel, time)
#
# # 1D Depthwise Conv
# x = self.depthwise_conv(x)
# if self.use_layer_norm:
# x = x.transpose(1, 2)
# x = self.activation(self.norm(x))
# if self.use_layer_norm:
# x = x.transpose(1, 2)
# x = self.pointwise_conv2(x)
#
# # # mask batch padding
# # if zero_mask_pad is not None:
# # x.masked_fill_(zero_mask_pad, 0.0)
#
# return x.transpose(1, 2)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention layer definition."""
import math
import torch
from torch import nn
from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
)
class ESPNETMultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
"""
def __init__(self, n_feat, n_head, dropout):
"""Construct an MultiHeadedAttention object."""
super(ESPNETMultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward_qkv(self, query, key, value, **kwargs):
"""Transform query, key and value.
Args:
query: Query tensor B X T1 X C
key: Key tensor B X T2 X C
value: Value tensor B X T2 X C
Returns:
torch.Tensor: Transformed query tensor B X n_head X T1 X d_k
torch.Tensor: Transformed key tensor B X n_head X T2 X d_k
torch.Tensor: Transformed value tensor B X n_head X T2 X d_k
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value: Transformed value B X n_head X T2 X d_k.
scores: Attention score B X n_head X T1 X T2
mask: Mask T2 X B
Returns:
torch.Tensor: Transformed value B X T1 X d_model
weighted by the attention score B X T1 X T2
"""
n_batch = value.size(0)
if mask is not None:
scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool),
float("-inf"), # (batch, head, time1, time2)
)
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor T X B X C
key (torch.Tensor): Key tensor T X B X C
value (torch.Tensor): Value tensor T X B X C
mask (torch.Tensor): Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X D.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x: Input tensor B X n_head X T X 2T-1
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor B X 2T-1 X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
"""
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
pos_emb = pos_emb.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
def __init__(
self,
n_feat,
n_head,
dropout,
precision,
rotary_emd_base=10000,
):
"""Construct an RotaryPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout)
precision = torch.float
self.rotary_ndims = self.d_k # also try self.d_k//2
if precision == "fp16":
precision = torch.half
self.rotary_emb = RotaryPositionalEmbedding(
self.rotary_ndims, base=rotary_emd_base, precision=precision
)
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
"""Compute rotary position attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X D.
Notes:
Assumes self attn
"""
T, B, C = value.size()
query = query.view(T, B, self.h, self.d_k)
key = key.view(T, B, self.h, self.d_k)
value = value.view(T, B, self.h, self.d_k)
cos, sin = self.rotary_emb(value, seq_len=T)
query, key = apply_rotary_pos_emb(
query, key, cos, sin, offset=0
) # offset is based on layer_past
query = query.view(T, B, self.h * self.d_k)
key = key.view(T, B, self.h * self.d_k)
value = value.view(T, B, self.h * self.d_k)
# TBD to BTD
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = self.forward_attention(v, scores, key_padding_mask)
scores = scores.transpose(0, 1)
return scores, None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch
import torch.nn.functional as F
class LocationAttention(nn.Module):
"""
Attention-Based Models for Speech Recognition
https://arxiv.org/pdf/1506.07503.pdf
:param int encoder_dim: # projection-units of encoder
:param int decoder_dim: # units of decoder
:param int attn_dim: attention dimension
:param int conv_dim: # channels of attention convolution
:param int conv_kernel_size: filter size of attention convolution
"""
def __init__(
self,
attn_dim,
encoder_dim,
decoder_dim,
attn_state_kernel_size,
conv_dim,
conv_kernel_size,
scaling=2.0,
):
super(LocationAttention, self).__init__()
self.attn_dim = attn_dim
self.decoder_dim = decoder_dim
self.scaling = scaling
self.proj_enc = nn.Linear(encoder_dim, attn_dim)
self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False)
self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False)
self.conv = nn.Conv1d(
attn_state_kernel_size,
conv_dim,
2 * conv_kernel_size + 1,
padding=conv_kernel_size,
bias=False,
)
self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1))
self.proj_enc_out = None # cache
def clear_cache(self):
self.proj_enc_out = None
def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state):
"""
:param torch.Tensor encoder_out: padded encoder hidden state B x T x D
:param torch.Tensor encoder_padding_mask: encoder padding mask
:param torch.Tensor decoder_h: decoder hidden state B x D
:param torch.Tensor attn_prev: previous attention weight B x K x T
:return: attention weighted encoder state (B, D)
:rtype: torch.Tensor
:return: previous attention weights (B x T)
:rtype: torch.Tensor
"""
bsz, seq_len, _ = encoder_out.size()
if self.proj_enc_out is None:
self.proj_enc_out = self.proj_enc(encoder_out)
# B x K x T -> B x C x T
attn = self.conv(attn_state)
# B x C x T -> B x T x C -> B x T x D
attn = self.proj_attn(attn.transpose(1, 2))
if decoder_h is None:
decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim)
dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim)
out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2)
out.masked_fill_(encoder_padding_mask, -float("inf"))
w = F.softmax(self.scaling * out, dim=1)
c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1)
return c, w
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import math
import torch
class PositionalEncoding(nn.Module):
"""Positional encoding.
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
reverse: Whether to reverse the input position.
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.reverse = reverse
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor B X T X C
Returns:
torch.Tensor: Encoded tensor B X T X C
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1)]
return self.dropout(x)
class RelPositionalEncoding(nn.Module):
"""Relative positional encoding module (new implementation).
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, max_len, d_model):
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x : Input tensor T X B X C.
Returns:
torch.Tensor: Encoded tensor T X B X C.
"""
x = x.transpose(0, 1) # Change TBC to BTC
self.extend_pe(x)
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
]
pos_emb = pos_emb.transpose(0, 1) # change to TBC
return pos_emb
import torch
class RotaryPositionalEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/
Paper: https://arxiv.org/pdf/2104.09864.pdf
Args:
dim: Dimension of embedding
base: Base value for exponential
precision: precision to use for numerical values
"""
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
def forward(self, x, seq_len=None):
"""
Args:
x: Input x with T X B X C
seq_len: Sequence length of input x
"""
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
return self.cos_cached, self.sin_cached
# rotary pos emb helpers:
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat(
(-x2, x1), dim=x1.ndim - 1
) # dim=-1 triggers a bug in earlier torch versions
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
cos, sin = (
cos[offset : q.shape[0] + offset, ...],
sin[offset : q.shape[0] + offset, ...],
)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
...@@ -7,17 +7,62 @@ from typing import Optional ...@@ -7,17 +7,62 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
LayerNorm, LayerNorm,
MultiheadAttention, MultiheadAttention,
RelPositionMultiheadAttention, RelPositionMultiheadAttention,
RelativeMultiheadAttention, RelativeMultiheadAttention,
ConvolutionModule ConvolutionModule,
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
) )
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
from fairseq.modules.activations import get_activation_fn, get_activation_class
class FeedForwardModule(torch.nn.Module):
"""Positionwise feed forward layer used in conformer"""
def __init__(
self,
input_feat,
hidden_units,
dropout1,
dropout2,
activation_fn="relu",
bias=True,
):
"""
Args:
input_feat: Input feature dimension
hidden_units: Hidden unit dimension
dropout1: dropout value for layer1
dropout2: dropout value for layer2
activation_fn: Name of activation function
bias: If linear layers should have bias
"""
super(FeedForwardModule, self).__init__()
self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias)
self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias)
self.dropout1 = torch.nn.Dropout(dropout1)
self.dropout2 = torch.nn.Dropout(dropout2)
self.activation = get_activation_class(activation_fn)
def forward(self, x):
"""
Args:
x: Input Tensor of shape T X B X C
Returns:
Tensor of shape T X B X C
"""
x = self.w_1(x)
x = self.activation(x)
x = self.dropout1(x)
x = self.w_2(x)
return self.dropout2(x)
class S2TTransformerEncoderLayer(nn.Module): class S2TTransformerEncoderLayer(nn.Module):
...@@ -38,6 +83,10 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -38,6 +83,10 @@ class S2TTransformerEncoderLayer(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
embed_dim = args.encoder_embed_dim
ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout
self.embed_dim = args.encoder_embed_dim self.embed_dim = args.encoder_embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
...@@ -45,77 +94,53 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -45,77 +94,53 @@ class S2TTransformerEncoderLayer(nn.Module):
self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__ dropout, module_name=self.__class__.__name__
)
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu') or "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
) )
self.normalize_before = args.encoder_normalize_before
activation = getattr(args, 'encoder_activation_fn', 'relu')
if args.macaron_style: if args.macaron_style:
self.macaron_fc1 = self.build_fc1( self.macaron_ffn = FeedForwardModule(
self.embed_dim, embed_dim,
args.encoder_ffn_embed_dim, ffn_dim,
self.quant_noise, dropout,
self.quant_noise_block_size, dropout,
) activation
self.macaron_fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
) )
self.macaron_norm = LayerNorm(self.embed_dim) self.macaron_norm = LayerNorm(embed_dim)
self.ffn_scale = 0.5 self.ffn_scale = 0.5
else: else:
self.macaron_fc1 = None self.macaron_ffn = None
self.macaron_fc2 = None
self.macaron_norm = None self.macaron_norm = None
self.ffn_scale = 1.0 self.ffn_scale = 1.0
if args.use_cnn_module: if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim) self.conv_norm = LayerNorm(embed_dim)
self.conv_module = ConvolutionModule( self.conv_module = ConvolutionModule(
self.embed_dim, self.embed_dim,
args.cnn_module_kernel) self.embed_dim,
self.final_norm = LayerNorm(self.embed_dim) depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish'))
self.final_norm = LayerNorm(embed_dim)
else: else:
self.conv_norm = None self.conv_norm = None
self.conv_module = None self.conv_module = None
self.final_norm = None self.final_norm = None
self.normalize_before = args.encoder_normalize_before self.ffn = FeedForwardModule(
self.fc1 = self.build_fc1( embed_dim,
self.embed_dim, ffn_dim,
args.encoder_ffn_embed_dim, dropout,
self.quant_noise, dropout,
self.quant_noise_block_size, activation
)
self.fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
) )
self.ffn_norm = LayerNorm(self.embed_dim) self.ffn_norm = LayerNorm(self.embed_dim)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_self_attention(self, embed_dim, args): def build_self_attention(self, embed_dim, args):
attention_heads = args.encoder_attention_heads
dropout = args.dropout
if self.attn_type == "selfattn": if self.attn_type == "selfattn":
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
...@@ -125,8 +150,8 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -125,8 +150,8 @@ class S2TTransformerEncoderLayer(nn.Module):
if max_relative_length != -1: if max_relative_length != -1:
return RelativeMultiheadAttention( return RelativeMultiheadAttention(
embed_dim, embed_dim,
args.encoder_attention_heads, attention_heads,
dropout=args.attention_dropout, dropout=dropout,
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
...@@ -135,6 +160,25 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -135,6 +160,25 @@ class S2TTransformerEncoderLayer(nn.Module):
else: else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length) print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1) exit(1)
elif self.attn_type == "rel_pos":
return RelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
elif self.attn_type == "rope":
return RotaryPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
precision=args.fp16
)
elif self.attn_type == "abs":
return ESPNETMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
else: else:
attn_func = MultiheadAttention attn_func = MultiheadAttention
print("The attention type %s is not supported!" % self.attn_type) print("The attention type %s is not supported!" % self.attn_type)
...@@ -191,23 +235,23 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -191,23 +235,23 @@ class S2TTransformerEncoderLayer(nn.Module):
# Note that we cannot use -inf here, because at some edge cases, # Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query # the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters # will become -inf, which results in NaN in model parameters
if attn_mask is not None: # if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) # attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# whether to use macaron style # whether to use macaron style
if self.macaron_norm is not None: if self.macaron_norm is not None:
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.macaron_norm(x) x = self.macaron_norm(x)
x = self.macaron_fc2(self.activation_dropout_module(self.activation_fn(self.macaron_fc1(x)))) x = self.macaron_ffn(x)
x = residual + self.ffn_scale * self.dropout_module(x) x = residual + self.ffn_scale * x
if not self.normalize_before: if not self.normalize_before:
x = self.macaron_norm(x) x = self.macaron_norm(x)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn": if self.attn_type == "rel_selfattn" or self.attn_type == "rel_pos":
assert pos_emb is not None, "Positions is necessary for RPE!" assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
...@@ -234,22 +278,23 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -234,22 +278,23 @@ class S2TTransformerEncoderLayer(nn.Module):
# convolution module # convolution module
if self.conv_module is not None: if self.conv_module is not None:
x = x.transpose(0, 1)
residual = x residual = x
x = x.transpose(0, 1)
if self.normalize_before: if self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = residual + self.dropout_module(self.conv_module(x, encoder_padding_mask))
x = self.conv_module(x)
x = x.transpose(0, 1)
x = residual + x
if not self.normalize_before: if not self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = x.transpose(0, 1)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.ffn_norm(x) x = self.ffn_norm(x)
x = self.activation_fn(self.fc1(x)) x = self.ffn(x)
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(self.ffn_scale * x, residual) x = self.residual_connection(self.ffn_scale * x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.ffn_norm(x) x = self.ffn_norm(x)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List
from fairseq.modules.activations import Swish from fairseq.modules.activations import Swish
from fairseq.modules.layer_norm import LayerNorm from fairseq.modules.layer_norm import LayerNorm
...@@ -46,6 +48,61 @@ def get_norm(norm_type, size, transpose=False): ...@@ -46,6 +48,61 @@ def get_norm(norm_type, size, transpose=False):
raise RuntimeError("normalization type {} not supported".format(norm_type)) raise RuntimeError("normalization type {} not supported".format(norm_type))
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
kernel_sizes: List[int] = (3, 3),
):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
inner_x = []
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
inner_x.append(x)
_, _, out_seq_len = x.size()
# x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
out_inner_x = []
for x in inner_x:
out_inner_x.append(x.transpose(1, 2).transpose(0, 1).contiguous())
return out_inner_x, self.get_out_seq_lens_tensor(src_lengths)
# fairseq style
class Conv1dSubsampling(nn.Module): class Conv1dSubsampling(nn.Module):
"""Conv1d Subsampling Block """Conv1d Subsampling Block
...@@ -74,12 +131,14 @@ class Conv1dSubsampling(nn.Module): ...@@ -74,12 +131,14 @@ class Conv1dSubsampling(nn.Module):
# Layers # Layers
self.layers = nn.ModuleList([nn.Sequential( self.layers = nn.ModuleList([nn.Sequential(
nn.Conv1d(in_dim if layer_id == 0 else filters[layer_id - 1], nn.Conv1d(in_dim if layer_id == 0 else filters[layer_id - 1] // 2 if act == "glu" else filters[layer_id - 1],
filters[layer_id] * 2 if act == "glu" else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2), padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False), get_norm(norm,
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1) get_activation_class(act, dim=1)
) for layer_id in range(num_layers)]) ) for layer_id in range(num_layers)])
...@@ -126,12 +185,14 @@ class Conv2dSubsampling(nn.Module): ...@@ -126,12 +185,14 @@ class Conv2dSubsampling(nn.Module):
# Conv 2D Subsampling Layers # Conv 2D Subsampling Layers
self.layers = nn.ModuleList([nn.Sequential( self.layers = nn.ModuleList([nn.Sequential(
nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1], nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1] // 2 if act == "glu" else filters[layer_id - 1],
filters[layer_id] * 2 if act =="glu" else filters[layer_id], filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2), padding=(kernel_size - 1) // 2),
get_norm(norm, filters[layer_id], transpose=True if norm == "layer" else False), get_norm(norm,
filters[layer_id] * 2 if act == "glu" and layer_id == num_layers - 1 else filters[layer_id],
transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1) get_activation_class(act, dim=1)
) for layer_id in range(num_layers)]) ) for layer_id in range(num_layers)])
self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1]) self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1])
...@@ -139,7 +200,7 @@ class Conv2dSubsampling(nn.Module): ...@@ -139,7 +200,7 @@ class Conv2dSubsampling(nn.Module):
def forward(self, x, x_len): def forward(self, x, x_len):
# (B, T, D) -> (B, D, T) -> (B, 1, D, T) # (B, T, D) -> (B, D, T) -> (B, 1, D, T)
x = x.tranpose(1, 2).unsqueeze(dim=1) x = x.transpose(1, 2).unsqueeze(dim=1)
# Layers # Layers
for layer in self.layers: for layer in self.layers:
......
...@@ -17,7 +17,9 @@ from typing import Callable, Dict, List, Optional ...@@ -17,7 +17,9 @@ from typing import Callable, Dict, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules.multihead_attention import MultiheadAttention from fairseq.modules.multihead_attention import MultiheadAttention
from torch import Tensor from torch import Tensor
...@@ -514,6 +516,8 @@ def get_activation_fn(activation: str) -> Callable: ...@@ -514,6 +516,8 @@ def get_activation_fn(activation: str) -> Callable:
return torch.tanh return torch.tanh
elif activation == "linear": elif activation == "linear":
return lambda x: x return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else: else:
raise RuntimeError("--activation-fn {} not supported".format(activation)) raise RuntimeError("--activation-fn {} not supported".format(activation))
...@@ -526,6 +530,7 @@ def get_available_activation_fns() -> List: ...@@ -526,6 +530,7 @@ def get_available_activation_fns() -> List:
"gelu_accurate", "gelu_accurate",
"tanh", "tanh",
"linear", "linear",
"swish",
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论