Commit a440c641 by libei

add transformer_dla.py to support Dynamic Layer Aggeration training

parent 90010cd3
......@@ -3,7 +3,6 @@
<component name="ChangeListManager">
<list default="true" id="7d6d9926-f879-4708-ad8e-442bac96b62a" name="Default" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" afterPath="$PROJECT_DIR$/.idea/workspace.xml" />
<change beforePath="$PROJECT_DIR$/tensor2tensor/models/transformer.py" afterPath="$PROJECT_DIR$/tensor2tensor/models/transformer.py" />
<change beforePath="$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py" afterPath="$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py" />
</list>
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
......@@ -66,8 +65,8 @@
<file leaf-file-name="transformer_dla.py" pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-1723">
<caret line="209" column="0" lean-forward="false" selection-start-line="209" selection-start-column="0" selection-end-line="209" selection-end-column="0" />
<state relative-caret-position="379">
<caret line="239" column="5" lean-forward="true" selection-start-line="239" selection-start-column="5" selection-end-line="239" selection-end-column="5" />
<folding>
<element signature="e#738#776#0" expanded="true" />
</folding>
......@@ -220,6 +219,7 @@
</component>
<component name="ToolWindowManager">
<frame x="-8" y="-8" width="1936" height="1056" extended-state="7" />
<editor active="true" />
<layout>
<window_info id="TODO" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="11" side_tool="false" content_ui="tabs" />
<window_info id="Event Log" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="true" content_ui="tabs" />
......@@ -422,8 +422,8 @@
</entry>
<entry file="file://$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-1723">
<caret line="209" column="0" lean-forward="false" selection-start-line="209" selection-start-column="0" selection-end-line="209" selection-end-column="0" />
<state relative-caret-position="379">
<caret line="239" column="5" lean-forward="true" selection-start-line="239" selection-start-column="5" selection-end-line="239" selection-end-column="5" />
<folding>
<element signature="e#738#776#0" expanded="true" />
</folding>
......
......@@ -172,7 +172,6 @@ def transformer_encoder(encoder_input,
encoder_layer.add(x)
for layer in xrange(hparams.encoder_layers):
with tf.variable_scope("layer_%d" % layer):
#self-attention network
residual = x
x = may_be_layernorm(x, hparams, before=True)
......@@ -205,9 +204,14 @@ def transformer_encoder(encoder_input,
broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x
x = may_be_layernorm(x, hparams, after=True)
# add layer output into the history for dynamic layer aggeration
with tf.variable_scope("layer_history"):
encoder_layer.add(x)
x = encoder_layer.pop()
# if use normalize before, it's necessary to normalize the final output
if hparams.normalize_before:
x = may_be_layernorm(x, hparams, before=True, name="norm_top")
return x
......@@ -246,6 +250,8 @@ def transformer_decoder(decoder_input,
# Summaries don't work in multi-problem setting yet.
summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name):
if hparams.use_emb:
decoder_layer.add(x)
for layer in xrange(hparams.decoder_layers):
with tf.variable_scope("layer_%d" % layer):
# self-attention network
......@@ -300,6 +306,11 @@ def transformer_decoder(decoder_input,
broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x
x = may_be_layernorm(x, hparams, after=True)
# add layer output into the history for dynamic layer aggeration
with tf.variable_scope("layer_history"):
decoder_layer.add(x)
x = decoder_layer.pop()
# if use normalize before, it's necessary to normalize the final output
if hparams.normalize_before:
x = may_be_layernorm(x, hparams, before=True, name="norm_top")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论