Commit 32216b6d by xuchen

fix the bug of the block attention

parent 0ce623be
......@@ -273,8 +273,6 @@ class PyS2TTransformerEncoder(FairseqEncoder):
vdim=self.pyramid_embed_dims[i-1],
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
else:
attn = None
......@@ -373,7 +371,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if attn is not None:
if block_attn is not None:
residual = x
x, attn = block_attn(
query=x,
......@@ -383,8 +381,8 @@ class PyS2TTransformerEncoder(FairseqEncoder):
)
x += residual
prev_state[i] = x
prev_padding[i] = encoder_padding_mask
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论