Commit c4748e46 by xuchen

optimize the relative attention for acceleration

parent 00945388
...@@ -238,14 +238,14 @@ class RelativeMultiheadAttention(MultiheadAttention): ...@@ -238,14 +238,14 @@ class RelativeMultiheadAttention(MultiheadAttention):
) )
relative_positions_matrix = self._generate_relative_positions_matrix( relative_positions_matrix = self._generate_relative_positions_matrix(
src_len, self.max_relative_length, incremental_state src_len, self.max_relative_length, k.device, incremental_state
) )
if self.k_only: if self.k_only:
relation_keys = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_keys) relation_keys = F.embedding(relative_positions_matrix.long(), self.relative_position_keys)
else: else:
relation_keys = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_keys) relation_keys = F.embedding(relative_positions_matrix.long(), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_values) relation_values = F.embedding(relative_positions_matrix.long(), self.relative_position_values)
attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True) attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
...@@ -313,14 +313,14 @@ class RelativeMultiheadAttention(MultiheadAttention): ...@@ -313,14 +313,14 @@ class RelativeMultiheadAttention(MultiheadAttention):
return attn, attn_weights return attn, attn_weights
@staticmethod @staticmethod
def _generate_relative_positions_matrix(length, max_relative_length, incremental_state): def _generate_relative_positions_matrix(length, max_relative_length, device, incremental_state):
if not incremental_state: if not incremental_state:
# training process # training process
range_vec = torch.arange(length) range_vec = torch.arange(length).to(device)
range_mat = range_vec.repeat(length, 1) range_mat = range_vec.repeat(length, 1)
distance_mat = range_mat - range_mat.transpose(0, 1) distance_mat = range_mat - range_mat.transpose(0, 1)
else: else:
distance_mat = torch.arange(-length + 1, 1).view(1, -1) distance_mat = torch.arange(-length + 1, 1).view(1, -1).to(device)
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_length, max_relative_length) distance_mat_clipped = torch.clamp(distance_mat, -max_relative_length, max_relative_length)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论