Commit f0605efa by xuchen

update the pyramid transformer

parent ed623111
......@@ -4,6 +4,7 @@ pyramid-stages: 4
pyramid-layers: 2_2_6_2
#encoder-attention-type: reduced
#pyramid-attn-sample-ratios: 8_4_2_1
#pyramid-block-attn: True
pyramid-sr-ratios: 2_2_2_2
pyramid-use-ppm: True
pyramid-embed-dims: 128_128_256_512
......@@ -42,7 +43,7 @@ lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3
ctc-weight: 0.0
label_smoothing: 0.1
conv-channels: 1024
......
......@@ -94,8 +94,6 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_loss = self.compute_ctc_loss(model, sample, encoder_out)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
loss = (1 - self.ctc_weight) * loss + self.ctc_weight * ctc_loss
else:
loss = (1 - self.ctc_weight) * loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
......
......@@ -55,16 +55,18 @@ class ReducedEmbed(nn.Module):
self.conv = nn.Conv1d(in_channels, out_channels * 2, kernel_sizes, stride=stride, padding=padding)
self.glu = nn.GLU(dim=1)
elif self.reduced_way == "proj":
self.proj = nn.Linear(2 * in_channels, out_channels, bias=False)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, padding=padding)
elif self.reduced_way == "fuse":
self.conv = nn.Conv1d(in_channels, out_channels, kernel_sizes, stride=stride, padding=padding)
self.conv_proj = nn.Conv1d(in_channels, out_channels, kernel_sizes, padding=padding)
else:
logger.error("Unsupported reduced way!")
self.embed_norm = embed_norm
if self.embed_norm:
if self.reduced_way == "proj":
self.norm = LayerNorm(2 * in_channels)
else:
self.norm = LayerNorm(out_channels)
if self.reduced_way in ["proj", "fuse"]:
self.in_norm = nn.BatchNorm1d(in_channels)
self.norm = LayerNorm(out_channels)
def forward(self, x, lengths):
seq_len, bsz, dim = x.size()
......@@ -78,15 +80,22 @@ class ReducedEmbed(nn.Module):
x.masked_fill_(mask_pad, 0.0)
x = x.transpose(0, 1)
# x = self.in_norm(x)
if self.reduced_way == "proj":
x = x.transpose(0, 1).contiguous().view(bsz, int(seq_len / 2), -1)
x = self.proj(self.norm(x))
x = x.transpose(0, 1)
x = x.permute(1, 2, 0) # bsz, dim, seq_len
x = nn.functional.adaptive_avg_pool1d(x, int(seq_len // self.stride))
x = self.conv(self.in_norm(x))
x = x.permute(2, 0, 1) # seq_len, bsz, dim
else:
x = x.permute(1, 2, 0) # B * D * T
origin_x = x
x = self.conv(x)
if self.reduced_way == "glu":
x = self.glu(x)
if self.reduced_way == "fuse":
x2 = nn.functional.adaptive_avg_pool1d(origin_x, int(seq_len // self.stride))
x2 = self.conv_proj(self.in_norm(x2))
x = x + x2
x = x.permute(2, 0, 1) # T * B * D
if self.embed_norm:
x = self.norm(x)
......@@ -182,7 +191,7 @@ class PYS2TTransformerModel(S2TTransformerModel):
parser.add_argument(
"--pyramid-reduced-embed",
type=str,
choices=["glu", "conv", "proj"],
choices=["glu", "conv", "proj", "fuse"],
help="the reduced way of the embedding",
)
parser.add_argument(
......@@ -285,12 +294,16 @@ class PyS2TTransformerEncoder(FairseqEncoder):
for i in range(self.pyramid_stages):
num_layers = self.pyramid_layers[i]
sr_ratio = self.pyramid_sr_ratios[i]
attn_sample_ratio = self.pyramid_attn_sample_ratios[i]
attn_sample_ratio = self.pyramid_attn_sample_ratios[i] if self.attn_type == "reduced" else -1
embed_dim = self.pyramid_embed_dims[i]
kernel_size = self.pyramid_kernel_sizes[i]
ffn_ratio = self.pyramid_ffn_ratios[i]
num_head = self.pyramid_heads[i]
use_pos_embed = self.pyramid_position_embed[i]
logger.info("The stage {}: layer {}, sample ratio {}, attention sample ratio {}, embed dim {}, "
"kernel size {}, ffn ratio {}, num head {}, position embed {}".
format(i, num_layers, sr_ratio, attn_sample_ratio,
embed_dim, kernel_size, ffn_ratio, num_head, use_pos_embed))
if i == 0:
self.embed_scale = math.sqrt(embed_dim)
......@@ -300,12 +313,11 @@ class PyS2TTransformerEncoder(FairseqEncoder):
reduced_embed = ReducedEmbed(
self.pyramid_reduced_embed,
self.pyramid_embed_norm,
# self.pyramid_embed_norm if i != 0 else False,
args.input_feat_per_channel * args.input_channels if i == 0 else self.pyramid_embed_dims[i-1],
embed_dim,
kernel_sizes=kernel_size,
stride=sr_ratio,
padding=kernel_size // 2,
padding=(kernel_size - 1) // 2,
)
if use_pos_embed:
pos_embed = PositionalEmbedding(
......@@ -327,13 +339,16 @@ class PyS2TTransformerEncoder(FairseqEncoder):
if self.use_ppm:
ppm_layer_norm = LayerNorm(embed_dim)
ppm_layer_norm2 = LayerNorm(self.embed_dim)
ppm = nn.Sequential(
nn.Conv1d(embed_dim, self.embed_dim, kernel_size=1, bias=False),
nn.BatchNorm1d(self.embed_dim),
nn.ReLU(),
)
else:
ppm_layer_norm = None
ppm_layer_norm2 = None
ppm = None
setattr(self, f"reduced_embed{i + 1}", reduced_embed)
......@@ -342,12 +357,12 @@ class PyS2TTransformerEncoder(FairseqEncoder):
setattr(self, f"block_attn{i + 1}", block_attn)
setattr(self, f"ppm{i + 1}", ppm)
setattr(self, f"ppm_layer_norm{i + 1}", ppm_layer_norm)
setattr(self, f"ppm_layer_norm2{i + 1}", ppm_layer_norm2)
if i == self.pyramid_stages - 1:
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim)
else:
self.layer_norm = None
if self.use_ppm:
self.ppm_weight = nn.Parameter(torch.Tensor(self.pyramid_stages).fill_(1.0))
......@@ -442,21 +457,25 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = block_attn(x, prev_state[-1], prev_padding[-1])
if self.use_ppm:
pool_state = [x]
pool_state = []
seq_len, bsz, dim = x.size()
i = -1
for state in prev_state[:-1]:
for state in prev_state:
i += 1
ppm = getattr(self, f"ppm{i + 1}")
ppm_layer_norm = getattr(self, f"ppm_layer_norm{i + 1}")
ppm_layer_norm2 = getattr(self, f"ppm_layer_norm2{i + 1}")
state = ppm_layer_norm(state)
state = state.permute(1, 2, 0)
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = state.permute(1, 2, 0) # bsz, dim, seq_len
if i != self.pyramid_stages - 1:
state = nn.functional.adaptive_avg_pool1d(state, seq_len)
state = ppm(state)
state = state.permute(2, 0, 1)
state = ppm_layer_norm2(state)
pool_state.append(state)
x = (torch.stack(pool_state, dim=0) * self.ppm_weight.view(-1, 1, 1, 1)).sum(0)
ppm_weight = self.ppm_weight
x = (torch.stack(pool_state, dim=0) * ppm_weight.view(-1, 1, 1, 1)).sum(0)
if self.layer_norm is not None:
x = self.layer_norm(x)
......
......@@ -128,7 +128,6 @@ class RelPositionMultiheadAttention(MultiheadAttention):
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
......@@ -301,7 +300,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
x = x * torch.tril(ones, x.size(2) - x.size(1))[None, :, :]
return x
# matrix_bd = rel_shift(matrix_bd)
matrix_bd = rel_shift(matrix_bd)
attn_weights = (matrix_ac + matrix_bd) * self.scaling
......@@ -456,7 +455,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
dim: 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论