Commit 31d0303e by xuchen

support the ppm for pyramid transformer

parent c9d8dbc3
......@@ -54,19 +54,17 @@ class ReducedEmbed(nn.Module):
elif self.reduced_way == "glu":
self.conv = nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding)
self.glu = nn.GLU(dim=1)
elif self.reduced_way == "proj":
self.proj = nn.Linear(2 * in_channels, out_channels, bias=False)
else:
logger.error("Unsupported reduced way!")
self.embed_norm = embed_norm
if self.embed_norm:
# self.norm = LayerNorm(out_channels)
self.norm = LayerNorm(in_channels)
if out_channels % in_channels == 0:
self.residual = True
else:
self.residual = False
self.residual = False
if self.reduced_way == "proj":
self.norm = LayerNorm(2 * in_channels)
else:
self.norm = LayerNorm(out_channels)
def forward(self, x, lengths):
seq_len, bsz, dim = x.size()
......@@ -80,16 +78,18 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
if self.residual:
origin_x = x.transpose(0, 1).contiguous().view(bsz, int(seq_len / self.stride), -1).transpose(0, 1)
if self.reduced_way == "proj":
x = x.transpose(0, 1).contiguous().view(bsz, int(seq_len / 2), -1)
x = self.proj(self.norm(x))
x = x.transpose(0, 1)
else:
x = x.permute(1, 2, 0) # B * D * T
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm:
x = self.norm(x)
x = x.permute(1, 2, 0) # B * D * T
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
x = x.permute(2, 0, 1) # T * B * D
lengths = lengths / self.stride
......@@ -101,15 +101,49 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
if self.residual:
if x.size() == origin_x.size():
x += origin_x
else:
logging.error("The size is unmatched {} and {}".format(x.size(), origin_x.size()))
return x, lengths, padding_mask
class BlockFuse(nn.Module):
def __init__(self, embed_dim, prev_embed_dim, dropout, num_head, fuse_way="add"):
super().__init__()
self.attn = MultiheadAttention(
embed_dim,
num_head,
kdim=prev_embed_dim,
vdim=prev_embed_dim,
dropout=dropout,
encoder_decoder_attention=True,
)
self.q_layer_norm = LayerNorm(embed_dim)
self.kv_layer_norm = LayerNorm(prev_embed_dim)
self.fuse_way = fuse_way
if self.fuse_way == "gated":
self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
def forward(self, x, state, padding):
residual = x
x = self.q_layer_norm(x)
state = self.kv_layer_norm(state)
x, attn = self.attn(
query=x,
key=state,
value=state,
key_padding_mask=padding,
static_kv=True,
)
if self.fuse_way == "add":
x = residual + x
elif self.fuse_way == "gated":
coef = (self.gate_linear(torch.cat([x, residual], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * residual
return x
@register_model("pys2t_transformer")
class PYS2TTransformerModel(S2TTransformerModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
......@@ -148,7 +182,7 @@ class PYS2TTransformerModel(S2TTransformerModel):
parser.add_argument(
"--pyramid-reduced-embed",
type=str,
choices=["glu", "conv"],
choices=["glu", "conv", "proj"],
help="the reduced way of the embedding",
)
parser.add_argument(
......@@ -157,6 +191,16 @@ class PYS2TTransformerModel(S2TTransformerModel):
help="use layer norm in reduced embedding",
)
parser.add_argument(
"--pyramid-block-attn",
action="store_true",
help="use block attention",
)
parser.add_argument(
"--pyramid-fuse-way",
type=str,
help="fused way for block attention",
)
parser.add_argument(
"--pyramid-position-embed",
type=str,
help="use the position embedding or not",
......@@ -181,6 +225,11 @@ class PYS2TTransformerModel(S2TTransformerModel):
type=str,
help="the number of the attention heads",
)
parser.add_argument(
"--pyramid-use-ppm",
action="store_true",
help="use ppm",
)
parser.add_argument(
"--ctc-layer",
......@@ -211,7 +260,11 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.padding_idx = 1
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_dim = args.encoder_embed_dim
self.dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.pyramid_stages = getattr(args, "pyramid_stages", 4)
self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")]
......@@ -225,6 +278,10 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.pyramid_reduced_embed = args.pyramid_reduced_embed
self.pyramid_embed_norm = args.pyramid_embed_norm
self.pyramid_block_attn = getattr(args, "pyramid_block_attn", False)
self.pyramid_fuse_way = getattr(args, "pyramid_fuse_way", "add")
self.use_ppm = getattr(args, "pyramid_use_ppm", False)
for i in range(self.pyramid_stages):
num_layers = self.pyramid_layers[i]
sr_ratio = self.pyramid_sr_ratios[i]
......@@ -242,7 +299,8 @@ class PyS2TTransformerEncoder(FairseqEncoder):
reduced_embed = ReducedEmbed(
self.pyramid_reduced_embed,
self.pyramid_embed_norm if i != 0 else False,
self.pyramid_embed_norm,
# self.pyramid_embed_norm if i != 0 else False,
args.input_feat_per_channel * args.input_channels if i == 0 else self.pyramid_embed_dims[i-1],
embed_dim,
kernel_sizes=kernel_size,
......@@ -257,34 +315,33 @@ class PyS2TTransformerEncoder(FairseqEncoder):
else:
pos_embed = None
dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
block = nn.ModuleList([
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio)
for _ in range(num_layers)])
if i != 0:
attn = MultiheadAttention(
embed_dim,
num_head,
kdim=self.pyramid_embed_dims[i-1],
vdim=self.pyramid_embed_dims[i-1],
dropout=args.attention_dropout,
encoder_decoder_attention=True,
block_attn = None
if self.pyramid_block_attn:
if i != 0:
block_attn = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1],
args.dropout, num_head, self.pyramid_fuse_way)
if self.use_ppm:
ppm_layer_norm = LayerNorm(embed_dim)
ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
)
attn_layer_norm = LayerNorm(embed_dim)
else:
attn = None
attn_layer_norm = None
ppm_layer_norm = None
ppm = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"dropout{i + 1}", dropout)
setattr(self, f"block{i + 1}", block)
setattr(self, f"attn{i + 1}", attn)
setattr(self, f"attn_layer_norm{i + 1}", attn_layer_norm)
setattr(self, f"block_attn{i + 1}", block_attn)
setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_layer_norm{i + 1}", ppm_layer_norm)
if i == self.pyramid_stages - 1:
if args.encoder_normalize_before:
......@@ -292,6 +349,10 @@ class PyS2TTransformerEncoder(FairseqEncoder):
else:
self.layer_norm = None
if self.use_ppm:
self.ppm_weight = nn.Parameter(torch.Tensor(self.pyramid_stages).fill_(1.0))
self.ppm_weight.data = self.ppm_weight.data / self.ppm_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and
(getattr(args, "ctc_weight", False) > 0))
......@@ -347,30 +408,25 @@ class PyS2TTransformerEncoder(FairseqEncoder):
for i in range(self.pyramid_stages):
reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
dropout = getattr(self, f"dropout{i + 1}")
block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"attn{i + 1}")
attn_layer_norm = getattr(self, f"attn_layer_norm{i + 1}")
block_attn = getattr(self, f"block_attn{i + 1}")
if i == 0:
x = self.embed_scale * x
# if i == 0:
# x = self.embed_scale * x
# reduced embed
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths)
# max_lens = int(x.size(0))
# encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, max_lens)
# add the position encoding and dropout
if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
if i == 0:
x = dropout(x)
positions = dropout(positions)
positions = self.dropout(positions)
else:
positions = None
if i == 0:
x = self.dropout(x)
for layer in block:
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
......@@ -379,19 +435,28 @@ class PyS2TTransformerEncoder(FairseqEncoder):
prev_padding.append(encoder_padding_mask)
if block_attn is not None:
residual = x
x = attn_layer_norm(x)
x, attn = block_attn(
query=x,
key=prev_state[i-1],
value=prev_state[i-1],
key_padding_mask=prev_padding[i-1],
)
x = residual + x
x = block_attn(x, prev_state[-1], prev_padding[-1])
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
if self.use_ppm:
pool_state = [x]
seq_len, bsz, dim = x.size()
i = -1
for state in prev_state[:-1]:
i += 1
ppm = getattr(self, f"ppm{i + 1}")
ppm_layer_norm = getattr(self, f"ppm_layer_norm{i + 1}")
state = ppm_layer_norm(state)
state = state.permute(1, 2, 0)
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = ppm(state)
state = state.permute(2, 0, 1)
pool_state.append(state)
x = (torch.stack(pool_state, dim=0) * self.ppm_weight.view(-1, 1, 1, 1)).sum(0)
if self.layer_norm is not None:
x = self.layer_norm(x)
......
......@@ -204,7 +204,7 @@ class ReducedMultiheadAttention(nn.Module):
q = self.q_proj(query)
if self.self_attention:
if self.sample_ratio > 1:
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len:
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim
query = self.norm(query_)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论