Commit ca78c4b8 by xuchen

add the local attention

parent f1cf477d
......@@ -135,3 +135,4 @@ experimental/*
# Weights and Biases logs
wandb/
/examples/translation/iwslt14.tokenized.de-en/
toy/
......@@ -150,6 +150,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer encoder self-attention layer type"
)
......@@ -187,6 +188,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"selfattn",
"rel_selfattn",
"relative",
"local",
],
help="transformer decoder self-attention layer type"
)
......@@ -277,6 +279,29 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default="learnable_dense",
help='decoder layer history type'
)
parser.add_argument(
'--hard-mask-window',
type=float,
metavar="D",
default=0,
help='window size of local mask'
)
parser.add_argument(
'--gauss-mask-sigma',
type=float,
metavar="D",
default=0,
help='standard deviation of the gauss mask'
)
parser.add_argument(
'--init-mask-weight',
type=float,
metavar="D",
default=0.5,
help='initialized weight for local mask'
)
pass
@classmethod
......@@ -627,6 +652,10 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
args.hard_mask_window = getattr(args, 'hard_mask_window', 0)
args.gauss_mask_sigma = getattr(args, 'gauss_mask_sigma', 0)
args.init_mask_weight = getattr(args, 'init_mask_weight', 0)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args):
......
......@@ -26,6 +26,7 @@ from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution
from .local_multihead_attention import LocalMultiheadAttention
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .rel_position_multihead_attention import RelPositionMultiheadAttention
......@@ -70,6 +71,7 @@ __all__ = [
"LightweightConv1dTBC",
"LightweightConv",
"LinearizedConvolution",
"LocalMultiheadAttention",
"MultiheadAttention",
"PositionalEmbedding",
"RelPositionMultiheadAttention",
......
......@@ -12,7 +12,8 @@ from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention
RelativeMultiheadAttention,
LocalMultiheadAttention,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
......@@ -103,6 +104,21 @@ class TransformerEncoderLayer(nn.Module):
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
elif self.attn_type == "local":
hard_mask_window = getattr(args, "hard_mask_window", 0)
gauss_mask_sigma = getattr(args, "gauss_mask_sigma", 0)
init_mask_weight = getattr(args, "init_mask_weight", 0)
return LocalMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
hard_mask_window=hard_mask_window,
gauss_mask_sigma=gauss_mask_sigma,
init_mask_weight=init_mask_weight
)
else:
print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论