Commit dd402ec2 by xuchen

modify the implementation of the relative position encoding

parent 306dd6fc
...@@ -319,6 +319,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -319,6 +319,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
self.layer_norm = None self.layer_norm = None
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.use_ctc = ("ctc" in getattr(args, "criterion", False)) and \ self.use_ctc = ("ctc" in getattr(args, "criterion", False)) and \
(getattr(args, "ctc_weight", False) > 0) (getattr(args, "ctc_weight", False) > 0)
if self.use_ctc: if self.use_ctc:
...@@ -344,11 +345,13 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -344,11 +345,13 @@ class S2TTransformerEncoder(FairseqEncoder):
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions x += positions
# x += positions
x = self.dropout_module(x) x = self.dropout_module(x)
for layer in self.transformer_layers: for layer in self.transformer_layers:
x = layer(x, encoder_padding_mask) x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
......
...@@ -685,6 +685,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -685,6 +685,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
] ]
) )
self.num_layers = len(self.layers) self.num_layers = len(self.layers)
self.attn_type = getattr(args, "decoder_attention_type", "selfattn")
if args.decoder_normalize_before and not getattr( if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False args, "no_decoder_final_norm", False
...@@ -892,6 +893,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -892,6 +893,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)), need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)),
pos_emb=positions
) )
inner_states.append(x) inner_states.append(x)
if layer_attn is not None and idx == alignment_layer: if layer_attn is not None and idx == alignment_layer:
......
...@@ -55,7 +55,7 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -55,7 +55,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# linear transformation for positional encoding # linear transformation for positional encoding
self.linear_pos = quant_noise( self.linear_pos = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=False), q_noise, qn_block_size nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
) )
# these two learnable bias are used in matrix c and matrix d # these two learnable bias are used in matrix c and matrix d
...@@ -63,7 +63,7 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -63,7 +63,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
self.pos_bias_u = Parameter(torch.Tensor(self.num_heads, self.head_dim)) self.pos_bias_u = Parameter(torch.Tensor(self.num_heads, self.head_dim))
self.pos_bias_v = Parameter(torch.Tensor(self.num_heads, self.head_dim)) self.pos_bias_v = Parameter(torch.Tensor(self.num_heads, self.head_dim))
nn.init.xavier_uniform_(self.linear_pos.weight) # nn.init.xavier_uniform_(self.linear_pos.weight)
nn.init.xavier_normal_(self.pos_bias_u) nn.init.xavier_normal_(self.pos_bias_u)
nn.init.xavier_normal_(self.pos_bias_v) nn.init.xavier_normal_(self.pos_bias_v)
...@@ -109,6 +109,7 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -109,6 +109,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
if ( if (
False and
not self.onnx_trace not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None and incremental_state is None
...@@ -196,6 +197,8 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -196,6 +197,8 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# .view(tgt_len, bsz * self.num_heads, self.head_dim) # .view(tgt_len, bsz * self.num_heads, self.head_dim)
# .transpose(0, 1) # .transpose(0, 1)
# ) # )
# prepare q for RPE # (tgt_len, bsz, num_heads, head_dim)
q = q.contiguous().view(tgt_len, bsz, self.num_heads, self.head_dim)
if k is not None: if k is not None:
k = ( k = (
k.contiguous() k.contiguous()
...@@ -279,18 +282,19 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -279,18 +282,19 @@ class RelPositionMultiheadAttention(MultiheadAttention):
dim=1, dim=1,
) )
pos_emb = pos_emb.transpose(0, 1)
p_rep = self.linear_pos(pos_emb).view(bsz, -1, self.num_heads, self.head_dim) p_rep = self.linear_pos(pos_emb).view(bsz, -1, self.num_heads, self.head_dim)
p_rep = p_rep.contiguous().transpose(1, 2).view(bsz * self.num_heads, -1, self.head_dim) p_rep = p_rep.transpose(1, 2).contiguous().view(bsz * self.num_heads, -1, self.head_dim)
# (batch * head, time1, d_k) # (batch * head, time1, d_k)
q_with_bias_u = ( q_with_bias_u = (
(q + self.pos_bias_u) .contiguous() (q + self.pos_bias_u).contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim) .view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1) .transpose(0, 1)
) )
# (batch * head, time1, d_k) # (batch * head, time1, d_k)
q_with_bias_v = ( q_with_bias_v = (
(q + self.pos_bias_v) .contiguous() (q + self.pos_bias_v).contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim) .view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1) .transpose(0, 1)
) )
...@@ -298,15 +302,15 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -298,15 +302,15 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 # as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2) # (batch * head, time1, time2)
matrix_ac = torch.bmm(q_with_bias_u, k.transpose(1, 2)) matrix_ac = torch.bmm(q_with_bias_u, k.transpose(1, 2))
# compute matrix b and matrix d # compute matrix b and matrix d
# (batch, head, time1, time2) # (batch * head, time1, time2)
matrix_bd = torch.bmm(q_with_bias_v, p_rep.transpose(1, 2)) matrix_bd = torch.bmm(q_with_bias_v, p_rep.transpose(1, 2))
def rel_shift(x, zero_triu=False): def rel_shift(x, zero_triu=False):
"""Compute relative positinal encoding. """Compute relative positional encoding.
:param torch.Tensor x: (batch, time, size) :param torch.Tensor x: (batch, time, size)
:param bool zero_triu: return the lower triangular part of the matrix :param bool zero_triu: return the lower triangular part of the matrix
...@@ -323,8 +327,11 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -323,8 +327,11 @@ class RelPositionMultiheadAttention(MultiheadAttention):
return x return x
matrix_bd = rel_shift(matrix_bd) matrix_bd = matrix_bd.contiguous().view(bsz, self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
attn_weights = (matrix_ac + matrix_bd) / self.scaling matrix_bd = rel_shift(
matrix_bd,
).contiguous().view(bsz * self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
attn_weights = (matrix_ac + matrix_bd) * self.scaling
# attn_weights = torch.bmm(q, k.transpose(1, 2)) # attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
......
...@@ -35,6 +35,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -35,6 +35,7 @@ class TransformerEncoderLayer(nn.Module):
self.embed_dim = args.encoder_embed_dim self.embed_dim = args.encoder_embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
...@@ -77,13 +78,12 @@ class TransformerEncoderLayer(nn.Module): ...@@ -77,13 +78,12 @@ class TransformerEncoderLayer(nn.Module):
) )
def build_self_attention(self, embed_dim, args): def build_self_attention(self, embed_dim, args):
attn_type = getattr(args, "encoder_attention_type", "selfattn") if self.attn_type == "selfattn":
if attn_type == "selfattn":
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
else: else:
print("The attention type %s is not supported!" % attn_type) print("The attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
return attn_func( return attn_func(
...@@ -112,7 +112,10 @@ class TransformerEncoderLayer(nn.Module): ...@@ -112,7 +112,10 @@ class TransformerEncoderLayer(nn.Module):
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k] del state_dict[k]
def forward(self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None): def forward(self, x,
encoder_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None):
""" """
Args: Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -124,6 +127,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -124,6 +127,7 @@ class TransformerEncoderLayer(nn.Module):
`attn_mask[tgt_i, src_j] = 1` means that when calculating the `attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention. useful for strided self-attention.
positions (Tensor): the position embedding for relative position encoding
Returns: Returns:
encoded output of shape `(seq_len, batch, embed_dim)` encoded output of shape `(seq_len, batch, embed_dim)`
...@@ -139,6 +143,18 @@ class TransformerEncoderLayer(nn.Module): ...@@ -139,6 +143,18 @@ class TransformerEncoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
pos_emb=pos_emb
)
else:
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
key=x, key=x,
...@@ -195,6 +211,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -195,6 +211,7 @@ class TransformerDecoderLayer(nn.Module):
self.cross_self_attention = getattr(args, "cross_self_attention", False) self.cross_self_attention = getattr(args, "cross_self_attention", False)
self.attn_type = getattr(args, "decoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention( self.self_attn = self.build_self_attention(
self.embed_dim, self.embed_dim,
args, args,
...@@ -256,13 +273,12 @@ class TransformerDecoderLayer(nn.Module): ...@@ -256,13 +273,12 @@ class TransformerDecoderLayer(nn.Module):
def build_self_attention( def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
): ):
attn_type = getattr(args, "decoder_attention_type", "selfattn") if self.attn_type == "selfattn":
if attn_type == "selfattn":
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
else: else:
print("The attention type %s is not supported!" % attn_type) print("The attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
return attn_func( return attn_func(
...@@ -277,16 +293,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -277,16 +293,7 @@ class TransformerDecoderLayer(nn.Module):
) )
def build_encoder_attention(self, embed_dim, args): def build_encoder_attention(self, embed_dim, args):
attn_type = getattr(args, "decoder_attention_type", "selfattn") return MultiheadAttention(
if attn_type == "selfattn":
attn_func = MultiheadAttention
elif attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
else:
print("The attention type %s is not supported!" % attn_type)
exit(1)
return attn_func(
embed_dim, embed_dim,
args.decoder_attention_heads, args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None), kdim=getattr(args, "encoder_embed_dim", None),
...@@ -315,6 +322,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -315,6 +322,7 @@ class TransformerDecoderLayer(nn.Module):
self_attn_padding_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False, need_attn: bool = False,
need_head_weights: bool = False, need_head_weights: bool = False,
pos_emb: Optional[Tensor] = None,
): ):
""" """
Args: Args:
...@@ -370,6 +378,19 @@ class TransformerDecoderLayer(nn.Module): ...@@ -370,6 +378,19 @@ class TransformerDecoderLayer(nn.Module):
else: else:
y = x y = x
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
pos_emb=pos_emb
)
else:
x, attn = self.self_attn( x, attn = self.self_attn(
query=x, query=x,
key=y, key=y,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论