Commit 07f0ab14 by xuchen

fix the device bug of the local attention

parent 8b97a50b
gpu_num=1
cmd=""
cmd="sh train.sh"
while :
do
......
......@@ -319,22 +319,21 @@ class LocalMultiheadAttention(nn.Module):
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1)
dis = x2 - x1
mask_diag = torch.abs(dis) > hard_mask_window
mask_diag = mask_diag.unsqueeze(0)
mask_diag = mask_diag.unsqueeze(0).to(attn_weights.device)
attn_weights = attn_weights.masked_fill(mask_diag, float("-inf"))
if self.multihead_gauss_mask_sigma is not None:
x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1)
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1)
diag_growing = -(x1 - x2) ** 2 / 2.0
e_diag_gauss_mask = diag_growing.unsqueeze(0).repeat(self.num_heads, 1, 1)
e_sigma_square = 1 / torch.square(self.multihead_gauss_mask_sigma)
e_diag_gauss_mask_final = e_diag_gauss_mask * e_sigma_square
x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1).to(attn_weights.device)
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1).to(attn_weights.device)
dis_square = -(x1 - x2) ** 2 / 2.0
multihead_dis_square = dis_square.unsqueeze(0).repeat(self.num_heads, 1, 1)
sigma_square = 1 / torch.square(self.multihead_gauss_mask_sigma)
gauss_bias = multihead_dis_square * sigma_square
gauss_bias = torch.unsqueeze(gauss_bias, 0)
e_diag_gauss_mask_final = torch.unsqueeze(e_diag_gauss_mask_final, 0)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
multihead_mask_weight = torch.sigmoid(self.multihead_mask_weight.unsqueeze(0))
attn_weights = (1 - multihead_mask_weight) * attn_weights + multihead_mask_weight * e_diag_gauss_mask_final
attn_weights = (1 - multihead_mask_weight) * attn_weights + multihead_mask_weight * gauss_bias
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论