Commit 4b4603c0 by xuchen

optimize the code of pyramid transformer

parent 80e64569
...@@ -50,64 +50,81 @@ class ReducedEmbed(nn.Module): ...@@ -50,64 +50,81 @@ class ReducedEmbed(nn.Module):
self.stride = stride self.stride = stride
self.reduced_way = reduced_way self.reduced_way = reduced_way
if self.reduced_way == "conv": if self.reduced_way == "conv":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding) self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
)
elif self.reduced_way == "glu": elif self.reduced_way == "glu":
self.conv = nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding) self.conv = nn.Sequential(
self.glu = nn.GLU(dim=1) nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding),
nn.GLU(dim=1)
)
elif self.reduced_way == "proj": elif self.reduced_way == "proj":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, padding=padding) self.conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
nn.ReLU()
)
elif self.reduced_way == "fuse": elif self.reduced_way == "fuse":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding) self.conv = nn.Sequential(
self.conv_proj = nn.Conv1d(in_channels, out_channels, kernel_sizes, padding=padding) nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding),
)
self.pool_conv = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1),
)
else: else:
logger.error("Unsupported reduced way!") logger.error("Unsupported reduced way!")
self.embed_norm = embed_norm self.embed_norm = embed_norm
if self.embed_norm: if self.embed_norm:
if self.reduced_way in ["proj", "fuse"]: # if self.reduced_way == "fuse":
self.in_norm = nn.BatchNorm1d(in_channels) # self.in_norm = LayerNorm(in_channels)
self.norm = LayerNorm(out_channels) self.norm = LayerNorm(out_channels)
def forward(self, x, lengths): def forward(self, x, lengths):
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride) # assert seq_len % self.stride == 0, "The sequence length %d must be a multiple of %d." % (seq_len, self.stride)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, seq_len) # bsz, seq_len
mask_pad = padding_mask.unsqueeze(2)
# mask batch padding # mask batch padding
if mask_pad is not None: if not torch.all(lengths == seq_len):
x = x.transpose(0, 1) padding_mask = lengths_to_padding_mask_with_maxlen(lengths, seq_len) # bsz, seq_len
x.masked_fill_(mask_pad, 0.0) mask_pad = padding_mask.unsqueeze(2)
x = x.transpose(0, 1) if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
lengths = ((lengths.float() - 1) / self.stride + 1).floor().long()
out_seq_len = max(lengths).item()
if self.reduced_way == "proj": if self.reduced_way == "proj":
x = x.permute(1, 2, 0) # bsz, dim, seq_len x = x.permute(1, 2, 0) # bsz, dim, seq_len
x = nn.functional.adaptive_avg_pool1d(x, int(seq_len // self.stride)) x = nn.functional.adaptive_avg_pool1d(x, out_seq_len)
x = self.conv(self.in_norm(x)) x = self.conv(self.in_norm(x))
x = x.permute(2, 0, 1) # seq_len, bsz, dim x = x.permute(2, 0, 1) # seq_len, bsz, dim
else: else:
# if self.reduced_way == "fuse":
# x = self.in_norm(x)
x = x.permute(1, 2, 0) # B * D * T x = x.permute(1, 2, 0) # B * D * T
origin_x = x origin_x = x
x = self.conv(x) x = self.conv(x)
if self.reduced_way == "glu": if self.reduced_way == "glu":
x = self.glu(x) x = self.glu(x)
if self.reduced_way == "fuse": if self.reduced_way == "fuse":
x2 = nn.functional.adaptive_avg_pool1d(origin_x, int(seq_len // self.stride)) x2 = nn.functional.adaptive_avg_pool1d(origin_x, x.size(-1))
x2 = self.conv_proj(self.in_norm(x2)) x2 = self.pool_conv(x2)
x = x + x2 x = x + x2
x = x.permute(2, 0, 1) # T * B * D x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm: if self.embed_norm:
x = self.norm(x) x = self.norm(x)
lengths = lengths / self.stride # assert max(lengths) == x.size(0)
padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0)) padding_mask = lengths_to_padding_mask_with_maxlen(lengths, x.size(0))
mask_pad = padding_mask.unsqueeze(2)
# mask batch padding # mask batch padding
if mask_pad is not None: if not torch.all(lengths == x.size(-1)):
x = x.transpose(0, 1) mask_pad = padding_mask.unsqueeze(2)
x.masked_fill_(mask_pad, 0.0) if mask_pad is not None:
x = x.transpose(0, 1) x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x, lengths, padding_mask return x, lengths, padding_mask
...@@ -119,10 +136,13 @@ class BlockFuse(nn.Module): ...@@ -119,10 +136,13 @@ class BlockFuse(nn.Module):
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1), nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1),
nn.BatchNorm1d(embed_dim),
nn.ReLU() nn.ReLU()
) )
self.layer_norm = LayerNorm(embed_dim) self.pre_layer_norm = LayerNorm(prev_embed_dim)
self.kv_layer_norm = LayerNorm(prev_embed_dim) self.post_layer_norm = LayerNorm(embed_dim)
self.final_layer_norm = LayerNorm(embed_dim)
self.fuse_way = fuse_way self.fuse_way = fuse_way
self.final_stage = final_stage self.final_stage = final_stage
...@@ -133,25 +153,29 @@ class BlockFuse(nn.Module): ...@@ -133,25 +153,29 @@ class BlockFuse(nn.Module):
def forward(self, x, state, padding): def forward(self, x, state, padding):
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
state = self.kv_layer_norm(state) state = self.pre_layer_norm(state)
state = state.permute(1, 2, 0) # bsz, dim, seq_len state = state.permute(1, 2, 0) # bsz, dim, seq_len
if state.size(-1) != seq_len: if state.size(-1) != seq_len:
state = nn.functional.adaptive_avg_pool1d(state, seq_len) state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = self.conv(state) state = self.conv(state)
state = state.permute(2, 0, 1) # seq_len, bsz, dim state = state.permute(2, 0, 1) # seq_len, bsz, dim
state = self.post_layer_norm(state)
if self.fuse_way == "gated": if self.fuse_way == "gated":
# x = x.contiguous().view(-1, dim)
# state = state.contiguous().view(-1, dim)
# x = self.gate(x, state).view(seq_len, bsz, dim)
coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state x = coef * x + (1 - coef) * state
else: else:
x = x + state x = x + state
# if not self.final_stage: # if not self.final_stage:
x = self.layer_norm(x) x = self.final_layer_norm(x)
mask_pad = padding.unsqueeze(2)
# mask batch padding
if mask_pad is not None:
x = x.transpose(0, 1)
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
return x return x
...@@ -434,7 +458,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -434,7 +458,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout # add the position encoding and dropout
if block_fuse is not None: if block_fuse is not None:
x = block_fuse(x, prev_state[-1], prev_padding[-1]) x = block_fuse(x, prev_state[-1], encoder_padding_mask)
if pos_embed: if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1) positions = pos_embed(encoder_padding_mask).transpose(0, 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论