Commit c9d8dbc3 by xuchen

fix the bug of the block attention

parent 32216b6d
......@@ -274,14 +274,17 @@ class PyS2TTransformerEncoder(FairseqEncoder):
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
attn_layer_norm = LayerNorm(embed_dim)
else:
attn = None
attn_layer_norm = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"dropout{i + 1}", dropout)
setattr(self, f"block{i + 1}", block)
setattr(self, f"attn{i + 1}", attn)
setattr(self, f"attn_layer_norm{i + 1}", attn_layer_norm)
if i == self.pyramid_stages - 1:
if args.encoder_normalize_before:
......@@ -347,6 +350,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
dropout = getattr(self, f"dropout{i + 1}")
block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"attn{i + 1}")
attn_layer_norm = getattr(self, f"attn_layer_norm{i + 1}")
if i == 0:
x = self.embed_scale * x
......@@ -371,18 +375,19 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
if block_attn is not None:
residual = x
x = attn_layer_norm(x)
x, attn = block_attn(
query=x,
key=prev_state[i-1],
value=prev_state[i-1],
key_padding_mask=prev_padding[i-1],
)
x += residual
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
x = residual + x
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论