Commit 31d0303e by xuchen

support the ppm for pyramid transformer

parent c9d8dbc3
......@@ -204,7 +204,7 @@ class ReducedMultiheadAttention(nn.Module):
q = self.q_proj(query)
if self.self_attention:
if self.sample_ratio > 1:
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len
query_ = query.permute(1, 2, 0) # bsz, dim, seq_len:
query_ = self.sr(query_).permute(2, 0, 1) # seq_len, bsz, dim
query = self.norm(query_)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论