Commit 0e2452b9 by xuchen

fix the bug of the relative multihead attention

parent 61cf1afa
...@@ -147,7 +147,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -147,7 +147,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default="selfattn", default="selfattn",
choices=[ choices=[
"selfattn", "selfattn",
"rel_selfattn" "rel_selfattn",
"relative",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -183,7 +184,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -183,7 +184,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default="selfattn", default="selfattn",
choices=[ choices=[
"selfattn", "selfattn",
"rel_selfattn" "rel_selfattn",
"relative",
], ],
help="transformer decoder self-attention layer type" help="transformer decoder self-attention layer type"
) )
......
...@@ -194,6 +194,29 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -194,6 +194,29 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='block size of quantization noise at training time') help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time') help='scalar quantization noise and scalar quantization at training time')
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument('--max-relative-length', type=int, default=-1, parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length') help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true', parser.add_argument('--k-only', default=False, action='store_true',
...@@ -1134,6 +1157,10 @@ def base_architecture(args): ...@@ -1134,6 +1157,10 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.max_relative_length = getattr(args, 'max_relative_length', -1) args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq.models.transformer import LayerNorm from fairseq.modules.layer_norm import LayerNorm
import queue import queue
import numpy as np import numpy as np
......
...@@ -3,13 +3,12 @@ ...@@ -3,13 +3,12 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.modules.multihead_attention import MultiheadAttention from fairseq.modules import MultiheadAttention
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -56,11 +55,10 @@ class RelativeMultiheadAttention(MultiheadAttention): ...@@ -56,11 +55,10 @@ class RelativeMultiheadAttention(MultiheadAttention):
self.k_only = k_only self.k_only = k_only
self.relative_position_keys = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim)) self.relative_position_keys = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
if not self.k_only:
self.relative_position_values = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
nn.init.xavier_uniform_(self.relative_position_keys) nn.init.xavier_uniform_(self.relative_position_keys)
if not self.k_only: if not self.k_only:
self.relative_position_values = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
nn.init.xavier_uniform_(self.relative_position_values) nn.init.xavier_uniform_(self.relative_position_values)
def forward( def forward(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论