Commit 5fb50cc3 by xuchen

implement the dual speech-to-text (need to optimize) and CTC loss for MT

parent 6cbfe851
arch: pdss2t_transformer_s_8 arch: s2t_dual
#pds-ctc: 0_1_1_0 #pds-ctc: 0_1_1_0
#intermedia-adapter: league #intermedia-adapter: league
...@@ -9,11 +9,15 @@ arch: pdss2t_transformer_s_8 ...@@ -9,11 +9,15 @@ arch: pdss2t_transformer_s_8
#attention-reduced-method: pool #attention-reduced-method: pool
#attention-reduced-q: True #attention-reduced-q: True
inter-mixup: True #inter-mixup: True
inter-mixup-layer: 0 #inter-mixup-layer: 0
inter-mixup-beta: 0.5 #inter-mixup-beta: 0.5
encoder-embed-dim: 384 asr-encoder: sate
mt-encoder-layers: 3
mt-encoder-dim: 256
encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
#ctc-layer: 15 #ctc-layer: 15
encoder-layers: 6 encoder-layers: 6
...@@ -39,7 +43,7 @@ lr: 2e-3 ...@@ -39,7 +43,7 @@ lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
ctc-weight: 0.3 ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: join_speech_and_text_loss
label_smoothing: 0.1 label_smoothing: 0.1
dropout: 0.1 dropout: 0.1
......
train-subset: train train-subset: train
valid-subset: valid valid-subset: dev
max-epoch: 50 max-epoch: 50
max-update: 100000 max-update: 100000
......
arch: transformer
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3
intermedia-ctc-layers: 2,4
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
ctc-weight: 0.2 ctc-weight: 0.2
intermedia-ctc-layers: 6,9 intermedia-ctc-layers: 6,9
intermedia-adapter: league intermedia-adapter: league
intermedia-ctc-weight: 0.1 intermedia-ctc-weight: 0.1
#intermedia-drop-prob: 0.2 #intermedia-drop-prob: 0.2
#intermedia-temperature: 5
#target-ctc-weight: 0.5
#target-ctc-layers: 2,4,6
ctc-self-distill-weight: 0 ctc-self-distill-weight: 0
post-process: sentencepiece post-process: sentencepiece
\ No newline at end of file
...@@ -13,7 +13,7 @@ pds-stages: 4 ...@@ -13,7 +13,7 @@ pds-stages: 4
#ctc-layer: 12 #ctc-layer: 12
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
#pds-fusion: True pds-fusion: True
pds-fusion-method: all_conv pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256 pds-embed-dims: 256_256_256_256
pds-ds-method: conv pds-ds-method: conv
......
...@@ -29,7 +29,7 @@ acoustic-encoder: pds ...@@ -29,7 +29,7 @@ acoustic-encoder: pds
adapter: league adapter: league
encoder-embed-dim: 256 encoder-embed-dim: 256
ctc-layer: 12 #ctc-layer: 12
pds-stages: 4 pds-stages: 4
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
......
...@@ -23,6 +23,7 @@ from fairseq.logging.meters import safe_round ...@@ -23,6 +23,7 @@ from fairseq.logging.meters import safe_round
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass
class CtcCriterionConfig(FairseqDataclass): class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field( zero_infinity: bool = field(
...@@ -48,6 +49,10 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -48,6 +49,10 @@ class CtcCriterionConfig(FairseqDataclass):
) )
target_ctc_weight: float = field( target_ctc_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of CTC loss for target sentence"},
)
target_intermedia_ctc_weight: float = field(
default=0.0,
metadata={"help": "weight of intermedia CTC loss for target sentence"}, metadata={"help": "weight of intermedia CTC loss for target sentence"},
) )
ctc_self_distill_weight: float = field( ctc_self_distill_weight: float = field(
...@@ -124,13 +129,15 @@ class CtcCriterion(FairseqCriterion): ...@@ -124,13 +129,15 @@ class CtcCriterion(FairseqCriterion):
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
self.intermedia_ctc_weight = cfg.intermedia_ctc_weight self.intermedia_ctc_weight = cfg.intermedia_ctc_weight
self.target_ctc_weight = cfg.target_ctc_weight self.target_ctc_weight = cfg.target_ctc_weight
self.target_intermedia_ctc_weight = cfg.target_intermedia_ctc_weight
self.ctc_self_distill_weight = cfg.ctc_self_distill_weight self.ctc_self_distill_weight = cfg.ctc_self_distill_weight
self.ctc_entropy = cfg.ctc_entropy self.ctc_entropy = cfg.ctc_entropy
self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + self.target_ctc_weight + \ self.all_ctc_weight = self.ctc_weight + self.intermedia_ctc_weight + \
self.target_ctc_weight + self.target_intermedia_ctc_weight + \
self.ctc_self_distill_weight + self.ctc_entropy self.ctc_self_distill_weight + self.ctc_entropy
if self.all_ctc_weight > 0: if self.all_ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary." # assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True) self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
...@@ -159,7 +166,10 @@ class CtcCriterion(FairseqCriterion): ...@@ -159,7 +166,10 @@ class CtcCriterion(FairseqCriterion):
return ctc_loss return ctc_loss
def compute_ctc_loss(self, model, sample, net_output, logging_output): def compute_ctc_loss(self, model, sample, net_output, logging_output):
transcript = sample["transcript"] if "transcript" in sample:
tokens = sample["transcript"]["tokens"]
else:
tokens = sample["target"]
# if "ctc_padding_mask" in net_output: # if "ctc_padding_mask" in net_output:
# non_padding_mask = ~net_output["ctc_padding_mask"][0] # non_padding_mask = ~net_output["ctc_padding_mask"][0]
# else: # else:
...@@ -175,21 +185,21 @@ class CtcCriterion(FairseqCriterion): ...@@ -175,21 +185,21 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask = ~net_output["encoder_padding_mask"][0] non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (transcript["tokens"] != self.pad_idx) & ( pad_mask = (tokens != self.pad_idx) & (
transcript["tokens"] != self.eos_idx tokens != self.eos_idx
) )
if mixup: if mixup:
mask1 = pad_mask[mixup_idx1] mask1 = pad_mask[mixup_idx1]
mask2 = pad_mask[mixup_idx2] mask2 = pad_mask[mixup_idx2]
transcript_flat1 = transcript["tokens"][[mixup_idx1]].masked_select(mask1) transcript_flat1 = tokens[[mixup_idx1]].masked_select(mask1)
transcript_flat2 = transcript["tokens"][mixup_idx2].masked_select(mask2) transcript_flat2 = tokens[mixup_idx2].masked_select(mask2)
transcript_lengths1 = mask1.sum(-1) transcript_lengths1 = mask1.sum(-1)
transcript_lengths2 = mask2.sum(-1) transcript_lengths2 = mask2.sum(-1)
transcript_flat = [transcript_flat1, transcript_flat2] transcript_flat = [transcript_flat1, transcript_flat2]
transcript_lengths = [transcript_lengths1, transcript_lengths2] transcript_lengths = [transcript_lengths1, transcript_lengths2]
loss_coef = [mixup_coef, 1 - mixup_coef] loss_coef = [mixup_coef, 1 - mixup_coef]
else: else:
transcript_flat = [transcript["tokens"].masked_select(pad_mask)] transcript_flat = [tokens.masked_select(pad_mask)]
transcript_lengths = [pad_mask.sum(-1)] transcript_lengths = [pad_mask.sum(-1)]
loss_coef = [1] loss_coef = [1]
...@@ -247,13 +257,11 @@ class CtcCriterion(FairseqCriterion): ...@@ -247,13 +257,11 @@ class CtcCriterion(FairseqCriterion):
if lprobs is None: if lprobs is None:
lprobs = inter_lprobs lprobs = inter_lprobs
target_ctc_num = 0
target_ctc_loss = 0 target_ctc_loss = 0
if "target_ctc_logits" in net_output: target_intermedia_ctc_loss = 0
target_ctc_num = len(net_output["target_ctc_logits"])
# calculate the target CTC loss # calculate the target CTC loss
if self.target_ctc_weight > 0 and target_ctc_num > 0: if self.target_ctc_weight > 0 or self.target_intermedia_ctc_weight:
target = sample["target"] target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx) pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
...@@ -272,8 +280,24 @@ class CtcCriterion(FairseqCriterion): ...@@ -272,8 +280,24 @@ class CtcCriterion(FairseqCriterion):
target_length = [pad_mask.sum(-1)] target_length = [pad_mask.sum(-1)]
loss_coef = [1] loss_coef = [1]
for i in range(target_ctc_num): if self.target_ctc_weight > 0:
out = net_output["target_ctc_logits"][i] assert "target_ctc_logit" in net_output
target_ctc_logit = net_output["target_ctc_logit"]
tgt_lprobs = model.get_normalized_probs(
[target_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
tgt_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_lprobs, flat, input_lengths, lengths) * coef
target_intermedia_ctc_num = 0
if "target_intermedia_ctc_logits" in net_output:
target_intermedia_ctc_num = len(net_output["target_intermedia_ctc_logits"])
for i in range(target_intermedia_ctc_num):
out = net_output["target_intermedia_ctc_logits"][i]
if type(out) == list: if type(out) == list:
inter_ctc_logit = out[0] inter_ctc_logit = out[0]
padding = ~out[1] padding = ~out[1]
...@@ -288,10 +312,10 @@ class CtcCriterion(FairseqCriterion): ...@@ -288,10 +312,10 @@ class CtcCriterion(FairseqCriterion):
tgt_inter_lprobs.batch_first = False tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef): for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef target_intermedia_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
target_ctc_loss /= target_ctc_num target_intermedia_ctc_loss /= target_intermedia_ctc_num
logging_output["target_ctc_loss"] = utils.item(target_ctc_loss.data) logging_output["target_intermedia_ctc_loss"] = utils.item(target_intermedia_ctc_loss.data)
# calculate the self distillation CTC loss # calculate the self distillation CTC loss
ctc_self_distill_loss = 0 ctc_self_distill_loss = 0
...@@ -344,7 +368,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -344,7 +368,7 @@ class CtcCriterion(FairseqCriterion):
with torch.no_grad(): with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
target = sample["transcript"]["tokens"] if "transcript" in sample else sample["target"] target = tokens
if mixup: if mixup:
idx = mixup_idx1 idx = mixup_idx1
if mixup_coef < 0.5: if mixup_coef < 0.5:
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from .ctc import CtcCriterion, CtcCriterionConfig
@register_criterion("join_speech_and_text_loss")
class JoinSpeechTextLoss(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, label_smoothing,
sentence_avg,
cfg: CtcCriterionConfig,
ctc_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing)
self.report_accuracy = True
self.ctc_weight = ctc_weight
self.ctc_criterion = CtcCriterion(cfg, task, ctc_weight)
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser)
CtcCriterion.add_args(parser)
parser.add_argument(
"--ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of CTC loss",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
speech_tokens, speech_lengths, prev_output_tokens = sample["net_input"].values()
text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"]
encoder_out = model.encoder(speech_tokens, speech_lengths,
text_src_tokens, text_src_lengths)
net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
use_mixup = False
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
use_mixup = True
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0)
if use_mixup:
sample_size //= 2
n_tokens //= 2
n_sentences //= 2
logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": n_tokens,
"nsentences": n_sentences,
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item(
sum(log.get("trans_loss", 0) for log in logging_outputs)
)
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if trans_loss_sum != loss_sum:
metrics.log_scalar(
"trans_loss", trans_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
if "ctc_loss" in logging_outputs[0] or "all_ctc_loss" in logging_outputs[0]:
CtcCriterion.reduce_metrics(logging_outputs)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
...@@ -9,3 +9,4 @@ from .s2t_ctc import * ...@@ -9,3 +9,4 @@ from .s2t_ctc import *
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .pdss2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
from .s2t_dual import * # noqa
...@@ -6,7 +6,6 @@ import torch.nn as nn ...@@ -6,7 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models.transformer import Embedding
from fairseq.modules import LayerNorm from fairseq.modules import LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -75,7 +74,7 @@ class Adapter(nn.Module): ...@@ -75,7 +74,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]: if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
if embed_tokens is None: if embed_tokens is None:
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, dim, dictionary.pad()) self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
else: else:
self.embed_adapter = embed_tokens self.embed_adapter = embed_tokens
......
...@@ -88,12 +88,6 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -88,12 +88,6 @@ class S2TSATEModel(S2TTransformerModel):
help="the architecture of the acoustic encoder", help="the architecture of the acoustic encoder",
) )
parser.add_argument( parser.add_argument(
"--target-ctc-layers",
default=None,
type=str,
help="ctc layers for target sentence",
)
parser.add_argument(
"--load-pretrained-acoustic-encoder-from", "--load-pretrained-acoustic-encoder-from",
type=str, type=str,
metavar="STR", metavar="STR",
...@@ -105,11 +99,24 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -105,11 +99,24 @@ class S2TSATEModel(S2TTransformerModel):
metavar="STR", metavar="STR",
help="model to take text encoder weights from (for initialization)", help="model to take text encoder weights from (for initialization)",
) )
# target CTC
parser.add_argument(
"--target-ctc-layer",
default=None,
type=str,
help="ctc layer for target sentence",
)
parser.add_argument(
"--target-intermedia-ctc-layer",
default=None,
type=str,
help="intermedia ctc layers for target sentence",
)
pass pass
@classmethod @classmethod
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, decoder_embed_tokens=None):
encoder = S2TSATEEncoder(args, task, embed_tokens) encoder = S2TSATEEncoder(args, task, decoder_embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
logger.info( logger.info(
...@@ -174,11 +181,25 @@ class TextEncoder(FairseqEncoder): ...@@ -174,11 +181,25 @@ class TextEncoder(FairseqEncoder):
else: else:
self.layer_norm = None self.layer_norm = None
# CTC
self.use_ctc = getattr(args, "target_ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.target_ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Target CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=embed_tokens.num_embeddings,
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = [] self.intermedia_ctc_layers = []
self.target_ctc_layers = getattr(args, "target_ctc_layers", None) self.target_intermedia_ctc_layers = getattr(args, "target_intermedia_ctc_layers", None)
if self.target_ctc_layers is not None: if self.target_intermedia_ctc_layers is not None:
intermedia_ctc_layers = self.target_ctc_layers.split(",") target_intermedia_ctc_layers = self.target_intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers: for layer_idx in target_intermedia_ctc_layers:
layer_idx = int(layer_idx) layer_idx = int(layer_idx)
assert layer_idx <= layer_num, (layer_idx, layer_num) assert layer_idx <= layer_num, (layer_idx, layer_num)
...@@ -186,7 +207,7 @@ class TextEncoder(FairseqEncoder): ...@@ -186,7 +207,7 @@ class TextEncoder(FairseqEncoder):
layer_idx += layer_num layer_idx += layer_num
self.intermedia_ctc_layers.append(layer_idx) self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx) logger.info("Intermedia target CTC loss in layer %d" % layer_idx)
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
dictionary_size=len(dictionary), dictionary_size=len(dictionary),
...@@ -201,7 +222,8 @@ class TextEncoder(FairseqEncoder): ...@@ -201,7 +222,8 @@ class TextEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league": elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter, self.adapter = Adapter(embed_dim, args.intermedia_adapter,
dictionary, strategy=strategy) dictionary, embed_tokens=embed_tokens,
strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
...@@ -212,14 +234,17 @@ class TextEncoder(FairseqEncoder): ...@@ -212,14 +234,17 @@ class TextEncoder(FairseqEncoder):
x = positions + x x = positions + x
x = self.dropout_module(x) x = self.dropout_module(x)
target_ctc_logits = [] target_ctc_logit = None
target_intermedia_ctc_logits = []
layer_idx = 0 layer_idx = 0
for layer in self.layers: for layer in self.layers:
layer_idx += 1
if history is not None: if history is not None:
x = history.pop() x = history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
target_ctc_logit = self.ctc(x.clone())
if layer_idx != self.layer_num and layer_idx in self.intermedia_ctc_layers: if layer_idx != self.layer_num and layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0: if self.intermedia_drop_prob > 0:
...@@ -229,7 +254,7 @@ class TextEncoder(FairseqEncoder): ...@@ -229,7 +254,7 @@ class TextEncoder(FairseqEncoder):
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x) logit = self.ctc(norm_x)
target_ctc_logits.append(logit) target_intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
...@@ -243,18 +268,17 @@ class TextEncoder(FairseqEncoder): ...@@ -243,18 +268,17 @@ class TextEncoder(FairseqEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if layer_idx in self.intermedia_ctc_layers: if self.use_ctc and target_ctc_logit is None:
logit = self.ctc(x) target_ctc_logit = self.ctc(x)
target_ctc_logits.append(logit)
return x, target_ctc_logits return x, target_ctc_logit, target_intermedia_ctc_logits
class S2TSATEEncoder(FairseqEncoder): class S2TSATEEncoder(FairseqEncoder):
"""Speech-to-text Conformer encoder that consists of input subsampler and """Speech-to-text Conformer encoder that consists of input subsampler and
Transformer encoder.""" Transformer encoder."""
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, decoder_embed_tokens=None):
super().__init__(None) super().__init__(None)
# acoustic encoder # acoustic encoder
...@@ -278,7 +302,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -278,7 +302,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.adapter = Adapter(args.encoder_embed_dim, self.adapter = Adapter(args.encoder_embed_dim,
args.adapter, args.adapter,
task.source_dictionary, task.source_dictionary,
embed_tokens if task.source_dictionary == task.target_dictionary else None, decoder_embed_tokens if task.source_dictionary == task.target_dictionary else None,
strategy=strategy) strategy=strategy)
if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"): if args.share_ctc_and_adapter and hasattr(self.adapter, "embed_adapter"):
...@@ -288,7 +312,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -288,7 +312,7 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = args.text_attention_type args.encoder_attention_type = args.text_attention_type
# text encoder # text encoder
self.text_encoder = TextEncoder(args, task.source_dictionary, embed_tokens) self.text_encoder = TextEncoder(args, task.source_dictionary, decoder_embed_tokens)
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
...@@ -383,9 +407,15 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -383,9 +407,15 @@ class S2TSATEEncoder(FairseqEncoder):
@register_model_architecture(model_name="s2t_sate", arch_name="s2t_sate") @register_model_architecture(model_name="s2t_sate", arch_name="s2t_sate")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.subsampling_type = getattr(args, "subsampling_type", "conv1d")
args.conv_channels = getattr(args, "conv_channels", 1024) args.subsampling_layers = getattr(args, "subsampling_layers", 2)
args.subsampling_filter = getattr(args, "subsampling_filter", 1024)
args.subsampling_kernel = getattr(args, "subsampling_kernel", 5)
args.subsampling_stride = getattr(args, "subsampling_stride", 2)
args.subsampling_norm = getattr(args, "subsampling_norm", "none")
args.subsampling_activation = getattr(args, "subsampling_activation", "glu")
# transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
...@@ -415,6 +445,7 @@ def base_architecture(args): ...@@ -415,6 +445,7 @@ def base_architecture(args):
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr( args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim args, "decoder_output_dim", args.decoder_embed_dim
......
...@@ -245,6 +245,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -245,6 +245,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
metavar="STR", metavar="STR",
help="freeze the module of the decoder", help="freeze the module of the decoder",
) )
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
# DLCL
parser.add_argument( parser.add_argument(
"--use-enc-dlcl", "--use-enc-dlcl",
default=False, default=False,
...@@ -540,9 +546,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -540,9 +546,10 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
self.use_ctc = "sate" in args.arch or \ # self.use_ctc = "sate" in args.arch or \
(getattr(args, "criterion", "") == "ctc") or \ # (getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0)) # (("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
if self.use_ctc: if self.use_ctc:
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False
...@@ -679,6 +686,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -679,6 +686,9 @@ class S2TTransformerEncoder(FairseqEncoder):
# encoder layer # encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# interleave CTC # interleave CTC
if layer_idx in self.intermedia_ctc_layers: if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0: if self.intermedia_drop_prob > 0:
...@@ -690,7 +700,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -690,7 +700,6 @@ class S2TTransformerEncoder(FairseqEncoder):
logit = self.ctc(norm_x) logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit) intermedia_ctc_logits.append(logit)
# prob = self.ctc.softmax(norm_x)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
...@@ -713,7 +722,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -713,7 +722,7 @@ class S2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C "intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
...@@ -849,6 +858,7 @@ def base_architecture(args): ...@@ -849,6 +858,7 @@ def base_architecture(args):
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input = getattr(args, "adaptive_input", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr( args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim args, "decoder_output_dim", args.decoder_embed_dim
......
...@@ -18,6 +18,7 @@ from fairseq.models import ( ...@@ -18,6 +18,7 @@ from fairseq.models import (
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text.modules import Adapter, CTC
from fairseq.modules import ( from fairseq.modules import (
AdaptiveSoftmax, AdaptiveSoftmax,
FairseqDropout, FairseqDropout,
...@@ -294,6 +295,42 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -294,6 +295,42 @@ class TransformerModel(FairseqEncoderDecoderModel):
action='store_true', action='store_true',
help="use squeeze and excitation method", help="use squeeze and excitation method",
) )
# CTC
parser.add_argument(
"--ctc-layer",
type=int,
help="ctc layers for target sentence",
)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -342,7 +379,7 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -342,7 +379,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
) )
if getattr(args, "offload_activations", False): if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens, decoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None): if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module) utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module)) logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
...@@ -370,8 +407,8 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -370,8 +407,8 @@ class TransformerModel(FairseqEncoderDecoderModel):
return emb return emb
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens, decoder_embed_tokens=None):
encoder = TransformerEncoder(args, src_dict, embed_tokens) encoder = TransformerEncoder(args, src_dict, embed_tokens, decoder_embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
logger.info( logger.info(
f"loaded pretrained encoder from: " f"loaded pretrained encoder from: "
...@@ -460,7 +497,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -460,7 +497,7 @@ class TransformerEncoder(FairseqEncoder):
embed_tokens (torch.nn.Embedding): input embedding embed_tokens (torch.nn.Embedding): input embedding
""" """
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens, decoder_embed_tokens=None):
self.args = args self.args = args
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3])) self.register_buffer("version", torch.Tensor([3]))
...@@ -534,6 +571,48 @@ class TransformerEncoder(FairseqEncoder): ...@@ -534,6 +571,48 @@ class TransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
# CTC
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout)
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
None, embed_tokens=decoder_embed_tokens, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
def build_encoder_layer(self, args): def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args) layer = TransformerEncoderLayer(args)
if getattr(args, "checkpoint_activations", False): if getattr(args, "checkpoint_activations", False):
...@@ -653,6 +732,9 @@ class TransformerEncoder(FairseqEncoder): ...@@ -653,6 +732,9 @@ class TransformerEncoder(FairseqEncoder):
self.history.push(x) self.history.push(x)
# encoder layers # encoder layers
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers: for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
...@@ -660,10 +742,29 @@ class TransformerEncoder(FairseqEncoder): ...@@ -660,10 +742,29 @@ class TransformerEncoder(FairseqEncoder):
x = layer( x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None x, encoder_padding_mask=encoder_padding_mask if has_pads else None
) )
layer_idx += 1
if return_all_hiddens: if return_all_hiddens:
assert encoder_states is not None assert encoder_states is not None
encoder_states.append(x) encoder_states.append(x)
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# Intermedia CTC
if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
...@@ -673,12 +774,16 @@ class TransformerEncoder(FairseqEncoder): ...@@ -673,12 +774,16 @@ class TransformerEncoder(FairseqEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists. # TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None. # The empty list is equivalent to None.
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C "encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
...@@ -1331,6 +1436,12 @@ def base_architecture(args): ...@@ -1331,6 +1436,12 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
@register_model_architecture("transformer", "transformer_relative") @register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args): def transformer_rpr(args):
......
...@@ -44,6 +44,7 @@ from .transformer_sentence_encoder import TransformerSentenceEncoder ...@@ -44,6 +44,7 @@ from .transformer_sentence_encoder import TransformerSentenceEncoder
from .transpose_last import TransposeLast from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .transformer_s2_layer import TransformerS2EncoderLayer, TransformerS2DecoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .rotary_positional_embedding import RotaryPositionalEmbedding from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import ( from .positional_encoding import (
...@@ -109,6 +110,8 @@ __all__ = [ ...@@ -109,6 +110,8 @@ __all__ = [
"TransformerSentenceEncoder", "TransformerSentenceEncoder",
"TransformerDecoderLayer", "TransformerDecoderLayer",
"TransformerEncoderLayer", "TransformerEncoderLayer",
"TransformerS2DecoderLayer",
"TransformerS2EncoderLayer",
"TransposeLast", "TransposeLast",
"VGGBlock", "VGGBlock",
"unfold1d", "unfold1d",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论