Commit f190005c by xuchen

support the rpe for encoder and decoder respectively

parent 28d33ad8
......@@ -532,13 +532,16 @@ def base_architecture(args):
args.encoder_integration_type = getattr(args, 'encoder_integration_type', 'avg')
args.decoder_integration_type = getattr(args, 'decoder_integration_type', 'avg')
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("dlcl_transformer", "dlcl_transformer_relative")
def dlcl_transformer_relative(args):
args.max_relative_length = 20
args.max_encoder_relative_length = 20
args.max_decoder_relative_length = 20
args.k_only = True
base_architecture(args)
......
......@@ -185,7 +185,8 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
......@@ -201,7 +202,8 @@ def s2t_conformer_s(args):
@register_model_architecture("s2t_conformer", "s2t_conformer_s_relative")
def s2t_conformer_s_relative(args):
args.max_relative_length = 20
args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20
args.k_only = True
s2t_conformer_s(args)
......
......@@ -6,6 +6,7 @@ import math
import torch
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
register_model,
......@@ -82,6 +83,15 @@ class S2TSATEModel(S2TTransformerModel):
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TSATEEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained acoustic encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
if getattr(args, "load_pretrained_acoustic_encoder_from", None):
logger.info(
f"loaded pretrained acoustic encoder from: "
......@@ -202,6 +212,7 @@ class TextEncoder(FairseqEncoder):
super().__init__(None)
self.embed_tokens = embed_tokens
self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.text_encoder_layers)]
)
......@@ -247,8 +258,19 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter
self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
# self.length_adapter = Conv1dSubsampler(
# args.encoder_embed_dim,
# args.conv_channels,
# args.encoder_embed_dim,
# [int(k) for k in args.conv_kernel_sizes.split(",")],
# )
# acoustic_encoder_attention_type = args.encoder_attention_type
# args.encoder_attention_type = "selfattn"
# text encoder
self.text_encoder = TextEncoder(args, embed_tokens)
# args.encoder_attention_type = acoustic_encoder_attention_type
if getattr(args, "use_enc_dlcl", False):
normalize_before = args.encoder_normalize_before
......@@ -283,6 +305,11 @@ class S2TSATEEncoder(FairseqEncoder):
self.history.add(x)
# src_lengths = (~encoder_padding_mask).sum(1)
# x = x.transpose(0, 1)
# x, input_lengths = self.length_adapter(x, src_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x = self.text_encoder(x, encoder_padding_mask, positions, self.history)
return {
......@@ -375,7 +402,8 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
......@@ -391,7 +419,8 @@ def s2t_sate_s(args):
@register_model_architecture("s2t_sate", "s2t_sate_s_relative")
def s2t_sate_s_relative(args):
args.max_relative_length = 20
args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20
args.k_only = True
s2t_sate_s(args)
......
......@@ -220,7 +220,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument('--max-relative-length', type=int, default=-1,
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
......@@ -567,7 +569,8 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
......@@ -583,7 +586,8 @@ def s2t_transformer_s(args):
@register_model_architecture("s2t_transformer", "s2t_transformer_s_relative")
def s2t_transformer_s_relative(args):
args.max_relative_length = 20
args.max_encoder_relative_length = 20
args.max_decoder_relative_length = 20
args.k_only = True
s2t_transformer_s(args)
......
......@@ -218,8 +218,10 @@ class TransformerModel(FairseqEncoderDecoderModel):
],
help="transformer decoder self-attention layer type"
)
parser.add_argument('--max-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max encoder relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max decoder relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
# args for loading pre-trained models
......@@ -1182,13 +1184,15 @@ def base_architecture(args):
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.max_relative_length = getattr(args, 'max_relative_length', -1)
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args):
args.max_relative_length = 20
args.max_encoder_relative_length = 20
args.max_decoder_relative_length = 20
args.k_only = True
base_architecture(args)
......
......@@ -8,7 +8,13 @@ from typing import Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention, RelPositionMultiheadAttention, ConvolutionModule
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention,
ConvolutionModule
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
......@@ -112,6 +118,21 @@ class ConformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
max_relative_length = getattr(args, "max_encoder_relative_length", -1)
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else:
attn_func = MultiheadAttention
print("The attention type %s is not supported!" % self.attn_type)
......
......@@ -87,18 +87,24 @@ class TransformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
elif self.attn_type == "relative":
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length = max(getattr(args, "max_encoder_relative_length", -1), getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else:
print("The attention type %s is not supported!" % self.attn_type)
print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
......@@ -292,18 +298,23 @@ class TransformerDecoderLayer(nn.Module):
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=args.max_relative_length,
)
elif self.attn_type == "relative":
max_relative_length = max(getattr(args, "max_decoder_relative_length", -1), getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum decoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else:
print("The attention type %s is not supported!" % self.attn_type)
print("The decoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论