Commit 32216b6d by xuchen

fix the bug of the block attention

parent 0ce623be
...@@ -273,8 +273,6 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -273,8 +273,6 @@ class PyS2TTransformerEncoder(FairseqEncoder):
vdim=self.pyramid_embed_dims[i-1], vdim=self.pyramid_embed_dims[i-1],
dropout=args.attention_dropout, dropout=args.attention_dropout,
encoder_decoder_attention=True, encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
) )
else: else:
attn = None attn = None
...@@ -373,7 +371,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -373,7 +371,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
if attn is not None: if block_attn is not None:
residual = x residual = x
x, attn = block_attn( x, attn = block_attn(
query=x, query=x,
...@@ -383,8 +381,8 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -383,8 +381,8 @@ class PyS2TTransformerEncoder(FairseqEncoder):
) )
x += residual x += residual
prev_state[i] = x prev_state.append(x)
prev_padding[i] = encoder_padding_mask prev_padding.append(encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x) ctc_logit = self.ctc_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论