Commit 1288e535 by xuchen

fix the bugs of checkpoint load and add the special kernel sizes for pds conformer

parent 408e2b95
......@@ -721,7 +721,7 @@ def load_pretrained_component_from_model(
mismatch_keys.append(key)
for name, child in modules._modules.items():
check(load_state_dict, child, name + prefix + ".")
check(load_state_dict, child, prefix + name + ".")
check(component_state_dict, component)
# parameters = component.named_parameters()
......
......@@ -526,6 +526,11 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-cnn-kernel-sizes",
type=str,
help="the kernel size of convolutional modules in Conformer",
)
parser.add_argument(
"--pds-conv-strides",
type=str,
help="the strides of the convolutional module (conformer) in each stage",
......@@ -664,6 +669,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
self.pds_cnn_kernel_sizes = \
[int(n) for n in args.pds_cnn_kernel_sizes.split("_")] \
if getattr(args, "pds_cnn_kernel_sizes", None) is not None else None
self.pds_attn_ds_ratios = \
[int(n) for n in args.pds_attn_ds_ratios.split("_")] if args.pds_attn_ds_ratios is not None else None
......@@ -710,8 +718,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i] if self.pds_ctc is not None else False
ffn_ratio = self.pds_ffn_ratios[i]
num_head = self.pds_attn_heads[i]
ffn_ratio = self.pds_ffn_ratios[i]
cnn_kernel_size = self.pds_cnn_kernel_sizes[i] if self.pds_cnn_kernel_sizes is not None else None
attn_ds_ratio = self.pds_attn_ds_ratios[i] \
if self.pds_conv_strides is not None and self.attn_type == "reduced" else 1
conv_stride = self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1
......@@ -778,6 +787,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
conv_stride=conv_stride if layer_idx == num_layers - 1 else 1,
attn_stride=attn_stride if layer_idx == num_layers - 1 else 1,
expand_embed_dim=expand_embed_dim if layer_idx == num_layers - 1 else None,
cnn_kernel_size=cnn_kernel_size,
)
for layer_idx in range(num_layers)])
......@@ -1238,6 +1248,7 @@ def base_architecture(args):
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
......
......@@ -742,6 +742,7 @@ def base_architecture(args):
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", "1_1_1_1")
args.pds_conv_strides = getattr(args, "pds_conv_strides", "1_1_1_1")
......
......@@ -499,6 +499,7 @@ def base_architecture(args):
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
......
......@@ -527,6 +527,7 @@ def base_architecture(args):
args.pds_attn_heads = getattr(args, "pds_attn_heads", None)
args.pds_ffn_ratios = getattr(args, "pds_ffn_ratios", None)
args.pds_cnn_kernel_sizes = getattr(args, "pds_cnn_kernel_sizes", None)
args.pds_attn_ds_ratios = getattr(args, "pds_attn_ds_ratios", None)
args.pds_conv_strides = getattr(args, "pds_conv_strides", None)
......
......@@ -151,7 +151,6 @@ class MultiheadAttention(nn.Module):
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if (
False and
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
......@@ -350,8 +349,8 @@ class MultiheadAttention(nn.Module):
if before_softmax:
return attn_weights, v
attn_weights = attn_weights.clamp(min=-1e8 if attn_weights.dtype == torch.float32 else -1e4,
max=1e8 if attn_weights.dtype == torch.float32 else 1e4)
# attn_weights = attn_weights.clamp(min=-1e8 if attn_weights.dtype == torch.float32 else -1e4,
# max=1e8 if attn_weights.dtype == torch.float32 else 1e4)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
attn_weights = attn_weights_float.type_as(attn_weights)
......
......@@ -46,7 +46,8 @@ class PDSTransformerEncoderLayer(nn.Module):
attn_sample_ratio=1,
attn_stride=1,
conv_stride=1,
expand_embed_dim=None):
expand_embed_dim=None,
cnn_kernel_size=None):
super().__init__()
self.args = args
......@@ -98,7 +99,7 @@ class PDSTransformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(
embed_dim,
expand_embed_dim,
depthwise_kernel_size=args.cnn_module_kernel,
depthwise_kernel_size=args.cnn_module_kernel if cnn_kernel_size is None else cnn_kernel_size,
dropout=args.dropout,
activation_fn=activation,
stride=conv_stride
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论