Commit c9d8dbc3 by xuchen

fix the bug of the block attention

parent 32216b6d
...@@ -274,14 +274,17 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -274,14 +274,17 @@ class PyS2TTransformerEncoder(FairseqEncoder):
dropout=args.attention_dropout, dropout=args.attention_dropout,
encoder_decoder_attention=True, encoder_decoder_attention=True,
) )
attn_layer_norm = LayerNorm(embed_dim)
else: else:
attn = None attn = None
attn_layer_norm = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed) setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"dropout{i + 1}", dropout) setattr(self, f"dropout{i + 1}", dropout)
setattr(self, f"block{i + 1}", block) setattr(self, f"block{i + 1}", block)
setattr(self, f"attn{i + 1}", attn) setattr(self, f"attn{i + 1}", attn)
setattr(self, f"attn_layer_norm{i + 1}", attn_layer_norm)
if i == self.pyramid_stages - 1: if i == self.pyramid_stages - 1:
if args.encoder_normalize_before: if args.encoder_normalize_before:
...@@ -347,6 +350,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -347,6 +350,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
dropout = getattr(self, f"dropout{i + 1}") dropout = getattr(self, f"dropout{i + 1}")
block = getattr(self, f"block{i + 1}") block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"attn{i + 1}") block_attn = getattr(self, f"attn{i + 1}")
attn_layer_norm = getattr(self, f"attn_layer_norm{i + 1}")
if i == 0: if i == 0:
x = self.embed_scale * x x = self.embed_scale * x
...@@ -371,18 +375,19 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -371,18 +375,19 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
if block_attn is not None: if block_attn is not None:
residual = x residual = x
x = attn_layer_norm(x)
x, attn = block_attn( x, attn = block_attn(
query=x, query=x,
key=prev_state[i-1], key=prev_state[i-1],
value=prev_state[i-1], value=prev_state[i-1],
key_padding_mask=prev_padding[i-1], key_padding_mask=prev_padding[i-1],
) )
x += residual x = residual + x
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x) ctc_logit = self.ctc_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论