Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
S
S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
S2T
Commits
a64cdfcc
Commit
a64cdfcc
authored
Mar 16, 2024
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update pds
parent
e59c8eb4
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
61 行增加
和
80 行删除
+61
-80
fairseq/models/speech_to_text/pdss2t_transformer.py
+38
-69
fairseq/models/speech_to_text/s2t_ctc.py
+23
-11
没有找到文件。
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
a64cdfcc
...
@@ -403,9 +403,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -403,9 +403,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
inter_ctc_mlo
=
getattr
(
args
,
"inter_ctc_mlo"
,
""
)
inter_ctc_mlo
=
getattr
(
args
,
"inter_ctc_mlo"
,
""
)
if
inter_ctc_mlo
!=
""
:
if
inter_ctc_mlo
!=
""
:
inter_ctc_mlo
=
[
int
(
x
)
for
x
in
inter_ctc_mlo
.
split
(
":"
)]
inter_ctc_mlo
=
[
int
(
x
)
for
x
in
inter_ctc_mlo
.
split
(
":"
)]
if
self
.
share_inter_ctc
is
True
:
#
if self.share_inter_ctc is True:
self
.
share_inter_ctc
=
False
#
self.share_inter_ctc = False
logger
.
info
(
"Overwrite the config share_inter_ctc to False for MLO."
)
#
logger.info("Overwrite the config share_inter_ctc to False for MLO.")
# PDS XCTC
# PDS XCTC
args
.
pds_xctc
=
getattr
(
args
,
"pds_xctc"
,
None
)
args
.
pds_xctc
=
getattr
(
args
,
"pds_xctc"
,
None
)
...
@@ -416,9 +416,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -416,9 +416,9 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
)
)
self
.
share_inter_xctc
=
getattr
(
args
,
"share_inter_xctc"
,
False
)
self
.
share_inter_xctc
=
getattr
(
args
,
"share_inter_xctc"
,
False
)
ctc_dict
=
dict
()
ctc_pae_dict
=
dict
()
ctc_idx
=
0
ctc_idx
=
0
inter_ctc
=
None
inter_ctc_pae
=
None
xctc_idx
=
0
xctc_idx
=
0
inter_xctc
=
None
inter_xctc
=
None
...
@@ -644,76 +644,44 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -644,76 +644,44 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# Inter CTC
# Inter CTC
if
use_ctc
:
if
use_ctc
:
ctc_norm
=
LayerNorm
(
embed_dim
)
ctc_norm
=
LayerNorm
(
embed_dim
)
if
not
self
.
share_inter_ctc
:
vocab
=
task
.
get_source_dictionary
(
inter_ctc_mlo
[
ctc_idx
]
-
1
ctc
=
CTC
(
if
inter_ctc_mlo
!=
""
else
-
1
)
embed_dim
,
vocab_size
=
len
(
vocab
)
dictionary_size
=
len
(
ctc
=
CTC
(
task
.
get_source_dictionary
(
embed_dim
,
inter_ctc_mlo
[
ctc_idx
]
-
1
dictionary_size
=
vocab_size
,
if
inter_ctc_mlo
!=
""
dropout
=
args
.
dropout
,
else
-
1
dictionary
=
vocab
)
)
),
if
vocab_size
not
in
ctc_dict
:
dropout
=
args
.
dropout
,
ctc_dict
[
vocab_size
]
=
ctc
dictionary
=
task
.
source_dictionary
)
inter_ctc
=
ctc
else
:
ctc
=
CTC
(
embed_dim
,
dictionary_size
=
len
(
task
.
source_dictionary
),
dropout
=
args
.
dropout
,
dictionary
=
task
.
source_dictionary
)
if
self
.
share_inter_ctc
:
if
(
if
(
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
getattr
(
args
,
"share_ctc_and_embed"
,
False
)
and
task
.
source_dictionary
==
task
.
target_dictionary
and
vocab
==
task
.
target_dictionary
and
embed_tokens
is
not
None
and
embed_tokens
is
not
None
and
embed_dim
==
embed_tokens
.
embedding_dim
and
embed_dim
==
embed_tokens
.
embedding_dim
):
):
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
if
(
if
vocab_size
in
ctc_dict
:
inter_ctc
is
not
None
logger
.
warning
(
"Use the existing CTC."
)
and
ctc
.
ctc_projection
.
weight
.
shape
ctc
=
ctc_dict
[
vocab_size
]
==
inter_ctc
.
ctc_projection
.
weight
.
shape
):
ctc
=
inter_ctc
else
:
inter_ctc
=
ctc
ctc_pae
=
None
if
i
!=
self
.
pds_stages
-
1
:
if
i
!=
self
.
pds_stages
-
1
:
if
not
self
.
share_inter_ctc
:
ctc_pae
=
Adapter
(
ctc_pae
=
Adapter
(
embed_dim
,
embed_dim
,
args
.
ctc_pae
,
args
.
ctc_pae
,
vocab_size
,
len
(
strategy
=
ctc_pae_strategy
,)
task
.
get_source_dictionary
(
inter_ctc_mlo
[
ctc_idx
]
-
1
if
self
.
share_inter_ctc
:
if
inter_ctc_mlo
!=
""
if
(
vocab_size
in
ctc_pae_dict
else
-
1
and
ctc_pae
.
dim
==
ctc_pae_dict
[
vocab_size
]
.
dim
)
),
strategy
=
ctc_pae_strategy
,
)
else
:
ctc_pae
=
Adapter
(
embed_dim
,
args
.
ctc_pae
,
len
(
task
.
get_source_dictionary
(
i
)),
strategy
=
ctc_pae_strategy
,
)
if
(
inter_ctc_pae
is
not
None
and
ctc_pae
.
dim
==
inter_ctc_pae
.
dim
and
ctc_pae
.
dict_size
==
inter_ctc_pae
.
dict_size
):
):
ctc_pae
=
inter_ctc_pae
ctc_pae
=
ctc_pae_dict
[
vocab_size
]
else
:
inter_ctc_pae
=
ctc_pae
else
:
ctc_pae
=
None
ctc_idx
+=
1
ctc_idx
+=
1
else
:
else
:
ctc
=
None
ctc
=
None
...
@@ -838,6 +806,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -838,6 +806,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if
self
.
inter_ctc
:
if
self
.
inter_ctc
:
logger
.
info
(
"Intermediate CTC loss in layer
%
d"
%
self
.
ctc_layer
)
logger
.
info
(
"Intermediate CTC loss in layer
%
d"
%
self
.
ctc_layer
)
vocab_size
=
len
(
task
.
source_dictionary
)
embed_dim
=
self
.
embed_dim
embed_dim
=
self
.
embed_dim
if
self
.
inter_ctc
:
if
self
.
inter_ctc
:
ctc_layer
=
self
.
ctc_layer
ctc_layer
=
self
.
ctc_layer
...
@@ -848,10 +817,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -848,10 +817,10 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
break
break
self
.
ctc
=
CTC
(
self
.
ctc
=
CTC
(
embed_dim
,
embed_dim
,
dictionary_size
=
len
(
task
.
source_dictionary
)
,
dictionary_size
=
vocab_size
,
dropout
=
args
.
dropout
,
dropout
=
args
.
dropout
,
dictionary
=
task
.
source_dictionary
,
need_layernorm
=
True
if
self
.
inter_ctc
else
False
,
need_layernorm
=
True
if
self
.
inter_ctc
else
False
,
dictionary
=
task
.
source_dictionary
)
)
if
(
if
(
...
@@ -863,11 +832,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -863,11 +832,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
.
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
self
.
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
if
(
if
(
inter_ctc
is
not
None
vocab_size
in
ctc_dict
and
self
.
ctc
.
ctc_projection
.
weight
.
shape
and
self
.
ctc
.
ctc_projection
.
weight
.
shape
==
inter_ctc
.
ctc_projection
.
weight
.
shape
==
ctc_dict
[
vocab_size
]
.
ctc_projection
.
weight
.
shape
):
):
self
.
ctc
.
ctc_projection
=
inter_ctc
.
ctc_projection
self
.
ctc
.
ctc_projection
=
ctc_dict
[
vocab_size
]
.
ctc_projection
# XCTC
# XCTC
self
.
use_xctc
=
getattr
(
args
,
"disable_xctc"
,
False
)
is
False
and
getattr
(
args
,
"xctc_weight"
,
0
)
>
0
self
.
use_xctc
=
getattr
(
args
,
"disable_xctc"
,
False
)
is
False
and
getattr
(
args
,
"xctc_weight"
,
0
)
>
0
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
a64cdfcc
...
@@ -244,18 +244,30 @@ class CTCDecoder(object):
...
@@ -244,18 +244,30 @@ class CTCDecoder(object):
bsz
,
src_len
=
src_tokens
.
size
()[:
2
]
bsz
,
src_len
=
src_tokens
.
size
()[:
2
]
if
self
.
cal_flops
:
if
self
.
cal_flops
:
from
thop
import
profile
# from thop import profile
macs
,
encoder_outs
=
profile
(
self
.
model
,
inputs
=
(
net_input
.
values
()))
# macs, encoder_outs = profile(self.model, inputs=[src_tokens, src_lengths])
gmacs
=
macs
/
1e9
# gmacs = macs / 1e9
logger
.
info
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
# logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
# print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from
torchprofile
import
profile_macs
# from torchprofile import profile_macs
macs
=
profile_macs
(
self
.
model
,
[
src_tokens
,
src_lengths
])
# macs = profile_macs(self.model, [src_tokens, src_lengths])
gmacs
=
macs
/
1e9
# gmacs = macs / 1e9
logger
.
info
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
# logger.info("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
print
(
"GMACs:
%
f. GFLOPs:
%
f"
%
(
gmacs
,
gmacs
*
2
))
# print("GMACs: %f. GFLOPs: %f" % (gmacs, gmacs * 2))
from
deepspeed.profiling.flops_profiler
import
get_model_profile
from
deepspeed.accelerator
import
get_accelerator
with
get_accelerator
()
.
device
(
0
):
flops
,
macs
,
params
=
get_model_profile
(
model
=
self
.
model
,
kwargs
=
{
"src_tokens"
:
src_tokens
,
"src_lengths"
:
src_lengths
},
print_profile
=
True
,
detailed
=
True
,
)
logger
.
info
(
"flops:
%
s. macs:
%
s, params:
%
s"
%
(
flops
,
macs
,
params
))
print
(
"flops:
%
s. macs:
%
s, params:
%
s"
%
(
flops
,
macs
,
params
))
exit
()
exit
()
encoder_outs
=
self
.
model
(
src_tokens
=
src_tokens
,
encoder_outs
=
self
.
model
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lengths
)
src_lengths
=
src_lengths
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论