Commit 31e7c426 by xuchen

fix the bug of the intermedia ctc losses

parent 2215ade0
arch: multi_ctc_s2t_transformer_s arch: multi_ctc_s2t_transformer_s
multi-ctc-layers: 6,8,10,12 intermedia-ctc-layers: 6,8,10
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
......
...@@ -29,21 +29,21 @@ cat $GEN | cut -f 3 > $REF ...@@ -29,21 +29,21 @@ cat $GEN | cut -f 3 > $REF
cat $GEN | cut -f 4 > $SYS cat $GEN | cut -f 4 > $SYS
#detokenize the decodes file to format the manner to do tokenize #detokenize the decodes file to format the manner to do tokenize
perl $detokenizer -l de < $SYS > $SYS.dtk $detokenizer -l de < $SYS > $SYS.dtk
perl $detokenizer -l de < $REF > $REF.dtk $detokenizer -l de < $REF > $REF.dtk
#replace unicode #replace unicode
perl $replace_unicode_punctuation -l de < $SYS.dtk > $SYS.dtk.punc $replace_unicode_punctuation -l de < $SYS.dtk > $SYS.dtk.punc
perl $replace_unicode_punctuation -l de < $REF.dtk > $REF.dtk.punc $replace_unicode_punctuation -l de < $REF.dtk > $REF.dtk.punc
#tokenize the decodes file by moses tokenizer.perl #tokenize the decodes file by moses tokenizer.perl
perl $tokenizer -l de < $SYS.dtk.punc > $SYS.dtk.punc.tok $tokenizer -l de < $SYS.dtk.punc > $SYS.dtk.punc.tok
perl $tokenizer -l de < $REF.dtk.punc > $REF.dtk.punc.tok $tokenizer -l de < $REF.dtk.punc > $REF.dtk.punc.tok
#"rich-text format" --> rich ##AT##-##AT## text format. #"rich-text format" --> rich ##AT##-##AT## text format.
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $SYS.dtk.punc.tok > $SYS.dtk.punc.tok.atat perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $SYS.dtk.punc.tok > $SYS.dtk.punc.tok.atat
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $REF.dtk.punc.tok > $REF.dtk.punc.tok.atat perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $REF.dtk.punc.tok > $REF.dtk.punc.tok.atat
perl $multi_bleu $REF.dtk.punc.tok.atat < $SYS.dtk.punc.tok.atat $multi_bleu $REF.dtk.punc.tok.atat < $SYS.dtk.punc.tok.atat
rm -f $SYS.dtk $SYS.dtk.punc $SYS.dtk.punc.tok $REF.dtk $REF.dtk.punc $REF.dtk.punc.tok rm -f $SYS.dtk $SYS.dtk.punc $SYS.dtk.punc.tok $REF.dtk $REF.dtk.punc $REF.dtk.punc.tok
\ No newline at end of file
...@@ -19,7 +19,8 @@ from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion ...@@ -19,7 +19,8 @@ from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
class LabelSmoothedCrossEntropyCriterionWithCTC( class LabelSmoothedCrossEntropyCriterionWithCTC(
LabelSmoothedCrossEntropyCriterion LabelSmoothedCrossEntropyCriterion
): ):
def __init__(self, task, sentence_avg, label_smoothing, post_process="letter", ctc_weight=0.0): def __init__(self, task, sentence_avg, label_smoothing, post_process="letter",
ctc_weight=0.0, intermedia_ctc_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing) super().__init__(task, sentence_avg, label_smoothing)
self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0 self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad() self.pad_idx = task.target_dictionary.pad()
...@@ -29,7 +30,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -29,7 +30,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
assert 0 <= ctc_weight assert 0 <= ctc_weight
self.ctc_weight = ctc_weight self.ctc_weight = ctc_weight
if self.ctc_weight > 0: self.intermedia_ctc_weight = intermedia_ctc_weight
if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary." assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.post_process = post_process self.post_process = post_process
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True) self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
...@@ -52,6 +54,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -52,6 +54,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
help="weight of CTC loss", help="weight of CTC loss",
) )
parser.add_argument( parser.add_argument(
"--intermedia-ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of intermedia CT loss",
)
parser.add_argument(
"--post-process", "--post-process",
default="letter", default="letter",
type=str, type=str,
...@@ -91,10 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -91,10 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output["n_correct"] = utils.item(n_correct.data) logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data) logging_output["total"] = utils.item(total.data)
if self.ctc_weight > 0: if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0:
ctc_loss, logging_output = self.compute_ctc_loss(model, sample, encoder_out, logging_output) ctc_loss, logging_output = self.compute_ctc_loss(model, sample, encoder_out, logging_output)
logging_output["ctc_loss"] = utils.item(ctc_loss.data) loss = (1 - self.ctc_weight) * loss + ctc_loss
loss = (1 - self.ctc_weight) * loss + self.ctc_weight * ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -114,10 +122,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -114,10 +122,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
transcript_lengths = pad_mask.sum(-1) transcript_lengths = pad_mask.sum(-1)
ctc_loss = 0 ctc_loss = 0
ctc_num = len(encoder_out["ctc_logit"]) if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) > 0:
assert ctc_num != 0, "No ctc logit for loss!"
for i in range(ctc_num):
ctc_logit = encoder_out["ctc_logit"][0] ctc_logit = encoder_out["ctc_logit"][0]
lprobs = model.get_normalized_probs( lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True [ctc_logit], log_probs=True
...@@ -125,17 +130,41 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -125,17 +130,41 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
lprobs.batch_first = False lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss( ctc_loss = self.ctc_loss(
lprobs, lprobs,
targets_flat, targets_flat,
input_lengths, input_lengths,
transcript_lengths, transcript_lengths,
) )
ctc_loss += loss
ctc_loss /= ctc_num
logging_output["ctc_loss"] = utils.item(ctc_loss.data) logging_output["ctc_loss"] = utils.item(ctc_loss.data)
if not model.training: intermedia_ctc_num = 0
intermedia_ctc_loss = 0
if "intermedia_ctc_logit" in encoder_out:
intermedia_ctc_num = len(encoder_out["intermedia_ctc_logit"])
if intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
ctc_logit = encoder_out["intermedia_ctc_logit"][i]
inter_lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
intermedia_ctc_loss += loss
intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
loss = self.ctc_weight * ctc_loss + self.intermedia_ctc_weight * intermedia_ctc_loss
if not model.training and self.ctc_weight > 0:
import editdistance import editdistance
with torch.no_grad(): with torch.no_grad():
...@@ -189,7 +218,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -189,7 +218,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output["c_errors"] = c_err logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len logging_output["c_total"] = c_len
return ctc_loss, logging_output return loss, logging_output
@staticmethod @staticmethod
def reduce_metrics(logging_outputs) -> None: def reduce_metrics(logging_outputs) -> None:
...@@ -204,6 +233,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -204,6 +233,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_loss_sum = utils.item( ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs) sum(log.get("ctc_loss", 0) for log in logging_outputs)
) )
inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -226,6 +258,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -226,6 +258,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sample_size, sample_size,
round=3, round=3,
) )
if inter_ctc_loss_sum > 0:
metrics.log_scalar(
"intermedia_ctc_loss",
inter_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
......
...@@ -7,7 +7,7 @@ from .berard import * # noqa ...@@ -7,7 +7,7 @@ from .berard import * # noqa
from .ctc import * # noqa from .ctc import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .multi_ctc_s2t_transformer import * # noqa from .inter_ctc_s2t_transformer import * # noqa
from .s2t_conformer import * # noqa from .s2t_conformer import * # noqa
from .pdss2t_transformer import * # noqa from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
...@@ -39,17 +39,14 @@ class CTC(nn.Module): ...@@ -39,17 +39,14 @@ class CTC(nn.Module):
x = self.ctc_projection(self.ctc_dropout_module(x)) x = self.ctc_projection(self.ctc_dropout_module(x))
return x return x
@staticmethod def softmax(self, x, temperature=1.0):
def softmax(ctc_logit, temperature=1.0): return torch.nn.functional.softmax(self.ctc_projection(x) / temperature, dim=-1)
return torch.nn.functional.softmax(ctc_logit / temperature, dim=-1)
@staticmethod def log_softmax(self, x, temperature=1.0):
def log_softmax(ctc_logit, temperature=1.0): return torch.nn.functional.log_softmax(self.ctc_projection(x) / temperature, dim=-1)
return torch.nn.functional.log_softmax(ctc_logit / temperature, dim=-1)
@staticmethod def argmax(self, x):
def argmax(ctc_logit): return torch.argmax(self.ctc_projection(x), dim=-1)
return torch.argmax(ctc_logit, dim=-1)
class CTCCompressStrategy: class CTCCompressStrategy:
......
...@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) ...@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class Adapter(nn.Module): class Adapter(nn.Module):
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens=None):
super().__init__() super().__init__()
embed_dim = args.encoder_embed_dim embed_dim = args.encoder_embed_dim
...@@ -45,7 +45,7 @@ class Adapter(nn.Module): ...@@ -45,7 +45,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None: if embed_tokens is None:
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, embed_dim, self.padding_idx) self.embed_adapter = Embedding(num_embeddings, embed_dim, dictionary.pad())
else: else:
self.embed_adapter = embed_tokens self.embed_adapter = embed_tokens
...@@ -115,9 +115,9 @@ class Adapter(nn.Module): ...@@ -115,9 +115,9 @@ class Adapter(nn.Module):
return out, padding return out, padding
@register_model("multi_ctc_s2t_transformer") @register_model("inter_ctc_s2t_transformer")
class MultiCTCS2TTransformerModel(S2TTransformerModel): class InterCTCS2TTransformerModel(S2TTransformerModel):
"""Speech-to-Text Transformer with multiple CTC Loss in different layers""" """Speech-to-Text Transformer with intermedia CTC Loss in different layers"""
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -126,10 +126,10 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel): ...@@ -126,10 +126,10 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
def add_args(parser): def add_args(parser):
S2TTransformerModel.add_args(parser) S2TTransformerModel.add_args(parser)
parser.add_argument( parser.add_argument(
"--multi-ctc-layers", "--intermedia-ctc-layers",
default=None, default=None,
type=str, type=str,
help="the position of the ctc loss, separated by ", help="the position of the ctc loss, separated by comma ",
) )
parser.add_argument( parser.add_argument(
"--adapter", "--adapter",
...@@ -147,7 +147,7 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel): ...@@ -147,7 +147,7 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
@classmethod @classmethod
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TMultiCTCTransformerEncoder(args, task, embed_tokens) encoder = S2TInterCTCTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model( encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
...@@ -159,40 +159,34 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel): ...@@ -159,40 +159,34 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
return encoder return encoder
class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder): class S2TInterCTCTransformerEncoder(S2TTransformerEncoder):
"""Speech-to-text Transformer encoder that consists of multiple input subsampler and """Speech-to-text Transformer encoder that consists of intermedia ctc losses """
Conformer encoder."""
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, embed_tokens=None):
super().__init__(args, task, embed_tokens) super().__init__(args, task, embed_tokens)
if self.use_ctc: self.intermedia_ctc_layers = []
del self.ctc if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
self.multi_ctc_layers = [] for layer_idx in intermedia_ctc_layers:
if args.multi_ctc_layers is not None:
multi_ctc_layers = args.multi_ctc_layers.split(",")
for layer_idx in multi_ctc_layers:
layer_idx = int(layer_idx) layer_idx = int(layer_idx)
if layer_idx <= 0: if layer_idx <= 0:
layer_idx += args.encoder_layers layer_idx += args.encoder_layers
self.multi_ctc_layers.append(layer_idx) self.intermedia_ctc_layers.append(layer_idx)
inter_ctc = True if layer_idx != args.encoder_layers else False
if inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % layer_idx) logger.info("Intermedia CTC loss in layer %d" % layer_idx)
ctc = CTC(args.encoder_embed_dim, ctc = CTC(args.encoder_embed_dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
need_layernorm=inter_ctc) need_layernorm=True)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None: if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
ctc.ctc_projection.weight = embed_tokens.weight ctc.ctc_projection.weight = self.ctc.ctc_projection.weight
ctc.LayerNorm = self.layer_norm
setattr(self, f"ctc{layer_idx}", ctc) setattr(self, f"ctc{layer_idx}", ctc)
if inter_ctc: adapter = Adapter(args, task.source_dictionary)
adapter = Adapter(args, task.source_dictionary, ctc.ctc_projection) # adapter = Adapter(args, task.source_dictionary, ctc.ctc_projection)
setattr(self, f"adapter{layer_idx}", adapter) setattr(self, f"adapter{layer_idx}", adapter)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
...@@ -223,7 +217,8 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder): ...@@ -223,7 +217,8 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
self.history.push(x) self.history.push(x)
layer_idx = 0 layer_idx = 0
ctc_logit = [] ctc_logit = None
intermedia_ctc_logit = []
for layer in self.layers: for layer in self.layers:
layer_idx += 1 layer_idx += 1
...@@ -234,14 +229,14 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder): ...@@ -234,14 +229,14 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
# interleave CTC # interleave CTC
if self.use_ctc and layer_idx in self.multi_ctc_layers and layer_idx != len(self.layers): if layer_idx in self.intermedia_ctc_layers:
ctc = getattr(self, f"ctc{layer_idx}") ctc = getattr(self, f"ctc{layer_idx}")
adapter = getattr(self, f"adapter{layer_idx}") adapter = getattr(self, f"adapter{layer_idx}")
logit = ctc(x) logit = ctc(x)
prob = ctc.softmax(logit) prob = ctc.softmax(x)
x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask)
ctc_logit.append(ctc(x)) intermedia_ctc_logit.append(logit)
if layer_idx != len(self.layers) \ if layer_idx != len(self.layers) \
and self.interleaved_dropout is not None \ and self.interleaved_dropout is not None \
...@@ -257,13 +252,13 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder): ...@@ -257,13 +252,13 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and len(self.layers) in self.multi_ctc_layers: if self.use_ctc:
ctc = getattr(self, f"ctc{len(self.layers)}") ctc_logit = self.ctc(x)
ctc_logit.append(ctc(x))
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": ctc_logit, # B x T x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C
"intermedia_ctc_logit": intermedia_ctc_logit, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
...@@ -272,7 +267,7 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder): ...@@ -272,7 +267,7 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
} }
@register_model_architecture(model_name="multi_ctc_s2t_transformer", arch_name="multi_ctc_s2t_transformer") @register_model_architecture(model_name="inter_ctc_s2t_transformer", arch_name="inter_ctc_s2t_transformer")
def base_architecture(args): def base_architecture(args):
# Convolutional subsampler # Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
...@@ -321,7 +316,7 @@ def base_architecture(args): ...@@ -321,7 +316,7 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
# CTC # CTC
args.multi_ctc_layers = getattr(args, "multi_ctc_layers", 0) args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", 0)
# Conformer # Conformer
args.macaron_style = getattr(args, "macaron_style", False) args.macaron_style = getattr(args, "macaron_style", False)
...@@ -356,13 +351,13 @@ def base_architecture(args): ...@@ -356,13 +351,13 @@ def base_architecture(args):
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear") args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
@register_model_architecture("multi_ctc_s2t_transformer", "multi_ctc_s2t_transformer_s") @register_model_architecture("inter_ctc_s2t_transformer", "inter_ctc_s2t_transformer_s")
def multi_ctc_s2t_transformer_s(args): def inter_ctc_s2t_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1) args.dropout = getattr(args, "dropout", 0.1)
args.multi_ctc_layers = getattr(args, "multi_ctc_layers", None) args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
base_architecture(args) base_architecture(args)
...@@ -170,7 +170,7 @@ class Adapter(nn.Module): ...@@ -170,7 +170,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None: if embed_tokens is None:
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, embed_dim, self.padding_idx) self.embed_adapter = Embedding(num_embeddings, embed_dim, dictionary.pad())
else: else:
self.embed_adapter = embed_tokens self.embed_adapter = embed_tokens
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论