Commit 47e0f6e0 by xuchen

add the multibranch S2T architecture.

I also find some bugs in the dual architecture.
parent 793f553a
......@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa
from .s2t_dual import * # noqa
from .s2t_ctc import *
from .s2t_multibranch import *
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch
......@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
TransformerS2Encoder,
TransformerS2Decoder,
)
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
TransformerS2DecoderLayer,
TransformerS2EncoderLayer,
)
logger = logging.getLogger(__name__)
......
......@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
encoder_out_s2=encoder_out["encoder_out_s2"][0],
encoder_padding_mask_s2=encoder_out["encoder_padding_mask_s2"][0],
encoder_out_s2=encoder_out["s2_encoder_out"][0],
encoder_padding_mask_s2=encoder_out["s2_encoder_padding_mask"][0],
incremental_state=incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
......
......@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
)
from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .s2t_transformer_s2_layer import S2TTransformerS2EncoderLayer
from .pds_layer import PDSTransformerEncoderLayer
__all__ = [
......@@ -70,6 +71,7 @@ __all__ = [
"BeamableMM",
"CharacterTokenEmbedder",
"S2TTransformerEncoderLayer",
"S2TTransformerS2EncoderLayer",
"ConvolutionModule",
"ConvTBC",
"cross_entropy",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论