Commit bd8fae9b by xuchen

fix the buf of the reduced SA

parent 7cb7e508
......@@ -86,7 +86,7 @@ class ReducedMultiheadAttention(nn.Module):
self.add_zero_attn = add_zero_attn
self.sample_ratio = sample_ratio
if self.sample_ratio > 1:
self.sr = nn.Conv2d(embed_dim, embed_dim, kernel_size=sample_ratio, stride=sample_ratio)
self.sr = nn.Conv1d(embed_dim, embed_dim, kernel_size=sample_ratio, stride=sample_ratio)
self.norm = nn.LayerNorm(embed_dim)
self.reset_parameters()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论