Commit b4e95869 by xuchen

code optimize

parent 4f452308
......@@ -403,10 +403,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "")
if inter_ctc_mlo != "":
inter_ctc_mlo = [int(x) for x in inter_ctc_mlo.split(":")]
# assert len(inter_ctc_mlo.split(":")) - 1 == self.pds_stages, (
# inter_ctc_mlo,
# self.pds_stages,
# )
if self.share_inter_ctc is True:
self.share_inter_ctc = False
logger.info("Overwrite the config share_inter_ctc to False for MLO.")
......@@ -462,44 +458,47 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
"linear_init": getattr(args, "pae_linear_init", False)
}
for i in range(self.pds_stages):
num_layers = self.pds_layers[i]
ds_ratio = self.pds_ratios[i]
def get_pds_settings(list, idx, return_bool=False, default=None):
if list is None:
if return_bool:
return False if default is None else default
else:
return None if default is None else default
if idx >= len(list):
return list[-1] if default is None else default
else:
return list[idx]
embed_dim = self.pds_embed_dims[i]
kernel_size = self.pds_kernel_sizes[i]
use_pos_embed = self.pds_position_embed[i]
use_ctc = self.pds_ctc[i] if self.pds_ctc is not None else False
use_xctc = self.pds_xctc[i] if self.pds_xctc is not None else False
num_head = self.pds_attn_heads[i]
ffn_ratio = self.pds_ffn_ratios[i]
cnn_kernel_size = (
self.pds_cnn_kernel_sizes[i]
if self.pds_cnn_kernel_sizes is not None
else None
)
for i in range(self.pds_stages):
num_layers = get_pds_settings(self.pds_layers, i, default=1)
ds_ratio = get_pds_settings(self.pds_ratios, i, default=0)
embed_dim = get_pds_settings(self.pds_embed_dims, i)
kernel_size = get_pds_settings(self.pds_kernel_sizes, i)
use_pos_embed = get_pds_settings(self.pds_position_embed, i)
use_ctc = get_pds_settings(self.pds_ctc, i, True)
use_xctc = get_pds_settings(self.pds_xctc, i, True)
num_head = get_pds_settings(self.pds_attn_heads, i)
ffn_ratio = get_pds_settings(self.pds_ffn_ratios, i)
cnn_kernel_size = get_pds_settings( self.pds_cnn_kernel_sizes, i)
attn_ds_ratio = (
self.pds_attn_ds_ratios[i]
if self.pds_conv_strides is not None and self.attn_type == "reduced"
else 1
)
conv_stride = (
self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1
)
attn_stride = (
self.pds_attn_strides[i] if self.pds_attn_strides is not None else 1
get_pds_settings(self.pds_attn_ds_ratios, i, default=1)
if self.attn_type == "reduced" else 1
)
conv_stride = get_pds_settings(self.pds_conv_strides, i, default=1)
attn_stride = get_pds_settings(self.pds_attn_strides, i, default=1)
if conv_stride != 1 or attn_stride != 1:
expand_embed_dim = (
embed_dim
if i == self.pds_stages - 1
else self.pds_embed_dims[i + 1]
else get_pds_settings(self.pds_embed_dims, i + 1)
)
else:
expand_embed_dim = None
fusion = self.pds_fusion_layers[i]
fusion = get_pds_settings(self.pds_fusion_layers, i)
logger.info(
"The stage {}: layer {}, down-sample ratio {}, embed dim {}, "
......@@ -534,13 +533,15 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if ds_ratio == -1:
# downsampling = subsampling(args, embed_dim)
downsampling = subsampling(args)
elif ds_ratio == 0:
downsampling = None
else:
downsampling = Downsampling(
self.pds_ds_method,
self.pds_embed_norm,
args.input_feat_per_channel * args.input_channels
if i == 0
else self.pds_embed_dims[i - 1],
else get_pds_settings(self.pds_embed_dims, i - 1),
embed_dim,
kernel_sizes=kernel_size,
stride=ds_ratio,
......@@ -654,6 +655,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
)
),
dropout=args.dropout,
dictionary=task.source_dictionary
)
inter_ctc = ctc
else:
......@@ -661,6 +663,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
dictionary=task.source_dictionary
)
if (
......@@ -848,6 +851,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False,
dictionary=task.source_dictionary
)
if (
......@@ -928,6 +932,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.gather_cos_sim_dis = 2
self.cos_sim = dict()
self.early_exit_count = 0
self.early_exit_layer_record = []
self.early_exit_layer = 0
def set_flag(self, **kwargs):
for i in range(self.pds_stages):
stage = getattr(self, f"stage{i + 1}")
......@@ -939,11 +947,19 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.mixup_infer = kwargs.get("mixup_infer", False)
self.gather_cos_sim = kwargs.get("gather_cos_sim", False)
self.gather_cos_sim_dis = kwargs.get("gather_cos_sim_dis", 2)
self.early_exit_layer = kwargs.get("early_exit_layer", 0)
if self.early_exit_layer != 0:
logger.info("Using the logit in layer %d to infer." % self.early_exit_layer)
# if self.mixup_infer:
# self.mixup_keep_org = True
if self.mixup_infer:
self.mixup_keep_org = True
def dump(self, fstream, info=""):
print("Early exit layer.", file=fstream)
if self.early_exit_count != 0:
print("\n".join([str(l) for l in self.early_exit_layer_record]), file=fstream)
for i in range(self.pds_stages):
idx = 0
stage = getattr(self, f"stage{i + 1}")
......@@ -1016,8 +1032,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
return x, encoder_padding_mask, input_lengths, mixup
def set_ctc_infer(
self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None
self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None, early_exit_count=0
):
self.early_exit_count = early_exit_count
if hasattr(self, "ctc"):
import os
assert src_dict is not None
......@@ -1040,6 +1057,27 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
logger.error("No ctc module in the encoder")
def early_exit_or_not(self, history, new_logit, count):
history.append(new_logit)
length = len(history)
if count == 0 or length < count:
return False
else:
# for logit in history[length - count: length - 1]:
# if new_logit.size() != logit.size() or not (new_logit == logit).all():
# return False
# return True
hit = 0
for logit in history[: length - 1]:
if new_logit.size() == logit.size() and (new_logit == logit).all():
hit += 1
if hit >= count:
return True
else:
return False
def forward(self, src_tokens, src_lengths, **kwargs):
batch = src_tokens.size(0)
......@@ -1071,6 +1109,18 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_logits = []
xctc_logit = None
inter_xctc_logits = []
# Infer early exit
org_bsz = x.size(1)
batch_idx_dict = dict()
inter_ctc_logits_history = dict()
final_ctc_logits = dict()
final_encoder_padding_mask = dict()
early_exit_layer = dict()
for i in range(x.size(1)):
inter_ctc_logits_history[i] = []
batch_idx_dict[i] = i
for i in range(self.pds_stages):
downsampling = getattr(self, f"downsampling{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
......@@ -1088,6 +1138,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
x, encoder_padding_mask
)
if downsampling is not None:
x, input_lengths = downsampling(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen(
input_lengths, x.size(0)
......@@ -1172,6 +1223,19 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_logits.append([logit.clone(), encoder_padding_mask])
if not self.training and self.early_exit_count != 0:
predicts = ctc.predict(logit, encoder_padding_mask)
if len(inter_ctc_logits) < self.early_exit_count:
for i in range(x.size(1)):
inter_ctc_logits_history[i].append(predicts[i])
else:
if org_bsz == 1:
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[0], predicts[0], self.early_exit_count)
if early_exit_flag:
ctc_logit = logit
self.early_exit_layer_record.append(layer_idx)
break
# Inter XCTC
if xctc is not None:
norm_x = xctc_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论