Commit 9fadf1f4 by xuchen

update the pyramid transformer about block fuse

parent f0605efa
......@@ -5,6 +5,7 @@ pyramid-layers: 2_2_6_2
#encoder-attention-type: reduced
#pyramid-attn-sample-ratios: 8_4_2_1
#pyramid-block-attn: True
#pyramid-fuse-way: add
pyramid-sr-ratios: 2_2_2_2
pyramid-use-ppm: True
pyramid-embed-dims: 128_128_256_512
......
......@@ -2,16 +2,12 @@
# training the model
gpu_num=4
gpu_num=8
update_freq=1
max_tokens=80000
max_tokens=40000
#exp_tag=valid_prev_state
#exp_tag=lower128
#exp_tag=sr8
#config_list=(base conformer rpr)
config_list=(pyramid)
#config_list=(pyramid_stage3 rpr)
exp_tag=baseline
config_list=(base)
# exp full name
exp_name=
......@@ -42,8 +38,7 @@ if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}"
fi
if [[ -n ${extra_parameter} ]]; then
# cmd="$cmd --extra_parameter \"${extra_parameter}\""
cmd="$cmd --extra_parameter ${extra_parameter}"
cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi
echo ${cmd}
......
......@@ -7,7 +7,7 @@ update_freq=1
max_tokens=40000
exp_tag=baseline
config_list=(ctc local_attn)
config_list=(ctc)
# exp full name
exp_name=
......@@ -38,8 +38,7 @@ if [[ -n ${extra_tag} ]]; then
cmd="$cmd --extra_tag ${extra_tag}"
fi
if [[ -n ${extra_parameter} ]]; then
# cmd="$cmd --extra_parameter \"${extra_parameter}\""
cmd="$cmd --extra_parameter ${extra_parameter}"
cmd="$cmd --extra_parameter \"${extra_parameter}\""
fi
echo ${cmd}
......
......@@ -80,7 +80,6 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
# x = self.in_norm(x)
if self.reduced_way == "proj":
x = x.permute(1, 2, 0) # bsz, dim, seq_len
x = nn.functional.adaptive_avg_pool1d(x, int(seq_len // self.stride))
......@@ -115,40 +114,44 @@ class ReducedEmbed(nn.Module):
class BlockFuse(nn.Module):
def __init__(self, embed_dim, prev_embed_dim, dropout, num_head, fuse_way="add"):
def __init__(self, embed_dim, prev_embed_dim, final_stage, fuse_way="add"):
super().__init__()
self.attn = MultiheadAttention(
embed_dim,
num_head,
kdim=prev_embed_dim,
vdim=prev_embed_dim,
dropout=dropout,
encoder_decoder_attention=True,
self.conv = nn.Sequential(
nn.Conv1d(prev_embed_dim, embed_dim, kernel_size=1, bias=False),
nn.ReLU()
)
self.q_layer_norm = LayerNorm(embed_dim)
self.layer_norm = LayerNorm(embed_dim)
self.kv_layer_norm = LayerNorm(prev_embed_dim)
self.fuse_way = fuse_way
self.final_stage = final_stage
if self.fuse_way == "gated":
self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
# self.gate = nn.GRUCell(embed_dim, embed_dim)
def forward(self, x, state, padding):
residual = x
x = self.q_layer_norm(x)
seq_len, bsz, dim = x.size()
state = self.kv_layer_norm(state)
x, attn = self.attn(
query=x,
key=state,
value=state,
key_padding_mask=padding,
static_kv=True,
)
state = state.permute(1, 2, 0) # bsz, dim, seq_len
if state.size(-1) != seq_len:
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = self.conv(state)
state = state.permute(2, 0, 1) # seq_len, bsz, dim
if self.fuse_way == "add":
x = residual + x
elif self.fuse_way == "gated":
coef = (self.gate_linear(torch.cat([x, residual], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * residual
if self.fuse_way == "gated":
# x = x.contiguous().view(-1, dim)
# state = state.contiguous().view(-1, dim)
# x = self.gate(x, state).view(seq_len, bsz, dim)
coef = (self.gate_linear(torch.cat([x, state], dim=-1))).sigmoid()
x = coef * x + (1 - coef) * state
x = state + x
else:
x = x + state
# if not self.final_stage:
x = self.layer_norm(x)
return x
......@@ -331,15 +334,16 @@ class PyS2TTransformerEncoder(FairseqEncoder):
PyramidTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_sample_ratio)
for _ in range(num_layers)])
block_attn = None
block_fuse = None
if self.pyramid_block_attn:
if i != 0:
block_attn = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1],
args.dropout, num_head, self.pyramid_fuse_way)
block_fuse = BlockFuse(embed_dim, self.pyramid_embed_dims[i-1],
final_stage=True if i == self.pyramid_stages - 1 else False,
fuse_way=self.pyramid_fuse_way)
if self.use_ppm:
ppm_layer_norm = LayerNorm(embed_dim)
ppm_layer_norm2 = LayerNorm(self.embed_dim)
ppm_pre_layer_norm = LayerNorm(embed_dim)
ppm_post_layer_norm = LayerNorm(self.embed_dim)
ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False),
nn.BatchNorm1d(self.embed_dim),
......@@ -347,17 +351,17 @@ class PyS2TTransformerEncoder(FairseqEncoder):
)
else:
ppm_layer_norm = None
ppm_layer_norm2 = None
ppm_pre_layer_norm = None
ppm_post_layer_norm = None
ppm = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"block_attn{i + 1}", block_attn)
setattr(self, f"block_fuse{i + 1}", block_fuse)
setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_layer_norm{i + 1}", ppm_layer_norm)
setattr(self, f"ppm_layer_norm2{i + 1}", ppm_layer_norm2)
setattr(self, f"ppm_pre_layer_norm{i + 1}", ppm_pre_layer_norm)
setattr(self, f"ppm_layer_norm2{i + 1}", ppm_post_layer_norm)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim)
......@@ -425,13 +429,14 @@ class PyS2TTransformerEncoder(FairseqEncoder):
reduced_embed = getattr(self, f"reduced_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
block_attn = getattr(self, f"block_attn{i + 1}")
# if i == 0:
# x = self.embed_scale * x
block_fuse = getattr(self, f"block_fuse{i + 1}")
x, input_lengths, encoder_padding_mask = reduced_embed(x, input_lengths)
# add the position encoding and dropout
if block_fuse is not None:
x = block_fuse(x, prev_state[-1], prev_padding[-1])
if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
......@@ -453,9 +458,6 @@ class PyS2TTransformerEncoder(FairseqEncoder):
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
if block_attn is not None:
x = block_attn(x, prev_state[-1], prev_padding[-1])
if self.use_ppm:
pool_state = []
seq_len, bsz, dim = x.size()
......@@ -463,16 +465,16 @@ class PyS2TTransformerEncoder(FairseqEncoder):
for state in prev_state:
i += 1
ppm = getattr(self, f"ppm{i + 1}")
ppm_layer_norm = getattr(self, f"ppm_layer_norm{i + 1}")
ppm_layer_norm2 = getattr(self, f"ppm_layer_norm2{i + 1}")
ppm_pre_layer_norm = getattr(self, f"ppm_pre_layer_norm{i + 1}")
ppm_post_layer_norm = getattr(self, f"ppm_post_layer_norm{i + 1}")
state = ppm_layer_norm(state)
state = ppm_pre_layer_norm(state)
state = state.permute(1, 2, 0) # bsz, dim, seq_len
if i != self.pyramid_stages - 1:
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = ppm(state)
state = state.permute(2, 0, 1)
state = ppm_layer_norm2(state)
state = ppm_post_layer_norm(state)
pool_state.append(state)
ppm_weight = self.ppm_weight
x = (torch.stack(pool_state, dim=0) * ppm_weight.view(-1, 1, 1, 1)).sum(0)
......
......@@ -6,7 +6,7 @@
import torch.nn as nn
from .learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding, RelPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
def PositionalEmbedding(
......@@ -27,12 +27,6 @@ def PositionalEmbedding(
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
elif pos_emb_type is not None and pos_emb_type.startswith("debug"):
m = RelPositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim,
......
......@@ -103,37 +103,3 @@ class SinusoidalPositionalEmbedding(nn.Module):
.view(bsz, seq_len, -1)
.detach()
)
class RelPositionalEmbedding(SinusoidalPositionalEmbedding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__(embedding_dim, padding_idx, init_size)
self.max_size = init_size
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep: Optional[Tensor] = None,
positions: Optional[Any] = None,
offset: int = 0
):
"""Compute positional encoding.
Args:
input (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert offset + input.size(1) < self.max_size
self.weights = self.weights.to(input.device)
pos_emb = self.weights[:, offset:offset + input.size(1)]
return pos_emb
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论