Commit a440c641 by libei

add transformer_dla.py to support Dynamic Layer Aggeration training

parent 90010cd3
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="7d6d9926-f879-4708-ad8e-442bac96b62a" name="Default" comment=""> <list default="true" id="7d6d9926-f879-4708-ad8e-442bac96b62a" name="Default" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" afterPath="$PROJECT_DIR$/.idea/workspace.xml" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" afterPath="$PROJECT_DIR$/.idea/workspace.xml" />
<change beforePath="$PROJECT_DIR$/tensor2tensor/models/transformer.py" afterPath="$PROJECT_DIR$/tensor2tensor/models/transformer.py" />
<change beforePath="$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py" afterPath="$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py" /> <change beforePath="$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py" afterPath="$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py" />
</list> </list>
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" /> <option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
...@@ -66,8 +65,8 @@ ...@@ -66,8 +65,8 @@
<file leaf-file-name="transformer_dla.py" pinned="false" current-in-tab="true"> <file leaf-file-name="transformer_dla.py" pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py"> <entry file="file://$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py">
<provider selected="true" editor-type-id="text-editor"> <provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-1723"> <state relative-caret-position="379">
<caret line="209" column="0" lean-forward="false" selection-start-line="209" selection-start-column="0" selection-end-line="209" selection-end-column="0" /> <caret line="239" column="5" lean-forward="true" selection-start-line="239" selection-start-column="5" selection-end-line="239" selection-end-column="5" />
<folding> <folding>
<element signature="e#738#776#0" expanded="true" /> <element signature="e#738#776#0" expanded="true" />
</folding> </folding>
...@@ -220,6 +219,7 @@ ...@@ -220,6 +219,7 @@
</component> </component>
<component name="ToolWindowManager"> <component name="ToolWindowManager">
<frame x="-8" y="-8" width="1936" height="1056" extended-state="7" /> <frame x="-8" y="-8" width="1936" height="1056" extended-state="7" />
<editor active="true" />
<layout> <layout>
<window_info id="TODO" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="11" side_tool="false" content_ui="tabs" /> <window_info id="TODO" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="11" side_tool="false" content_ui="tabs" />
<window_info id="Event Log" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="true" content_ui="tabs" /> <window_info id="Event Log" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="true" content_ui="tabs" />
...@@ -422,8 +422,8 @@ ...@@ -422,8 +422,8 @@
</entry> </entry>
<entry file="file://$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py"> <entry file="file://$PROJECT_DIR$/tensor2tensor/models/transformer_dla.py">
<provider selected="true" editor-type-id="text-editor"> <provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-1723"> <state relative-caret-position="379">
<caret line="209" column="0" lean-forward="false" selection-start-line="209" selection-start-column="0" selection-end-line="209" selection-end-column="0" /> <caret line="239" column="5" lean-forward="true" selection-start-line="239" selection-start-column="5" selection-end-line="239" selection-end-column="5" />
<folding> <folding>
<element signature="e#738#776#0" expanded="true" /> <element signature="e#738#776#0" expanded="true" />
</folding> </folding>
......
...@@ -172,7 +172,6 @@ def transformer_encoder(encoder_input, ...@@ -172,7 +172,6 @@ def transformer_encoder(encoder_input,
encoder_layer.add(x) encoder_layer.add(x)
for layer in xrange(hparams.encoder_layers): for layer in xrange(hparams.encoder_layers):
with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("layer_%d" % layer):
#self-attention network #self-attention network
residual = x residual = x
x = may_be_layernorm(x, hparams, before=True) x = may_be_layernorm(x, hparams, before=True)
...@@ -205,9 +204,14 @@ def transformer_encoder(encoder_input, ...@@ -205,9 +204,14 @@ def transformer_encoder(encoder_input,
broadcast_dims=residual_dropout_broadcast_dims) broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x x = residual + x
x = may_be_layernorm(x, hparams, after=True) x = may_be_layernorm(x, hparams, after=True)
# add layer output into the history for dynamic layer aggeration
with tf.variable_scope("layer_history"):
encoder_layer.add(x)
x = encoder_layer.pop()
# if use normalize before, it's necessary to normalize the final output
if hparams.normalize_before: if hparams.normalize_before:
x = may_be_layernorm(x, hparams, before=True, name="norm_top") x = may_be_layernorm(x, hparams, before=True, name="norm_top")
return x return x
...@@ -246,6 +250,8 @@ def transformer_decoder(decoder_input, ...@@ -246,6 +250,8 @@ def transformer_decoder(decoder_input,
# Summaries don't work in multi-problem setting yet. # Summaries don't work in multi-problem setting yet.
summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name): with tf.variable_scope(name):
if hparams.use_emb:
decoder_layer.add(x)
for layer in xrange(hparams.decoder_layers): for layer in xrange(hparams.decoder_layers):
with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("layer_%d" % layer):
# self-attention network # self-attention network
...@@ -300,6 +306,11 @@ def transformer_decoder(decoder_input, ...@@ -300,6 +306,11 @@ def transformer_decoder(decoder_input,
broadcast_dims=residual_dropout_broadcast_dims) broadcast_dims=residual_dropout_broadcast_dims)
x = residual + x x = residual + x
x = may_be_layernorm(x, hparams, after=True) x = may_be_layernorm(x, hparams, after=True)
# add layer output into the history for dynamic layer aggeration
with tf.variable_scope("layer_history"):
decoder_layer.add(x)
x = decoder_layer.pop()
# if use normalize before, it's necessary to normalize the final output
if hparams.normalize_before: if hparams.normalize_before:
x = may_be_layernorm(x, hparams, before=True, name="norm_top") x = may_be_layernorm(x, hparams, before=True, name="norm_top")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论