Commit 07f0ab14 by xuchen

fix the device bug of the local attention

parent 8b97a50b
gpu_num=1 gpu_num=1
cmd="" cmd="sh train.sh"
while : while :
do do
......
...@@ -319,22 +319,21 @@ class LocalMultiheadAttention(nn.Module): ...@@ -319,22 +319,21 @@ class LocalMultiheadAttention(nn.Module):
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1) x2 = torch.arange(-1, src_len - 1, 1).view(1, -1)
dis = x2 - x1 dis = x2 - x1
mask_diag = torch.abs(dis) > hard_mask_window mask_diag = torch.abs(dis) > hard_mask_window
mask_diag = mask_diag.unsqueeze(0) mask_diag = mask_diag.unsqueeze(0).to(attn_weights.device)
attn_weights = attn_weights.masked_fill(mask_diag, float("-inf")) attn_weights = attn_weights.masked_fill(mask_diag, float("-inf"))
if self.multihead_gauss_mask_sigma is not None: if self.multihead_gauss_mask_sigma is not None:
x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1) x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1).to(attn_weights.device)
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1) x2 = torch.arange(-1, src_len - 1, 1).view(1, -1).to(attn_weights.device)
diag_growing = -(x1 - x2) ** 2 / 2.0 dis_square = -(x1 - x2) ** 2 / 2.0
e_diag_gauss_mask = diag_growing.unsqueeze(0).repeat(self.num_heads, 1, 1) multihead_dis_square = dis_square.unsqueeze(0).repeat(self.num_heads, 1, 1)
sigma_square = 1 / torch.square(self.multihead_gauss_mask_sigma)
e_sigma_square = 1 / torch.square(self.multihead_gauss_mask_sigma) gauss_bias = multihead_dis_square * sigma_square
e_diag_gauss_mask_final = e_diag_gauss_mask * e_sigma_square gauss_bias = torch.unsqueeze(gauss_bias, 0)
e_diag_gauss_mask_final = torch.unsqueeze(e_diag_gauss_mask_final, 0)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
multihead_mask_weight = torch.sigmoid(self.multihead_mask_weight.unsqueeze(0)) multihead_mask_weight = torch.sigmoid(self.multihead_mask_weight.unsqueeze(0))
attn_weights = (1 - multihead_mask_weight) * attn_weights + multihead_mask_weight * e_diag_gauss_mask_final attn_weights = (1 - multihead_mask_weight) * attn_weights + multihead_mask_weight * gauss_bias
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None: if attn_mask is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论