Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
8b47f344
Commit
8b47f344
authored
4 years ago
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement the setting of freezing the modules
parent
6c5436e5
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
30 行增加
和
0 行删除
+30
-0
fairseq/models/speech_to_text/s2t_transformer.py
+19
-0
fairseq/utils.py
+11
-0
没有找到文件。
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
8b47f344
...
...
@@ -229,6 +229,18 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
metavar
=
"STR"
,
help
=
"model to take decoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--encoder-freeze-module"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"freeze the module of the encoder"
,
)
parser
.
add_argument
(
"--decoder-freeze-module"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"freeze the module of the decoder"
,
)
pass
@classmethod
...
...
@@ -273,7 +285,14 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
task
.
target_dictionary
,
args
.
decoder_embed_dim
)
encoder
=
cls
.
build_encoder
(
args
,
task
,
decoder_embed_tokens
)
if
getattr
(
args
,
"encoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
encoder
,
args
.
encoder_freeze_module
)
logging
.
info
(
"freeze the encoder module: {}"
.
format
(
args
.
encoder_freeze_module
))
decoder
=
cls
.
build_decoder
(
args
,
task
,
decoder_embed_tokens
)
if
getattr
(
args
,
"decoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
decoder
,
args
.
decoder_freeze_module
)
logging
.
info
(
"freeze the decoder module: {}"
.
format
(
args
.
decoder_freeze_module
))
return
cls
(
encoder
,
decoder
)
def
get_normalized_probs
(
...
...
This diff is collapsed.
Click to expand it.
fairseq/utils.py
查看文件 @
8b47f344
...
...
@@ -738,3 +738,14 @@ def eval_bool(x, default=False):
return
bool
(
eval
(
x
))
except
TypeError
:
return
default
def
freeze_parameters
(
module
,
freeze_module_name
):
def
freeze_module_params_by_name
(
module
,
name
):
for
key
,
value
in
module
.
named_parameters
():
if
name
in
key
:
value
.
requires_grad
=
False
freeze_module_name
=
freeze_module_name
.
split
(
","
)
for
name
in
freeze_module_name
:
freeze_module_params_by_name
(
module
,
name
)
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论