Commit 47e0f6e0 by xuchen

add the multibranch S2T architecture.

I also find some bugs in the dual architecture.
parent 793f553a
...@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa ...@@ -10,3 +10,4 @@ from .pdss2t_transformer import * # noqa
from .s2t_sate import * # noqa from .s2t_sate import * # noqa
from .s2t_dual import * # noqa from .s2t_dual import * # noqa
from .s2t_ctc import * from .s2t_ctc import *
from .s2t_multibranch import *
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
import logging import logging
import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import ( ...@@ -35,17 +28,6 @@ from fairseq.models.transformer_s2 import (
TransformerS2Encoder, TransformerS2Encoder,
TransformerS2Decoder, TransformerS2Decoder,
) )
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
TransformerS2DecoderLayer,
TransformerS2EncoderLayer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder): ...@@ -343,8 +343,8 @@ class TransformerS2Decoder(TransformerDecoder):
and len(encoder_out["encoder_padding_mask"]) > 0 and len(encoder_out["encoder_padding_mask"]) > 0
) )
else None, else None,
encoder_out_s2=encoder_out["encoder_out_s2"][0], encoder_out_s2=encoder_out["s2_encoder_out"][0],
encoder_padding_mask_s2=encoder_out["encoder_padding_mask_s2"][0], encoder_padding_mask_s2=encoder_out["s2_encoder_padding_mask"][0],
incremental_state=incremental_state, incremental_state=incremental_state,
self_attn_mask=self_attn_mask, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
......
...@@ -61,6 +61,7 @@ from .espnet_multihead_attention import ( ...@@ -61,6 +61,7 @@ from .espnet_multihead_attention import (
) )
from .convolution import ConvolutionModule from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .s2t_transformer_s2_layer import S2TTransformerS2EncoderLayer
from .pds_layer import PDSTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
__all__ = [ __all__ = [
...@@ -70,6 +71,7 @@ __all__ = [ ...@@ -70,6 +71,7 @@ __all__ = [
"BeamableMM", "BeamableMM",
"CharacterTokenEmbedder", "CharacterTokenEmbedder",
"S2TTransformerEncoderLayer", "S2TTransformerEncoderLayer",
"S2TTransformerS2EncoderLayer",
"ConvolutionModule", "ConvolutionModule",
"ConvTBC", "ConvTBC",
"cross_entropy", "cross_entropy",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论