Commit 6e37ffba by xuchen

Cumulative updates.

implement mix-speech and optimize the ctc entropy loss
parent 6358474e
......@@ -398,7 +398,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cd ..
echo "CTC WER" >> ${result_file}
tail -n 1 ${src_ctc} >> ${result_file}
tail -n 2 ${src_ctc} >> ${result_file}
src_bleu=$(mktemp -t temp.record.XXXXXX)
cd local
......
arch: s2t_dynamic_transformer
condensation-metric: ratio
#condensation-metric: threshold
condensation-mode: create
##condensation-mode: mask
#condensation-layers: 3,6,9
condensation-threshold: 0.95
#condensation-ratio: 0.8
share-ctc-and-embed: True
interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 6,9
share-interleaved-ctc: True
\ No newline at end of file
inter-mixup: True
inter-mixup-layer: -1
inter-mixup-decoder-layer: 0
inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5
inter-mixup-keep-org: True
inter-mixup-decoder-emb: True
cal-mixup-loss: True
ctc-mixup-consistent-weight: 0
mixup-consistent-weight: 0
\ No newline at end of file
......@@ -5,10 +5,13 @@ condensation-metric: ratio
condensation-mode: create
##condensation-mode: mask
#condensation-layers: 3,6,9
condensation-threshold: 0.9
condensation-ratio: 0.8
condensation-threshold: 0.95
#condensation-ratio: 0.8
share-ctc-and-embed: True
interleaved-ctc-weight: 0.2
interleaved-ctc-layers: 3,6,9
interleaved-ctc-layers: 6,9
share-interleaved-ctc: True
ctc-entropy-weight: 0.0
ctc-entropy-cutoff: 0
\ No newline at end of file
......@@ -34,5 +34,5 @@ sae-ctc-temperature: 1
#ctc-self-distill-prob: 0.1
#cal-all-ctc: True
use-aligned-text: True
aligned-target-ctc: True
# use-aligned-text: True
# aligned-target-ctc: True
inter-mixup: True
inter-mixup-layer: -1
inter-mixup-decoder-layer: 0
inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5
inter-mixup-keep-org: False
inter-mixup-keep-org: True
inter-mixup-decoder-emb: True
cal-mixup-loss: True
ctc-mixup-consistent-weight: 0
mixup-consistent-weight: 0
mixup-consistent-weight: 0
\ No newline at end of file
......@@ -14,6 +14,9 @@ label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-norm: True
encoder-no-scale-embedding: True
subsampling-type: conv1d
subsampling-layers: 2
subsampling-filter: 2048
......@@ -36,7 +39,7 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
acoustic-encoder: transformer
adapter: league
adapter: inter_league
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
......
......@@ -11,6 +11,9 @@ adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-embed-norm: True
encoder-no-scale-embedding: True
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
......@@ -34,7 +37,7 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
acoustic-encoder: pds
adapter: league
adapter: inter_league
encoder-embed-dim: 512
ctc-layer: 12
......
......@@ -11,8 +11,12 @@ adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-embed-norm: True
encoder-no-scale-embedding: True
encoder-normalize-before: True
decoder-normalize-before: True
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
......@@ -26,10 +30,10 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
acoustic-encoder: pds
adapter: league
adapter: inter_league
encoder-embed-dim: 256
#ctc-layer: 12
ctc-layer: 12
pds-stages: 4
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
......
#! /bin/bash
set -e
infer_dir=$1
......
#! /bin/bash
set -e
infer_dir=$1
......
#! /bin/bash
gpu_num=4
cmd="sh train.sh"
......
......@@ -449,7 +449,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Source language" >> ${result_file}
echo "CTC WER" >> ${result_file}
tail -n 1 ${src_ctc} >> ${result_file}
tail -n 2 ${src_ctc} >> ${result_file}
src_bleu=$(mktemp -t temp.record.XXXXXX)
cd local
......@@ -475,7 +475,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Target language" >> ${result_file}
echo "CTC WER" >> ${result_file}
tail -n 1 ${tgt_ctc} >> ${result_file}
tail -n 2 ${tgt_ctc} >> ${result_file}
tgt_bleu=$(mktemp -t temp.record.XXXXXX)
cd local
......
......@@ -44,7 +44,7 @@ class CtcCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "weight of CTC loss"},
)
ctc_entropy: float = field(
ctc_entropy_weight: float = field(
default=0.0,
metadata={"help": "weight of CTC entropy"},
)
......@@ -175,7 +175,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_self_distill_prob = float(cfg.ctc_self_distill_prob)
self.ctc_self_distill_temperature = float(cfg.ctc_self_distill_temperature)
self.ctc_entropy = cfg.ctc_entropy
self.ctc_entropy_weight = cfg.ctc_entropy_weight
self.ctc_entropy_cutoff = cfg.ctc_entropy_cutoff
self.ctc_mixup_consistent_weight = cfg.ctc_mixup_consistent_weight
......@@ -183,7 +183,7 @@ class CtcCriterion(FairseqCriterion):
self.all_ctc_weight = self.ctc_weight + self.interleaved_ctc_weight + \
self.target_ctc_weight + self.target_interleaved_ctc_weight + \
self.ctc_self_distill_weight + self.target_ctc_self_distill_weight + \
self.ctc_entropy + self.ctc_mixup_consistent_weight
self.ctc_entropy_weight + self.ctc_mixup_consistent_weight
if self.all_ctc_weight > 0:
self.ctc_loss = torch.nn.CTCLoss(blank=self.blank_idx, reduction="none", zero_infinity=True)
......@@ -375,6 +375,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_names = []
lprobs = None
target_lprobs = None
ctc_entropy = []
interleaved_ctc_num = 0
interleaved_ctc_loss = 0
......@@ -393,6 +394,14 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_logit = logit
inter_input_lengths = input_lengths
if self.ctc_entropy_weight > 0:
if self.ctc_entropy_cutoff != 0:
cut_ctc_logit = inter_ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy.append(Categorical(logits=cut_ctc_logit).entropy().sum())
else:
ctc_entropy.append(Categorical(logits=inter_ctc_logit).entropy().sum())
all_ctc_logits["interleaved_ctc_logit%d" % i] = [inter_ctc_logit, inter_input_lengths]
inter_loss, inter_lprobs = self.get_ctc_loss(
model, inter_ctc_logit, transcripts, inter_input_lengths, transcript_lengths, loss_coef)
......@@ -403,7 +412,6 @@ class CtcCriterion(FairseqCriterion):
logging_output["interleaved_ctc_loss"] = utils.item(interleaved_ctc_loss.data)
ctc_loss = 0
ctc_entropy = 0
use_ctc = False
if self.ctc_weight > 0 and "ctc_logit" in net_output and len(net_output["ctc_logit"]) > 0:
use_ctc = True
......@@ -413,15 +421,14 @@ class CtcCriterion(FairseqCriterion):
ctc_loss, lprobs = self.get_ctc_loss(
model, ctc_logit, transcripts, input_lengths, transcript_lengths, loss_coef)
if self.ctc_entropy > 0:
if self.ctc_entropy_weight > 0:
if self.ctc_entropy_cutoff != 0:
cut_ctc_logit = ctc_logit.sort(dim=-1, descending=True)[0][:, :, 0:self.ctc_entropy_cutoff]
cut_ctc_logit = cut_ctc_logit / cut_ctc_logit.sum(dim=-1, keepdim=True)
ctc_entropy = Categorical(logits=cut_ctc_logit).entropy().sum()
ctc_entropy.append(Categorical(logits=cut_ctc_logit).entropy().sum())
else:
ctc_entropy = Categorical(logits=ctc_logit).entropy().sum()
ctc_entropy.append(Categorical(logits=ctc_logit).entropy().sum())
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
# calculate the target CTC loss
......@@ -584,13 +591,19 @@ class CtcCriterion(FairseqCriterion):
ctc_mixup_consistent_loss += (loss.sum(-1).transpose(0, 1).masked_fill_(~pad, 0.0).sum(-1) * coef).sum()
logging_output["ctc_mixup_consistent_loss"] = utils.item(ctc_mixup_consistent_loss.data)
if len(ctc_entropy) != 0:
ctc_entropy = sum(ctc_entropy) / len(ctc_entropy)
logging_output["ctc_entropy"] = utils.item(ctc_entropy.data)
else:
ctc_entropy = 0
loss = \
self.ctc_weight * ctc_loss + \
self.interleaved_ctc_weight * interleaved_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.target_interleaved_ctc_weight * target_interleaved_ctc_loss + \
ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy + \
self.ctc_entropy_weight * ctc_entropy + \
self.ctc_mixup_consistent_weight * ctc_mixup_consistent_loss
if loss != 0:
......@@ -600,10 +613,6 @@ class CtcCriterion(FairseqCriterion):
logger.warning("Illegal loss %f!" % loss)
if ctc_loss != 0 and (torch.isnan(ctc_loss) or torch.isinf(ctc_loss)):
logger.warning("CTC loss %f!" % ctc_loss)
if self.interleaved_ctc_weight != 0:
logger.warning("Intermedia CTC loss %f!" % interleaved_ctc_loss)
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
# CER is not completely accurate and is for reference only.
if not model.training:
......@@ -734,7 +743,7 @@ class CtcCriterion(FairseqCriterion):
if ctc_entropy_sum > 0:
metrics.log_scalar(
"ctc_entropy",
ctc_entropy_sum / nfeatures / math.log(2),
ctc_entropy_sum / nsentences / math.log(2),
sample_size,
round=3,
)
......@@ -763,21 +772,21 @@ class CtcCriterion(FairseqCriterion):
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"ctc_self_distill_loss",
ctc_self_distill_loss_sum / nfeatures / math.log(2),
ctc_self_distill_loss_sum / nsentences / math.log(2),
sample_size,
round=3,
)
if target_ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
"target_ctc_self_distill_loss_sum",
target_ctc_self_distill_loss_sum / nfeatures / math.log(2),
target_ctc_self_distill_loss_sum / nsentences / math.log(2),
sample_size,
round=3,
)
if ctc_mixup_consistent_loss > 0:
metrics.log_scalar(
"ctc_mixup_consistent_loss",
ctc_mixup_consistent_loss / nfeatures / math.log(2),
ctc_mixup_consistent_loss / nsentences / math.log(2),
sample_size,
round=3,
)
......
......@@ -25,6 +25,10 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
default=0.0,
metadata={"help": "the weight for consistency regularization of mixup"},
)
cal_mixup_loss: bool = field(
default=True,
metadata={"help": "calculate the loss for the mixed samples"},
)
report_accuracy: bool = field(
default=False,
metadata={"help": "report accuracy metric"},
......@@ -67,6 +71,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
cal_mixup_loss=True,
mixup_consistent_weight=0.0,
):
super().__init__(task)
......@@ -74,6 +79,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = float(label_smoothing)
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
self.cal_mixup_loss = cal_mixup_loss
self.mixup_consistent_weight = mixup_consistent_weight
def forward(self, model, sample, reduce=True):
......@@ -127,49 +133,63 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
mixup_idx1 = mixup["index1"]
mixup_idx2 = mixup["index2"]
batch_size = len(mixup_idx1)
target = model.get_targets(sample, net_output)
target1 = target[mixup_idx1].view(-1)
target2 = target[mixup_idx2].view(-1)
targets = [target1, target2]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
mixup_flag = mixup["mixup_flag"]
mixup_idx1 = idx1[mixup_flag]
mixup_idx2 = idx2[mixup_flag]
org_idx = idx1[~mixup_flag]
seq_len = target.size(1)
lprobs = lprobs.view(-1, seq_len, lprobs.size(-1))
if mixup["mixup_decoder_emb"]:
mixup_lprobs = [lprobs[mixup_flag, :, :], lprobs[mixup_flag, :, :]]
else:
decoder_mixup_flag1 = mixup["decoder_mixup_flag1"]
decoder_mixup_flag2 = mixup["decoder_mixup_flag2"]
mixup_lprobs = [lprobs[decoder_mixup_flag1, :, :], lprobs[decoder_mixup_flag2, :, :]]
mixup_coef = net_output[1]["mixup"]["coef"]
org_lprobs = lprobs[org_idx, :, :]
mixup_targets = [target[mixup_idx1], target[mixup_idx2]]
mixup_coef = net_output[1]["mixup"]["coef"][mixup_flag]
loss_coef = [mixup_coef, 1 - mixup_coef]
for item_target, item_coef in zip(targets, loss_coef):
item_loss, item_nll_loss = label_smoothed_nll_loss(
lprobs,
item_target,
if len(org_idx) > 0:
org_target = target[org_idx]
org_loss, org_nll_loss = label_smoothed_nll_loss(
org_lprobs.view(-1, org_lprobs.size(-1)),
org_target.view(-1),
self.eps,
ignore_index=self.padding_idx,
reduce=False,
)
loss += (item_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
nll_loss += (item_nll_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
)
loss += org_loss.sum()
nll_loss += org_nll_loss.sum()
if self.cal_mixup_loss:
for item_lprobs, item_target, item_coef in zip(mixup_lprobs, mixup_targets, loss_coef):
batch_size = item_target.size(0)
item_loss, item_nll_loss = label_smoothed_nll_loss(
item_lprobs.view(-1, item_lprobs.size(-1)),
item_target.view(-1),
self.eps,
ignore_index=self.padding_idx,
reduce=False,
)
loss += (item_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
nll_loss += (item_nll_loss.sum(-1).view(batch_size, -1).sum(-1) * item_coef).sum()
mixup_consistent_loss = 0
if self.mixup_consistent_weight > 0:
lprobs = lprobs.view(batch_size, -1, lprobs.size(-1))
mixup_pos = mixup_idx1 != mixup_idx2
mixup_real_coef = mixup_coef[mixup_pos]
loss_coef = [mixup_real_coef, 1 - mixup_real_coef]
mixup_real_lprobs = lprobs[mixup_pos, :, :]
mixup_real_idx1 = mixup_idx1[mixup_pos]
mixup_real_idx2 = mixup_idx2[mixup_pos]
non_padding_mask = ~target.eq(self.padding_idx)
non_padding_mask = ~org_target.eq(self.padding_idx)
no_mixup_lprobs = lprobs[~mixup_pos, :, :]
mixup_target_lprobs = [no_mixup_lprobs[mixup_real_idx1, :, :], no_mixup_lprobs[mixup_real_idx2, :, :]]
mixup_target_pad_mask = [non_padding_mask[mixup_real_idx1], non_padding_mask[mixup_real_idx2]]
teacher_lprobs = [org_lprobs[mixup_idx1, :, :], org_lprobs[mixup_idx2, :, :]]
target_pad_mask = [non_padding_mask[mixup_idx1], non_padding_mask[mixup_idx2]]
for tgt_lprobs, pad, coef in zip(mixup_target_lprobs, mixup_target_pad_mask, loss_coef):
for item_mixup_lprobs, tgt_lprobs, pad, coef in zip(mixup_lprobs, teacher_lprobs, target_pad_mask, loss_coef):
item_loss = F.kl_div(
F.log_softmax(mixup_real_lprobs, dim=-1, dtype=torch.float32),
F.log_softmax(item_mixup_lprobs, dim=-1, dtype=torch.float32),
F.log_softmax(tgt_lprobs.detach(), dim=-1, dtype=torch.float32),
log_target=True,
reduction="none",
......@@ -194,18 +214,17 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
mixup = net_output[1]["mixup"]
mixup_idx1 = mixup["index1"]
mixup_idx2 = mixup["index2"]
batch_size = len(mixup_idx1)
mixup_flag = mixup["mixup_flag"]
no_mixup_pos = mixup_idx1 == mixup_idx2
idx = mixup_idx1[no_mixup_pos]
lprobs = lprobs.view(batch_size, -1, lprobs.size(-1))[idx, :, :].view(-1, lprobs.size(-1))
if all(mixup_flag):
return torch.Tensor([0]), torch.Tensor([0])
idx = mixup_idx1[~mixup_flag]
lprobs = lprobs.view(-1, target.size(1), lprobs.size(-1))[idx, :, :].view(-1, lprobs.size(-1))
target = target[idx].view(-1)
else:
target = target.view(-1)
if lprobs.size(0) == 0:
return torch.Tensor([0]), torch.Tensor([0])
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
......
......@@ -25,9 +25,11 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
cfg: CtcCriterionConfig,
ctc_weight=0.0,
save_dir=None,
cal_mixup_loss=True,
mixup_consistent_weight=0.0):
super().__init__(task, sentence_avg, label_smoothing,
report_accuracy=True,
cal_mixup_loss=cal_mixup_loss,
mixup_consistent_weight=mixup_consistent_weight)
self.report_accuracy = True
......@@ -83,15 +85,14 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ratio = mixup["ratio"]
if mixup["keep_org"]:
n_tokens = int(sample_size * (1 + ratio))
n_tokens = int(n_tokens * (1 + ratio))
sample_size = int(sample_size * (1 + ratio)) if self.sentence_avg else n_tokens
n_sentences = int(n_sentences * (1 + ratio))
else:
n_tokens = int(sample_size * ratio)
if self.sentence_avg:
sample_size = net_output[0].size(0)
else:
sample_size = n_tokens
n_sentences = net_output[0].size(0)
if ratio > 1:
n_tokens = int(n_tokens * ratio)
sample_size = int(sample_size * ratio) if self.sentence_avg else n_tokens
n_sentences = int(n_sentences * ratio)
logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data,
......@@ -143,6 +144,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
n_sentences = utils.item(sum(log.get("nsentences", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
......@@ -159,7 +161,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
)
if mixup_consistent_loss_sum != 0:
metrics.log_scalar(
"mixup_consistent_loss", mixup_consistent_loss_sum / sample_size / math.log(2), sample_size, round=3
"mixup_consistent_loss", mixup_consistent_loss_sum / n_sentences / math.log(2), n_sentences, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
......
......@@ -170,7 +170,6 @@ def get_features_from_npy_or_audio(path):
def get_features_or_waveform_from_uncompressed_zip(
path, byte_offset, byte_size, need_waveform=False
):
assert path.endswith(".zip")
data = read_from_uncompressed_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
......@@ -214,7 +213,7 @@ def get_features_or_waveform(path: str, need_waveform=False):
return get_features_from_npy_or_audio(_path)
elif len(extra) == 2:
extra = [int(i) for i in extra]
if _path.endswith('.zip'):
if _path.endswith('.zip') or _path.endswith('.tar'):
features_or_waveform = get_features_or_waveform_from_uncompressed_zip(
_path, extra[0], extra[1], need_waveform=need_waveform
)
......
......@@ -347,6 +347,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default=0.5,
help='initialized weight for local mask'
)
parser.add_argument(
"--layer-padding-mask",
default=False,
type=bool,
help="mask the padding to 0 before each layer"
)
# Conformer setting
parser.add_argument(
......@@ -500,6 +506,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="the layers to apply mixup",
)
parser.add_argument(
"--inter-mixup-decoder-layer",
default="0",
type=str,
help="the layers to apply mixup in the decoder",
)
parser.add_argument(
"--inter-mixup-beta",
default=0.5,
type=float,
......@@ -522,6 +534,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true",
help="keep original batch",
)
parser.add_argument(
"--inter-mixup-decoder-emb",
action="store_true",
help="mix the embedding in the decoder",
)
pass
@classmethod
......@@ -654,6 +671,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.layers = nn.ModuleList(
[S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
self.layer_padding_mask = args.layer_padding_mask
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(dim)
......@@ -760,6 +778,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.mixup_prob = args.inter_mixup_prob
self.mixup_ratio = args.inter_mixup_ratio
self.mixup_keep_org = args.inter_mixup_keep_org
self.mixup_decoder_emb = args.inter_mixup_decoder_emb
beta = args.inter_mixup_beta
from torch.distributions import Beta
......@@ -826,6 +845,7 @@ class S2TTransformerEncoder(FairseqEncoder):
org_indices = np.arange(batch)
mixup_size = int(batch * self.mixup_ratio)
mixup_flag = []
if mixup_size <= batch:
mixup_index1 = np.random.permutation(mixup_size)
mixup_index2 = np.random.permutation(mixup_size)
......@@ -836,13 +856,17 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.mixup_keep_org:
idx1 = np.append(org_indices, mixup_index1)
idx2 = np.append(org_indices, mixup_index2)
else:
mixup_flag.extend([0] * len(org_indices))
mixup_flag.extend([1] * len(mixup_index1))
else:
keep_indices = []
for i in org_indices:
if i not in mixup_index1 and i not in mixup_index2:
keep_indices.append(i)
idx1 = np.append(keep_indices, mixup_index1)
idx2 = np.append(keep_indices, mixup_index2)
mixup_flag.extend([0] * len(keep_indices))
mixup_flag.extend([1] * len(mixup_index1))
idx1 = torch.from_numpy(idx1).to(x.device).long()
idx2 = torch.from_numpy(idx2).to(x.device).long()
......@@ -859,6 +883,7 @@ class S2TTransformerEncoder(FairseqEncoder):
pad2 = encoder_padding_mask[idx2]
encoder_padding_mask = pad1 & pad2
input_lengths = (~encoder_padding_mask).sum(-1)
mixup_flag = torch.Tensor(mixup_flag).to(x.device).bool()
mixup = {
"ratio": self.mixup_ratio,
......@@ -866,6 +891,8 @@ class S2TTransformerEncoder(FairseqEncoder):
"coef": coef,
"index1": idx1,
"index2": idx2,
"mixup_flag": mixup_flag,
"mixup_decoder_emb": self.mixup_decoder_emb,
}
return x, encoder_padding_mask, input_lengths, mixup
......@@ -965,6 +992,13 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None:
x = self.history.pop()
if self.layer_padding_mask and encoder_padding_mask is not None and not torch.all(encoder_padding_mask):
mask_pad = encoder_padding_mask.unsqueeze(2)
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
# encoder layer
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
......@@ -1190,6 +1224,8 @@ def base_architecture(args):
args.encoder_embed_linear = getattr(args, "encoder_embed_linear", False)
args.encoder_embed_norm = getattr(args, "encoder_embed_norm", False)
args.layer_padding_mask = getattr(args, "layer_padding_mask", False)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", 0)
args.share_ctc_and_embed = getattr(args, "share_ctc_and_embed", False)
......@@ -1240,10 +1276,12 @@ def base_architecture(args):
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", "-1")
args.inter_mixup_decoder_layer = getattr(args, "inter_mixup_decoder_layer", "0")
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
args.inter_mixup_keep_org = getattr(args, "inter_mixup_keep_org", False)
args.inter_mixup_decoder_emb = getattr(args, "inter_mixup_decoder_emb", False)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
......
......@@ -18,7 +18,8 @@ from fairseq.models import (
from fairseq.modules.speech_to_text import Adapter, CTC
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.speech_to_text import S2TTransformerModel, S2TTransformerEncoder
from fairseq.models.wav2vec import Wav2Vec2Model, Wav2VecCtc
from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecCtc
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
from fairseq.modules import (
FairseqDropout,
LayerNorm,
......
......@@ -6,6 +6,8 @@
import math
from typing import Any, Dict, List, Optional, Tuple
import logging
import copy
import torch
import torch.nn as nn
......@@ -900,6 +902,17 @@ class TransformerDecoder(FairseqIncrementalDecoder):
#self.gather_attn_weight = True
self.attn_weights = dict()
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_decoder_emb = args.inter_mixup_decoder_emb
str_mixup_layer = getattr(args, "inter_mixup_decoder_layer", "0")
if len(str_mixup_layer.split(",")) == 1:
self.mixup_layer = int(str_mixup_layer)
else:
self.mixup_layer = [int(layer) for layer in str_mixup_layer.split(",")]
logger.info("Use mixup in the decoder layer %s, mixup decoder embedding %r." % (
str_mixup_layer, self.mixup_decoder_emb))
def build_decoder_layer(self, args, no_encoder_attn=False):
layer = TransformerDecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False):
......@@ -974,6 +987,68 @@ class TransformerDecoder(FairseqIncrementalDecoder):
this function is made to be used in the subclass instead.
"""
def apply_mixup(self, encoder_out, x, self_attn_padding_mask):
mixup = encoder_out["mixup"]
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
flag = mixup["mixup_flag"]
if mixup["mixup_decoder_emb"]:
x1 = x[:, idx1]
x2 = x[:, idx2]
mixup_coef = coef.view(1, -1, 1)
x = mixup_coef * x1 + (1 - mixup_coef) * x2
x = x.contiguous()
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 & pad2
else:
mix_idx1 = idx1[flag]
mix_idx2 = idx2[flag]
org_idx = idx1[~flag]
x1 = x[:, mix_idx1]
x2 = x[:, mix_idx2]
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[mix_idx1]
pad2 = self_attn_padding_mask[mix_idx2]
decoder_mixup_flag1 = [0] * len(org_idx)
decoder_mixup_flag2 = [0] * len(org_idx)
if len(org_idx) != 0:
org_x = x[:, org_idx]
x = torch.cat([org_x, x1, x2], dim=1)
if self_attn_padding_mask is not None:
org_pad = self_attn_padding_mask[org_idx]
self_attn_padding_mask = torch.cat([org_pad, pad1, pad2], dim=0)
else:
x = torch.cat([x1, x2], dim=1)
if self_attn_padding_mask is not None:
self_attn_padding_mask = torch.cat([pad1, pad2], dim=0)
decoder_mixup_flag1.extend([1] * len(mix_idx1))
decoder_mixup_flag1.extend([0] * len(mix_idx2))
decoder_mixup_flag2.extend([0] * len(mix_idx1))
decoder_mixup_flag2.extend([1] * len(mix_idx2))
mixup["decoder_mixup_flag1"] = torch.Tensor(decoder_mixup_flag1).to(x.device).bool()
mixup["decoder_mixup_flag2"] = torch.Tensor(decoder_mixup_flag2).to(x.device).bool()
encoder_rep = encoder_out["encoder_out"][0]
mixup_encoder_rep = encoder_rep[:, flag, :]
encoder_out["encoder_out"][0] = torch.cat([encoder_rep, mixup_encoder_rep], dim=1)
padding = encoder_out["encoder_padding_mask"][0]
mixup_padding = padding[flag, :]
encoder_out["encoder_padding_mask"][0] = torch.cat([padding, mixup_padding], dim=0)
return encoder_out, x, self_attn_padding_mask, mixup
def extract_features_scriptable(
self,
prev_output_tokens,
......@@ -1008,6 +1083,24 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if alignment_layer is None:
alignment_layer = self.num_layers - 1
layer_idx = -1
bak_encoder_out = encoder_out["encoder_out"][0]
bak_encoder_padding_mask = encoder_out["encoder_padding_mask"][0]
do_mixup = False
mixup_layer = 0
mixup = None
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
do_mixup = True
if type(self.mixup_layer) is list:
from random import choice
mixup_layer = choice(self.mixup_layer)
else:
mixup_layer = self.mixup_layer
if do_mixup and layer_idx == mixup_layer:
logger.warning("To DO!!!")
# embed positions
positions = None
if self.embed_positions is not None:
......@@ -1048,24 +1141,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
mixup = None
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
mixup = encoder_out["mixup"]
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
x1 = x[:, idx1]
x2 = x[:, idx2]
mixup_coef = coef.view(1, -1, 1)
x = mixup_coef * x1 + (1 - mixup_coef) * x2
x = x.contiguous()
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 & pad2
layer_idx += 1
if do_mixup and layer_idx == mixup_layer:
encoder_out, x, self_attn_padding_mask, mixup = self.apply_mixup(encoder_out, x, self_attn_padding_mask)
# decoder layers
avg_attn = None
......@@ -1109,6 +1187,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else:
avg_attn += layer_attn
layer_idx += 1
if do_mixup and layer_idx == mixup_layer:
encoder_out, x, self_attn_padding_mask, mixup = self.apply_mixup(encoder_out, x, self_attn_padding_mask)
if self.gather_attn_weight:
avg_attn = avg_attn / len(self.layers)
attn = avg_attn.mean(0).sum(-2)
......@@ -1149,6 +1231,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if do_mixup:
encoder_out["encoder_out"][0] = bak_encoder_out
encoder_out["encoder_padding_mask"][0] = bak_encoder_padding_mask
return x, {"attn": [attn], "inner_states": inner_states, "mixup": mixup}
......
......@@ -67,7 +67,7 @@ class Adapter(nn.Module):
self.cal_context = False
self.shrink = False
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2", "league_shrink"]:
self.cal_linear = True
self.linear_adapter = nn.Sequential(
nn.Linear(dim, 2 * dim),
......@@ -101,7 +101,7 @@ class Adapter(nn.Module):
self.shrink = True
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if self.cal_context:
if self.cal_context or self.shrink:
self.distribution_cutoff = strategy.get("distribution_cutoff", None)
self.distribution_temperature = strategy.get("ctc_temperature", 1.0)
self.gumbel = strategy.get("gumbel", False)
......@@ -181,7 +181,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "context":
out = soft_out
elif self.adapter_type in ["league", "inter_league_shrink"]:
elif self.adapter_type in ["league", "league_shrink"]:
if self.training and self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5:
out = linear_out
......
......@@ -81,7 +81,7 @@ def last_n_checkpoints(paths, n, combine_choice, upper_bound=None, max_metric=Fa
pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
elif combine_choice == "best":
reverse = True if max_metric else False
pt_regexp = re.compile(r"checkpoint\.best_loss_\d+_(\d+\.?\d*)\.pt")
pt_regexp = re.compile(r"checkpoint\.best_\w+_\d+_(\d+\.?\d*)\.pt")
else:
pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论