Commit 31d0303e by xuchen

support the ppm for pyramid transformer

parent c9d8dbc3
...@@ -54,19 +54,17 @@ class ReducedEmbed(nn.Module): ...@@ -54,19 +54,17 @@ class ReducedEmbed(nn.Module):
elif self.reduced_way == "glu": elif self.reduced_way == "glu":
self.conv = nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding) self.conv = nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding)
self.glu = nn.GLU(dim=1) self.glu = nn.GLU(dim=1)
elif self.reduced_way == "proj":
self.proj = nn.Linear(2 * in_channels, out_channels, bias=False)
else: else:
logger.error("Unsupported reduced way!") logger.error("Unsupported reduced way!")
self.embed_norm = embed_norm self.embed_norm = embed_norm
if self.embed_norm: if self.embed_norm:
# self.norm = LayerNorm(out_channels) if self.reduced_way == "proj":
self.norm = LayerNorm(in_channels) self.norm = LayerNorm(2 * in_channels)
else:
if out_channels % in_channels == 0: self.norm = LayerNorm(out_channels)
self.residual = True
else:
self.residual = False
self.residual = False
def forward(self, x, lengths): def forward(self, x, lengths):
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
...@@ -80,16 +78,18 @@ class ReducedEmbed(nn.Module): ...@@ -80,16 +78,18 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0) x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1) x = x.transpose(0, 1)
if self.residual: if self.reduced_way == "proj":
origin_x = x.transpose(0, 1).contiguous().view(bsz, int(seq_len / self.stride), -1).transpose(0, 1) x = x.transpose(0, 1).contiguous().view(bsz, int(seq_len / 2), -1)
x = self.proj(self.norm(x))
x = x.transpose(0, 1)
else:
x = x.permute(1, 2, 0) # B * D * T
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm: if self.embed_norm:
x = self.norm(x) x = self.norm(x)
x = x.permute(1, 2, 0) # B * D * T
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
x = x.permute(2, 0, 1) # T * B * D
lengths = lengths / self.stride lengths = lengths / self.stride
...@@ -101,15 +101,49 @@ class ReducedEmbed(nn.Module): ...@@ -101,15 +101,49 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0) x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1) x = x.transpose(0, 1)
if self.residual:
if x.size() == origin_x.size():
x += origin_x
else:
logging.error("The size is unmatched {} and {}".format(x.size(), origin_x.size()))
return x, lengths, padding_mask return x, lengths, padding_mask
class BlockFuse(nn.Module):
def __init__(self, embed_dim, prev_embed_dim, dropout, num_head, fuse_way="add"):
super().__init__()
self.attn = MultiheadAttention(
embed_dim,
num_head,
kdim=prev_embed_dim,
vdim=prev_embed_dim,
dropout=dropout,
encoder_decoder_attention=True,
)
self.q_layer_norm = LayerNorm(embed_dim)
self.kv_layer_norm = LayerNorm(prev_embed_dim)
self.fuse_way = fuse_way
if self.fuse_way == "gated":
self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
def forward(self, x, state, padding):
residual = x
x = self.q_layer_norm(x)
state = self.kv_layer_norm(state)
x, attn = self.attn(
query=x,
key=state,
value=state,
key_padding_mask=padding,
static_kv=True,
)
if self.fuse_way == "add":
x = residual + x
elif self.fuse_way == "gated":
coef = (self.gate_linear(torch.cat([x, residual], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * residual
return x
@register_model("pys2t_transformer") @register_model("pys2t_transformer")
class PYS2TTransformerModel(S2TTransformerModel): class PYS2TTransformerModel(S2TTransformerModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
...@@ -148,7 +182,7 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -148,7 +182,7 @@ class PYS2TTransformerModel(S2TTransformerModel):
parser.add_argument( parser.add_argument(
"--pyramid-reduced-embed", "--pyramid-reduced-embed",
type=str, type=str,
choices=["glu", "conv"], choices=["glu", "conv", "proj"],
help="the reduced way of the embedding", help="the reduced way of the embedding",
) )
parser.add_argument( parser.add_argument(
...@@ -157,6 +191,16 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -157,6 +191,16 @@ class PYS2TTransformerModel(S2TTransformerModel):
help="use layer norm in reduced embedding", help="use layer norm in reduced embedding",
) )
parser.add_argument( parser.add_argument(
"--pyramid-block-attn",
action="store_true",
help="use block attention",
)
parser.add_argument(
"--pyramid-fuse-way",
type=str,
help="fused way for block attention",
)
parser.add_argument(
"--pyramid-position-embed", "--pyramid-position-embed",
type=str, type=str,
help="use the position embedding or not", help="use the position embedding or not",
...@@ -181,6 +225,11 @@ class PYS2TTransformerModel(S2TTransformerModel): ...@@ -181,6 +225,11 @@ class PYS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="the number of the attention heads", help="the number of the attention heads",
) )
parser.add_argument(
"--pyramid-use-ppm",
action="store_true",
help="use ppm",
)
parser.add_argument( parser.add_argument(
"--ctc-layer", "--ctc-layer",
...@@ -211,7 +260,11 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -211,7 +260,11 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.padding_idx = 1 self.padding_idx = 1
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.embed_dim = args.encoder_embed_dim
self.dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.pyramid_stages = getattr(args, "pyramid_stages", 4) self.pyramid_stages = getattr(args, "pyramid_stages", 4)
self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")] self.pyramid_layers = [int(n) for n in args.pyramid_layers.split("_")]
...@@ -225,6 +278,10 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -225,6 +278,10 @@ class PyS2TTransformerEncoder(FairseqEncoder):
self.pyramid_reduced_embed = args.pyramid_reduced_embed self.pyramid_reduced_embed = args.pyramid_reduced_embed
self.pyramid_embed_norm = args.pyramid_embed_norm self.pyramid_embed_norm = args.pyramid_embed_norm
self.pyramid_block_attn = getattr(args, "pyramid_block_attn", False)
self.pyramid_fuse_way = getattr(args, "pyramid_fuse_way", "add")
self.use_ppm = getattr(args, "pyramid_use_ppm", False)
for i in range(self.pyramid_stages): for i in range(self.pyramid_stages):
num_layers = self.pyramid_layers[i] num_layers = self.pyramid_layers[i]
sr_ratio = self.pyramid_sr_ratios[i] sr_ratio = self.pyramid_sr_ratios[i]
...@@ -242,7 +299,8 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -242,7 +299,8 @@ class PyS2TTransformerEncoder(FairseqEncoder):
reduced_embed = ReducedEmbed( reduced_embed = ReducedEmbed(
self.pyramid_reduced_embed, self.pyramid_reduced_embed,
self.pyramid_embed_norm if i != 0 else False, self.pyramid_embed_norm,
# self.pyramid_embed_norm if i != 0 else False,
args.input_feat_per_channel * args.input_channels if i == 0 else self.pyramid_embed_dims[i-1], args.input_feat_per_channel * args.input_channels if i == 0 else self.pyramid_embed_dims[i-1],
embed_dim, embed_dim,
kernel_sizes=kernel_size, kernel_sizes=kernel_size,
...@@ -257,34 +315,33 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -257,34 +315,33 @@ class PyS2TTransformerEncoder(FairseqEncoder):
else: else:
pos_embed = None pos_embed = None
dropout = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
block = nn.ModuleList([ block = nn.ModuleList([
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio) PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio)
for _ in range(num_layers)]) for _ in range(num_layers)])
if i != 0: block_attn = None
attn = MultiheadAttention( if self.pyramid_block_attn:
embed_dim, if i != 0:
num_head, block_attn = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1],
kdim=self.pyramid_embed_dims[i-1], args.dropout, num_head, self.pyramid_fuse_way)
vdim=self.pyramid_embed_dims[i-1],
dropout=args.attention_dropout, if self.use_ppm:
encoder_decoder_attention=True, ppm_layer_norm = LayerNorm(embed_dim)
ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
) )
attn_layer_norm = LayerNorm(embed_dim)
else: else:
attn = None ppm_layer_norm = None
attn_layer_norm = None ppm = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed) setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"dropout{i + 1}", dropout)
setattr(self, f"block{i + 1}", block) setattr(self, f"block{i + 1}", block)
setattr(self, f"attn{i + 1}", attn) setattr(self, f"block_attn{i + 1}", block_attn)
setattr(self, f"attn_layer_norm{i + 1}", attn_layer_norm) setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_layer_norm{i + 1}", ppm_layer_norm)
if i == self.pyramid_stages - 1: if i == self.pyramid_stages - 1:
if args.encoder_normalize_before: if args.encoder_normalize_before:
...@@ -292,6 +349,10 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -292,6 +349,10 @@ class PyS2TTransformerEncoder(FairseqEncoder):
else: else:
self.layer_norm = None self.layer_norm = None
if self.use_ppm:
self.ppm_weight = nn.Parameter(torch.Tensor(self.pyramid_stages).fill_(1.0))
self.ppm_weight.data = self.ppm_weight.data / self.ppm_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and (("ctc" in getattr(args, "criterion", False)) and
(getattr(args, "ctc_weight", False) > 0)) (getattr(args, "ctc_weight", False) > 0))
...@@ -347,30 +408,25 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -347,30 +408,25 @@ class PyS2TTransformerEncoder(FairseqEncoder):
for i in range(self.pyramid_stages): for i in range(self.pyramid_stages):
reduced_embed = getattr(self, f"reduced_embed{i + 1}") reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}")
dropout = getattr(self, f"dropout{i + 1}")
block = getattr(self, f"block{i + 1}") block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"attn{i + 1}") block_attn = getattr(self, f"block_attn{i + 1}")
attn_layer_norm = getattr(self, f"attn_layer_norm{i + 1}")
if i == 0: # if i == 0:
x = self.embed_scale * x # x = self.embed_scale * x
# reduced embed
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths) x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths)
# max_lens = int(x.size(0))
# encoder_padding_mask = lengths_to_padding_mask_with_maxlen(input_lengths, max_lens)
# add the position encoding and dropout # add the position encoding and dropout
if pos_embed: if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1) positions = pos_embed(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn": if self.attn_type != "rel_selfattn":
x += positions x += positions
if i == 0: positions = self.dropout(positions)
x = dropout(x)
positions = dropout(positions)
else: else:
positions = None positions = None
if i == 0:
x = self.dropout(x)
for layer in block: for layer in block:
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
...@@ -379,19 +435,28 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -379,19 +435,28 @@ class PyS2TTransformerEncoder(FairseqEncoder):
prev_padding.append(encoder_padding_mask) prev_padding.append(encoder_padding_mask)
if block_attn is not None: if block_attn is not None:
residual = x x = block_attn(x, prev_state[-1], prev_padding[-1])
x = attn_layer_norm(x)
x, attn = block_attn(
query=x,
key=prev_state[i-1],
value=prev_state[i-1],
key_padding_mask=prev_padding[i-1],
)
x = residual + x
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x) ctc_logit = self.ctc_layer_norm(x)
if self.use_ppm:
pool_state = [x]
seq_len, bsz, dim = x.size()
i = -1
for state in prev_state[:-1]:
i += 1
ppm = getattr(self, f"ppm{i + 1}")
ppm_layer_norm = getattr(self, f"ppm_layer_norm{i + 1}")
state = ppm_layer_norm(state)
state = state.permute(1, 2, 0)
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = ppm(state)
state = state.permute(2, 0, 1)
pool_state.append(state)
x = (torch.stack(pool_state, dim=0) * self.ppm_weight.view(-1, 1, 1, 1)).sum(0)
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
......
...@@ -204,7 +204,7 @@ class ReducedMultiheadAttention(nn.Module): ...@@ -204,7 +204,7 @@ class ReducedMultiheadAttention(nn.Module):
q = self.q_proj(query) q = self.q_proj(query)
if self.self_attention: if self.self_attention:
if self.sample_ratio > 1: if self.sample_ratio > 1:
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len query_ = query.permute(1, 2, 0) # bsz, dim, seq_len:
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim
query = self.norm(query_) query = self.norm(query_)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论