Commit 0ce623be by xuchen

add the block attention for pyramid transformer

parent 6292949b
...@@ -12,7 +12,7 @@ log-interval: 100 ...@@ -12,7 +12,7 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
#arch: s2t_transformer_s arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -26,7 +26,7 @@ ctc-weight: 0.3 ...@@ -26,7 +26,7 @@ ctc-weight: 0.3
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
#conv-kernel-sizes: 5,5 conv-kernel-sizes: 5,5
conv-channels: 1024 conv-channels: 1024
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
......
...@@ -20,6 +20,7 @@ from fairseq.modules import ( ...@@ -20,6 +20,7 @@ from fairseq.modules import (
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
PyramidTransformerEncoderLayer, PyramidTransformerEncoderLayer,
MultiheadAttention,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -61,6 +62,12 @@ class ReducedEmbed(nn.Module): ...@@ -61,6 +62,12 @@ class ReducedEmbed(nn.Module):
# self.norm = LayerNorm(out_channels) # self.norm = LayerNorm(out_channels)
self.norm = LayerNorm(in_channels) self.norm = LayerNorm(in_channels)
if out_channels % in_channels == 0:
self.residual = True
else:
self.residual = False
self.residual = False
def forward(self, x, lengths): def forward(self, x, lengths):
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride) assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
...@@ -73,6 +80,9 @@ class ReducedEmbed(nn.Module): ...@@ -73,6 +80,9 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0) x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1) x = x.transpose(0, 1)
if self.residual:
origin_x = x.transpose(0, 1).contiguous().view(bsz, int(seq_len / self.stride), -1).transpose(0, 1)
if self.embed_norm: if self.embed_norm:
x = self.norm(x) x = self.norm(x)
x = x.permute(1, 2, 0) # B * D * T x = x.permute(1, 2, 0) # B * D * T
...@@ -91,6 +101,12 @@ class ReducedEmbed(nn.Module): ...@@ -91,6 +101,12 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0) x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1) x = x.transpose(0, 1)
if self.residual:
if x.size() == origin_x.size():
x += origin_x
else:
logging.error("The size is unmatched {} and {}".format(x.size(), origin_x.size()))
return x, lengths, padding_mask return x, lengths, padding_mask
...@@ -249,10 +265,25 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -249,10 +265,25 @@ class PyS2TTransformerEncoder(FairseqEncoder):
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio) PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio)
for _ in range(num_layers)]) for _ in range(num_layers)])
if i != 0:
attn = MultiheadAttention(
embed_dim,
num_head,
kdim=self.pyramid_embed_dims[i-1],
vdim=self.pyramid_embed_dims[i-1],
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
else:
attn = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed) setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"dropout{i + 1}", dropout) setattr(self, f"dropout{i + 1}", dropout)
setattr(self, f"block{i + 1}", block) setattr(self, f"block{i + 1}", block)
setattr(self, f"attn{i + 1}", attn)
if i == self.pyramid_stages - 1: if i == self.pyramid_stages - 1:
if args.encoder_normalize_before: if args.encoder_normalize_before:
...@@ -310,11 +341,14 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -310,11 +341,14 @@ class PyS2TTransformerEncoder(FairseqEncoder):
layer_idx = 0 layer_idx = 0
ctc_logit = None ctc_logit = None
prev_state = []
prev_padding = []
for i in range(self.pyramid_stages): for i in range(self.pyramid_stages):
reduced_embed = getattr(self, f"reduced_embed{i + 1}") reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}")
dropout = getattr(self, f"dropout{i + 1}") dropout = getattr(self, f"dropout{i + 1}")
block = getattr(self, f"block{i + 1}") block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"attn{i + 1}")
if i == 0: if i == 0:
x = self.embed_scale * x x = self.embed_scale * x
...@@ -339,6 +373,19 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -339,6 +373,19 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
if attn is not None:
residual = x
x, attn = block_attn(
query=x,
key=prev_state[i-1],
value=prev_state[i-1],
key_padding_mask=prev_padding[i-1],
)
x += residual
prev_state[i] = x
prev_padding[i] = encoder_padding_mask
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x) ctc_logit = self.ctc_layer_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论