Commit 80e64569 by xuchen

update the pyramid transformer about block fuse

parent 9fadf1f4
...@@ -118,7 +118,7 @@ class BlockFuse(nn.Module): ...@@ -118,7 +118,7 @@ class BlockFuse(nn.Module):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1, bias=False), nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1),
nn.ReLU() nn.ReLU()
) )
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
...@@ -146,7 +146,6 @@ class BlockFuse(nn.Module): ...@@ -146,7 +146,6 @@ class BlockFuse(nn.Module):
# x = self.gate(x, state).view(seq_len, bsz, dim) # x = self.gate(x, state).view(seq_len, bsz, dim)
coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state x = coef * x + (1 - coef) * state
x = state + x
else: else:
x = x + state x = x + state
...@@ -345,7 +344,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -345,7 +344,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
ppm_pre_layer_norm = LayerNorm(embed_dim) ppm_pre_layer_norm = LayerNorm(embed_dim)
ppm_post_layer_norm = LayerNorm(self.embed_dim) ppm_post_layer_norm = LayerNorm(self.embed_dim)
ppm = nn.Sequential( ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False), nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1),
nn.BatchNorm1d(self.embed_dim), nn.BatchNorm1d(self.embed_dim),
nn.ReLU(), nn.ReLU(),
) )
...@@ -361,7 +360,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -361,7 +360,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
setattr(self, f"block_fuse{i + 1}", block_fuse) setattr(self, f"block_fuse{i + 1}", block_fuse)
setattr(self, f"ppm{i + 1}", ppm) setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_pre_layer_norm{i + 1}", ppm_pre_layer_norm) setattr(self, f"ppm_pre_layer_norm{i + 1}", ppm_pre_layer_norm)
setattr(self, f"ppm_layer_norm2{i + 1}", ppm_post_layer_norm) setattr(self, f"ppm_post_layer_norm{i + 1}", ppm_post_layer_norm)
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim) self.layer_norm = LayerNorm(self.embed_dim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论