Commit 8b47f344 by xuchen

implement the setting of freezing the modules

parent 6c5436e5
...@@ -229,6 +229,18 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -229,6 +229,18 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
metavar="STR", metavar="STR",
help="model to take decoder weights from (for initialization)", help="model to take decoder weights from (for initialization)",
) )
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
pass pass
@classmethod @classmethod
...@@ -273,7 +285,14 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -273,7 +285,14 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
task.target_dictionary, args.decoder_embed_dim task.target_dictionary, args.decoder_embed_dim
) )
encoder = cls.build_encoder(args, task, decoder_embed_tokens) encoder = cls.build_encoder(args, task, decoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, task, decoder_embed_tokens) decoder = cls.build_decoder(args, task, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
return cls(encoder, decoder) return cls(encoder, decoder)
def get_normalized_probs( def get_normalized_probs(
......
...@@ -738,3 +738,14 @@ def eval_bool(x, default=False): ...@@ -738,3 +738,14 @@ def eval_bool(x, default=False):
return bool(eval(x)) return bool(eval(x))
except TypeError: except TypeError:
return default return default
def freeze_parameters(module, freeze_module_name):
def freeze_module_params_by_name(module, name):
for key, value in module.named_parameters():
if name in key:
value.requires_grad = False
freeze_module_name = freeze_module_name.split(",")
for name in freeze_module_name:
freeze_module_params_by_name(module, name)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论