Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
W
WMT19-1.0.14
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
libei
WMT19-1.0.14
Commits
6097530a
Commit
6097530a
authored
Feb 15, 2019
by
libei
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove file
parent
037d45dc
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
0 行增加
和
570 行删除
+0
-570
tensor2tensor/models/libei.py
+0
-570
没有找到文件。
tensor2tensor/models/libei.py
deleted
100644 → 0
查看文件 @
037d45dc
#!/usr/bin/env python
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
from
tensorflow.python.ops
import
init_ops
import
six
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
from
transformer.utils.data_parallel
import
data_parallelism
from
transformer.layers.hparams
import
*
from
transformer.layers
import
modules
from
transformer.layers
import
attention_layer
from
transformer.layers
import
ffn_layer
from
transformer.utils
import
beam_search
from
transformer.optimizer
import
optimize
,
learning_rate
class
Transformer
(
object
):
def
__init__
(
self
,
hparams
,
vocab_tables
,
batched_input
,
mode
):
""" The Transformer model including Beam search.
Args:
hparams: Total params.
vocab_tables: a dict, from dataset.
batched_input: a iterator, shape [src_len, tgt_len].
mode: a string, TRAIN or INFER or EVAL.
"""
tf
.
logging
.
info
(
"Creating graph."
)
self
.
_hparams
=
eval
(
hparams
.
model_params
)(
hparams
)
self
.
_mode
=
mode
self
.
_decode_gpu
=
hparams
.
decode_gpu
tf
.
logging
.
info
(
"batch size is
%
d"
%
self
.
_hparams
.
batch_size
)
tf
.
logging
.
info
(
"Optimizer is
%
s"
%
self
.
_hparams
.
optimizer
)
if
self
.
_mode
!=
"TRAIN"
:
self
.
_hparams
.
residual_dropout
=
0.0
self
.
_hparams
.
attention_dropout
=
0.0
self
.
_hparams
.
relu_dropout
=
0.0
self
.
_hparams
.
dropout
=
0.0
self
.
_hparams
.
worker_gpu
=
1
self
.
_hparams
.
batch_size
=
hparams
.
decode_batch_size
self
.
_data_parallelism
=
data_parallelism
(
self
.
_hparams
.
worker_gpu
,
self
.
_decode_gpu
)
else
:
self
.
_data_parallelism
=
data_parallelism
(
self
.
_hparams
.
worker_gpu
)
self
.
_devices
=
self
.
_data_parallelism
.
devices
self
.
_num_datashards
=
self
.
_data_parallelism
.
n
self
.
_vocab_tables
=
vocab_tables
tf
.
train
.
get_or_create_global_step
()
tf
.
get_variable_scope
()
.
set_initializer
(
init_ops
.
variance_scaling_initializer
(
self
.
_hparams
.
initializer_gain
,
mode
=
"fan_avg"
,
distribution
=
"uniform"
))
features
=
{}
# source: shape [batch, src_len]
features
[
"inputs"
]
=
batched_input
.
source
self
.
inputs
=
features
[
"inputs"
]
# target: shape [batch, tgt_len]
features
[
"targets"
]
=
batched_input
.
target
self
.
targets
=
features
[
"targets"
]
features
[
"label"
]
=
batched_input
.
target
[:,
1
:]
self
.
label
=
features
[
"label"
]
if
self
.
_mode
==
"INFER"
:
self
.
result
=
self
.
_beam_decode
(
features
)
tf
.
logging
.
info
(
"Create decoding graph finished."
)
elif
self
.
_mode
==
"EVAL"
:
self
.
batch_loss
,
self
.
word_num
=
self
.
sharded_model
(
features
)
else
:
sharded_logits
,
training_loss
=
self
.
sharded_model
(
features
)
modules
.
print_all_weights
()
with
tf
.
name_scope
(
"learing_rate"
):
lr
=
learning_rate
.
learning_rate_schedule
(
self
.
_hparams
)
tf
.
summary
.
scalar
(
"learning_rate"
,
lr
/
500
)
train_op
=
optimize
.
optimize
(
training_loss
,
lr
,
self
.
_hparams
)
self
.
training_loss
=
training_loss
self
.
train_op
=
train_op
tf
.
logging
.
info
(
"Create training graph finished."
)
def
sharded_model
(
self
,
features
):
# sharded_features: a dict, e.g. source shape [worker_gpu, batch_size, src_len]
sharded_features
=
self
.
_shard_features
(
features
)
# shape [worker_gpu, batch_size, tgt_len, tgt_vocab]
sharded_logits
=
self
.
_data_parallelism
(
self
.
encode_decoder
,
sharded_features
[
"inputs"
],
sharded_features
[
"targets"
])
if
self
.
_mode
!=
"INFER"
:
loss_num
,
loss_den
=
self
.
_data_parallelism
(
modules
.
padded_cross_entropy
,
sharded_logits
,
sharded_features
[
"label"
],
self
.
_hparams
.
label_smoothing
,
weights_fn
=
modules
.
weights_nonzero
)
training_loss
=
tf
.
add_n
(
loss_num
)
/
tf
.
maximum
(
1.0
,
tf
.
add_n
(
loss_den
))
if
self
.
_mode
==
"TRAIN"
:
return
sharded_logits
,
training_loss
elif
self
.
_mode
==
"EVAL"
:
batch_loss
=
tf
.
add_n
(
loss_num
)
word_num
=
tf
.
add_n
(
loss_den
)
return
batch_loss
,
word_num
else
:
return
sharded_logits
def
encode_decoder
(
self
,
inputs
,
targets
):
""" The whole Transformer model including encoder and decoder.
Args:
inputs: source ids, shape [batch_size, src_len]
targets: target ids, shape [batch_size, tgt_len]
Other Variables:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
encoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
decoder_self_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
Returns:
logits: logits(before softmax), shape [batch_size, tgt_len, tgt_vocab]
"""
encoder_input
,
encoder_attention_bias
=
self
.
prepare_encoder
(
inputs
)
encoder_output
=
self
.
encoder
(
encoder_input
,
encoder_attention_bias
)
decoder_input
,
decoder_self_attention_bias
=
self
.
prepare_decoder
(
targets
)
logits
=
self
.
decoder
(
decoder_input
,
decoder_self_attention_bias
,
encoder_output
,
encoder_attention_bias
)
return
logits
def
prepare_encoder
(
self
,
inputs
):
"""input embedding, positional embedding and mask-matrix.
Args:
inputs: encoder input ids, shape [batch_size, src_len]
Other Variables:
inputs_embedding: input embedding with padding, shape [batch_size, src_len, hidden_size]
positional_encoding: positional embedding, shape [batch_size, src_len, hidden_size]
Returns:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
encoder_self_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
"""
with
tf
.
variable_scope
(
"input_embedding"
,
reuse
=
False
):
inputs_embedding
=
self
.
input_embedding
(
inputs
,
self
.
_vocab_tables
[
"src_size"
])
self
.
input_emb
=
inputs_embedding
encoder_self_attention_bias
=
modules
.
attention_bias_ignore_padding
(
inputs
)
positional_encoding
=
modules
.
add_positional_encoding
(
tf
.
shape
(
inputs
)[
1
],
self
.
_hparams
.
hidden_size
)
encoder_input
=
inputs_embedding
+
positional_encoding
return
encoder_input
,
encoder_self_attention_bias
def
encoder
(
self
,
encoder_input
,
encoder_attention_bias
):
""" transformer encoder.
Args:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
encoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
Returns:
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
"""
hparams
=
copy
.
copy
(
self
.
_hparams
)
if
hparams
.
residual_dropout
:
encoder_input
=
tf
.
nn
.
dropout
(
encoder_input
,
1.0
-
hparams
.
residual_dropout
)
encoder_output
=
self
.
transformer_encoder
(
encoder_input
,
encoder_attention_bias
,
hparams
)
return
encoder_output
def
transformer_encoder
(
self
,
encoder_input
,
encoder_self_attention_bias
,
hparams
,
name
=
"encoder"
):
""" transformer encoder.
Args:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
residual_fn: residual function.
encoder_self_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
hparams: model params
name: variable name
Returns:
x: encoder_output, shape [batch_size, src_len, hidden_size]
"""
x
=
encoder_input
residual_dropout_broadcast_dims
=
(
modules
.
comma_separated_string_to_integer_list
(
getattr
(
hparams
,
"residual_dropout_broadcast_dims"
,
""
))
)
attention_dropout_broadcast_dims
=
(
modules
.
comma_separated_string_to_integer_list
(
getattr
(
hparams
,
"attention_dropout_broadcast_dims"
,
""
)))
with
tf
.
variable_scope
(
name
):
for
layer
in
xrange
(
hparams
.
encoder_layers
):
with
tf
.
variable_scope
(
"layer_
%
d"
%
layer
):
# self-attention network
residual
=
x
x
=
self
.
may_be_layernorm
(
x
,
before
=
True
,
name
=
"self_attention_before"
)
x
=
attention_layer
.
multihead_attention
(
x
,
None
,
encoder_self_attention_bias
,
hparams
.
hidden_size
,
hparams
.
hidden_size
,
hparams
.
hidden_size
,
hparams
.
num_heads
,
hparams
.
attention_dropout
,
dropout_broadcast_dims
=
attention_dropout_broadcast_dims
,
name
=
"encoder_self_attention"
)
x
=
modules
.
dropout_with_broadcast_dims
(
x
,
1.0
-
self
.
_hparams
.
residual_dropout
,
broadcast_dims
=
residual_dropout_broadcast_dims
)
x
=
residual
+
x
x
=
self
.
may_be_layernorm
(
x
,
after
=
True
,
name
=
"self_attention_after"
)
# feed-forward network
residual
=
x
x
=
self
.
may_be_layernorm
(
x
,
before
=
True
,
name
=
"ffn_before"
)
x
=
ffn_layer
.
transformer_ffn_layer
(
x
,
hparams
)
x
=
modules
.
dropout_with_broadcast_dims
(
x
,
1.0
-
self
.
_hparams
.
residual_dropout
,
broadcast_dims
=
residual_dropout_broadcast_dims
)
x
=
residual
+
x
x
=
self
.
may_be_layernorm
(
x
,
after
=
True
,
name
=
"ffn_after"
)
if
self
.
_hparams
.
normalize_before
:
x
=
self
.
may_be_layernorm
(
x
,
before
=
True
,
name
=
"norm_last"
)
return
x
def
prepare_decoder
(
self
,
targets
):
"""input embedding, positional embedding and lower triangle mask-matrix.
Args:
targets: decoder input ids, shape [batch_size, tgt_len]
Other Variables:
targets_embedding: decoder input embedding with padding,
shape [batch_size, 0 + tgt_len - EOS_ID, hidden_size]
positional_encoding: positional embedding, shape [batch_size, tgt_len, hidden_size]
Returns:
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
decoder_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
"""
with
tf
.
variable_scope
(
"target_embedding"
,
reuse
=
tf
.
AUTO_REUSE
):
targets_embedding
=
self
.
input_embedding
(
targets
,
self
.
_vocab_tables
[
"tgt_size"
],
reuse
=
tf
.
AUTO_REUSE
)
self
.
tgt_emb
=
targets_embedding
decoder_attention_bias
=
modules
.
attention_bias_lower_triangle
(
tf
.
shape
(
targets
)[
1
])
# targets_embedding = modules.shift_left_3d(targets_embedding)
positional_encoding
=
modules
.
add_positional_encoding
(
tf
.
shape
(
targets
)[
1
],
self
.
_hparams
.
hidden_size
)
decoder_input
=
targets_embedding
+
positional_encoding
return
decoder_input
,
decoder_attention_bias
def
decoder
(
self
,
decoder_input
,
decoder_self_attention_bias
,
encoder_output
,
encoder_attention_bias
,
cache
=
None
):
"""
Args:
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
decoder_self_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
encoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
cache: cache for fast decoding.
Other Variables:
decoder_output: decoder_output, shape [batch_size, tgt_len, hidden_size]
Returns:
logits: decoder output embedding, shape [batch_size, tgt_len, tgt_vocab]
"""
hparams
=
copy
.
copy
(
self
.
_hparams
)
if
hparams
.
residual_dropout
:
decoder_input
=
tf
.
nn
.
dropout
(
decoder_input
,
1.0
-
hparams
.
residual_dropout
)
decoder_output
=
self
.
transformer_decoder
(
decoder_input
,
encoder_output
,
decoder_self_attention_bias
,
encoder_attention_bias
,
hparams
,
cache
=
cache
)
if
self
.
_hparams
.
share_decoder_input_output
:
with
tf
.
variable_scope
(
"target_embedding"
,
reuse
=
tf
.
AUTO_REUSE
):
logits
=
self
.
shared_output_embedding
(
decoder_output
)
else
:
with
tf
.
variable_scope
(
"output_embedding"
,
reuse
=
False
):
logits
=
self
.
output_embedding
(
decoder_output
)
return
logits
def
transformer_decoder
(
self
,
decoder_input
,
encoder_output
,
decoder_self_attention_bias
,
encoder_decoder_attention_bias
,
hparams
,
cache
=
None
,
name
=
"decoder"
):
""" transformer decoder.
Args:
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
residual_fn: residual function.
decoder_self_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
encoder_decoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
hparams: model params.
cache: cache for fast decoding.
name: variable name.
Returns:
x: decoder_output, shape [batch_size, tgt_len, hidden_size]
"""
x
=
decoder_input
residual_dropout_broadcast_dims
=
(
modules
.
comma_separated_string_to_integer_list
(
getattr
(
hparams
,
"residual_dropout_broadcast_dims"
,
""
))
)
attention_dropout_broadcast_dims
=
(
modules
.
comma_separated_string_to_integer_list
(
getattr
(
hparams
,
"attention_dropout_broadcast_dims"
,
""
)))
with
tf
.
variable_scope
(
name
):
for
layer
in
xrange
(
hparams
.
decoder_layers
):
layer_name
=
"layer_
%
d"
%
layer
layer_cache
=
cache
[
layer_name
]
if
cache
is
not
None
else
None
with
tf
.
variable_scope
(
"layer_
%
d"
%
layer
):
# self-attention network
residual
=
x
x
=
self
.
may_be_layernorm
(
x
,
before
=
True
,
name
=
"self_attention_before"
)
x
=
attention_layer
.
multihead_attention
(
x
,
None
,
decoder_self_attention_bias
,
hparams
.
hidden_size
,
hparams
.
hidden_size
,
hparams
.
hidden_size
,
hparams
.
num_heads
,
hparams
.
attention_dropout
,
dropout_broadcast_dims
=
attention_dropout_broadcast_dims
,
cache
=
layer_cache
,
name
=
"decoder_self_attention"
)
x
=
modules
.
dropout_with_broadcast_dims
(
x
,
1.0
-
self
.
_hparams
.
residual_dropout
,
broadcast_dims
=
residual_dropout_broadcast_dims
)
x
=
residual
+
x
x
=
self
.
may_be_layernorm
(
x
,
after
=
True
,
name
=
"self_attention_after"
)
# encoder-decoder-attention network
residual
=
x
x
=
self
.
may_be_layernorm
(
x
,
before
=
True
,
name
=
"encdec_attention_before"
)
x
=
attention_layer
.
multihead_attention
(
x
,
encoder_output
,
encoder_decoder_attention_bias
,
hparams
.
hidden_size
,
hparams
.
hidden_size
,
hparams
.
hidden_size
,
hparams
.
num_heads
,
hparams
.
attention_dropout
,
dropout_broadcast_dims
=
attention_dropout_broadcast_dims
,
cache
=
layer_cache
,
name
=
"encoder_decoder_attention"
)
x
=
modules
.
dropout_with_broadcast_dims
(
x
,
1.0
-
self
.
_hparams
.
residual_dropout
,
broadcast_dims
=
residual_dropout_broadcast_dims
)
x
=
residual
+
x
x
=
self
.
may_be_layernorm
(
x
,
after
=
True
,
name
=
"encdec_attention_after"
)
# feed-forward network
residual
=
x
x
=
self
.
may_be_layernorm
(
x
,
before
=
True
,
name
=
"ffn_before"
)
x
=
ffn_layer
.
transformer_ffn_layer
(
x
,
hparams
)
x
=
modules
.
dropout_with_broadcast_dims
(
x
,
1.0
-
self
.
_hparams
.
residual_dropout
,
broadcast_dims
=
residual_dropout_broadcast_dims
)
x
=
residual
+
x
x
=
self
.
may_be_layernorm
(
x
,
after
=
True
,
name
=
"ffn_after"
)
if
self
.
_hparams
.
normalize_before
:
x
=
self
.
may_be_layernorm
(
x
,
before
=
True
,
name
=
"norm_last"
)
return
x
def
_beam_decode
(
self
,
features
):
beam_size
=
self
.
_hparams
.
beam_size
decode_length
=
self
.
_hparams
.
extra_decode_length
alpha
=
self
.
_hparams
.
decode_alpha
print
(
"beam size:
%
d"
%
beam_size
)
print
(
"alpha:
%
f"
%
alpha
)
print
(
"decode length:
%
d"
%
decode_length
)
every_decode_length
=
decode_length
+
tf
.
reduce_sum
(
tf
.
to_int32
(
tf
.
cast
(
features
[
"inputs"
],
dtype
=
bool
)),
axis
=
1
)
batch_size
=
tf
.
shape
(
features
[
"inputs"
])[
0
]
initial_ids
=
tf
.
ones
([
batch_size
],
dtype
=
tf
.
int32
)
vocab_size
=
self
.
_vocab_tables
[
"tgt_size"
]
# Setting decode length to input length + decode_length
decode_length
=
tf
.
shape
(
features
[
"inputs"
])[
1
]
+
tf
.
constant
(
decode_length
)
encoder_input
,
encoder_attention_bias
=
self
.
prepare_encoder
(
features
[
"inputs"
])
encoder_output
=
self
.
encoder
(
encoder_input
,
encoder_attention_bias
)
decoder_attention_bias
=
modules
.
attention_bias_lower_triangle
(
decode_length
)
positional_encoding
=
modules
.
add_positional_encoding
(
decode_length
,
self
.
_hparams
.
hidden_size
)
cache
=
{
"layer_
%
d"
%
layer
:
{
"k"
:
modules
.
split_heads
(
tf
.
zeros
([
batch_size
,
0
,
self
.
_hparams
.
hidden_size
]),
self
.
_hparams
.
num_heads
),
"v"
:
modules
.
split_heads
(
tf
.
zeros
([
batch_size
,
0
,
self
.
_hparams
.
hidden_size
]),
self
.
_hparams
.
num_heads
),
}
for
layer
in
range
(
self
.
_hparams
.
decoder_layers
)
}
for
layer
in
range
(
self
.
_hparams
.
decoder_layers
):
layer_name
=
"layer_
%
d"
%
layer
with
tf
.
variable_scope
(
"decoder/
%
s/encoder_decoder_attention"
%
layer_name
):
combined
=
modules
.
conv
(
encoder_output
,
2
*
self
.
_hparams
.
hidden_size
,
name
=
"kv_transform"
)
k_encdec
,
v_encdec
=
tf
.
split
(
combined
,
[
self
.
_hparams
.
hidden_size
,
self
.
_hparams
.
hidden_size
],
axis
=
2
)
k_encdec
=
modules
.
split_heads
(
k_encdec
,
self
.
_hparams
.
num_heads
)
v_encdec
=
modules
.
split_heads
(
v_encdec
,
self
.
_hparams
.
num_heads
)
cache
[
layer_name
][
"k_encdec"
]
=
k_encdec
cache
[
layer_name
][
"v_encdec"
]
=
v_encdec
cache
[
"encoder_output"
]
=
encoder_output
cache
[
"encoder_attention_bias"
]
=
encoder_attention_bias
cache
[
"every_decode_length"
]
=
every_decode_length
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Go from ids to logits."""
ids
=
ids
[:,
-
1
:]
with
tf
.
variable_scope
(
"target_embedding"
,
reuse
=
tf
.
AUTO_REUSE
):
targets_embedding
=
self
.
input_embedding
(
ids
,
self
.
_vocab_tables
[
"tgt_size"
])
decoder_input
=
targets_embedding
+
positional_encoding
[:,
i
:
i
+
1
,
:]
d_bias
=
decoder_attention_bias
[:,
:,
i
:
i
+
1
,
:
i
+
1
]
logits
=
self
.
decoder
(
decoder_input
,
d_bias
,
cache
.
get
(
"encoder_output"
),
cache
.
get
(
"encoder_attention_bias"
),
cache
=
cache
)
return
logits
,
cache
ids
,
scores
=
beam_search
.
beam_search
(
symbols_to_logits_fn
,
initial_ids
,
beam_size
,
decode_length
,
vocab_size
,
alpha
,
states
=
cache
)
result
=
ids
[:,
0
,
1
:]
result
=
self
.
_vocab_tables
[
"ids2tgt"
]
.
lookup
(
tf
.
to_int64
(
result
))
return
result
def
input_embedding
(
self
,
x
,
vocab_size
,
name
=
"input_emb"
,
reuse
=
None
):
with
tf
.
variable_scope
(
name
,
reuse
=
reuse
):
embedding
=
tf
.
get_variable
(
"weights"
,
[
vocab_size
,
self
.
_hparams
.
hidden_size
],
initializer
=
tf
.
random_normal_initializer
(
0.0
,
self
.
_hparams
.
hidden_size
**
-
0.5
))
ret
=
tf
.
gather
(
embedding
,
x
)
if
self
.
_hparams
.
multiply_embedding_mode
==
"sqrt_depth"
:
ret
*=
self
.
_hparams
.
hidden_size
**
0.5
ret
*=
tf
.
expand_dims
(
tf
.
to_float
(
tf
.
not_equal
(
x
,
0
)),
-
1
)
return
ret
def
output_embedding
(
self
,
body_output
):
with
tf
.
variable_scope
(
"output_emb"
,
reuse
=
False
):
embedding
=
tf
.
get_variable
(
"weights"
,
[
self
.
_hparams
.
hidden_size
,
self
.
_vocab_tables
[
"tgt_size"
]],
initializer
=
tf
.
random_normal_initializer
(
0.0
,
self
.
_hparams
.
hidden_size
**
-
0.5
))
shape
=
tf
.
shape
(
body_output
)[:
-
1
]
body_output
=
tf
.
reshape
(
body_output
,
[
-
1
,
self
.
_hparams
.
hidden_size
])
logits
=
tf
.
matmul
(
body_output
,
embedding
)
logits
=
tf
.
reshape
(
logits
,
tf
.
concat
([
shape
,
[
self
.
_vocab_tables
[
"tgt_size"
]]],
0
))
# insert a channels dimension
return
logits
def
shared_output_embedding
(
self
,
body_output
):
with
tf
.
variable_scope
(
"input_emb"
,
reuse
=
tf
.
AUTO_REUSE
):
output_embedding
=
tf
.
get_variable
(
"weights"
,
[
self
.
_vocab_tables
[
"tgt_size"
],
self
.
_hparams
.
hidden_size
],
initializer
=
tf
.
random_normal_initializer
(
0.0
,
self
.
_hparams
.
hidden_size
**
-
0.5
))
shape
=
tf
.
shape
(
body_output
)[:
-
1
]
body_output
=
tf
.
reshape
(
body_output
,
[
-
1
,
self
.
_hparams
.
hidden_size
])
logits
=
tf
.
matmul
(
body_output
,
output_embedding
,
transpose_b
=
True
)
logits
=
tf
.
reshape
(
logits
,
tf
.
concat
([
shape
,
[
self
.
_vocab_tables
[
"tgt_size"
]]],
0
))
return
logits
def
residual_fn
(
self
,
x
,
y
,
dropout_broadcast_dims
=
None
,
name
=
None
):
return
modules
.
layer_norm
(
x
+
modules
.
dropout_with_broadcast_dims
(
y
,
1.0
-
self
.
_hparams
.
residual_dropout
,
broadcast_dims
=
dropout_broadcast_dims
),
name
=
name
)
def
may_be_layernorm
(
self
,
input
,
before
=
False
,
after
=
False
,
name
=
None
):
assert
before
^
after
if
after
^
self
.
_hparams
.
normalize_before
:
return
modules
.
layer_norm
(
input
,
name
=
name
)
else
:
return
input
def
_shard_features
(
self
,
features
):
# pylint: disable=missing-docstring
sharded_features
=
dict
()
for
k
,
v
in
six
.
iteritems
(
features
):
v
=
tf
.
convert_to_tensor
(
v
)
if
not
v
.
shape
.
as_list
():
v
=
tf
.
expand_dims
(
v
,
axis
=-
1
)
v
=
tf
.
tile
(
v
,
[
self
.
_num_datashards
])
sharded_features
[
k
]
=
self
.
_data_parallelism
(
tf
.
identity
,
tf
.
split
(
v
,
self
.
_num_datashards
,
0
))
return
sharded_features
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论