Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
61cf1afa
Commit
61cf1afa
authored
Apr 10, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the dlcl arch
parent
143e11af
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
596 行增加
和
0 行删除
+596
-0
fairseq/models/dlcl_transformer.py
+596
-0
没有找到文件。
fairseq/models/dlcl_transformer.py
0 → 100644
查看文件 @
61cf1afa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
logging
import
torch
from
fairseq
import
checkpoint_utils
,
utils
from
fairseq.models
import
(
register_model
,
register_model_architecture
,
)
from
fairseq.models.transformer
import
(
TransformerModel
,
TransformerEncoder
,
TransformerDecoder
)
from
fairseq.modules.layer_history
import
CreateLayerHistory
from
torch
import
Tensor
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"dlcl_transformer"
)
class
DLCLTransformerModel
(
TransformerModel
):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.dlcl_transformer_parser
:prog:
"""
def
__init__
(
self
,
args
,
encoder
,
decoder
):
super
()
.
__init__
(
args
,
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
TransformerModel
.
add_args
(
parser
)
# dense layer parameters
parser
.
add_argument
(
'--encoder-history-type'
,
help
=
'encoder layer history type'
)
parser
.
add_argument
(
'--decoder-history-type'
,
help
=
'decoder layer history type'
)
parser
.
add_argument
(
'--encoder-integration-type'
,
choices
=
[
'avg'
,
'sum'
],
help
=
'encoder layer integration type'
)
parser
.
add_argument
(
'--decoder-integration-type'
,
choices
=
[
'avg'
,
'sum'
],
help
=
'decoder layer integration type'
)
@classmethod
def
build_encoder
(
cls
,
args
,
src_dict
,
embed_tokens
):
encoder
=
DLCLTransformerEncoder
(
args
,
src_dict
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
return
encoder
@classmethod
def
build_decoder
(
cls
,
args
,
tgt_dict
,
embed_tokens
):
decoder
=
DLCLTransformerDecoder
(
args
,
tgt_dict
,
embed_tokens
,
no_encoder_attn
=
getattr
(
args
,
"no_cross_attention"
,
False
),
)
if
getattr
(
args
,
"load_pretrained_decoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained decoder from: "
f
"{args.load_pretrained_decoder_from}"
)
decoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
decoder
,
checkpoint
=
args
.
load_pretrained_decoder_from
,
strict
=
False
)
return
decoder
class
DLCLTransformerEncoder
(
TransformerEncoder
):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
self
.
args
=
args
super
()
.
__init__
(
args
,
dictionary
,
embed_tokens
)
self
.
history
=
CreateLayerHistory
(
args
,
is_encoder
=
True
)
def
forward
(
self
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return
self
.
forward_scriptable
(
src_tokens
,
src_lengths
,
return_all_hiddens
,
token_embeddings
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def
forward_scriptable
(
self
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
has_pads
=
(
src_tokens
.
device
.
type
==
"xla"
or
encoder_padding_mask
.
any
())
x
,
encoder_embedding
=
self
.
forward_embedding
(
src_tokens
,
token_embeddings
)
# account for padding while computing the representation
if
encoder_padding_mask
is
not
None
:
x
=
x
*
(
1
-
encoder_padding_mask
.
unsqueeze
(
-
1
)
.
type_as
(
x
))
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
encoder_states
=
[]
if
return_all_hiddens
:
encoder_states
.
append
(
x
)
# add emb into history
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
# encoder layers
for
layer
in
self
.
layers
:
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
x
=
layer
(
x
,
encoder_padding_mask
=
encoder_padding_mask
if
has_pads
else
None
)
if
return_all_hiddens
:
assert
encoder_states
is
not
None
encoder_states
.
append
(
x
)
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[
encoder_embedding
],
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
[],
"src_lengths"
:
[],
}
class
DLCLTransformerDecoder
(
TransformerDecoder
):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
,
no_encoder_attn
=
False
):
self
.
args
=
args
super
()
.
__init__
(
args
,
dictionary
,
embed_tokens
,
no_encoder_attn
)
self
.
history
=
CreateLayerHistory
(
args
,
is_encoder
=
False
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
features_only
:
bool
=
False
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
src_lengths
:
Optional
[
Any
]
=
None
,
return_all_hiddens
:
bool
=
False
,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x
,
extra
=
self
.
extract_features
(
prev_output_tokens
,
encoder_out
=
encoder_out
,
incremental_state
=
incremental_state
,
full_context_alignment
=
full_context_alignment
,
alignment_layer
=
alignment_layer
,
alignment_heads
=
alignment_heads
,
)
if
not
features_only
:
x
=
self
.
output_layer
(
x
)
return
x
,
extra
def
extract_features
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
return
self
.
extract_features_scriptable
(
prev_output_tokens
,
encoder_out
,
incremental_state
,
full_context_alignment
,
alignment_layer
,
alignment_heads
,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def
extract_features_scriptable
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
if
alignment_layer
is
None
:
alignment_layer
=
self
.
num_layers
-
1
# embed positions
positions
=
None
if
self
.
embed_positions
is
not
None
:
positions
=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
=
incremental_state
)
if
incremental_state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
if
positions
is
not
None
:
positions
=
positions
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
self
.
quant_noise
is
not
None
:
x
=
self
.
quant_noise
(
x
)
if
self
.
project_in_dim
is
not
None
:
x
=
self
.
project_in_dim
(
x
)
if
positions
is
not
None
and
self
.
attn_type
!=
"rel_selfattn"
:
x
+=
positions
if
self
.
layernorm_embedding
is
not
None
:
x
=
self
.
layernorm_embedding
(
x
)
x
=
self
.
dropout_module
(
x
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
# add emb into history
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
self_attn_padding_mask
:
Optional
[
Tensor
]
=
None
if
self
.
cross_self_attention
or
prev_output_tokens
.
eq
(
self
.
padding_idx
)
.
any
():
self_attn_padding_mask
=
prev_output_tokens
.
eq
(
self
.
padding_idx
)
# decoder layers
attn
:
Optional
[
Tensor
]
=
None
inner_states
:
List
[
Optional
[
Tensor
]]
=
[
x
]
for
idx
,
layer
in
enumerate
(
self
.
layers
):
if
incremental_state
is
None
and
not
full_context_alignment
:
self_attn_mask
=
self
.
buffered_future_mask
(
x
)
else
:
self_attn_mask
=
None
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
x
,
layer_attn
,
_
=
layer
(
x
,
encoder_out
[
"encoder_out"
][
0
]
if
(
encoder_out
is
not
None
and
len
(
encoder_out
[
"encoder_out"
])
>
0
)
else
None
,
encoder_out
[
"encoder_padding_mask"
][
0
]
if
(
encoder_out
is
not
None
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
)
else
None
,
incremental_state
,
self_attn_mask
=
self_attn_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
need_attn
=
bool
((
idx
==
alignment_layer
)),
need_head_weights
=
bool
((
idx
==
alignment_layer
)),
pos_emb
=
positions
)
inner_states
.
append
(
x
)
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
if
layer_attn
is
not
None
and
idx
==
alignment_layer
:
attn
=
layer_attn
.
float
()
.
to
(
x
)
if
attn
is
not
None
:
if
alignment_heads
is
not
None
:
attn
=
attn
[:
alignment_heads
]
# average probabilities over heads
attn
=
attn
.
mean
(
dim
=
0
)
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
project_out_dim
is
not
None
:
x
=
self
.
project_out_dim
(
x
)
return
x
,
{
"attn"
:
[
attn
],
"inner_states"
:
inner_states
}
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_tiny"
)
def
tiny_architecture
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
64
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
64
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
2
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
2
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
2
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
2
)
return
base_architecture
(
args
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer"
)
def
base_architecture
(
args
):
args
.
encoder_embed_path
=
getattr
(
args
,
"encoder_embed_path"
,
None
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
False
)
args
.
encoder_learned_pos
=
getattr
(
args
,
"encoder_learned_pos"
,
False
)
args
.
decoder_embed_path
=
getattr
(
args
,
"decoder_embed_path"
,
None
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
False
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.0
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
0.0
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
"share_all_embeddings"
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
no_cross_attention
=
getattr
(
args
,
"no_cross_attention"
,
False
)
args
.
cross_self_attention
=
getattr
(
args
,
"cross_self_attention"
,
False
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
layernorm_embedding
=
getattr
(
args
,
"layernorm_embedding"
,
False
)
args
.
tie_adaptive_weights
=
getattr
(
args
,
"tie_adaptive_weights"
,
False
)
args
.
checkpoint_activations
=
getattr
(
args
,
"checkpoint_activations"
,
False
)
args
.
offload_activations
=
getattr
(
args
,
"offload_activations"
,
False
)
if
args
.
offload_activations
:
args
.
checkpoint_activations
=
True
args
.
encoder_layers_to_keep
=
getattr
(
args
,
"encoder_layers_to_keep"
,
None
)
args
.
decoder_layers_to_keep
=
getattr
(
args
,
"decoder_layers_to_keep"
,
None
)
args
.
encoder_layerdrop
=
getattr
(
args
,
"encoder_layerdrop"
,
0
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
args
.
quant_noise_scalar
=
getattr
(
args
,
"quant_noise_scalar"
,
0
)
args
.
encoder_history_type
=
getattr
(
args
,
'encoder_history_type'
,
'learnable_dense'
)
args
.
decoder_history_type
=
getattr
(
args
,
'decoder_history_type'
,
'learnable_dense'
)
args
.
encoder_integration_type
=
getattr
(
args
,
'encoder_integration_type'
,
'avg'
)
args
.
decoder_integration_type
=
getattr
(
args
,
'decoder_integration_type'
,
'avg'
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_relative"
)
def
dlcl_transformer_relative
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
base_architecture
(
args
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_iwslt_de_en"
)
def
dlcl_transformer_iwslt_de_en
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
1024
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
512
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
1024
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
4
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
base_architecture
(
args
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_wmt_en_de"
)
def
dlcl_transformer_wmt_en_de
(
args
):
base_architecture
(
args
)
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_vaswani_wmt_en_de_big"
)
def
dlcl_transformer_vaswani_wmt_en_de_big
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
4096
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
False
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
1024
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
4096
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.3
)
base_architecture
(
args
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_vaswani_wmt_en_fr_big"
)
def
dlcl_transformer_vaswani_wmt_en_fr_big
(
args
):
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
dlcl_transformer_vaswani_wmt_en_de_big
(
args
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_wmt_en_de_big"
)
def
dlcl_transformer_wmt_en_de_big
(
args
):
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.1
)
dlcl_transformer_vaswani_wmt_en_de_big
(
args
)
# default parameters used in tensor2tensor implementation
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_wmt_en_de_big_t2t"
)
def
dlcl_transformer_wmt_en_de_big_t2t
(
args
):
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
True
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
True
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.1
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
0.1
)
dlcl_transformer_vaswani_wmt_en_de_big
(
args
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论