Commit 0ce623be by xuchen

add the block attention for pyramid transformer

parent 6292949b
......@@ -12,7 +12,7 @@ log-interval: 100
seed: 1
report-accuracy: True
#arch: s2t_transformer_s
arch: s2t_transformer_s
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......@@ -26,7 +26,7 @@ ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
#conv-kernel-sizes: 5,5
conv-kernel-sizes: 5,5
conv-channels: 1024
dropout: 0.1
activation-fn: relu
......
......@@ -20,6 +20,7 @@ from fairseq.modules import (
LayerNorm,
PositionalEmbedding,
PyramidTransformerEncoderLayer,
MultiheadAttention,
)
logger = logging.getLogger(__name__)
......@@ -61,6 +62,12 @@ class ReducedEmbed(nn.Module):
# self.norm = LayerNorm(out_channels)
self.norm = LayerNorm(in_channels)
if out_channels % in_channels == 0:
self.residual = True
else:
self.residual = False
self.residual = False
def forward(self, x, lengths):
seq_len, bsz, dim = x.size()
assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
......@@ -73,6 +80,9 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
if self.residual:
origin_x = x.transpose(0, 1).contiguous().view(bsz, int(seq_len / self.stride), -1).transpose(0, 1)
if self.embed_norm:
x = self.norm(x)
x = x.permute(1, 2, 0) # B * D * T
......@@ -91,6 +101,12 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
if self.residual:
if x.size() == origin_x.size():
x += origin_x
else:
logging.error("The size is unmatched {} and {}".format(x.size(), origin_x.size()))
return x, lengths, padding_mask
......@@ -249,10 +265,25 @@ class PyS2TTransformerEncoder(FairseqEncoder):
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio)
for _ in range(num_layers)])
if i != 0:
attn = MultiheadAttention(
embed_dim,
num_head,
kdim=self.pyramid_embed_dims[i-1],
vdim=self.pyramid_embed_dims[i-1],
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
else:
attn = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"dropout{i + 1}", dropout)
setattr(self, f"block{i + 1}", block)
setattr(self, f"attn{i + 1}", attn)
if i == self.pyramid_stages - 1:
if args.encoder_normalize_before:
......@@ -310,11 +341,14 @@ class PyS2TTransformerEncoder(FairseqEncoder):
layer_idx = 0
ctc_logit = None
prev_state = []
prev_padding = []
for i in range(self.pyramid_stages):
reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
dropout = getattr(self, f"dropout{i + 1}")
block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"attn{i + 1}")
if i == 0:
x = self.embed_scale * x
......@@ -339,6 +373,19 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if attn is not None:
residual = x
x, attn = block_attn(
query=x,
key=prev_state[i-1],
value=prev_state[i-1],
key_padding_mask=prev_padding[i-1],
)
x += residual
prev_state[i] = x
prev_padding[i] = encoder_padding_mask
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论