Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
67d8695f
Commit
67d8695f
authored
Mar 04, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add target ctc
parent
d4255246
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
17 个修改的文件
包含
447 行增加
和
146 行删除
+447
-146
egs/libri_trans/asr/conf/debug.yaml
+35
-36
egs/librispeech/asr/conf/purectc_pds_base_8_compare2.yaml
+3
-3
egs/mustc/asr/conf/purectc_pds_base_8_grow.yaml
+50
-0
egs/wmt20/mt/local/lower_rm.py
+3
-3
egs/wmt20/mt/run.sh
+23
-41
examples/speech_to_text/prep_audio_data.py
+5
-2
examples/speech_to_text/prep_mt_data.py
+79
-26
fairseq/criterions/ctc.py
+80
-19
fairseq/models/speech_to_text/modules/ctc.py
+1
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+15
-10
fairseq/models/speech_to_text/s2t_ctc.py
+2
-1
fairseq/models/speech_to_text/s2t_sate.py
+0
-0
fairseq/models/speech_to_text/s2t_transformer.py
+8
-1
fairseq/modules/__init__.py
+3
-1
fairseq/modules/attention.py
+1
-1
fairseq/modules/espnet_multihead_attention.py
+127
-0
fairseq/modules/pds_layer.py
+12
-2
没有找到文件。
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
67d8695f
arch
:
s2t_ctc
encoder-type
:
pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim
:
240
pds-stages
:
3
#ctc-layer: 15
pds-layers
:
4_5_6
pds-ratios
:
2_2_2
pds-fusion
:
False
pds-fusion-method
:
all_conv
pds-embed-dims
:
120_168_240
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1
pds-kernel-sizes
:
3_3_3
pds-ffn-ratios
:
4_4_4
pds-attn-heads
:
4_4_4
arch
:
s2t_sate
share-decoder-input-output-embed
:
True
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
lr
:
0.0015
lr
:
2e-3
adam_betas
:
(0.9,0.98)
criterion
:
ctc
post-process
:
sentencepiece
ctc-weight
:
0.3
target-ctc-weight
:
0.2
target-ctc-layers
:
3,6
criterion
:
label_smoothed_cross_entropy_with_ctc
label_smoothing
:
0.1
encoder-normalize-before
:
True
decoder-normalize-before
:
True
subsampling-type
:
conv1d
subsampling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
dropout
:
0.1
activation-fn
:
relu
encoder-layers
:
15
encoder-embed-dim
:
256
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
text-encoder-layers
:
6
decoder-layers
:
6
encoder-attention-heads
:
4
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
macaron-style
:
True
use-cnn-module
:
True
cnn-module-kernel
:
15
encoder-activation-fn
:
swish
encoder-attention-type
:
rel_pos
acoustic-encoder
:
transformer
adapter
:
league
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
egs/librispeech/asr/conf/purectc_pds_base_8_compare2.yaml
查看文件 @
67d8695f
...
...
@@ -13,15 +13,15 @@ encoder-type: pds
encoder-embed-dim
:
240
pds-stages
:
3
#ctc-layer: 15
pds-layers
:
4_5_6
pds-layers
:
5_5_5
pds-ratios
:
2_2_2
pds-fusion
:
Fals
e
pds-fusion
:
Tru
e
pds-fusion-method
:
all_conv
pds-embed-dims
:
120_168_240
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1
pds-kernel-sizes
:
3_3_3
pds-kernel-sizes
:
5_5_5
pds-ffn-ratios
:
4_4_4
pds-attn-heads
:
4_4_4
...
...
egs/mustc/asr/conf/purectc_pds_base_8_grow.yaml
0 → 100644
查看文件 @
67d8695f
arch
:
s2t_ctc
encoder-type
:
pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type
:
rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim
:
512
pds-stages
:
4
#ctc-layer: 15
encoder-layers
:
10
pds-layers
:
3_2_2_3
pds-ratios
:
2_2_1_2
pds-fusion
:
True
pds-fusion-method
:
all_conv
pds-embed-dims
:
256_384_384_512
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1_1
pds-kernel-sizes
:
5_5_5_5
pds-ffn-ratios
:
8_4_4_4
pds-attn-heads
:
4_6_6_8
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
lr
:
0.002
adam_betas
:
(0.9,0.98)
criterion
:
ctc
post-process
:
sentencepiece
dropout
:
0.1
activation-fn
:
relu
macaron-style
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
encoder-activation-fn
:
swish
#load-pretrained-encoder-from:
egs/wmt20/mt/local/lower_rm.py
查看文件 @
67d8695f
...
...
@@ -8,7 +8,7 @@ with open(in_file, "r", encoding="utf-8") as f:
for
line
in
f
.
readlines
():
line
=
line
.
strip
()
.
lower
()
for
w
in
string
.
punctuation
:
line
=
line
.
replace
(
w
,
""
)
line
=
line
.
replace
(
" "
,
""
)
if
w
!=
"'"
:
line
=
line
.
replace
(
w
,
""
)
line
=
line
.
replace
(
" "
,
" "
)
print
(
line
)
egs/wmt20/mt/run.sh
查看文件 @
67d8695f
...
...
@@ -44,10 +44,10 @@ lcrm=1
tokenizer
=
1
use_specific_dict
=
1
specific_prefix
=
asr5k_st10k
specific_dir
=
${
root_dir
}
/data/
iwslt2022/st_lcrm_asr
src_vocab_prefix
=
spm_
unigram5000_asr
tgt_vocab_prefix
=
spm_
unigram10000_st
specific_prefix
=
unified
specific_dir
=
${
root_dir
}
/data/
wmt20/vocab
src_vocab_prefix
=
spm_
en
tgt_vocab_prefix
=
spm_
zh
org_data_dir
=
${
root_dir
}
/data/
${
dataset
}
data_dir
=
${
root_dir
}
/data/
${
dataset
}
/mt
...
...
@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if
[[
!
-e
${
data_dir
}
]]
;
then
mkdir
-p
${
data_dir
}
fi
if
[[
!
-e
${
data_dir
}
/data
]]
;
then
mkdir
-p
${
data_dir
}
/data
fi
if
[[
!
-f
${
data_dir
}
/
${
src_vocab_prefix
}
.txt
||
!
-f
${
data_dir
}
/
${
tgt_vocab_prefix
}
.txt
]]
;
then
if
[[
${
use_specific_dict
}
-eq
0
]]
;
then
...
...
@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--tgt-vocab-type
${
tgt_vocab_type
}
--src-vocab-size
${
src_vocab_size
}
--tgt-vocab-size
${
tgt_vocab_size
}
"
if
[[
$share_dict
-eq
1
]]
;
then
cmd
=
"
$cmd
--share"
fi
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
else
cp
-r
${
specific_dir
}
/
${
src_vocab_prefix
}
.
*
${
data_dir
}
cp
${
specific_dir
}
/
${
tgt_vocab_prefix
}
.
*
${
data_dir
}
fi
fi
mkdir
-p
${
data_dir
}
/data
for
split
in
${
train_subset
}
${
valid_subset
}
${
trans_subset
}
;
do
{
if
[[
-d
${
org_data_dir
}
/data/
${
split
}
/txt
]]
;
then
text_dir
=
${
org_data_dir
}
/data/
${
split
}
/txt
else
text_dir
=
${
org_data_dir
}
/data/
${
split
}
cmd
=
"python
${
code_dir
}
/examples/speech_to_text/prep_mt_data.py
--data-root
${
org_data_dir
}
--output-root
${
data_dir
}
--splits
${
train_subset
}
,
${
valid_subset
}
,
${
trans_subset
}
--src-lang
${
src_lang
}
--tgt-lang
${
tgt_lang
}
--src-vocab-prefix
${
src_vocab_prefix
}
--tgt-vocab-prefix
${
tgt_vocab_prefix
}
"
fi
if
[[
$share_dict
-eq
1
]]
;
then
cmd
=
"
$cmd
--share"
fi
src_text
=
${
text_dir
}
/
${
split
}
.
${
src_lang
}
tgt_text
=
${
text_dir
}
/
${
split
}
.
${
tgt_lang
}
cmd
=
"cat
${
src_text
}
"
if
[[
${
lcrm
}
-eq
1
]]
;
then
cmd
=
"python local/lower_rm.py
${
src_text
}
"
cmd
=
"
$cmd
--lowercase-src
--rm-punc-src"
fi
cmd
=
"
${
cmd
}
| spm_encode --model
${
data_dir
}
/
${
src_vocab_prefix
}
.model
--output_format=piece
>
${
data_dir
}
/data/
${
split
}
.
${
src_lang
}
"
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
cmd
=
"spm_encode
--model
${
data_dir
}
/
${
tgt_vocab_prefix
}
.model
--output_format=piece
<
${
tgt_text
}
>
${
data_dir
}
/data/
${
split
}
.
${
tgt_lang
}
"
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
}
&
done
wait
fi
cmd
=
"python
${
code_dir
}
/fairseq_cli/preprocess.py
--source-lang
${
src_lang
}
--target-lang
${
tgt_lang
}
...
...
examples/speech_to_text/prep_audio_data.py
查看文件 @
67d8695f
...
...
@@ -296,6 +296,9 @@ def process(args):
gen_manifest_flag
=
True
break
punctuation_str
=
string
.
punctuation
punctuation_str
.
replace
(
"'"
,
""
)
train_text
=
[]
if
args
.
overwrite
or
gen_manifest_flag
:
if
not
use_raw
:
...
...
@@ -340,7 +343,7 @@ def process(args):
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
else
:
...
...
@@ -414,7 +417,7 @@ def process(args):
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
train_text
.
append
(
src_utt
)
...
...
examples/speech_to_text/prep_mt_data.py
查看文件 @
67d8695f
...
...
@@ -11,6 +11,7 @@ from pathlib import Path
from
tempfile
import
NamedTemporaryFile
from
typing
import
Tuple
import
string
import
sentencepiece
as
spm
from
examples.speech_to_text.data_utils
import
(
gen_vocab
,
...
...
@@ -62,6 +63,8 @@ def process(args):
splits
=
args
.
splits
.
split
(
","
)
src_train_text
=
[]
tgt_train_text
=
[]
manifest
=
{
c
:
[]
for
c
in
MANIFEST_COLUMNS
}
sent_num
=
[
0
]
lang
=
f
"{src_lang}-{tgt_lang}"
cur_root
=
Path
(
args
.
data_root
)
.
absolute
()
/
lang
...
...
@@ -70,20 +73,22 @@ def process(args):
else
:
output_root
=
Path
(
args
.
output_root
)
.
absolute
()
punctuation_str
=
string
.
punctuation
punctuation_str
=
punctuation_str
.
replace
(
"'"
,
""
)
# Generate TSV manifest
print
(
"Generating manifest..."
)
for
split
in
splits
:
is_train_split
=
split
.
startswith
(
"train"
)
manifest
=
{
c
:
[]
for
c
in
MANIFEST_COLUMNS
}
dataset
=
MTDataset
(
args
.
data_root
,
src_lang
,
tgt_lang
,
split
,
args
.
tokenizer
)
for
src_text
,
tgt_text
in
tqdm
(
dataset
):
if
args
.
lowercase_src
:
src_text
=
src_text
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
for
w
in
punctuation_str
:
src_text
=
src_text
.
replace
(
w
,
""
)
src_text
=
src_text
.
replace
(
" "
,
""
)
src_text
=
src_text
.
replace
(
" "
,
"
"
)
manifest
[
"src_text"
]
.
append
(
src_text
)
manifest
[
"tgt_text"
]
.
append
(
tgt_text
)
...
...
@@ -94,34 +99,50 @@ def process(args):
if
is_train_split
:
src_train_text
.
extend
(
manifest
[
"src_text"
])
tgt_train_text
.
extend
(
manifest
[
"tgt_text"
])
sent_num
.
append
(
len
(
manifest
[
"src_text"
]))
# Generate vocab and yaml
print
(
"Generating vocabulary..."
)
tgt_v_size_str
=
""
if
args
.
tgt_vocab_type
==
"char"
else
str
(
args
.
tgt_vocab_size
)
tgt_spm_filename_prefix
=
f
"spm_{args.tgt_vocab_type}{tgt_v_size_str}"
if
args
.
share
:
tgt_train_text
.
extend
(
src_train_text
)
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_share"
if
args
.
tgt_vocab_prefix
is
not
None
:
tgt_spm_filename_prefix
=
args
.
tgt_vocab_prefix
else
:
tgt_train_text
.
extend
(
src_train_text
)
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_share"
src_spm_filename_prefix
=
tgt_spm_filename_prefix
else
:
src_v_size_str
=
""
if
args
.
src_vocab_type
==
"char"
else
str
(
args
.
src_vocab_size
)
src_spm_filename_prefix
=
f
"spm_{args.src_vocab_type}{src_v_size_str}"
src_spm_filename_prefix
=
src_spm_filename_prefix
+
"_"
+
src_lang
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_"
+
tgt_lang
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
tgt_train_text
:
f
.
write
(
t
+
"
\n
"
)
gen_vocab
(
Path
(
f
.
name
),
output_root
/
tgt_spm_filename_prefix
,
args
.
tgt_vocab_type
,
args
.
tgt_vocab_size
,
normalization_rule_name
=
"identity"
if
tgt_lang
==
"zh"
else
None
)
if
not
args
.
share
:
if
args
.
tgt_vocab_prefix
is
not
None
:
tgt_spm_filename_prefix
=
args
.
tgt_vocab_prefix
else
:
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_"
+
tgt_lang
if
args
.
src_vocab_prefix
is
not
None
:
src_spm_filename_prefix
=
args
.
src_vocab_prefix
else
:
src_v_size_str
=
""
if
args
.
src_vocab_type
==
"char"
else
str
(
args
.
src_vocab_size
)
src_spm_filename_prefix
=
f
"spm_{args.src_vocab_type}{src_v_size_str}"
src_spm_filename_prefix
=
src_spm_filename_prefix
+
"_"
+
src_lang
src_spm_model
=
(
output_root
/
(
src_spm_filename_prefix
+
".model"
))
.
as_posix
()
tgt_spm_model
=
(
output_root
/
(
tgt_spm_filename_prefix
+
".model"
))
.
as_posix
()
if
not
os
.
path
.
exists
(
tgt_spm_model
):
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
tgt_train_text
:
f
.
write
(
t
+
"
\n
"
)
gen_vocab
(
Path
(
f
.
name
),
output_root
/
tgt_spm_filename_prefix
,
args
.
tgt_vocab_type
,
args
.
tgt_vocab_size
,
normalization_rule_name
=
"identity"
if
tgt_lang
==
"zh"
else
None
)
if
not
args
.
share
and
not
os
.
path
.
exists
(
src_spm_model
):
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
src_train_text
:
f
.
write
(
t
+
"
\n
"
)
...
...
@@ -133,6 +154,38 @@ def process(args):
normalization_rule_name
=
"identity"
if
tgt_lang
==
"zh"
else
None
)
# Generate sentencepiece
print
(
"Applying sentencepiece..."
)
tgt_sp
=
spm
.
SentencePieceProcessor
()
tgt_sp
.
Load
(
tgt_spm_model
)
if
args
.
share
:
src_sp
=
tgt_sp
else
:
src_sp
=
spm
.
SentencePieceProcessor
()
src_sp
.
Load
(
src_spm_model
)
index
=
0
for
split
in
splits
:
src_text
=
manifest
[
"src_text"
][
sent_num
[
index
]:
sent_num
[
index
+
1
]]
tgt_text
=
manifest
[
"tgt_text"
][
sent_num
[
index
]:
sent_num
[
index
+
1
]]
index
+=
1
src_spm_name
=
(
output_root
/
"data"
/
(
split
+
"."
+
src_lang
))
.
as_posix
()
tgt_spm_name
=
(
output_root
/
"data"
/
(
split
+
"."
+
tgt_lang
))
.
as_posix
()
with
open
(
src_spm_name
,
'w'
)
as
f
:
for
sentence
in
src_text
:
pieces
=
src_sp
.
EncodeAsPieces
(
sentence
)
result
=
" "
.
join
(
pieces
)
f
.
write
(
result
+
"
\n
"
)
with
open
(
tgt_spm_name
,
'w'
)
as
f
:
for
sentence
in
tgt_text
:
pieces
=
tgt_sp
.
EncodeAsPieces
(
sentence
)
result
=
" "
.
join
(
pieces
)
f
.
write
(
result
+
"
\n
"
)
# Generate config YAML
yaml_filename
=
f
"config.yaml"
if
args
.
share
:
...
...
@@ -162,19 +215,19 @@ def main():
parser
.
add_argument
(
"--src-vocab-type"
,
default
=
"unigram"
,
required
=
True
,
type
=
str
,
choices
=
[
"bpe"
,
"unigram"
,
"char"
],
)
parser
.
add_argument
(
"--tgt-vocab-type"
,
default
=
"unigram"
,
required
=
True
,
type
=
str
,
choices
=
[
"bpe"
,
"unigram"
,
"char"
],
)
parser
.
add_argument
(
"--src-vocab-size"
,
default
=
10000
,
type
=
int
)
parser
.
add_argument
(
"--tgt-vocab-size"
,
default
=
10000
,
type
=
int
)
parser
.
add_argument
(
"--src-vocab-prefix"
,
default
=
None
,
type
=
str
,
help
=
"prefix of the specific source vocabulary"
)
parser
.
add_argument
(
"--tgt-vocab-prefix"
,
default
=
None
,
type
=
str
,
help
=
"prefix of the specific target vocabulary"
)
parser
.
add_argument
(
"--size"
,
default
=-
1
,
type
=
int
)
parser
.
add_argument
(
"--splits"
,
default
=
"train,dev,test"
,
type
=
str
)
parser
.
add_argument
(
"--lowercase-src"
,
action
=
"store_true"
,
help
=
"lowercase the source text"
)
...
...
fairseq/criterions/ctc.py
查看文件 @
67d8695f
...
...
@@ -9,6 +9,8 @@ from argparse import Namespace
from
dataclasses
import
dataclass
,
field
from
omegaconf
import
II
from
typing
import
Optional
import
numpy
as
np
import
logging
import
torch
import
torch.nn.functional
as
F
...
...
@@ -19,6 +21,7 @@ from fairseq.data.data_utils import post_process
from
fairseq.tasks
import
FairseqTask
from
fairseq.logging.meters
import
safe_round
logger
=
logging
.
getLogger
(
__name__
)
@dataclass
class
CtcCriterionConfig
(
FairseqDataclass
):
...
...
@@ -31,8 +34,8 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
"sentencepiece"
,
metadata
=
{
"help"
:
"how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
ctc_entropy
:
float
=
field
(
...
...
@@ -43,6 +46,10 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
0.0
,
metadata
=
{
"help"
:
"weight of intermedia CTC loss"
},
)
target_ctc_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight of intermedia CTC loss for target sentence"
},
)
ctc_self_distill_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight of the self distillation CTC loss"
},
...
...
@@ -116,10 +123,12 @@ class CtcCriterion(FairseqCriterion):
self
.
ctc_weight
=
ctc_weight
self
.
intermedia_ctc_weight
=
cfg
.
intermedia_ctc_weight
self
.
target_ctc_weight
=
cfg
.
target_ctc_weight
self
.
ctc_self_distill_weight
=
cfg
.
ctc_self_distill_weight
self
.
ctc_entropy
=
cfg
.
ctc_entropy
self
.
all_ctc_weight
=
self
.
ctc_weight
+
self
.
intermedia_ctc_weight
+
self
.
ctc_self_distill_weight
+
self
.
ctc_entropy
self
.
all_ctc_weight
=
self
.
ctc_weight
+
self
.
intermedia_ctc_weight
+
self
.
target_ctc_weight
+
\
self
.
ctc_self_distill_weight
+
self
.
ctc_entropy
if
self
.
all_ctc_weight
>
0
:
assert
getattr
(
task
,
"src_dict"
,
None
)
is
not
None
,
"CTC need a source dictionary."
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
blank
=
self
.
blank_idx
,
reduction
=
"sum"
,
zero_infinity
=
True
)
...
...
@@ -145,7 +154,7 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask
=
~
net_output
[
"ctc_padding_mask"
][
0
]
else
:
non_padding_mask
=
~
net_output
[
"encoder_padding_mask"
][
0
]
input_lengths
=
non_padding_mask
.
long
()
.
sum
(
-
1
)
ctc_input_lengths
=
input_lengths
=
non_padding_mask
.
long
()
.
sum
(
-
1
)
pad_mask
=
(
transcript
[
"tokens"
]
!=
self
.
pad_idx
)
&
(
transcript
[
"tokens"
]
!=
self
.
eos_idx
...
...
@@ -215,6 +224,43 @@ class CtcCriterion(FairseqCriterion):
if
lprobs
is
None
:
lprobs
=
inter_lprobs
target_ctc_num
=
0
target_ctc_loss
=
0
if
"target_ctc_logits"
in
net_output
:
target_ctc_num
=
len
(
net_output
[
"target_ctc_logits"
])
# calculate the target CTC loss
if
self
.
target_ctc_weight
>
0
and
target_ctc_num
>
0
:
target
=
sample
[
"target"
]
pad_mask
=
(
target
!=
self
.
pad_idx
)
&
(
target
!=
self
.
eos_idx
)
targets_flat
=
target
.
masked_select
(
pad_mask
)
target_length
=
pad_mask
.
sum
(
-
1
)
for
i
in
range
(
target_ctc_num
):
out
=
net_output
[
"target_ctc_logits"
][
i
]
if
type
(
out
)
==
list
:
inter_ctc_logit
=
out
[
0
]
padding
=
~
out
[
1
]
input_lengths
=
padding
.
long
()
.
sum
(
-
1
)
else
:
inter_ctc_logit
=
out
inter_lprobs
=
model
.
get_normalized_probs
(
[
inter_ctc_logit
],
log_probs
=
True
)
.
contiguous
()
# (T, B, C) from the encoder
inter_lprobs
.
batch_first
=
False
with
torch
.
backends
.
cudnn
.
flags
(
enabled
=
False
):
loss
=
self
.
ctc_loss
(
inter_lprobs
,
targets_flat
,
ctc_input_lengths
,
target_length
,
)
target_ctc_loss
+=
loss
target_ctc_loss
/=
target_ctc_num
logging_output
[
"target_ctc_loss"
]
=
utils
.
item
(
target_ctc_loss
.
data
)
# calculate the self distillation CTC loss
ctc_self_distill_loss
=
0
ctc_self_distill_num
=
0
...
...
@@ -247,6 +293,7 @@ class CtcCriterion(FairseqCriterion):
loss
=
\
self
.
ctc_weight
*
ctc_loss
+
\
self
.
intermedia_ctc_weight
*
intermedia_ctc_loss
+
\
self
.
target_ctc_weight
*
target_ctc_loss
+
\
self
.
ctc_self_distill_weight
*
ctc_self_distill_loss
+
\
self
.
ctc_entropy
*
ctc_entropy
...
...
@@ -264,9 +311,9 @@ class CtcCriterion(FairseqCriterion):
w_len
=
0
wv_errs
=
0
for
lp
,
t
,
inp_l
in
zip
(
lprobs_t
,
sample
[
"transcript"
][
"tokens"
]
if
"transcript"
in
sample
else
sample
[
"target"
],
input_lengths
,
lprobs_t
,
sample
[
"transcript"
][
"tokens"
]
if
"transcript"
in
sample
else
sample
[
"target"
],
input_lengths
,
):
lp
=
lp
[:
inp_l
]
.
unsqueeze
(
0
)
...
...
@@ -283,7 +330,7 @@ class CtcCriterion(FairseqCriterion):
decoded
=
decoded
[
0
]
p
=
(
t
!=
self
.
task
.
target_dictionary
.
pad
())
&
(
t
!=
self
.
task
.
target_dictionary
.
eos
()
t
!=
self
.
task
.
target_dictionary
.
eos
()
)
targ
=
t
[
p
]
targ_units
=
self
.
task
.
target_dictionary
.
string
(
targ
)
...
...
@@ -332,6 +379,9 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"intermedia_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
target_ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"target_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
ctc_self_distill_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"ctc_self_distill_loss"
,
0
)
for
log
in
logging_outputs
)
)
...
...
@@ -346,6 +396,9 @@ class CtcCriterion(FairseqCriterion):
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
if
np
.
isnan
(
all_ctc_loss_sum
)
or
np
.
isinf
(
all_ctc_loss_sum
)
or
all_ctc_loss_sum
<
0
:
logger
.
error
(
"Illegal loss
%
f!"
%
all_ctc_loss_sum
)
if
all_ctc_loss_sum
>
0
:
if
"loss"
not
in
logging_outputs
[
0
]:
metrics
.
log_scalar
(
...
...
@@ -383,6 +436,14 @@ class CtcCriterion(FairseqCriterion):
sample_size
,
round
=
3
,
)
if
target_ctc_loss_sum
>
0
:
metrics
.
log_scalar
(
"target_ctc_loss"
,
target_ctc_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
ctc_self_distill_loss_sum
>
0
:
metrics
.
log_scalar
(
"ctc_self_distill_loss"
,
...
...
@@ -404,8 +465,8 @@ class CtcCriterion(FairseqCriterion):
metrics
.
log_scalar
(
"_c_total"
,
c_total
)
w_errors
=
sum
(
log
.
get
(
"w_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_errors"
,
w_errors
)
wv_errors
=
sum
(
log
.
get
(
"wv_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_wv_errors"
,
wv_errors
)
#
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
#
metrics.log_scalar("_wv_errors", wv_errors)
w_total
=
sum
(
log
.
get
(
"w_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_total"
,
w_total
)
...
...
@@ -427,14 +488,14 @@ class CtcCriterion(FairseqCriterion):
if
meters
[
"_w_total"
]
.
sum
>
0
else
float
(
"nan"
),
)
metrics
.
log_derived
(
"raw_wer"
,
lambda
meters
:
safe_round
(
meters
[
"_wv_errors"
]
.
sum
*
100.0
/
meters
[
"_w_total"
]
.
sum
,
3
)
if
meters
[
"_w_total"
]
.
sum
>
0
else
float
(
"nan"
),
)
#
metrics.log_derived(
#
"raw_wer",
#
lambda meters: safe_round(
#
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
#
)
#
if meters["_w_total"].sum > 0
#
else float("nan"),
#
)
@staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
...
...
fairseq/models/speech_to_text/modules/ctc.py
查看文件 @
67d8695f
...
...
@@ -17,6 +17,7 @@ class CTC(nn.Module):
def
__init__
(
self
,
embed_dim
,
dictionary_size
,
dropout
,
need_layernorm
=
False
):
super
(
CTC
,
self
)
.
__init__
()
self
.
embed_dim
=
embed_dim
self
.
ctc_projection
=
nn
.
Linear
(
embed_dim
,
dictionary_size
,
bias
=
False
)
nn
.
init
.
normal_
(
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
67d8695f
...
...
@@ -232,6 +232,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"rope"
,
"abs"
,
"transfer"
,
"reduced_rel_pos"
,
],
help
=
"transformer encoder self-attention layer type"
)
...
...
@@ -579,6 +580,12 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
float
,
help
=
"probability of dropping the followed layers"
,
)
parser
.
add_argument
(
"--intermedia-temperature"
,
default
=
1
,
type
=
float
,
help
=
"temperature of the intermedia ctc probability"
,
)
pass
@classmethod
...
...
@@ -626,10 +633,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
.
pds_position_embed
=
[
int
(
n
)
for
n
in
args
.
pds_position_embed
.
split
(
"_"
)]
self
.
pds_attn_heads
=
[
int
(
n
)
for
n
in
args
.
pds_attn_heads
.
split
(
"_"
)]
self
.
pds_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_ffn_ratios
.
split
(
"_"
)]
if
self
.
attn_type
==
"reduced"
:
self
.
pds_attn_ds_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_attn_ds_ratios
.
split
(
"_"
)]
else
:
self
.
pds_attn_ds_ratios
=
None
self
.
pds_attn_ds_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_attn_ds_ratios
.
split
(
"_"
)]
self
.
pds_conv_strides
=
[
int
(
n
)
for
n
in
args
.
pds_conv_strides
.
split
(
"_"
)]
self
.
pds_attn_strides
=
[
int
(
n
)
for
n
in
args
.
pds_attn_strides
.
split
(
"_"
)]
...
...
@@ -674,7 +678,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ffn_ratio
=
self
.
pds_ffn_ratios
[
i
]
num_head
=
self
.
pds_attn_heads
[
i
]
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
if
self
.
attn_type
==
"reduced"
else
-
1
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
#
if self.attn_type == "reduced" else -1
conv_stride
=
self
.
pds_conv_strides
[
i
]
attn_stride
=
self
.
pds_attn_strides
[
i
]
if
conv_stride
!=
1
or
attn_stride
!=
1
:
...
...
@@ -712,7 +716,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# position encoding
if
use_pos_embed
:
if
self
.
attn_type
==
"rel_pos"
:
if
self
.
attn_type
in
[
"rel_pos"
,
"reduced_rel_pos"
]
:
pos_embed
=
RelPositionalEncoding
(
args
.
max_source_positions
,
embed_dim
)
...
...
@@ -850,7 +854,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if
ctc_layer
<=
0
:
embed_dim
=
self
.
pds_embed_dims
[
i
]
break
if
inter_ctc_module
is
None
:
if
inter_ctc_module
is
None
or
embed_dim
!=
inter_ctc_module
.
embed_dim
:
self
.
ctc
=
CTC
(
embed_dim
,
dictionary_size
=
len
(
task
.
source_dictionary
),
dropout
=
args
.
dropout
,
...
...
@@ -866,6 +870,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else
:
self
.
layer_norm
=
None
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
self
.
gather_cos_sim
=
getattr
(
args
,
"gather_cos_sim"
,
False
)
self
.
dis
=
2
self
.
cos_sim
=
dict
()
...
...
@@ -933,7 +938,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout
if
pos_embed
:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
,
"reduced_rel_pos"
]:
positions
=
pos_embed
(
x
)
elif
self
.
attn_type
==
"rope"
:
...
...
@@ -981,7 +986,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
logit
=
ctc
(
x
.
clone
())
intermedia_ctc_logits
.
append
([
logit
,
encoder_padding_mask
])
prob
=
utils
.
softmax
(
logit
,
dim
=-
1
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
adapter
([
x
,
prob
],
encoder_padding_mask
)
if
self
.
fusion_stages_num
!=
0
:
...
...
@@ -1131,9 +1136,9 @@ def base_architecture(args):
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
67d8695f
...
...
@@ -118,6 +118,7 @@ class S2TCTCModel(FairseqEncoderModel):
"rope"
,
"abs"
,
"transfer"
,
"reduced_rel_pos"
,
],
help
=
"transformer encoder self-attention layer type"
)
...
...
@@ -739,9 +740,9 @@ def base_architecture(args):
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
67d8695f
差异被折叠。
点击展开。
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
67d8695f
...
...
@@ -395,6 +395,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type
=
float
,
help
=
"probability of dropping the followed layers"
,
)
parser
.
add_argument
(
"--intermedia-temperature"
,
default
=
1
,
type
=
float
,
help
=
"temperature of the intermedia ctc probability"
,
)
pass
@classmethod
...
...
@@ -585,6 +591,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
adapter
=
Adapter
(
dim
,
args
.
intermedia_adapter
,
task
.
source_dictionary
,
strategy
=
strategy
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
@staticmethod
def
pooling_ratio
():
...
...
@@ -683,7 +690,7 @@ class S2TTransformerEncoder(FairseqEncoder):
intermedia_ctc_logits
.
append
(
logit
)
# prob = self.ctc.softmax(norm_x)
prob
=
utils
.
softmax
(
logit
,
dim
=-
1
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
# gather cosine similarity
...
...
fairseq/modules/__init__.py
查看文件 @
67d8695f
...
...
@@ -54,6 +54,7 @@ from .positional_encoding import (
from
.espnet_multihead_attention
import
(
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
ReducedRelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
)
...
...
@@ -113,10 +114,11 @@ __all__ = [
"unfold1d"
,
"ESPNETMultiHeadedAttention"
,
"PositionalEmbedding"
,
"RelPositionMultiHeadedAttention"
,
"PositionalEncoding"
,
"LegacyRelPositionalEncoding"
,
"RelPositionalEncoding"
,
"RelPositionMultiHeadedAttention"
,
"ReducedRelPositionMultiHeadedAttention"
,
"LegacyRelPositionMultiHeadedAttention"
,
"RotaryPositionalEmbedding"
,
"RotaryPositionMultiHeadedAttention"
,
...
...
fairseq/modules/attention.py
查看文件 @
67d8695f
...
...
@@ -1347,7 +1347,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
# Assert
assert
not
(
group_size
>
1
and
kernel_size
is
not
None
),
"Local grouped attention not implemented"
assert
not
(
group_size
>
1
and
stride
>
1
is
not
None
),
"Strided grouped attention not implemented"
assert
not
(
group_size
>
1
and
stride
>
1
),
"Strided grouped attention not implemented"
assert
not
(
linear_att
and
relative_pos_enc
),
"Linear attention requires absolute positional encodings"
# Pre Norm
...
...
fairseq/modules/espnet_multihead_attention.py
查看文件 @
67d8695f
...
...
@@ -14,6 +14,7 @@ from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding
,
apply_rotary_pos_emb
,
)
from
.layer_norm
import
LayerNorm
class
ESPNETMultiHeadedAttention
(
nn
.
Module
):
...
...
@@ -72,6 +73,7 @@ class ESPNETMultiHeadedAttention(nn.Module):
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
to
(
bool
),
# -1e8 if scores.dtype == torch.float32 else -1e4
float
(
"-inf"
),
# (batch, head, time1, time2)
)
self
.
attn
=
F
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
torch
.
float32
)
.
type_as
(
scores
)
# (batch, head, time1, time2)
...
...
@@ -195,6 +197,131 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
return
scores
,
None
class
ReducedRelPositionMultiHeadedAttention
(
RelPositionMultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_feat
,
n_head
,
dropout
,
zero_triu
=
False
,
sample_ratio
=
1
,
reduced_method
=
"conv"
,
reduced_q
=
False
,
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
()
.
__init__
(
n_feat
,
n_head
,
dropout
)
self
.
zero_triu
=
zero_triu
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
self
.
pos_bias_v
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_u
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_v
)
super
()
.
__init__
(
n_feat
,
n_head
,
dropout
,
zero_triu
)
self
.
sample_ratio
=
sample_ratio
self
.
reduced_method
=
reduced_method
self
.
reduced_q
=
reduced_q
if
reduced_q
:
assert
self
.
reduced_method
==
'group'
,
"only support grouped method for query reduction"
if
self
.
sample_ratio
>
1
:
if
reduced_method
==
"conv"
:
self
.
sr
=
nn
.
Conv1d
(
n_feat
,
n_feat
,
kernel_size
=
sample_ratio
,
stride
=
sample_ratio
,
)
self
.
norm
=
LayerNorm
(
n_feat
)
elif
reduced_method
==
"pool"
:
self
.
linear
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
norm
=
LayerNorm
(
n_feat
)
self
.
act
=
nn
.
GELU
()
elif
reduced_method
==
"group"
:
pass
def
forward
(
self
,
query
,
key
,
value
,
pos_emb
,
key_padding_mask
=
None
,
**
kwargs
):
"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor 2T-1 X B(1) X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
"""
# (bsz, seq_len, dim)
query
=
query
.
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
)
value
=
value
.
transpose
(
0
,
1
)
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
tgt_len
=
query
.
size
(
1
)
query_
=
query
if
self
.
sample_ratio
>
1
:
assert
tgt_len
%
self
.
sample_ratio
==
0
,
\
(
"sample ratio
%
d is mismatched with length
%
d"
%
(
self
.
sample_ratio
,
tgt_len
))
if
self
.
reduced_method
==
"conv"
:
query_
=
query
.
transpose
(
1
,
2
)
# bsz, dim, seq_len
query_
=
self
.
sr
(
query_
)
.
transpose
(
1
,
2
)
# bsz, seq_len, dim
query_
=
self
.
norm
(
query_
)
elif
self
.
reduced_method
==
"pool"
:
query_
=
query
.
transpose
(
1
,
2
)
# bsz, dim, seq_len
pool_length
=
int
(
tgt_len
/
self
.
sample_ratio
)
query_
=
nn
.
functional
.
adaptive_max_pool1d
(
query_
,
pool_length
)
.
transpose
(
1
,
2
)
query_
=
self
.
act
(
self
.
norm
(
query_
))
key
=
value
=
query_
if
key_padding_mask
is
not
None
:
key_padding_mask
=
key_padding_mask
[:,
::
self
.
sample_ratio
]
n_batch
=
query
.
size
(
0
)
q
=
self
.
linear_q
(
query
)
.
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
k
=
self
.
linear_k
(
key
)
.
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
v
=
self
.
linear_v
(
value
)
.
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
q
=
q
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
k
=
k
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
v
=
v
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
# q, k, v = self.forward_qkv(query, key, value)
q
=
q
.
transpose
(
1
,
2
)
# (batch, time1, head, d_k)
n_batch_pos
=
pos_emb
.
size
(
0
)
p
=
self
.
linear_pos
(
pos_emb
)
.
view
(
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
)
p
=
p
.
transpose
(
1
,
2
)
# (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u
=
(
q
+
self
.
pos_bias_u
)
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
q_with_bias_v
=
(
q
+
self
.
pos_bias_v
)
.
transpose
(
1
,
2
)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
(
-
2
,
-
1
))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
-
2
,
-
1
))
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
scores
=
self
.
forward_attention
(
v
,
scores
,
key_padding_mask
)
scores
=
scores
.
transpose
(
0
,
1
)
return
scores
,
None
class
LegacyRelPositionMultiHeadedAttention
(
RelPositionMultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
...
...
fairseq/modules/pds_layer.py
查看文件 @
67d8695f
...
...
@@ -12,6 +12,7 @@ from fairseq.modules import (
ConvolutionModule
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
ReducedRelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
LocalMultiheadAttention
,
ReducedMultiheadAttention
,
...
...
@@ -91,6 +92,7 @@ class PDSTransformerEncoderLayer(nn.Module):
self
.
macaron_norm
=
None
self
.
ffn_scale
=
1.0
self
.
conv_stride
=
conv_stride
if
args
.
use_cnn_module
:
self
.
conv_norm
=
LayerNorm
(
embed_dim
)
self
.
conv_module
=
ConvolutionModule
(
...
...
@@ -104,7 +106,6 @@ class PDSTransformerEncoderLayer(nn.Module):
self
.
final_norm
=
LayerNorm
(
expand_embed_dim
)
# Convolution Residual
self
.
conv_stride
=
conv_stride
self
.
conv_res
=
nn
.
Sequential
(
Permute3D
(
1
,
2
,
0
),
nn
.
Conv1d
(
embed_dim
,
expand_embed_dim
,
kernel_size
=
1
,
stride
=
conv_stride
),
...
...
@@ -173,6 +174,15 @@ class PDSTransformerEncoderLayer(nn.Module):
attention_heads
,
dropout
=
dropout
,
)
elif
self
.
attn_type
==
"reduced_rel_pos"
:
return
ReducedRelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
sample_ratio
=
sample_ratio
,
reduced_method
=
getattr
(
args
,
"attention_reduced_method"
,
"conv"
),
reduced_q
=
getattr
(
args
,
"attention_reduced_q"
,
False
)
)
elif
self
.
attn_type
==
"rel_pos_legacy"
:
return
LegacyRelPositionMultiHeadedAttention
(
embed_dim
,
...
...
@@ -284,7 +294,7 @@ class PDSTransformerEncoderLayer(nn.Module):
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
,
"reduced_rel_pos"
]:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
query
=
x
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论