Commit c4748e46 by xuchen

optimize the relative attention for acceleration

parent 00945388
......@@ -238,14 +238,14 @@ class RelativeMultiheadAttention(MultiheadAttention):
)
relative_positions_matrix = self._generate_relative_positions_matrix(
src_len, self.max_relative_length, incremental_state
src_len, self.max_relative_length, k.device, incremental_state
)
if self.k_only:
relation_keys = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_keys)
relation_keys = F.embedding(relative_positions_matrix.long(), self.relative_position_keys)
else:
relation_keys = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_values)
relation_keys = F.embedding(relative_positions_matrix.long(), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long(), self.relative_position_values)
attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
......@@ -313,14 +313,14 @@ class RelativeMultiheadAttention(MultiheadAttention):
return attn, attn_weights
@staticmethod
def _generate_relative_positions_matrix(length, max_relative_length, incremental_state):
def _generate_relative_positions_matrix(length, max_relative_length, device, incremental_state):
if not incremental_state:
# training process
range_vec = torch.arange(length)
range_vec = torch.arange(length).to(device)
range_mat = range_vec.repeat(length, 1)
distance_mat = range_mat - range_mat.transpose(0, 1)
else:
distance_mat = torch.arange(-length + 1, 1).view(1, -1)
distance_mat = torch.arange(-length + 1, 1).view(1, -1).to(device)
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_length, max_relative_length)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论