Commit 5fb50cc3 by xuchen

implement the dual speech-to-text (need to optimize) and CTC loss for MT

parent 6cbfe851
arch: pdss2t_transformer_s_8 arch: s2t_dual
#pds-ctc: 0_1_1_0 #pds-ctc: 0_1_1_0
#intermedia-adapter: league #intermedia-adapter: league
...@@ -9,11 +9,15 @@ arch: pdss2t_transformer_s_8 ...@@ -9,11 +9,15 @@ arch: pdss2t_transformer_s_8
#attention-reduced-method: pool #attention-reduced-method: pool
#attention-reduced-q: True #attention-reduced-q: True
inter-mixup: True #inter-mixup: True
inter-mixup-layer: 0 #inter-mixup-layer: 0
inter-mixup-beta: 0.5 #inter-mixup-beta: 0.5
encoder-embed-dim: 384 asr-encoder: sate
mt-encoder-layers: 3
mt-encoder-dim: 256
encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
#ctc-layer: 15 #ctc-layer: 15
encoder-layers: 6 encoder-layers: 6
...@@ -39,7 +43,7 @@ lr: 2e-3 ...@@ -39,7 +43,7 @@ lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
ctc-weight: 0.3 ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: join_speech_and_text_loss
label_smoothing: 0.1 label_smoothing: 0.1
dropout: 0.1 dropout: 0.1
......
train-subset: train train-subset: train
valid-subset: valid valid-subset: dev
max-epoch: 50 max-epoch: 50
max-update: 100000 max-update: 100000
......
arch: transformer
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3
intermedia-ctc-layers: 2,4
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
ctc-weight: 0.2 ctc-weight: 0.2
intermedia-ctc-layers: 6,9 intermedia-ctc-layers: 6,9
intermedia-adapter: league intermedia-adapter: league
intermedia-ctc-weight: 0.1 intermedia-ctc-weight: 0.1
#intermedia-drop-prob: 0.2 #intermedia-drop-prob: 0.2
#intermedia-temperature: 5
#target-ctc-weight: 0.5
#target-ctc-layers: 2,4,6
ctc-self-distill-weight: 0 ctc-self-distill-weight: 0
post-process: sentencepiece post-process: sentencepiece
\ No newline at end of file
...@@ -13,7 +13,7 @@ pds-stages: 4 ...@@ -13,7 +13,7 @@ pds-stages: 4
#ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
#pds-fusion: True pds-fusion: True
pds-fusion-method: all_conv pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256 pds-embed-dims: 256_256_256_256
pds-ds-method: conv pds-ds-method: conv
......
...@@ -29,7 +29,7 @@ acoustic-encoder: pds ...@@ -29,7 +29,7 @@ acoustic-encoder: pds
adapter: league adapter: league
encoder-embed-dim: 256 encoder-embed-dim: 256
ctc-layer: 12 #ctc-layer: 12
pds-stages: 4 pds-stages: 4
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
......
...@@ -23,6 +23,7 @@ from fairseq.logging.meters import safe_round ...@@ -23,6 +23,7 @@ from fairseq.logging.meters import safe_round
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass
class CtcCriterionConfig(FairseqDataclass): class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field( zero_infinity: bool = field(
...@@ -48,6 +49,10 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -48,6 +49,10 @@ class CtcCriterionConfig(FairseqDataclass):
) )
target_ctc_weight: float = field( target_ctc_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of CTC loss for target sentence"},
)
target_intermedia_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of intermedia CTC loss for target sentence"}, metadata={"help": "weight of intermedia CTC loss for target sentence"},
) )
ctc_self_distill_weight: float = field( ctc_self_distill_weight: float = field(
...@@ -124,13 +129,15 @@ class CtcCriterion(FairseqCriterion): ...@@ -124,13 +129,15 @@ class CtcCriterion(FairseqCriterion):
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.intermedia_ctc_weight = cfg.intermedia_ctc_weight self.intermedia_ctc_weight = cfg.intermedia_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight self.target_ctc_weight = cfg.target_ctc_weight
self.target_intermedia_ctc_weight = cfg.target_intermedia_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy self.ctc_entropy = cfg.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + self.target_ctc_weight + \ self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + \
self.target_ctc_weight + self.target_intermedia_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy self.ctc_self_distill_weight + self.ctc_entropy
if self.all_ctc_weight > 0: if self.all_ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary." # assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True) self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
...@@ -159,7 +166,10 @@ class CtcCriterion(FairseqCriterion): ...@@ -159,7 +166,10 @@ class CtcCriterion(FairseqCriterion):
return ctc_loss return ctc_loss
def compute_ctc_loss(self, model, sample, net_output, logging_output): def compute_ctc_loss(self, model, sample, net_output, logging_output):
transcript = sample["transcript"] if "transcript" in sample:
tokens = sample["transcript"]["tokens"]
else:
tokens = sample["target"]
# if "ctc_padding_mask" in net_output: # if "ctc_padding_mask" in net_output:
# non_padding_mask = ~net_output["ctc_padding_mask"][0] # non_padding_mask = ~net_output["ctc_padding_mask"][0]
# else: # else:
...@@ -175,21 +185,21 @@ class CtcCriterion(FairseqCriterion): ...@@ -175,21 +185,21 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask = ~net_output["encoder_padding_mask"][0] non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (transcript["tokens"] != self.pad_idx) & ( pad_mask = (tokens != self.pad_idx) & (
transcript["tokens"] != self.eos_idx tokens != self.eos_idx
) )
if mixup: if mixup:
mask1 = pad_mask[mixup_idx1] mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2] mask2 = pad_mask[mixup_idx2]
transcript_flat1 = transcript["tokens"][[mixup_idx1]].masked_select(mask1) transcript_flat1 = tokens[[mixup_idx1]].masked_select(mask1)
transcript_flat2 = transcript["tokens"][mixup_idx2].masked_select(mask2) transcript_flat2 = tokens[mixup_idx2].masked_select(mask2)
transcript_lengths1 = mask1.sum(-1) transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1) transcript_lengths2 = mask2.sum(-1)
transcript_flat = [transcript_flat1, transcript_flat2] transcript_flat = [transcript_flat1, transcript_flat2]
transcript_lengths = [transcript_lengths1, transcript_lengths2] transcript_lengths = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef] loss_coef = [mixup_coef, 1 - mixup_coef]
else: else:
transcript_flat = [transcript["tokens"].masked_select(pad_mask)] transcript_flat = [tokens.masked_select(pad_mask)]
transcript_lengths = [pad_mask.sum(-1)] transcript_lengths = [pad_mask.sum(-1)]
loss_coef = [1] loss_coef = [1]
...@@ -247,13 +257,11 @@ class CtcCriterion(FairseqCriterion): ...@@ -247,13 +257,11 @@ class CtcCriterion(FairseqCriterion):
if lprobs is None: if lprobs is None:
lprobs = inter_lprobs lprobs = inter_lprobs
target_ctc_num = 0
target_ctc_loss = 0 target_ctc_loss = 0
if "target_ctc_logits" in net_output: target_intermedia_ctc_loss = 0
target_ctc_num = len(net_output["target_ctc_logits"])
# calculate the target CTC loss # calculate the target CTC loss
if self.target_ctc_weight > 0 and target_ctc_num > 0: if self.target_ctc_weight > 0 or self.target_intermedia_ctc_weight:
target = sample["target"] target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx) pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
...@@ -272,8 +280,24 @@ class CtcCriterion(FairseqCriterion): ...@@ -272,8 +280,24 @@ class CtcCriterion(FairseqCriterion):
target_length = [pad_mask.sum(-1)] target_length = [pad_mask.sum(-1)]
loss_coef = [1] loss_coef = [1]
for i in range(target_ctc_num): if self.target_ctc_weight > 0:
out = net_output["target_ctc_logits"][i] assert "target_ctc_logit" in net_output
target_ctc_logit = net_output["target_ctc_logit"]
tgt_lprobs = model.get_normalized_probs(
[target_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
tgt_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_lprobs, flat, input_lengths, lengths) * coef
target_intermedia_ctc_num = 0
if "target_intermedia_ctc_logits" in net_output:
target_intermedia_ctc_num = len(net_output["target_intermedia_ctc_logits"])
for i in range(target_intermedia_ctc_num):
out = net_output["target_intermedia_ctc_logits"][i]
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
...@@ -288,10 +312,10 @@ class CtcCriterion(FairseqCriterion): ...@@ -288,10 +312,10 @@ class CtcCriterion(FairseqCriterion):
tgt_inter_lprobs.batch_first = False tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef): for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef target_intermedia_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
target_ctc_loss /= target_ctc_num target_intermedia_ctc_loss /= target_intermedia_ctc_num
logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data) logging_output["target_intermedia_ctc_loss"] = utils.item(target_intermedia_ctc_loss.data)
# calculate the self distillation CTC loss # calculate the self distillation CTC loss
ctc_self_distill_loss = 0 ctc_self_distill_loss = 0
...@@ -344,7 +368,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -344,7 +368,7 @@ class CtcCriterion(FairseqCriterion):
with torch.no_grad(): with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = sample["transcript"]["tokens"] if "transcript" in sample else sample["target"] target = tokens
if mixup: if mixup:
idx = mixup_idx1 idx = mixup_idx1
if mixup_coef < 0.5: if mixup_coef < 0.5:
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from .ctc import CtcCriterion, CtcCriterionConfig
@register_criterion("join_speech_and_text_loss")
class JoinSpeechTextLoss(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, label_smoothing,
sentence_avg,
cfg: CtcCriterionConfig,
ctc_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing)
self.report_accuracy = True
self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight)
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser)
CtcCriterion.add_args(parser)
parser.add_argument(
"--ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of CTC loss",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
speech_tokens, speech_lengths, prev_output_tokens = sample["net_input"].values()
text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"]
encoder_out = model.encoder(speech_tokens, speech_lengths,
text_src_tokens, text_src_lengths)
net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0)
if use_mixup:
sample_size //= 2
n_tokens //= 2
n_sentences //= 2
logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": n_tokens,
"nsentences": n_sentences,
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item(
sum(log.get("trans_loss", 0) for log in logging_outputs)
)
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if trans_loss_sum != loss_sum:
metrics.log_scalar(
"trans_loss", trans_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]:
CtcCriterion.reduce_metrics(logging_outputs)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
...@@ -9,3 +9,4 @@ from .s2t_ctc import * ...@@ -9,3 +9,4 @@ from .s2t_ctc import *
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .pdss2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
from .s2t_dual import * # noqa
...@@ -6,7 +6,6 @@ import torch.nn as nn ...@@ -6,7 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models.transformer import Embedding
from fairseq.modules import LayerNorm from fairseq.modules import LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -75,7 +74,7 @@ class Adapter(nn.Module): ...@@ -75,7 +74,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]: if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
if embed_tokens is None: if embed_tokens is None:
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, dim, dictionary.pad()) self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
else: else:
self.embed_adapter = embed_tokens self.embed_adapter = embed_tokens
......
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
PDSS2TTransformerModel,
PDSS2TTransformerEncoder,
S2TSATEEncoder,
)
from fairseq.models.speech_to_text.modules import Adapter, CTC
from fairseq.models.transformer_s2 import (
Embedding,
TransformerS2Encoder,
TransformerS2Decoder,
)
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
TransformerS2DecoderLayer,
TransformerS2EncoderLayer,
)
logger = logging.getLogger(__name__)
@register_model("s2t_dual")
class S2TDualModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
PDSS2TTransformerModel.add_args(parser)
# SATE setting
parser.add_argument(
"--text-encoder-layers",
default=6,
type=int,
help="layers of the text encoder",
)
parser.add_argument(
"--text-attention-type",
default="selfattn",
type=str,
help="attention type of the textual encoder",
)
parser.add_argument(
"--adapter",
default="league",
type=str,
help="adapter type",
)
parser.add_argument(
"--ctc-compress-strategy",
default="avg",
type=str,
help="compress strategy, such as avg, weighted, and softmax",
)
parser.add_argument(
"--share-ctc-and-adapter",
default=False,
action="store_true",
help="share the projection weights of the ctc and adapter",
)
parser.add_argument(
"--temperature",
default=1.0,
type=float,
help="temperature of the CTC softmax",
)
parser.add_argument(
"--acoustic-encoder",
default="transformer",
type=str,
help="the architecture of the acoustic encoder",
)
parser.add_argument(
"--target-ctc-layers",
default=None,
type=str,
help="ctc layers for target sentence",
)
parser.add_argument(
"--load-pretrained-acoustic-encoder-from",
type=str,
metavar="STR",
help="model to take acoustic encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-text-encoder-from",
type=str,
metavar="STR",
help="model to take text encoder weights from (for initialization)",
)
# multi-encoder
parser.add_argument(
"--asr-encoder",
default="transformer",
type=str,
help="the architecture of the ASR encoder",
)
parser.add_argument(
"--mt-encoder",
default="transformer",
type=str,
help="the architecture of the MT encoder",
)
# parser.add_argument(
# "--mt-encoder-dim",
# default="transformer",
# type=str,
# help="the dimension of the MT encoder",
# )
parser.add_argument(
"--mt-encoder-layers",
default=6,
type=str,
help="the layers of the MT encoder",
)
parser.add_argument(
"--encoder-asr-ratio",
default=1,
type=float,
help="the ratio of the asr representation",
)
parser.add_argument(
"--encoder-mt-ratio",
default=1,
type=float,
help="the ratio of the mt representation",
)
parser.add_argument(
"--encoder-drop-net",
action="store_true",
help="drop an input",
)
parser.add_argument(
"--encoder-drop-net-prob",
default=0.5,
type=float,
help="the probability of dropping",
)
parser.add_argument(
"--encoder-drop-net-mix",
action="store_true",
help="mix the two input with any probability",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TDualEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
if getattr(args, "load_pretrained_asr_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.asr_encoder, checkpoint=args.load_pretrained_asr_encoder_from, strict=False
)
logger.info(
f"loaded pretrained asr encoder from: "
f"{args.load_pretrained_asr_encoder_from}"
)
if getattr(args, "load_pretrained_mt_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.mt_encoder, checkpoint=args.load_pretrained_mt_encoder_from, strict=False
)
logger.info(
f"loaded pretrained mt encoder from: "
f"{args.load_pretrained_mt_encoder_from}"
)
return encoder
@classmethod
def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerS2Decoder(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim
)
setattr(args, "encoder_s1_ratio", args.encoder_asr_ratio)
setattr(args, "encoder_s2_ratio", args.encoder_mt_ratio)
encoder = cls.build_encoder(args, task, encoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
return cls(encoder, decoder)
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def forward(self, speech_src_tokens, speech_src_lengths,
text_src_tokens, text_src_lengths,
prev_output_tokens):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(speech_src_tokens, speech_src_lengths,
text_src_tokens, text_src_lengths)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
return decoder_out
class S2TDualEncoder(FairseqEncoder):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None)
asr_encoder_type = args.asr_encoder
if asr_encoder_type == "transformer":
self.asr_encoder = S2TTransformerEncoder(args, task)
elif asr_encoder_type == "pds":
self.asr_encoder = PDSS2TTransformerEncoder(args, task)
elif asr_encoder_type == "sate":
self.asr_encoder = S2TSATEEncoder(args, task)
else:
logger.error("Unsupported ASR architecture: %s." % asr_encoder_type)
self.mt_encoder = TransformerS2Encoder(args, task.source_dictionary, embed_tokens)
def forward(self, speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths, **kwargs):
asr_encoder_out = self.asr_encoder(speech_src_tokens, speech_src_lengths)
ctc_logit = asr_encoder_out["ctc_logit"]
encoder_representation = asr_encoder_out["encoder_out"][0]
encoder_padding_mask = asr_encoder_out["encoder_padding_mask"][0]
encoder_out = self.mt_encoder(text_src_tokens, text_src_lengths,
encoder_representation, encoder_padding_mask)
encoder_out["ctc_logit"] = ctc_logit
return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
self.mt_encoder.reorder_encoder_out(encoder_out, new_order)
return
@register_model_architecture(model_name="s2t_dual", arch_name="s2t_dual")
def base_architecture(args):
# Convolutional subsampler
args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# SATE
args.acoustic_encoder = getattr(args, "acoustic_encoder", "transformer")
args.adapter = getattr(args, "adapter", "league")
args.ctc_compress_strategy = getattr(args, "ctc_compress_strategy", "avg")
args.temperature = getattr(args, "temperature", 1.0)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.text_attention_type = getattr(args, "text_attention_type", "selfattn")
args.share_ctc_and_adapter = getattr(args, "share_ctc_and_adapter", False)
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
args.pds_layers = getattr(args, "pds_layers", None)
args.pds_ratios = getattr(args, "pds_ratios", None)
args.pds_ds_method = getattr(args, "pds_ds_method", "conv")
args.pds_embed_dims = getattr(args, "pds_embed_dims", None)
args.pds_embed_norm = getattr(args, "pds_embed_norm", True)
args.pds_position_embed = getattr(args, "pds_position_embed", None)
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
args.pds_attn_strides = getattr(args, "pds_attn_strides", "1_1_1_1")
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.pds_dropout = getattr(args, "pds_dropout", args.dropout)
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
# dual
args.encoder_asr_ratio = getattr(args, "encoder_asr_ratio", 1.0)
args.encoder_mt_ratio = getattr(args, "encoder_mt_ratio", 1.0)
args.encoder_drop_net = getattr(args, "encoder_drop_net", False)
args.encoder_drop_net_prob = getattr(args, "encoder_drop_net_prob", 1.0)
args.encoder_drop_net_mix = getattr(args, "encoder_drop_net_mix", False)
@register_model_architecture("s2t_dual", "s2t_dual_s")
def s2t_dual_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
@register_model_architecture("s2t_dual", "s2t_dual_s_relative")
def s2t_dual_s_relative(args):
args.max_encoder_relative_length = 100
args.k_only = True
s2t_dual_s(args)
@register_model_architecture("s2t_dual", "s2t_dual_xs")
def s2t_dual_xs(args):
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
args.dropout = getattr(args, "dropout", 0.3)
s2t_dual_s(args)
@register_model_architecture("s2t_dual", "s2t_dual_sp")
def s2t_dual_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_dual_s(args)
@register_model_architecture("s2t_dual", "s2t_dual_m")
def s2t_dual_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("s2t_dual", "s2t_dual_mp")
def s2t_dual_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_dual_m(args)
@register_model_architecture("s2t_dual", "s2t_dual_l")
def s2t_dual_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("s2t_dual", "s2t_dual_lp")
def s2t_dual_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_dual_l(args)
...@@ -88,12 +88,6 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -88,12 +88,6 @@ class S2TSATEModel(S2TTransformerModel):
help="the architecture of the acoustic encoder", help="the architecture of the acoustic encoder",
) )
parser.add_argument( parser.add_argument(
"--target-ctc-layers",
default=None,
type=str,
help="ctc layers for target sentence",
)
parser.add_argument(
"--load-pretrained-acoustic-encoder-from", "--load-pretrained-acoustic-encoder-from",
type=str, type=str,
metavar="STR", metavar="STR",
...@@ -105,11 +99,24 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -105,11 +99,24 @@ class S2TSATEModel(S2TTransformerModel):
metavar="STR", metavar="STR",
help="model to take text encoder weights from (for initialization)", help="model to take text encoder weights from (for initialization)",
) )
# target CTC
parser.add_argument(
"--target-ctc-layer",
default=None,
type=str,
help="ctc layer for target sentence",
)
parser.add_argument(
"--target-intermedia-ctc-layer",
default=None,
type=str,
help="intermedia ctc layers for target sentence",
)
pass pass
@classmethod @classmethod
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, decoder_embed_tokens=None):
encoder = S2TSATEEncoder(args, task, embed_tokens) encoder = S2TSATEEncoder(args, task, decoder_embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
logger.info( logger.info(
...@@ -174,11 +181,25 @@ class TextEncoder(FairseqEncoder): ...@@ -174,11 +181,25 @@ class TextEncoder(FairseqEncoder):
else: else:
self.layer_norm = None self.layer_norm = None
# CTC
self.use_ctc = getattr(args, "target_ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.target_ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Target CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=embed_tokens.num_embeddings,
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = [] self.intermedia_ctc_layers = []
self.target_ctc_layers = getattr(args, "target_ctc_layers", None) self.target_intermedia_ctc_layers = getattr(args, "target_intermedia_ctc_layers", None)
if self.target_ctc_layers is not None: if self.target_intermedia_ctc_layers is not None:
intermedia_ctc_layers = self.target_ctc_layers.split(",") target_intermedia_ctc_layers = self.target_intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers: for layer_idx in target_intermedia_ctc_layers:
layer_idx = int(layer_idx) layer_idx = int(layer_idx)
assert layer_idx <= layer_num, (layer_idx, layer_num) assert layer_idx <= layer_num, (layer_idx, layer_num)
...@@ -186,7 +207,7 @@ class TextEncoder(FairseqEncoder): ...@@ -186,7 +207,7 @@ class TextEncoder(FairseqEncoder):
layer_idx += layer_num layer_idx += layer_num
self.intermedia_ctc_layers.append(layer_idx) self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx) logger.info("Intermedia target CTC loss in layer %d" % layer_idx)
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
dictionary_size=len(dictionary), dictionary_size=len(dictionary),
...@@ -201,7 +222,8 @@ class TextEncoder(FairseqEncoder): ...@@ -201,7 +222,8 @@ class TextEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league": elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter, self.adapter = Adapter(embed_dim, args.intermedia_adapter,
dictionary, strategy=strategy) dictionary, embed_tokens=embed_tokens,
strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
...@@ -212,14 +234,17 @@ class TextEncoder(FairseqEncoder): ...@@ -212,14 +234,17 @@ class TextEncoder(FairseqEncoder):
x = positions + x x = positions + x
x = self.dropout_module(x) x = self.dropout_module(x)
target_ctc_logits = [] target_ctc_logit = None
target_intermedia_ctc_logits = []
layer_idx = 0 layer_idx = 0
for layer in self.layers: for layer in self.layers:
layer_idx += 1
if history is not None: if history is not None:
x = history.pop() x = history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone())
if layer_idx != self.layer_num and layer_idx in self.intermedia_ctc_layers: if layer_idx != self.layer_num and layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0: if self.intermedia_drop_prob > 0:
...@@ -229,7 +254,7 @@ class TextEncoder(FairseqEncoder): ...@@ -229,7 +254,7 @@ class TextEncoder(FairseqEncoder):
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x) logit = self.ctc(norm_x)
target_ctc_logits.append(logit) target_intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
...@@ -243,18 +268,17 @@ class TextEncoder(FairseqEncoder): ...@@ -243,18 +268,17 @@ class TextEncoder(FairseqEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if layer_idx in self.intermedia_ctc_layers: if self.use_ctc and target_ctc_logit is None:
logit = self.ctc(x) target_ctc_logit = self.ctc(x)
target_ctc_logits.append(logit)
return x, target_ctc_logits return x, target_ctc_logit, target_intermedia_ctc_logits
class S2TSATEEncoder(FairseqEncoder): class S2TSATEEncoder(FairseqEncoder):
"""Speech-to-text Conformer encoder that consists of input subsampler and """Speech-to-text Conformer encoder that consists of input subsampler and
Transformer encoder.""" Transformer encoder."""
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, decoder_embed_tokens=None):
super().__init__(None) super().__init__(None)
# acoustic encoder # acoustic encoder
...@@ -278,7 +302,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -278,7 +302,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.adapter = Adapter(args.encoder_embed_dim, self.adapter = Adapter(args.encoder_embed_dim,
args.adapter, args.adapter,
task.source_dictionary, task.source_dictionary,
embed_tokens if task.source_dictionary == task.target_dictionary else None, decoder_embed_tokens if task.source_dictionary == task.target_dictionary else None,
strategy=strategy) strategy=strategy)
if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"): if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"):
...@@ -288,7 +312,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -288,7 +312,7 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = args.text_attention_type args.encoder_attention_type = args.text_attention_type
# text encoder # text encoder
self.text_encoder = TextEncoder(args, task.source_dictionary, embed_tokens) self.text_encoder = TextEncoder(args, task.source_dictionary, decoder_embed_tokens)
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
...@@ -383,9 +407,15 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -383,9 +407,15 @@ class S2TSATEEncoder(FairseqEncoder):
@register_model_architecture(model_name="s2t_sate", arch_name="s2t_sate") @register_model_architecture(model_name="s2t_sate", arch_name="s2t_sate")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.conv_channels = getattr(args, "conv_channels", 1024) args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
...@@ -415,6 +445,7 @@ def base_architecture(args): ...@@ -415,6 +445,7 @@ def base_architecture(args):
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr( args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim args, "decoder_output_dim", args.decoder_embed_dim
......
...@@ -245,6 +245,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -245,6 +245,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
metavar="STR", metavar="STR",
help="freeze the module of the decoder", help="freeze the module of the decoder",
) )
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
# DLCL
parser.add_argument( parser.add_argument(
"--use-enc-dlcl", "--use-enc-dlcl",
default=False, default=False,
...@@ -540,9 +546,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -540,9 +546,10 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
self.use_ctc = "sate" in args.arch or \ # self.use_ctc = "sate" in args.arch or \
(getattr(args, "criterion", "") == "ctc") or \ # (getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0)) # (("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
if self.use_ctc: if self.use_ctc:
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False
...@@ -679,6 +686,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -679,6 +686,9 @@ class S2TTransformerEncoder(FairseqEncoder):
# encoder layer # encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# interleave CTC # interleave CTC
if layer_idx in self.intermedia_ctc_layers: if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0: if self.intermedia_drop_prob > 0:
...@@ -690,7 +700,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -690,7 +700,6 @@ class S2TTransformerEncoder(FairseqEncoder):
logit = self.ctc(norm_x) logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit) intermedia_ctc_logits.append(logit)
# prob = self.ctc.softmax(norm_x)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
...@@ -713,7 +722,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -713,7 +722,7 @@ class S2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C "intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
...@@ -849,6 +858,7 @@ def base_architecture(args): ...@@ -849,6 +858,7 @@ def base_architecture(args):
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr( args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim args, "decoder_output_dim", args.decoder_embed_dim
......
...@@ -18,6 +18,7 @@ from fairseq.models import ( ...@@ -18,6 +18,7 @@ from fairseq.models import (
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text.modules import Adapter, CTC
from fairseq.modules import ( from fairseq.modules import (
AdaptiveSoftmax, AdaptiveSoftmax,
FairseqDropout, FairseqDropout,
...@@ -294,6 +295,42 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -294,6 +295,42 @@ class TransformerModel(FairseqEncoderDecoderModel):
action='store_true', action='store_true',
help="use squeeze and excitation method", help="use squeeze and excitation method",
) )
# CTC
parser.add_argument(
"--ctc-layer",
type=int,
help="ctc layers for target sentence",
)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -342,7 +379,7 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -342,7 +379,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
) )
if getattr(args, "offload_activations", False): if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens, decoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None): if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module) utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module)) logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
...@@ -370,8 +407,8 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -370,8 +407,8 @@ class TransformerModel(FairseqEncoderDecoderModel):
return emb return emb
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens, decoder_embed_tokens=None):
encoder = TransformerEncoder(args, src_dict, embed_tokens) encoder = TransformerEncoder(args, src_dict, embed_tokens, decoder_embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
logger.info( logger.info(
f"loaded pretrained encoder from: " f"loaded pretrained encoder from: "
...@@ -460,7 +497,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -460,7 +497,7 @@ class TransformerEncoder(FairseqEncoder):
embed_tokens (torch.nn.Embedding): input embedding embed_tokens (torch.nn.Embedding): input embedding
""" """
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens, decoder_embed_tokens=None):
self.args = args self.args = args
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3])) self.register_buffer("version", torch.Tensor([3]))
...@@ -534,6 +571,48 @@ class TransformerEncoder(FairseqEncoder): ...@@ -534,6 +571,48 @@ class TransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
# CTC
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout)
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
None, embed_tokens=decoder_embed_tokens, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
def build_encoder_layer(self, args): def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args) layer = TransformerEncoderLayer(args)
if getattr(args, "checkpoint_activations", False): if getattr(args, "checkpoint_activations", False):
...@@ -653,6 +732,9 @@ class TransformerEncoder(FairseqEncoder): ...@@ -653,6 +732,9 @@ class TransformerEncoder(FairseqEncoder):
self.history.push(x) self.history.push(x)
# encoder layers # encoder layers
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers: for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
...@@ -660,10 +742,29 @@ class TransformerEncoder(FairseqEncoder): ...@@ -660,10 +742,29 @@ class TransformerEncoder(FairseqEncoder):
x = layer( x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None x, encoder_padding_mask=encoder_padding_mask if has_pads else None
) )
layer_idx += 1
if return_all_hiddens: if return_all_hiddens:
assert encoder_states is not None assert encoder_states is not None
encoder_states.append(x) encoder_states.append(x)
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# Intermedia CTC
if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
...@@ -673,12 +774,16 @@ class TransformerEncoder(FairseqEncoder): ...@@ -673,12 +774,16 @@ class TransformerEncoder(FairseqEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists. # TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None. # The empty list is equivalent to None.
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C "encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
...@@ -1331,6 +1436,12 @@ def base_architecture(args): ...@@ -1331,6 +1436,12 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
@register_model_architecture("transformer", "transformer_relative") @register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args): def transformer_rpr(args):
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
import logging
import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
from fairseq.modules import (
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
RelPositionalEncoding,
LegacyRelPositionalEncoding,
DynamicLinearCombination,
TransformerS2DecoderLayer,
TransformerS2EncoderLayer
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
class TransformerS2Encoder(TransformerEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
def build_encoder_layer(self, args):
layer = TransformerS2EncoderLayer(args)
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer
def forward(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
x2 = None,
x2_encoder_padding_mask = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(src_tokens,
src_lengths,
x2,
x2_encoder_padding_mask,
return_all_hiddens,
token_embeddings)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
x2=None,
x2_encoder_padding_mask=None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any())
if self.history is not None:
self.history.clean()
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
# account for padding while computing the representation
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
# add emb into history
if self.history is not None:
self.history.push(x)
# encoder layers
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None,
x2=x2, x2_encoder_padding_mask=x2_encoder_padding_mask,
)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
if self.history is not None:
self.history.push(x)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_out_s2": [x2], # T x B x C
"encoder_padding_mask_s2": [x2_encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
class TransformerS2Decoder(TransformerDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.args = args
super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
def build_decoder_layer(self, args, no_encoder_attn=False):
layer = TransformerS2DecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer
def extract_features_scriptable(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if self.history is not None:
self.history.clean()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None and self.attn_type != "rel_selfattn":
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# add emb into history
if self.history is not None:
self.history.push(x)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
mixup = None
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
mixup = encoder_out["mixup"]
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
x1 = x[:, idx1]
x2 = x[:, idx2]
x = coef * x1 + (1 - coef) * x2
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 + pad2
# decoder layers
avg_attn = None
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if self.history is not None:
x = self.history.pop()
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = layer(
x,
encoder_out["encoder_out"][0]
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
encoder_out_s2=encoder_out["encoder_out_s2"][0],
encoder_padding_mask_s2=encoder_out["encoder_padding_mask_s2"][0],
incremental_state=incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer) or self.gather_attn_weight),
need_head_weights=bool((idx == alignment_layer) or self.gather_attn_weight),
pos_emb=positions
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if self.history is not None:
self.history.push(x)
if self.gather_attn_weight:
if avg_attn is None:
avg_attn = layer_attn
else:
avg_attn += layer_attn
if self.gather_attn_weight:
avg_attn = avg_attn / len(self.layers)
attn = avg_attn.mean(0).sum(-2)
attn = torch.reshape(attn, [attn.numel()])
attn = attn // 0.001
attn = attn.int().cpu()
if len(encoder_out["encoder_padding_mask"]) > 0:
mask = encoder_out["encoder_padding_mask"][0]
mask = torch.reshape(mask, [mask.numel()])
else:
mask = None
i = -1
for item in attn:
i += 1
if mask[i]:
continue
idx = int(item) * 0.001
if idx not in self.attn_weights:
self.attn_weights[idx] = 0
self.attn_weights[idx] += 1
elif attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states, "mixup": mixup}
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
\ No newline at end of file
...@@ -44,6 +44,7 @@ from .transformer_sentence_encoder import TransformerSentenceEncoder ...@@ -44,6 +44,7 @@ from .transformer_sentence_encoder import TransformerSentenceEncoder
from .transpose_last import TransposeLast from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .transformer_s2_layer import TransformerS2EncoderLayer, TransformerS2DecoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .rotary_positional_embedding import RotaryPositionalEmbedding from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import ( from .positional_encoding import (
...@@ -109,6 +110,8 @@ __all__ = [ ...@@ -109,6 +110,8 @@ __all__ = [
"TransformerSentenceEncoder", "TransformerSentenceEncoder",
"TransformerDecoderLayer", "TransformerDecoderLayer",
"TransformerEncoderLayer", "TransformerEncoderLayer",
"TransformerS2DecoderLayer",
"TransformerS2EncoderLayer",
"TransposeLast", "TransposeLast",
"VGGBlock", "VGGBlock",
"unfold1d", "unfold1d",
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
from numpy.random import uniform
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention,
LocalMultiheadAttention,
SEAttention,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
class TransformerS2EncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.args = args
self.embed_dim = args.encoder_embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu') or "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim)
self.use_se = getattr(args, "squeeze_excitation", False)
if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim),
dropout=args.attention_dropout,
self_attention=False,
)
self.s1_ratio = args.encoder_s1_ratio
self.s2_ratio = args.encoder_s2_ratio
self.drop_net = args.encoder_drop_net
self.drop_net_prob = args.encoder_drop_net_prob
self.drop_net_mix = args.encoder_drop_net_mix
def get_ratio(self):
if self.drop_net:
frand = float(uniform(0, 1))
if self.drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.s1_ratio, self.s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_self_attention(self, embed_dim, args):
if self.attn_type == "selfattn":
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length = max(getattr(args, "max_encoder_relative_length", -1), getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
elif self.attn_type == "local":
hard_mask_window = getattr(args, "hard_mask_window", 0)
gauss_mask_sigma = getattr(args, "gauss_mask_sigma", 0)
init_mask_weight = getattr(args, "init_mask_weight", 0)
return LocalMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
hard_mask_window=hard_mask_window,
gauss_mask_sigma=gauss_mask_sigma,
init_mask_weight=init_mask_weight
)
else:
print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, x,
encoder_padding_mask: Optional[Tensor],
x2 = None,
x2_encoder_padding_mask = None,
attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
pos_emb (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
pos_emb=pos_emb
)
else:
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
if x2 is not None:
x2, _ = self.s2_attn(x, x2, x2, x2_encoder_padding_mask)
x2 = self.dropout_module(x2)
ratio = self.get_ratio()
x = x * ratio[0] + x2 * ratio[1]
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
# use squeeze-and-excitation method
if self.use_se:
x = self.se_attn(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
class TransformerS2DecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.quant_noise = getattr(args, "quant_noise_pq", 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
self.cross_self_attention = getattr(args, "cross_self_attention", False)
self.attn_type = getattr(args, "decoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(
self.embed_dim,
args,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.activation_fn = utils.get_activation_fn(
activation=str(args.activation_fn)
if getattr(args, "activation_fn", None) is not None
else "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.s2_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_x2_dim", self.embed_dim),
vdim=getattr(args, "encoder_x2_dim", self.embed_dim),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.fc1 = self.build_fc1(
self.embed_dim,
args.decoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
args.decoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
self.s1_ratio = args.encoder_s1_ratio
self.s2_ratio = args.encoder_s2_ratio
self.drop_net = args.encoder_drop_net
self.drop_net_prob = args.encoder_drop_net_prob
self.drop_net_mix = args.encoder_drop_net_mix
def get_ratio(self):
if self.drop_net:
frand = float(uniform(0, 1))
if self.drop_net_mix and self.training:
return [frand, 1 - frand]
if frand < self.drop_net_prob and self.training:
return [1, 0]
elif frand > 1 - self.drop_net_prob and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [self.s1_ratio, self.s2_ratio]
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
if self.attn_type == "selfattn":
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
max_relative_length = max(getattr(args, "max_decoder_relative_length", -1), getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum decoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else:
print("The decoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not getattr(args, "cross_self_attention", False),
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def build_encoder_attention(self, embed_dim, args):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
encoder_out_s2 = None,
encoder_padding_mask_s2 = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
pos_emb: Optional[Tensor] = None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
pos_emb=pos_emb
)
else:
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
if encoder_out_s2 is not None:
x2, _ = self.s2_attn(
query=x,
key=encoder_out_s2,
value=encoder_out_s2,
key_padding_mask=encoder_padding_mask_s2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x2 = self.dropout_module(x2)
ratios = self.get_ratio()
x = ratios[0] * x + ratios[1] * x2
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论