Commit 9fadf1f4 by xuchen

update the pyramid transformer about block fuse

parent f0605efa
...@@ -5,6 +5,7 @@ pyramid-layers: 2_2_6_2 ...@@ -5,6 +5,7 @@ pyramid-layers: 2_2_6_2
#encoder-attention-type: reduced #encoder-attention-type: reduced
#pyramid-attn-sample-ratios: 8_4_2_1 #pyramid-attn-sample-ratios: 8_4_2_1
#pyramid-block-attn: True #pyramid-block-attn: True
#pyramid-fuse-way: add
pyramid-sr-ratios: 2_2_2_2 pyramid-sr-ratios: 2_2_2_2
pyramid-use-ppm: True pyramid-use-ppm: True
pyramid-embed-dims: 128_128_256_512 pyramid-embed-dims: 128_128_256_512
......
...@@ -2,16 +2,12 @@ ...@@ -2,16 +2,12 @@
# training the model # training the model
gpu_num=4 gpu_num=8
update_freq=1 update_freq=1
max_tokens=80000 max_tokens=40000
#exp_tag=valid_prev_state exp_tag=baseline
#exp_tag=lower128 config_list=(base)
#exp_tag=sr8
#config_list=(base conformer rpr)
config_list=(pyramid)
#config_list=(pyramid_stage3 rpr)
# exp full name # exp full name
exp_name= exp_name=
...@@ -42,8 +38,7 @@ if [[ -n ${extra_tag} ]]; then ...@@ -42,8 +38,7 @@ if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}" cmd="$cmd --extra_tag ${extra_tag}"
fi fi
if [[ -n ${extra_parameter} ]]; then if [[ -n ${extra_parameter} ]]; then
# cmd="$cmd --extra_parameter \"${extra_parameter}\"" cmd="$cmd --extra_parameter \"${extra_parameter}\""
cmd="$cmd --extra_parameter ${extra_parameter}"
fi fi
echo ${cmd} echo ${cmd}
......
...@@ -7,7 +7,7 @@ update_freq=1 ...@@ -7,7 +7,7 @@ update_freq=1
max_tokens=40000 max_tokens=40000
exp_tag=baseline exp_tag=baseline
config_list=(ctc local_attn) config_list=(ctc)
# exp full name # exp full name
exp_name= exp_name=
...@@ -38,8 +38,7 @@ if [[ -n ${extra_tag} ]]; then ...@@ -38,8 +38,7 @@ if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}" cmd="$cmd --extra_tag ${extra_tag}"
fi fi
if [[ -n ${extra_parameter} ]]; then if [[ -n ${extra_parameter} ]]; then
# cmd="$cmd --extra_parameter \"${extra_parameter}\"" cmd="$cmd --extra_parameter \"${extra_parameter}\""
cmd="$cmd --extra_parameter ${extra_parameter}"
fi fi
echo ${cmd} echo ${cmd}
......
...@@ -80,7 +80,6 @@ class ReducedEmbed(nn.Module): ...@@ -80,7 +80,6 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0) x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1) x = x.transpose(0, 1)
# x = self.in_norm(x)
if self.reduced_way == "proj": if self.reduced_way == "proj":
x = x.permute(1, 2, 0) # bsz, dim, seq_len x = x.permute(1, 2, 0) # bsz, dim, seq_len
x = nn.functional.adaptive_avg_pool1d(x, int(seq_len // self.stride)) x = nn.functional.adaptive_avg_pool1d(x, int(seq_len // self.stride))
...@@ -115,40 +114,44 @@ class ReducedEmbed(nn.Module): ...@@ -115,40 +114,44 @@ class ReducedEmbed(nn.Module):
class BlockFuse(nn.Module): class BlockFuse(nn.Module):
def __init__(self, embed_dim, prev_embed_dim, dropout, num_head, fuse_way="add"): def __init__(self, embed_dim, prev_embed_dim, final_stage, fuse_way="add"):
super().__init__() super().__init__()
self.attn = MultiheadAttention( self.conv = nn.Sequential(
embed_dim, nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1, bias=False),
num_head, nn.ReLU()
kdim=prev_embed_dim,
vdim=prev_embed_dim,
dropout=dropout,
encoder_decoder_attention=True,
) )
self.q_layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
self.kv_layer_norm = LayerNorm(prev_embed_dim) self.kv_layer_norm = LayerNorm(prev_embed_dim)
self.fuse_way = fuse_way self.fuse_way = fuse_way
self.final_stage = final_stage
if self.fuse_way == "gated": if self.fuse_way == "gated":
self.gate_linear = nn.Linear(2 * embed_dim, embed_dim) self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
# self.gate = nn.GRUCell(embed_dim, embed_dim)
def forward(self, x, state, padding): def forward(self, x, state, padding):
residual = x seq_len, bsz, dim = x.size()
x = self.q_layer_norm(x)
state = self.kv_layer_norm(state) state = self.kv_layer_norm(state)
x, attn = self.attn( state = state.permute(1, 2, 0) # bsz, dim, seq_len
query=x, if state.size(-1) != seq_len:
key=state, state = nn.functional.adaptive_avg_pool1d(state, seq_len)
value=state, state = self.conv(state)
key_padding_mask=padding, state = state.permute(2, 0, 1) # seq_len, bsz, dim
static_kv=True,
)
if self.fuse_way == "add": if self.fuse_way == "gated":
x = residual + x # x = x.contiguous().view(-1, dim)
elif self.fuse_way == "gated": # state = state.contiguous().view(-1, dim)
coef = (self.gate_linear(torch.cat([x, residual], dim=-1))).sigmoid() # x = self.gate(x, state).view(seq_len, bsz, dim)
x = coef * x + (1 - coef) * residual coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state
x = state + x
else:
x = x + state
# if not self.final_stage:
x = self.layer_norm(x)
return x return x
...@@ -331,15 +334,16 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -331,15 +334,16 @@ class PyS2TTransformerEncoder(FairseqEncoder):
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio) PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio)
for _ in range(num_layers)]) for _ in range(num_layers)])
block_attn = None block_fuse = None
if self.pyramid_block_attn: if self.pyramid_block_attn:
if i != 0: if i != 0:
block_attn = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1], block_fuse = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1],
args.dropout, num_head, self.pyramid_fuse_way) final_stage=True if i == self.pyramid_stages - 1 else False,
fuse_way=self.pyramid_fuse_way)
if self.use_ppm: if self.use_ppm:
ppm_layer_norm = LayerNorm(embed_dim) ppm_pre_layer_norm = LayerNorm(embed_dim)
ppm_layer_norm2 = LayerNorm(self.embed_dim) ppm_post_layer_norm = LayerNorm(self.embed_dim)
ppm = nn.Sequential( ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False), nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False),
nn.BatchNorm1d(self.embed_dim), nn.BatchNorm1d(self.embed_dim),
...@@ -347,17 +351,17 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -347,17 +351,17 @@ class PyS2TTransformerEncoder(FairseqEncoder):
) )
else: else:
ppm_layer_norm = None ppm_pre_layer_norm = None
ppm_layer_norm2 = None ppm_post_layer_norm = None
ppm = None ppm = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed) setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed) setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"block{i + 1}", block) setattr(self, f"block{i + 1}", block)
setattr(self, f"block_attn{i + 1}", block_attn) setattr(self, f"block_fuse{i + 1}", block_fuse)
setattr(self, f"ppm{i + 1}", ppm) setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_layer_norm{i + 1}", ppm_layer_norm) setattr(self, f"ppm_pre_layer_norm{i + 1}", ppm_pre_layer_norm)
setattr(self, f"ppm_layer_norm2{i + 1}", ppm_layer_norm2) setattr(self, f"ppm_layer_norm2{i + 1}", ppm_post_layer_norm)
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim) self.layer_norm = LayerNorm(self.embed_dim)
...@@ -425,13 +429,14 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -425,13 +429,14 @@ class PyS2TTransformerEncoder(FairseqEncoder):
reduced_embed = getattr(self, f"reduced_embed{i + 1}") reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}")
block = getattr(self, f"block{i + 1}") block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"block_attn{i + 1}") block_fuse = getattr(self, f"block_fuse{i + 1}")
# if i == 0:
# x = self.embed_scale * x
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths) x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths)
# add the position encoding and dropout # add the position encoding and dropout
if block_fuse is not None:
x = block_fuse(x, prev_state[-1], prev_padding[-1])
if pos_embed: if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1) positions = pos_embed(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn": if self.attn_type != "rel_selfattn":
...@@ -453,9 +458,6 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -453,9 +458,6 @@ class PyS2TTransformerEncoder(FairseqEncoder):
prev_state.append(x) prev_state.append(x)
prev_padding.append(encoder_padding_mask) prev_padding.append(encoder_padding_mask)
if block_attn is not None:
x = block_attn(x, prev_state[-1], prev_padding[-1])
if self.use_ppm: if self.use_ppm:
pool_state = [] pool_state = []
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
...@@ -463,16 +465,16 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -463,16 +465,16 @@ class PyS2TTransformerEncoder(FairseqEncoder):
for state in prev_state: for state in prev_state:
i += 1 i += 1
ppm = getattr(self, f"ppm{i + 1}") ppm = getattr(self, f"ppm{i + 1}")
ppm_layer_norm = getattr(self, f"ppm_layer_norm{i + 1}") ppm_pre_layer_norm = getattr(self, f"ppm_pre_layer_norm{i + 1}")
ppm_layer_norm2 = getattr(self, f"ppm_layer_norm2{i + 1}") ppm_post_layer_norm = getattr(self, f"ppm_post_layer_norm{i + 1}")
state = ppm_layer_norm(state) state = ppm_pre_layer_norm(state)
state = state.permute(1, 2, 0) # bsz, dim, seq_len state = state.permute(1, 2, 0) # bsz, dim, seq_len
if i != self.pyramid_stages - 1: if i != self.pyramid_stages - 1:
state = nn.functional.adaptive_avg_pool1d(state, seq_len) state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = ppm(state) state = ppm(state)
state = state.permute(2, 0, 1) state = state.permute(2, 0, 1)
state = ppm_layer_norm2(state) state = ppm_post_layer_norm(state)
pool_state.append(state) pool_state.append(state)
ppm_weight = self.ppm_weight ppm_weight = self.ppm_weight
x = (torch.stack(pool_state, dim=0) * ppm_weight.view(-1, 1, 1, 1)).sum(0) x = (torch.stack(pool_state, dim=0) * ppm_weight.view(-1, 1, 1, 1)).sum(0)
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import torch.nn as nn import torch.nn as nn
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding, RelPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
def PositionalEmbedding( def PositionalEmbedding(
...@@ -27,12 +27,6 @@ def PositionalEmbedding( ...@@ -27,12 +27,6 @@ def PositionalEmbedding(
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None: if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
elif pos_emb_type is not None and pos_emb_type.startswith("debug"):
m = RelPositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
)
else: else:
m = SinusoidalPositionalEmbedding( m = SinusoidalPositionalEmbedding(
embedding_dim, embedding_dim,
......
...@@ -103,37 +103,3 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -103,37 +103,3 @@ class SinusoidalPositionalEmbedding(nn.Module):
.view(bsz, seq_len, -1) .view(bsz, seq_len, -1)
.detach() .detach()
) )
class RelPositionalEmbedding(SinusoidalPositionalEmbedding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__(embedding_dim, padding_idx, init_size)
self.max_size = init_size
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep: Optional[Tensor] = None,
positions: Optional[Any] = None,
offset: int = 0
):
"""Compute positional encoding.
Args:
input (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert offset + input.size(1) < self.max_size
self.weights = self.weights.to(input.device)
pos_emb = self.weights[:, offset:offset + input.size(1)]
return pos_emb
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论