Commit f286e56c by xuchen

add the intermedia ctc

parent 31e7c426
# Fairseq-S2T # NiuTrans-Fairseq-S2T
Adapt the fairseq toolkit for speech to text task. This project adapts the [fairseq](https://github.com/pytorch/fairseq) toolkit for speech-to-text tasks, including speech recognition and speech translation.
It contains the implementation of the following methods proposed by NiuTrans Team.
Implementation of the paper:
[Stacked Acoustic-and-Textual Encoding: Integrating the Pre-trained Models into Speech Translation Encoders](https://arxiv.org/abs/2105.05752) [Stacked Acoustic-and-Textual Encoding: Integrating the Pre-trained Models into Speech Translation Encoders](https://arxiv.org/abs/2105.05752)
## Key Features ## Key Features
### Training ### Training
- Support the Kaldi-style complete recipe - Support the Kaldi-style complete recipes
- ASR, MT, and ST pipeline (bin) - ASR, MT, and ST pipelines (bin)
- Read training config in yaml file - Read training config in yaml file
- CTC multi-task learning - CTC multi-task learning
- MT training in the ST-like way (Online tokenizer) (There may be bugs) - MT training in the ST-like way (Online tokenizer) (This may be slowly.)
- speed perturb during pre-processing (need torchaudio ≥ 0.8.0) - speed perturb during pre-processing
### Model ### Model
- Conformer Architecture - Conformer Architecture
- Load pre-trained model for ST - Load pre-trained modules
- Relative position encoding - Relative position representation
- Stacked acoustic-and-textual encoding - Stacked acoustic-and-textual encoding
- Progressive down-sampling for acoustic encoding
## Installation ## Installation
...@@ -43,9 +45,10 @@ make -j src.build CUDA_HOME=<path to cuda install> ...@@ -43,9 +45,10 @@ make -j src.build CUDA_HOME=<path to cuda install>
pip install pandas sentencepiece configargparse gpustat tensorboard editdistance pip install pandas sentencepiece configargparse gpustat tensorboard editdistance
``` ```
## Code Tree ## Code Structure
The shell scripts for each benchmark is in the egs folder, we create the ASR pipeline for LibriSpeech, all pipelines (ASR, MT, and ST) for MuST-C. Besides, we also provide the template for other benchmarks. We supply the recipes for multiple benchmarks in the egs folder, including machine translation, speech recognition, and speech translation corpora.
Besides, we also provide the template for other benchmarks.
Here is an example for MuST-C: Here is an example for MuST-C:
...@@ -53,41 +56,40 @@ Here is an example for MuST-C: ...@@ -53,41 +56,40 @@ Here is an example for MuST-C:
mustc mustc
├── asr ├── asr
│   ├── binary.sh │   ├── binary.sh
│   ├── conf │   ├── conf/
│   ├── decode.sh │   ├── decode.sh
│   ├── local │   ├── local/
│   ├── run.sh │   ├── run.sh
│   └── train.sh │   └── train.sh
├── mt ├── mt
│   ├── binary.sh │   ├── binary.sh
│   ├── conf │   ├── conf/
│   ├── decode.sh │   ├── decode.sh
│   ├── local │   ├── local/
│   ├── run.sh │   ├── run.sh
│   └── train.sh │   └── train.sh
└── st └── st
├── binary.sh ├── binary.sh
├── conf ├── conf/
├── decode.sh ├── decode.sh
├── ensemble.sh ├── local/
├── local
├── run.sh ├── run.sh
└── train.sh └── train.sh
``` ```
* run.sh: the core script, which includes the whole processes * run.sh: the core script that includes the whole pipeline
* train.sh: call the run.sh for training * train.sh: call the run.sh for training
* decode.sh: call the run.sh for decoding * decode.sh: call the run.sh for decoding
* binary.sh: generate the datasets alone * binary.sh: generate the datasets alone
* conf: the folder to save the configure files (.yaml). * conf: the folder to save the configure files (.yaml).
* local: the folder to save utils shell scripts * local: the folder to save utils
* monitor.sh: check the GPUS for running the program automatically * monitor.sh: check the GPUS for running the program automatically
* parse_options.sh: parse the parameters for run.sh * parse_options.sh: parse the parameters for run.sh
* path.sh: no use * utils.sh: the util shell functions
* utils.sh: the utils shell functions
## Citations ## Citations
```angular2html
```bibtex
@inproceedings{xu-etal-2021-stacked, @inproceedings{xu-etal-2021-stacked,
title = "Stacked Acoustic-and-Textual Encoding: Integrating the Pre-trained Models into Speech Translation Encoders", title = "Stacked Acoustic-and-Textual Encoding: Integrating the Pre-trained Models into Speech Translation Encoders",
author = "Xu, Chen and author = "Xu, Chen and
......
...@@ -16,4 +16,5 @@ no-progress-bar: True ...@@ -16,4 +16,5 @@ no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
\ No newline at end of file post-process: sentencepiece
ctc-weight: 0.3 ctc-weight: 0.3
post-process: sentencepiece
arch: multi_ctc_s2t_transformer_s #arch: pdss2t_transformer_s_8
arch: s2t_transformer_s
#pds-ctc: 1_1_1_1
#pds-ctc: 0_0_0_1
intermedia-ctc-layers: 6,8,10 intermedia-ctc-layers: 6,8,10
intermedia-adapter: league
intermedia-ctc-weight: 0.2
ctc-self-distill-weight: 1
ctc-weight: 0.1
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -9,7 +17,6 @@ warmup-updates: 10000 ...@@ -9,7 +17,6 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
......
...@@ -16,4 +16,5 @@ no-progress-bar: True ...@@ -16,4 +16,5 @@ no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
\ No newline at end of file post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.3 ctc-weight: 0.3
post-process: sentencepiece
\ No newline at end of file
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.15
ctc-self-distill-weight: 1
\ No newline at end of file
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.15
ctc-self-distill-weight: 1
\ No newline at end of file
arch: pdss2t_transformer_s_8 arch: pdss2t_transformer_s_8
pds-ctc: 1_1_1_1
intermedia-adapter: league
intermedia-ctc-weight: 0.15
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
ctc-layer: 12 ctc-layer: 12
......
...@@ -20,7 +20,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -20,7 +20,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
LabelSmoothedCrossEntropyCriterion LabelSmoothedCrossEntropyCriterion
): ):
def __init__(self, task, sentence_avg, label_smoothing, post_process="letter", def __init__(self, task, sentence_avg, label_smoothing, post_process="letter",
ctc_weight=0.0, intermedia_ctc_weight=0.0): ctc_weight=0.0, intermedia_ctc_weight=0.0, ctc_self_distill_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing) super().__init__(task, sentence_avg, label_smoothing)
self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0 self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad() self.pad_idx = task.target_dictionary.pad()
...@@ -29,9 +29,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -29,9 +29,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
self.report_accuracy = True self.report_accuracy = True
assert 0 <= ctc_weight assert 0 <= ctc_weight
self.ctc_weight = ctc_weight self.top_ctc_weight = ctc_weight
self.intermedia_ctc_weight = intermedia_ctc_weight self.intermedia_ctc_weight = intermedia_ctc_weight
if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0: self.ctc_self_distill_weight = ctc_self_distill_weight
self.ctc_weight = ctc_weight + intermedia_ctc_weight + ctc_self_distill_weight
if self.ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary." assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.post_process = post_process self.post_process = post_process
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True) self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
...@@ -58,7 +60,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -58,7 +60,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
default=0.0, default=0.0,
type=float, type=float,
metavar="D", metavar="D",
help="weight of intermedia CT loss", help="weight of intermedia CTC loss",
)
parser.add_argument(
"--ctc-self-distill",
action="store_true",
help="use self distillation for intermedia CTC loss",
)
parser.add_argument(
"--ctc-self-distill-weight",
default=0.0,
type=float,
metavar="D",
help="weight of the self distillation CTC loss",
) )
parser.add_argument( parser.add_argument(
"--post-process", "--post-process",
...@@ -100,9 +114,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -100,9 +114,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output["n_correct"] = utils.item(n_correct.data) logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data) logging_output["total"] = utils.item(total.data)
if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0: if self.ctc_weight > 0:
ctc_loss, logging_output = self.compute_ctc_loss(model, sample, encoder_out, logging_output) ctc_loss, logging_output = self.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss loss = (1 - self.top_ctc_weight - self.intermedia_ctc_weight) * loss + ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -122,7 +136,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -122,7 +136,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
transcript_lengths = pad_mask.sum(-1) transcript_lengths = pad_mask.sum(-1)
ctc_loss = 0 ctc_loss = 0
if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) > 0: lprobs = None
if self.top_ctc_weight > 0 and "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) > 0:
ctc_logit = encoder_out["ctc_logit"][0] ctc_logit = encoder_out["ctc_logit"][0]
lprobs = model.get_normalized_probs( lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True [ctc_logit], log_probs=True
...@@ -140,15 +155,22 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -140,15 +155,22 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
intermedia_ctc_num = 0 intermedia_ctc_num = 0
intermedia_ctc_loss = 0 intermedia_ctc_loss = 0
if "intermedia_ctc_logit" in encoder_out: if "intermedia_ctc_logits" in encoder_out:
intermedia_ctc_num = len(encoder_out["intermedia_ctc_logit"]) intermedia_ctc_num = len(encoder_out["intermedia_ctc_logits"])
if intermedia_ctc_num > 0: # calculate the intermedia CTC loss
if self.intermedia_ctc_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num): for i in range(intermedia_ctc_num):
out = encoder_out["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
ctc_logit = encoder_out["intermedia_ctc_logit"][i]
inter_lprobs = model.get_normalized_probs( inter_lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True [inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False inter_lprobs.batch_first = False
...@@ -162,7 +184,46 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -162,7 +184,46 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
intermedia_ctc_loss += loss intermedia_ctc_loss += loss
intermedia_ctc_loss /= intermedia_ctc_num intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data) logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
loss = self.ctc_weight * ctc_loss + self.intermedia_ctc_weight * intermedia_ctc_loss
if lprobs is None:
lprobs = inter_lprobs
# calculate the self distillation CTC loss
ctc_self_distill_loss = 0
ctc_self_distill_num = 0
if self.top_ctc_weight > 0 and self.ctc_self_distill_weight > 0 and intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
out = encoder_out["intermedia_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
if inter_ctc_logit.size() != ctc_logit.size():
continue
ctc_self_distill_num += 1
loss = F.kl_div(
F.log_softmax(inter_ctc_logit, dim=-1),
F.softmax(ctc_logit, dim=-1),
reduction="none",
)
loss = loss.sum(-1).transpose(0, 1).masked_fill_(~non_padding_mask, 0.0)
loss = loss.sum()
ctc_self_distill_loss += loss
ctc_self_distill_loss /= ctc_self_distill_num
logging_output["ctc_self_distill_loss"] = utils.item(ctc_self_distill_loss.data)
loss = \
self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss
if self.intermedia_ctc_weight > 0 or self.ctc_self_distill_weight > 0:
logging_output["all_ctc_loss"] = utils.item(loss.data)
if not model.training and self.ctc_weight > 0: if not model.training and self.ctc_weight > 0:
import editdistance import editdistance
...@@ -236,6 +297,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -236,6 +297,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
inter_ctc_loss_sum = utils.item( inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs) sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
) )
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
)
all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -265,6 +333,21 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -265,6 +333,21 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sample_size, sample_size,
round=3, round=3,
) )
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"ctc_self_distill_loss",
ctc_self_distill_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if all_ctc_loss_sum > 0:
metrics.log_scalar(
"all_ctc_loss",
all_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
......
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .berard import * # noqa from .berard import * # noqa
from .ctc import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .inter_ctc_s2t_transformer import * # noqa
from .s2t_conformer import * # noqa from .s2t_conformer import * # noqa
from .pdss2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
#!/usr/bin/env python3
import logging
import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding
from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
CTC,
CTCCompressStrategy,
)
from fairseq.modules import (
LayerNorm
)
logger = logging.getLogger(__name__)
class Adapter(nn.Module):
def __init__(self, args, dictionary, embed_tokens=None):
super().__init__()
embed_dim = args.encoder_embed_dim
self.adapter_type = args.adapter
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.linear_adapter = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
LayerNorm(args.encoder_embed_dim),
nn.ReLU(),
)
elif self.adapter_type == "linear2":
self.linear_adapter = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
)
if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None:
num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, embed_dim, dictionary.pad())
else:
self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
elif self.adapter_type == "gated_league2":
self.gate_linear1 = nn.Linear(embed_dim, embed_dim)
self.gate_linear2 = nn.Linear(embed_dim, embed_dim)
if self.adapter_type == "shrink":
self.ctc_compress_method = getattr(CTCCompressStrategy, args.ctc_compress_strategy)
def forward(self, x, padding):
representation, distribution = x
batch, seq_len, embed_dim = representation.size()
org_distribution = distribution
if distribution is not None:
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear":
out = self.linear_adapter(representation)
elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "none":
out = representation
elif self.adapter_type == "shrink":
from itertools import groupby
with torch.no_grad():
batch_predicted = []
prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
for b in range(prob_ctc.shape[0]):
predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist()
batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
new_lengths = [len(p) for p in batch_predicted]
weights_matrix = self.ctc_compress_method(prob_ctc, batch_predicted, new_lengths,
representation.dtype, representation.device)
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
compressed_output = representation.permute(1, 2, 0).bmm(weights_matrix) # B x C x T'
out = compressed_output.permute(2, 0, 1)
out_lengths = lengths.new(new_lengths)
padding = lengths_to_padding_mask(out_lengths)
else:
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
return out, padding
@register_model("inter_ctc_s2t_transformer")
class InterCTCS2TTransformerModel(S2TTransformerModel):
"""Speech-to-Text Transformer with intermedia CTC Loss in different layers"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
S2TTransformerModel.add_args(parser)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--adapter",
default="league",
type=str,
help="adapter type",
)
parser.add_argument(
"--ctc-compress-strategy",
default="avg",
type=str,
help="compress strategy, such as avg, weighted, and softmax",
)
pass
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TInterCTCTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
return encoder
class S2TInterCTCTransformerEncoder(S2TTransformerEncoder):
"""Speech-to-text Transformer encoder that consists of intermedia ctc losses """
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(args, task, embed_tokens)
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
ctc = CTC(args.encoder_embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=True)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
ctc.ctc_projection.weight = self.ctc.ctc_projection.weight
ctc.LayerNorm = self.layer_norm
setattr(self, f"ctc{layer_idx}", ctc)
adapter = Adapter(args, task.source_dictionary)
# adapter = Adapter(args, task.source_dictionary, ctc.ctc_projection)
setattr(self, f"adapter{layer_idx}", adapter)
def forward(self, src_tokens, src_lengths):
if self.history is not None:
self.history.clean()
# down-sampling
x, input_lengths = self.subsample(src_tokens, src_lengths)
if type(x) == list:
inner_x = x
x = inner_x[-1]
# embedding scaling
x = self.embed_scale * x
# padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
x = self.dropout_module(x)
positions = self.dropout_module(positions)
# add emb into history
if self.history is not None:
self.history.push(x)
layer_idx = 0
ctc_logit = None
intermedia_ctc_logit = []
for layer in self.layers:
layer_idx += 1
if self.history is not None:
x = self.history.pop()
# encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions)
# interleave CTC
if layer_idx in self.intermedia_ctc_layers:
ctc = getattr(self, f"ctc{layer_idx}")
adapter = getattr(self, f"adapter{layer_idx}")
logit = ctc(x)
prob = ctc.softmax(x)
x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask)
intermedia_ctc_logit.append(logit)
if layer_idx != len(self.layers) \
and self.interleaved_dropout is not None \
and layer_idx % self.interleaved_dropout == 0:
x = self.dropout_module(x)
if self.history is not None:
self.history.push(x)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
if self.use_ctc:
ctc_logit = self.ctc(x)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C
"intermedia_ctc_logit": intermedia_ctc_logit, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
@register_model_architecture(model_name="inter_ctc_s2t_transformer", arch_name="inter_ctc_s2t_transformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
# CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", 0)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.decoder_learnable = getattr(args, 'decoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# Relative position encoding
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# local modeling
args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
# interleaved dropout
args.interleave_dropout = getattr(args, "interleave_dropout", None)
args.cl_dropout = getattr(args, "cl_dropout", False)
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
@register_model_architecture("inter_ctc_s2t_transformer", "inter_ctc_s2t_transformer_s")
def inter_ctc_s2t_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
base_architecture(args)
from .adapter import *
from .ctc import *
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models.transformer import Embedding
from fairseq.modules import LayerNorm
logger = logging.getLogger(__name__)
class CTCCompressStrategy:
@staticmethod
def avg(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = 1.0 / same[1]
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix.to(device)
@staticmethod
def weighted(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
@staticmethod
def softmax(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = F.softmax(prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]])
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
class InterAdapter(nn.Module):
def __init__(self, dim, adapter_type, dictionary, embed_tokens=None, strategy=None):
super().__init__()
dim = dim
self.adapter_type = adapter_type
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.linear_adapter = nn.Sequential(
nn.Linear(dim, dim),
LayerNorm(dim),
nn.ReLU(),
)
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
if embed_tokens is None:
num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, dim, dictionary.pad())
else:
self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * dim, dim)
elif self.adapter_type == "gated_league2":
self.gate_linear1 = nn.Linear(dim, dim)
self.gate_linear2 = nn.Linear(dim, dim)
if self.adapter_type == "shrink":
self.ctc_compress = getattr(CTCCompressStrategy, strategy)
def forward(self, x, padding):
representation, distribution = x
dim1, dim2, dim = representation.size()
org_distribution = distribution
if distribution is not None:
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear":
out = self.linear_adapter(representation)
elif self.adapter_type == "context":
out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league":
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(dim1, dim2, -1)
out = representation + soft_out
elif self.adapter_type == "none":
out = representation
elif self.adapter_type == "shrink":
from itertools import groupby
with torch.no_grad():
batch_predicted = []
prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
for b in range(prob_ctc.shape[0]):
predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist()
batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
new_lengths = [len(p) for p in batch_predicted]
weights_matrix = self.ctc_compress(prob_ctc, batch_predicted, new_lengths,
representation.dtype, representation.device)
# x is T x B x C -> B x C x T; weights_matrix is B x T x T'
compressed_output = representation.permute(1, 2, 0).bmm(weights_matrix) # B x C x T'
out = compressed_output.permute(2, 0, 1)
out_lengths = lengths.new(new_lengths)
padding = lengths_to_padding_mask(out_lengths)
else:
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
return out, padding
#!/usr/bin/env python3
import logging import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
...@@ -27,7 +26,6 @@ class CTC(nn.Module): ...@@ -27,7 +26,6 @@ class CTC(nn.Module):
self.ctc_dropout_module = FairseqDropout( self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__ p=dropout, module_name=self.__class__.__name__
) )
self.softmax = nn.Softmax(dim=-1)
self.need_layernorm = need_layernorm self.need_layernorm = need_layernorm
if self.need_layernorm: if self.need_layernorm:
self.LayerNorm = LayerNorm(embed_dim) self.LayerNorm = LayerNorm(embed_dim)
...@@ -40,54 +38,11 @@ class CTC(nn.Module): ...@@ -40,54 +38,11 @@ class CTC(nn.Module):
return x return x
def softmax(self, x, temperature=1.0): def softmax(self, x, temperature=1.0):
return torch.nn.functional.softmax(self.ctc_projection(x) / temperature, dim=-1) return F.softmax(self.ctc_projection(x) / temperature, dim=-1)
def log_softmax(self, x, temperature=1.0): def log_softmax(self, x, temperature=1.0):
return torch.nn.functional.log_softmax(self.ctc_projection(x) / temperature, dim=-1) return F.log_softmax(self.ctc_projection(x) / temperature, dim=-1)
def argmax(self, x): def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1) return torch.argmax(self.ctc_projection(x), dim=-1)
class CTCCompressStrategy:
@staticmethod
def avg(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = 1.0 / same[1]
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix.to(device)
@staticmethod
def weighted(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]]
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
@staticmethod
def softmax(prob_ctc, predicted, new_lengths, dtype, device):
new_maxlen = max(new_lengths)
weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
for b_idx, pred in enumerate(predicted):
processed_inputs_cnt = 0
for t_idx, same in enumerate(pred):
new_processed_inputs_cnt = processed_inputs_cnt + same[1]
# Get the probabilities of the prediction for the different time steps as weight
weights = F.softmax(prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]])
weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
weights / weights.sum()
processed_inputs_cnt = new_processed_inputs_cnt
return weights_matrix
#!/usr/bin/env python3
import logging import logging
import math import math
import torch
from functools import reduce from functools import reduce
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils from fairseq import checkpoint_utils, utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text import CTC, S2TTransformerModel from fairseq.models.speech_to_text import S2TTransformerModel
from fairseq.models.speech_to_text.modules import CTC, InterAdapter
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
PDSTransformerEncoderLayer, PDSTransformerEncoderLayer,
MultiheadAttention,
DownSampleConvolutionModule DownSampleConvolutionModule
) )
...@@ -50,14 +50,14 @@ class Permute201(nn.Module): ...@@ -50,14 +50,14 @@ class Permute201(nn.Module):
class Downsampling(nn.Module): class Downsampling(nn.Module):
# down-sampling module # down-sampling module
def __init__( def __init__(
self, self,
reduced_way: str, reduced_way: str,
embed_norm: bool, embed_norm: bool,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_sizes: int, kernel_sizes: int,
stride: int, stride: int,
padding: int, padding: int,
): ):
super().__init__() super().__init__()
...@@ -75,8 +75,8 @@ class Downsampling(nn.Module): ...@@ -75,8 +75,8 @@ class Downsampling(nn.Module):
) )
elif self.reduced_way == "proj": elif self.reduced_way == "proj":
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding), nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
nn.ReLU() nn.ReLU()
) )
else: else:
logger.error("Unsupported reduced way!") logger.error("Unsupported reduced way!")
...@@ -91,7 +91,7 @@ class Downsampling(nn.Module): ...@@ -91,7 +91,7 @@ class Downsampling(nn.Module):
# mask batch padding # mask batch padding
if not torch.all(lengths == seq_len): if not torch.all(lengths == seq_len):
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, seq_len) # bsz, seq_len padding_mask = lengths_to_padding_mask_with_maxlen(lengths, seq_len) # bsz, seq_len
mask_pad = padding_mask.unsqueeze(2) mask_pad = padding_mask.unsqueeze(2)
if mask_pad is not None: if mask_pad is not None:
x = x.transpose(0, 1) x = x.transpose(0, 1)
...@@ -459,7 +459,6 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -459,7 +459,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="the ratio of the ffn in each stage", help="the ratio of the ffn in each stage",
) )
parser.add_argument( parser.add_argument(
"--pds-fusion", "--pds-fusion",
action="store_true", action="store_true",
...@@ -475,6 +474,23 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -475,6 +474,23 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=float, type=float,
help="dropout in each stage", help="dropout in each stage",
) )
parser.add_argument(
"--pds-ctc",
type=str,
help="use the ctc after each stage",
)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
pass pass
@classmethod @classmethod
...@@ -502,10 +518,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -502,10 +518,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.embed_dim = args.encoder_embed_dim self.embed_dim = args.encoder_embed_dim
self.dropout = FairseqDropout( self.dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__ p=args.dropout, module_name=self.__class__.__name__
) )
self.pds_dropout = FairseqDropout( self.pds_dropout = FairseqDropout(
p=getattr(args, "pds_dropout", args.dropout), module_name=self.__class__.__name__ p=getattr(args, "pds_dropout", args.dropout), module_name=self.__class__.__name__
) )
self.pds_stages = getattr(args, "pds_stages", 4) self.pds_stages = getattr(args, "pds_stages", 4)
...@@ -550,6 +566,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -550,6 +566,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
fusion_stages_num = 0 fusion_stages_num = 0
self.fusion_stages_num = fusion_stages_num self.fusion_stages_num = fusion_stages_num
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
self.pds_ctc = [int(n) for n in args.pds_ctc.split("_")]
inter_ctc_module = None
inter_adapter = None
for i in range(self.pds_stages): for i in range(self.pds_stages):
num_layers = self.pds_layers[i] num_layers = self.pds_layers[i]
ds_ratio = self.pds_ratios[i] ds_ratio = self.pds_ratios[i]
...@@ -557,6 +578,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -557,6 +578,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim = self.pds_embed_dims[i] embed_dim = self.pds_embed_dims[i]
kernel_size = self.pds_kernel_sizes[i] kernel_size = self.pds_kernel_sizes[i]
use_pos_embed = self.pds_position_embed[i] use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i]
num_head = self.pds_attn_heads[i] num_head = self.pds_attn_heads[i]
attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1 attn_ds_ratio = self.pds_attn_ds_ratios[i] if self.attn_type == "reduced" else -1
...@@ -577,7 +599,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -577,7 +599,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
downsampling = Downsampling( downsampling = Downsampling(
self.pds_ds_method, self.pds_ds_method,
self.pds_embed_norm, self.pds_embed_norm,
args.input_feat_per_channel * args.input_channels if i == 0 else self.pds_embed_dims[i-1], args.input_feat_per_channel * args.input_channels if i == 0 else self.pds_embed_dims[i - 1],
embed_dim, embed_dim,
kernel_sizes=kernel_size, kernel_sizes=kernel_size,
stride=ds_ratio, stride=ds_ratio,
...@@ -597,7 +619,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -597,7 +619,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
fusion_downsampling = None fusion_downsampling = None
if fusion_stages_num != 0: if fusion_stages_num != 0:
if self.pds_fusion_method == "all" or ( if self.pds_fusion_method == "all" or (
self.pds_fusion_method == "same" and self.embed_dim == embed_dim self.pds_fusion_method == "same" and self.embed_dim == embed_dim
): ):
if i != self.pds_stages - 1: if i != self.pds_stages - 1:
ratio = reduce(lambda a, b: a * b, self.pds_ratios[i + 1:]) ratio = reduce(lambda a, b: a * b, self.pds_ratios[i + 1:])
...@@ -633,6 +655,36 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -633,6 +655,36 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else: else:
logger.error("Unsupported fusion transform!") logger.error("Unsupported fusion transform!")
# intermedia modules for each stage
if use_ctc:
if inter_ctc_module is None:
ctc = CTC(embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=True)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
inter_ctc_module = ctc
else:
ctc = inter_ctc_module
if i != self.pds_stages - 1:
if inter_adapter is None:
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg")
adapter = InterAdapter(embed_dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
inter_adapter = adapter
else:
adapter = inter_adapter
else:
adapter = InterAdapter(embed_dim, "none",
task.source_dictionary)
else:
ctc = None
adapter = None
setattr(self, f"downsampling{i + 1}", downsampling) setattr(self, f"downsampling{i + 1}", downsampling)
setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"stage{i + 1}", stage) setattr(self, f"stage{i + 1}", stage)
...@@ -640,6 +692,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -640,6 +692,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
setattr(self, f"fusion_pre_layer_norm{i + 1}", fusion_pre_layer_norm) setattr(self, f"fusion_pre_layer_norm{i + 1}", fusion_pre_layer_norm)
setattr(self, f"fusion_post_layer_norm{i + 1}", fusion_post_layer_norm) setattr(self, f"fusion_post_layer_norm{i + 1}", fusion_post_layer_norm)
setattr(self, f"ctc{i + 1}", ctc)
setattr(self, f"adapter{i + 1}", adapter)
if self.fusion_stages_num != 0: if self.fusion_stages_num != 0:
self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0)) self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0))
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True) self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
...@@ -648,10 +703,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -648,10 +703,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
(("ctc" in getattr(args, "criterion", False)) and (("ctc" in getattr(args, "criterion", False)) and
(getattr(args, "ctc_weight", False) > 0)) (getattr(args, "ctc_weight", False) > 0))
if self.use_ctc: if self.use_ctc:
# self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
# self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False self.ctc_layer = args.encoder_layers if self.ctc_layer == 0 else self.ctc_layer
self.ctc_layer = args.encoder_layers self.inter_ctc = True if self.ctc_layer != args.encoder_layers or self.fusion_stages_num != 0 else False
self.inter_ctc = True if self.ctc_layer != 0 else False
if self.inter_ctc: if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer) logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
...@@ -663,14 +717,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -663,14 +717,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if ctc_layer <= 0: if ctc_layer <= 0:
embed_dim = self.pds_embed_dims[i] embed_dim = self.pds_embed_dims[i]
break break
if inter_ctc_module is None:
self.ctc = CTC(embed_dim, self.ctc = CTC(embed_dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False) need_layernorm=True if self.inter_ctc else False)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None: if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
else:
self.ctc = inter_ctc_module
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim) self.layer_norm = LayerNorm(self.embed_dim)
...@@ -708,7 +764,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -708,7 +764,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# padding to the multiply of 2 # padding to the multiply of 2
max_len = x.size(0) max_len = x.size(0)
length = reduce(lambda a, b: a*b, self.pds_ratios) length = reduce(lambda a, b: a * b, self.pds_ratios)
padding_to_len = (length - max_len % length) padding_to_len = (length - max_len % length)
if padding_to_len > 0: if padding_to_len > 0:
padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2))) padding_for_pds = x.new_zeros((padding_to_len, batch, x.size(2)))
...@@ -724,10 +780,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -724,10 +780,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ctc_logit = None ctc_logit = None
prev_state = [] prev_state = []
prev_padding = [] prev_padding = []
intermedia_ctc_logits = []
for i in range(self.pds_stages): for i in range(self.pds_stages):
downsampling = getattr(self, f"downsampling{i + 1}") downsampling = getattr(self, f"downsampling{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}")
stage = getattr(self, f"stage{i + 1}") stage = getattr(self, f"stage{i + 1}")
ctc = getattr(self, f"ctc{i + 1}")
adapter = getattr(self, f"adapter{i + 1}")
x, input_lengths, encoder_padding_mask = downsampling(x, input_lengths) x, input_lengths, encoder_padding_mask = downsampling(x, input_lengths)
...@@ -766,6 +825,14 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -766,6 +825,14 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
prev_state.append(x) prev_state.append(x)
prev_padding.append(encoder_padding_mask) prev_padding.append(encoder_padding_mask)
# interleave CTC
if ctc is not None:
logit = ctc(x.clone())
intermedia_ctc_logits.append([logit, encoder_padding_mask])
prob = F.softmax(logit, dim=-1)
x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask)
if self.fusion_stages_num != 0: if self.fusion_stages_num != 0:
fusion_state = [] fusion_state = []
i = -1 i = -1
...@@ -802,6 +869,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -802,6 +869,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
...@@ -917,6 +985,11 @@ def base_architecture(args): ...@@ -917,6 +985,11 @@ def base_architecture(args):
args.pds_fusion = getattr(args, "pds_fusion", False) args.pds_fusion = getattr(args, "pds_fusion", False)
args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv") args.pds_fusion_method = getattr(args, "pds_fusion_method", "all_conv")
# intermedia CTC
args.pds_ctc = getattr(args, "pds_ctc", "0_0_0_0")
args.intermedia_adapter = getattr(args, "intermedia_adapter", "none")
args.ctc_self_distill = getattr(args, "ctc_self_distill", False)
def set_pds_base_8(args): def set_pds_base_8(args):
args.pds_stages = getattr(args, "pds_stages", 4) args.pds_stages = getattr(args, "pds_stages", 4)
......
#!/usr/bin/env python3
import logging import logging
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
...@@ -18,8 +17,8 @@ from fairseq.models.speech_to_text import ( ...@@ -18,8 +17,8 @@ from fairseq.models.speech_to_text import (
S2TTransformerEncoder, S2TTransformerEncoder,
PDSS2TTransformerModel, PDSS2TTransformerModel,
PDSS2TTransformerEncoder, PDSS2TTransformerEncoder,
CTCCompressStrategy
) )
from fairseq.models.speech_to_text.modules import CTCCompressStrategy
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
...@@ -329,7 +328,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -329,7 +328,7 @@ class S2TSATEEncoder(FairseqEncoder):
args.encoder_attention_type = acoustic_encoder_attention_type args.encoder_attention_type = acoustic_encoder_attention_type
if getattr(args, "use_enc_dlcl", False): if getattr(args, "use_enc_dlcl", False):
layer_num = args.encoder_layers + args.text_encoder_layers + 1 layer_num = args.encoder_layers + args.text_encoder_layers + 2
self.history = DynamicLinearCombination(args, is_encoder=True, layer_num=layer_num) self.history = DynamicLinearCombination(args, is_encoder=True, layer_num=layer_num)
else: else:
self.history = None self.history = None
...@@ -346,7 +345,8 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -346,7 +345,8 @@ class S2TSATEEncoder(FairseqEncoder):
if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0: if "ctc_logit" in acoustic_encoder_out and len(acoustic_encoder_out["ctc_logit"]) > 0:
ctc_logit = acoustic_encoder_out["ctc_logit"][0] ctc_logit = acoustic_encoder_out["ctc_logit"][0]
ctc_prob = self.acoustic_encoder.ctc.softmax(ctc_logit, self.temperature) ctc_prob = F.softmax(ctc_logit / self.temperature, dim=-1)
# ctc_prob = self.acoustic_encoder.ctc.softmax(encoder_out, self.temperature)
else: else:
ctc_logit = None ctc_logit = None
ctc_prob = None ctc_prob = None
...@@ -357,19 +357,19 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -357,19 +357,19 @@ class S2TSATEEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
acoustic_history = self.acoustic_encoder.history acoustic_history = self.acoustic_encoder.history
layer_num = acoustic_history.layer_num layer_num = acoustic_history.layer_num
idx = torch.arange(layer_num).unsqueeze(0).T.repeat(1, layer_num).to(x.device) idx = torch.arange(layer_num).unsqueeze(0).T.repeat(1, layer_num).to(x.device).unsqueeze(2)
self.history.weight.scatter(0, idx, acoustic_history.weight) self.history.weight.scatter(0, idx, acoustic_history.weight)
self.history.layers.extend(acoustic_history.layers) self.history.layers.extend(acoustic_history.layers)
self.history.count = acoustic_history.count self.history.count = acoustic_history.count
self.history.sum = acoustic_history.sum
self.history.add(x) self.history.push(x)
x = self.text_encoder(x, encoder_padding_mask, self.history) x = self.text_encoder(x, encoder_padding_mask, self.history)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C "ctc_logit": [ctc_logit], # T x B x C
"intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T "ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
......
#!/usr/bin/env python3
import logging import logging
import math import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
...@@ -13,13 +13,12 @@ from fairseq.models import ( ...@@ -13,13 +13,12 @@ from fairseq.models import (
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.speech_to_text import CTC from fairseq.models.speech_to_text.modules import InterAdapter, CTC
from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
TransformerEncoderLayer,
ConformerEncoderLayer, ConformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
) )
...@@ -382,7 +381,19 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -382,7 +381,19 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
parser.add_argument('--cl-dropout-strategy', parser.add_argument('--cl-dropout-strategy',
type=str, type=str,
help='interleaved dropout probability') help='interleaved dropout probability')
# intermedia CTC loss
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
pass pass
@classmethod @classmethod
...@@ -479,10 +490,11 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -479,10 +490,11 @@ class S2TTransformerEncoder(FairseqEncoder):
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None) super().__init__(None)
dim = args.encoder_embed_dim
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__ p=args.dropout, module_name=self.__class__.__name__
) )
self.embed_scale = math.sqrt(args.encoder_embed_dim) self.embed_scale = math.sqrt(dim)
if args.no_scale_embedding: if args.no_scale_embedding:
self.embed_scale = 1.0 self.embed_scale = 1.0
self.padding_idx = 1 self.padding_idx = 1
...@@ -490,13 +502,13 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -490,13 +502,13 @@ class S2TTransformerEncoder(FairseqEncoder):
self.subsample = Conv1dSubsampler( self.subsample = Conv1dSubsampler(
args.input_feat_per_channel * args.input_channels, args.input_feat_per_channel * args.input_channels,
args.conv_channels, args.conv_channels,
args.encoder_embed_dim, dim,
[int(k) for k in args.conv_kernel_sizes.split(",")], [int(k) for k in args.conv_kernel_sizes.split(",")],
) )
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx args.max_source_positions, dim, self.padding_idx
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
...@@ -504,7 +516,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -504,7 +516,7 @@ class S2TTransformerEncoder(FairseqEncoder):
) )
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim) self.layer_norm = LayerNorm(dim)
else: else:
self.layer_norm = None self.layer_norm = None
...@@ -520,7 +532,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -520,7 +532,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False self.inter_ctc = True if self.ctc_layer != 0 and self.ctc_layer != args.encoder_layers else False
if self.inter_ctc: if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer) logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(args.encoder_embed_dim, self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False) need_layernorm=True if self.inter_ctc else False)
...@@ -535,6 +547,32 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -535,6 +547,32 @@ class S2TTransformerEncoder(FairseqEncoder):
self.dis = 2 self.dis = 2
self.cos_sim = dict() self.cos_sim = dict()
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
self.adapter = InterAdapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
@staticmethod @staticmethod
def pooling_ratio(): def pooling_ratio():
return 4 return 4
...@@ -601,21 +639,32 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -601,21 +639,32 @@ class S2TTransformerEncoder(FairseqEncoder):
cos_sim_idx += 1 cos_sim_idx += 1
self.add_to_dict(x, dis, cos_sim_idx) self.add_to_dict(x, dis, cos_sim_idx)
layer_index = 0 layer_idx = 0
ctc_logit = None ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers: for layer in self.layers:
layer_index += 1 layer_idx += 1
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
if layer_idx != len(self.layers) \
and self.interleaved_dropout is not None \
and layer_idx % self.interleaved_dropout == 0:
x = self.dropout_module(x)
# encoder layer # encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
if layer_index != len(self.layers) \ # interleave CTC
and self.interleaved_dropout is not None \ if layer_idx in self.intermedia_ctc_layers:
and layer_index % self.interleaved_dropout == 0: norm_x = self.layer_norm(x)
x = self.dropout_module(x) logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
# prob = self.ctc.softmax(norm_x)
prob = F.softmax(logit, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
# gather cosine similarity # gather cosine similarity
if self.gather_cos_sim: if self.gather_cos_sim:
...@@ -637,6 +686,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -637,6 +686,7 @@ class S2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
...@@ -809,6 +859,10 @@ def base_architecture(args): ...@@ -809,6 +859,10 @@ def base_architecture(args):
args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None) args.cl_dropout_epoch = getattr(args, "cl_dropout_epoch", None)
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear") args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
# intermedia CTC
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
@register_model_architecture("s2t_transformer", "s2t_transformer_s") @register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args): def s2t_transformer_s(args):
......
...@@ -26,6 +26,7 @@ class DynamicLinearCombination(nn.Module): ...@@ -26,6 +26,7 @@ class DynamicLinearCombination(nn.Module):
else: else:
layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers) layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.layer_num = layer_num
# init weights and corresponding masks # init weights and corresponding masks
learnable = args.encoder_learnable if is_encoder else args.decoder_learnable learnable = args.encoder_learnable if is_encoder else args.decoder_learnable
self.weight, self.weight_mask = self._init(layer_num, args.init_value, args.weight_type, self.weight, self.weight_mask = self._init(layer_num, args.init_value, args.weight_type,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论