Commit a64cdfcc by xuchen

update pds

parent e59c8eb4
......@@ -403,9 +403,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_mlo = getattr(args, "inter_ctc_mlo", "")
if inter_ctc_mlo != "":
inter_ctc_mlo = [int(x) for x in inter_ctc_mlo.split(":")]
if self.share_inter_ctc is True:
self.share_inter_ctc = False
logger.info("Overwrite the config share_inter_ctc to False for MLO.")
# if self.share_inter_ctc is True:
# self.share_inter_ctc = False
# logger.info("Overwrite the config share_inter_ctc to False for MLO.")
# PDS XCTC
args.pds_xctc = getattr(args, "pds_xctc", None)
......@@ -416,9 +416,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
)
self.share_inter_xctc = getattr(args, "share_inter_xctc", False)
ctc_dict = dict()
ctc_pae_dict = dict()
ctc_idx = 0
inter_ctc = None
inter_ctc_pae = None
xctc_idx = 0
inter_xctc = None
......@@ -644,76 +644,44 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# Inter CTC
if use_ctc:
ctc_norm = LayerNorm(embed_dim)
if not self.share_inter_ctc:
ctc = CTC(
embed_dim,
dictionary_size=len(
task.get_source_dictionary(
inter_ctc_mlo[ctc_idx] - 1
if inter_ctc_mlo != ""
else -1
)
),
dropout=args.dropout,
dictionary=task.source_dictionary
)
inter_ctc = ctc
else:
ctc = CTC(
embed_dim,
dictionary_size=len(task.source_dictionary),
dropout=args.dropout,
dictionary=task.source_dictionary
)
vocab = task.get_source_dictionary(inter_ctc_mlo[ctc_idx] - 1
if inter_ctc_mlo != "" else -1)
vocab_size = len(vocab)
ctc = CTC(
embed_dim,
dictionary_size=vocab_size,
dropout=args.dropout,
dictionary=vocab
)
if vocab_size not in ctc_dict:
ctc_dict[vocab_size] = ctc
if self.share_inter_ctc:
if (
getattr(args, "share_ctc_and_embed", False)
and task.source_dictionary == task.target_dictionary
and vocab == task.target_dictionary
and embed_tokens is not None
and embed_dim == embed_tokens.embedding_dim
):
ctc.ctc_projection.weight = embed_tokens.weight
if (
inter_ctc is not None
and ctc.ctc_projection.weight.shape
== inter_ctc.ctc_projection.weight.shape
):
ctc = inter_ctc
else:
inter_ctc = ctc
if vocab_size in ctc_dict:
logger.warning("Use the existing CTC.")
ctc = ctc_dict[vocab_size]
ctc_pae = None
if i != self.pds_stages - 1:
if not self.share_inter_ctc:
ctc_pae = Adapter(
embed_dim,
args.ctc_pae,
len(
task.get_source_dictionary(
inter_ctc_mlo[ctc_idx] - 1
if inter_ctc_mlo != ""
else -1
)
),
strategy=ctc_pae_strategy,
)
else:
ctc_pae = Adapter(
embed_dim,
args.ctc_pae,
len(task.get_source_dictionary(i)),
strategy=ctc_pae_strategy,
)
if (
inter_ctc_pae is not None
and ctc_pae.dim == inter_ctc_pae.dim
and ctc_pae.dict_size == inter_ctc_pae.dict_size
ctc_pae = Adapter(
embed_dim,
args.ctc_pae,
vocab_size,
strategy=ctc_pae_strategy,)
if self.share_inter_ctc:
if (vocab_size in ctc_pae_dict
and ctc_pae.dim == ctc_pae_dict[vocab_size].dim
):
ctc_pae = inter_ctc_pae
else:
inter_ctc_pae = ctc_pae
else:
ctc_pae = None
ctc_pae = ctc_pae_dict[vocab_size]
ctc_idx += 1
else:
ctc = None
......@@ -838,6 +806,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if self.inter_ctc:
logger.info("Intermediate CTC loss in layer %d" % self.ctc_layer)
vocab_size = len(task.source_dictionary)
embed_dim = self.embed_dim
if self.inter_ctc:
ctc_layer = self.ctc_layer
......@@ -848,10 +817,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
break
self.ctc = CTC(
embed_dim,
dictionary_size=len(task.source_dictionary),
dictionary_size=vocab_size,
dropout=args.dropout,
dictionary=task.source_dictionary,
need_layernorm=True if self.inter_ctc else False,
dictionary=task.source_dictionary
)
if (
......@@ -863,11 +832,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.ctc.ctc_projection.weight = embed_tokens.weight
if (
inter_ctc is not None
vocab_size in ctc_dict
and self.ctc.ctc_projection.weight.shape
== inter_ctc.ctc_projection.weight.shape
== ctc_dict[vocab_size].ctc_projection.weight.shape
):
self.ctc.ctc_projection = inter_ctc.ctc_projection
self.ctc.ctc_projection = ctc_dict[vocab_size].ctc_projection
# XCTC
self.use_xctc = getattr(args, "disable_xctc", False) is False and getattr(args, "xctc_weight", 0) > 0
......
......@@ -244,18 +244,30 @@ class CTCDecoder(object):
bsz, src_len = src_tokens.size()[:2]
if self.cal_flops:
from thop import profile
macs, encoder_outs = profile(self.model, inputs=(net_input.values()))
gmacs = macs / 1e9
logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from torchprofile import profile_macs
macs = profile_macs(self.model, [src_tokens, src_lengths])
gmacs = macs / 1e9
logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
# from thop import profile
# macs, encoder_outs = profile(self.model, inputs=[src_tokens, src_lengths])
# gmacs = macs / 1e9
# logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
# print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
# from torchprofile import profile_macs
# macs = profile_macs(self.model, [src_tokens, src_lengths])
# gmacs = macs / 1e9
# logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
# print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator
with get_accelerator().device(0):
flops, macs, params = get_model_profile(model=self.model,
kwargs={"src_tokens": src_tokens, "src_lengths": src_lengths},
print_profile=True,
detailed=True,
)
logger.info("flops: %s. macs: %s, params: %s" % (flops, macs, params))
print("flops: %s. macs: %s, params: %s" % (flops, macs, params))
exit()
encoder_outs = self.model(src_tokens=src_tokens,
src_lengths=src_lengths)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论