Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
e96e141d
Commit
e96e141d
authored
Feb 22, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bug of the relative position encoding.
parent
ebd1be88
显示空白字符变更
内嵌
并排
正在显示
19 个修改的文件
包含
170 行增加
和
42 行删除
+170
-42
egs/libri_trans/asr/conf/base.yaml
+6
-1
egs/libri_trans/asr/conf/basis.yaml
+0
-0
egs/libri_trans/asr/conf/conformer.yaml
+2
-0
egs/libri_trans/asr/conf/debug.yaml
+1
-1
egs/libri_trans/asr/conf/purectc.yaml
+14
-8
egs/libri_trans/asr/conf/rpr.yaml
+3
-1
egs/libri_trans/mt/conf/base.yaml
+1
-0
egs/libri_trans/st/conf/base.yaml
+10
-3
egs/libri_trans/st/conf/conformer.yaml
+3
-0
egs/libri_trans/st/conf/ctc.yaml
+3
-0
egs/libri_trans/st/conf/sate_ctc.yaml
+9
-2
egs/librispeech/asr/conf/ConformerCTCSmall.yaml
+7
-6
fairseq/models/speech_to_text/s2t_ctc.py
+8
-3
fairseq/models/speech_to_text/s2t_transformer.py
+8
-3
fairseq/modules/__init__.py
+11
-6
fairseq/modules/espnet_multihead_attention.py
+38
-1
fairseq/modules/positional_encoding.py
+39
-2
fairseq/modules/rel_position_multihead_attention.py
+1
-0
fairseq/modules/s2t_transformer_layer.py
+6
-5
没有找到文件。
egs/libri_trans/asr/conf/base.yaml
查看文件 @
e96e141d
...
@@ -5,8 +5,9 @@ clip-norm: 10.0
...
@@ -5,8 +5,9 @@ clip-norm: 10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
2e-3
lr
:
2e-3
#
adam_betas: (0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
label_smoothed_cross_entropy_with_ctc
criterion
:
label_smoothed_cross_entropy_with_ctc
label_smoothing
:
0.1
label_smoothing
:
0.1
...
@@ -32,3 +33,6 @@ decoder-ffn-embed-dim: 2048
...
@@ -32,3 +33,6 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads
:
4
decoder-attention-heads
:
4
attention-dropout
:
0.1
attention-dropout
:
0.1
activation-dropout
:
0.1
activation-dropout
:
0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
egs/libri_trans/asr/conf/basis.yaml
查看文件 @
e96e141d
egs/libri_trans/asr/conf/conformer.yaml
查看文件 @
e96e141d
...
@@ -2,3 +2,4 @@ macaron-style: True
...
@@ -2,3 +2,4 @@ macaron-style: True
use-cnn-module
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
cnn-module-kernel
:
31
encoder-attention-type
:
rel_pos
encoder-attention-type
:
rel_pos
encoder-activation-fn
:
swish
\ No newline at end of file
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
e96e141d
...
@@ -37,4 +37,4 @@ macaron-style: True
...
@@ -37,4 +37,4 @@ macaron-style: True
use-cnn-module
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
cnn-module-kernel
:
31
encoder-activation-fn
:
swish
encoder-activation-fn
:
swish
encoder-attention-type
:
rel_pos
encoder-attention-type
:
rel_pos
_legacy
egs/libri_trans/asr/conf/purectc.yaml
查看文件 @
e96e141d
...
@@ -4,22 +4,28 @@ clip-norm: 10.0
...
@@ -4,22 +4,28 @@ clip-norm: 10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
2e-3
lr
:
2e-3
#
adam_betas: (0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
ctc
criterion
:
ctc
zero_infinity
:
True
zero_infinity
:
True
post-process
:
sentencepiece
post-process
:
sentencepiece
label_smoothing
:
0.1
conv-kernel-sizes
:
5,5
subsampling-type
:
conv1d
conv-channels
:
1024
subsmapling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
dropout
:
0.1
dropout
:
0.1
attention-dropout
:
0.1
activation-dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-embed-dim
:
256
encoder-embed-dim
:
256
encoder-ffn-embed-dim
:
2048
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
encoder-layers
:
12
encoder-attention-heads
:
4
encoder-attention-heads
:
4
\ No newline at end of file
attention-dropout
:
0.1
activation-dropout
:
0.1
\ No newline at end of file
egs/libri_trans/asr/conf/rpr.yaml
查看文件 @
e96e141d
encoder-attention-type
:
rel_
pos
encoder-attention-type
:
rel_
selfattn
#encoder-attention-type: relative
#encoder-attention-type: relative
#decoder-attention-type: relative
#max-encoder-relative-length: 100
#max-encoder-relative-length: 100
#max-decoder-relative-length: 20
egs/libri_trans/mt/conf/base.yaml
查看文件 @
e96e141d
...
@@ -5,6 +5,7 @@ clip-norm: 10.0
...
@@ -5,6 +5,7 @@ clip-norm: 10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
8000
warmup-updates
:
8000
weight-decay
:
1e-6
lr
:
1e-3
lr
:
1e-3
adam_betas
:
(0.9,0.997)
adam_betas
:
(0.9,0.997)
...
...
egs/libri_trans/st/conf/base.yaml
查看文件 @
e96e141d
...
@@ -5,14 +5,21 @@ clip-norm: 10.0
...
@@ -5,14 +5,21 @@ clip-norm: 10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
2e-3
lr
:
2e-3
#
adam_betas: (0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
label_smoothed_cross_entropy_with_ctc
criterion
:
label_smoothed_cross_entropy_with_ctc
label_smoothing
:
0.1
label_smoothing
:
0.1
conv-kernel-sizes
:
5,5
subsampling-type
:
conv1d
conv-channels
:
1024
subsmapling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-embed-dim
:
256
encoder-embed-dim
:
256
...
...
egs/libri_trans/st/conf/conformer.yaml
查看文件 @
e96e141d
macaron-style
:
True
macaron-style
:
True
use-cnn-module
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
cnn-module-kernel
:
31
encoder-attention-type
:
rel_pos
encoder-activation-fn
:
swish
\ No newline at end of file
egs/libri_trans/st/conf/ctc.yaml
查看文件 @
e96e141d
ctc-weight
:
0.3
ctc-weight
:
0.3
zero_infinity
:
True
post-process
:
sentencepiece
\ No newline at end of file
egs/libri_trans/st/conf/sate_ctc.yaml
查看文件 @
e96e141d
...
@@ -14,8 +14,15 @@ label_smoothing: 0.1
...
@@ -14,8 +14,15 @@ label_smoothing: 0.1
encoder-normalize-before
:
True
encoder-normalize-before
:
True
decoder-normalize-before
:
True
decoder-normalize-before
:
True
conv-kernel-sizes
:
5,5
conv-channels
:
1024
subsampling-type
:
conv1d
subsmapling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-embed-dim
:
256
encoder-embed-dim
:
256
...
...
egs/librispeech/asr/conf/ConformerCTCSmall.yaml
查看文件 @
e96e141d
arch
:
s2t_ctc
arch
:
s2t_ctc
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
#
clip-norm: 10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
weight-decay
:
1e-6
lr
:
0.0015
lr
:
0.0015
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
ctc
criterion
:
ctc
post-process
:
sentencepiece
post-process
:
sentencepiece
subsampling-type
:
conv
1
d
subsampling-type
:
conv
2
d
subsmapling-layers
:
2
subsmapling-layers
:
2
subsampling-filter
:
1
024
subsampling-filter
:
1
76
subsampling-kernel
:
5
subsampling-kernel
:
3
subsampling-stride
:
2
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-norm
:
batch2d
subsampling-activation
:
glu
subsampling-activation
:
swish
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
e96e141d
...
@@ -19,6 +19,8 @@ from fairseq.modules import (
...
@@ -19,6 +19,8 @@ from fairseq.modules import (
FairseqDropout
,
FairseqDropout
,
LayerNorm
,
LayerNorm
,
PositionalEmbedding
,
PositionalEmbedding
,
PositionalEncoding
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
RelPositionalEncoding
,
S2TTransformerEncoderLayer
,
S2TTransformerEncoderLayer
,
DynamicLinearCombination
,
DynamicLinearCombination
,
...
@@ -462,6 +464,10 @@ class S2TCTCEncoder(FairseqEncoder):
...
@@ -462,6 +464,10 @@ class S2TCTCEncoder(FairseqEncoder):
self
.
embed_positions
=
RelPositionalEncoding
(
self
.
embed_positions
=
RelPositionalEncoding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
args
.
max_source_positions
,
args
.
encoder_embed_dim
)
)
elif
self
.
attn_type
==
"rel_selfattn"
:
self
.
embed_positions
=
LegacyRelPositionalEncoding
(
args
.
encoder_embed_dim
,
args
.
dropout
,
args
.
max_source_positions
)
elif
self
.
attn_type
==
"rope"
:
elif
self
.
attn_type
==
"rope"
:
self
.
embed_positions
=
None
self
.
embed_positions
=
None
else
:
# Use absolute positional embedding
else
:
# Use absolute positional embedding
...
@@ -554,7 +560,7 @@ class S2TCTCEncoder(FairseqEncoder):
...
@@ -554,7 +560,7 @@ class S2TCTCEncoder(FairseqEncoder):
# padding and position embedding
# padding and position embedding
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
if
self
.
attn_type
==
"rel_pos"
:
if
self
.
attn_type
==
"rel_pos"
or
self
.
attn_type
==
"rel_selfattn"
:
positions
=
self
.
embed_positions
(
x
)
positions
=
self
.
embed_positions
(
x
)
elif
self
.
attn_type
==
"rope"
:
elif
self
.
attn_type
==
"rope"
:
...
@@ -566,7 +572,6 @@ class S2TCTCEncoder(FairseqEncoder):
...
@@ -566,7 +572,6 @@ class S2TCTCEncoder(FairseqEncoder):
positions
=
None
positions
=
None
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
# positions = self.dropout_module(positions)
# add emb into history
# add emb into history
if
self
.
history
is
not
None
:
if
self
.
history
is
not
None
:
...
@@ -698,7 +703,7 @@ class CTCDecoder(object):
...
@@ -698,7 +703,7 @@ class CTCDecoder(object):
model_path
=
self
.
lm_model
,
model_path
=
self
.
lm_model
,
alpha
=
self
.
lm_weight
,
alpha
=
self
.
lm_weight
,
beta
=
0
,
beta
=
0
,
cutoff_top_n
=
40
,
cutoff_top_n
=
self
.
vocab_size
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
beam_width
=
self
.
beam_size
,
beam_width
=
self
.
beam_size
,
num_processes
=
20
,
num_processes
=
20
,
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
e96e141d
...
@@ -19,6 +19,7 @@ from fairseq.modules import (
...
@@ -19,6 +19,7 @@ from fairseq.modules import (
FairseqDropout
,
FairseqDropout
,
LayerNorm
,
LayerNorm
,
PositionalEmbedding
,
PositionalEmbedding
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
RelPositionalEncoding
,
S2TTransformerEncoderLayer
,
S2TTransformerEncoderLayer
,
DynamicLinearCombination
,
DynamicLinearCombination
,
...
@@ -133,9 +134,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -133,9 +134,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
"reduced"
,
"reduced"
,
"rel_selfattn"
,
"rel_selfattn"
,
"relative"
,
"relative"
,
"rel_pos_legacy"
,
"rel_pos"
,
"rel_pos"
,
"rope"
,
"rope"
,
"abs"
"abs"
,
],
],
help
=
"transformer encoder self-attention layer type"
help
=
"transformer encoder self-attention layer type"
)
)
...
@@ -499,6 +501,10 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -499,6 +501,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
embed_positions
=
RelPositionalEncoding
(
self
.
embed_positions
=
RelPositionalEncoding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
args
.
max_source_positions
,
args
.
encoder_embed_dim
)
)
elif
self
.
attn_type
in
[
"rel_selfattn"
,
"rel_pos_legacy"
]:
self
.
embed_positions
=
LegacyRelPositionalEncoding
(
args
.
encoder_embed_dim
,
args
.
dropout
,
args
.
max_source_positions
)
elif
self
.
attn_type
==
"rope"
:
elif
self
.
attn_type
==
"rope"
:
self
.
embed_positions
=
None
self
.
embed_positions
=
None
else
:
# Use absolute positional embedding
else
:
# Use absolute positional embedding
...
@@ -614,7 +620,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -614,7 +620,7 @@ class S2TTransformerEncoder(FairseqEncoder):
# padding and position embedding
# padding and position embedding
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
if
self
.
attn_type
==
"rel_pos"
:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]
:
positions
=
self
.
embed_positions
(
x
)
positions
=
self
.
embed_positions
(
x
)
elif
self
.
attn_type
==
"rope"
:
elif
self
.
attn_type
==
"rope"
:
...
@@ -626,7 +632,6 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -626,7 +632,6 @@ class S2TTransformerEncoder(FairseqEncoder):
positions
=
None
positions
=
None
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
# positions = self.dropout_module(positions)
# add emb into history
# add emb into history
if
self
.
history
is
not
None
:
if
self
.
history
is
not
None
:
...
...
fairseq/modules/__init__.py
查看文件 @
e96e141d
...
@@ -44,15 +44,18 @@ from .transpose_last import TransposeLast
...
@@ -44,15 +44,18 @@ from .transpose_last import TransposeLast
from
.unfold
import
unfold1d
from
.unfold
import
unfold1d
from
.transformer_layer
import
TransformerDecoderLayer
,
TransformerEncoderLayer
from
.transformer_layer
import
TransformerDecoderLayer
,
TransformerEncoderLayer
from
.vggblock
import
VGGBlock
from
.vggblock
import
VGGBlock
from
.rotary_positional_embedding
import
RotaryPositionalEmbedding
from
.positional_encoding
import
(
PositionalEncoding
,
LegacyRelPositionalEncoding
,
RelPositionalEncoding
,
)
from
.espnet_multihead_attention
import
(
from
.espnet_multihead_attention
import
(
ESPNETMultiHeadedAttention
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
)
)
from
.rotary_positional_embedding
import
RotaryPositionalEmbedding
from
.positional_encoding
import
(
RelPositionalEncoding
,
)
from
.convolution
import
ConvolutionModule
from
.convolution
import
ConvolutionModule
from
.s2t_transformer_layer
import
S2TTransformerEncoderLayer
from
.s2t_transformer_layer
import
S2TTransformerEncoderLayer
from
.pds_layer
import
PDSTransformerEncoderLayer
from
.pds_layer
import
PDSTransformerEncoderLayer
...
@@ -87,7 +90,6 @@ __all__ = [
...
@@ -87,7 +90,6 @@ __all__ = [
"LightweightConv"
,
"LightweightConv"
,
"LinearizedConvolution"
,
"LinearizedConvolution"
,
"LocalMultiheadAttention"
,
"LocalMultiheadAttention"
,
"MultiheadAttention"
,
"MultiheadAttention"
,
"PositionalEmbedding"
,
"PositionalEmbedding"
,
"PDSTransformerEncoderLayer"
,
"PDSTransformerEncoderLayer"
,
...
@@ -107,10 +109,13 @@ __all__ = [
...
@@ -107,10 +109,13 @@ __all__ = [
"TransposeLast"
,
"TransposeLast"
,
"VGGBlock"
,
"VGGBlock"
,
"unfold1d"
,
"unfold1d"
,
"ESPNETMulti
h
eadedAttention"
,
"ESPNETMulti
H
eadedAttention"
,
"PositionalEmbedding"
,
"PositionalEmbedding"
,
"RelPositionMultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
,
"PositionalEncoding"
,
"LegacyRelPositionalEncoding"
,
"RelPositionalEncoding"
,
"RelPositionalEncoding"
,
"LegacyRelPositionMultiHeadedAttention"
,
"RotaryPositionalEmbedding"
,
"RotaryPositionalEmbedding"
,
"RotaryPositionMultiHeadedAttention"
,
"RotaryPositionMultiHeadedAttention"
,
]
]
fairseq/modules/espnet_multihead_attention.py
查看文件 @
e96e141d
...
@@ -156,7 +156,7 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
...
@@ -156,7 +156,7 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
query: Query tensor T X B X C
query: Query tensor T X B X C
key: Key tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor
B X 2T-1
X C
pos_emb: Positional embedding tensor
2T-1 X B(1)
X C
key_padding_mask: Mask tensor T X B
key_padding_mask: Mask tensor T X B
Returns:
Returns:
torch.Tensor: Output tensor T X B X C.
torch.Tensor: Output tensor T X B X C.
...
@@ -196,6 +196,43 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
...
@@ -196,6 +196,43 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
return
scores
,
None
return
scores
,
None
class
LegacyRelPositionMultiHeadedAttention
(
RelPositionMultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_feat
,
n_head
,
dropout
,
zero_triu
=
False
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
()
.
__init__
(
n_feat
,
n_head
,
dropout
,
zero_triu
)
def
rel_shift
(
self
,
x
):
"""Compute relative positional encoding.
Args:
x: Input tensor B X n_head X T X 2T-1
Returns:
torch.Tensor: Output tensor.
"""
zero_pad
=
torch
.
zeros
((
*
x
.
size
()[:
3
],
1
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=-
1
)
x_padded
=
x_padded
.
view
(
*
x
.
size
()[:
2
],
x
.
size
(
3
)
+
1
,
x
.
size
(
2
))
x
=
x_padded
[:,
:,
1
:]
.
view_as
(
x
)
if
self
.
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
)
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
3
)
-
x
.
size
(
2
))[
None
,
None
,
:,
:]
return
x
class
RotaryPositionMultiHeadedAttention
(
ESPNETMultiHeadedAttention
):
class
RotaryPositionMultiHeadedAttention
(
ESPNETMultiHeadedAttention
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
fairseq/modules/positional_encoding.py
查看文件 @
e96e141d
...
@@ -63,13 +63,50 @@ class PositionalEncoding(nn.Module):
...
@@ -63,13 +63,50 @@ class PositionalEncoding(nn.Module):
return
self
.
dropout
(
x
)
return
self
.
dropout
(
x
)
class
LegacyRelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
,
dropout_rate
,
max_len
=
5000
):
"""Initialize class."""
super
()
.
__init__
(
d_model
=
d_model
,
dropout_rate
=
dropout_rate
,
max_len
=
max_len
,
reverse
=
True
,
)
def
forward
(
self
,
x
):
"""Add positional encoding.
Args:
x : Input tensor T X B X C.
Returns:
torch.Tensor: Encoded tensor T X B X C.
"""
x
=
x
.
transpose
(
0
,
1
)
# Change TBC to BTC
self
.
extend_pe
(
x
)
pos_emb
=
self
.
pe
[:,
:
x
.
size
(
1
)]
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
# change to TBC
return
self
.
dropout
(
pos_emb
)
class
RelPositionalEncoding
(
nn
.
Module
):
class
RelPositionalEncoding
(
nn
.
Module
):
"""Relative positional encoding module (new implementation).
"""Relative positional encoding module (new implementation).
Args:
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
max_len: Maximum input length.
d_model: Embedding dimension.
"""
"""
def
__init__
(
self
,
max_len
,
d_model
):
def
__init__
(
self
,
max_len
,
d_model
):
...
...
fairseq/modules/rel_position_multihead_attention.py
查看文件 @
e96e141d
...
@@ -247,6 +247,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -247,6 +247,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
dim
=
1
,
dim
=
1
,
)
)
pos_emb
=
pos_emb
.
repeat
(
bsz
,
1
,
1
)
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
p
=
self
.
linear_pos
(
pos_emb
)
.
view
(
bsz
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
p
=
self
.
linear_pos
(
pos_emb
)
.
view
(
bsz
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
# p (bsz * num_heads, tgt_len, head_dim)
# p (bsz * num_heads, tgt_len, head_dim)
...
...
fairseq/modules/s2t_transformer_layer.py
查看文件 @
e96e141d
...
@@ -15,11 +15,12 @@ from fairseq.modules import (
...
@@ -15,11 +15,12 @@ from fairseq.modules import (
ConvolutionModule
,
ConvolutionModule
,
ESPNETMultiHeadedAttention
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
)
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
torch
import
Tensor
from
torch
import
Tensor
from
fairseq.modules.activations
import
get_activation_
fn
,
get_activation_
class
from
fairseq.modules.activations
import
get_activation_class
class
FeedForwardModule
(
torch
.
nn
.
Module
):
class
FeedForwardModule
(
torch
.
nn
.
Module
):
...
@@ -160,8 +161,8 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -160,8 +161,8 @@ class S2TTransformerEncoderLayer(nn.Module):
else
:
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
exit
(
1
)
elif
self
.
attn_type
==
"rel_pos"
:
elif
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
]
:
return
RelPositionMultiHeadedAttention
(
return
Legacy
RelPositionMultiHeadedAttention
(
embed_dim
,
embed_dim
,
attention_heads
,
attention_heads
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -225,7 +226,7 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -225,7 +226,7 @@ class S2TTransformerEncoderLayer(nn.Module):
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
useful for strided self-attention.
pos
itions
(Tensor): the position embedding for relative position encoding
pos
_emb
(Tensor): the position embedding for relative position encoding
Returns:
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
encoded output of shape `(seq_len, batch, embed_dim)`
...
@@ -251,7 +252,7 @@ class S2TTransformerEncoderLayer(nn.Module):
...
@@ -251,7 +252,7 @@ class S2TTransformerEncoderLayer(nn.Module):
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
==
"rel_selfattn"
or
self
.
attn_type
==
"rel_pos"
:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]
:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
x
,
_
=
self
.
self_attn
(
query
=
x
,
query
=
x
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论