Commit e96e141d by xuchen

fix the bug of the relative position encoding.

parent ebd1be88
......@@ -5,8 +5,9 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3
#adam_betas: (0.9,0.98)
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
......@@ -32,3 +33,6 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
attention-dropout: 0.1
activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -2,3 +2,4 @@ macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-attention-type: rel_pos
encoder-activation-fn: swish
\ No newline at end of file
......@@ -37,4 +37,4 @@ macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-activation-fn: swish
encoder-attention-type: rel_pos
encoder-attention-type: rel_pos_legacy
......@@ -4,22 +4,28 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3
#adam_betas: (0.9,0.98)
adam_betas: (0.9,0.98)
criterion: ctc
zero_infinity: True
post-process: sentencepiece
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 12
encoder-attention-heads: 4
\ No newline at end of file
attention-dropout: 0.1
activation-dropout: 0.1
\ No newline at end of file
encoder-attention-type: rel_pos
encoder-attention-type: rel_selfattn
#encoder-attention-type: relative
#decoder-attention-type: relative
#max-encoder-relative-length: 100
#max-decoder-relative-length: 20
......@@ -5,6 +5,7 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
weight-decay: 1e-6
lr: 1e-3
adam_betas: (0.9,0.997)
......
......@@ -5,14 +5,21 @@ clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 2e-3
#adam_betas: (0.9,0.98)
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
conv-kernel-sizes: 5,5
conv-channels: 1024
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
......
macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
encoder-attention-type: rel_pos
encoder-activation-fn: swish
\ No newline at end of file
ctc-weight: 0.3
zero_infinity: True
post-process: sentencepiece
\ No newline at end of file
......@@ -14,8 +14,15 @@ label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
conv-kernel-sizes: 5,5
conv-channels: 1024
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.1
activation-fn: relu
encoder-embed-dim: 256
......
arch: s2t_ctc
optimizer: adam
clip-norm: 10.0
#clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
weight-decay: 1e-6
lr: 0.0015
adam_betas: (0.9,0.98)
criterion: ctc
post-process: sentencepiece
subsampling-type: conv1d
subsampling-type: conv2d
subsmapling-layers: 2
subsampling-filter: 1024
subsampling-kernel: 5
subsampling-filter: 176
subsampling-kernel: 3
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
subsampling-norm: batch2d
subsampling-activation: swish
dropout: 0.1
activation-fn: relu
......
......@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
PositionalEncoding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
......@@ -462,6 +464,10 @@ class S2TCTCEncoder(FairseqEncoder):
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type == "rel_selfattn":
self.embed_positions = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
......@@ -554,7 +560,7 @@ class S2TCTCEncoder(FairseqEncoder):
# padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type == "rel_pos":
if self.attn_type == "rel_pos" or self.attn_type == "rel_selfattn":
positions = self.embed_positions(x)
elif self.attn_type == "rope":
......@@ -566,7 +572,6 @@ class S2TCTCEncoder(FairseqEncoder):
positions = None
x = self.dropout_module(x)
# positions = self.dropout_module(positions)
# add emb into history
if self.history is not None:
......@@ -698,7 +703,7 @@ class CTCDecoder(object):
model_path=self.lm_model,
alpha=self.lm_weight,
beta=0,
cutoff_top_n=40,
cutoff_top_n=self.vocab_size,
cutoff_prob=1.0,
beam_width=self.beam_size,
num_processes=20,
......
......@@ -19,6 +19,7 @@ from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
S2TTransformerEncoderLayer,
DynamicLinearCombination,
......@@ -133,9 +134,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"reduced",
"rel_selfattn",
"relative",
"rel_pos_legacy",
"rel_pos",
"rope",
"abs"
"abs",
],
help="transformer encoder self-attention layer type"
)
......@@ -499,6 +501,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
self.embed_positions = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
......@@ -614,7 +620,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type == "rel_pos":
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
positions = self.embed_positions(x)
elif self.attn_type == "rope":
......@@ -626,7 +632,6 @@ class S2TTransformerEncoder(FairseqEncoder):
positions = None
x = self.dropout_module(x)
# positions = self.dropout_module(positions)
# add emb into history
if self.history is not None:
......
......@@ -44,15 +44,18 @@ from .transpose_last import TransposeLast
from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock
from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import (
PositionalEncoding,
LegacyRelPositionalEncoding,
RelPositionalEncoding,
)
from .espnet_multihead_attention import (
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from .rotary_positional_embedding import RotaryPositionalEmbedding
from .positional_encoding import (
RelPositionalEncoding,
)
from .convolution import ConvolutionModule
from .s2t_transformer_layer import S2TTransformerEncoderLayer
from .pds_layer import PDSTransformerEncoderLayer
......@@ -87,7 +90,6 @@ __all__ = [
"LightweightConv",
"LinearizedConvolution",
"LocalMultiheadAttention",
"MultiheadAttention",
"PositionalEmbedding",
"PDSTransformerEncoderLayer",
......@@ -107,10 +109,13 @@ __all__ = [
"TransposeLast",
"VGGBlock",
"unfold1d",
"ESPNETMultiheadedAttention",
"ESPNETMultiHeadedAttention",
"PositionalEmbedding",
"RelPositionMultiHeadedAttention",
"PositionalEncoding",
"LegacyRelPositionalEncoding",
"RelPositionalEncoding",
"LegacyRelPositionMultiHeadedAttention",
"RotaryPositionalEmbedding",
"RotaryPositionMultiHeadedAttention",
]
......@@ -156,7 +156,7 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor B X 2T-1 X C
pos_emb: Positional embedding tensor 2T-1 X B(1) X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
......@@ -196,6 +196,43 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
return scores, None
class LegacyRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_feat, n_head, dropout, zero_triu)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x: Input tensor B X n_head X T X 2T-1
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
def __init__(
self,
......
......@@ -63,13 +63,50 @@ class PositionalEncoding(nn.Module):
return self.dropout(x)
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class."""
super().__init__(
d_model=d_model,
dropout_rate=dropout_rate,
max_len=max_len,
reverse=True,
)
def forward(self, x):
"""Add positional encoding.
Args:
x : Input tensor T X B X C.
Returns:
torch.Tensor: Encoded tensor T X B X C.
"""
x = x.transpose(0, 1) # Change TBC to BTC
self.extend_pe(x)
pos_emb = self.pe[:, : x.size(1)]
pos_emb = pos_emb.transpose(0, 1) # change to TBC
return self.dropout(pos_emb)
class RelPositionalEncoding(nn.Module):
"""Relative positional encoding module (new implementation).
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
d_model: Embedding dimension.
"""
def __init__(self, max_len, d_model):
......
......@@ -247,6 +247,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
dim=1,
)
pos_emb = pos_emb.repeat(bsz, 1, 1)
pos_emb = pos_emb.transpose(0, 1)
p = self.linear_pos(pos_emb).view(bsz, -1, self.num_heads, self.head_dim)
# p (bsz * num_heads, tgt_len, head_dim)
......
......@@ -15,11 +15,12 @@ from fairseq.modules import (
ConvolutionModule,
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
LegacyRelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
from torch import Tensor
from fairseq.modules.activations import get_activation_fn, get_activation_class
from fairseq.modules.activations import get_activation_class
class FeedForwardModule(torch.nn.Module):
......@@ -160,8 +161,8 @@ class S2TTransformerEncoderLayer(nn.Module):
else:
print("The maximum encoder relative length %d can not be -1!" % max_relative_length)
exit(1)
elif self.attn_type == "rel_pos":
return RelPositionMultiHeadedAttention(
elif self.attn_type in ["rel_pos", "rel_pos_legacy"]:
return LegacyRelPositionMultiHeadedAttention(
embed_dim,
attention_heads,
dropout=dropout,
......@@ -225,7 +226,7 @@ class S2TTransformerEncoderLayer(nn.Module):
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
positions (Tensor): the position embedding for relative position encoding
pos_emb (Tensor): the position embedding for relative position encoding
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
......@@ -251,7 +252,7 @@ class S2TTransformerEncoderLayer(nn.Module):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.attn_type == "rel_selfattn" or self.attn_type == "rel_pos":
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
assert pos_emb is not None, "Positions is necessary for RPE!"
x, _ = self.self_attn(
query=x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论