Commit 6e37ffba by xuchen

Cumulative updates.

implement mix-speech and optimize the ctc entropy loss
parent 6358474e
...@@ -398,7 +398,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -398,7 +398,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cd .. cd ..
echo "CTC WER" >> ${result_file} echo "CTC WER" >> ${result_file}
tail -n 1 ${src_ctc} >> ${result_file} tail -n 2 ${src_ctc} >> ${result_file}
src_bleu=$(mktemp -t temp.record.XXXXXX) src_bleu=$(mktemp -t temp.record.XXXXXX)
cd local cd local
......
arch: s2t_dynamic_transformer
condensation-metric: ratio
#condensation-metric: threshold
condensation-mode: create
##condensation-mode: mask
#condensation-layers: 3,6,9
condensation-threshold: 0.95
#condensation-ratio: 0.8
share-ctc-and-embed: True
interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 6,9
share-interleaved-ctc: True
\ No newline at end of file
inter-mixup: True
inter-mixup-layer: -1
inter-mixup-decoder-layer: 0
inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5
inter-mixup-keep-org: True
inter-mixup-decoder-emb: True
cal-mixup-loss: True
ctc-mixup-consistent-weight: 0
mixup-consistent-weight: 0
\ No newline at end of file
...@@ -5,10 +5,13 @@ condensation-metric: ratio ...@@ -5,10 +5,13 @@ condensation-metric: ratio
condensation-mode: create condensation-mode: create
##condensation-mode: mask ##condensation-mode: mask
#condensation-layers: 3,6,9 #condensation-layers: 3,6,9
condensation-threshold: 0.9 condensation-threshold: 0.95
condensation-ratio: 0.8 #condensation-ratio: 0.8
share-ctc-and-embed: True share-ctc-and-embed: True
interleaved-ctc-weight: 0.2 interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 3,6,9 interleaved-ctc-layers: 6,9
share-interleaved-ctc: True share-interleaved-ctc: True
ctc-entropy-weight: 0.0
ctc-entropy-cutoff: 0
\ No newline at end of file
...@@ -34,5 +34,5 @@ sae-ctc-temperature: 1 ...@@ -34,5 +34,5 @@ sae-ctc-temperature: 1
#ctc-self-distill-prob: 0.1 #ctc-self-distill-prob: 0.1
#cal-all-ctc: True #cal-all-ctc: True
use-aligned-text: True # use-aligned-text: True
aligned-target-ctc: True # aligned-target-ctc: True
inter-mixup: True inter-mixup: True
inter-mixup-layer: -1 inter-mixup-layer: -1
inter-mixup-decoder-layer: 0
inter-mixup-prob: 1.0 inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0 inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5 inter-mixup-beta: 0.5
inter-mixup-keep-org: False inter-mixup-keep-org: True
inter-mixup-decoder-emb: True
cal-mixup-loss: True
ctc-mixup-consistent-weight: 0 ctc-mixup-consistent-weight: 0
mixup-consistent-weight: 0 mixup-consistent-weight: 0
\ No newline at end of file
...@@ -14,6 +14,9 @@ label_smoothing: 0.1 ...@@ -14,6 +14,9 @@ label_smoothing: 0.1
encoder-normalize-before: True encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
encoder-embed-norm: True
encoder-no-scale-embedding: True
subsampling-type: conv1d subsampling-type: conv1d
subsampling-layers: 2 subsampling-layers: 2
subsampling-filter: 2048 subsampling-filter: 2048
...@@ -36,7 +39,7 @@ decoder-ffn-embed-dim: 2048 ...@@ -36,7 +39,7 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8 decoder-attention-heads: 8
acoustic-encoder: transformer acoustic-encoder: transformer
adapter: league adapter: inter_league
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from: #load-pretrained-acoustic-encoder-from:
......
...@@ -11,6 +11,9 @@ adam_betas: (0.9,0.98) ...@@ -11,6 +11,9 @@ adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-embed-norm: True
encoder-no-scale-embedding: True
encoder-normalize-before: True encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
subsampling-type: conv1d subsampling-type: conv1d
...@@ -34,7 +37,7 @@ decoder-ffn-embed-dim: 2048 ...@@ -34,7 +37,7 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8 decoder-attention-heads: 8
acoustic-encoder: pds acoustic-encoder: pds
adapter: league adapter: inter_league
encoder-embed-dim: 512 encoder-embed-dim: 512
ctc-layer: 12 ctc-layer: 12
......
...@@ -11,8 +11,12 @@ adam_betas: (0.9,0.98) ...@@ -11,8 +11,12 @@ adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
encoder-embed-norm: True
encoder-no-scale-embedding: True
encoder-normalize-before: True encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
...@@ -26,10 +30,10 @@ decoder-ffn-embed-dim: 2048 ...@@ -26,10 +30,10 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
acoustic-encoder: pds acoustic-encoder: pds
adapter: league adapter: inter_league
encoder-embed-dim: 256 encoder-embed-dim: 256
#ctc-layer: 12 ctc-layer: 12
pds-stages: 4 pds-stages: 4
pds-layers: 3_3_3_3 pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
......
#! /bin/bash
set -e set -e
ref=$1 ref=$1
......
#! /bin/bash
set -e set -e
infer_dir=$1 infer_dir=$1
......
#! /bin/bash
set -e set -e
infer_dir=$1 infer_dir=$1
......
#! /bin/bash
gpu_num=4 gpu_num=4
cmd="sh train.sh" cmd="sh train.sh"
......
...@@ -449,7 +449,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -449,7 +449,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Source language" >> ${result_file} echo "Source language" >> ${result_file}
echo "CTC WER" >> ${result_file} echo "CTC WER" >> ${result_file}
tail -n 1 ${src_ctc} >> ${result_file} tail -n 2 ${src_ctc} >> ${result_file}
src_bleu=$(mktemp -t temp.record.XXXXXX) src_bleu=$(mktemp -t temp.record.XXXXXX)
cd local cd local
...@@ -475,7 +475,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -475,7 +475,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Target language" >> ${result_file} echo "Target language" >> ${result_file}
echo "CTC WER" >> ${result_file} echo "CTC WER" >> ${result_file}
tail -n 1 ${tgt_ctc} >> ${result_file} tail -n 2 ${tgt_ctc} >> ${result_file}
tgt_bleu=$(mktemp -t temp.record.XXXXXX) tgt_bleu=$(mktemp -t temp.record.XXXXXX)
cd local cd local
......
...@@ -44,7 +44,7 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -44,7 +44,7 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "weight of CTC loss"}, metadata={"help": "weight of CTC loss"},
) )
ctc_entropy: float = field( ctc_entropy_weight: float = field(
default=0.0, default=0.0,
metadata={"help": "weight of CTC entropy"}, metadata={"help": "weight of CTC entropy"},
) )
...@@ -175,7 +175,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -175,7 +175,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_self_distill_prob = float(cfg.ctc_self_distill_prob) self.ctc_self_distill_prob = float(cfg.ctc_self_distill_prob)
self.ctc_self_distill_temperature = float(cfg.ctc_self_distill_temperature) self.ctc_self_distill_temperature = float(cfg.ctc_self_distill_temperature)
self.ctc_entropy = cfg.ctc_entropy self.ctc_entropy_weight = cfg.ctc_entropy_weight
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.ctc_mixup_consistent_weight = cfg.ctc_mixup_consistent_weight self.ctc_mixup_consistent_weight = cfg.ctc_mixup_consistent_weight
...@@ -183,7 +183,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -183,7 +183,7 @@ class CtcCriterion(FairseqCriterion):
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \ self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_interleaved_ctc_weight + \ self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.target_ctc_self_distill_weight + \ self.ctc_self_distill_weight + self.target_ctc_self_distill_weight + \
self.ctc_entropy + self.ctc_mixup_consistent_weight self.ctc_entropy_weight + self.ctc_mixup_consistent_weight
if self.all_ctc_weight > 0: if self.all_ctc_weight > 0:
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="none", zero_infinity=True) self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="none", zero_infinity=True)
...@@ -375,6 +375,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -375,6 +375,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_names = [] self.ctc_names = []
lprobs = None lprobs = None
target_lprobs = None target_lprobs = None
ctc_entropy = []
interleaved_ctc_num = 0 interleaved_ctc_num = 0
interleaved_ctc_loss = 0 interleaved_ctc_loss = 0
...@@ -393,6 +394,14 @@ class CtcCriterion(FairseqCriterion): ...@@ -393,6 +394,14 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_logit = logit inter_ctc_logit = logit
inter_input_lengths = input_lengths inter_input_lengths = input_lengths
if self.ctc_entropy_weight > 0:
if self.ctc_entropy_cutoff != 0:
cut_ctc_logit = inter_ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy.append(Categorical(logits=cut_ctc_logit).entropy().sum())
else:
ctc_entropy.append(Categorical(logits=inter_ctc_logit).entropy().sum())
all_ctc_logits["interleaved_ctc_logit%d" % i] = [inter_ctc_logit, inter_input_lengths] all_ctc_logits["interleaved_ctc_logit%d" % i] = [inter_ctc_logit, inter_input_lengths]
inter_loss, inter_lprobs = self.get_ctc_loss( inter_loss, inter_lprobs = self.get_ctc_loss(
model, inter_ctc_logit, transcripts, inter_input_lengths, transcript_lengths, loss_coef) model, inter_ctc_logit, transcripts, inter_input_lengths, transcript_lengths, loss_coef)
...@@ -403,7 +412,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -403,7 +412,6 @@ class CtcCriterion(FairseqCriterion):
logging_output["interleaved_ctc_loss"] = utils.item(interleaved_ctc_loss.data) logging_output["interleaved_ctc_loss"] = utils.item(interleaved_ctc_loss.data)
ctc_loss = 0 ctc_loss = 0
ctc_entropy = 0
use_ctc = False use_ctc = False
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0: if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
use_ctc = True use_ctc = True
...@@ -413,15 +421,14 @@ class CtcCriterion(FairseqCriterion): ...@@ -413,15 +421,14 @@ class CtcCriterion(FairseqCriterion):
ctc_loss, lprobs = self.get_ctc_loss( ctc_loss, lprobs = self.get_ctc_loss(
model, ctc_logit, transcripts, input_lengths, transcript_lengths, loss_coef) model, ctc_logit, transcripts, input_lengths, transcript_lengths, loss_coef)
if self.ctc_entropy > 0: if self.ctc_entropy_weight > 0:
if self.ctc_entropy_cutoff != 0: if self.ctc_entropy_cutoff != 0:
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff] cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True) cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum() ctc_entropy.append(Categorical(logits=cut_ctc_logit).entropy().sum())
else: else:
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum() ctc_entropy.append(Categorical(logits=ctc_logit).entropy().sum())
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data) logging_output["ctc_loss"] = utils.item(ctc_loss.data)
# calculate the target CTC loss # calculate the target CTC loss
...@@ -584,13 +591,19 @@ class CtcCriterion(FairseqCriterion): ...@@ -584,13 +591,19 @@ class CtcCriterion(FairseqCriterion):
ctc_mixup_consistent_loss += (loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum(-1) * coef).sum() ctc_mixup_consistent_loss += (loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum(-1) * coef).sum()
logging_output["ctc_mixup_consistent_loss"] = utils.item(ctc_mixup_consistent_loss.data) logging_output["ctc_mixup_consistent_loss"] = utils.item(ctc_mixup_consistent_loss.data)
if len(ctc_entropy) != 0:
ctc_entropy = sum(ctc_entropy) / len(ctc_entropy)
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
else:
ctc_entropy = 0
loss = \ loss = \
self.ctc_weight * ctc_loss + \ self.ctc_weight * ctc_loss + \
self.interleaved_ctc_weight * interleaved_ctc_loss + \ self.interleaved_ctc_weight * interleaved_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \ self.target_ctc_weight * target_ctc_loss + \
self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \ self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \
ctc_self_distill_loss + \ ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy + \ self.ctc_entropy_weight * ctc_entropy + \
self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss
if loss != 0: if loss != 0:
...@@ -600,10 +613,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -600,10 +613,6 @@ class CtcCriterion(FairseqCriterion):
logger.warning("Illegal loss %f!" % loss) logger.warning("Illegal loss %f!" % loss)
if ctc_loss != 0 and (torch.isnan(ctc_loss) or torch.isinf(ctc_loss)): if ctc_loss != 0 and (torch.isnan(ctc_loss) or torch.isinf(ctc_loss)):
logger.warning("CTC loss %f!" % ctc_loss) logger.warning("CTC loss %f!" % ctc_loss)
if self.interleaved_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss)
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
# CER is not completely accurate and is for reference only. # CER is not completely accurate and is for reference only.
if not model.training: if not model.training:
...@@ -734,7 +743,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -734,7 +743,7 @@ class CtcCriterion(FairseqCriterion):
if ctc_entropy_sum > 0: if ctc_entropy_sum > 0:
metrics.log_scalar( metrics.log_scalar(
"ctc_entropy", "ctc_entropy",
ctc_entropy_sum / nfeatures / math.log(2), ctc_entropy_sum / nsentences / math.log(2),
sample_size, sample_size,
round=3, round=3,
) )
...@@ -763,21 +772,21 @@ class CtcCriterion(FairseqCriterion): ...@@ -763,21 +772,21 @@ class CtcCriterion(FairseqCriterion):
if ctc_self_distill_loss_sum > 0: if ctc_self_distill_loss_sum > 0:
metrics.log_scalar( metrics.log_scalar(
"ctc_self_distill_loss", "ctc_self_distill_loss",
ctc_self_distill_loss_sum / nfeatures / math.log(2), ctc_self_distill_loss_sum / nsentences / math.log(2),
sample_size, sample_size,
round=3, round=3,
) )
if target_ctc_self_distill_loss_sum > 0: if target_ctc_self_distill_loss_sum > 0:
metrics.log_scalar( metrics.log_scalar(
"target_ctc_self_distill_loss_sum", "target_ctc_self_distill_loss_sum",
target_ctc_self_distill_loss_sum / nfeatures / math.log(2), target_ctc_self_distill_loss_sum / nsentences / math.log(2),
sample_size, sample_size,
round=3, round=3,
) )
if ctc_mixup_consistent_loss > 0: if ctc_mixup_consistent_loss > 0:
metrics.log_scalar( metrics.log_scalar(
"ctc_mixup_consistent_loss", "ctc_mixup_consistent_loss",
ctc_mixup_consistent_loss / nfeatures / math.log(2), ctc_mixup_consistent_loss / nsentences / math.log(2),
sample_size, sample_size,
round=3, round=3,
) )
......
...@@ -25,6 +25,10 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): ...@@ -25,6 +25,10 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default=0.0, default=0.0,
metadata={"help": "the weight for consistency regularization of mixup"}, metadata={"help": "the weight for consistency regularization of mixup"},
) )
cal_mixup_loss: bool = field(
default=True,
metadata={"help": "calculate the loss for the mixed samples"},
)
report_accuracy: bool = field( report_accuracy: bool = field(
default=False, default=False,
metadata={"help": "report accuracy metric"}, metadata={"help": "report accuracy metric"},
...@@ -67,6 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -67,6 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing, label_smoothing,
ignore_prefix_size=0, ignore_prefix_size=0,
report_accuracy=False, report_accuracy=False,
cal_mixup_loss=True,
mixup_consistent_weight=0.0, mixup_consistent_weight=0.0,
): ):
super().__init__(task) super().__init__(task)
...@@ -74,6 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -74,6 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = float(label_smoothing) self.eps = float(label_smoothing)
self.ignore_prefix_size = ignore_prefix_size self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy self.report_accuracy = report_accuracy
self.cal_mixup_loss = cal_mixup_loss
self.mixup_consistent_weight = mixup_consistent_weight self.mixup_consistent_weight = mixup_consistent_weight
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
...@@ -127,49 +133,63 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -127,49 +133,63 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None: if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"] mixup = net_output[1]["mixup"]
mixup_idx1 = mixup["index1"] idx1 = mixup["index1"]
mixup_idx2 = mixup["index2"] idx2 = mixup["index2"]
batch_size = len(mixup_idx1) mixup_flag = mixup["mixup_flag"]
mixup_idx1 = idx1[mixup_flag]
target = model.get_targets(sample, net_output) mixup_idx2 = idx2[mixup_flag]
target1 = target[mixup_idx1].view(-1) org_idx = idx1[~mixup_flag]
target2 = target[mixup_idx2].view(-1)
targets = [target1, target2] seq_len = target.size(1)
lprobs = lprobs.view(-1, seq_len, lprobs.size(-1))
if mixup["mixup_decoder_emb"]:
mixup_lprobs = [lprobs[mixup_flag, :, :], lprobs[mixup_flag, :, :]]
else:
decoder_mixup_flag1 = mixup["decoder_mixup_flag1"]
decoder_mixup_flag2 = mixup["decoder_mixup_flag2"]
mixup_lprobs = [lprobs[decoder_mixup_flag1, :, :], lprobs[decoder_mixup_flag2, :, :]]
mixup_coef = net_output[1]["mixup"]["coef"] org_lprobs = lprobs[org_idx, :, :]
mixup_targets = [target[mixup_idx1], target[mixup_idx2]]
mixup_coef = net_output[1]["mixup"]["coef"][mixup_flag]
loss_coef = [mixup_coef, 1 - mixup_coef] loss_coef = [mixup_coef, 1 - mixup_coef]
for item_target, item_coef in zip(targets, loss_coef): if len(org_idx) > 0:
item_loss, item_nll_loss = label_smoothed_nll_loss( org_target = target[org_idx]
lprobs, org_loss, org_nll_loss = label_smoothed_nll_loss(
item_target, org_lprobs.view(-1, org_lprobs.size(-1)),
org_target.view(-1),
self.eps, self.eps,
ignore_index=self.padding_idx, ignore_index=self.padding_idx,
reduce=False, reduce=False,
) )
loss += (item_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum() loss += org_loss.sum()
nll_loss += (item_nll_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum() nll_loss += org_nll_loss.sum()
if self.cal_mixup_loss:
for item_lprobs, item_target, item_coef in zip(mixup_lprobs, mixup_targets, loss_coef):
batch_size = item_target.size(0)
item_loss, item_nll_loss = label_smoothed_nll_loss(
item_lprobs.view(-1, item_lprobs.size(-1)),
item_target.view(-1),
self.eps,
ignore_index=self.padding_idx,
reduce=False,
)
loss += (item_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
nll_loss += (item_nll_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
mixup_consistent_loss = 0 mixup_consistent_loss = 0
if self.mixup_consistent_weight > 0: if self.mixup_consistent_weight > 0:
lprobs = lprobs.view(batch_size, -1, lprobs.size(-1)) non_padding_mask = ~org_target.eq(self.padding_idx)
mixup_pos = mixup_idx1 != mixup_idx2
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_lprobs = lprobs[mixup_pos, :, :]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
non_padding_mask = ~target.eq(self.padding_idx)
no_mixup_lprobs = lprobs[~mixup_pos, :, :] teacher_lprobs = [org_lprobs[mixup_idx1, :, :], org_lprobs[mixup_idx2, :, :]]
mixup_target_lprobs = [no_mixup_lprobs[mixup_real_idx1, :, :], no_mixup_lprobs[mixup_real_idx2, :, :]] target_pad_mask = [non_padding_mask[mixup_idx1], non_padding_mask[mixup_idx2]]
mixup_target_pad_mask = [non_padding_mask[mixup_real_idx1], non_padding_mask[mixup_real_idx2]]
for tgt_lprobs, pad, coef in zip(mixup_target_lprobs, mixup_target_pad_mask, loss_coef): for item_mixup_lprobs, tgt_lprobs, pad, coef in zip(mixup_lprobs, teacher_lprobs, target_pad_mask, loss_coef):
item_loss = F.kl_div( item_loss = F.kl_div(
F.log_softmax(mixup_real_lprobs, dim=-1, dtype=torch.float32), F.log_softmax(item_mixup_lprobs, dim=-1, dtype=torch.float32),
F.log_softmax(tgt_lprobs.detach(), dim=-1, dtype=torch.float32), F.log_softmax(tgt_lprobs.detach(), dim=-1, dtype=torch.float32),
log_target=True, log_target=True,
reduction="none", reduction="none",
...@@ -194,18 +214,17 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -194,18 +214,17 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup = net_output[1]["mixup"] mixup = net_output[1]["mixup"]
mixup_idx1 = mixup["index1"] mixup_idx1 = mixup["index1"]
mixup_idx2 = mixup["index2"] mixup_idx2 = mixup["index2"]
batch_size = len(mixup_idx1) mixup_flag = mixup["mixup_flag"]
no_mixup_pos = mixup_idx1 == mixup_idx2 if all(mixup_flag):
idx = mixup_idx1[no_mixup_pos] return torch.Tensor([0]), torch.Tensor([0])
lprobs = lprobs.view(batch_size, -1, lprobs.size(-1))[idx, :, :].view(-1, lprobs.size(-1))
idx = mixup_idx1[~mixup_flag]
lprobs = lprobs.view(-1, target.size(1), lprobs.size(-1))[idx, :, :].view(-1, lprobs.size(-1))
target = target[idx].view(-1) target = target[idx].view(-1)
else: else:
target = target.view(-1) target = target.view(-1)
if lprobs.size(0) == 0:
return torch.Tensor([0]), torch.Tensor([0])
mask = target.ne(self.padding_idx) mask = target.ne(self.padding_idx)
n_correct = torch.sum( n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
......
...@@ -25,9 +25,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -25,9 +25,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg: CtcCriterionConfig, cfg: CtcCriterionConfig,
ctc_weight=0.0, ctc_weight=0.0,
save_dir=None, save_dir=None,
cal_mixup_loss=True,
mixup_consistent_weight=0.0): mixup_consistent_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing, super().__init__(task, sentence_avg, label_smoothing,
report_accuracy=True, report_accuracy=True,
cal_mixup_loss=cal_mixup_loss,
mixup_consistent_weight=mixup_consistent_weight) mixup_consistent_weight=mixup_consistent_weight)
self.report_accuracy = True self.report_accuracy = True
...@@ -83,15 +85,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -83,15 +85,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ratio = mixup["ratio"] ratio = mixup["ratio"]
if mixup["keep_org"]: if mixup["keep_org"]:
n_tokens = int(sample_size * (1 + ratio)) n_tokens = int(n_tokens * (1 + ratio))
sample_size = int(sample_size * (1 + ratio)) if self.sentence_avg else n_tokens
n_sentences = int(n_sentences * (1 + ratio))
else: else:
n_tokens = int(sample_size * ratio) if ratio > 1:
n_tokens = int(n_tokens * ratio)
if self.sentence_avg: sample_size = int(sample_size * ratio) if self.sentence_avg else n_tokens
sample_size = net_output[0].size(0) n_sentences = int(n_sentences * ratio)
else:
sample_size = n_tokens
n_sentences = net_output[0].size(0)
logging_output = { logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data, "trans_loss": utils.item(loss.data) if reduce else loss.data,
...@@ -143,6 +144,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -143,6 +144,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
) )
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
n_sentences = utils.item(sum(log.get("nsentences", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
) )
...@@ -159,7 +161,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -159,7 +161,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
) )
if mixup_consistent_loss_sum != 0: if mixup_consistent_loss_sum != 0:
metrics.log_scalar( metrics.log_scalar(
"mixup_consistent_loss", mixup_consistent_loss_sum / sample_size / math.log(2), sample_size, round=3 "mixup_consistent_loss", mixup_consistent_loss_sum / n_sentences / math.log(2), n_sentences, round=3
) )
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
......
...@@ -170,7 +170,6 @@ def get_features_from_npy_or_audio(path): ...@@ -170,7 +170,6 @@ def get_features_from_npy_or_audio(path):
def get_features_or_waveform_from_uncompressed_zip( def get_features_or_waveform_from_uncompressed_zip(
path, byte_offset, byte_size, need_waveform=False path, byte_offset, byte_size, need_waveform=False
): ):
assert path.endswith(".zip")
data = read_from_uncompressed_zip(path, byte_offset, byte_size) data = read_from_uncompressed_zip(path, byte_offset, byte_size)
f = io.BytesIO(data) f = io.BytesIO(data)
if is_npy_data(data): if is_npy_data(data):
...@@ -214,7 +213,7 @@ def get_features_or_waveform(path: str, need_waveform=False): ...@@ -214,7 +213,7 @@ def get_features_or_waveform(path: str, need_waveform=False):
return get_features_from_npy_or_audio(_path) return get_features_from_npy_or_audio(_path)
elif len(extra) == 2: elif len(extra) == 2:
extra = [int(i) for i in extra] extra = [int(i) for i in extra]
if _path.endswith('.zip'): if _path.endswith('.zip') or _path.endswith('.tar'):
features_or_waveform = get_features_or_waveform_from_uncompressed_zip( features_or_waveform = get_features_or_waveform_from_uncompressed_zip(
_path, extra[0], extra[1], need_waveform=need_waveform _path, extra[0], extra[1], need_waveform=need_waveform
) )
......
...@@ -347,6 +347,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -347,6 +347,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default=0.5, default=0.5,
help='initialized weight for local mask' help='initialized weight for local mask'
) )
parser.add_argument(
"--layer-padding-mask",
default=False,
type=bool,
help="mask the padding to 0 before each layer"
)
# Conformer setting # Conformer setting
parser.add_argument( parser.add_argument(
...@@ -500,6 +506,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -500,6 +506,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="the layers to apply mixup", help="the layers to apply mixup",
) )
parser.add_argument( parser.add_argument(
"--inter-mixup-decoder-layer",
default="0",
type=str,
help="the layers to apply mixup in the decoder",
)
parser.add_argument(
"--inter-mixup-beta", "--inter-mixup-beta",
default=0.5, default=0.5,
type=float, type=float,
...@@ -522,6 +534,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -522,6 +534,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="keep original batch", help="keep original batch",
) )
parser.add_argument(
"--inter-mixup-decoder-emb",
action="store_true",
help="mix the embedding in the decoder",
)
pass pass
@classmethod @classmethod
...@@ -654,6 +671,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -654,6 +671,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)] [S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
self.layer_padding_mask = args.layer_padding_mask
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(dim) self.layer_norm = LayerNorm(dim)
...@@ -760,6 +778,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -760,6 +778,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.mixup_prob = args.inter_mixup_prob self.mixup_prob = args.inter_mixup_prob
self.mixup_ratio = args.inter_mixup_ratio self.mixup_ratio = args.inter_mixup_ratio
self.mixup_keep_org = args.inter_mixup_keep_org self.mixup_keep_org = args.inter_mixup_keep_org
self.mixup_decoder_emb = args.inter_mixup_decoder_emb
beta = args.inter_mixup_beta beta = args.inter_mixup_beta
from torch.distributions import Beta from torch.distributions import Beta
...@@ -826,6 +845,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -826,6 +845,7 @@ class S2TTransformerEncoder(FairseqEncoder):
org_indices = np.arange(batch) org_indices = np.arange(batch)
mixup_size = int(batch * self.mixup_ratio) mixup_size = int(batch * self.mixup_ratio)
mixup_flag = []
if mixup_size <= batch: if mixup_size <= batch:
mixup_index1 = np.random.permutation(mixup_size) mixup_index1 = np.random.permutation(mixup_size)
mixup_index2 = np.random.permutation(mixup_size) mixup_index2 = np.random.permutation(mixup_size)
...@@ -836,13 +856,17 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -836,13 +856,17 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.mixup_keep_org: if self.mixup_keep_org:
idx1 = np.append(org_indices, mixup_index1) idx1 = np.append(org_indices, mixup_index1)
idx2 = np.append(org_indices, mixup_index2) idx2 = np.append(org_indices, mixup_index2)
else: mixup_flag.extend([0] * len(org_indices))
mixup_flag.extend([1] * len(mixup_index1))
else:
keep_indices = [] keep_indices = []
for i in org_indices: for i in org_indices:
if i not in mixup_index1 and i not in mixup_index2: if i not in mixup_index1 and i not in mixup_index2:
keep_indices.append(i) keep_indices.append(i)
idx1 = np.append(keep_indices, mixup_index1) idx1 = np.append(keep_indices, mixup_index1)
idx2 = np.append(keep_indices, mixup_index2) idx2 = np.append(keep_indices, mixup_index2)
mixup_flag.extend([0] * len(keep_indices))
mixup_flag.extend([1] * len(mixup_index1))
idx1 = torch.from_numpy(idx1).to(x.device).long() idx1 = torch.from_numpy(idx1).to(x.device).long()
idx2 = torch.from_numpy(idx2).to(x.device).long() idx2 = torch.from_numpy(idx2).to(x.device).long()
...@@ -859,6 +883,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -859,6 +883,7 @@ class S2TTransformerEncoder(FairseqEncoder):
pad2 = encoder_padding_mask[idx2] pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 & pad2 encoder_padding_mask = pad1 & pad2
input_lengths = (~encoder_padding_mask).sum(-1) input_lengths = (~encoder_padding_mask).sum(-1)
mixup_flag = torch.Tensor(mixup_flag).to(x.device).bool()
mixup = { mixup = {
"ratio": self.mixup_ratio, "ratio": self.mixup_ratio,
...@@ -866,6 +891,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -866,6 +891,8 @@ class S2TTransformerEncoder(FairseqEncoder):
"coef": coef, "coef": coef,
"index1": idx1, "index1": idx1,
"index2": idx2, "index2": idx2,
"mixup_flag": mixup_flag,
"mixup_decoder_emb": self.mixup_decoder_emb,
} }
return x, encoder_padding_mask, input_lengths, mixup return x, encoder_padding_mask, input_lengths, mixup
...@@ -965,6 +992,13 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -965,6 +992,13 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
if self.layer_padding_mask and encoder_padding_mask is not None and not torch.all(encoder_padding_mask):
mask_pad = encoder_padding_mask.unsqueeze(2)
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
# encoder layer # encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
...@@ -1190,6 +1224,8 @@ def base_architecture(args): ...@@ -1190,6 +1224,8 @@ def base_architecture(args):
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False) args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False) args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
args.layer_padding_mask = getattr(args, "layer_padding_mask", False)
# CTC # CTC
args.ctc_layer = getattr(args, "ctc_layer", 0) args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False) args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
...@@ -1240,10 +1276,12 @@ def base_architecture(args): ...@@ -1240,10 +1276,12 @@ def base_architecture(args):
# mixup # mixup
args.inter_mixup = getattr(args, "inter_mixup", False) args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", "-1") args.inter_mixup_layer = getattr(args, "inter_mixup_layer", "-1")
args.inter_mixup_decoder_layer = getattr(args, "inter_mixup_decoder_layer", "0")
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5) args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1) args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3) args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False) args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False)
args.inter_mixup_decoder_emb = getattr(args, "inter_mixup_decoder_emb", False)
@register_model_architecture("s2t_transformer", "s2t_transformer_s") @register_model_architecture("s2t_transformer", "s2t_transformer_s")
......
...@@ -18,7 +18,8 @@ from fairseq.models import ( ...@@ -18,7 +18,8 @@ from fairseq.models import (
from fairseq.modules.speech_to_text import Adapter, CTC from fairseq.modules.speech_to_text import Adapter, CTC
from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.speech_to_text import S2TTransformerModel, S2TTransformerEncoder from fairseq.models.speech_to_text import S2TTransformerModel, S2TTransformerEncoder
from fairseq.models.wav2vec import Wav2Vec2Model, Wav2VecCtc from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecCtc
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
from fairseq.modules import ( from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
import math import math
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import logging import logging
import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -900,6 +902,17 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -900,6 +902,17 @@ class TransformerDecoder(FairseqIncrementalDecoder):
#self.gather_attn_weight = True #self.gather_attn_weight = True
self.attn_weights = dict() self.attn_weights = dict()
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_decoder_emb = args.inter_mixup_decoder_emb
str_mixup_layer = getattr(args, "inter_mixup_decoder_layer", "0")
if len(str_mixup_layer.split(",")) == 1:
self.mixup_layer = int(str_mixup_layer)
else:
self.mixup_layer = [int(layer) for layer in str_mixup_layer.split(",")]
logger.info("Use mixup in the decoder layer %s, mixup decoder embedding %r." % (
str_mixup_layer, self.mixup_decoder_emb))
def build_decoder_layer(self, args, no_encoder_attn=False): def build_decoder_layer(self, args, no_encoder_attn=False):
layer = TransformerDecoderLayer(args, no_encoder_attn) layer = TransformerDecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False): if getattr(args, "checkpoint_activations", False):
...@@ -974,6 +987,68 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -974,6 +987,68 @@ class TransformerDecoder(FairseqIncrementalDecoder):
this function is made to be used in the subclass instead. this function is made to be used in the subclass instead.
""" """
def apply_mixup(self, encoder_out, x, self_attn_padding_mask):
mixup = encoder_out["mixup"]
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
flag = mixup["mixup_flag"]
if mixup["mixup_decoder_emb"]:
x1 = x[:, idx1]
x2 = x[:, idx2]
mixup_coef = coef.view(1, -1, 1)
x = mixup_coef * x1 + (1 - mixup_coef) * x2
x = x.contiguous()
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 & pad2
else:
mix_idx1 = idx1[flag]
mix_idx2 = idx2[flag]
org_idx = idx1[~flag]
x1 = x[:, mix_idx1]
x2 = x[:, mix_idx2]
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[mix_idx1]
pad2 = self_attn_padding_mask[mix_idx2]
decoder_mixup_flag1 = [0] * len(org_idx)
decoder_mixup_flag2 = [0] * len(org_idx)
if len(org_idx) != 0:
org_x = x[:, org_idx]
x = torch.cat([org_x, x1, x2], dim=1)
if self_attn_padding_mask is not None:
org_pad = self_attn_padding_mask[org_idx]
self_attn_padding_mask = torch.cat([org_pad, pad1, pad2], dim=0)
else:
x = torch.cat([x1, x2], dim=1)
if self_attn_padding_mask is not None:
self_attn_padding_mask = torch.cat([pad1, pad2], dim=0)
decoder_mixup_flag1.extend([1] * len(mix_idx1))
decoder_mixup_flag1.extend([0] * len(mix_idx2))
decoder_mixup_flag2.extend([0] * len(mix_idx1))
decoder_mixup_flag2.extend([1] * len(mix_idx2))
mixup["decoder_mixup_flag1"] = torch.Tensor(decoder_mixup_flag1).to(x.device).bool()
mixup["decoder_mixup_flag2"] = torch.Tensor(decoder_mixup_flag2).to(x.device).bool()
encoder_rep = encoder_out["encoder_out"][0]
mixup_encoder_rep = encoder_rep[:, flag, :]
encoder_out["encoder_out"][0] = torch.cat([encoder_rep, mixup_encoder_rep], dim=1)
padding = encoder_out["encoder_padding_mask"][0]
mixup_padding = padding[flag, :]
encoder_out["encoder_padding_mask"][0] = torch.cat([padding, mixup_padding], dim=0)
return encoder_out, x, self_attn_padding_mask, mixup
def extract_features_scriptable( def extract_features_scriptable(
self, self,
prev_output_tokens, prev_output_tokens,
...@@ -1008,6 +1083,24 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1008,6 +1083,24 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if alignment_layer is None: if alignment_layer is None:
alignment_layer = self.num_layers - 1 alignment_layer = self.num_layers - 1
layer_idx = -1
bak_encoder_out = encoder_out["encoder_out"][0]
bak_encoder_padding_mask = encoder_out["encoder_padding_mask"][0]
do_mixup = False
mixup_layer = 0
mixup = None
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
do_mixup = True
if type(self.mixup_layer) is list:
from random import choice
mixup_layer = choice(self.mixup_layer)
else:
mixup_layer = self.mixup_layer
if do_mixup and layer_idx == mixup_layer:
logger.warning("To DO!!!")
# embed positions # embed positions
positions = None positions = None
if self.embed_positions is not None: if self.embed_positions is not None:
...@@ -1048,24 +1141,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1048,24 +1141,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
mixup = None layer_idx += 1
if "mixup" in encoder_out and encoder_out["mixup"] is not None: if do_mixup and layer_idx == mixup_layer:
mixup = encoder_out["mixup"] encoder_out, x, self_attn_padding_mask, mixup = self.apply_mixup(encoder_out, x, self_attn_padding_mask)
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
x1 = x[:, idx1]
x2 = x[:, idx2]
mixup_coef = coef.view(1, -1, 1)
x = mixup_coef * x1 + (1 - mixup_coef) * x2
x = x.contiguous()
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 & pad2
# decoder layers # decoder layers
avg_attn = None avg_attn = None
...@@ -1109,6 +1187,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1109,6 +1187,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else: else:
avg_attn += layer_attn avg_attn += layer_attn
layer_idx += 1
if do_mixup and layer_idx == mixup_layer:
encoder_out, x, self_attn_padding_mask, mixup = self.apply_mixup(encoder_out, x, self_attn_padding_mask)
if self.gather_attn_weight: if self.gather_attn_weight:
avg_attn = avg_attn / len(self.layers) avg_attn = avg_attn / len(self.layers)
attn = avg_attn.mean(0).sum(-2) attn = avg_attn.mean(0).sum(-2)
...@@ -1149,6 +1231,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1149,6 +1231,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.project_out_dim is not None: if self.project_out_dim is not None:
x = self.project_out_dim(x) x = self.project_out_dim(x)
if do_mixup:
encoder_out["encoder_out"][0] = bak_encoder_out
encoder_out["encoder_padding_mask"][0] = bak_encoder_padding_mask
return x, {"attn": [attn], "inner_states": inner_states, "mixup": mixup} return x, {"attn": [attn], "inner_states": inner_states, "mixup": mixup}
......
...@@ -67,7 +67,7 @@ class Adapter(nn.Module): ...@@ -67,7 +67,7 @@ class Adapter(nn.Module):
self.cal_context = False self.cal_context = False
self.shrink = False self.shrink = False
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["linear", "league", "gated_league", "gated_league2", "league_shrink"]:
self.cal_linear = True self.cal_linear = True
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(dim, 2 * dim), nn.Linear(dim, 2 * dim),
...@@ -101,7 +101,7 @@ class Adapter(nn.Module): ...@@ -101,7 +101,7 @@ class Adapter(nn.Module):
self.shrink = True self.shrink = True
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy) logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if self.cal_context: if self.cal_context or self.shrink:
self.distribution_cutoff = strategy.get("distribution_cutoff", None) self.distribution_cutoff = strategy.get("distribution_cutoff", None)
self.distribution_temperature = strategy.get("ctc_temperature", 1.0) self.distribution_temperature = strategy.get("ctc_temperature", 1.0)
self.gumbel = strategy.get("gumbel", False) self.gumbel = strategy.get("gumbel", False)
...@@ -181,7 +181,7 @@ class Adapter(nn.Module): ...@@ -181,7 +181,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "context": elif self.adapter_type == "context":
out = soft_out out = soft_out
elif self.adapter_type in ["league", "inter_league_shrink"]: elif self.adapter_type in ["league", "league_shrink"]:
if self.training and self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob: if self.training and self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5: if torch.rand(1).uniform_() < 0.5:
out = linear_out out = linear_out
......
...@@ -81,7 +81,7 @@ def last_n_checkpoints(paths, n, combine_choice, upper_bound=None, max_metric=Fa ...@@ -81,7 +81,7 @@ def last_n_checkpoints(paths, n, combine_choice, upper_bound=None, max_metric=Fa
pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
elif combine_choice == "best": elif combine_choice == "best":
reverse = True if max_metric else False reverse = True if max_metric else False
pt_regexp = re.compile(r"checkpoint\.best_loss_\d+_(\d+\.?\d*)\.pt") pt_regexp = re.compile(r"checkpoint\.best_\w+_\d+_(\d+\.?\d*)\.pt")
else: else:
pt_regexp = re.compile(r"checkpoint(\d+)\.pt") pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论