Commit ca4271f2 by xuchen

use inf instead of 1e8 or 1e-8

parent 976237ec
...@@ -64,9 +64,9 @@ class ESPNETMultiHeadedAttention(nn.Module): ...@@ -64,9 +64,9 @@ class ESPNETMultiHeadedAttention(nn.Module):
if kwargs.get("cal_localness", False) and not self.encoder_decoder_attention: if kwargs.get("cal_localness", False) and not self.encoder_decoder_attention:
self.cal_localness = True self.cal_localness = True
self.localness_window = kwargs.get("localness_window", 0.1) self.localness_window = kwargs.get("localness_window", 0.1)
if kwargs.get("cal_entropy", False): if kwargs.get("cal_entropy", False): # and self.encoder_decoder_attention:
self.cal_entropy = True self.cal_entropy = True
if kwargs.get("cal_topk_cross_attn_weights", False) and self.encoder_decoder_attention: if kwargs.get("cal_topk_cross_attn_weights", False):
self.cal_topk = True self.cal_topk = True
self.weights_topk = kwargs.get("topk_cross_attn_weights", 1) self.weights_topk = kwargs.get("topk_cross_attn_weights", 1)
if kwargs.get("cal_monotonic_cross_attn_weights", False) and self.encoder_decoder_attention: if kwargs.get("cal_monotonic_cross_attn_weights", False) and self.encoder_decoder_attention:
...@@ -74,7 +74,7 @@ class ESPNETMultiHeadedAttention(nn.Module): ...@@ -74,7 +74,7 @@ class ESPNETMultiHeadedAttention(nn.Module):
def dump(self, fstream, info): def dump(self, fstream, info):
if self.cal_localness: if self.cal_localness:
print("%s window size: %f localness: %.2f" % (info, self.localness_window, self.localness), file=fstream) print("%s window size: %.2f localness: %.4f" % (info, self.localness_window, self.localness), file=fstream)
if self.cal_entropy: if self.cal_entropy:
print("%s Entropy: %.2f" % (info, self.entropy), file=fstream) print("%s Entropy: %.2f" % (info, self.entropy), file=fstream)
...@@ -119,8 +119,8 @@ class ESPNETMultiHeadedAttention(nn.Module): ...@@ -119,8 +119,8 @@ class ESPNETMultiHeadedAttention(nn.Module):
if mask is not None: if mask is not None:
scores = scores.masked_fill( scores = scores.masked_fill(
mask.unsqueeze(1).unsqueeze(2).to(bool), mask.unsqueeze(1).unsqueeze(2).to(bool),
-1e8 if scores.dtype == torch.float32 else -1e4 # -1e8 if scores.dtype == torch.float32 else -1e4
# float("-inf"), # (batch, head, time1, time2) float("-inf"), # (batch, head, time1, time2)
) )
# self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) # self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
......
...@@ -286,7 +286,7 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -286,7 +286,7 @@ class PDSTransformerEncoderLayer(nn.Module):
# the attention weight (before softmax) for some padded element in query # the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters # will become -inf, which results in NaN in model parameters
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -float('inf'))
# whether to use macaron style # whether to use macaron style
if self.macaron_norm is not None: if self.macaron_norm is not None:
......
...@@ -252,7 +252,7 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -252,7 +252,7 @@ class S2TTransformerEncoderLayer(nn.Module):
# the attention weight (before softmax) for some padded element in query # the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters # will become -inf, which results in NaN in model parameters
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -float('inf'))
# whether to use macaron style # whether to use macaron style
if self.macaron_norm is not None: if self.macaron_norm is not None:
......
...@@ -315,7 +315,7 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -315,7 +315,7 @@ class S2TTransformerS2EncoderLayer(nn.Module):
# the attention weight (before softmax) for some padded element in query # the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters # will become -inf, which results in NaN in model parameters
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -float('inf'))
# whether to use macaron style # whether to use macaron style
if self.macaron_norm is not None: if self.macaron_norm is not None:
......
...@@ -188,7 +188,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -188,7 +188,7 @@ class TransformerEncoderLayer(nn.Module):
# will become -inf, which results in NaN in model parameters # will become -inf, which results in NaN in model parameters
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill( attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 attn_mask.to(torch.bool), -float('inf') # -1e8 if x.dtype == torch.float32 else -1e4
) )
residual = x residual = x
......
...@@ -241,7 +241,7 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -241,7 +241,7 @@ class TransformerS2EncoderLayer(nn.Module):
# will become -inf, which results in NaN in model parameters # will become -inf, which results in NaN in model parameters
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill( attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 attn_mask.to(torch.bool), -float('inf') # -1e8 if x.dtype == torch.float32 else -1e4
) )
residual = x residual = x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论