Commit cbeb5521 by xuchen

update the mixup implementation

parent c7242ff4
...@@ -3,6 +3,6 @@ inter-mixup-layer: -1 ...@@ -3,6 +3,6 @@ inter-mixup-layer: -1
inter-mixup-prob: 1.0 inter-mixup-prob: 1.0
inter-mixup-ratio: 1.0 inter-mixup-ratio: 1.0
inter-mixup-beta: 0.5 inter-mixup-beta: 0.5
inter-mixup-keep-org: True inter-mixup-keep-org: False
ctc-mixup-consistent-weight: 1 ctc-mixup-consistent-weight: 0
mixup-consistent-weight: 1 mixup-consistent-weight: 0
...@@ -203,6 +203,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -203,6 +203,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
else: else:
target = target.view(-1) target = target.view(-1)
if lprobs.size(0) == 0:
return torch.Tensor([0]), torch.Tensor([0])
mask = target.ne(self.padding_idx) mask = target.ne(self.padding_idx)
n_correct = torch.sum( n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
......
...@@ -79,9 +79,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -79,9 +79,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
n_sentences = sample["target"].size(0) n_sentences = sample["target"].size(0)
if "mixup" in encoder_out and encoder_out["mixup"] is not None: if "mixup" in encoder_out and encoder_out["mixup"] is not None:
sample_size //= net_output[0].size(0) if self.sentence_avg else encoder_out["mixup"]["ratio"] mixup = encoder_out["mixup"]
n_tokens //= encoder_out["mixup"]["ratio"] ratio = mixup["ratio"]
n_sentences //= net_output[0].size(0)
if mixup["keep_org"]:
n_tokens = int(sample_size * (1 + ratio))
else:
n_tokens = int(sample_size * ratio)
if self.sentence_avg:
sample_size = net_output[0].size(0)
else:
sample_size = n_tokens
n_sentences = net_output[0].size(0)
logging_output = { logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data, "trans_loss": utils.item(loss.data) if reduce else loss.data,
......
...@@ -825,27 +825,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -825,27 +825,6 @@ class S2TTransformerEncoder(FairseqEncoder):
batch = x.size(1) batch = x.size(1)
org_indices = np.arange(batch) org_indices = np.arange(batch)
# indices = np.random.permutation(batch)
# if self.mixup_ratio == 1:
# if len(indices) % 2 != 0:
# indices = np.append(indices, (indices[-1]))
# idx1 = indices[0::2]
# idx2 = indices[1::2]
#
# if self.mixup_keep_org:
# idx1 = np.append(org_indices, idx1)
# idx2 = np.append(org_indices, idx2)
#
# else:
# mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
# mix_indices = indices[: mix_size]
# if self.mixup_keep_org:
# idx1 = np.append(org_indices, mix_indices[0::2])
# idx2 = np.append(org_indices, mix_indices[1::2])
# else:
# idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
# idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
mixup_size = int(batch * self.mixup_ratio) mixup_size = int(batch * self.mixup_ratio)
if mixup_size <= batch: if mixup_size <= batch:
mixup_index1 = np.random.permutation(mixup_size) mixup_index1 = np.random.permutation(mixup_size)
...@@ -853,6 +832,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -853,6 +832,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
mixup_index1 = np.random.randint(0, batch, mixup_size) mixup_index1 = np.random.randint(0, batch, mixup_size)
mixup_index2 = np.random.randint(0, batch, mixup_size) mixup_index2 = np.random.randint(0, batch, mixup_size)
if self.mixup_keep_org: if self.mixup_keep_org:
idx1 = np.append(org_indices, mixup_index1) idx1 = np.append(org_indices, mixup_index1)
idx2 = np.append(org_indices, mixup_index2) idx2 = np.append(org_indices, mixup_index2)
...@@ -864,15 +844,16 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -864,15 +844,16 @@ class S2TTransformerEncoder(FairseqEncoder):
idx1 = np.append(keep_indices, mixup_index1) idx1 = np.append(keep_indices, mixup_index1)
idx2 = np.append(keep_indices, mixup_index2) idx2 = np.append(keep_indices, mixup_index2)
idx1 = torch.from_numpy(idx1).to(x.device) idx1 = torch.from_numpy(idx1).to(x.device).long()
idx2 = torch.from_numpy(idx2).to(x.device) idx2 = torch.from_numpy(idx2).to(x.device).long()
x1 = x[:, idx1] x1 = x[:, idx1]
x2 = x[:, idx2] x2 = x[:, idx2]
coef = self.beta.sample([len(idx1)]).to(x.device).type_as(x).view(-1) coef = self.beta.sample([len(idx1)]).to(x.device).type_as(x).view(-1)
mixup_coef = coef.view(1, -1, 1) mixup_coef = coef.view(1, -1, 1)
x = (mixup_coef * x1 + (1 - mixup_coef) * x2) x = mixup_coef * x1 + (1 - mixup_coef) * x2
x = x.contiguous()
pad1 = encoder_padding_mask[idx1] pad1 = encoder_padding_mask[idx1]
pad2 = encoder_padding_mask[idx2] pad2 = encoder_padding_mask[idx2]
...@@ -881,6 +862,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -881,6 +862,7 @@ class S2TTransformerEncoder(FairseqEncoder):
mixup = { mixup = {
"ratio": self.mixup_ratio, "ratio": self.mixup_ratio,
"keep_org": self.mixup_keep_org,
"coef": coef, "coef": coef,
"index1": idx1, "index1": idx1,
"index2": idx2, "index2": idx2,
...@@ -1046,6 +1028,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1046,6 +1028,10 @@ class S2TTransformerEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
self.show_debug(x, "x after encoding layer norm") self.show_debug(x, "x after encoding layer norm")
if self.training and self.mixup and layer_idx == mixup_layer:
if torch.rand(1) < self.mixup_prob:
x, encoder_padding_mask, input_lengths, mixup = self.apply_mixup(x, encoder_padding_mask)
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x, encoder_padding_mask, "Source output", is_top=True) ctc_logit = self.ctc(x, encoder_padding_mask, "Source output", is_top=True)
self.show_debug(x, "x after ctc") self.show_debug(x, "x after ctc")
......
...@@ -1060,6 +1060,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -1060,6 +1060,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x2 = x[:, idx2] x2 = x[:, idx2]
mixup_coef = coef.view(1, -1, 1) mixup_coef = coef.view(1, -1, 1)
x = mixup_coef * x1 + (1 - mixup_coef) * x2 x = mixup_coef * x1 + (1 - mixup_coef) * x2
x = x.contiguous()
if self_attn_padding_mask is not None: if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1] pad1 = self_attn_padding_mask[idx1]
......
...@@ -36,7 +36,6 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper ...@@ -36,7 +36,6 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
...@@ -66,13 +65,13 @@ class TransformerS2Encoder(TransformerEncoder): ...@@ -66,13 +65,13 @@ class TransformerS2Encoder(TransformerEncoder):
return layer return layer
def forward( def forward(
self, self,
src_tokens, src_tokens,
src_lengths: Optional[torch.Tensor] = None, src_lengths: Optional[torch.Tensor] = None,
x2 = None, s2=None,
x2_encoder_padding_mask = None, s2_encoder_padding_mask=None,
return_all_hiddens: bool = False, return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None,
): ):
""" """
Args: Args:
...@@ -99,8 +98,8 @@ class TransformerS2Encoder(TransformerEncoder): ...@@ -99,8 +98,8 @@ class TransformerS2Encoder(TransformerEncoder):
""" """
return self.forward_scriptable(src_tokens, return self.forward_scriptable(src_tokens,
src_lengths, src_lengths,
x2, s2,
x2_encoder_padding_mask, s2_encoder_padding_mask,
return_all_hiddens, return_all_hiddens,
token_embeddings) token_embeddings)
...@@ -109,13 +108,13 @@ class TransformerS2Encoder(TransformerEncoder): ...@@ -109,13 +108,13 @@ class TransformerS2Encoder(TransformerEncoder):
# Current workaround is to add a helper function with different name and # Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass. # call the helper function from scriptable Subclass.
def forward_scriptable( def forward_scriptable(
self, self,
src_tokens, src_tokens,
src_lengths: Optional[torch.Tensor] = None, src_lengths: Optional[torch.Tensor] = None,
x2=None, s2=None,
x2_encoder_padding_mask=None, s2_encoder_padding_mask=None,
return_all_hiddens: bool = False, return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None,
): ):
""" """
Args: Args:
...@@ -172,7 +171,7 @@ class TransformerS2Encoder(TransformerEncoder): ...@@ -172,7 +171,7 @@ class TransformerS2Encoder(TransformerEncoder):
x = layer( x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None, x, encoder_padding_mask=encoder_padding_mask if has_pads else None,
x2=x2, x2_encoder_padding_mask=x2_encoder_padding_mask, s2=s2, s2_encoder_padding_mask=s2_encoder_padding_mask,
) )
if return_all_hiddens: if return_all_hiddens:
assert encoder_states is not None assert encoder_states is not None
...@@ -194,8 +193,8 @@ class TransformerS2Encoder(TransformerEncoder): ...@@ -194,8 +193,8 @@ class TransformerS2Encoder(TransformerEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_out_s2": [x2], # T x B x C "s2_encoder_out": [s2], # T x B x C
"encoder_padding_mask_s2": [x2_encoder_padding_mask], # B x T "s2_encoder_padding_mask": [s2_encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C "encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], "src_tokens": [],
...@@ -229,13 +228,13 @@ class TransformerS2Decoder(TransformerDecoder): ...@@ -229,13 +228,13 @@ class TransformerS2Decoder(TransformerDecoder):
return layer return layer
def extract_features_scriptable( def extract_features_scriptable(
self, self,
prev_output_tokens, prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]], encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False, full_context_alignment: bool = False,
alignment_layer: Optional[int] = None, alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None, alignment_heads: Optional[int] = None,
): ):
""" """
Similar to *forward* but only return features. Similar to *forward* but only return features.
...@@ -339,12 +338,13 @@ class TransformerS2Decoder(TransformerDecoder): ...@@ -339,12 +338,13 @@ class TransformerS2Decoder(TransformerDecoder):
else None, else None,
encoder_out["encoder_padding_mask"][0] encoder_out["encoder_padding_mask"][0]
if ( if (
encoder_out is not None encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0 and len(encoder_out["encoder_padding_mask"]) > 0
) )
else None, else None,
encoder_out_s2=encoder_out["s2_encoder_out"][0], encoder_out_s2=encoder_out["s2_encoder_out"][0] if len(encoder_out["s2_encoder_out"]) > 0 else None,
encoder_padding_mask_s2=encoder_out["s2_encoder_padding_mask"][0], encoder_padding_mask_s2=encoder_out["s2_encoder_padding_mask"][0] if len(
encoder_out["s2_encoder_padding_mask"]) > 0 else None,
incremental_state=incremental_state, incremental_state=incremental_state,
self_attn_mask=self_attn_mask, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
...@@ -411,4 +411,4 @@ def Embedding(num_embeddings, embedding_dim, padding_idx): ...@@ -411,4 +411,4 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
return m return m
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论