Commit f1cf477d by xuchen

optimize the code

parent ad1caa72
gpu_num=1 gpu_num=1
cmd=""
while : while :
do do
......
...@@ -312,7 +312,8 @@ class RelativeMultiheadAttention(MultiheadAttention): ...@@ -312,7 +312,8 @@ class RelativeMultiheadAttention(MultiheadAttention):
return attn, attn_weights return attn, attn_weights
def _generate_relative_positions_matrix(self, length, max_relative_length, incremental_state): @staticmethod
def _generate_relative_positions_matrix(length, max_relative_length, incremental_state):
if not incremental_state: if not incremental_state:
# training process # training process
range_vec = torch.arange(length) range_vec = torch.arange(length)
...@@ -328,7 +329,8 @@ class RelativeMultiheadAttention(MultiheadAttention): ...@@ -328,7 +329,8 @@ class RelativeMultiheadAttention(MultiheadAttention):
return final_mat return final_mat
def _relative_attention_inner(self, x, y, z, transpose=True): @staticmethod
def _relative_attention_inner(x, y, z, transpose=True):
"""Relative position-aware dot-product attention inner calculation. """Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting. This batches matrix multiply calculations to avoid unnecessary broadcasting.
...@@ -337,7 +339,7 @@ class RelativeMultiheadAttention(MultiheadAttention): ...@@ -337,7 +339,7 @@ class RelativeMultiheadAttention(MultiheadAttention):
x: Tensor with shape [batch_size*heads, length, length or depth]. x: Tensor with shape [batch_size*heads, length, length or depth].
y: Tensor with shap e [batch_size*heads, length, depth]. y: Tensor with shap e [batch_size*heads, length, depth].
z: Tensor with shape [length, length, depth]. z: Tensor with shape [length, length, depth].
transpose: Whether to tranpose inner matrices of y and z. Should be true if transpose: Whether to transpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length. last dimension of x is depth, not length.
Returns: Returns:
...@@ -348,11 +350,12 @@ class RelativeMultiheadAttention(MultiheadAttention): ...@@ -348,11 +350,12 @@ class RelativeMultiheadAttention(MultiheadAttention):
""" """
batch_size_mul_head = x.size()[0] batch_size_mul_head = x.size()[0]
length = z.size()[0] length = z.size()[0]
# print(batch_size_mul_head, length)
# xy_matmul is [batch_size*heads, length, length or depth] # xy_matmul is [batch_size*heads, length, length or depth]
if transpose: if transpose:
y = y.transpose(1, 2) y = y.transpose(1, 2)
xy_matmul = torch.bmm(x, y) xy_matmul = torch.bmm(x, y)
# x_t is [length, batch_size * heads, length or depth] # x_t is [length, batch_size * heads, length or depth]
x_t = x.transpose(0, 1) x_t = x.transpose(0, 1)
# x_tz_matmul is [length, batch_size * heads, length or depth] # x_tz_matmul is [length, batch_size * heads, length or depth]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论