Commit 80e64569 by xuchen

update the pyramid transformer about block fuse

parent 9fadf1f4
......@@ -118,7 +118,7 @@ class BlockFuse(nn.Module):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1, bias=False),
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1),
nn.ReLU()
)
self.layer_norm = LayerNorm(embed_dim)
......@@ -146,7 +146,6 @@ class BlockFuse(nn.Module):
# x = self.gate(x, state).view(seq_len, bsz, dim)
coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state
x = state + x
else:
x = x + state
......@@ -345,7 +344,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
ppm_pre_layer_norm = LayerNorm(embed_dim)
ppm_post_layer_norm = LayerNorm(self.embed_dim)
ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False),
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
)
......@@ -361,7 +360,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
setattr(self, f"block_fuse{i + 1}", block_fuse)
setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_pre_layer_norm{i + 1}", ppm_pre_layer_norm)
setattr(self, f"ppm_layer_norm2{i + 1}", ppm_post_layer_norm)
setattr(self, f"ppm_post_layer_norm{i + 1}", ppm_post_layer_norm)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论