Commit a64cdfcc by xuchen

update pds

parent e59c8eb4
...@@ -403,9 +403,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -403,9 +403,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "") inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "")
if inter_ctc_mlo != "": if inter_ctc_mlo != "":
inter_ctc_mlo = [int(x) for x in inter_ctc_mlo.split(":")] inter_ctc_mlo = [int(x) for x in inter_ctc_mlo.split(":")]
if self.share_inter_ctc is True: # if self.share_inter_ctc is True:
self.share_inter_ctc = False # self.share_inter_ctc = False
logger.info("Overwrite the config share_inter_ctc to False for MLO.") # logger.info("Overwrite the config share_inter_ctc to False for MLO.")
# PDS XCTC # PDS XCTC
args.pds_xctc = getattr(args, "pds_xctc", None) args.pds_xctc = getattr(args, "pds_xctc", None)
...@@ -416,9 +416,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -416,9 +416,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
) )
self.share_inter_xctc = getattr(args, "share_inter_xctc", False) self.share_inter_xctc = getattr(args, "share_inter_xctc", False)
ctc_dict = dict()
ctc_pae_dict = dict()
ctc_idx = 0 ctc_idx = 0
inter_ctc = None
inter_ctc_pae = None
xctc_idx = 0 xctc_idx = 0
inter_xctc = None inter_xctc = None
...@@ -644,76 +644,44 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -644,76 +644,44 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# Inter CTC # Inter CTC
if use_ctc: if use_ctc:
ctc_norm = LayerNorm(embed_dim) ctc_norm = LayerNorm(embed_dim)
if not self.share_inter_ctc: vocab = task.get_source_dictionary(inter_ctc_mlo[ctc_idx] - 1
if inter_ctc_mlo != "" else -1)
vocab_size = len(vocab)
ctc = CTC( ctc = CTC(
embed_dim, embed_dim,
dictionary_size=len( dictionary_size=vocab_size,
task.get_source_dictionary(
inter_ctc_mlo[ctc_idx] - 1
if inter_ctc_mlo != ""
else -1
)
),
dropout=args.dropout,
dictionary=task.source_dictionary
)
inter_ctc = ctc
else:
ctc = CTC(
embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary dictionary=vocab
) )
if vocab_size not in ctc_dict:
ctc_dict[vocab_size] = ctc
if self.share_inter_ctc:
if ( if (
getattr(args, "share_ctc_and_embed", False) getattr(args, "share_ctc_and_embed", False)
and task.source_dictionary == task.target_dictionary and vocab == task.target_dictionary
and embed_tokens is not None and embed_tokens is not None
and embed_dim == embed_tokens.embedding_dim and embed_dim == embed_tokens.embedding_dim
): ):
ctc.ctc_projection.weight = embed_tokens.weight ctc.ctc_projection.weight = embed_tokens.weight
if ( if vocab_size in ctc_dict:
inter_ctc is not None logger.warning("Use the existing CTC.")
and ctc.ctc_projection.weight.shape ctc = ctc_dict[vocab_size]
== inter_ctc.ctc_projection.weight.shape
):
ctc = inter_ctc
else:
inter_ctc = ctc
ctc_pae = None
if i != self.pds_stages - 1: if i != self.pds_stages - 1:
if not self.share_inter_ctc:
ctc_pae = Adapter( ctc_pae = Adapter(
embed_dim, embed_dim,
args.ctc_pae, args.ctc_pae,
len( vocab_size,
task.get_source_dictionary( strategy=ctc_pae_strategy,)
inter_ctc_mlo[ctc_idx] - 1
if inter_ctc_mlo != "" if self.share_inter_ctc:
else -1 if (vocab_size in ctc_pae_dict
) and ctc_pae.dim == ctc_pae_dict[vocab_size].dim
),
strategy=ctc_pae_strategy,
)
else:
ctc_pae = Adapter(
embed_dim,
args.ctc_pae,
len(task.get_source_dictionary(i)),
strategy=ctc_pae_strategy,
)
if (
inter_ctc_pae is not None
and ctc_pae.dim == inter_ctc_pae.dim
and ctc_pae.dict_size == inter_ctc_pae.dict_size
): ):
ctc_pae = inter_ctc_pae ctc_pae = ctc_pae_dict[vocab_size]
else:
inter_ctc_pae = ctc_pae
else:
ctc_pae = None
ctc_idx += 1 ctc_idx += 1
else: else:
ctc = None ctc = None
...@@ -838,6 +806,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -838,6 +806,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if self.inter_ctc: if self.inter_ctc:
logger.info("Intermediate CTC loss in layer %d" % self.ctc_layer) logger.info("Intermediate CTC loss in layer %d" % self.ctc_layer)
vocab_size = len(task.source_dictionary)
embed_dim = self.embed_dim embed_dim = self.embed_dim
if self.inter_ctc: if self.inter_ctc:
ctc_layer = self.ctc_layer ctc_layer = self.ctc_layer
...@@ -848,10 +817,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -848,10 +817,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
break break
self.ctc = CTC( self.ctc = CTC(
embed_dim, embed_dim,
dictionary_size=len(task.source_dictionary), dictionary_size=vocab_size,
dropout=args.dropout, dropout=args.dropout,
dictionary=task.source_dictionary,
need_layernorm=True if self.inter_ctc else False, need_layernorm=True if self.inter_ctc else False,
dictionary=task.source_dictionary
) )
if ( if (
...@@ -863,11 +832,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -863,11 +832,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = embed_tokens.weight self.ctc.ctc_projection.weight = embed_tokens.weight
if ( if (
inter_ctc is not None vocab_size in ctc_dict
and self.ctc.ctc_projection.weight.shape and self.ctc.ctc_projection.weight.shape
== inter_ctc.ctc_projection.weight.shape == ctc_dict[vocab_size].ctc_projection.weight.shape
): ):
self.ctc.ctc_projection = inter_ctc.ctc_projection self.ctc.ctc_projection = ctc_dict[vocab_size].ctc_projection
# XCTC # XCTC
self.use_xctc = getattr(args, "disable_xctc", False) is False and getattr(args, "xctc_weight", 0) > 0 self.use_xctc = getattr(args, "disable_xctc", False) is False and getattr(args, "xctc_weight", 0) > 0
......
...@@ -244,18 +244,30 @@ class CTCDecoder(object): ...@@ -244,18 +244,30 @@ class CTCDecoder(object):
bsz, src_len = src_tokens.size()[:2] bsz, src_len = src_tokens.size()[:2]
if self.cal_flops: if self.cal_flops:
from thop import profile # from thop import profile
macs, encoder_outs = profile(self.model, inputs=(net_input.values())) # macs, encoder_outs = profile(self.model, inputs=[src_tokens, src_lengths])
gmacs = macs / 1e9 # gmacs = macs / 1e9
logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) # logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) # print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from torchprofile import profile_macs # from torchprofile import profile_macs
macs = profile_macs(self.model, [src_tokens, src_lengths]) # macs = profile_macs(self.model, [src_tokens, src_lengths])
gmacs = macs / 1e9 # gmacs = macs / 1e9
logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) # logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2)) # print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator
with get_accelerator().device(0):
flops, macs, params = get_model_profile(model=self.model,
kwargs={"src_tokens": src_tokens, "src_lengths": src_lengths},
print_profile=True,
detailed=True,
)
logger.info("flops: %s. macs: %s, params: %s" % (flops, macs, params))
print("flops: %s. macs: %s, params: %s" % (flops, macs, params))
exit() exit()
encoder_outs = self.model(src_tokens=src_tokens, encoder_outs = self.model(src_tokens=src_tokens,
src_lengths=src_lengths) src_lengths=src_lengths)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论