Commit 55702466 by xuchen

add the pure ctc arch for pds method

parent bab6c520
arch: s2t_transformer_s arch: s2t_ctc
encoder-type: pds
#arch: pdss2t_transformer_s_8
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3 lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: ctc
label_smoothing: 0.1
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
ctc-weight: 0.2
intermedia-ctc-layers: 6,9
intermedia-adapter: league
intermedia-ctc-weight: 0.1
intermedia-drop-prob: 0.5
ctc-self-distill-weight: 0
post-process: sentencepiece
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
...@@ -40,8 +38,3 @@ encoder-attention-heads: 4 ...@@ -40,8 +38,3 @@ encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_ctc
encoder-type: pds
encoder-embed-dim: 256
pds-stages: 4
ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: ctc
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
...@@ -137,7 +137,8 @@ class AudioDataset(Dataset): ...@@ -137,7 +137,8 @@ class AudioDataset(Dataset):
for i, segment in enumerate(seg_group): for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate) offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{split}_{wav_path.stem}_{i}" # _id = f"{split}_{wav_path.stem}_{i}"
_id = f"{wav_path.stem}_{i}"
item = dict() item = dict()
item["audio"] = wav_path.as_posix() item["audio"] = wav_path.as_posix()
...@@ -263,8 +264,12 @@ def process(args): ...@@ -263,8 +264,12 @@ def process(args):
utt_id = item['id'] utt_id = item['id']
features_path = (feature_root / f"{utt_id}.npy").as_posix() features_path = (feature_root / f"{utt_id}.npy").as_posix()
tag_features_path = (feature_root / f"{split}_{utt_id}.npy").as_posix()
if os.path.exists(features_path): if os.path.exists(tag_features_path):
continue
if os.path.exists(features_path) and not os.path.exists(tag_features_path):
shutil.move(features_path, tag_features_path)
continue continue
waveform, sample_rate, _ = dataset.get(idx, need_waveform=True) waveform, sample_rate, _ = dataset.get(idx, need_waveform=True)
......
...@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask ...@@ -20,8 +20,6 @@ from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round from fairseq.logging.meters import safe_round
@dataclass @dataclass
class CtcCriterionConfig(FairseqDataclass): class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field( zero_infinity: bool = field(
...@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass): ...@@ -30,7 +28,7 @@ class CtcCriterionConfig(FairseqDataclass):
) )
sentence_avg: bool = II("optimization.sentence_avg") sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field( post_process: str = field(
default="letter", default="sentencepiece",
metadata={ metadata={
"help": "how to post process predictions into words. can be letter, " "help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. " "wordpiece, BPE symbols, etc. "
......
...@@ -47,6 +47,8 @@ __all__ = [ ...@@ -47,6 +47,8 @@ __all__ = [
"FairseqLanguageModel", "FairseqLanguageModel",
"FairseqModel", "FairseqModel",
"FairseqMultiModel", "FairseqMultiModel",
"register_model",
"register_model_architecture"
] ]
......
...@@ -19,6 +19,8 @@ from fairseq.modules import ( ...@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
RelPositionalEncoding,
LegacyRelPositionalEncoding,
PDSTransformerEncoderLayer, PDSTransformerEncoderLayer,
DownSampleConvolutionModule DownSampleConvolutionModule
) )
...@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -137,19 +139,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# input
parser.add_argument(
"--conv-kernel-sizes",
type=str,
metavar="N",
help="kernel sizes of Conv1d subsampling layers",
)
parser.add_argument(
"--conv-channels",
type=int,
metavar="N",
help="# of channels in Conv1d subsampling layers",
)
# Transformer # Transformer
parser.add_argument( parser.add_argument(
"--activation-fn", "--activation-fn",
...@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -199,6 +188,10 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"reduced", "reduced",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"rel_pos_legacy",
"rel_pos",
"rope",
"abs",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -333,6 +326,14 @@ class PDSS2TTransformerModel(S2TTransformerModel):
help='dropout for history output') help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1', parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all') help='how many past layers are considered. -1 means all')
# CTC
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# local modeling # local modeling
parser.add_argument( parser.add_argument(
'--hard-mask-window', '--hard-mask-window',
...@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -358,6 +359,13 @@ class PDSS2TTransformerModel(S2TTransformerModel):
# Conformer setting # Conformer setting
parser.add_argument( parser.add_argument(
"--encoder-activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--macaron-style", "--macaron-style",
default=False, default=False,
type=bool, type=bool,
...@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -380,14 +388,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"The legacy relative positional encoding will be deprecated in the future." "The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.", "More Details can be found in https://github.com/espnet/espnet/pull/2816.",
) )
# CTC # CNN module
parser.add_argument(
"--ctc-layer",
default=0,
type=int,
help="the position of the ctc loss",
)
# Conformer module
parser.add_argument( parser.add_argument(
"--use-cnn-module", "--use-cnn-module",
default=False, default=False,
...@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -443,7 +444,6 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="use the position embedding or not before each encoding", help="use the position embedding or not before each encoding",
) )
parser.add_argument( parser.add_argument(
"--pds-attn-heads", "--pds-attn-heads",
type=str, type=str,
...@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -479,6 +479,8 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="use the ctc after each stage", help="use the ctc after each stage",
) )
# intermedia ctc
parser.add_argument( parser.add_argument(
"--intermedia-ctc-layers", "--intermedia-ctc-layers",
default=None, default=None,
...@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -491,6 +493,18 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type=str, type=str,
help="type of intermedia adapter", help="type of intermedia adapter",
) )
parser.add_argument(
"--intermedia-distribution-cutoff",
default=-1,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
pass pass
@classmethod @classmethod
...@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel): ...@@ -504,6 +518,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
f"loaded pretrained encoder from: " f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}" f"{args.load_pretrained_encoder_from}"
) )
return encoder return encoder
...@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -535,7 +550,6 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.pds_kernel_sizes = [int(n) for n in args.pds_kernel_sizes.split("_")] self.pds_kernel_sizes = [int(n) for n in args.pds_kernel_sizes.split("_")]
self.pds_embed_norm = args.pds_embed_norm self.pds_embed_norm = args.pds_embed_norm
self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")] self.pds_position_embed = [int(n) for n in args.pds_position_embed.split("_")]
self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")] self.pds_attn_heads = [int(n) for n in args.pds_attn_heads.split("_")]
self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")] self.pds_ffn_ratios = [int(n) for n in args.pds_ffn_ratios.split("_")]
if self.attn_type == "reduced": if self.attn_type == "reduced":
...@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -596,6 +610,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if args.no_scale_embedding: if args.no_scale_embedding:
self.embed_scale = 1.0 self.embed_scale = 1.0
# down-sampling
downsampling = Downsampling( downsampling = Downsampling(
self.pds_ds_method, self.pds_ds_method,
self.pds_embed_norm, self.pds_embed_norm,
...@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -605,8 +620,23 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
stride=ds_ratio, stride=ds_ratio,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
) )
# position encoding
if use_pos_embed: if use_pos_embed:
pos_embed = PositionalEmbedding(args.max_source_positions, embed_dim, self.padding_idx) if self.attn_type == "rel_pos":
pos_embed = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
pos_embed = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
pos_embed = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
else: else:
pos_embed = None pos_embed = None
...@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -614,6 +644,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio) PDSTransformerEncoderLayer(args, embed_dim, embed_dim * ffn_ratio, num_head, attn_ds_ratio)
for _ in range(num_layers)]) for _ in range(num_layers)])
# representation fusion
fusion_pre_layer_norm = None fusion_pre_layer_norm = None
fusion_post_layer_norm = None fusion_post_layer_norm = None
fusion_downsampling = None fusion_downsampling = None
...@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -700,7 +731,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True) self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and (getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0)) (getattr(args, "ctc_weight", False) > 0))
if self.use_ctc: if self.use_ctc:
self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers self.ctc_layer = (args.ctc_layer + args.encoder_layers) % args.encoder_layers
...@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -799,9 +831,16 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout # add the position encoding and dropout
if pos_embed: if pos_embed:
positions = pos_embed(encoder_padding_mask).transpose(0, 1) if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
x += positions positions = pos_embed(x)
positions = self.dropout(positions)
elif self.attn_type == "rope":
positions = None
else:
positions = pos_embed(encoder_padding_mask).transpose(0, 1)
x += positions
positions = None
else: else:
positions = None positions = None
......
...@@ -363,6 +363,85 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -363,6 +363,85 @@ class S2TCTCModel(FairseqEncoderModel):
parser.add_argument('--cl-dropout-strategy', parser.add_argument('--cl-dropout-strategy',
type=str, type=str,
help='interleaved dropout probability') help='interleaved dropout probability')
# pds setting
parser.add_argument(
"--pds-stages",
type=int,
help="the number of the stage",
)
parser.add_argument(
"--pds-layers",
type=str,
help="the number of the encoder layers in each stage",
)
parser.add_argument(
"--pds-ratios",
type=str,
help="the ratio of the down-sampling in each stage",
)
parser.add_argument(
"--pds-ds-method",
type=str,
choices=["glu", "conv", "proj", "fusion"],
help="the down-sampling method",
)
parser.add_argument(
"--pds-embed-dims",
type=str,
help="the embedding dimension in each stage",
)
parser.add_argument(
"--pds-kernel-sizes",
type=str,
help="the kernel size of the down-sampling module in each stage",
)
parser.add_argument(
"--pds-embed-norm",
action="store_true",
help="use layer norm in the down-sampling module",
)
parser.add_argument(
"--pds-position-embed",
type=str,
help="use the position embedding or not before each encoding",
)
parser.add_argument(
"--pds-attn-heads",
type=str,
help="the number of the attention heads in each stage",
)
parser.add_argument(
"--pds-attn-ds-ratio",
type=str,
help="the ratio of the down-sampling in the self attention module",
)
parser.add_argument(
"--pds-ffn-ratios",
type=str,
help="the ratio of the ffn in each stage",
)
parser.add_argument(
"--pds-fusion",
action="store_true",
help="use the representation fusion method",
)
parser.add_argument(
"--pds-fusion-method",
type=str,
help="the fusion method",
)
parser.add_argument(
"--pds-dropout",
type=float,
help="dropout in each stage",
)
parser.add_argument(
"--pds-ctc",
type=str,
help="use the ctc after each stage",
)
# intermedia CTC loss # intermedia CTC loss
parser.add_argument( parser.add_argument(
"--intermedia-ctc-layers", "--intermedia-ctc-layers",
...@@ -388,6 +467,14 @@ class S2TCTCModel(FairseqEncoderModel): ...@@ -388,6 +467,14 @@ class S2TCTCModel(FairseqEncoderModel):
type=float, type=float,
help="probability of dropping the followed layers", help="probability of dropping the followed layers",
) )
# encoder
parser.add_argument(
"--encoder-type",
default="transformer",
type=str,
help="encoder type",
)
pass pass
@classmethod @classmethod
...@@ -452,79 +539,90 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -452,79 +539,90 @@ class S2TCTCEncoder(FairseqEncoder):
def __init__(self, args, task=None): def __init__(self, args, task=None):
super().__init__(None) super().__init__(None)
dim = args.encoder_embed_dim encoder_type = getattr(args, "encoder_type", "transformer")
self.dropout_module = FairseqDropout( if encoder_type == "transformer":
p=args.dropout, module_name=self.__class__.__name__ from .s2t_transformer import S2TTransformerEncoder
) self.encoder = S2TTransformerEncoder(args, task)
self.embed_scale = math.sqrt(dim) elif encoder_type == "pds":
if args.no_scale_embedding: from .pdss2t_transformer import PDSS2TTransformerEncoder
self.embed_scale = 1.0 self.encoder = PDSS2TTransformerEncoder(args, task)
self.padding_idx = 1
self.subsample = subsampling(args)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos":
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
self.embed_positions = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
self.layers = nn.ModuleList(
[S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(dim)
else: else:
self.layer_norm = None logger.error("Unsupported architecture: %s." % encoder_type)
if args.use_enc_dlcl: return
self.history = DynamicLinearCombination(args, is_encoder=True) # dim = args.encoder_embed_dim
else: # self.dropout_module = FairseqDropout(
self.history = None # p=args.dropout, module_name=self.__class__.__name__
# )
self.ctc = CTC(dim, # self.embed_scale = math.sqrt(dim)
dictionary_size=len(task.source_dictionary), # if args.no_scale_embedding:
dropout=args.dropout, # self.embed_scale = 1.0
) # self.padding_idx = 1
#
# gather cosine similarity of the representation # self.subsample = subsampling(args)
self.gather_cos_sim = getattr(args, "gather_cos_sim", False) #
# self.gather_cos_sim = True # self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.dis = 2 #
self.cos_sim = dict() # if self.attn_type == "rel_pos":
# self.embed_positions = RelPositionalEncoding(
self.intermedia_ctc_layers = [] # args.max_source_positions, args.encoder_embed_dim
# )
if args.intermedia_ctc_layers is not None: # elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",") # self.embed_positions = LegacyRelPositionalEncoding(
for layer_idx in intermedia_ctc_layers: # args.encoder_embed_dim, args.dropout, args.max_source_positions
layer_idx = int(layer_idx) # )
if layer_idx <= 0: # elif self.attn_type == "rope":
layer_idx += args.encoder_layers # self.embed_positions = None
self.intermedia_ctc_layers.append(layer_idx) # else: # Use absolute positional embedding
# self.embed_positions = PositionalEmbedding(
logger.info("Intermedia CTC loss in layer %d" % layer_idx) # args.max_source_positions, args.encoder_embed_dim, self.padding_idx
# )
strategy = None #
if args.intermedia_adapter == "shrink": # self.layers = nn.ModuleList(
strategy = getattr(args, "ctc_compress_strategy", "avg") # [S2TTransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
elif args.intermedia_adapter == "league": # )
strategy = getattr(args, "intermedia_distribution_cutoff", -1) #
self.adapter = Adapter(dim, args.intermedia_adapter, # if args.encoder_normalize_before:
task.source_dictionary, strategy=strategy) # self.layer_norm = LayerNorm(dim)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) # else:
# self.layer_norm = None
#
# if args.use_enc_dlcl:
# self.history = DynamicLinearCombination(args, is_encoder=True)
# else:
# self.history = None
#
# self.ctc = CTC(dim,
# dictionary_size=len(task.source_dictionary),
# dropout=args.dropout,
# )
#
# # gather cosine similarity of the representation
# self.gather_cos_sim = getattr(args, "gather_cos_sim", False)
# # self.gather_cos_sim = True
# self.dis = 2
# self.cos_sim = dict()
#
# self.intermedia_ctc_layers = []
#
# if args.intermedia_ctc_layers is not None:
# intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
# for layer_idx in intermedia_ctc_layers:
# layer_idx = int(layer_idx)
# if layer_idx <= 0:
# layer_idx += args.encoder_layers
# self.intermedia_ctc_layers.append(layer_idx)
#
# logger.info("Intermedia CTC loss in layer %d" % layer_idx)
#
# strategy = None
# if args.intermedia_adapter == "shrink":
# strategy = getattr(args, "ctc_compress_strategy", "avg")
# elif args.intermedia_adapter == "league":
# strategy = getattr(args, "intermedia_distribution_cutoff", -1)
# self.adapter = Adapter(dim, args.intermedia_adapter,
# task.source_dictionary, strategy=strategy)
# self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
def add_to_dict(self, x, dis, idx): def add_to_dict(self, x, dis, idx):
sim = 0 sim = 0
...@@ -546,102 +644,107 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -546,102 +644,107 @@ class S2TCTCEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths, **kwargs): def forward(self, src_tokens, src_lengths, **kwargs):
if self.history is not None: return self.encoder(src_tokens, src_lengths, **kwargs)
self.history.clean() #
# if self.history is not None:
# gather cosine similarity # self.history.clean()
cos_sim_idx = -1 #
dis = self.dis # # gather cosine similarity
if self.gather_cos_sim: # cos_sim_idx = -1
self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx) # dis = self.dis
# if self.gather_cos_sim:
# down-sampling # self.add_to_dict(src_tokens.transpose(0, 1), dis, cos_sim_idx)
x, input_lengths = self.subsample(src_tokens, src_lengths) #
# (B, T, D) -> (T, B, D) # # down-sampling
x = x.transpose(0, 1) # x, input_lengths = self.subsample(src_tokens, src_lengths)
# # (B, T, D) -> (T, B, D)
# embedding scaling # x = x.transpose(0, 1)
x = self.embed_scale * x #
# # embedding scaling
# padding and position embedding # x = self.embed_scale * x
encoder_padding_mask = lengths_to_padding_mask(input_lengths) #
# # padding and position embedding
if self.attn_type in ["rel_selfattn", "rel_pos", "rel_pos_legacy"]: # encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(x) #
# if self.attn_type in ["rel_selfattn", "rel_pos", "rel_pos_legacy"]:
elif self.attn_type == "rope": # positions = self.embed_positions(x)
positions = None #
# elif self.attn_type == "rope":
else: # positions = None
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) #
x += positions # else:
positions = None # positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
# x += positions
x = self.dropout_module(x) # positions = None
#
# add emb into history # x = self.dropout_module(x)
if self.history is not None: #
self.history.push(x) # # add emb into history
# if self.history is not None:
# gather cosine similarity # self.history.push(x)
cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1 #
if self.gather_cos_sim: # # gather cosine similarity
cos_sim_idx += 1 # cos_sim_idx = (cos_sim_idx + 10) // 10 * 10 - 1
self.add_to_dict(x, dis, cos_sim_idx) # if self.gather_cos_sim:
# cos_sim_idx += 1
layer_idx = 0 # self.add_to_dict(x, dis, cos_sim_idx)
intermedia_ctc_logits = [] #
for layer in self.layers: # layer_idx = 0
layer_idx += 1 # intermedia_ctc_logits = []
# for layer in self.layers:
if self.history is not None: # layer_idx += 1
x = self.history.pop() #
# if self.history is not None:
# encoder layer # x = self.history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions) #
# # encoder layer
# interleave CTC # x = layer(x, encoder_padding_mask, pos_emb=positions)
if layer_idx in self.intermedia_ctc_layers: #
if self.intermedia_drop_prob > 0: # # interleave CTC
p = torch.rand(1).uniform_() # if layer_idx in self.intermedia_ctc_layers:
if p < self.intermedia_drop_prob: # if self.intermedia_drop_prob > 0:
break # p = torch.rand(1).uniform_()
# if p < self.intermedia_drop_prob:
norm_x = self.layer_norm(x) # break
logit = self.ctc(norm_x) #
intermedia_ctc_logits.append(logit) # norm_x = self.layer_norm(x)
# logit = self.ctc(norm_x)
prob = F.softmax(logit, dim=-1, dtype=torch.float32) # intermedia_ctc_logits.append(logit)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) #
# prob = F.softmax(logit, dim=-1, dtype=torch.float32)
# gather cosine similarity # x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if self.gather_cos_sim: #
cos_sim_idx += 1 # # gather cosine similarity
self.add_to_dict(x, dis, cos_sim_idx) # if self.gather_cos_sim:
# cos_sim_idx += 1
if self.history is not None: # self.add_to_dict(x, dis, cos_sim_idx)
self.history.push(x) #
# if self.history is not None:
if self.history is not None: # self.history.push(x)
x = self.history.pop() #
# if self.history is not None:
if self.layer_norm is not None: # x = self.history.pop()
x = self.layer_norm(x) #
# if self.layer_norm is not None:
ctc_logit = self.ctc(x) # x = self.layer_norm(x)
#
return { # ctc_logit = self.ctc(x)
"encoder_out": [x], # T x B x C #
"ctc_logit": [ctc_logit], # B x T x C # return {
"intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C # "encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T # "ctc_logit": [ctc_logit], # B x T x C
"encoder_embedding": [], # B x T x C # "intermedia_ctc_logits": intermedia_ctc_logits, # B x T x C
"encoder_states": [], # List[T x B x C] # "encoder_padding_mask": [encoder_padding_mask], # B x T
"src_tokens": [], # "encoder_embedding": [], # B x T x C
"src_lengths": [], # "encoder_states": [], # List[T x B x C]
} # "src_tokens": [],
# "src_lengths": [],
# }
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
self.encoder.reorder_encoder_out(encoder_out, new_order)
return
new_encoder_out = ( new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0 [] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
......
...@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -401,13 +401,13 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TTransformerEncoder(args, task, embed_tokens) encoder = S2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info( logger.info(
f"loaded pretrained encoder from: " f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}" f"{args.load_pretrained_encoder_from}"
) )
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder return encoder
...@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -501,7 +501,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.padding_idx = 1 self.padding_idx = 1
self.subsample = subsampling(args) self.subsample = subsampling(args)
self.linear = nn.Linear(dim, dim) # self.linear = nn.Linear(dim, dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
...@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -535,6 +535,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.history = None self.history = None
self.use_ctc = "sate" in args.arch or \ self.use_ctc = "sate" in args.arch or \
(getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0)) (("ctc" in getattr(args, "criterion", "")) and (getattr(args, "ctc_weight", 0) > 0))
if self.use_ctc: if self.use_ctc:
self.ctc_layer = args.ctc_layer self.ctc_layer = args.ctc_layer
...@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -640,7 +641,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x += positions x += positions
positions = None positions = None
x = self.linear(x) # x = self.linear(x)
x = self.dropout_module(x) x = self.dropout_module(x)
# add emb into history # add emb into history
......
# Copyright (c) Facebook, Inc. and its affiliates. from typing import Optional
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
LayerNorm, LayerNorm,
MultiheadAttention, MultiheadAttention,
ReducedMultiheadAttention,
RelPositionMultiheadAttention, RelPositionMultiheadAttention,
RelativeMultiheadAttention, RelativeMultiheadAttention,
ConvolutionModule,
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
LocalMultiheadAttention, LocalMultiheadAttention,
ConvolutionModule ReducedMultiheadAttention,
RotaryPositionMultiHeadedAttention,
) )
from fairseq.modules.s2t_transformer_layer import FeedForwardModule
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor from torch import Tensor
...@@ -40,104 +38,76 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -40,104 +38,76 @@ class PDSTransformerEncoderLayer(nn.Module):
def __init__(self, args, embed_dim, ffn_embed_dim, num_head, att_sample_ratio=1): def __init__(self, args, embed_dim, ffn_embed_dim, num_head, att_sample_ratio=1):
super().__init__() super().__init__()
self.args = args self.args = args
self.embed_dim = embed_dim
self.encoder_ffn_embed_dim = ffn_embed_dim embed_dim = embed_dim
ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout
self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(args, self.embed_dim, num_head, att_sample_ratio) self.self_attn = self.build_self_attention(args, embed_dim, num_head, att_sample_ratio)
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(embed_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__ dropout, module_name=self.__class__.__name__
) )
self.activation_fn = utils.get_activation_fn( self.normalize_before = args.encoder_normalize_before
activation=getattr(args, 'activation_fn', 'relu') or "relu" activation = getattr(args, 'encoder_activation_fn', 'relu')
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
args.macaron_style = getattr(args, "macaron_style", False)
args.use_cnn_module = getattr(args, "use_cnn_module", False)
args.cnn_module_kernel = getattr(args, "cnn_module_kernel", 31)
if args.macaron_style: if args.macaron_style:
self.macaron_fc1 = self.build_fc1( self.macaron_ffn = FeedForwardModule(
self.embed_dim, embed_dim,
args.encoder_ffn_embed_dim, ffn_dim,
self.quant_noise, dropout,
self.quant_noise_block_size, dropout,
) activation
self.macaron_fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
) )
self.macaron_norm = LayerNorm(self.embed_dim) self.macaron_norm = LayerNorm(embed_dim)
self.ffn_scale = 0.5 self.ffn_scale = 0.5
else: else:
self.macaron_fc1 = None self.macaron_ffn = None
self.macaron_fc2 = None
self.macaron_norm = None self.macaron_norm = None
self.ffn_scale = 1.0 self.ffn_scale = 1.0
if args.use_cnn_module: if args.use_cnn_module:
self.conv_norm = LayerNorm(self.embed_dim) self.conv_norm = LayerNorm(embed_dim)
self.conv_module = ConvolutionModule( self.conv_module = ConvolutionModule(
self.embed_dim, embed_dim,
args.cnn_module_kernel) embed_dim,
self.final_norm = LayerNorm(self.embed_dim) depthwise_kernel_size=args.cnn_module_kernel,
dropout=args.dropout,
activation_fn=getattr(args, 'activation_fn', 'swish'))
self.final_norm = LayerNorm(embed_dim)
else: else:
self.conv_norm = None self.conv_norm = None
self.conv_module = None self.conv_module = None
self.final_norm = None self.final_norm = None
self.normalize_before = args.encoder_normalize_before self.ffn = FeedForwardModule(
self.fc1 = self.build_fc1( embed_dim,
self.embed_dim, ffn_dim,
self.encoder_ffn_embed_dim, dropout,
self.quant_noise, dropout,
self.quant_noise_block_size, activation
) )
self.fc2 = self.build_fc2( self.ffn_norm = LayerNorm(embed_dim)
self.encoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_self_attention(self, args, embed_dim, num_head, sample_ratio=1): def build_self_attention(self, args, embed_dim, num_head, sample_ratio=1):
encoder_attention_heads = num_head attention_heads = num_head
dropout = args.dropout
if self.attn_type == "selfattn": if self.attn_type == "selfattn":
attn_func = MultiheadAttention attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative": elif self.attn_type == "relative":
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length = max(getattr(args, "max_encoder_relative_length", -1), max_relative_length = max(getattr(args, "max_encoder_relative_length", -1),
getattr(args, "max_relative_length", -1)) getattr(args, "max_relative_length", -1))
if max_relative_length != -1: if max_relative_length != -1:
return RelativeMultiheadAttention( return RelativeMultiheadAttention(
embed_dim, embed_dim,
encoder_attention_heads, attention_heads,
dropout=args.attention_dropout, dropout=dropout,
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
...@@ -152,8 +122,8 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -152,8 +122,8 @@ class PDSTransformerEncoderLayer(nn.Module):
init_mask_weight = getattr(args, "init_mask_weight", 0) init_mask_weight = getattr(args, "init_mask_weight", 0)
return LocalMultiheadAttention( return LocalMultiheadAttention(
embed_dim, embed_dim,
encoder_attention_heads, attention_heads,
dropout=args.attention_dropout, dropout=dropout,
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
...@@ -161,11 +131,36 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -161,11 +131,36 @@ class PDSTransformerEncoderLayer(nn.Module):
gauss_mask_sigma=gauss_mask_sigma, gauss_mask_sigma=gauss_mask_sigma,
init_mask_weight=init_mask_weight init_mask_weight=init_mask_weight
) )
elif self.attn_type == "rel_pos":
return RelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
elif self.attn_type == "rel_pos_legacy":
return LegacyRelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
elif self.attn_type == "rope":
return RotaryPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
precision=args.fp16
)
elif self.attn_type == "abs":
return ESPNETMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
)
elif self.attn_type == "reduced": elif self.attn_type == "reduced":
return ReducedMultiheadAttention( return ReducedMultiheadAttention(
embed_dim, embed_dim,
encoder_attention_heads, attention_heads,
dropout=args.attention_dropout, dropout=dropout,
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
...@@ -177,8 +172,8 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -177,8 +172,8 @@ class PDSTransformerEncoderLayer(nn.Module):
return attn_func( return attn_func(
embed_dim, embed_dim,
encoder_attention_heads, attention_heads,
dropout=args.attention_dropout, dropout=dropout,
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
...@@ -234,15 +229,15 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -234,15 +229,15 @@ class PDSTransformerEncoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.macaron_norm(x) x = self.macaron_norm(x)
x = self.macaron_fc2(self.activation_dropout_module(self.activation_fn(self.macaron_fc1(x)))) x = self.macaron_ffn(x)
x = residual + self.ffn_scale * self.dropout_module(x) x = residual + self.ffn_scale * x
if not self.normalize_before: if not self.normalize_before:
x = self.macaron_norm(x) x = self.macaron_norm(x)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn": if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
assert pos_emb is not None, "Positions is necessary for RPE!" assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
...@@ -269,326 +264,28 @@ class PDSTransformerEncoderLayer(nn.Module): ...@@ -269,326 +264,28 @@ class PDSTransformerEncoderLayer(nn.Module):
# convolution module # convolution module
if self.conv_module is not None: if self.conv_module is not None:
x = x.transpose(0, 1)
residual = x residual = x
x = x.transpose(0, 1)
if self.normalize_before: if self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = residual + self.dropout_module(self.conv_module(x, encoder_padding_mask))
x = self.conv_module(x)
x = x.transpose(0, 1)
x = residual + x
if not self.normalize_before: if not self.normalize_before:
x = self.conv_norm(x) x = self.conv_norm(x)
x = x.transpose(0, 1)
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.final_layer_norm(x) x = self.ffn_norm(x)
x = self.activation_fn(self.fc1(x)) x = self.ffn(x)
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(self.ffn_scale * x, residual) x = self.residual_connection(self.ffn_scale * x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.final_layer_norm(x) x = self.ffn_norm(x)
if self.conv_module is not None: if self.conv_module is not None:
x = self.final_norm(x) x = self.final_norm(x)
return x return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.quant_noise = getattr(args, "quant_noise_pq", 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
self.cross_self_attention = getattr(args, "cross_self_attention", False)
self.attn_type = getattr(args, "decoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(
self.embed_dim,
args,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.activation_fn = utils.get_activation_fn(
activation=str(args.activation_fn)
if getattr(args, "activation_fn", None) is not None
else "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = self.build_fc1(
self.embed_dim,
args.decoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
args.decoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
if self.attn_type == "selfattn":
attn_func = MultiheadAttention
elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative":
max_relative_length = max(getattr(args, "max_decoder_relative_length", -1), getattr(args, "max_relative_length", -1))
if max_relative_length != -1:
return RelativeMultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
max_relative_length=max_relative_length,
)
else:
print("The maximum decoder relative length %d can not be -1!" % max_relative_length)
exit(1)
else:
print("The decoder attention type %s is not supported!" % self.attn_type)
exit(1)
return attn_func(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not getattr(args, "cross_self_attention", False),
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def build_encoder_attention(self, embed_dim, args):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
pos_emb: Optional[Tensor] = None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
if self.attn_type == "rel_selfattn":
assert pos_emb is not None, "Positions is necessary for RPE!"
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
pos_emb=pos_emb
)
else:
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
...@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -88,11 +88,11 @@ class S2TTransformerEncoderLayer(nn.Module):
embed_dim = args.encoder_embed_dim embed_dim = args.encoder_embed_dim
ffn_dim = args.encoder_ffn_embed_dim ffn_dim = args.encoder_ffn_embed_dim
dropout = args.dropout dropout = args.dropout
self.embed_dim = args.encoder_embed_dim self.embed_dim = embed_dim
self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.attn_type = getattr(args, "encoder_attention_type", "selfattn") self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn = self.build_self_attention(args, self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__ dropout, module_name=self.__class__.__name__
...@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -138,7 +138,7 @@ class S2TTransformerEncoderLayer(nn.Module):
) )
self.ffn_norm = LayerNorm(self.embed_dim) self.ffn_norm = LayerNorm(self.embed_dim)
def build_self_attention(self, embed_dim, args): def build_self_attention(self, args, embed_dim):
attention_heads = args.encoder_attention_heads attention_heads = args.encoder_attention_heads
dropout = args.dropout dropout = args.dropout
...@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -147,7 +147,8 @@ class S2TTransformerEncoderLayer(nn.Module):
elif self.attn_type == "rel_selfattn": elif self.attn_type == "rel_selfattn":
attn_func = RelPositionMultiheadAttention attn_func = RelPositionMultiheadAttention
elif self.attn_type == "relative": elif self.attn_type == "relative":
max_relative_length = getattr(args, "max_encoder_relative_length", -1) max_relative_length = max(getattr(args, "max_encoder_relative_length", -1),
getattr(args, "max_relative_length", -1))
if max_relative_length != -1: if max_relative_length != -1:
return RelativeMultiheadAttention( return RelativeMultiheadAttention(
embed_dim, embed_dim,
...@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -188,13 +189,13 @@ class S2TTransformerEncoderLayer(nn.Module):
) )
else: else:
attn_func = MultiheadAttention attn_func = MultiheadAttention
print("The attention type %s is not supported!" % self.attn_type) print("The encoder attention type %s is not supported!" % self.attn_type)
exit(1) exit(1)
return attn_func( return attn_func(
embed_dim, embed_dim,
args.encoder_attention_heads, attention_heads,
dropout=args.attention_dropout, dropout=dropout,
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论