Commit 444a1f46 by xuchen

Daily revision and add the consistency regularization for mixup

parent cabfc4ea
......@@ -8,9 +8,9 @@ best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True
#keep-last-epochs: 10
keep-best-checkpoints: 10
keep-last-epochs: 10
#no-epoch-checkpoints: True
#keep-best-checkpoints: 10
num-workers: 8
no-progress-bar: True
......
ctc-weight: 0.3
ctc-weight: 1.0
share-ctc-and-embed: True
share-target-ctc-and-embed: True
interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 6,9
......@@ -9,7 +8,10 @@ interleaved-ctc-drop-prob: 0
sae-adapter: inter_league
sae-drop-prob: 0.0
sae-distribution-cutoff: 0
#sae-distribution-cutoff: 0
share-ctc-and-sae: False
share-interleaved-ctc: True
ctc-self-distill-weight: 0
ctc-self-distill-prob: 0
ctc-self-distill-temperature: 1
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
inter_mixup_beta: 0.2
inter-mixup: True
inter-mixup-layer: -1
inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5
inter-mixup-keep-org: True
ctc-mixupr-consistent-weight: 1
......@@ -9,6 +9,7 @@ adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
ctc-weight: 1.0
encoder-embed-norm: True
encoder-no-scale-embedding: True
......
......@@ -300,7 +300,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ ! -f ${model_dir}/${dec_model} ]]; then
cmd="python ${code_dir}/scripts/average_checkpoints.py
--inputs ${model_dir}
--num-best-checkpoints ${n_average}
--num-epoch-checkpoints ${n_average}
--output ${model_dir}/${dec_model}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval $cmd
......@@ -328,6 +328,16 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
if [[ ${n_average} -ne 1 ]]; then
suffix=${suffix}_${n_average}
fi
if [[ ${ctc_infer} -eq 1 ]]; then
suffix=${suffix}_ctc
fi
if [[ ${ctc_self_ensemble} -eq 1 ]]; then
suffix=${suffix}_ensemble
fi
if [[ ${ctc_inter_logit} -ne 0 ]]; then
suffix=${suffix}_logit${ctc_inter_logit}
fi
result_file=${model_dir}/decode_result_${suffix}
[[ -f ${result_file} ]] && rm ${result_file}
......@@ -359,9 +369,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmd="${cmd}
--ctc-self-ensemble"
fi
if [[ ${ctc_inter_logit} -eq 1 ]]; then
if [[ ${ctc_inter_logit} -ne 0 ]]; then
cmd="${cmd}
--ctc-inter-logit"
--ctc-inter-logit ${ctc_inter_logit}"
fi
echo -e "\033[34mRun command: \n${cmd} \033[0m"
......
......@@ -11,9 +11,6 @@ from omegaconf import II
from typing import Optional
import numpy as np
import logging
import editdistance
import os
import sys
import torch
import torch.nn.functional as F
......@@ -22,7 +19,6 @@ from torch.distributions import Categorical
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import post_process
from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round
......@@ -95,6 +91,10 @@ class CtcCriterionConfig(FairseqDataclass):
default=1,
metadata={"help": "temperature for ctc self distillation"},
)
ctc_mixup_consistent_weight: float = field(
default=0,
metadata={"help": "consistent regularization loss for mixup"},
)
wer_kenlm_model: Optional[str] = field(
default=None,
......@@ -178,10 +178,12 @@ class CtcCriterion(FairseqCriterion):
self.ctc_entropy = cfg.ctc_entropy
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.ctc_mixup_consistent_weight = cfg.ctc_mixup_consistent_weight
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.target_ctc_self_distill_weight + \
self.ctc_entropy
self.ctc_entropy + self.ctc_mixup_consistent_weight
if self.all_ctc_weight > 0:
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="sum", zero_infinity=True)
......@@ -342,6 +344,8 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
nfeatures = input_lengths.sum().item()
logging_output["nfeatures"] = nfeatures
pad_mask = (tokens != self.pad_idx) & (tokens != self.eos_idx)
......@@ -551,13 +555,38 @@ class CtcCriterion(FairseqCriterion):
logging_output["target_ctc_self_distill_loss"] = utils.item(target_ctc_self_distill_loss.data)
ctc_self_distill_loss += target_ctc_self_distill_loss * self.target_ctc_self_distill_weight
ctc_mixup_consistent_loss = 0
if mixup is True and self.ctc_mixup_consistent_weight > 0:
mixup_pos = mixup_idx1 != mixup_idx2
ctc_logit = net_output["ctc_logit"][0]
mixup_real_logit = ctc_logit[:, mixup_pos, :]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
no_mixup_logit = ctc_logit[:, ~mixup_pos, :]
mixup_target_logit = [no_mixup_logit[:, mixup_real_idx1, :], no_mixup_logit[:, mixup_real_idx2, :]]
mixup_target_pad_mask = [non_padding_mask[mixup_real_idx1], non_padding_mask[mixup_real_idx2]]
for logit, pad, coef in zip(mixup_target_logit, mixup_target_pad_mask, loss_coef):
loss = F.kl_div(
F.log_softmax(mixup_real_logit, dim=-1, dtype=torch.float32),
# F.log_softmax(teacher_logit / temperature, dim=-1, dtype=torch.float32),
F.log_softmax(logit.detach(), dim=-1, dtype=torch.float32),
log_target=True,
reduction="none",
)
ctc_mixup_consistent_loss += loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum() * coef
logging_output["ctc_mixup_consistent_loss"] = utils.item(ctc_mixup_consistent_loss.data)
loss = \
self.ctc_weight * ctc_loss + \
self.interleaved_ctc_weight * interleaved_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \
ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy
self.ctc_entropy * ctc_entropy + \
self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss
logging_output["all_ctc_loss"] = utils.item(loss.data)
......@@ -577,7 +606,11 @@ class CtcCriterion(FairseqCriterion):
if lprobs is not None:
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
if mixup:
idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
# idx = mixup_idx1 if mixup_coef > 0.5 else mixup_idx2
# tokens = tokens[idx]
no_mixup_idx = mixup_idx1 == mixup_idx2
idx = mixup_idx1[no_mixup_idx]
lprobs_t = lprobs_t[idx]
tokens = tokens[idx]
c_err, c_len, w_errs, w_len, wv_errs = encoder.ctc_valid(
......@@ -652,10 +685,15 @@ class CtcCriterion(FairseqCriterion):
target_ctc_self_distill_loss_sum = utils.item(
sum(log.get("target_ctc_self_distill_loss", 0) for log in logging_outputs)
)
ctc_mixup_consistent_loss = utils.item(
sum(log.get("ctc_mixup_consistent_loss", 0) for log in logging_outputs)
)
all_ctc_loss_sum = utils.item(
sum(log.get("all_ctc_loss", 0) for log in logging_outputs)
)
# loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nfeatures = utils.item(sum(log.get("nfeatures", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
......@@ -664,8 +702,6 @@ class CtcCriterion(FairseqCriterion):
sum(log.get("sample_size", 0) for log in logging_outputs)
)
if np.isnan(all_ctc_loss_sum) or np.isinf(all_ctc_loss_sum) or all_ctc_loss_sum < 0:
logger.warning("Illegal loss %f!" % all_ctc_loss_sum)
if all_ctc_loss_sum > 0:
if "loss" not in logging_outputs[0]:
metrics.log_scalar(
......@@ -692,7 +728,7 @@ class CtcCriterion(FairseqCriterion):
if ctc_entropy_sum > 0:
metrics.log_scalar(
"ctc_entropy",
ctc_entropy_sum / sample_size / math.log(2),
ctc_entropy_sum / nfeatures / math.log(2),
sample_size,
round=3,
)
......@@ -721,14 +757,21 @@ class CtcCriterion(FairseqCriterion):
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"ctc_self_distill_loss",
ctc_self_distill_loss_sum / sample_size / math.log(2),
ctc_self_distill_loss_sum / nfeatures / math.log(2),
sample_size,
round=3,
)
if target_ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"target_ctc_self_distill_loss_sum",
target_ctc_self_distill_loss_sum / sample_size / math.log(2),
target_ctc_self_distill_loss_sum / nfeatures / math.log(2),
sample_size,
round=3,
)
if ctc_mixup_consistent_loss > 0:
metrics.log_scalar(
"ctc_mixup_consistent_loss",
ctc_mixup_consistent_loss / nfeatures / math.log(2),
sample_size,
round=3,
)
......
......@@ -11,9 +11,6 @@ from fairseq.models import (
register_model_architecture,
)
# from .s2t_sate import S2TSATEModel, S2TSATEEncoder
# from .s2t_transformer import S2TTransformerModel, S2TTransformerEncoder
# from .pdss2t_transformer import PDSS2TTransformerModel, PDSS2TTransformerEncoder
from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
......@@ -128,6 +125,9 @@ class S2TCTCEncoder(FairseqEncoder):
def set_ctc_infer(self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None):
self.encoder.set_ctc_infer(ctc_infer, post_process, src_dict=src_dict, tgt_dict=tgt_dict, path=path)
def ctc_valid(self, lprobs, targets, input_lengths, dictionary, lang="source"):
return self.encoder.ctc_valid(lprobs, targets, input_lengths, dictionary, lang)
def forward(self, src_tokens, src_lengths, **kwargs):
return self.encoder(src_tokens, src_lengths, **kwargs)
......@@ -208,6 +208,9 @@ class CTCDecoder(object):
encoder_outs = self.model(src_tokens=src_tokens,
src_lengths=src_lengths)
if "target_ctc_logit" in encoder_outs:
ctc_logit = encoder_outs["target_ctc_logit"][0].transpose(0, 1)
else:
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
inter_logits = encoder_outs.get("interleaved_ctc_logits", [])
inter_logits_num = len(inter_logits)
......@@ -357,9 +360,9 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.target_sae_adapter = getattr(args, "target_sae_adapter", args.sae_adapter)
args.share_sae_and_ctc = getattr(args, "share_sae_and_ctc", False)
args.sae_embed_norm = getattr(args, "sae_embed_norm", False)
args.sae_out_norm = getattr(args, "sae_out_norm", False)
args.share_target_sae_and_ctc = getattr(args, "share_target_sae_and_ctc", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
args.sae_distribution_cutoff = getattr(args, "sae_distribution_cutoff", None)
args.sae_distribution_hard = getattr(args, "sae_distribution_hard", False)
......@@ -370,7 +373,8 @@ def base_architecture(args):
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False)
# PDS
args.pds_stages = getattr(args, "pds_stages", None)
......
......@@ -296,7 +296,7 @@ class TextualEncoder(FairseqEncoder):
if self.inter_ctc:
logger.info("Target CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=embed_tokens.num_embeddings,
dictionary_size=embed_tokens.num_embeddings if embed_tokens is not None else len(dictionary),
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
......
......@@ -515,6 +515,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=float,
help="the ratio of mixup",
)
parser.add_argument(
"--inter-mixup-keep-org",
action="store_true",
help="keep original batch",
)
pass
@classmethod
......@@ -745,15 +750,16 @@ class S2TTransformerEncoder(FairseqEncoder):
# mixup
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_layer = int(args.inter_mixup_layer)
self.mixup_prob = float(args.inter_mixup_prob)
self.mixup_ratio = float(args.inter_mixup_ratio)
self.mixup_layer = args.inter_mixup_layer
self.mixup_prob = args.inter_mixup_prob
self.mixup_ratio = args.inter_mixup_ratio
self.mixup_keep_org = args.inter_mixup_keep_org
beta = float(args.inter_mixup_beta)
beta = args.inter_mixup_beta
from torch.distributions import Beta
self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta]))
logger.info("Use mixup in layer %d with beta %.2f, prob %.2f, ratio %.2f." % (
self.mixup_layer, beta, self.mixup_prob, self.mixup_ratio))
logger.info("Use mixup in layer %d with beta %.2f, prob %.2f, ratio %.2f, keep original data %r." % (
self.mixup_layer, beta, self.mixup_prob, self.mixup_ratio, self.mixup_keep_org))
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
......@@ -812,15 +818,24 @@ class S2TTransformerEncoder(FairseqEncoder):
def apply_mixup(self, x, encoder_padding_mask):
batch = x.size(1)
indices = np.random.permutation(batch)
org_indices = np.arange(batch)
if self.mixup_ratio == 1:
if len(indices) % 2 != 0:
indices = np.append(indices, (indices[-1]))
idx1 = indices[0::2]
idx2 = indices[1::2]
if self.mixup_keep_org:
idx1 = np.append(org_indices, idx1)
idx2 = np.append(org_indices, idx2)
else:
mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
mix_indices = indices[: mix_size]
if self.mixup_keep_org:
idx1 = np.append(org_indices, mix_indices[0::2])
idx2 = np.append(org_indices, mix_indices[1::2])
else:
idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
......@@ -928,7 +943,7 @@ class S2TTransformerEncoder(FairseqEncoder):
interleaved_ctc_logits = []
if self.training and self.mixup and layer_idx == self.mixup_layer:
if torch.rand(1) < self.mixup_prob:
if torch.rand(1) <= self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
self.show_debug(x, "x before encoding")
......@@ -1209,7 +1224,8 @@ def base_architecture(args):
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论