Commit f1cf477d by xuchen

optimize the code

parent ad1caa72
......@@ -312,7 +312,8 @@ class RelativeMultiheadAttention(MultiheadAttention):
return attn, attn_weights
def _generate_relative_positions_matrix(self, length, max_relative_length, incremental_state):
@staticmethod
def _generate_relative_positions_matrix(length, max_relative_length, incremental_state):
if not incremental_state:
# training process
range_vec = torch.arange(length)
......@@ -328,7 +329,8 @@ class RelativeMultiheadAttention(MultiheadAttention):
return final_mat
def _relative_attention_inner(self, x, y, z, transpose=True):
@staticmethod
def _relative_attention_inner(x, y, z, transpose=True):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
......@@ -337,7 +339,7 @@ class RelativeMultiheadAttention(MultiheadAttention):
x: Tensor with shape [batch_size*heads, length, length or depth].
y: Tensor with shap e [batch_size*heads, length, depth].
z: Tensor with shape [length, length, depth].
transpose: Whether to tranpose inner matrices of y and z. Should be true if
transpose: Whether to transpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
......@@ -348,11 +350,12 @@ class RelativeMultiheadAttention(MultiheadAttention):
"""
batch_size_mul_head = x.size()[0]
length = z.size()[0]
# print(batch_size_mul_head, length)
# xy_matmul is [batch_size*heads, length, length or depth]
if transpose:
y = y.transpose(1, 2)
xy_matmul = torch.bmm(x, y)
# x_t is [length, batch_size * heads, length or depth]
x_t = x.transpose(0, 1)
# x_tz_matmul is [length, batch_size * heads, length or depth]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论