Commit 31e7c426 by xuchen

fix the bug of the intermedia ctc losses

parent 2215ade0
arch: multi_ctc_s2t_transformer_s
multi-ctc-layers: 6,8,10,12
intermedia-ctc-layers: 6,8,10
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......
......@@ -29,21 +29,21 @@ cat $GEN | cut -f 3 > $REF
cat $GEN | cut -f 4 > $SYS
#detokenize the decodes file to format the manner to do tokenize
perl $detokenizer -l de < $SYS > $SYS.dtk
perl $detokenizer -l de < $REF > $REF.dtk
$detokenizer -l de < $SYS > $SYS.dtk
$detokenizer -l de < $REF > $REF.dtk
#replace unicode
perl $replace_unicode_punctuation -l de < $SYS.dtk > $SYS.dtk.punc
perl $replace_unicode_punctuation -l de < $REF.dtk > $REF.dtk.punc
$replace_unicode_punctuation -l de < $SYS.dtk > $SYS.dtk.punc
$replace_unicode_punctuation -l de < $REF.dtk > $REF.dtk.punc
#tokenize the decodes file by moses tokenizer.perl
perl $tokenizer -l de < $SYS.dtk.punc > $SYS.dtk.punc.tok
perl $tokenizer -l de < $REF.dtk.punc > $REF.dtk.punc.tok
$tokenizer -l de < $SYS.dtk.punc > $SYS.dtk.punc.tok
$tokenizer -l de < $REF.dtk.punc > $REF.dtk.punc.tok
#"rich-text format" --> rich ##AT##-##AT## text format.
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $SYS.dtk.punc.tok > $SYS.dtk.punc.tok.atat
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $REF.dtk.punc.tok > $REF.dtk.punc.tok.atat
perl $multi_bleu $REF.dtk.punc.tok.atat < $SYS.dtk.punc.tok.atat
$multi_bleu $REF.dtk.punc.tok.atat < $SYS.dtk.punc.tok.atat
rm -f $SYS.dtk $SYS.dtk.punc $SYS.dtk.punc.tok $REF.dtk $REF.dtk.punc $REF.dtk.punc.tok
\ No newline at end of file
......@@ -19,7 +19,8 @@ from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
class LabelSmoothedCrossEntropyCriterionWithCTC(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, sentence_avg, label_smoothing, post_process="letter", ctc_weight=0.0):
def __init__(self, task, sentence_avg, label_smoothing, post_process="letter",
ctc_weight=0.0, intermedia_ctc_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing)
self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
self.pad_idx = task.target_dictionary.pad()
......@@ -29,7 +30,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
assert 0 <= ctc_weight
self.ctc_weight = ctc_weight
if self.ctc_weight > 0:
self.intermedia_ctc_weight = intermedia_ctc_weight
if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0:
assert getattr(task, "src_dict", None) is not None, "CTC need a source dictionary."
self.post_process = post_process
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
......@@ -52,6 +54,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
help="weight of CTC loss",
)
parser.add_argument(
"--intermedia-ctc-weight",
default=0.0,
type=float,
metavar="D",
help="weight of intermedia CT loss",
)
parser.add_argument(
"--post-process",
default="letter",
type=str,
......@@ -91,10 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.ctc_weight > 0:
if self.ctc_weight > 0 or self.intermedia_ctc_weight > 0:
ctc_loss, logging_output = self.compute_ctc_loss(model, sample, encoder_out, logging_output)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
loss = (1 - self.ctc_weight) * loss + self.ctc_weight * ctc_loss
loss = (1 - self.ctc_weight) * loss + ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
......@@ -114,10 +122,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
transcript_lengths = pad_mask.sum(-1)
ctc_loss = 0
ctc_num = len(encoder_out["ctc_logit"])
assert ctc_num != 0, "No ctc logit for loss!"
for i in range(ctc_num):
if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) > 0:
ctc_logit = encoder_out["ctc_logit"][0]
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
......@@ -125,17 +130,41 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
ctc_loss = self.ctc_loss(
lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
ctc_loss += loss
ctc_loss /= ctc_num
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
if not model.training:
intermedia_ctc_num = 0
intermedia_ctc_loss = 0
if "intermedia_ctc_logit" in encoder_out:
intermedia_ctc_num = len(encoder_out["intermedia_ctc_logit"])
if intermedia_ctc_num > 0:
for i in range(intermedia_ctc_num):
ctc_logit = encoder_out["intermedia_ctc_logit"][i]
inter_lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
inter_lprobs.batch_first = False
with torch.backends.cudnn.flags(enabled=False):
loss = self.ctc_loss(
inter_lprobs,
targets_flat,
input_lengths,
transcript_lengths,
)
intermedia_ctc_loss += loss
intermedia_ctc_loss /= intermedia_ctc_num
logging_output["intermedia_ctc_loss"] = utils.item(intermedia_ctc_loss.data)
loss = self.ctc_weight * ctc_loss + self.intermedia_ctc_weight * intermedia_ctc_loss
if not model.training and self.ctc_weight > 0:
import editdistance
with torch.no_grad():
......@@ -189,7 +218,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return ctc_loss, logging_output
return loss, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
......@@ -204,6 +233,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
inter_ctc_loss_sum = utils.item(
sum(log.get("intermedia_ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
......@@ -226,6 +258,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sample_size,
round=3,
)
if inter_ctc_loss_sum > 0:
metrics.log_scalar(
"intermedia_ctc_loss",
inter_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
......
......@@ -7,7 +7,7 @@ from .berard import * # noqa
from .ctc import * # noqa
from .convtransformer import * # noqa
from .s2t_transformer import * # noqa
from .multi_ctc_s2t_transformer import * # noqa
from .inter_ctc_s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa
......@@ -39,17 +39,14 @@ class CTC(nn.Module):
x = self.ctc_projection(self.ctc_dropout_module(x))
return x
@staticmethod
def softmax(ctc_logit, temperature=1.0):
return torch.nn.functional.softmax(ctc_logit / temperature, dim=-1)
def softmax(self, x, temperature=1.0):
return torch.nn.functional.softmax(self.ctc_projection(x) / temperature, dim=-1)
@staticmethod
def log_softmax(ctc_logit, temperature=1.0):
return torch.nn.functional.log_softmax(ctc_logit / temperature, dim=-1)
def log_softmax(self, x, temperature=1.0):
return torch.nn.functional.log_softmax(self.ctc_projection(x) / temperature, dim=-1)
@staticmethod
def argmax(ctc_logit):
return torch.argmax(ctc_logit, dim=-1)
def argmax(self, x):
return torch.argmax(self.ctc_projection(x), dim=-1)
class CTCCompressStrategy:
......
......@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class Adapter(nn.Module):
def __init__(self, args, dictionary, embed_tokens):
def __init__(self, args, dictionary, embed_tokens=None):
super().__init__()
embed_dim = args.encoder_embed_dim
......@@ -45,7 +45,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None:
num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, embed_dim, self.padding_idx)
self.embed_adapter = Embedding(num_embeddings, embed_dim, dictionary.pad())
else:
self.embed_adapter = embed_tokens
......@@ -115,9 +115,9 @@ class Adapter(nn.Module):
return out, padding
@register_model("multi_ctc_s2t_transformer")
class MultiCTCS2TTransformerModel(S2TTransformerModel):
"""Speech-to-Text Transformer with multiple CTC Loss in different layers"""
@register_model("inter_ctc_s2t_transformer")
class InterCTCS2TTransformerModel(S2TTransformerModel):
"""Speech-to-Text Transformer with intermedia CTC Loss in different layers"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
......@@ -126,10 +126,10 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
def add_args(parser):
S2TTransformerModel.add_args(parser)
parser.add_argument(
"--multi-ctc-layers",
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by ",
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--adapter",
......@@ -147,7 +147,7 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
@classmethod
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TMultiCTCTransformerEncoder(args, task, embed_tokens)
encoder = S2TInterCTCTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
......@@ -159,41 +159,35 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
return encoder
class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
"""Speech-to-text Transformer encoder that consists of multiple input subsampler and
Conformer encoder."""
class S2TInterCTCTransformerEncoder(S2TTransformerEncoder):
"""Speech-to-text Transformer encoder that consists of intermedia ctc losses """
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(args, task, embed_tokens)
if self.use_ctc:
del self.ctc
self.multi_ctc_layers = []
if args.multi_ctc_layers is not None:
multi_ctc_layers = args.multi_ctc_layers.split(",")
for layer_idx in multi_ctc_layers:
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.multi_ctc_layers.append(layer_idx)
self.intermedia_ctc_layers.append(layer_idx)
inter_ctc = True if layer_idx != args.encoder_layers else False
if inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
ctc = CTC(args.encoder_embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=inter_ctc)
need_layernorm=True)
if task.source_dictionary == task.target_dictionary and embed_tokens is not None:
ctc.ctc_projection.weight = embed_tokens.weight
ctc.ctc_projection.weight = self.ctc.ctc_projection.weight
ctc.LayerNorm = self.layer_norm
setattr(self, f"ctc{layer_idx}", ctc)
if inter_ctc:
adapter = Adapter(args, task.source_dictionary, ctc.ctc_projection)
setattr(self, f"adapter{layer_idx}", adapter)
adapter = Adapter(args, task.source_dictionary)
# adapter = Adapter(args, task.source_dictionary, ctc.ctc_projection)
setattr(self, f"adapter{layer_idx}", adapter)
def forward(self, src_tokens, src_lengths):
......@@ -223,7 +217,8 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
self.history.push(x)
layer_idx = 0
ctc_logit = []
ctc_logit = None
intermedia_ctc_logit = []
for layer in self.layers:
layer_idx += 1
......@@ -234,14 +229,14 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions)
# interleave CTC
if self.use_ctc and layer_idx in self.multi_ctc_layers and layer_idx != len(self.layers):
if layer_idx in self.intermedia_ctc_layers:
ctc = getattr(self, f"ctc{layer_idx}")
adapter = getattr(self, f"adapter{layer_idx}")
logit = ctc(x)
prob = ctc.softmax(logit)
prob = ctc.softmax(x)
x, encoder_padding_mask = adapter([x, prob], encoder_padding_mask)
ctc_logit.append(ctc(x))
intermedia_ctc_logit.append(logit)
if layer_idx != len(self.layers) \
and self.interleaved_dropout is not None \
......@@ -257,13 +252,13 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
if self.layer_norm is not None:
x = self.layer_norm(x)
if self.use_ctc and len(self.layers) in self.multi_ctc_layers:
ctc = getattr(self, f"ctc{len(self.layers)}")
ctc_logit.append(ctc(x))
if self.use_ctc:
ctc_logit = self.ctc(x)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": ctc_logit, # B x T x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # B x T x C
"intermedia_ctc_logit": intermedia_ctc_logit, # B x T x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
......@@ -272,7 +267,7 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
}
@register_model_architecture(model_name="multi_ctc_s2t_transformer", arch_name="multi_ctc_s2t_transformer")
@register_model_architecture(model_name="inter_ctc_s2t_transformer", arch_name="inter_ctc_s2t_transformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
......@@ -321,7 +316,7 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
# CTC
args.multi_ctc_layers = getattr(args, "multi_ctc_layers", 0)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", 0)
# Conformer
args.macaron_style = getattr(args, "macaron_style", False)
......@@ -356,13 +351,13 @@ def base_architecture(args):
args.cl_dropout_strategy = getattr(args, "cl_dropout_strategy", "linear")
@register_model_architecture("multi_ctc_s2t_transformer", "multi_ctc_s2t_transformer_s")
def multi_ctc_s2t_transformer_s(args):
@register_model_architecture("inter_ctc_s2t_transformer", "inter_ctc_s2t_transformer_s")
def inter_ctc_s2t_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
args.multi_ctc_layers = getattr(args, "multi_ctc_layers", None)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
base_architecture(args)
......@@ -170,7 +170,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None:
num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, embed_dim, self.padding_idx)
self.embed_adapter = Embedding(num_embeddings, embed_dim, dictionary.pad())
else:
self.embed_adapter = embed_tokens
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论