Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
FairseqDecoder
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
libei
FairseqDecoder
Commits
e186b036
Commit
e186b036
authored
Mar 05, 2019
by
libei
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
revise share all embedding bugs and inference bugs
parent
b97a0cfd
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
32 行增加
和
30 行删除
+32
-30
fairseq/utils.py
+29
-29
scripts/convert_t2t_to_fairseq.py
+3
-1
没有找到文件。
fairseq/utils.py
查看文件 @
e186b036
...
@@ -131,36 +131,36 @@ def _upgrade_state_dict(state):
...
@@ -131,36 +131,36 @@ def _upgrade_state_dict(state):
return
state
return
state
def
load_ensemble_for_inference
(
filenames
,
task
,
model_arg_overrides
=
None
):
#
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
#
"""Load an ensemble of models for inference.
#
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
#
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
#
{'arg_name': arg} -- to override model args that were used during model
training
#
training
"""
#
"""
# load model architectures and weights
#
# load model architectures and weights
states
=
[]
#
states = []
for
filename
in
filenames
:
#
for filename in filenames:
if
not
os
.
path
.
exists
(
filename
):
#
if not os.path.exists(filename):
raise
IOError
(
'Model file not found: {}'
.
format
(
filename
))
#
raise IOError('Model file not found: {}'.format(filename))
state
=
torch
.
load
(
filename
,
map_location
=
lambda
s
,
l
:
default_restore_location
(
s
,
'cpu'
))
#
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
#state = _upgrade_state_dict(state)
#
#state = _upgrade_state_dict(state)
states
.
append
(
state
)
#
states.append(state)
args
=
states
[
0
][
'args'
]
#
args = states[0]['args']
if
model_arg_overrides
is
not
None
:
#
if model_arg_overrides is not None:
args
=
_override_model_args
(
args
,
model_arg_overrides
)
#
args = _override_model_args(args, model_arg_overrides)
#
# build ensemble
#
# build ensemble
ensemble
=
[]
#
ensemble = []
for
state
in
states
:
#
for state in states:
model
=
task
.
build_model
(
args
)
#
model = task.build_model(args)
print
(
model
)
#
print(model)
model
.
upgrade_state_dict
(
state
[
'model'
])
#
model.upgrade_state_dict(state['model'])
model
.
load_state_dict
(
state
[
'model'
],
strict
=
True
)
#
model.load_state_dict(state['model'], strict=True)
ensemble
.
append
(
model
)
#
ensemble.append(model)
return
ensemble
,
args
#
return ensemble, args
def
load_ensemble_for_inference
_2
(
filenames
,
task
,
model_arg_overrides
=
None
):
def
load_ensemble_for_inference
(
filenames
,
task
,
model_arg_overrides
=
None
):
"""Load an ensemble of models for inference.
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
...
...
scripts/convert_t2t_to_fairseq.py
查看文件 @
e186b036
...
@@ -91,10 +91,12 @@ def find_useful_param(model_file):
...
@@ -91,10 +91,12 @@ def find_useful_param(model_file):
share_decoder_input_and_softmax
=
True
share_decoder_input_and_softmax
=
True
tgt_vocab_size
=
int
(
match
.
group
(
1
))
tgt_vocab_size
=
int
(
match
.
group
(
1
))
match
=
re
.
match
(
'symbol_modality_(
\
d+)_(
\
d+)/shared/
$
'
,
key
)
match
=
re
.
match
(
'symbol_modality_(
\
d+)_(
\
d+)/shared/'
,
key
)
if
match
and
share_all_embedding
is
False
:
if
match
and
share_all_embedding
is
False
:
share_all_embedding
=
True
share_all_embedding
=
True
src_vocab_size
=
int
(
match
.
group
(
1
))
tgt_vocab_size
=
int
(
match
.
group
(
1
))
tgt_vocab_size
=
int
(
match
.
group
(
1
))
emb_size
=
int
(
match
.
group
(
2
))
except
Exception
as
e
:
# pylint: disable=broad-except
except
Exception
as
e
:
# pylint: disable=broad-except
print
(
str
(
e
))
print
(
str
(
e
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论