Commit b23817e0 by xuchen

implement the pyramid transformer

parent 7802e6f7
......@@ -7,4 +7,5 @@ from .berard import * # noqa
from .convtransformer import * # noqa
from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
from .pys2t_transformer import * # noqa
from .s2t_sate import * # noqa
......@@ -17,7 +17,9 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
S2TConformerEncoder,
S2TConformerModel
S2TConformerModel,
PYS2TTransformerModel,
PyS2TTransformerEncoder,
)
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import (
......@@ -46,6 +48,7 @@ class S2TSATEModel(S2TTransformerModel):
def add_args(parser):
"""Add model-specific arguments to the parser."""
S2TConformerModel.add_args(parser)
PYS2TTransformerModel.add_args(parser)
parser.add_argument(
"--text-encoder-layers",
......@@ -195,13 +198,16 @@ class Adapter(nn.Module):
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "none":
out = representation
else:
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
......@@ -262,6 +268,8 @@ class S2TSATEEncoder(FairseqEncoder):
self.acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "conformer":
self.acoustic_encoder = S2TConformerEncoder(args, task, embed_tokens)
elif acoustic_encoder_type == "pyramid":
self.acoustic_encoder = PyS2TTransformerEncoder(args, task, embed_tokens)
else:
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
......@@ -277,9 +285,9 @@ class S2TSATEEncoder(FairseqEncoder):
# )
acoustic_encoder_attention_type = args.encoder_attention_type
if acoustic_encoder_attention_type != "selfattn":
args.encoder_attention_type = "selfattn"
logger.info("Force self attention for text encoder.")
# if acoustic_encoder_attention_type != "selfattn":
# args.encoder_attention_type = "selfattn"
# logger.info("Force self attention for text encoder.")
# text encoder
self.text_encoder = TextEncoder(args, embed_tokens)
......@@ -378,6 +386,9 @@ def base_architecture(args):
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
# Pyramid
args.pyramid_layers = getattr(args, "pyramid_layers", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
......
......@@ -147,10 +147,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type=str,
default="selfattn",
choices=[
"local",
"selfattn",
"reduced",
"rel_selfattn",
"relative",
"local",
],
help="transformer encoder self-attention layer type"
)
......
......@@ -29,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from .local_multihead_attention import LocalMultiheadAttention
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .reduced_multihead_attention import ReducedMultiheadAttention
from .rel_position_multihead_attention import RelPositionMultiheadAttention
from .relative_multihead_attention import RelativeMultiheadAttention
from .same_pad import SamePad
......@@ -41,6 +42,7 @@ from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock
from .conformer_layer import ConformerEncoderLayer
from .pyramid_layer import PyramidTransformerEncoderLayer
__all__ = [
"AdaptiveInput",
......@@ -74,6 +76,8 @@ __all__ = [
"LocalMultiheadAttention",
"MultiheadAttention",
"PositionalEmbedding",
"PyramidTransformerEncoderLayer",
"ReducedMultiheadAttention",
"RelPositionMultiheadAttention",
"RelativeMultiheadAttention",
"SamePad",
......
......@@ -325,7 +325,6 @@ class LocalMultiheadAttention(nn.Module):
multihead_mask_weight = None
gauss_bias = None
if self.multihead_gauss_mask_sigma is not None:
data_type = attn_weights.dtype
x1 = torch.arange(-1, src_len - 1, 1).view(-1, 1).to(attn_weights.device)
x2 = torch.arange(-1, src_len - 1, 1).view(1, -1).to(attn_weights.device)
dis_square = -(x1 - x2) ** 2 / 2.0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论