Commit 6097530a by libei

remove file

parent 037d45dc
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensorflow.python.ops import init_ops
import six
from six.moves import xrange # pylint: disable=redefined-builtin
from transformer.utils.data_parallel import data_parallelism
from transformer.layers.hparams import *
from transformer.layers import modules
from transformer.layers import attention_layer
from transformer.layers import ffn_layer
from transformer.utils import beam_search
from transformer.optimizer import optimize, learning_rate
class Transformer(object):
def __init__(self,
hparams,
vocab_tables,
batched_input,
mode):
""" The Transformer model including Beam search.
Args:
hparams: Total params.
vocab_tables: a dict, from dataset.
batched_input: a iterator, shape [src_len, tgt_len].
mode: a string, TRAIN or INFER or EVAL.
"""
tf.logging.info("Creating graph.")
self._hparams = eval(hparams.model_params)(hparams)
self._mode = mode
self._decode_gpu = hparams.decode_gpu
tf.logging.info("batch size is %d" % self._hparams.batch_size)
tf.logging.info("Optimizer is %s" % self._hparams.optimizer)
if self._mode!="TRAIN":
self._hparams.residual_dropout = 0.0
self._hparams.attention_dropout = 0.0
self._hparams.relu_dropout = 0.0
self._hparams.dropout = 0.0
self._hparams.worker_gpu = 1
self._hparams.batch_size = hparams.decode_batch_size
self._data_parallelism = data_parallelism(self._hparams.worker_gpu, self._decode_gpu)
else:
self._data_parallelism = data_parallelism(self._hparams.worker_gpu)
self._devices = self._data_parallelism.devices
self._num_datashards = self._data_parallelism.n
self._vocab_tables = vocab_tables
tf.train.get_or_create_global_step()
tf.get_variable_scope().set_initializer(init_ops.variance_scaling_initializer(
self._hparams.initializer_gain, mode="fan_avg", distribution="uniform"))
features = {}
# source: shape [batch, src_len]
features["inputs"] = batched_input.source
self.inputs = features["inputs"]
# target: shape [batch, tgt_len]
features["targets"] = batched_input.target
self.targets = features["targets"]
features["label"] = batched_input.target[:, 1:]
self.label = features["label"]
if self._mode == "INFER":
self.result = self._beam_decode(features)
tf.logging.info("Create decoding graph finished.")
elif self._mode == "EVAL":
self.batch_loss, self.word_num = self.sharded_model(features)
else:
sharded_logits, training_loss = self.sharded_model(features)
modules.print_all_weights()
with tf.name_scope("learing_rate"):
lr = learning_rate.learning_rate_schedule(self._hparams)
tf.summary.scalar("learning_rate", lr / 500)
train_op = optimize.optimize(training_loss, lr, self._hparams)
self.training_loss = training_loss
self.train_op = train_op
tf.logging.info("Create training graph finished.")
def sharded_model(self, features):
# sharded_features: a dict, e.g. source shape [worker_gpu, batch_size, src_len]
sharded_features = self._shard_features(features)
# shape [worker_gpu, batch_size, tgt_len, tgt_vocab]
sharded_logits = self._data_parallelism(self.encode_decoder,
sharded_features["inputs"],
sharded_features["targets"])
if self._mode != "INFER":
loss_num, loss_den = self._data_parallelism(
modules.padded_cross_entropy,
sharded_logits,
sharded_features["label"],
self._hparams.label_smoothing,
weights_fn=modules.weights_nonzero)
training_loss = tf.add_n(loss_num) / tf.maximum(1.0, tf.add_n(loss_den))
if self._mode == "TRAIN":
return sharded_logits, training_loss
elif self._mode == "EVAL":
batch_loss = tf.add_n(loss_num)
word_num = tf.add_n(loss_den)
return batch_loss, word_num
else:
return sharded_logits
def encode_decoder(self, inputs, targets):
""" The whole Transformer model including encoder and decoder.
Args:
inputs: source ids, shape [batch_size, src_len]
targets: target ids, shape [batch_size, tgt_len]
Other Variables:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
encoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
decoder_self_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
Returns:
logits: logits(before softmax), shape [batch_size, tgt_len, tgt_vocab]
"""
encoder_input, encoder_attention_bias = self.prepare_encoder(inputs)
encoder_output = self.encoder(encoder_input, encoder_attention_bias)
decoder_input, decoder_self_attention_bias = self.prepare_decoder(targets)
logits = self.decoder(decoder_input,
decoder_self_attention_bias,
encoder_output,
encoder_attention_bias)
return logits
def prepare_encoder(self, inputs):
"""input embedding, positional embedding and mask-matrix.
Args:
inputs: encoder input ids, shape [batch_size, src_len]
Other Variables:
inputs_embedding: input embedding with padding, shape [batch_size, src_len, hidden_size]
positional_encoding: positional embedding, shape [batch_size, src_len, hidden_size]
Returns:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
encoder_self_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
"""
with tf.variable_scope("input_embedding", reuse=False):
inputs_embedding = self.input_embedding(inputs, self._vocab_tables["src_size"])
self.input_emb = inputs_embedding
encoder_self_attention_bias = modules.attention_bias_ignore_padding(inputs)
positional_encoding = modules.add_positional_encoding(tf.shape(inputs)[1], self._hparams.hidden_size)
encoder_input = inputs_embedding + positional_encoding
return encoder_input, encoder_self_attention_bias
def encoder(self,
encoder_input,
encoder_attention_bias):
""" transformer encoder.
Args:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
encoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
Returns:
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
"""
hparams = copy.copy(self._hparams)
if hparams.residual_dropout:
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
encoder_output = self.transformer_encoder(encoder_input,
encoder_attention_bias,
hparams)
return encoder_output
def transformer_encoder(self,
encoder_input,
encoder_self_attention_bias,
hparams,
name="encoder"):
""" transformer encoder.
Args:
encoder_input: embedded input with pos, shape [batch_size, src_len, hidden_size]
residual_fn: residual function.
encoder_self_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
hparams: model params
name: variable name
Returns:
x: encoder_output, shape [batch_size, src_len, hidden_size]
"""
x = encoder_input
residual_dropout_broadcast_dims = (
modules.comma_separated_string_to_integer_list(
getattr(hparams, "residual_dropout_broadcast_dims", ""))
)
attention_dropout_broadcast_dims = (
modules.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
for layer in xrange(hparams.encoder_layers):
with tf.variable_scope("layer_%d" % layer):
# self-attention network
residual = x
x = self.may_be_layernorm(x, before=True, name="self_attention_before")
x = attention_layer.multihead_attention(
x,
None,
encoder_self_attention_bias,
hparams.hidden_size,
hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
name="encoder_self_attention")
x = modules.dropout_with_broadcast_dims(x,
1.0 - self._hparams.residual_dropout,
broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x
x = self.may_be_layernorm(x, after=True, name="self_attention_after")
# feed-forward network
residual = x
x = self.may_be_layernorm(x, before=True, name="ffn_before")
x = ffn_layer.transformer_ffn_layer(x, hparams)
x = modules.dropout_with_broadcast_dims(x,
1.0 - self._hparams.residual_dropout,
broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x
x = self.may_be_layernorm(x, after=True, name="ffn_after")
if self._hparams.normalize_before:
x = self.may_be_layernorm(x, before=True, name="norm_last")
return x
def prepare_decoder(self, targets):
"""input embedding, positional embedding and lower triangle mask-matrix.
Args:
targets: decoder input ids, shape [batch_size, tgt_len]
Other Variables:
targets_embedding: decoder input embedding with padding,
shape [batch_size, 0 + tgt_len - EOS_ID, hidden_size]
positional_encoding: positional embedding, shape [batch_size, tgt_len, hidden_size]
Returns:
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
decoder_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
"""
with tf.variable_scope("target_embedding", reuse=tf.AUTO_REUSE):
targets_embedding = self.input_embedding(targets, self._vocab_tables["tgt_size"],reuse=tf.AUTO_REUSE)
self.tgt_emb = targets_embedding
decoder_attention_bias = modules.attention_bias_lower_triangle(tf.shape(targets)[1])
# targets_embedding = modules.shift_left_3d(targets_embedding)
positional_encoding = modules.add_positional_encoding(tf.shape(targets)[1], self._hparams.hidden_size)
decoder_input = targets_embedding + positional_encoding
return decoder_input, decoder_attention_bias
def decoder(self,
decoder_input,
decoder_self_attention_bias,
encoder_output,
encoder_attention_bias,
cache=None):
"""
Args:
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
decoder_self_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
encoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
cache: cache for fast decoding.
Other Variables:
decoder_output: decoder_output, shape [batch_size, tgt_len, hidden_size]
Returns:
logits: decoder output embedding, shape [batch_size, tgt_len, tgt_vocab]
"""
hparams = copy.copy(self._hparams)
if hparams.residual_dropout:
decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout)
decoder_output = self.transformer_decoder(
decoder_input, encoder_output, decoder_self_attention_bias,
encoder_attention_bias, hparams, cache=cache)
if self._hparams.share_decoder_input_output:
with tf.variable_scope("target_embedding", reuse=tf.AUTO_REUSE):
logits = self.shared_output_embedding(decoder_output)
else:
with tf.variable_scope("output_embedding", reuse=False):
logits = self.output_embedding(decoder_output)
return logits
def transformer_decoder(self,
decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=None,
name="decoder"):
""" transformer decoder.
Args:
decoder_input: embedded decoder input with pos, shape [batch_size, tgt_len, hidden_size]
encoder_output: encoder_output, shape [batch_size, src_len, hidden_size]
residual_fn: residual function.
decoder_self_attention_bias: lower triangle mask-matrix(-1e9), shape [1, 1, tgt_len, tgt_len]
encoder_decoder_attention_bias: encoder mask-matrix(-1e9), shape [batch_size, src_len, hidden_size]
hparams: model params.
cache: cache for fast decoding.
name: variable name.
Returns:
x: decoder_output, shape [batch_size, tgt_len, hidden_size]
"""
x = decoder_input
residual_dropout_broadcast_dims = (
modules.comma_separated_string_to_integer_list(
getattr(hparams, "residual_dropout_broadcast_dims", ""))
)
attention_dropout_broadcast_dims = (
modules.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
for layer in xrange(hparams.decoder_layers):
layer_name = "layer_%d" % layer
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope("layer_%d" % layer):
# self-attention network
residual = x
x = self.may_be_layernorm(x, before=True, name="self_attention_before")
x = attention_layer.multihead_attention(
x,
None,
decoder_self_attention_bias,
hparams.hidden_size,
hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
cache=layer_cache,
name="decoder_self_attention")
x = modules.dropout_with_broadcast_dims(x,
1.0 - self._hparams.residual_dropout,
broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x
x = self.may_be_layernorm(x, after=True, name="self_attention_after")
# encoder-decoder-attention network
residual = x
x = self.may_be_layernorm(x, before=True, name="encdec_attention_before")
x = attention_layer.multihead_attention(
x,
encoder_output,
encoder_decoder_attention_bias,
hparams.hidden_size,
hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
cache=layer_cache,
name="encoder_decoder_attention")
x = modules.dropout_with_broadcast_dims(x,
1.0 - self._hparams.residual_dropout,
broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x
x = self.may_be_layernorm(x, after=True, name="encdec_attention_after")
# feed-forward network
residual = x
x = self.may_be_layernorm(x, before=True, name="ffn_before")
x = ffn_layer.transformer_ffn_layer(x, hparams)
x = modules.dropout_with_broadcast_dims(x,
1.0 - self._hparams.residual_dropout,
broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x
x = self.may_be_layernorm(x, after=True, name="ffn_after")
if self._hparams.normalize_before:
x = self.may_be_layernorm(x, before=True, name="norm_last")
return x
def _beam_decode(self, features):
beam_size = self._hparams.beam_size
decode_length = self._hparams.extra_decode_length
alpha = self._hparams.decode_alpha
print("beam size: %d" % beam_size)
print("alpha: %f" % alpha)
print("decode length: %d" % decode_length)
every_decode_length = decode_length + tf.reduce_sum(tf.to_int32(tf.cast(features["inputs"], dtype=bool)),
axis=1)
batch_size = tf.shape(features["inputs"])[0]
initial_ids = tf.ones([batch_size], dtype=tf.int32)
vocab_size = self._vocab_tables["tgt_size"]
# Setting decode length to input length + decode_length
decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length)
encoder_input, encoder_attention_bias = self.prepare_encoder(features["inputs"])
encoder_output = self.encoder(encoder_input, encoder_attention_bias)
decoder_attention_bias = modules.attention_bias_lower_triangle(decode_length)
positional_encoding = modules.add_positional_encoding(decode_length, self._hparams.hidden_size)
cache = {
"layer_%d" % layer: {
"k":
modules.split_heads(
tf.zeros([batch_size, 0, self._hparams.hidden_size]), self._hparams.num_heads),
"v":
modules.split_heads(
tf.zeros([batch_size, 0, self._hparams.hidden_size]), self._hparams.num_heads),
} for layer in range(self._hparams.decoder_layers)
}
for layer in range(self._hparams.decoder_layers):
layer_name = "layer_%d" % layer
with tf.variable_scope(
"decoder/%s/encoder_decoder_attention" % layer_name):
combined = modules.conv(
encoder_output,
2 * self._hparams.hidden_size,
name="kv_transform")
k_encdec, v_encdec = tf.split(combined, [self._hparams.hidden_size, self._hparams.hidden_size], axis=2)
k_encdec = modules.split_heads(k_encdec, self._hparams.num_heads)
v_encdec = modules.split_heads(v_encdec, self._hparams.num_heads)
cache[layer_name]["k_encdec"] = k_encdec
cache[layer_name]["v_encdec"] = v_encdec
cache["encoder_output"] = encoder_output
cache["encoder_attention_bias"] = encoder_attention_bias
cache["every_decode_length"] = every_decode_length
def symbols_to_logits_fn(ids, i, cache):
"""Go from ids to logits."""
ids = ids[:, -1:]
with tf.variable_scope("target_embedding", reuse=tf.AUTO_REUSE):
targets_embedding = self.input_embedding(ids, self._vocab_tables["tgt_size"])
decoder_input = targets_embedding + positional_encoding[:, i:i + 1, :]
d_bias = decoder_attention_bias[:, :, i:i + 1, :i + 1]
logits = self.decoder(decoder_input,
d_bias,
cache.get("encoder_output"),
cache.get("encoder_attention_bias"),
cache=cache)
return logits, cache
ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids,
beam_size, decode_length, vocab_size,
alpha, states=cache)
result = ids[:, 0, 1:]
result = self._vocab_tables["ids2tgt"].lookup(tf.to_int64(result))
return result
def input_embedding(self, x, vocab_size, name="input_emb", reuse=None):
with tf.variable_scope(name, reuse=reuse):
embedding = tf.get_variable("weights",
[vocab_size, self._hparams.hidden_size],
initializer=tf.random_normal_initializer(
0.0, self._hparams.hidden_size ** -0.5))
ret = tf.gather(embedding, x)
if self._hparams.multiply_embedding_mode == "sqrt_depth":
ret *= self._hparams.hidden_size ** 0.5
ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
return ret
def output_embedding(self, body_output):
with tf.variable_scope("output_emb", reuse=False):
embedding = tf.get_variable("weights",
[self._hparams.hidden_size, self._vocab_tables["tgt_size"]],
initializer=tf.random_normal_initializer(
0.0, self._hparams.hidden_size ** -0.5))
shape = tf.shape(body_output)[:-1]
body_output = tf.reshape(body_output, [-1, self._hparams.hidden_size])
logits = tf.matmul(body_output, embedding)
logits = tf.reshape(logits, tf.concat([shape, [self._vocab_tables["tgt_size"]]], 0))
# insert a channels dimension
return logits
def shared_output_embedding(self, body_output):
with tf.variable_scope("input_emb", reuse=tf.AUTO_REUSE):
output_embedding = tf.get_variable("weights",
[self._vocab_tables["tgt_size"], self._hparams.hidden_size],
initializer=tf.random_normal_initializer(
0.0, self._hparams.hidden_size ** -0.5))
shape = tf.shape(body_output)[:-1]
body_output = tf.reshape(body_output, [-1, self._hparams.hidden_size])
logits = tf.matmul(body_output, output_embedding, transpose_b=True)
logits = tf.reshape(logits, tf.concat([shape, [self._vocab_tables["tgt_size"]]], 0))
return logits
def residual_fn(self, x, y, dropout_broadcast_dims=None, name=None):
return modules.layer_norm(x + modules.dropout_with_broadcast_dims(
y, 1.0 - self._hparams.residual_dropout, broadcast_dims=dropout_broadcast_dims), name=name)
def may_be_layernorm(self, input, before=False, after=False, name=None):
assert before ^ after
if after ^ self._hparams.normalize_before:
return modules.layer_norm(input, name=name)
else:
return input
def _shard_features(self, features): # pylint: disable=missing-docstring
sharded_features = dict()
for k, v in six.iteritems(features):
v = tf.convert_to_tensor(v)
if not v.shape.as_list():
v = tf.expand_dims(v, axis=-1)
v = tf.tile(v, [self._num_datashards])
sharded_features[k] = self._data_parallelism(tf.identity,
tf.split(
v, self._num_datashards,
0))
return sharded_features
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论