Commit f190005c by xuchen

support the rpe for encoder and decoder respectively

parent 28d33ad8
...@@ -532,13 +532,16 @@ def base_architecture(args): ...@@ -532,13 +532,16 @@ def base_architecture(args):
args.encoder_integration_type = getattr(args, 'encoder_integration_type', 'avg') args.encoder_integration_type = getattr(args, 'encoder_integration_type', 'avg')
args.decoder_integration_type = getattr(args, 'decoder_integration_type', 'avg') args.decoder_integration_type = getattr(args, 'decoder_integration_type', 'avg')
args.max_relative_length = getattr(args, 'max_relative_length', -1) args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("dlcl_transformer", "dlcl_transformer_relative") @register_model_architecture("dlcl_transformer", "dlcl_transformer_relative")
def dlcl_transformer_relative(args): def dlcl_transformer_relative(args):
args.max_relative_length = 20 args.max_encoder_relative_length = 20
args.max_decoder_relative_length = 20
args.k_only = True args.k_only = True
base_architecture(args) base_architecture(args)
......
...@@ -185,7 +185,8 @@ def base_architecture(args): ...@@ -185,7 +185,8 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1) args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
...@@ -201,7 +202,8 @@ def s2t_conformer_s(args): ...@@ -201,7 +202,8 @@ def s2t_conformer_s(args):
@register_model_architecture("s2t_conformer", "s2t_conformer_s_relative") @register_model_architecture("s2t_conformer", "s2t_conformer_s_relative")
def s2t_conformer_s_relative(args): def s2t_conformer_s_relative(args):
args.max_relative_length = 20 args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20
args.k_only = True args.k_only = True
s2t_conformer_s(args) s2t_conformer_s(args)
......
...@@ -6,6 +6,7 @@ import math ...@@ -6,6 +6,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import checkpoint_utils from fairseq import checkpoint_utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
register_model, register_model,
...@@ -82,6 +83,15 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -82,6 +83,15 @@ class S2TSATEModel(S2TTransformerModel):
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TSATEEncoder(args, task, embed_tokens) encoder = S2TSATEEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained acoustic encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
if getattr(args, "load_pretrained_acoustic_encoder_from", None): if getattr(args, "load_pretrained_acoustic_encoder_from", None):
logger.info( logger.info(
f"loaded pretrained acoustic encoder from: " f"loaded pretrained acoustic encoder from: "
...@@ -202,6 +212,7 @@ class TextEncoder(FairseqEncoder): ...@@ -202,6 +212,7 @@ class TextEncoder(FairseqEncoder):
super().__init__(None) super().__init__(None)
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.text_encoder_layers)] [TransformerEncoderLayer(args) for _ in range(args.text_encoder_layers)]
) )
...@@ -247,8 +258,19 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -247,8 +258,19 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter # adapter
self.adapter = Adapter(args, task.source_dictionary, embed_tokens) self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
# self.length_adapter = Conv1dSubsampler(
# args.encoder_embed_dim,
# args.conv_channels,
# args.encoder_embed_dim,
# [int(k) for k in args.conv_kernel_sizes.split(",")],
# )
# acoustic_encoder_attention_type = args.encoder_attention_type
# args.encoder_attention_type = "selfattn"
# text encoder # text encoder
self.text_encoder = TextEncoder(args, embed_tokens) self.text_encoder = TextEncoder(args, embed_tokens)
# args.encoder_attention_type = acoustic_encoder_attention_type
if getattr(args, "use_enc_dlcl", False): if getattr(args, "use_enc_dlcl", False):
normalize_before = args.encoder_normalize_before normalize_before = args.encoder_normalize_before
...@@ -283,6 +305,11 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -283,6 +305,11 @@ class S2TSATEEncoder(FairseqEncoder):
self.history.add(x) self.history.add(x)
# src_lengths = (~encoder_padding_mask).sum(1)
# x = x.transpose(0, 1)
# x, input_lengths = self.length_adapter(x, src_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x = self.text_encoder(x, encoder_padding_mask, positions, self.history) x = self.text_encoder(x, encoder_padding_mask, positions, self.history)
return { return {
...@@ -375,7 +402,8 @@ def base_architecture(args): ...@@ -375,7 +402,8 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1) args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
...@@ -391,7 +419,8 @@ def s2t_sate_s(args): ...@@ -391,7 +419,8 @@ def s2t_sate_s(args):
@register_model_architecture("s2t_sate", "s2t_sate_s_relative") @register_model_architecture("s2t_sate", "s2t_sate_s_relative")
def s2t_sate_s_relative(args): def s2t_sate_s_relative(args):
args.max_relative_length = 20 args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20
args.k_only = True args.k_only = True
s2t_sate_s(args) s2t_sate_s(args)
......
...@@ -220,7 +220,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -220,7 +220,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="if True, dont scale embeddings", help="if True, dont scale embeddings",
) )
parser.add_argument('--max-relative-length', type=int, default=-1, parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max relative length') help='the max relative length')
parser.add_argument('--k-only', default=False, action='store_true', parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information') help='select the relative mode to map relative position information')
...@@ -567,7 +569,8 @@ def base_architecture(args): ...@@ -567,7 +569,8 @@ def base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_relative_length = getattr(args, 'max_relative_length', -1) args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
...@@ -583,7 +586,8 @@ def s2t_transformer_s(args): ...@@ -583,7 +586,8 @@ def s2t_transformer_s(args):
@register_model_architecture("s2t_transformer", "s2t_transformer_s_relative") @register_model_architecture("s2t_transformer", "s2t_transformer_s_relative")
def s2t_transformer_s_relative(args): def s2t_transformer_s_relative(args):
args.max_relative_length = 20 args.max_encoder_relative_length = 20
args.max_decoder_relative_length = 20
args.k_only = True args.k_only = True
s2t_transformer_s(args) s2t_transformer_s(args)
......
...@@ -218,8 +218,10 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -218,8 +218,10 @@ class TransformerModel(FairseqEncoderDecoderModel):
], ],
help="transformer decoder self-attention layer type" help="transformer decoder self-attention layer type"
) )
parser.add_argument('--max-relative-length', type=int, default=-1, parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length') help='the max encoder relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max decoder relative length')
parser.add_argument('--k-only', default=False, action='store_true', parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information') help='select the relative mode to map relative position information')
# args for loading pre-trained models # args for loading pre-trained models
...@@ -1182,13 +1184,15 @@ def base_architecture(args): ...@@ -1182,13 +1184,15 @@ def base_architecture(args):
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn") args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn") args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
args.max_relative_length = getattr(args, 'max_relative_length', -1) args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
@register_model_architecture("transformer", "transformer_relative") @register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args): def transformer_rpr(args):
args.max_relative_length = 20 args.max_encoder_relative_length = 20
args.max_decoder_relative_length = 20
args.k_only = True args.k_only = True
base_architecture(args) base_architecture(args)
......
...@@ -8,7 +8,13 @@ from typing import Optional ...@@ -8,7 +8,13 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention, RelPositionMultiheadAttention, ConvolutionModule from fairseq.modules import (
LayerNorm,
MultiheadAttention,
RelPositionMultiheadAttention,
RelativeMultiheadAttention,
ConvolutionModule
)
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
...@@ -112,6 +118,21 @@ class ConformerEncoderLayer(nn.Module): ...@@ -112,6 +118,21 @@ class ConformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
max_relative_length = getattr(args, "max_encoder_relative_length", -1)
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else: else:
attn_func = MultiheadAttention attn_func = MultiheadAttention
print("The attention type %s is not supported!" % self.attn_type) print("The attention type %s is not supported!" % self.attn_type)
......
...@@ -87,18 +87,24 @@ class TransformerEncoderLayer(nn.Module): ...@@ -87,18 +87,24 @@ class TransformerEncoderLayer(nn.Module):
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1: elif self.attn_type == "relative":
return RelativeMultiheadAttention( # max_relative_length = getattr(args, "max_encoder_relative_length", -1)
embed_dim, max_relative_length = max(getattr(args, "max_encoder_relative_length", -1), getattr(args, "max_relative_length", -1))
args.encoder_attention_heads, if max_relative_length != -1:
dropout=args.attention_dropout, return RelativeMultiheadAttention(
self_attention=True, embed_dim,
q_noise=self.quant_noise, args.encoder_attention_heads,
qn_block_size=self.quant_noise_block_size, dropout=args.attention_dropout,
max_relative_length=args.max_relative_length, self_attention=True,
) q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else: else:
print("The attention type %s is not supported!" % self.attn_type) print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
return attn_func( return attn_func(
...@@ -292,18 +298,23 @@ class TransformerDecoderLayer(nn.Module): ...@@ -292,18 +298,23 @@ class TransformerDecoderLayer(nn.Module):
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative" or getattr(args, "max_relative_length", -1) != -1: elif self.attn_type == "relative":
return RelativeMultiheadAttention( max_relative_length = max(getattr(args, "max_decoder_relative_length", -1), getattr(args, "max_relative_length", -1))
embed_dim, if max_relative_length != -1:
args.encoder_attention_heads, return RelativeMultiheadAttention(
dropout=args.attention_dropout, embed_dim,
self_attention=True, args.decoder_attention_heads,
q_noise=self.quant_noise, dropout=args.attention_dropout,
qn_block_size=self.quant_noise_block_size, self_attention=True,
max_relative_length=args.max_relative_length, q_noise=self.quant_noise,
) qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum decoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else: else:
print("The attention type %s is not supported!" % self.attn_type) print("The decoder attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
return attn_func( return attn_func(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论