Commit e96e141d by xuchen

fix the bug of the relative position encoding.

parent ebd1be88
...@@ -5,8 +5,9 @@ clip-norm: 10.0 ...@@ -5,8 +5,9 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
...@@ -32,3 +33,6 @@ decoder-ffn-embed-dim: 2048 ...@@ -32,3 +33,6 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
attention-dropout: 0.1 attention-dropout: 0.1
activation-dropout: 0.1 activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
...@@ -2,3 +2,4 @@ macaron-style: True ...@@ -2,3 +2,4 @@ macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-attention-type: rel_pos encoder-attention-type: rel_pos
encoder-activation-fn: swish
\ No newline at end of file
...@@ -37,4 +37,4 @@ macaron-style: True ...@@ -37,4 +37,4 @@ macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-activation-fn: swish encoder-activation-fn: swish
encoder-attention-type: rel_pos encoder-attention-type: rel_pos_legacy
...@@ -4,22 +4,28 @@ clip-norm: 10.0 ...@@ -4,22 +4,28 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
zero_infinity: True zero_infinity: True
post-process: sentencepiece post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048 encoder-ffn-embed-dim: 2048
encoder-layers: 12 encoder-layers: 12
encoder-attention-heads: 4 encoder-attention-heads: 4
\ No newline at end of file
attention-dropout: 0.1
activation-dropout: 0.1
\ No newline at end of file
encoder-attention-type: rel_pos encoder-attention-type: rel_selfattn
#encoder-attention-type: relative #encoder-attention-type: relative
#decoder-attention-type: relative
#max-encoder-relative-length: 100 #max-encoder-relative-length: 100
#max-decoder-relative-length: 20
...@@ -5,6 +5,7 @@ clip-norm: 10.0 ...@@ -5,6 +5,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 8000 warmup-updates: 8000
weight-decay: 1e-6
lr: 1e-3 lr: 1e-3
adam_betas: (0.9,0.997) adam_betas: (0.9,0.997)
......
...@@ -5,14 +5,21 @@ clip-norm: 10.0 ...@@ -5,14 +5,21 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1 label_smoothing: 0.1
conv-kernel-sizes: 5,5 subsampling-type: conv1d
conv-channels: 1024 subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
macaron-style: True macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
encoder-attention-type: rel_pos
encoder-activation-fn: swish
\ No newline at end of file
ctc-weight: 0.3 ctc-weight: 0.3
zero_infinity: True
post-process: sentencepiece
\ No newline at end of file
...@@ -14,8 +14,15 @@ label_smoothing: 0.1 ...@@ -14,8 +14,15 @@ label_smoothing: 0.1
encoder-normalize-before: True encoder-normalize-before: True
decoder-normalize-before: True decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024 subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-embed-dim: 256 encoder-embed-dim: 256
......
arch: s2t_ctc arch: s2t_ctc
optimizer: adam optimizer: adam
clip-norm: 10.0 #clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7 warmup-init-lr: 1e-7
warmup-updates: 10000 warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015 lr: 0.0015
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
criterion: ctc criterion: ctc
post-process: sentencepiece post-process: sentencepiece
subsampling-type: conv1d subsampling-type: conv2d
subsmapling-layers: 2 subsmapling-layers: 2
subsampling-filter: 1024 subsampling-filter: 176
subsampling-kernel: 5 subsampling-kernel: 3
subsampling-stride: 2 subsampling-stride: 2
subsampling-norm: none subsampling-norm: batch2d
subsampling-activation: glu subsampling-activation: swish
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
......
...@@ -19,6 +19,8 @@ from fairseq.modules import ( ...@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
PositionalEncoding,
LegacyRelPositionalEncoding,
RelPositionalEncoding, RelPositionalEncoding,
S2TTransformerEncoderLayer, S2TTransformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
...@@ -462,6 +464,10 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -462,6 +464,10 @@ class S2TCTCEncoder(FairseqEncoder):
self.embed_positions = RelPositionalEncoding( self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim args.max_source_positions, args.encoder_embed_dim
) )
elif self.attn_type == "rel_selfattn":
self.embed_positions = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope": elif self.attn_type == "rope":
self.embed_positions = None self.embed_positions = None
else: # Use absolute positional embedding else: # Use absolute positional embedding
...@@ -554,7 +560,7 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -554,7 +560,7 @@ class S2TCTCEncoder(FairseqEncoder):
# padding and position embedding # padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type == "rel_pos": if self.attn_type == "rel_pos" or self.attn_type == "rel_selfattn":
positions = self.embed_positions(x) positions = self.embed_positions(x)
elif self.attn_type == "rope": elif self.attn_type == "rope":
...@@ -566,7 +572,6 @@ class S2TCTCEncoder(FairseqEncoder): ...@@ -566,7 +572,6 @@ class S2TCTCEncoder(FairseqEncoder):
positions = None positions = None
x = self.dropout_module(x) x = self.dropout_module(x)
# positions = self.dropout_module(positions)
# add emb into history # add emb into history
if self.history is not None: if self.history is not None:
...@@ -698,7 +703,7 @@ class CTCDecoder(object): ...@@ -698,7 +703,7 @@ class CTCDecoder(object):
model_path=self.lm_model, model_path=self.lm_model,
alpha=self.lm_weight, alpha=self.lm_weight,
beta=0, beta=0,
cutoff_top_n=40, cutoff_top_n=self.vocab_size,
cutoff_prob=1.0, cutoff_prob=1.0,
beam_width=self.beam_size, beam_width=self.beam_size,
num_processes=20, num_processes=20,
......
...@@ -19,6 +19,7 @@ from fairseq.modules import ( ...@@ -19,6 +19,7 @@ from fairseq.modules import (
FairseqDropout, FairseqDropout,
LayerNorm, LayerNorm,
PositionalEmbedding, PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding, RelPositionalEncoding,
S2TTransformerEncoderLayer, S2TTransformerEncoderLayer,
DynamicLinearCombination, DynamicLinearCombination,
...@@ -133,9 +134,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -133,9 +134,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"reduced", "reduced",
"rel_selfattn", "rel_selfattn",
"relative", "relative",
"rel_pos_legacy",
"rel_pos", "rel_pos",
"rope", "rope",
"abs" "abs",
], ],
help="transformer encoder self-attention layer type" help="transformer encoder self-attention layer type"
) )
...@@ -499,6 +501,10 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -499,6 +501,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self.embed_positions = RelPositionalEncoding( self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim args.max_source_positions, args.encoder_embed_dim
) )
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
self.embed_positions = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope": elif self.attn_type == "rope":
self.embed_positions = None self.embed_positions = None
else: # Use absolute positional embedding else: # Use absolute positional embedding
...@@ -614,7 +620,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -614,7 +620,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# padding and position embedding # padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type == "rel_pos": if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
positions = self.embed_positions(x) positions = self.embed_positions(x)
elif self.attn_type == "rope": elif self.attn_type == "rope":
...@@ -626,7 +632,6 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -626,7 +632,6 @@ class S2TTransformerEncoder(FairseqEncoder):
positions = None positions = None
x = self.dropout_module(x) x = self.dropout_module(x)
# positions = self.dropout_module(positions)
# add emb into history # add emb into history
if self.history is not None: if self.history is not None:
......
...@@ -44,15 +44,18 @@ from .transpose_last import TransposeLast ...@@ -44,15 +44,18 @@ from .transpose_last import TransposeLast
from .unfold import unfold1d from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock from .vggblock import VGGBlock
from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import (
PositionalEncoding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
)
from .espnet_multihead_attention import ( from .espnet_multihead_attention import (
ESPNETMultiHeadedAttention, ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention, RelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention, RotaryPositionMultiHeadedAttention,
) )
from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import (
RelPositionalEncoding,
)
from .convolution import ConvolutionModule from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .pds_layer import PDSTransformerEncoderLayer from .pds_layer import PDSTransformerEncoderLayer
...@@ -87,7 +90,6 @@ __all__ = [ ...@@ -87,7 +90,6 @@ __all__ = [
"LightweightConv", "LightweightConv",
"LinearizedConvolution", "LinearizedConvolution",
"LocalMultiheadAttention", "LocalMultiheadAttention",
"MultiheadAttention", "MultiheadAttention",
"PositionalEmbedding", "PositionalEmbedding",
"PDSTransformerEncoderLayer", "PDSTransformerEncoderLayer",
...@@ -107,10 +109,13 @@ __all__ = [ ...@@ -107,10 +109,13 @@ __all__ = [
"TransposeLast", "TransposeLast",
"VGGBlock", "VGGBlock",
"unfold1d", "unfold1d",
"ESPNETMultiheadedAttention", "ESPNETMultiHeadedAttention",
"PositionalEmbedding", "PositionalEmbedding",
"RelPositionMultiHeadedAttention", "RelPositionMultiHeadedAttention",
"PositionalEncoding",
"LegacyRelPositionalEncoding",
"RelPositionalEncoding", "RelPositionalEncoding",
"LegacyRelPositionMultiHeadedAttention",
"RotaryPositionalEmbedding", "RotaryPositionalEmbedding",
"RotaryPositionMultiHeadedAttention", "RotaryPositionMultiHeadedAttention",
] ]
...@@ -156,7 +156,7 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): ...@@ -156,7 +156,7 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
query: Query tensor T X B X C query: Query tensor T X B X C
key: Key tensor T X B X C key: Key tensor T X B X C
value: Value tensor T X B X C value: Value tensor T X B X C
pos_emb: Positional embedding tensor B X 2T-1 X C pos_emb: Positional embedding tensor 2T-1 X B(1) X C
key_padding_mask: Mask tensor T X B key_padding_mask: Mask tensor T X B
Returns: Returns:
torch.Tensor: Output tensor T X B X C. torch.Tensor: Output tensor T X B X C.
...@@ -196,6 +196,43 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): ...@@ -196,6 +196,43 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
return scores, None return scores, None
class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout, zero_triu)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x: Input tensor B X n_head X T X 2T-1
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
def __init__( def __init__(
self, self,
......
...@@ -63,13 +63,50 @@ class PositionalEncoding(nn.Module): ...@@ -63,13 +63,50 @@ class PositionalEncoding(nn.Module):
return self.dropout(x) return self.dropout(x)
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class."""
super().__init__(
d_model=d_model,
dropout_rate=dropout_rate,
max_len=max_len,
reverse=True,
)
def forward(self, x):
"""Add positional encoding.
Args:
x : Input tensor T X B X C.
Returns:
torch.Tensor: Encoded tensor T X B X C.
"""
x = x.transpose(0, 1) # Change TBC to BTC
self.extend_pe(x)
pos_emb = self.pe[:, : x.size(1)]
pos_emb = pos_emb.transpose(0, 1) # change to TBC
return self.dropout(pos_emb)
class RelPositionalEncoding(nn.Module): class RelPositionalEncoding(nn.Module):
"""Relative positional encoding module (new implementation). """Relative positional encoding module (new implementation).
Args: Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length. max_len: Maximum input length.
d_model: Embedding dimension.
""" """
def __init__(self, max_len, d_model): def __init__(self, max_len, d_model):
......
...@@ -247,6 +247,7 @@ class RelPositionMultiheadAttention(MultiheadAttention): ...@@ -247,6 +247,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
dim=1, dim=1,
) )
pos_emb = pos_emb.repeat(bsz, 1, 1)
pos_emb = pos_emb.transpose(0, 1) pos_emb = pos_emb.transpose(0, 1)
p = self.linear_pos(pos_emb).view(bsz, -1, self.num_heads, self.head_dim) p = self.linear_pos(pos_emb).view(bsz, -1, self.num_heads, self.head_dim)
# p (bsz * num_heads, tgt_len, head_dim) # p (bsz * num_heads, tgt_len, head_dim)
......
...@@ -15,11 +15,12 @@ from fairseq.modules import ( ...@@ -15,11 +15,12 @@ from fairseq.modules import (
ConvolutionModule, ConvolutionModule,
ESPNETMultiHeadedAttention, ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention, RelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention, RotaryPositionMultiHeadedAttention,
) )
from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.fairseq_dropout import FairseqDropout
from torch import Tensor from torch import Tensor
from fairseq.modules.activations import get_activation_fn, get_activation_class from fairseq.modules.activations import get_activation_class
class FeedForwardModule(torch.nn.Module): class FeedForwardModule(torch.nn.Module):
...@@ -160,8 +161,8 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -160,8 +161,8 @@ class S2TTransformerEncoderLayer(nn.Module):
else: else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length) print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1) exit(1)
elif self.attn_type == "rel_pos": elif self.attn_type in ["rel_pos", "rel_pos_legacy"]:
return RelPositionMultiHeadedAttention( return LegacyRelPositionMultiHeadedAttention(
embed_dim, embed_dim,
attention_heads, attention_heads,
dropout=dropout, dropout=dropout,
...@@ -225,7 +226,7 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -225,7 +226,7 @@ class S2TTransformerEncoderLayer(nn.Module):
`attn_mask[tgt_i, src_j] = 1` means that when calculating the `attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention. useful for strided self-attention.
positions (Tensor): the position embedding for relative position encoding pos_emb (Tensor): the position embedding for relative position encoding
Returns: Returns:
encoded output of shape `(seq_len, batch, embed_dim)` encoded output of shape `(seq_len, batch, embed_dim)`
...@@ -251,7 +252,7 @@ class S2TTransformerEncoderLayer(nn.Module): ...@@ -251,7 +252,7 @@ class S2TTransformerEncoderLayer(nn.Module):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn" or self.attn_type == "rel_pos": if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
assert pos_emb is not None, "Positions is necessary for RPE!" assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论