Commit ca78c4b8 by xuchen

add the local attention

parent f1cf477d
...@@ -135,3 +135,4 @@ experimental/* ...@@ -135,3 +135,4 @@ experimental/*
# Weights and Biases logs # Weights and Biases logs
wandb/ wandb/
/examples/translation/iwslt14.tokenized.de-en/ /examples/translation/iwslt14.tokenized.de-en/
toy/
...@@ -150,6 +150,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -150,6 +150,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"selfattn", "selfattn",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"local",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -187,6 +188,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -187,6 +188,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"selfattn", "selfattn",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"local",
], ],
help="transformer decoder self-attention layer type" help="transformer decoder self-attention layer type"
) )
...@@ -277,6 +279,29 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -277,6 +279,29 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default="learnable_dense", default="learnable_dense",
help='decoder layer history type' help='decoder layer history type'
) )
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
pass pass
@classmethod @classmethod
...@@ -627,6 +652,10 @@ def base_architecture(args): ...@@ -627,6 +652,10 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
@register_model_architecture("s2t_transformer", "s2t_transformer_s") @register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args): def s2t_transformer_s(args):
......
...@@ -26,6 +26,7 @@ from .layer_norm import Fp32LayerNorm, LayerNorm ...@@ -26,6 +26,7 @@ from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
from .local_multihead_attention import LocalMultiheadAttention
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding from .positional_embedding import PositionalEmbedding
from .rel_position_multihead_attention import RelPositionMultiheadAttention from .rel_position_multihead_attention import RelPositionMultiheadAttention
...@@ -70,6 +71,7 @@ __all__ = [ ...@@ -70,6 +71,7 @@ __all__ = [
"LightweightConv1dTBC", "LightweightConv1dTBC",
"LightweightConv", "LightweightConv",
"LinearizedConvolution", "LinearizedConvolution",
"LocalMultiheadAttention",
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"RelPositionMultiheadAttention", "RelPositionMultiheadAttention",
......
...@@ -12,7 +12,8 @@ from fairseq.modules import ( ...@@ -12,7 +12,8 @@ from fairseq.modules import (
LayerNorm, LayerNorm,
MultiheadAttention, MultiheadAttention,
RelPositionMultiheadAttention, RelPositionMultiheadAttention,
RelativeMultiheadAttention RelativeMultiheadAttention,
LocalMultiheadAttention,
) )
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
...@@ -103,6 +104,21 @@ class TransformerEncoderLayer(nn.Module): ...@@ -103,6 +104,21 @@ class TransformerEncoderLayer(nn.Module):
else: else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length) print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1) exit(1)
elif self.attn_type == "local":
hard_mask_window = getattr(args, "hard_mask_window", 0)
gauss_mask_sigma = getattr(args, "gauss_mask_sigma", 0)
init_mask_weight = getattr(args, "init_mask_weight", 0)
return LocalMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
hard_mask_window=hard_mask_window,
gauss_mask_sigma=gauss_mask_sigma,
init_mask_weight=init_mask_weight
)
else: else:
print("The encoder attention type %s is not supported!" % self.attn_type) print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论