Commit f286e56c by xuchen

add the intermedia ctc

parent 31e7c426
# Fairseq-S2T
# NiuTrans-Fairseq-S2T
Adapt the fairseq toolkit for speech to text task.
This project adapts the [fairseq](https://github.com/pytorch/fairseq) toolkit for speech-to-text tasks, including speech recognition and speech translation.
It contains the implementation of the following methods proposed by NiuTrans Team.
Implementation of the paper:
[Stacked Acoustic-and-Textual Encoding: Integrating the Pre-trained Models into Speech Translation Encoders](https://arxiv.org/abs/2105.05752)
## Key Features
### Training
- Support the Kaldi-style complete recipe
- ASR, MT, and ST pipeline (bin)
- Support the Kaldi-style complete recipes
- ASR, MT, and ST pipelines (bin)
- Read training config in yaml file
- CTC multi-task learning
- MT training in the ST-like way (Online tokenizer) (There may be bugs)
- speed perturb during pre-processing (need torchaudio ≥ 0.8.0)
- MT training in the ST-like way (Online tokenizer) (This may be slowly.)
- speed perturb during pre-processing
### Model
- Conformer Architecture
- Load pre-trained model for ST
- Relative position encoding
- Load pre-trained modules
- Relative position representation
- Stacked acoustic-and-textual encoding
- Progressive down-sampling for acoustic encoding
## Installation
......@@ -43,9 +45,10 @@ make -j src.build CUDA_HOME=<path to cuda install>
pip install pandas sentencepiece configargparse gpustat tensorboard editdistance
```
## Code Tree
## Code Structure
The shell scripts for each benchmark is in the egs folder, we create the ASR pipeline for LibriSpeech, all pipelines (ASR, MT, and ST) for MuST-C. Besides, we also provide the template for other benchmarks.
We supply the recipes for multiple benchmarks in the egs folder, including machine translation, speech recognition, and speech translation corpora.
Besides, we also provide the template for other benchmarks.
Here is an example for MuST-C:
......@@ -53,41 +56,40 @@ Here is an example for MuST-C:
mustc
├── asr
│   ├── binary.sh
│   ├── conf
│   ├── conf/
│   ├── decode.sh
│   ├── local
│   ├── local/
│   ├── run.sh
│   └── train.sh
├── mt
│   ├── binary.sh
│   ├── conf
│   ├── conf/
│   ├── decode.sh
│   ├── local
│   ├── local/
│   ├── run.sh
│   └── train.sh
└── st
├── binary.sh
├── conf
├── conf/
├── decode.sh
├── ensemble.sh
├── local
├── local/
├── run.sh
└── train.sh
```
* run.sh: the core script, which includes the whole processes
* run.sh: the core script that includes the whole pipeline
* train.sh: call the run.sh for training
* decode.sh: call the run.sh for decoding
* binary.sh: generate the datasets alone
* conf: the folder to save the configure files (.yaml).
* local: the folder to save utils shell scripts
* local: the folder to save utils
* monitor.sh: check the GPUS for running the program automatically
* parse_options.sh: parse the parameters for run.sh
* path.sh: no use
* utils.sh: the utils shell functions
* utils.sh: the util shell functions
## Citations
```angular2html
```bibtex
@inproceedings{xu-etal-2021-stacked,
title = "Stacked Acoustic-and-Textual Encoding: Integrating the Pre-trained Models into Speech Translation Encoders",
author = "Xu, Chen and
......
......@@ -17,3 +17,4 @@ log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
ctc-weight: 0.3
post-process: sentencepiece
arch: multi_ctc_s2t_transformer_s
#arch: pdss2t_transformer_s_8
arch: s2t_transformer_s
#pds-ctc: 1_1_1_1
#pds-ctc: 0_0_0_1
intermedia-ctc-layers: 6,8,10
intermedia-adapter: league
intermedia-ctc-weight: 0.2
ctc-self-distill-weight: 1
ctc-weight: 0.1
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......@@ -9,7 +17,6 @@ warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
......
......@@ -17,3 +17,4 @@ log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.3
post-process: sentencepiece
\ No newline at end of file
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.15
ctc-self-distill-weight: 1
\ No newline at end of file
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.15
ctc-self-distill-weight: 1
\ No newline at end of file
arch: pdss2t_transformer_s_8
pds-ctc: 1_1_1_1
intermedia-adapter: league
intermedia-ctc-weight: 0.15
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
......
......@@ -20,7 +20,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, sentence_avg, label_smoothing, post_process="letter",
ctc_weight=0.0, intermedia_ctc_weight=0.0):
ctc_weight=0.0, intermedia_ctc_weight=0.0, ctc_self_distill_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing)
self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad()
......@@ -29,9 +29,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
self.report_accuracy = True
assert 0 <= ctc_weight
self.ctc_weight = ctc_weight
self.top_ctc_weight = ctc_weight
self.intermedia_ctc_weight = intermedia_ctc_weight
if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0:
self.ctc_self_distill_weight = ctc_self_distill_weight
self.ctc_weight = ctc_weight + intermedia_ctc_weight + ctc_self_distill_weight
if self.ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.post_process = post_process
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
......@@ -58,7 +60,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
default=0.0,
type=float,
metavar="D",
help="weight of intermedia CT loss",
help="weight of intermedia CTC loss",
)
parser.add_argument(
"--ctc-self-distill",
action="store_true",
help="use self distillation for intermedia CTC loss",
)
parser.add_argument(
"--ctc-self-distill-weight",
default=0.0,
type=float,
metavar="D",
help="weight of the self distillation CTC loss",
)
parser.add_argument(
"--post-process",
......@@ -100,9 +114,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0:
if self.ctc_weight > 0:
ctc_loss, logging_output = self.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss
loss = (1 - self.top_ctc_weight - self.intermedia_ctc_weight) * loss + ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
......@@ -122,7 +136,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
transcript_lengths = pad_mask.sum(-1)
ctc_loss = 0
if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) > 0:
lprobs = None
if self.top_ctc_weight > 0 and "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) > 0:
ctc_logit = encoder_out["ctc_logit"][0]
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
......@@ -140,15 +155,22 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
intermedia_ctc_num = 0
intermedia_ctc_loss = 0
if "intermedia_ctc_logit" in encoder_out:
intermedia_ctc_num = len(encoder_out["intermedia_ctc_logit"])
if "intermedia_ctc_logits" in encoder_out:
intermedia_ctc_num = len(encoder_out["intermedia_ctc_logits"])
if intermedia_ctc_num > 0:
# calculate the intermedia CTC loss
if self.intermedia_ctc_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = encoder_out["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
ctc_logit = encoder_out["intermedia_ctc_logit"][i]
inter_lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
......@@ -162,7 +184,46 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
intermedia_ctc_loss += loss
intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
loss = self.ctc_weight * ctc_loss + self.intermedia_ctc_weight * intermedia_ctc_loss
if lprobs is None:
lprobs = inter_lprobs
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.top_ctc_weight > 0 and self.ctc_self_distill_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = encoder_out["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
if inter_ctc_logit.size() != ctc_logit.size():
continue
ctc_self_distill_num += 1
loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1),
F.softmax(ctc_logit, dim=-1),
reduction="none",
)
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
loss = loss.sum()
ctc_self_distill_loss += loss
ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(ctc_self_distill_loss.data)
loss = \
self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss
if self.intermedia_ctc_weight > 0 or self.ctc_self_distill_weight > 0:
logging_output["all_ctc_loss"] = utils.item(loss.data)
if not model.training and self.ctc_weight > 0:
import editdistance
......@@ -236,6 +297,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
)
all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
......@@ -265,6 +333,21 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sample_size,
round=3,
)
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"ctc_self_distill_loss",
ctc_self_distill_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if all_ctc_loss_sum > 0:
metrics.log_scalar(
"all_ctc_loss",
all_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
......
......@@ -4,10 +4,8 @@
# LICENSE file in the root directory of this source tree.
from .berard import * # noqa
from .ctc import * # noqa
from .convtransformer import * # noqa
from .s2t_transformer import * # noqa
from .inter_ctc_s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa
from .adapter import *
from .ctc import *
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models.transformer import Embedding
from fairseq.modules import LayerNorm
logger = logging.getLogger(__name__)
class CTCCompressStrategy:
@staticmethod
def avg(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = 1.0 / same[1]
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix.to(device)
@staticmethod
def weighted(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
@staticmethod
def softmax(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = F.softmax(prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]])
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
class InterAdapter(nn.Module):
def __init__(self, dim, adapter_type, dictionary, embed_tokens=None, strategy=None):
super().__init__()
dim = dim
self.adapter_type = adapter_type
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.linear_adapter = nn.Sequential(
nn.Linear(dim, dim),
LayerNorm(dim),
nn.ReLU(),
)
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
if embed_tokens is None:
num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, dim, dictionary.pad())
else:
self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * dim, dim)
elif self.adapter_type == "gated_league2":
self.gate_linear1 = nn.Linear(dim, dim)
self.gate_linear2 = nn.Linear(dim, dim)
if self.adapter_type == "shrink":
self.ctc_compress = getattr(CTCCompressStrategy, strategy)
def forward(self, x, padding):
representation, distribution = x
dim1, dim2, dim = representation.size()
org_distribution = distribution
if distribution is not None:
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear":
out = self.linear_adapter(representation)
elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league":
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = representation + soft_out
elif self.adapter_type == "none":
out = representation
elif self.adapter_type == "shrink":
from itertools import groupby
with torch.no_grad():
batch_predicted = []
prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
for b in range(prob_ctc.shape[0]):
predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist()
batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
new_lengths = [len(p) for p in batch_predicted]
weights_matrix = self.ctc_compress(prob_ctc, batch_predicted, new_lengths,
representation.dtype, representation.device)
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
compressed_output = representation.permute(1, 2, 0).bmm(weights_matrix) # B x C x T'
out = compressed_output.permute(2, 0, 1)
out_lengths = lengths.new(new_lengths)
padding = lengths_to_padding_mask(out_lengths)
else:
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
return out, padding
#!/usr/bin/env python3
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import (
FairseqDropout,
LayerNorm,
......@@ -27,7 +26,6 @@ class CTC(nn.Module):
self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__
)
self.softmax = nn.Softmax(dim=-1)
self.need_layernorm = need_layernorm
if self.need_layernorm:
self.LayerNorm = LayerNorm(embed_dim)
......@@ -40,54 +38,11 @@ class CTC(nn.Module):
return x
def softmax(self, x, temperature=1.0):
return torch.nn.functional.softmax(self.ctc_projection(x) / temperature, dim=-1)
return F.softmax(self.ctc_projection(x) / temperature, dim=-1)
def log_softmax(self, x, temperature=1.0):
return torch.nn.functional.log_softmax(self.ctc_projection(x) / temperature, dim=-1)
return F.log_softmax(self.ctc_projection(x) / temperature, dim=-1)
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
class CTCCompressStrategy:
@staticmethod
def avg(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = 1.0 / same[1]
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix.to(device)
@staticmethod
def weighted(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
@staticmethod
def softmax(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = F.softmax(prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]])
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
#!/usr/bin/env python3
import logging
import math
import torch
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.models import (
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import CTC, S2TTransformerModel
from fairseq.models.speech_to_text import S2TTransformerModel
from fairseq.models.speech_to_text.modules import CTC, InterAdapter
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
PDSTransformerEncoderLayer,
MultiheadAttention,
DownSampleConvolutionModule
)
......@@ -459,7 +459,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str,
help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-fusion",
action="store_true",
......@@ -475,6 +474,23 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=float,
help="dropout in each stage",
)
parser.add_argument(
"--pds-ctc",
type=str,
help="use the ctc after each stage",
)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
pass
@classmethod
......@@ -550,6 +566,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
fusion_stages_num = 0
self.fusion_stages_num = fusion_stages_num
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")]
inter_ctc_module = None
inter_adapter = None
for i in range(self.pds_stages):
num_layers = self.pds_layers[i]
ds_ratio = self.pds_ratios[i]
......@@ -557,6 +578,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim = self.pds_embed_dims[i]
kernel_size = self.pds_kernel_sizes[i]
use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i]
num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1
......@@ -577,7 +599,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
downsampling = Downsampling(
self.pds_ds_method,
self.pds_embed_norm,
args.input_feat_per_channel * args.input_channels if i == 0 else self.pds_embed_dims[i-1],
args.input_feat_per_channel * args.input_channels if i == 0 else self.pds_embed_dims[i - 1],
embed_dim,
kernel_sizes=kernel_size,
stride=ds_ratio,
......@@ -633,6 +655,36 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else:
logger.error("Unsupported fusion transform!")
# intermedia modules for each stage
if use_ctc:
if inter_ctc_module is None:
ctc = CTC(embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=True)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
inter_ctc_module = ctc
else:
ctc = inter_ctc_module
if i != self.pds_stages - 1:
if inter_adapter is None:
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg")
adapter = InterAdapter(embed_dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
inter_adapter = adapter
else:
adapter = inter_adapter
else:
adapter = InterAdapter(embed_dim, "none",
task.source_dictionary)
else:
ctc = None
adapter = None
setattr(self, f"downsampling{i + 1}", downsampling)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"stage{i + 1}", stage)
......@@ -640,6 +692,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
setattr(self, f"fusion_pre_layer_norm{i + 1}", fusion_pre_layer_norm)
setattr(self, f"fusion_post_layer_norm{i + 1}", fusion_post_layer_norm)
setattr(self, f"ctc{i + 1}", ctc)
setattr(self, f"adapter{i + 1}", adapter)
if self.fusion_stages_num != 0:
self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0))
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
......@@ -648,10 +703,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
(("ctc" in getattr(args, "criterion", False)) and
(getattr(args, "ctc_weight", False) > 0))
if self.use_ctc:
# self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers
# self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
self.ctc_layer = args.encoder_layers
self.inter_ctc = True if self.ctc_layer != 0 else False
self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
......@@ -663,7 +717,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if ctc_layer <= 0:
embed_dim = self.pds_embed_dims[i]
break
if inter_ctc_module is None:
self.ctc = CTC(embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
......@@ -671,6 +725,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
else:
self.ctc = inter_ctc_module
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim)
......@@ -708,7 +764,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# padding to the multiply of 2
max_len = x.size(0)
length = reduce(lambda a, b: a*b, self.pds_ratios)
length = reduce(lambda a, b: a * b, self.pds_ratios)
padding_to_len = (length - max_len % length)
if padding_to_len > 0:
padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2)))
......@@ -724,10 +780,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ctc_logit = None
prev_state = []
prev_padding = []
intermedia_ctc_logits = []
for i in range(self.pds_stages):
downsampling = getattr(self, f"downsampling{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
stage = getattr(self, f"stage{i + 1}")
ctc = getattr(self, f"ctc{i + 1}")
adapter = getattr(self, f"adapter{i + 1}")
x, input_lengths, encoder_padding_mask = downsampling(x, input_lengths)
......@@ -766,6 +825,14 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
# interleave CTC
if ctc is not None:
logit = ctc(x.clone())
intermedia_ctc_logits.append([logit, encoder_padding_mask])
prob = F.softmax(logit, dim=-1)
x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask)
if self.fusion_stages_num != 0:
fusion_state = []
i = -1
......@@ -802,6 +869,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
......@@ -917,6 +985,11 @@ def base_architecture(args):
args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.ctc_self_distill = getattr(args, "ctc_self_distill", False)
def set_pds_base_8(args):
args.pds_stages = getattr(args, "pds_stages", 4)
......
#!/usr/bin/env python3
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
......@@ -18,8 +17,8 @@ from fairseq.models.speech_to_text import (
S2TTransformerEncoder,
PDSS2TTransformerModel,
PDSS2TTransformerEncoder,
CTCCompressStrategy
)
from fairseq.models.speech_to_text.modules import CTCCompressStrategy
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import (
FairseqDropout,
......@@ -329,7 +328,7 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = acoustic_encoder_attention_type
if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 1
layer_num = args.encoder_layers + args.text_encoder_layers + 2
self.history = DynamicLinearCombination(args, is_encoder=True, layer_num=layer_num)
else:
self.history = None
......@@ -346,7 +345,8 @@ class S2TSATEEncoder(FairseqEncoder):
if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = acoustic_encoder_out["ctc_logit"][0]
ctc_prob = self.acoustic_encoder.ctc.softmax(ctc_logit, self.temperature)
ctc_prob = F.softmax(ctc_logit / self.temperature, dim=-1)
# ctc_prob = self.acoustic_encoder.ctc.softmax(encoder_out, self.temperature)
else:
ctc_logit = None
ctc_prob = None
......@@ -357,19 +357,19 @@ class S2TSATEEncoder(FairseqEncoder):
if self.history is not None:
acoustic_history = self.acoustic_encoder.history
layer_num = acoustic_history.layer_num
idx = torch.arange(layer_num).unsqueeze(0).T.repeat(1, layer_num).to(x.device)
idx = torch.arange(layer_num).unsqueeze(0).T.repeat(1, layer_num).to(x.device).unsqueeze(2)
self.history.weight.scatter(0, idx, acoustic_history.weight)
self.history.layers.extend(acoustic_history.layers)
self.history.count = acoustic_history.count
self.history.sum = acoustic_history.sum
self.history.add(x)
self.history.push(x)
x = self.text_encoder(x, encoder_padding_mask, self.history)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
......
#!/usr/bin/env python3
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
......@@ -13,13 +13,12 @@ from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import CTC
from fairseq.models.speech_to_text.modules import InterAdapter, CTC
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
TransformerEncoderLayer,
ConformerEncoderLayer,
DynamicLinearCombination,
)
......@@ -382,7 +381,19 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
parser.add_argument('--cl-dropout-strategy',
type=str,
help='interleaved dropout probability')
# intermedia CTC loss
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
pass
@classmethod
......@@ -479,10 +490,11 @@ class S2TTransformerEncoder(FairseqEncoder):
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None)
dim = args.encoder_embed_dim
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_scale = math.sqrt(args.encoder_embed_dim)
self.embed_scale = math.sqrt(dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = 1
......@@ -490,13 +502,13 @@ class S2TTransformerEncoder(FairseqEncoder):
self.subsample = Conv1dSubsampler(
args.input_feat_per_channel * args.input_channels,
args.conv_channels,
args.encoder_embed_dim,
dim,
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
args.max_source_positions, dim, self.padding_idx
)
self.layers = nn.ModuleList(
......@@ -504,7 +516,7 @@ class S2TTransformerEncoder(FairseqEncoder):
)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim)
self.layer_norm = LayerNorm(dim)
else:
self.layer_norm = None
......@@ -520,7 +532,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(args.encoder_embed_dim,
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
......@@ -535,6 +547,32 @@ class S2TTransformerEncoder(FairseqEncoder):
self.dis = 2
self.cos_sim = dict()
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
self.adapter = InterAdapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
@staticmethod
def pooling_ratio():
return 4
......@@ -601,21 +639,32 @@ class S2TTransformerEncoder(FairseqEncoder):
cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx)
layer_index = 0
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers:
layer_index += 1
layer_idx += 1
if self.history is not None:
x = self.history.pop()
if layer_idx != len(self.layers) \
and self.interleaved_dropout is not None \
and layer_idx % self.interleaved_dropout == 0:
x = self.dropout_module(x)
# encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions)
if layer_index != len(self.layers) \
and self.interleaved_dropout is not None \
and layer_index % self.interleaved_dropout == 0:
x = self.dropout_module(x)
# interleave CTC
if layer_idx in self.intermedia_ctc_layers:
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
# prob = self.ctc.softmax(norm_x)
prob = F.softmax(logit, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
# gather cosine similarity
if self.gather_cos_sim:
......@@ -637,6 +686,7 @@ class S2TTransformerEncoder(FairseqEncoder):
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
......@@ -809,6 +859,10 @@ def base_architecture(args):
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args):
......
......@@ -26,6 +26,7 @@ class DynamicLinearCombination(nn.Module):
else:
layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.layer_num = layer_num
# init weights and corresponding masks
learnable = args.encoder_learnable if is_encoder else args.decoder_learnable
self.weight, self.weight_mask = self._init(layer_num, args.init_value, args.weight_type,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论