Commit c845197f by xuchen

fix the bugs of mixup and support the manifold mixup

parent 37eaeb25
...@@ -3,6 +3,8 @@ import math ...@@ -3,6 +3,8 @@ import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
from random import choice
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -493,8 +495,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -493,8 +495,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
) )
parser.add_argument( parser.add_argument(
"--inter-mixup-layer", "--inter-mixup-layer",
default=None, default="-1",
type=int, type=str,
help="the layers to apply mixup", help="the layers to apply mixup",
) )
parser.add_argument( parser.add_argument(
...@@ -750,7 +752,11 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -750,7 +752,11 @@ class S2TTransformerEncoder(FairseqEncoder):
# mixup # mixup
self.mixup = getattr(args, "inter_mixup", False) self.mixup = getattr(args, "inter_mixup", False)
if self.mixup: if self.mixup:
self.mixup_layer = args.inter_mixup_layer str_mixup_layer = args.inter_mixup_layer
if len(str_mixup_layer.split(",")) == 1:
self.mixup_layer = int(str_mixup_layer)
else:
self.mixup_layer = [int(layer) for layer in str_mixup_layer.split(",")]
self.mixup_prob = args.inter_mixup_prob self.mixup_prob = args.inter_mixup_prob
self.mixup_ratio = args.inter_mixup_ratio self.mixup_ratio = args.inter_mixup_ratio
self.mixup_keep_org = args.inter_mixup_keep_org self.mixup_keep_org = args.inter_mixup_keep_org
...@@ -758,8 +764,8 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -758,8 +764,8 @@ class S2TTransformerEncoder(FairseqEncoder):
beta = args.inter_mixup_beta beta = args.inter_mixup_beta
from torch.distributions import Beta from torch.distributions import Beta
self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta])) self.beta = Beta(torch.Tensor([beta]), torch.Tensor([beta]))
logger.info("Use mixup in layer %d with beta %.2f, prob %.2f, ratio %.2f, keep original data %r." % ( logger.info("Use mixup in layer %s with beta %.2f, prob %.2f, ratio %.2f, keep original data %r." % (
self.mixup_layer, beta, self.mixup_prob, self.mixup_ratio, self.mixup_keep_org)) str_mixup_layer, beta, self.mixup_prob, self.mixup_ratio, self.mixup_keep_org))
# gather cosine similarity # gather cosine similarity
self.gather_cos_sim = getattr(args, "gather_cos_sim", False) self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
...@@ -817,8 +823,9 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -817,8 +823,9 @@ class S2TTransformerEncoder(FairseqEncoder):
def apply_mixup(self, x, encoder_padding_mask): def apply_mixup(self, x, encoder_padding_mask):
batch = x.size(1) batch = x.size(1)
indices = np.random.permutation(batch)
org_indices = np.arange(batch) org_indices = np.arange(batch)
# indices = np.random.permutation(batch)
# if self.mixup_ratio == 1: # if self.mixup_ratio == 1:
# if len(indices) % 2 != 0: # if len(indices) % 2 != 0:
# indices = np.append(indices, (indices[-1])) # indices = np.append(indices, (indices[-1]))
...@@ -840,8 +847,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -840,8 +847,12 @@ class S2TTransformerEncoder(FairseqEncoder):
# idx2 = np.append(mix_indices[1::2], (indices[mix_size:])) # idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
mixup_size = int(batch * self.mixup_ratio) mixup_size = int(batch * self.mixup_ratio)
if mixup_size <= batch:
mixup_index1 = np.random.permutation(mixup_size) mixup_index1 = np.random.permutation(mixup_size)
mixup_index2 = np.random.permutation(mixup_size) mixup_index2 = np.random.permutation(mixup_size)
else:
mixup_index1 = np.random.randint(0, batch, mixup_size)
mixup_index2 = np.random.randint(0, batch, mixup_size)
if self.mixup_keep_org: if self.mixup_keep_org:
idx1 = np.append(org_indices, mixup_index1) idx1 = np.append(org_indices, mixup_index1)
idx2 = np.append(org_indices, mixup_index2) idx2 = np.append(org_indices, mixup_index2)
...@@ -888,6 +899,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -888,6 +899,10 @@ class S2TTransformerEncoder(FairseqEncoder):
layer_idx = -1 layer_idx = -1
mixup = None mixup = None
if type(self.mixup_layer) is list:
mixup_layer = choice(self.mixup_layer)
else:
mixup_layer = self.mixup_layer
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
...@@ -903,7 +918,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -903,7 +918,7 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.gather_cos_sim: if self.gather_cos_sim:
self.add_to_dict(x, dis, cos_sim_idx) self.add_to_dict(x, dis, cos_sim_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer: if self.training and self.mixup and layer_idx == mixup_layer:
if torch.rand(1) < self.mixup_prob: if torch.rand(1) < self.mixup_prob:
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask) x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
...@@ -957,7 +972,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -957,7 +972,7 @@ class S2TTransformerEncoder(FairseqEncoder):
ctc_logit = None ctc_logit = None
interleaved_ctc_logits = [] interleaved_ctc_logits = []
if self.training and self.mixup and layer_idx == self.mixup_layer: if self.training and self.mixup and layer_idx == mixup_layer:
if torch.rand(1) <= self.mixup_prob: if torch.rand(1) <= self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask) x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
...@@ -971,7 +986,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -971,7 +986,7 @@ class S2TTransformerEncoder(FairseqEncoder):
layer_idx += 1 layer_idx += 1
self.show_debug(x, "x after layer %d" % layer_idx) self.show_debug(x, "x after layer %d" % layer_idx)
if self.training and self.mixup and layer_idx == self.mixup_layer: if self.training and self.mixup and layer_idx == mixup_layer:
if torch.rand(1) < self.mixup_prob: if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask) x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
...@@ -1236,7 +1251,7 @@ def base_architecture(args): ...@@ -1236,7 +1251,7 @@ def base_architecture(args):
# mixup # mixup
args.inter_mixup = getattr(args, "inter_mixup", False) args.inter_mixup = getattr(args, "inter_mixup", False)
args.inter_mixup_layer = getattr(args, "inter_mixup_layer", None) args.inter_mixup_layer = getattr(args, "inter_mixup_layer", "-1")
args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5) args.inter_mixup_beta = getattr(args, "inter_mixup_beta", 0.5)
args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1) args.inter_mixup_prob = getattr(args, "inter_mixup_prob", 1)
args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3) args.inter_mixup_ratio = getattr(args, "inter_mixup_ratio", 0.3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论