Commit 4b4603c0 by xuchen

optimize the code of pyramid transformer

parent 80e64569
......@@ -50,64 +50,81 @@ class ReducedEmbed(nn.Module):
self.stride = stride
self.reduced_way = reduced_way
if self.reduced_way == "conv":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding)
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
)
elif self.reduced_way == "glu":
self.conv = nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding)
self.glu = nn.GLU(dim=1)
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding),
nn.GLU(dim=1)
)
elif self.reduced_way == "proj":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, padding=padding)
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
nn.ReLU()
)
elif self.reduced_way == "fuse":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding)
self.conv_proj = nn.Conv1d(in_channels, out_channels, kernel_sizes, padding=padding)
self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
)
self.pool_conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1),
)
else:
logger.error("Unsupported reduced way!")
self.embed_norm = embed_norm
if self.embed_norm:
if self.reduced_way in ["proj", "fuse"]:
self.in_norm = nn.BatchNorm1d(in_channels)
# if self.reduced_way == "fuse":
# self.in_norm = LayerNorm(in_channels)
self.norm = LayerNorm(out_channels)
def forward(self, x, lengths):
seq_len, bsz, dim = x.size()
assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
# assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, seq_len) # bsz, seq_len
mask_pad = padding_mask.unsqueeze(2)
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
if not torch.all(lengths == seq_len):
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, seq_len) # bsz, seq_len
mask_pad = padding_mask.unsqueeze(2)
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
lengths = ((lengths.float() - 1) / self.stride + 1).floor().long()
out_seq_len = max(lengths).item()
if self.reduced_way == "proj":
x = x.permute(1, 2, 0) # bsz, dim, seq_len
x = nn.functional.adaptive_avg_pool1d(x, int(seq_len // self.stride))
x = nn.functional.adaptive_avg_pool1d(x, out_seq_len)
x = self.conv(self.in_norm(x))
x = x.permute(2, 0, 1) # seq_len, bsz, dim
else:
# if self.reduced_way == "fuse":
# x = self.in_norm(x)
x = x.permute(1, 2, 0) # B * D * T
origin_x = x
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
if self.reduced_way == "fuse":
x2 = nn.functional.adaptive_avg_pool1d(origin_x, int(seq_len // self.stride))
x2 = self.conv_proj(self.in_norm(x2))
x2 = nn.functional.adaptive_avg_pool1d(origin_x, x.size(-1))
x2 = self.pool_conv(x2)
x = x + x2
x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm:
x = self.norm(x)
lengths = lengths / self.stride
# assert max(lengths) == x.size(0)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
mask_pad = padding_mask.unsqueeze(2)
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
if not torch.all(lengths == x.size(-1)):
mask_pad = padding_mask.unsqueeze(2)
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x, lengths, padding_mask
......@@ -119,10 +136,13 @@ class BlockFuse(nn.Module):
self.conv = nn.Sequential(
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1),
nn.BatchNorm1d(embed_dim),
nn.ReLU()
)
self.layer_norm = LayerNorm(embed_dim)
self.kv_layer_norm = LayerNorm(prev_embed_dim)
self.pre_layer_norm = LayerNorm(prev_embed_dim)
self.post_layer_norm = LayerNorm(embed_dim)
self.final_layer_norm = LayerNorm(embed_dim)
self.fuse_way = fuse_way
self.final_stage = final_stage
......@@ -133,25 +153,29 @@ class BlockFuse(nn.Module):
def forward(self, x, state, padding):
seq_len, bsz, dim = x.size()
state = self.kv_layer_norm(state)
state = self.pre_layer_norm(state)
state = state.permute(1, 2, 0) # bsz, dim, seq_len
if state.size(-1) != seq_len:
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = self.conv(state)
state = state.permute(2, 0, 1) # seq_len, bsz, dim
state = self.post_layer_norm(state)
if self.fuse_way == "gated":
# x = x.contiguous().view(-1, dim)
# state = state.contiguous().view(-1, dim)
# x = self.gate(x, state).view(seq_len, bsz, dim)
coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state
else:
x = x + state
# if not self.final_stage:
x = self.layer_norm(x)
x = self.final_layer_norm(x)
mask_pad = padding.unsqueeze(2)
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x
......@@ -434,7 +458,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout
if block_fuse is not None:
x = block_fuse(x, prev_state[-1], prev_padding[-1])
x = block_fuse(x, prev_state[-1], encoder_padding_mask)
if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论