Commit b23817e0 by xuchen

implement the pyramid transformer

parent 7802e6f7
...@@ -7,4 +7,5 @@ from .berard import * # noqa ...@@ -7,4 +7,5 @@ from .berard import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa from .s2t_conformer import * # noqa
from .pys2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
...@@ -17,7 +17,9 @@ from fairseq.models.speech_to_text import ( ...@@ -17,7 +17,9 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel, S2TTransformerModel,
S2TTransformerEncoder, S2TTransformerEncoder,
S2TConformerEncoder, S2TConformerEncoder,
S2TConformerModel S2TConformerModel,
PYS2TTransformerModel,
PyS2TTransformerEncoder,
) )
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import ( from fairseq.modules import (
...@@ -46,6 +48,7 @@ class S2TSATEModel(S2TTransformerModel): ...@@ -46,6 +48,7 @@ class S2TSATEModel(S2TTransformerModel):
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
S2TConformerModel.add_args(parser) S2TConformerModel.add_args(parser)
PYS2TTransformerModel.add_args(parser)
parser.add_argument( parser.add_argument(
"--text-encoder-layers", "--text-encoder-layers",
...@@ -195,13 +198,16 @@ class Adapter(nn.Module): ...@@ -195,13 +198,16 @@ class Adapter(nn.Module):
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
out = linear_out + soft_out out = linear_out + soft_out
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1) soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "none": elif self.adapter_type == "none":
out = representation out = representation
else: else:
out = None out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type)) logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
...@@ -262,6 +268,8 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -262,6 +268,8 @@ class S2TSATEEncoder(FairseqEncoder):
self.acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens) self.acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "conformer": elif acoustic_encoder_type == "conformer":
self.acoustic_encoder = S2TConformerEncoder(args, task, embed_tokens) self.acoustic_encoder = S2TConformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "pyramid":
self.acoustic_encoder = PyS2TTransformerEncoder(args, task, embed_tokens)
else: else:
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type)) logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
...@@ -277,9 +285,9 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -277,9 +285,9 @@ class S2TSATEEncoder(FairseqEncoder):
# ) # )
acoustic_encoder_attention_type = args.encoder_attention_type acoustic_encoder_attention_type = args.encoder_attention_type
if acoustic_encoder_attention_type != "selfattn": # if acoustic_encoder_attention_type != "selfattn":
args.encoder_attention_type = "selfattn" # args.encoder_attention_type = "selfattn"
logger.info("Force self attention for text encoder.") # logger.info("Force self attention for text encoder.")
# text encoder # text encoder
self.text_encoder = TextEncoder(args, embed_tokens) self.text_encoder = TextEncoder(args, embed_tokens)
...@@ -378,6 +386,9 @@ def base_architecture(args): ...@@ -378,6 +386,9 @@ def base_architecture(args):
args.use_cnn_module = getattr(args, "use_cnn_module", False) args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31) args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_layers = getattr(args, "encoder_layers", 12)
......
...@@ -147,10 +147,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -147,10 +147,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=str, type=str,
default="selfattn", default="selfattn",
choices=[ choices=[
"local",
"selfattn", "selfattn",
"reduced",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"local",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
......
...@@ -29,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution ...@@ -29,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from .local_multihead_attention import LocalMultiheadAttention from .local_multihead_attention import LocalMultiheadAttention
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding from .positional_embedding import PositionalEmbedding
from .reduced_multihead_attention import ReducedMultiheadAttention
from .rel_position_multihead_attention import RelPositionMultiheadAttention from .rel_position_multihead_attention import RelPositionMultiheadAttention
from .relative_multihead_attention import RelativeMultiheadAttention from .relative_multihead_attention import RelativeMultiheadAttention
from .same_pad import SamePad from .same_pad import SamePad
...@@ -41,6 +42,7 @@ from .unfold import unfold1d ...@@ -41,6 +42,7 @@ from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .conformer_layer import ConformerEncoderLayer from .conformer_layer import ConformerEncoderLayer
from .pyramid_layer import PyramidTransformerEncoderLayer
__all__ = [ __all__ = [
"AdaptiveInput", "AdaptiveInput",
...@@ -74,6 +76,8 @@ __all__ = [ ...@@ -74,6 +76,8 @@ __all__ = [
"LocalMultiheadAttention", "LocalMultiheadAttention",
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"PyramidTransformerEncoderLayer",
"ReducedMultiheadAttention",
"RelPositionMultiheadAttention", "RelPositionMultiheadAttention",
"RelativeMultiheadAttention", "RelativeMultiheadAttention",
"SamePad", "SamePad",
......
...@@ -325,7 +325,6 @@ class LocalMultiheadAttention(nn.Module): ...@@ -325,7 +325,6 @@ class LocalMultiheadAttention(nn.Module):
multihead_mask_weight = None multihead_mask_weight = None
gauss_bias = None gauss_bias = None
if self.multihead_gauss_mask_sigma is not None: if self.multihead_gauss_mask_sigma is not None:
data_type = attn_weights.dtype
x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1).to(attn_weights.device) x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1).to(attn_weights.device)
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1).to(attn_weights.device) x2 = torch.arange(-1, src_len - 1, 1).view(1, -1).to(attn_weights.device)
dis_square = -(x1 - x2) ** 2 / 2.0 dis_square = -(x1 - x2) ** 2 / 2.0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论