Commit b4e95869 by xuchen

code optimize

parent 4f452308
...@@ -403,10 +403,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -403,10 +403,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "") inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "")
if inter_ctc_mlo != "": if inter_ctc_mlo != "":
inter_ctc_mlo = [int(x) for x in inter_ctc_mlo.split(":")] inter_ctc_mlo = [int(x) for x in inter_ctc_mlo.split(":")]
# assert len(inter_ctc_mlo.split(":")) - 1 == self.pds_stages, (
# inter_ctc_mlo,
# self.pds_stages,
# )
if self.share_inter_ctc is True: if self.share_inter_ctc is True:
self.share_inter_ctc = False self.share_inter_ctc = False
logger.info("Overwrite the config share_inter_ctc to False for MLO.") logger.info("Overwrite the config share_inter_ctc to False for MLO.")
...@@ -462,44 +458,47 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -462,44 +458,47 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
"linear_init": getattr(args, "pae_linear_init", False) "linear_init": getattr(args, "pae_linear_init", False)
} }
for i in range(self.pds_stages): def get_pds_settings(list, idx, return_bool=False, default=None):
num_layers = self.pds_layers[i] if list is None:
ds_ratio = self.pds_ratios[i] if return_bool:
return False if default is None else default
else:
return None if default is None else default
if idx >= len(list):
return list[-1] if default is None else default
else:
return list[idx]
embed_dim = self.pds_embed_dims[i] for i in range(self.pds_stages):
kernel_size = self.pds_kernel_sizes[i] num_layers = get_pds_settings(self.pds_layers, i, default=1)
use_pos_embed = self.pds_position_embed[i] ds_ratio = get_pds_settings(self.pds_ratios, i, default=0)
use_ctc = self.pds_ctc[i] if self.pds_ctc is not None else False
use_xctc = self.pds_xctc[i] if self.pds_xctc is not None else False embed_dim = get_pds_settings(self.pds_embed_dims, i)
kernel_size = get_pds_settings(self.pds_kernel_sizes, i)
num_head = self.pds_attn_heads[i] use_pos_embed = get_pds_settings(self.pds_position_embed, i)
ffn_ratio = self.pds_ffn_ratios[i] use_ctc = get_pds_settings(self.pds_ctc, i, True)
cnn_kernel_size = ( use_xctc = get_pds_settings(self.pds_xctc, i, True)
self.pds_cnn_kernel_sizes[i]
if self.pds_cnn_kernel_sizes is not None num_head = get_pds_settings(self.pds_attn_heads, i)
else None ffn_ratio = get_pds_settings(self.pds_ffn_ratios, i)
) cnn_kernel_size = get_pds_settings( self.pds_cnn_kernel_sizes, i)
attn_ds_ratio = ( attn_ds_ratio = (
self.pds_attn_ds_ratios[i] get_pds_settings(self.pds_attn_ds_ratios, i, default=1)
if self.pds_conv_strides is not None and self.attn_type == "reduced" if self.attn_type == "reduced" else 1
else 1
)
conv_stride = (
self.pds_conv_strides[i] if self.pds_conv_strides is not None else 1
)
attn_stride = (
self.pds_attn_strides[i] if self.pds_attn_strides is not None else 1
) )
conv_stride = get_pds_settings(self.pds_conv_strides, i, default=1)
attn_stride = get_pds_settings(self.pds_attn_strides, i, default=1)
if conv_stride != 1 or attn_stride != 1: if conv_stride != 1 or attn_stride != 1:
expand_embed_dim = ( expand_embed_dim = (
embed_dim embed_dim
if i == self.pds_stages - 1 if i == self.pds_stages - 1
else self.pds_embed_dims[i + 1] else get_pds_settings(self.pds_embed_dims, i + 1)
) )
else: else:
expand_embed_dim = None expand_embed_dim = None
fusion = self.pds_fusion_layers[i] fusion = get_pds_settings(self.pds_fusion_layers, i)
logger.info( logger.info(
"The stage {}: layer {}, down-sample ratio {}, embed dim {}, " "The stage {}: layer {}, down-sample ratio {}, embed dim {}, "
...@@ -534,13 +533,15 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -534,13 +533,15 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if ds_ratio == -1: if ds_ratio == -1:
# downsampling = subsampling(args, embed_dim) # downsampling = subsampling(args, embed_dim)
downsampling = subsampling(args) downsampling = subsampling(args)
elif ds_ratio == 0:
downsampling = None
else: else:
downsampling = Downsampling( downsampling = Downsampling(
self.pds_ds_method, self.pds_ds_method,
self.pds_embed_norm, self.pds_embed_norm,
args.input_feat_per_channel * args.input_channels args.input_feat_per_channel * args.input_channels
if i == 0 if i == 0
else self.pds_embed_dims[i - 1], else get_pds_settings(self.pds_embed_dims, i - 1),
embed_dim, embed_dim,
kernel_sizes=kernel_size, kernel_sizes=kernel_size,
stride=ds_ratio, stride=ds_ratio,
...@@ -654,6 +655,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -654,6 +655,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
) )
), ),
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary
) )
inter_ctc = ctc inter_ctc = ctc
else: else:
...@@ -661,6 +663,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -661,6 +663,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
embed_dim, embed_dim,
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary
) )
if ( if (
...@@ -848,6 +851,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -848,6 +851,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
dictionary_size=len(task.source_dictionary), dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False, need_layernorm=True if self.inter_ctc else False,
dictionary=task.source_dictionary
) )
if ( if (
...@@ -928,6 +932,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -928,6 +932,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.gather_cos_sim_dis = 2 self.gather_cos_sim_dis = 2
self.cos_sim = dict() self.cos_sim = dict()
self.early_exit_count = 0
self.early_exit_layer_record = []
self.early_exit_layer = 0
def set_flag(self, **kwargs): def set_flag(self, **kwargs):
for i in range(self.pds_stages): for i in range(self.pds_stages):
stage = getattr(self, f"stage{i + 1}") stage = getattr(self, f"stage{i + 1}")
...@@ -939,11 +947,19 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -939,11 +947,19 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.mixup_infer = kwargs.get("mixup_infer", False) self.mixup_infer = kwargs.get("mixup_infer", False)
self.gather_cos_sim = kwargs.get("gather_cos_sim", False) self.gather_cos_sim = kwargs.get("gather_cos_sim", False)
self.gather_cos_sim_dis = kwargs.get("gather_cos_sim_dis", 2) self.gather_cos_sim_dis = kwargs.get("gather_cos_sim_dis", 2)
self.early_exit_layer = kwargs.get("early_exit_layer", 0)
if self.early_exit_layer != 0:
logger.info("Using the logit in layer %d to infer." % self.early_exit_layer)
# if self.mixup_infer: if self.mixup_infer:
# self.mixup_keep_org = True self.mixup_keep_org = True
def dump(self, fstream, info=""): def dump(self, fstream, info=""):
print("Early exit layer.", file=fstream)
if self.early_exit_count != 0:
print("\n".join([str(l) for l in self.early_exit_layer_record]), file=fstream)
for i in range(self.pds_stages): for i in range(self.pds_stages):
idx = 0 idx = 0
stage = getattr(self, f"stage{i + 1}") stage = getattr(self, f"stage{i + 1}")
...@@ -1016,8 +1032,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1016,8 +1032,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
return x, encoder_padding_mask, input_lengths, mixup return x, encoder_padding_mask, input_lengths, mixup
def set_ctc_infer( def set_ctc_infer(
self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None self, ctc_infer, post_process, src_dict=None, tgt_dict=None, path=None, early_exit_count=0
): ):
self.early_exit_count = early_exit_count
if hasattr(self, "ctc"): if hasattr(self, "ctc"):
import os import os
assert src_dict is not None assert src_dict is not None
...@@ -1040,6 +1057,27 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1040,6 +1057,27 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
logger.error("No ctc module in the encoder") logger.error("No ctc module in the encoder")
def early_exit_or_not(self, history, new_logit, count):
history.append(new_logit)
length = len(history)
if count == 0 or length < count:
return False
else:
# for logit in history[length - count: length - 1]:
# if new_logit.size() != logit.size() or not (new_logit == logit).all():
# return False
# return True
hit = 0
for logit in history[: length - 1]:
if new_logit.size() == logit.size() and (new_logit == logit).all():
hit += 1
if hit >= count:
return True
else:
return False
def forward(self, src_tokens, src_lengths, **kwargs): def forward(self, src_tokens, src_lengths, **kwargs):
batch = src_tokens.size(0) batch = src_tokens.size(0)
...@@ -1071,6 +1109,18 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1071,6 +1109,18 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_logits = [] inter_ctc_logits = []
xctc_logit = None xctc_logit = None
inter_xctc_logits = [] inter_xctc_logits = []
# Infer early exit
org_bsz = x.size(1)
batch_idx_dict = dict()
inter_ctc_logits_history = dict()
final_ctc_logits = dict()
final_encoder_padding_mask = dict()
early_exit_layer = dict()
for i in range(x.size(1)):
inter_ctc_logits_history[i] = []
batch_idx_dict[i] = i
for i in range(self.pds_stages): for i in range(self.pds_stages):
downsampling = getattr(self, f"downsampling{i + 1}") downsampling = getattr(self, f"downsampling{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}") pos_embed = getattr(self, f"pos_embed{i + 1}")
...@@ -1088,6 +1138,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1088,6 +1138,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
x, encoder_padding_mask x, encoder_padding_mask
) )
if downsampling is not None:
x, input_lengths = downsampling(x, input_lengths) x, input_lengths = downsampling(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask_with_maxlen( encoder_padding_mask = lengths_to_padding_mask_with_maxlen(
input_lengths, x.size(0) input_lengths, x.size(0)
...@@ -1172,6 +1223,19 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -1172,6 +1223,19 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_logits.append([logit.clone(), encoder_padding_mask]) inter_ctc_logits.append([logit.clone(), encoder_padding_mask])
if not self.training and self.early_exit_count != 0:
predicts = ctc.predict(logit, encoder_padding_mask)
if len(inter_ctc_logits) < self.early_exit_count:
for i in range(x.size(1)):
inter_ctc_logits_history[i].append(predicts[i])
else:
if org_bsz == 1:
early_exit_flag = self.early_exit_or_not(inter_ctc_logits_history[0], predicts[0], self.early_exit_count)
if early_exit_flag:
ctc_logit = logit
self.early_exit_layer_record.append(layer_idx)
break
# Inter XCTC # Inter XCTC
if xctc is not None: if xctc is not None:
norm_x = xctc_norm(x) norm_x = xctc_norm(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论