Commit 0e2452b9 by xuchen

fix the bug of the relative multihead attention

parent 61cf1afa
......@@ -147,7 +147,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default="selfattn",
choices=[
"selfattn",
"rel_selfattn"
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
......@@ -183,7 +184,8 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
default="selfattn",
choices=[
"selfattn",
"rel_selfattn"
"rel_selfattn",
"relative",
],
help="transformer decoder self-attention layer type"
)
......
......@@ -194,6 +194,29 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
],
help="transformer decoder self-attention layer type"
)
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
......@@ -1134,6 +1157,10 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
......
import torch
import torch.nn as nn
from fairseq.models.transformer import LayerNorm
from fairseq.modules.layer_norm import LayerNorm
import queue
import numpy as np
......
......@@ -3,13 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules.multihead_attention import MultiheadAttention
from fairseq.modules import MultiheadAttention
from torch import Tensor, nn
from torch.nn import Parameter
......@@ -56,11 +55,10 @@ class RelativeMultiheadAttention(MultiheadAttention):
self.k_only = k_only
self.relative_position_keys = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
if not self.k_only:
self.relative_position_values = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
nn.init.xavier_uniform_(self.relative_position_keys)
if not self.k_only:
self.relative_position_values = Parameter(torch.Tensor(2 * self.max_relative_length + 1, self.head_dim))
nn.init.xavier_uniform_(self.relative_position_values)
def forward(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论