Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
FairseqDecoder
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
libei
FairseqDecoder
Commits
0d041293
Commit
0d041293
authored
Mar 20, 2019
by
libei
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add relative_transformer model to support rpr fast decode
parent
d2cb256f
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
776 行增加
和
0 行删除
+776
-0
fairseq/models/relative_transformer.py
+500
-0
fairseq/modules/relative_multihead_attention.py
+276
-0
没有找到文件。
fairseq/models/relative_transformer.py
查看文件 @
0d041293
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq.modules
import
(
LearnedPositionalEmbedding
,
MultiheadAttention
,
SinusoidalPositionalEmbedding
,
)
from
fairseq.modules.relative_multihead_attention
import
RelativeMultiheadAttention
from
.
import
(
FairseqIncrementalDecoder
,
FairseqEncoder
,
FairseqModel
,
register_model
,
register_model_architecture
,
)
@register_model
(
'relative_transformer'
)
class
RelativeTransformerModel
(
FairseqModel
):
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability'
)
parser
.
add_argument
(
'--attention-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability for attention weights'
)
parser
.
add_argument
(
'--relu-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability after ReLU in FFN'
)
parser
.
add_argument
(
'--encoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension'
)
parser
.
add_argument
(
'--encoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--encoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder layers'
)
parser
.
add_argument
(
'--encoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder attention heads'
)
parser
.
add_argument
(
'--encoder-normalize-before'
,
default
=
False
,
action
=
'store_true'
,
help
=
'apply layernorm before each encoder block'
)
parser
.
add_argument
(
'--encoder-learned-pos'
,
default
=
False
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the encoder'
)
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension'
)
parser
.
add_argument
(
'--decoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--decoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder layers'
)
parser
.
add_argument
(
'--decoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder attention heads'
)
parser
.
add_argument
(
'--decoder-learned-pos'
,
default
=
False
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the decoder'
)
parser
.
add_argument
(
'--decoder-normalize-before'
,
default
=
False
,
action
=
'store_true'
,
help
=
'apply layernorm before each decoder block'
)
parser
.
add_argument
(
'--share-decoder-input-output-embed'
,
default
=
False
,
action
=
'store_true'
,
help
=
'share decoder input and output embeddings'
)
parser
.
add_argument
(
'--share-all-embeddings'
,
default
=
False
,
action
=
'store_true'
,
help
=
'share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)'
)
@classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
src_dict
,
tgt_dict
=
task
.
source_dictionary
,
task
.
target_dictionary
def
build_embedding
(
dictionary
,
embed_dim
):
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
return
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
if
args
.
share_all_embeddings
:
if
src_dict
!=
tgt_dict
:
raise
RuntimeError
(
'--share-all-embeddings requires a joined dictionary'
)
if
args
.
encoder_embed_dim
!=
args
.
decoder_embed_dim
:
raise
RuntimeError
(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
)
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
decoder_embed_tokens
=
encoder_embed_tokens
args
.
share_decoder_input_output_embed
=
True
else
:
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
decoder_embed_tokens
=
build_embedding
(
tgt_dict
,
args
.
decoder_embed_dim
)
encoder
=
RelativeTransformerEncoder
(
args
,
src_dict
,
encoder_embed_tokens
)
decoder
=
RelativeTransformerDecoder
(
args
,
tgt_dict
,
decoder_embed_tokens
)
return
RelativeTransformerModel
(
encoder
,
decoder
)
class
RelativeTransformerEncoder
(
FairseqEncoder
):
"""Transformer encoder."""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
,
left_pad
=
True
):
super
()
.
__init__
(
dictionary
)
self
.
dropout
=
args
.
dropout
embed_dim
=
embed_tokens
.
embedding_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
PositionalEmbedding
(
1024
,
embed_dim
,
self
.
padding_idx
,
left_pad
=
left_pad
,
learned
=
args
.
encoder_learned_pos
,
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
RelativeTransformerEncoderLayer
(
args
)
for
i
in
range
(
args
.
encoder_layers
)
])
self
.
normalize
=
args
.
encoder_normalize_before
if
self
.
normalize
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
# embed tokens and positions
emb_token
=
self
.
embed_tokens
(
src_tokens
)
x
=
self
.
embed_scale
*
emb_token
pos_emb
=
self
.
embed_positions
(
src_tokens
)
x
+=
pos_emb
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
#print("encoder_input:",x)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
if
not
encoder_padding_mask
.
any
():
encoder_padding_mask
=
None
# encoder layers
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
encoder_padding_mask
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
# print('enc_output:', x.size())
#print('enc_output:', x)
return
{
'encoder_out'
:
x
,
# T x B x C
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
}
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
self
.
embed_positions
.
max_positions
()
def
upgrade_state_dict
(
self
,
state_dict
):
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
if
'encoder.embed_positions.weights'
in
state_dict
:
del
state_dict
[
'encoder.embed_positions.weights'
]
if
'encoder.embed_positions._float_tensor'
not
in
state_dict
:
state_dict
[
'encoder.embed_positions._float_tensor'
]
=
torch
.
FloatTensor
()
return
state_dict
class
RelativeTransformerDecoder
(
FairseqIncrementalDecoder
):
"""Transformer decoder."""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
,
left_pad
=
False
,
final_norm
=
True
):
super
()
.
__init__
(
dictionary
)
self
.
dropout
=
args
.
dropout
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
embed_dim
=
embed_tokens
.
embedding_dim
padding_idx
=
embed_tokens
.
padding_idx
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
PositionalEmbedding
(
1024
,
embed_dim
,
padding_idx
,
left_pad
=
left_pad
,
learned
=
args
.
decoder_learned_pos
,
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
RelativeTransformerDecoderLayer
(
args
)
for
i
in
range
(
args
.
decoder_layers
)
])
self
.
normalize
=
args
.
decoder_normalize_before
and
final_norm
if
self
.
normalize
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
if
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
# embed positions
positions
=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
=
incremental_state
,
)
# print('raw position:',positions.size())
if
incremental_state
is
not
None
:
is_first_step
=
prev_output_tokens
.
size
(
1
)
==
1
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
positions
=
positions
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
is_first_step
:
x
=
positions
.
expand_as
(
x
)
else
:
x
+=
positions
else
:
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
x
[:,
0
,
:]
.
fill_
(
0
)
x
+=
positions
# print('dec-final-emb:', x.size())
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
# decoder layers
for
layer
in
self
.
layers
:
x
,
attn
=
layer
(
x
,
encoder_out
[
'encoder_out'
],
encoder_out
[
'encoder_padding_mask'
],
incremental_state
,
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
# project back to size of vocabulary
if
self
.
share_input_output_embed
:
x
=
F
.
linear
(
x
,
self
.
embed_tokens
.
weight
)
else
:
x
=
F
.
linear
(
x
,
self
.
embed_out
)
# x[:, 0] = float(-10)
# print(x[:5,:10])
return
x
,
attn
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
if
encoder_out_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out_dict
[
'encoder_padding_mask'
]
=
\
encoder_out_dict
[
'encoder_padding_mask'
]
.
index_select
(
0
,
new_order
)
return
encoder_out_dict
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
max_positions
()
def
upgrade_state_dict
(
self
,
state_dict
):
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
if
'decoder.embed_positions.weights'
in
state_dict
:
del
state_dict
[
'decoder.embed_positions.weights'
]
if
'decoder.embed_positions._float_tensor'
not
in
state_dict
:
state_dict
[
'decoder.embed_positions._float_tensor'
]
=
torch
.
FloatTensor
()
return
state_dict
class
RelativeTransformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def
__init__
(
self
,
args
):
super
()
.
__init__
()
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
self_attn
=
RelativeMultiheadAttention
(
self
.
embed_dim
,
args
.
encoder_attention_heads
,
args
.
max_relative_length
,
dropout
=
args
.
attention_dropout
,
)
self
.
dropout
=
args
.
dropout
self
.
relu_dropout
=
args
.
relu_dropout
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
2
)])
def
forward
(
self
,
x
,
encoder_padding_mask
):
residual
=
x
x
=
self
.
maybe_layer_norm
(
0
,
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
)
#print("rpr_out:",x)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
0
,
x
,
after
=
True
)
residual
=
x
x
=
self
.
maybe_layer_norm
(
1
,
x
,
before
=
True
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
dropout
(
x
,
p
=
self
.
relu_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
1
,
x
,
after
=
True
)
return
x
def
maybe_layer_norm
(
self
,
i
,
x
,
before
=
False
,
after
=
False
):
assert
before
^
after
if
after
^
self
.
normalize_before
:
return
self
.
layer_norms
[
i
](
x
)
else
:
return
x
class
RelativeTransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block."""
def
__init__
(
self
,
args
):
super
()
.
__init__
()
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
self_attn
=
RelativeMultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
args
.
max_relative_length
,
dropout
=
args
.
attention_dropout
,
)
self
.
dropout
=
args
.
dropout
self
.
relu_dropout
=
args
.
relu_dropout
self
.
normalize_before
=
args
.
decoder_normalize_before
self
.
encoder_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
3
)])
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
):
residual
=
x
x
=
self
.
maybe_layer_norm
(
0
,
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
mask_future_timesteps
=
True
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
0
,
x
,
after
=
True
)
residual
=
x
x
=
self
.
maybe_layer_norm
(
1
,
x
,
before
=
True
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
1
,
x
,
after
=
True
)
residual
=
x
x
=
self
.
maybe_layer_norm
(
2
,
x
,
before
=
True
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
dropout
(
x
,
p
=
self
.
relu_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
2
,
x
,
after
=
True
)
return
x
,
attn
def
maybe_layer_norm
(
self
,
i
,
x
,
before
=
False
,
after
=
False
):
assert
before
^
after
if
after
^
self
.
normalize_before
:
return
self
.
layer_norms
[
i
](
x
)
else
:
return
x
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
return
m
def
LayerNorm
(
embedding_dim
):
m
=
nn
.
LayerNorm
(
embedding_dim
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
,
learned
=
False
):
if
learned
:
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
else
:
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
left_pad
,
init_size
=
num_embeddings
)
return
m
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer'
)
def
base_architecture
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
6
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
8
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
6
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
8
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.
)
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.1
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
16
)
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer_wmt_en_de'
)
def
relative_transformer_wmt_en_de
(
args
):
base_architecture
(
args
)
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer_t2t_wmt_en_de'
)
def
relative_transformer_t2t_wmt_en_de
(
args
):
args
.
encoder_normalize_before
=
True
args
.
decoder_normalize_before
=
True
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.1
)
args
.
max_relative_length
=
16
base_architecture
(
args
)
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer_toy'
)
def
relative_transformer_toy
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
64
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
256
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
4
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
64
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
256
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
4
)
relative_transformer_t2t_wmt_en_de
(
args
)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer_vaswani_wmt_en_de_big'
)
def
relative_transformer_vaswani_wmt_en_de_big
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
4096
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
16
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
False
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
1024
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
4096
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
16
)
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.3
)
base_architecture
(
args
)
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer_vaswani_wmt_en_fr_big'
)
def
relative_transformer_vaswani_wmt_en_fr_big
(
args
):
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.1
)
relative_transformer_vaswani_wmt_en_de_big
(
args
)
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer_wmt_en_de_big'
)
def
relative_transformer_wmt_en_de_big
(
args
):
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
relative_transformer_vaswani_wmt_en_de_big
(
args
)
# default parameters used in tensor2tensor implementation
@register_model_architecture
(
'relative_transformer'
,
'relative_transformer_wmt_en_de_big_t2t'
)
def
relative_transformer_wmt_en_de_big_t2t
(
args
):
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
True
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'decoder_normalize_before'
,
True
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.1
)
relative_transformer_vaswani_wmt_en_de_big
(
args
)
fairseq/modules/relative_multihead_attention.py
查看文件 @
0d041293
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
from
torch
import
nn
from
torch.nn
import
Parameter
import
torch.nn.functional
as
F
from
fairseq
import
utils
class
RelativeMultiheadAttention
(
nn
.
Module
):
"""Multi-headed attention use relative position information to enchance the attention.
See "Self-Attention with Relative Position Representations" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
max_relative_length
,
dropout
=
0.
,
bias
=
True
):
super
()
.
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
max_relative_length
=
max_relative_length
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
_mask
=
None
self
.
in_proj_weight
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
,
embed_dim
))
if
bias
:
self
.
in_proj_bias
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
))
else
:
self
.
register_parameter
(
'in_proj_bias'
,
None
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
relative_position_keys
=
Parameter
(
torch
.
Tensor
(
2
*
self
.
max_relative_length
+
1
,
self
.
head_dim
))
self
.
relative_position_values
=
Parameter
(
torch
.
Tensor
(
2
*
self
.
max_relative_length
+
1
,
self
.
head_dim
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
relative_position_keys
)
nn
.
init
.
xavier_normal_
(
self
.
relative_position_values
)
if
self
.
in_proj_bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
in_proj_bias
,
0.
)
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.
)
def
forward
(
self
,
query
,
key
,
value
,
mask_future_timesteps
=
False
,
key_padding_mask
=
None
,
incremental_state
=
None
,
need_weights
=
True
,
static_kv
=
False
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
qkv_same
=
query
.
data_ptr
()
==
key
.
data_ptr
()
==
value
.
data_ptr
()
kv_same
=
key
.
data_ptr
()
==
value
.
data_ptr
()
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
key
.
size
()
==
value
.
size
()
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
'prev_key'
in
saved_state
:
# previous time steps are cached - no need to recompute
# key and value if they are static
if
static_kv
:
assert
kv_same
and
not
qkv_same
key
=
value
=
None
else
:
saved_state
=
None
if
qkv_same
:
# self-attention
q
,
k
,
v
=
self
.
in_proj_qkv
(
query
)
elif
kv_same
:
# encoder-decoder attention
q
=
self
.
in_proj_q
(
query
)
if
key
is
None
:
assert
value
is
None
# this will allow us to concat it with previous value and get
# just get the previous value
k
=
v
=
q
.
new
(
0
)
else
:
k
,
v
=
self
.
in_proj_kv
(
key
)
else
:
q
=
self
.
in_proj_q
(
query
)
k
=
self
.
in_proj_k
(
key
)
v
=
self
.
in_proj_v
(
value
)
q
*=
self
.
scaling
if
saved_state
is
not
None
:
if
'prev_key'
in
saved_state
:
k
=
torch
.
cat
((
saved_state
[
'prev_key'
],
k
),
dim
=
0
)
if
'prev_value'
in
saved_state
:
v
=
torch
.
cat
((
saved_state
[
'prev_value'
],
v
),
dim
=
0
)
saved_state
[
'prev_key'
]
=
k
saved_state
[
'prev_value'
]
=
v
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
src_len
=
k
.
size
(
0
)
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
q
=
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
k
=
k
.
contiguous
()
.
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
v
=
v
.
contiguous
()
.
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
if
not
incremental_state
:
assert
q
.
size
()
==
k
.
size
()
==
v
.
size
()
length
=
k
.
size
()[
1
]
relative_positions_matrix
=
self
.
_generate_relative_positions_matrix
(
length
,
self
.
max_relative_length
,
incremental_state
)
#print("relative_positions_matrix : {}".format(relative_positions_matrix))
relation_keys
=
F
.
embedding
(
relative_positions_matrix
.
long
()
.
cuda
(),
self
.
relative_position_keys
)
relation_values
=
F
.
embedding
(
relative_positions_matrix
.
long
()
.
cuda
(),
self
.
relative_position_values
)
relative_attn_weights
=
self
.
_relative_attention_inner
(
q
,
k
,
relation_keys
,
transpose
=
True
)
assert
list
(
relative_attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
# only apply masking at training time (when incremental state is None)
if
mask_future_timesteps
and
incremental_state
is
None
:
assert
query
.
size
()
==
key
.
size
(),
\
'mask_future_timesteps only applies to self-attention'
relative_attn_weights
+=
self
.
buffered_mask
(
relative_attn_weights
)
.
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
relative_attn_weights
=
relative_attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
relative_attn_weights
=
relative_attn_weights
.
float
()
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
),
float
(
'-inf'
),
)
.
type_as
(
relative_attn_weights
)
# FP16 support: cast to float and back
relative_attn_weights
=
relative_attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
relative_attn_weights
=
F
.
softmax
(
relative_attn_weights
.
float
(),
dim
=-
1
)
.
type_as
(
relative_attn_weights
)
relative_attn_weights
=
F
.
dropout
(
relative_attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn
=
self
.
_relative_attention_inner
(
relative_attn_weights
,
v
,
relation_values
,
transpose
=
False
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn
=
attn
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
#print("attn : {}".format(attn))
if
need_weights
:
# average attention weights over heads
relative_attn_weights
=
relative_attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
relative_attn_weights
=
relative_attn_weights
.
sum
(
dim
=
1
)
/
self
.
num_heads
else
:
relative_attn_weights
=
None
return
attn
,
relative_attn_weights
def
in_proj_qkv
(
self
,
query
):
return
self
.
_in_proj
(
query
)
.
chunk
(
3
,
dim
=-
1
)
def
in_proj_kv
(
self
,
key
):
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
)
.
chunk
(
2
,
dim
=-
1
)
def
in_proj_q
(
self
,
query
):
return
self
.
_in_proj
(
query
,
end
=
self
.
embed_dim
)
def
in_proj_k
(
self
,
key
):
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
,
end
=
2
*
self
.
embed_dim
)
def
in_proj_v
(
self
,
value
):
return
self
.
_in_proj
(
value
,
start
=
2
*
self
.
embed_dim
)
def
_in_proj
(
self
,
input
,
start
=
None
,
end
=
None
):
weight
=
self
.
in_proj_weight
bias
=
self
.
in_proj_bias
if
end
is
not
None
:
weight
=
weight
[:
end
,
:]
if
bias
is
not
None
:
bias
=
bias
[:
end
]
if
start
is
not
None
:
weight
=
weight
[
start
:,
:]
if
bias
is
not
None
:
bias
=
bias
[
start
:]
return
F
.
linear
(
input
,
weight
,
bias
)
def
buffered_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
-
1
)
if
self
.
_mask
is
None
:
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
if
self
.
_mask
.
size
(
0
)
<
dim
:
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_mask
[:
dim
,
:
dim
]
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
input_buffer
[
k
]
=
input_buffer
[
k
]
.
index_select
(
1
,
new_order
)
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
def
_get_input_buffer
(
self
,
incremental_state
):
return
utils
.
get_incremental_state
(
self
,
incremental_state
,
'attn_state'
,
)
or
{}
def
_set_input_buffer
(
self
,
incremental_state
,
buffer
):
utils
.
set_incremental_state
(
self
,
incremental_state
,
'attn_state'
,
buffer
,
)
def
_generate_relative_positions_matrix
(
self
,
length
,
max_relative_length
,
incremental_state
):
if
not
incremental_state
:
# training process
range_vec
=
torch
.
arange
(
length
)
range_mat
=
range_vec
.
repeat
(
length
,
1
)
distance_mat
=
range_mat
-
range_mat
.
transpose
(
0
,
1
)
else
:
distance_mat
=
torch
.
range
(
-
length
+
1
,
0
)
.
view
(
1
,
-
1
)
distance_mat_clipped
=
torch
.
clamp
(
distance_mat
,
-
max_relative_length
,
max_relative_length
)
# position difference.
final_mat
=
distance_mat_clipped
+
max_relative_length
return
final_mat
def
_relative_attention_inner
(
self
,
x
,
y
,
z
,
transpose
=
True
):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
Args:
x: Tensor with shape [batch_size*heads, length, length or depth].
y: Tensor with shape [batch_size*heads, length, depth].
z: Tensor with shape [length, length, depth].
transpose: Whether to tranpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
A Tensor with shape [batch_size*heads, length, length or depth].
wq: this function actually does 'X(Y+Z)', where Z is vector,
but factor above formular as: 'XY + XZ'
"""
batch_size_mul_head
=
x
.
size
()[
0
]
length
=
z
.
size
()[
0
]
#print(batch_size_mul_head, length)
# xy_matmul is [batch_size*heads, length, length or depth]
if
transpose
:
y
=
y
.
transpose
(
1
,
2
)
xy_matmul
=
torch
.
bmm
(
x
,
y
)
# x_t is [length, batch_size * heads, length or depth]
x_t
=
x
.
transpose
(
0
,
1
)
# x_tz_matmul is [length, batch_size * heads, length or depth]
if
transpose
:
z
=
z
.
transpose
(
1
,
2
)
x_tz_matmul
=
torch
.
bmm
(
x_t
,
z
)
.
transpose
(
0
,
1
)
.
view
(
batch_size_mul_head
,
length
,
-
1
)
#assert xy_matmul.size() == x_tz_matmul.size()
return
xy_matmul
+
x_tz_matmul
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论