Commit 8645e75b by xuchen

fix the bug of the RPE (TODO)

parent 8b47f344
......@@ -261,7 +261,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
)
# (batch * head, time1, d_k)
q_with_bias_v = (
(q + self.pos_bias_v).transpose(0, 1)
(q + self.pos_bias_v).transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, tgt_len, self.head_dim)
)
......@@ -280,8 +280,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
x (torch.Tensor): Input tensor (batch * head, time1, time2).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
......@@ -302,7 +301,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
x = x * torch.tril(ones, x.size(2) - x.size(1))[None, :, :]
return x
matrix_bd = rel_shift(matrix_bd)
# matrix_bd = rel_shift(matrix_bd)
attn_weights = (matrix_ac + matrix_bd) * self.scaling
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论