Commit 8645e75b by xuchen

fix the bug of the RPE (TODO)

parent 8b47f344
...@@ -261,7 +261,7 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -261,7 +261,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
) )
# (batch * head, time1, d_k) # (batch * head, time1, d_k)
q_with_bias_v = ( q_with_bias_v = (
(q + self.pos_bias_v).transpose(0, 1) (q + self.pos_bias_v).transpose(1, 2)
.contiguous() .contiguous()
.view(bsz * self.num_heads, tgt_len, self.head_dim) .view(bsz * self.num_heads, tgt_len, self.head_dim)
) )
...@@ -280,8 +280,7 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -280,8 +280,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
"""Compute relative positional encoding. """Compute relative positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). x (torch.Tensor): Input tensor (batch * head, time1, time2).
time1 means the length of query vector.
zero_triu (bool): If true, return the lower triangular part of zero_triu (bool): If true, return the lower triangular part of
the matrix. the matrix.
Returns: Returns:
...@@ -302,7 +301,7 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -302,7 +301,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
x = x * torch.tril(ones, x.size(2) - x.size(1))[None, :, :] x = x * torch.tril(ones, x.size(2) - x.size(1))[None, :, :]
return x return x
matrix_bd = rel_shift(matrix_bd) # matrix_bd = rel_shift(matrix_bd)
attn_weights = (matrix_ac + matrix_bd) * self.scaling attn_weights = (matrix_ac + matrix_bd) * self.scaling
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论