Commit c845197f by xuchen

fix the bugs of mixup and support the manifold mixup

parent 37eaeb25
......@@ -3,6 +3,8 @@ import math
from typing import Dict, List, Optional, Tuple
import numpy as np
from random import choice
import torch
import torch.nn as nn
......@@ -493,8 +495,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
)
parser.add_argument(
"--inter-mixup-layer",
default=None,
type=int,
default="-1",
type=str,
help="the layers to apply mixup",
)
parser.add_argument(
......@@ -750,7 +752,11 @@ class S2TTransformerEncoder(FairseqEncoder):
# mixup
self.mixup = getattr(args, "inter_mixup", False)
if self.mixup:
self.mixup_layer = args.inter_mixup_layer
str_mixup_layer = args.inter_mixup_layer
if len(str_mixup_layer.split(",")) == 1:
self.mixup_layer = int(str_mixup_layer)
else:
self.mixup_layer = [int(layer) for layer in str_mixup_layer.split(",")]
self.mixup_prob = args.inter_mixup_prob
self.mixup_ratio = args.inter_mixup_ratio
self.mixup_keep_org = args.inter_mixup_keep_org
......@@ -758,8 +764,8 @@ class S2TTransformerEncoder(FairseqEncoder):
beta = args.inter_mixup_beta
from torch.distributions import Beta
self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta]))
logger.info("Use mixup in layer %d with beta %.2f, prob %.2f, ratio %.2f, keep original data %r." % (
self.mixup_layer, beta, self.mixup_prob, self.mixup_ratio, self.mixup_keep_org))
logger.info("Use mixup in layer %s with beta %.2f, prob %.2f, ratio %.2f, keep original data %r." % (
str_mixup_layer, beta, self.mixup_prob, self.mixup_ratio, self.mixup_keep_org))
# gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
......@@ -817,8 +823,9 @@ class S2TTransformerEncoder(FairseqEncoder):
def apply_mixup(self, x, encoder_padding_mask):
batch = x.size(1)
indices = np.random.permutation(batch)
org_indices = np.arange(batch)
# indices = np.random.permutation(batch)
# if self.mixup_ratio == 1:
# if len(indices) % 2 != 0:
# indices = np.append(indices, (indices[-1]))
......@@ -840,8 +847,12 @@ class S2TTransformerEncoder(FairseqEncoder):
# idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
mixup_size = int(batch * self.mixup_ratio)
mixup_index1 = np.random.permutation(mixup_size)
mixup_index2 = np.random.permutation(mixup_size)
if mixup_size <= batch:
mixup_index1 = np.random.permutation(mixup_size)
mixup_index2 = np.random.permutation(mixup_size)
else:
mixup_index1 = np.random.randint(0, batch, mixup_size)
mixup_index2 = np.random.randint(0, batch, mixup_size)
if self.mixup_keep_org:
idx1 = np.append(org_indices, mixup_index1)
idx2 = np.append(org_indices, mixup_index2)
......@@ -888,6 +899,10 @@ class S2TTransformerEncoder(FairseqEncoder):
layer_idx = -1
mixup = None
if type(self.mixup_layer) is list:
mixup_layer = choice(self.mixup_layer)
else:
mixup_layer = self.mixup_layer
if self.history is not None:
self.history.clean()
......@@ -903,7 +918,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.gather_cos_sim:
self.add_to_dict(x, dis, cos_sim_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer:
if self.training and self.mixup and layer_idx == mixup_layer:
if torch.rand(1) < self.mixup_prob:
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
......@@ -957,7 +972,7 @@ class S2TTransformerEncoder(FairseqEncoder):
ctc_logit = None
interleaved_ctc_logits = []
if self.training and self.mixup and layer_idx == self.mixup_layer:
if self.training and self.mixup and layer_idx == mixup_layer:
if torch.rand(1) <= self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
......@@ -971,7 +986,7 @@ class S2TTransformerEncoder(FairseqEncoder):
layer_idx += 1
self.show_debug(x, "x after layer %d" % layer_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer:
if self.training and self.mixup and layer_idx == mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
......@@ -1236,7 +1251,7 @@ def base_architecture(args):
# mixup
args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", "-1")
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论