Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
67d8695f
Commit
67d8695f
authored
Mar 04, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add target ctc
parent
d4255246
隐藏空白字符变更
内嵌
并排
正在显示
17 个修改的文件
包含
526 行增加
和
255 行删除
+526
-255
egs/libri_trans/asr/conf/debug.yaml
+35
-36
egs/librispeech/asr/conf/purectc_pds_base_8_compare2.yaml
+3
-3
egs/mustc/asr/conf/purectc_pds_base_8_grow.yaml
+50
-0
egs/wmt20/mt/local/lower_rm.py
+3
-3
egs/wmt20/mt/run.sh
+23
-41
examples/speech_to_text/prep_audio_data.py
+5
-2
examples/speech_to_text/prep_mt_data.py
+79
-26
fairseq/criterions/ctc.py
+80
-19
fairseq/models/speech_to_text/modules/ctc.py
+1
-0
fairseq/models/speech_to_text/pdss2t_transformer.py
+15
-10
fairseq/models/speech_to_text/s2t_ctc.py
+2
-1
fairseq/models/speech_to_text/s2t_sate.py
+79
-109
fairseq/models/speech_to_text/s2t_transformer.py
+8
-1
fairseq/modules/__init__.py
+3
-1
fairseq/modules/attention.py
+1
-1
fairseq/modules/espnet_multihead_attention.py
+127
-0
fairseq/modules/pds_layer.py
+12
-2
没有找到文件。
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
67d8695f
arch
:
s2t_ctc
arch
:
s2t_sate
encoder-type
:
pds
share-decoder-input-output-embed
:
True
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_1_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim
:
240
pds-stages
:
3
#ctc-layer: 15
pds-layers
:
4_5_6
pds-ratios
:
2_2_2
pds-fusion
:
False
pds-fusion-method
:
all_conv
pds-embed-dims
:
120_168_240
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1
pds-kernel-sizes
:
3_3_3
pds-ffn-ratios
:
4_4_4
pds-attn-heads
:
4_4_4
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-init-lr
:
1e-7
warmup-updates
:
10000
warmup-updates
:
10000
lr
:
0.0015
lr
:
2e-3
adam_betas
:
(0.9,0.98)
adam_betas
:
(0.9,0.98)
criterion
:
ctc
ctc-weight
:
0.3
post-process
:
sentencepiece
target-ctc-weight
:
0.2
target-ctc-layers
:
3,6
criterion
:
label_smoothed_cross_entropy_with_ctc
label_smoothing
:
0.1
encoder-normalize-before
:
True
decoder-normalize-before
:
True
subsampling-type
:
conv1d
subsampling-layers
:
2
subsampling-filter
:
1024
subsampling-kernel
:
5
subsampling-stride
:
2
subsampling-norm
:
none
subsampling-activation
:
glu
dropout
:
0.1
dropout
:
0.1
activation-fn
:
relu
activation-fn
:
relu
encoder-layers
:
15
encoder-embed-dim
:
256
encoder-ffn-embed-dim
:
2048
encoder-layers
:
12
text-encoder-layers
:
6
decoder-layers
:
6
encoder-attention-heads
:
4
decoder-embed-dim
:
256
decoder-ffn-embed-dim
:
2048
decoder-attention-heads
:
4
macaron-style
:
True
acoustic-encoder
:
transformer
use-cnn-module
:
True
adapter
:
league
cnn-module-kernel
:
15
encoder-activation-fn
:
swish
encoder-attention-type
:
rel_pos
#load-pretrained-encoder-from:
#load-pretrained-encoder-from:
#load-pretrained-acoustic-encoder-from:
#load-pretrained-text-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
egs/librispeech/asr/conf/purectc_pds_base_8_compare2.yaml
查看文件 @
67d8695f
...
@@ -13,15 +13,15 @@ encoder-type: pds
...
@@ -13,15 +13,15 @@ encoder-type: pds
encoder-embed-dim
:
240
encoder-embed-dim
:
240
pds-stages
:
3
pds-stages
:
3
#ctc-layer: 15
#ctc-layer: 15
pds-layers
:
4_5_6
pds-layers
:
5_5_5
pds-ratios
:
2_2_2
pds-ratios
:
2_2_2
pds-fusion
:
Fals
e
pds-fusion
:
Tru
e
pds-fusion-method
:
all_conv
pds-fusion-method
:
all_conv
pds-embed-dims
:
120_168_240
pds-embed-dims
:
120_168_240
pds-ds-method
:
conv
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-embed-norm
:
True
pds-position-embed
:
1_1_1
pds-position-embed
:
1_1_1
pds-kernel-sizes
:
3_3_3
pds-kernel-sizes
:
5_5_5
pds-ffn-ratios
:
4_4_4
pds-ffn-ratios
:
4_4_4
pds-attn-heads
:
4_4_4
pds-attn-heads
:
4_4_4
...
...
egs/mustc/asr/conf/purectc_pds_base_8_grow.yaml
0 → 100644
查看文件 @
67d8695f
arch
:
s2t_ctc
encoder-type
:
pds
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 1
#intermedia-temperature: 5
encoder-attention-type
:
rel_pos
#encoder-attention-type: reduced_rel_pos
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
encoder-embed-dim
:
512
pds-stages
:
4
#ctc-layer: 15
encoder-layers
:
10
pds-layers
:
3_2_2_3
pds-ratios
:
2_2_1_2
pds-fusion
:
True
pds-fusion-method
:
all_conv
pds-embed-dims
:
256_384_384_512
pds-ds-method
:
conv
pds-embed-norm
:
True
pds-position-embed
:
1_1_1_1
pds-kernel-sizes
:
5_5_5_5
pds-ffn-ratios
:
8_4_4_4
pds-attn-heads
:
4_6_6_8
optimizer
:
adam
clip-norm
:
10.0
lr-scheduler
:
inverse_sqrt
warmup-init-lr
:
1e-7
warmup-updates
:
10000
lr
:
0.002
adam_betas
:
(0.9,0.98)
criterion
:
ctc
post-process
:
sentencepiece
dropout
:
0.1
activation-fn
:
relu
macaron-style
:
True
use-cnn-module
:
True
cnn-module-kernel
:
31
encoder-activation-fn
:
swish
#load-pretrained-encoder-from:
egs/wmt20/mt/local/lower_rm.py
查看文件 @
67d8695f
...
@@ -8,7 +8,7 @@ with open(in_file, "r", encoding="utf-8") as f:
...
@@ -8,7 +8,7 @@ with open(in_file, "r", encoding="utf-8") as f:
for
line
in
f
.
readlines
():
for
line
in
f
.
readlines
():
line
=
line
.
strip
()
.
lower
()
line
=
line
.
strip
()
.
lower
()
for
w
in
string
.
punctuation
:
for
w
in
string
.
punctuation
:
line
=
line
.
replace
(
w
,
""
)
if
w
!=
"'"
:
line
=
line
.
replace
(
" "
,
""
)
line
=
line
.
replace
(
w
,
""
)
line
=
line
.
replace
(
" "
,
" "
)
print
(
line
)
print
(
line
)
egs/wmt20/mt/run.sh
查看文件 @
67d8695f
...
@@ -44,10 +44,10 @@ lcrm=1
...
@@ -44,10 +44,10 @@ lcrm=1
tokenizer
=
1
tokenizer
=
1
use_specific_dict
=
1
use_specific_dict
=
1
specific_prefix
=
asr5k_st10k
specific_prefix
=
unified
specific_dir
=
${
root_dir
}
/data/
iwslt2022/st_lcrm_asr
specific_dir
=
${
root_dir
}
/data/
wmt20/vocab
src_vocab_prefix
=
spm_
unigram5000_asr
src_vocab_prefix
=
spm_
en
tgt_vocab_prefix
=
spm_
unigram10000_st
tgt_vocab_prefix
=
spm_
zh
org_data_dir
=
${
root_dir
}
/data/
${
dataset
}
org_data_dir
=
${
root_dir
}
/data/
${
dataset
}
data_dir
=
${
root_dir
}
/data/
${
dataset
}
/mt
data_dir
=
${
root_dir
}
/data/
${
dataset
}
/mt
...
@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
...
@@ -141,6 +141,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if
[[
!
-e
${
data_dir
}
]]
;
then
if
[[
!
-e
${
data_dir
}
]]
;
then
mkdir
-p
${
data_dir
}
mkdir
-p
${
data_dir
}
fi
fi
if
[[
!
-e
${
data_dir
}
/data
]]
;
then
mkdir
-p
${
data_dir
}
/data
fi
if
[[
!
-f
${
data_dir
}
/
${
src_vocab_prefix
}
.txt
||
!
-f
${
data_dir
}
/
${
tgt_vocab_prefix
}
.txt
]]
;
then
if
[[
!
-f
${
data_dir
}
/
${
src_vocab_prefix
}
.txt
||
!
-f
${
data_dir
}
/
${
tgt_vocab_prefix
}
.txt
]]
;
then
if
[[
${
use_specific_dict
}
-eq
0
]]
;
then
if
[[
${
use_specific_dict
}
-eq
0
]]
;
then
...
@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
...
@@ -154,52 +157,31 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--tgt-vocab-type
${
tgt_vocab_type
}
--tgt-vocab-type
${
tgt_vocab_type
}
--src-vocab-size
${
src_vocab_size
}
--src-vocab-size
${
src_vocab_size
}
--tgt-vocab-size
${
tgt_vocab_size
}
"
--tgt-vocab-size
${
tgt_vocab_size
}
"
if
[[
$share_dict
-eq
1
]]
;
then
cmd
=
"
$cmd
--share"
fi
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
else
else
cp
-r
${
specific_dir
}
/
${
src_vocab_prefix
}
.
*
${
data_dir
}
cp
-r
${
specific_dir
}
/
${
src_vocab_prefix
}
.
*
${
data_dir
}
cp
${
specific_dir
}
/
${
tgt_vocab_prefix
}
.
*
${
data_dir
}
cp
${
specific_dir
}
/
${
tgt_vocab_prefix
}
.
*
${
data_dir
}
fi
fi
mkdir
-p
${
data_dir
}
/data
cmd
=
"python
${
code_dir
}
/examples/speech_to_text/prep_mt_data.py
for
split
in
${
train_subset
}
${
valid_subset
}
${
trans_subset
}
;
do
--data-root
${
org_data_dir
}
{
--output-root
${
data_dir
}
if
[[
-d
${
org_data_dir
}
/data/
${
split
}
/txt
]]
;
then
--splits
${
train_subset
}
,
${
valid_subset
}
,
${
trans_subset
}
text_dir
=
${
org_data_dir
}
/data/
${
split
}
/txt
--src-lang
${
src_lang
}
else
--tgt-lang
${
tgt_lang
}
text_dir
=
${
org_data_dir
}
/data/
${
split
}
--src-vocab-prefix
${
src_vocab_prefix
}
--tgt-vocab-prefix
${
tgt_vocab_prefix
}
"
fi
if
[[
$share_dict
-eq
1
]]
;
then
cmd
=
"
$cmd
--share"
fi
fi
src_text
=
${
text_dir
}
/
${
split
}
.
${
src_lang
}
tgt_text
=
${
text_dir
}
/
${
split
}
.
${
tgt_lang
}
cmd
=
"cat
${
src_text
}
"
if
[[
${
lcrm
}
-eq
1
]]
;
then
if
[[
${
lcrm
}
-eq
1
]]
;
then
cmd
=
"python local/lower_rm.py
${
src_text
}
"
cmd
=
"
$cmd
--lowercase-src
--rm-punc-src"
fi
fi
cmd
=
"
${
cmd
}
| spm_encode --model
${
data_dir
}
/
${
src_vocab_prefix
}
.model
--output_format=piece
>
${
data_dir
}
/data/
${
split
}
.
${
src_lang
}
"
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
fi
cmd
=
"spm_encode
--model
${
data_dir
}
/
${
tgt_vocab_prefix
}
.model
--output_format=piece
<
${
tgt_text
}
>
${
data_dir
}
/data/
${
split
}
.
${
tgt_lang
}
"
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
}
&
done
wait
cmd
=
"python
${
code_dir
}
/fairseq_cli/preprocess.py
cmd
=
"python
${
code_dir
}
/fairseq_cli/preprocess.py
--source-lang
${
src_lang
}
--target-lang
${
tgt_lang
}
--source-lang
${
src_lang
}
--target-lang
${
tgt_lang
}
...
...
examples/speech_to_text/prep_audio_data.py
查看文件 @
67d8695f
...
@@ -296,6 +296,9 @@ def process(args):
...
@@ -296,6 +296,9 @@ def process(args):
gen_manifest_flag
=
True
gen_manifest_flag
=
True
break
break
punctuation_str
=
string
.
punctuation
punctuation_str
.
replace
(
"'"
,
""
)
train_text
=
[]
train_text
=
[]
if
args
.
overwrite
or
gen_manifest_flag
:
if
args
.
overwrite
or
gen_manifest_flag
:
if
not
use_raw
:
if
not
use_raw
:
...
@@ -340,7 +343,7 @@ def process(args):
...
@@ -340,7 +343,7 @@ def process(args):
if
args
.
lowercase_src
:
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
else
:
else
:
...
@@ -414,7 +417,7 @@ def process(args):
...
@@ -414,7 +417,7 @@ def process(args):
if
args
.
lowercase_src
:
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
for
w
in
punctuation_str
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
src_utt
.
replace
(
w
,
""
)
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
src_utt
=
" "
.
join
(
src_utt
.
split
(
" "
))
train_text
.
append
(
src_utt
)
train_text
.
append
(
src_utt
)
...
...
examples/speech_to_text/prep_mt_data.py
查看文件 @
67d8695f
...
@@ -11,6 +11,7 @@ from pathlib import Path
...
@@ -11,6 +11,7 @@ from pathlib import Path
from
tempfile
import
NamedTemporaryFile
from
tempfile
import
NamedTemporaryFile
from
typing
import
Tuple
from
typing
import
Tuple
import
string
import
string
import
sentencepiece
as
spm
from
examples.speech_to_text.data_utils
import
(
from
examples.speech_to_text.data_utils
import
(
gen_vocab
,
gen_vocab
,
...
@@ -62,6 +63,8 @@ def process(args):
...
@@ -62,6 +63,8 @@ def process(args):
splits
=
args
.
splits
.
split
(
","
)
splits
=
args
.
splits
.
split
(
","
)
src_train_text
=
[]
src_train_text
=
[]
tgt_train_text
=
[]
tgt_train_text
=
[]
manifest
=
{
c
:
[]
for
c
in
MANIFEST_COLUMNS
}
sent_num
=
[
0
]
lang
=
f
"{src_lang}-{tgt_lang}"
lang
=
f
"{src_lang}-{tgt_lang}"
cur_root
=
Path
(
args
.
data_root
)
.
absolute
()
/
lang
cur_root
=
Path
(
args
.
data_root
)
.
absolute
()
/
lang
...
@@ -70,20 +73,22 @@ def process(args):
...
@@ -70,20 +73,22 @@ def process(args):
else
:
else
:
output_root
=
Path
(
args
.
output_root
)
.
absolute
()
output_root
=
Path
(
args
.
output_root
)
.
absolute
()
punctuation_str
=
string
.
punctuation
punctuation_str
=
punctuation_str
.
replace
(
"'"
,
""
)
# Generate TSV manifest
# Generate TSV manifest
print
(
"Generating manifest..."
)
print
(
"Generating manifest..."
)
for
split
in
splits
:
for
split
in
splits
:
is_train_split
=
split
.
startswith
(
"train"
)
is_train_split
=
split
.
startswith
(
"train"
)
manifest
=
{
c
:
[]
for
c
in
MANIFEST_COLUMNS
}
dataset
=
MTDataset
(
args
.
data_root
,
src_lang
,
tgt_lang
,
split
,
args
.
tokenizer
)
dataset
=
MTDataset
(
args
.
data_root
,
src_lang
,
tgt_lang
,
split
,
args
.
tokenizer
)
for
src_text
,
tgt_text
in
tqdm
(
dataset
):
for
src_text
,
tgt_text
in
tqdm
(
dataset
):
if
args
.
lowercase_src
:
if
args
.
lowercase_src
:
src_text
=
src_text
.
lower
()
src_text
=
src_text
.
lower
()
if
args
.
rm_punc_src
:
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
for
w
in
punctuation_str
:
src_text
=
src_text
.
replace
(
w
,
""
)
src_text
=
src_text
.
replace
(
w
,
""
)
src_text
=
src_text
.
replace
(
" "
,
""
)
src_text
=
src_text
.
replace
(
" "
,
"
"
)
manifest
[
"src_text"
]
.
append
(
src_text
)
manifest
[
"src_text"
]
.
append
(
src_text
)
manifest
[
"tgt_text"
]
.
append
(
tgt_text
)
manifest
[
"tgt_text"
]
.
append
(
tgt_text
)
...
@@ -94,34 +99,50 @@ def process(args):
...
@@ -94,34 +99,50 @@ def process(args):
if
is_train_split
:
if
is_train_split
:
src_train_text
.
extend
(
manifest
[
"src_text"
])
src_train_text
.
extend
(
manifest
[
"src_text"
])
tgt_train_text
.
extend
(
manifest
[
"tgt_text"
])
tgt_train_text
.
extend
(
manifest
[
"tgt_text"
])
sent_num
.
append
(
len
(
manifest
[
"src_text"
]))
# Generate vocab and yaml
# Generate vocab and yaml
print
(
"Generating vocabulary..."
)
tgt_v_size_str
=
""
if
args
.
tgt_vocab_type
==
"char"
else
str
(
args
.
tgt_vocab_size
)
tgt_v_size_str
=
""
if
args
.
tgt_vocab_type
==
"char"
else
str
(
args
.
tgt_vocab_size
)
tgt_spm_filename_prefix
=
f
"spm_{args.tgt_vocab_type}{tgt_v_size_str}"
tgt_spm_filename_prefix
=
f
"spm_{args.tgt_vocab_type}{tgt_v_size_str}"
if
args
.
share
:
if
args
.
share
:
tgt_train_text
.
extend
(
src_train_text
)
if
args
.
tgt_vocab_prefix
is
not
None
:
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_share"
tgt_spm_filename_prefix
=
args
.
tgt_vocab_prefix
else
:
tgt_train_text
.
extend
(
src_train_text
)
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_share"
src_spm_filename_prefix
=
tgt_spm_filename_prefix
src_spm_filename_prefix
=
tgt_spm_filename_prefix
else
:
else
:
src_v_size_str
=
""
if
args
.
src_vocab_type
==
"char"
else
str
(
args
.
src_vocab_size
)
if
args
.
tgt_vocab_prefix
is
not
None
:
src_spm_filename_prefix
=
f
"spm_{args.src_vocab_type}{src_v_size_str}"
tgt_spm_filename_prefix
=
args
.
tgt_vocab_prefix
else
:
src_spm_filename_prefix
=
src_spm_filename_prefix
+
"_"
+
src_lang
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_"
+
tgt_lang
tgt_spm_filename_prefix
=
tgt_spm_filename_prefix
+
"_"
+
tgt_lang
if
args
.
src_vocab_prefix
is
not
None
:
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
src_spm_filename_prefix
=
args
.
src_vocab_prefix
for
t
in
tgt_train_text
:
else
:
f
.
write
(
t
+
"
\n
"
)
src_v_size_str
=
""
if
args
.
src_vocab_type
==
"char"
else
str
(
args
.
src_vocab_size
)
gen_vocab
(
src_spm_filename_prefix
=
f
"spm_{args.src_vocab_type}{src_v_size_str}"
Path
(
f
.
name
),
src_spm_filename_prefix
=
src_spm_filename_prefix
+
"_"
+
src_lang
output_root
/
tgt_spm_filename_prefix
,
args
.
tgt_vocab_type
,
src_spm_model
=
(
output_root
/
(
src_spm_filename_prefix
+
".model"
))
.
as_posix
()
args
.
tgt_vocab_size
,
tgt_spm_model
=
(
output_root
/
(
tgt_spm_filename_prefix
+
".model"
))
.
as_posix
()
normalization_rule_name
=
"identity"
if
tgt_lang
==
"zh"
else
None
)
if
not
os
.
path
.
exists
(
tgt_spm_model
):
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
if
not
args
.
share
:
for
t
in
tgt_train_text
:
f
.
write
(
t
+
"
\n
"
)
gen_vocab
(
Path
(
f
.
name
),
output_root
/
tgt_spm_filename_prefix
,
args
.
tgt_vocab_type
,
args
.
tgt_vocab_size
,
normalization_rule_name
=
"identity"
if
tgt_lang
==
"zh"
else
None
)
if
not
args
.
share
and
not
os
.
path
.
exists
(
src_spm_model
):
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
src_train_text
:
for
t
in
src_train_text
:
f
.
write
(
t
+
"
\n
"
)
f
.
write
(
t
+
"
\n
"
)
...
@@ -133,6 +154,38 @@ def process(args):
...
@@ -133,6 +154,38 @@ def process(args):
normalization_rule_name
=
"identity"
if
tgt_lang
==
"zh"
else
None
normalization_rule_name
=
"identity"
if
tgt_lang
==
"zh"
else
None
)
)
# Generate sentencepiece
print
(
"Applying sentencepiece..."
)
tgt_sp
=
spm
.
SentencePieceProcessor
()
tgt_sp
.
Load
(
tgt_spm_model
)
if
args
.
share
:
src_sp
=
tgt_sp
else
:
src_sp
=
spm
.
SentencePieceProcessor
()
src_sp
.
Load
(
src_spm_model
)
index
=
0
for
split
in
splits
:
src_text
=
manifest
[
"src_text"
][
sent_num
[
index
]:
sent_num
[
index
+
1
]]
tgt_text
=
manifest
[
"tgt_text"
][
sent_num
[
index
]:
sent_num
[
index
+
1
]]
index
+=
1
src_spm_name
=
(
output_root
/
"data"
/
(
split
+
"."
+
src_lang
))
.
as_posix
()
tgt_spm_name
=
(
output_root
/
"data"
/
(
split
+
"."
+
tgt_lang
))
.
as_posix
()
with
open
(
src_spm_name
,
'w'
)
as
f
:
for
sentence
in
src_text
:
pieces
=
src_sp
.
EncodeAsPieces
(
sentence
)
result
=
" "
.
join
(
pieces
)
f
.
write
(
result
+
"
\n
"
)
with
open
(
tgt_spm_name
,
'w'
)
as
f
:
for
sentence
in
tgt_text
:
pieces
=
tgt_sp
.
EncodeAsPieces
(
sentence
)
result
=
" "
.
join
(
pieces
)
f
.
write
(
result
+
"
\n
"
)
# Generate config YAML
# Generate config YAML
yaml_filename
=
f
"config.yaml"
yaml_filename
=
f
"config.yaml"
if
args
.
share
:
if
args
.
share
:
...
@@ -162,19 +215,19 @@ def main():
...
@@ -162,19 +215,19 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"--src-vocab-type"
,
"--src-vocab-type"
,
default
=
"unigram"
,
default
=
"unigram"
,
required
=
True
,
type
=
str
,
type
=
str
,
choices
=
[
"bpe"
,
"unigram"
,
"char"
],
choices
=
[
"bpe"
,
"unigram"
,
"char"
],
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tgt-vocab-type"
,
"--tgt-vocab-type"
,
default
=
"unigram"
,
default
=
"unigram"
,
required
=
True
,
type
=
str
,
type
=
str
,
choices
=
[
"bpe"
,
"unigram"
,
"char"
],
choices
=
[
"bpe"
,
"unigram"
,
"char"
],
)
)
parser
.
add_argument
(
"--src-vocab-size"
,
default
=
10000
,
type
=
int
)
parser
.
add_argument
(
"--src-vocab-size"
,
default
=
10000
,
type
=
int
)
parser
.
add_argument
(
"--tgt-vocab-size"
,
default
=
10000
,
type
=
int
)
parser
.
add_argument
(
"--tgt-vocab-size"
,
default
=
10000
,
type
=
int
)
parser
.
add_argument
(
"--src-vocab-prefix"
,
default
=
None
,
type
=
str
,
help
=
"prefix of the specific source vocabulary"
)
parser
.
add_argument
(
"--tgt-vocab-prefix"
,
default
=
None
,
type
=
str
,
help
=
"prefix of the specific target vocabulary"
)
parser
.
add_argument
(
"--size"
,
default
=-
1
,
type
=
int
)
parser
.
add_argument
(
"--size"
,
default
=-
1
,
type
=
int
)
parser
.
add_argument
(
"--splits"
,
default
=
"train,dev,test"
,
type
=
str
)
parser
.
add_argument
(
"--splits"
,
default
=
"train,dev,test"
,
type
=
str
)
parser
.
add_argument
(
"--lowercase-src"
,
action
=
"store_true"
,
help
=
"lowercase the source text"
)
parser
.
add_argument
(
"--lowercase-src"
,
action
=
"store_true"
,
help
=
"lowercase the source text"
)
...
...
fairseq/criterions/ctc.py
查看文件 @
67d8695f
...
@@ -9,6 +9,8 @@ from argparse import Namespace
...
@@ -9,6 +9,8 @@ from argparse import Namespace
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
omegaconf
import
II
from
omegaconf
import
II
from
typing
import
Optional
from
typing
import
Optional
import
numpy
as
np
import
logging
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -19,6 +21,7 @@ from fairseq.data.data_utils import post_process
...
@@ -19,6 +21,7 @@ from fairseq.data.data_utils import post_process
from
fairseq.tasks
import
FairseqTask
from
fairseq.tasks
import
FairseqTask
from
fairseq.logging.meters
import
safe_round
from
fairseq.logging.meters
import
safe_round
logger
=
logging
.
getLogger
(
__name__
)
@dataclass
@dataclass
class
CtcCriterionConfig
(
FairseqDataclass
):
class
CtcCriterionConfig
(
FairseqDataclass
):
...
@@ -31,8 +34,8 @@ class CtcCriterionConfig(FairseqDataclass):
...
@@ -31,8 +34,8 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
"sentencepiece"
,
default
=
"sentencepiece"
,
metadata
=
{
metadata
=
{
"help"
:
"how to post process predictions into words. can be letter, "
"help"
:
"how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
"See fairseq.data.data_utils.post_process() for full list of options"
},
},
)
)
ctc_entropy
:
float
=
field
(
ctc_entropy
:
float
=
field
(
...
@@ -43,6 +46,10 @@ class CtcCriterionConfig(FairseqDataclass):
...
@@ -43,6 +46,10 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"weight of intermedia CTC loss"
},
metadata
=
{
"help"
:
"weight of intermedia CTC loss"
},
)
)
target_ctc_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight of intermedia CTC loss for target sentence"
},
)
ctc_self_distill_weight
:
float
=
field
(
ctc_self_distill_weight
:
float
=
field
(
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"weight of the self distillation CTC loss"
},
metadata
=
{
"help"
:
"weight of the self distillation CTC loss"
},
...
@@ -116,10 +123,12 @@ class CtcCriterion(FairseqCriterion):
...
@@ -116,10 +123,12 @@ class CtcCriterion(FairseqCriterion):
self
.
ctc_weight
=
ctc_weight
self
.
ctc_weight
=
ctc_weight
self
.
intermedia_ctc_weight
=
cfg
.
intermedia_ctc_weight
self
.
intermedia_ctc_weight
=
cfg
.
intermedia_ctc_weight
self
.
target_ctc_weight
=
cfg
.
target_ctc_weight
self
.
ctc_self_distill_weight
=
cfg
.
ctc_self_distill_weight
self
.
ctc_self_distill_weight
=
cfg
.
ctc_self_distill_weight
self
.
ctc_entropy
=
cfg
.
ctc_entropy
self
.
ctc_entropy
=
cfg
.
ctc_entropy
self
.
all_ctc_weight
=
self
.
ctc_weight
+
self
.
intermedia_ctc_weight
+
self
.
ctc_self_distill_weight
+
self
.
ctc_entropy
self
.
all_ctc_weight
=
self
.
ctc_weight
+
self
.
intermedia_ctc_weight
+
self
.
target_ctc_weight
+
\
self
.
ctc_self_distill_weight
+
self
.
ctc_entropy
if
self
.
all_ctc_weight
>
0
:
if
self
.
all_ctc_weight
>
0
:
assert
getattr
(
task
,
"src_dict"
,
None
)
is
not
None
,
"CTC need a source dictionary."
assert
getattr
(
task
,
"src_dict"
,
None
)
is
not
None
,
"CTC need a source dictionary."
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
blank
=
self
.
blank_idx
,
reduction
=
"sum"
,
zero_infinity
=
True
)
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
blank
=
self
.
blank_idx
,
reduction
=
"sum"
,
zero_infinity
=
True
)
...
@@ -145,7 +154,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -145,7 +154,7 @@ class CtcCriterion(FairseqCriterion):
non_padding_mask
=
~
net_output
[
"ctc_padding_mask"
][
0
]
non_padding_mask
=
~
net_output
[
"ctc_padding_mask"
][
0
]
else
:
else
:
non_padding_mask
=
~
net_output
[
"encoder_padding_mask"
][
0
]
non_padding_mask
=
~
net_output
[
"encoder_padding_mask"
][
0
]
input_lengths
=
non_padding_mask
.
long
()
.
sum
(
-
1
)
ctc_input_lengths
=
input_lengths
=
non_padding_mask
.
long
()
.
sum
(
-
1
)
pad_mask
=
(
transcript
[
"tokens"
]
!=
self
.
pad_idx
)
&
(
pad_mask
=
(
transcript
[
"tokens"
]
!=
self
.
pad_idx
)
&
(
transcript
[
"tokens"
]
!=
self
.
eos_idx
transcript
[
"tokens"
]
!=
self
.
eos_idx
...
@@ -215,6 +224,43 @@ class CtcCriterion(FairseqCriterion):
...
@@ -215,6 +224,43 @@ class CtcCriterion(FairseqCriterion):
if
lprobs
is
None
:
if
lprobs
is
None
:
lprobs
=
inter_lprobs
lprobs
=
inter_lprobs
target_ctc_num
=
0
target_ctc_loss
=
0
if
"target_ctc_logits"
in
net_output
:
target_ctc_num
=
len
(
net_output
[
"target_ctc_logits"
])
# calculate the target CTC loss
if
self
.
target_ctc_weight
>
0
and
target_ctc_num
>
0
:
target
=
sample
[
"target"
]
pad_mask
=
(
target
!=
self
.
pad_idx
)
&
(
target
!=
self
.
eos_idx
)
targets_flat
=
target
.
masked_select
(
pad_mask
)
target_length
=
pad_mask
.
sum
(
-
1
)
for
i
in
range
(
target_ctc_num
):
out
=
net_output
[
"target_ctc_logits"
][
i
]
if
type
(
out
)
==
list
:
inter_ctc_logit
=
out
[
0
]
padding
=
~
out
[
1
]
input_lengths
=
padding
.
long
()
.
sum
(
-
1
)
else
:
inter_ctc_logit
=
out
inter_lprobs
=
model
.
get_normalized_probs
(
[
inter_ctc_logit
],
log_probs
=
True
)
.
contiguous
()
# (T, B, C) from the encoder
inter_lprobs
.
batch_first
=
False
with
torch
.
backends
.
cudnn
.
flags
(
enabled
=
False
):
loss
=
self
.
ctc_loss
(
inter_lprobs
,
targets_flat
,
ctc_input_lengths
,
target_length
,
)
target_ctc_loss
+=
loss
target_ctc_loss
/=
target_ctc_num
logging_output
[
"target_ctc_loss"
]
=
utils
.
item
(
target_ctc_loss
.
data
)
# calculate the self distillation CTC loss
# calculate the self distillation CTC loss
ctc_self_distill_loss
=
0
ctc_self_distill_loss
=
0
ctc_self_distill_num
=
0
ctc_self_distill_num
=
0
...
@@ -247,6 +293,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -247,6 +293,7 @@ class CtcCriterion(FairseqCriterion):
loss
=
\
loss
=
\
self
.
ctc_weight
*
ctc_loss
+
\
self
.
ctc_weight
*
ctc_loss
+
\
self
.
intermedia_ctc_weight
*
intermedia_ctc_loss
+
\
self
.
intermedia_ctc_weight
*
intermedia_ctc_loss
+
\
self
.
target_ctc_weight
*
target_ctc_loss
+
\
self
.
ctc_self_distill_weight
*
ctc_self_distill_loss
+
\
self
.
ctc_self_distill_weight
*
ctc_self_distill_loss
+
\
self
.
ctc_entropy
*
ctc_entropy
self
.
ctc_entropy
*
ctc_entropy
...
@@ -264,9 +311,9 @@ class CtcCriterion(FairseqCriterion):
...
@@ -264,9 +311,9 @@ class CtcCriterion(FairseqCriterion):
w_len
=
0
w_len
=
0
wv_errs
=
0
wv_errs
=
0
for
lp
,
t
,
inp_l
in
zip
(
for
lp
,
t
,
inp_l
in
zip
(
lprobs_t
,
lprobs_t
,
sample
[
"transcript"
][
"tokens"
]
if
"transcript"
in
sample
else
sample
[
"target"
],
sample
[
"transcript"
][
"tokens"
]
if
"transcript"
in
sample
else
sample
[
"target"
],
input_lengths
,
input_lengths
,
):
):
lp
=
lp
[:
inp_l
]
.
unsqueeze
(
0
)
lp
=
lp
[:
inp_l
]
.
unsqueeze
(
0
)
...
@@ -283,7 +330,7 @@ class CtcCriterion(FairseqCriterion):
...
@@ -283,7 +330,7 @@ class CtcCriterion(FairseqCriterion):
decoded
=
decoded
[
0
]
decoded
=
decoded
[
0
]
p
=
(
t
!=
self
.
task
.
target_dictionary
.
pad
())
&
(
p
=
(
t
!=
self
.
task
.
target_dictionary
.
pad
())
&
(
t
!=
self
.
task
.
target_dictionary
.
eos
()
t
!=
self
.
task
.
target_dictionary
.
eos
()
)
)
targ
=
t
[
p
]
targ
=
t
[
p
]
targ_units
=
self
.
task
.
target_dictionary
.
string
(
targ
)
targ_units
=
self
.
task
.
target_dictionary
.
string
(
targ
)
...
@@ -332,6 +379,9 @@ class CtcCriterion(FairseqCriterion):
...
@@ -332,6 +379,9 @@ class CtcCriterion(FairseqCriterion):
inter_ctc_loss_sum
=
utils
.
item
(
inter_ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"intermedia_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
sum
(
log
.
get
(
"intermedia_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
)
target_ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"target_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
ctc_self_distill_loss_sum
=
utils
.
item
(
ctc_self_distill_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"ctc_self_distill_loss"
,
0
)
for
log
in
logging_outputs
)
sum
(
log
.
get
(
"ctc_self_distill_loss"
,
0
)
for
log
in
logging_outputs
)
)
)
...
@@ -346,6 +396,9 @@ class CtcCriterion(FairseqCriterion):
...
@@ -346,6 +396,9 @@ class CtcCriterion(FairseqCriterion):
sample_size
=
utils
.
item
(
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
)
if
np
.
isnan
(
all_ctc_loss_sum
)
or
np
.
isinf
(
all_ctc_loss_sum
)
or
all_ctc_loss_sum
<
0
:
logger
.
error
(
"Illegal loss
%
f!"
%
all_ctc_loss_sum
)
if
all_ctc_loss_sum
>
0
:
if
all_ctc_loss_sum
>
0
:
if
"loss"
not
in
logging_outputs
[
0
]:
if
"loss"
not
in
logging_outputs
[
0
]:
metrics
.
log_scalar
(
metrics
.
log_scalar
(
...
@@ -383,6 +436,14 @@ class CtcCriterion(FairseqCriterion):
...
@@ -383,6 +436,14 @@ class CtcCriterion(FairseqCriterion):
sample_size
,
sample_size
,
round
=
3
,
round
=
3
,
)
)
if
target_ctc_loss_sum
>
0
:
metrics
.
log_scalar
(
"target_ctc_loss"
,
target_ctc_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
ctc_self_distill_loss_sum
>
0
:
if
ctc_self_distill_loss_sum
>
0
:
metrics
.
log_scalar
(
metrics
.
log_scalar
(
"ctc_self_distill_loss"
,
"ctc_self_distill_loss"
,
...
@@ -404,8 +465,8 @@ class CtcCriterion(FairseqCriterion):
...
@@ -404,8 +465,8 @@ class CtcCriterion(FairseqCriterion):
metrics
.
log_scalar
(
"_c_total"
,
c_total
)
metrics
.
log_scalar
(
"_c_total"
,
c_total
)
w_errors
=
sum
(
log
.
get
(
"w_errors"
,
0
)
for
log
in
logging_outputs
)
w_errors
=
sum
(
log
.
get
(
"w_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_errors"
,
w_errors
)
metrics
.
log_scalar
(
"_w_errors"
,
w_errors
)
wv_errors
=
sum
(
log
.
get
(
"wv_errors"
,
0
)
for
log
in
logging_outputs
)
#
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics
.
log_scalar
(
"_wv_errors"
,
wv_errors
)
#
metrics.log_scalar("_wv_errors", wv_errors)
w_total
=
sum
(
log
.
get
(
"w_total"
,
0
)
for
log
in
logging_outputs
)
w_total
=
sum
(
log
.
get
(
"w_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_total"
,
w_total
)
metrics
.
log_scalar
(
"_w_total"
,
w_total
)
...
@@ -427,14 +488,14 @@ class CtcCriterion(FairseqCriterion):
...
@@ -427,14 +488,14 @@ class CtcCriterion(FairseqCriterion):
if
meters
[
"_w_total"
]
.
sum
>
0
if
meters
[
"_w_total"
]
.
sum
>
0
else
float
(
"nan"
),
else
float
(
"nan"
),
)
)
metrics
.
log_derived
(
#
metrics.log_derived(
"raw_wer"
,
#
"raw_wer",
lambda
meters
:
safe_round
(
#
lambda meters: safe_round(
meters
[
"_wv_errors"
]
.
sum
*
100.0
/
meters
[
"_w_total"
]
.
sum
,
3
#
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
#
)
if
meters
[
"_w_total"
]
.
sum
>
0
#
if meters["_w_total"].sum > 0
else
float
(
"nan"
),
#
else float("nan"),
)
#
)
@staticmethod
@staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
def
logging_outputs_can_be_summed
()
->
bool
:
...
...
fairseq/models/speech_to_text/modules/ctc.py
查看文件 @
67d8695f
...
@@ -17,6 +17,7 @@ class CTC(nn.Module):
...
@@ -17,6 +17,7 @@ class CTC(nn.Module):
def
__init__
(
self
,
embed_dim
,
dictionary_size
,
dropout
,
need_layernorm
=
False
):
def
__init__
(
self
,
embed_dim
,
dictionary_size
,
dropout
,
need_layernorm
=
False
):
super
(
CTC
,
self
)
.
__init__
()
super
(
CTC
,
self
)
.
__init__
()
self
.
embed_dim
=
embed_dim
self
.
ctc_projection
=
nn
.
Linear
(
embed_dim
,
dictionary_size
,
bias
=
False
)
self
.
ctc_projection
=
nn
.
Linear
(
embed_dim
,
dictionary_size
,
bias
=
False
)
nn
.
init
.
normal_
(
nn
.
init
.
normal_
(
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
67d8695f
...
@@ -232,6 +232,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -232,6 +232,7 @@ class PDSS2TTransformerModel(S2TTransformerModel):
"rope"
,
"rope"
,
"abs"
,
"abs"
,
"transfer"
,
"transfer"
,
"reduced_rel_pos"
,
],
],
help
=
"transformer encoder self-attention layer type"
help
=
"transformer encoder self-attention layer type"
)
)
...
@@ -579,6 +580,12 @@ class PDSS2TTransformerModel(S2TTransformerModel):
...
@@ -579,6 +580,12 @@ class PDSS2TTransformerModel(S2TTransformerModel):
type
=
float
,
type
=
float
,
help
=
"probability of dropping the followed layers"
,
help
=
"probability of dropping the followed layers"
,
)
)
parser
.
add_argument
(
"--intermedia-temperature"
,
default
=
1
,
type
=
float
,
help
=
"temperature of the intermedia ctc probability"
,
)
pass
pass
@classmethod
@classmethod
...
@@ -626,10 +633,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -626,10 +633,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self
.
pds_position_embed
=
[
int
(
n
)
for
n
in
args
.
pds_position_embed
.
split
(
"_"
)]
self
.
pds_position_embed
=
[
int
(
n
)
for
n
in
args
.
pds_position_embed
.
split
(
"_"
)]
self
.
pds_attn_heads
=
[
int
(
n
)
for
n
in
args
.
pds_attn_heads
.
split
(
"_"
)]
self
.
pds_attn_heads
=
[
int
(
n
)
for
n
in
args
.
pds_attn_heads
.
split
(
"_"
)]
self
.
pds_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_ffn_ratios
.
split
(
"_"
)]
self
.
pds_ffn_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_ffn_ratios
.
split
(
"_"
)]
if
self
.
attn_type
==
"reduced"
:
self
.
pds_attn_ds_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_attn_ds_ratios
.
split
(
"_"
)]
self
.
pds_attn_ds_ratios
=
[
int
(
n
)
for
n
in
args
.
pds_attn_ds_ratios
.
split
(
"_"
)]
else
:
self
.
pds_attn_ds_ratios
=
None
self
.
pds_conv_strides
=
[
int
(
n
)
for
n
in
args
.
pds_conv_strides
.
split
(
"_"
)]
self
.
pds_conv_strides
=
[
int
(
n
)
for
n
in
args
.
pds_conv_strides
.
split
(
"_"
)]
self
.
pds_attn_strides
=
[
int
(
n
)
for
n
in
args
.
pds_attn_strides
.
split
(
"_"
)]
self
.
pds_attn_strides
=
[
int
(
n
)
for
n
in
args
.
pds_attn_strides
.
split
(
"_"
)]
...
@@ -674,7 +678,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -674,7 +678,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
ffn_ratio
=
self
.
pds_ffn_ratios
[
i
]
ffn_ratio
=
self
.
pds_ffn_ratios
[
i
]
num_head
=
self
.
pds_attn_heads
[
i
]
num_head
=
self
.
pds_attn_heads
[
i
]
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
if
self
.
attn_type
==
"reduced"
else
-
1
attn_ds_ratio
=
self
.
pds_attn_ds_ratios
[
i
]
#
if self.attn_type == "reduced" else -1
conv_stride
=
self
.
pds_conv_strides
[
i
]
conv_stride
=
self
.
pds_conv_strides
[
i
]
attn_stride
=
self
.
pds_attn_strides
[
i
]
attn_stride
=
self
.
pds_attn_strides
[
i
]
if
conv_stride
!=
1
or
attn_stride
!=
1
:
if
conv_stride
!=
1
or
attn_stride
!=
1
:
...
@@ -712,7 +716,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -712,7 +716,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# position encoding
# position encoding
if
use_pos_embed
:
if
use_pos_embed
:
if
self
.
attn_type
==
"rel_pos"
:
if
self
.
attn_type
in
[
"rel_pos"
,
"reduced_rel_pos"
]
:
pos_embed
=
RelPositionalEncoding
(
pos_embed
=
RelPositionalEncoding
(
args
.
max_source_positions
,
embed_dim
args
.
max_source_positions
,
embed_dim
)
)
...
@@ -850,7 +854,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -850,7 +854,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if
ctc_layer
<=
0
:
if
ctc_layer
<=
0
:
embed_dim
=
self
.
pds_embed_dims
[
i
]
embed_dim
=
self
.
pds_embed_dims
[
i
]
break
break
if
inter_ctc_module
is
None
:
if
inter_ctc_module
is
None
or
embed_dim
!=
inter_ctc_module
.
embed_dim
:
self
.
ctc
=
CTC
(
embed_dim
,
self
.
ctc
=
CTC
(
embed_dim
,
dictionary_size
=
len
(
task
.
source_dictionary
),
dictionary_size
=
len
(
task
.
source_dictionary
),
dropout
=
args
.
dropout
,
dropout
=
args
.
dropout
,
...
@@ -866,6 +870,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -866,6 +870,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
else
:
else
:
self
.
layer_norm
=
None
self
.
layer_norm
=
None
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
self
.
gather_cos_sim
=
getattr
(
args
,
"gather_cos_sim"
,
False
)
self
.
gather_cos_sim
=
getattr
(
args
,
"gather_cos_sim"
,
False
)
self
.
dis
=
2
self
.
dis
=
2
self
.
cos_sim
=
dict
()
self
.
cos_sim
=
dict
()
...
@@ -933,7 +938,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -933,7 +938,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
# add the position encoding and dropout
# add the position encoding and dropout
if
pos_embed
:
if
pos_embed
:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
,
"reduced_rel_pos"
]:
positions
=
pos_embed
(
x
)
positions
=
pos_embed
(
x
)
elif
self
.
attn_type
==
"rope"
:
elif
self
.
attn_type
==
"rope"
:
...
@@ -981,7 +986,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -981,7 +986,7 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
logit
=
ctc
(
x
.
clone
())
logit
=
ctc
(
x
.
clone
())
intermedia_ctc_logits
.
append
([
logit
,
encoder_padding_mask
])
intermedia_ctc_logits
.
append
([
logit
,
encoder_padding_mask
])
prob
=
utils
.
softmax
(
logit
,
dim
=-
1
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
adapter
([
x
,
prob
],
encoder_padding_mask
)
x
,
encoder_padding_mask
=
adapter
([
x
,
prob
],
encoder_padding_mask
)
if
self
.
fusion_stages_num
!=
0
:
if
self
.
fusion_stages_num
!=
0
:
...
@@ -1131,9 +1136,9 @@ def base_architecture(args):
...
@@ -1131,9 +1136,9 @@ def base_architecture(args):
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
67d8695f
...
@@ -118,6 +118,7 @@ class S2TCTCModel(FairseqEncoderModel):
...
@@ -118,6 +118,7 @@ class S2TCTCModel(FairseqEncoderModel):
"rope"
,
"rope"
,
"abs"
,
"abs"
,
"transfer"
,
"transfer"
,
"reduced_rel_pos"
,
],
],
help
=
"transformer encoder self-attention layer type"
help
=
"transformer encoder self-attention layer type"
)
)
...
@@ -739,9 +740,9 @@ def base_architecture(args):
...
@@ -739,9 +740,9 @@ def base_architecture(args):
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
67d8695f
...
@@ -4,7 +4,7 @@ import math
...
@@ -4,7 +4,7 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairseq
import
checkpoint_utils
from
fairseq
import
checkpoint_utils
,
utils
from
fairseq.models
import
(
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoder
,
register_model
,
register_model
,
...
@@ -16,7 +16,7 @@ from fairseq.models.speech_to_text import (
...
@@ -16,7 +16,7 @@ from fairseq.models.speech_to_text import (
PDSS2TTransformerModel
,
PDSS2TTransformerModel
,
PDSS2TTransformerEncoder
,
PDSS2TTransformerEncoder
,
)
)
from
fairseq.models.speech_to_text.modules
import
CTCCompressStrategy
,
Adapter
from
fairseq.models.speech_to_text.modules
import
Adapter
,
CTC
from
fairseq.modules
import
(
from
fairseq.modules
import
(
FairseqDropout
,
FairseqDropout
,
LayerNorm
,
LayerNorm
,
...
@@ -88,6 +88,12 @@ class S2TSATEModel(S2TTransformerModel):
...
@@ -88,6 +88,12 @@ class S2TSATEModel(S2TTransformerModel):
help
=
"the architecture of the acoustic encoder"
,
help
=
"the architecture of the acoustic encoder"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--target-ctc-layers"
,
default
=
None
,
type
=
str
,
help
=
"ctc layers for target sentence"
,
)
parser
.
add_argument
(
"--load-pretrained-acoustic-encoder-from"
,
"--load-pretrained-acoustic-encoder-from"
,
type
=
str
,
type
=
str
,
metavar
=
"STR"
,
metavar
=
"STR"
,
...
@@ -138,113 +144,15 @@ class S2TSATEModel(S2TTransformerModel):
...
@@ -138,113 +144,15 @@ class S2TSATEModel(S2TTransformerModel):
return
encoder
return
encoder
# class Adapter(nn.Module):
# def __init__(self, args, dictionary, embed_tokens):
# super().__init__()
#
# embed_dim = args.encoder_embed_dim
#
# self.adapter_type = args.adapter
# if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
# self.linear_adapter = nn.Sequential(
# nn.Linear(embed_dim, embed_dim),
# LayerNorm(args.encoder_embed_dim),
# nn.ReLU(),
# )
# elif self.adapter_type == "linear2":
# self.linear_adapter = nn.Sequential(
# nn.Linear(embed_dim, embed_dim),
# )
#
# if self.adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
# if embed_tokens is None:
# num_embeddings = len(dictionary)
# self.embed_adapter = Embedding(num_embeddings, embed_dim, dictionary.pad())
# else:
# self.embed_adapter = embed_tokens
#
# if self.adapter_type == "gated_league":
# self.gate_linear = nn.Linear(2 * embed_dim, embed_dim)
# elif self.adapter_type == "gated_league2":
# self.gate_linear1 = nn.Linear(embed_dim, embed_dim)
# self.gate_linear2 = nn.Linear(embed_dim, embed_dim)
#
# if self.adapter_type == "shrink":
# self.ctc_compress_method = getattr(CTCCompressStrategy, args.ctc_compress_strategy)
#
# def forward(self, x, padding):
#
# representation, distribution = x
# batch, seq_len, embed_dim = representation.size()
# org_distribution = distribution
# if distribution is not None:
# distribution = distribution.view(-1, distribution.size(-1))
# lengths = (~padding).long().sum(-1)
#
# if self.adapter_type == "linear":
# out = self.linear_adapter(representation)
#
# elif self.adapter_type == "context":
# out = torch.mm(
# distribution, self.embed_adapter.weight.float()
# ).view(batch, seq_len, -1).type_as(representation)
#
# elif self.adapter_type == "league":
# linear_out = self.linear_adapter(representation)
# soft_out = torch.mm(
# distribution, self.embed_adapter.weight.float()
# ).view(batch, seq_len, -1).type_as(linear_out)
# out = linear_out + soft_out
#
# elif self.adapter_type == "gated_league":
# linear_out = self.linear_adapter(representation)
# soft_out = torch.mm(
# distribution, self.embed_adapter.weight.float()
# ).view(batch, seq_len, -1).type_as(linear_out)
# coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
# out = coef * linear_out + (1 - coef) * soft_out
#
# elif self.adapter_type == "none":
# out = representation
#
# elif self.adapter_type == "shrink":
# from itertools import groupby
#
# with torch.no_grad():
# batch_predicted = []
# prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
# for b in range(prob_ctc.shape[0]):
# predicted = prob_ctc[b][: lengths[b]].argmax(-1).tolist()
# batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
#
# new_lengths = [len(p) for p in batch_predicted]
# weights_matrix = self.ctc_compress_method(prob_ctc, batch_predicted, new_lengths,
# prob_ctc.dtype, prob_ctc.device)
#
# # x is T x B x C -> B x C x T; weights_matrix is B x T x T'
# data_type = representation.dtype
# representation = representation.permute(1, 2, 0).float()
# compressed_output = representation.bmm(weights_matrix).type_as(data_type) # B x C x T'
# out = compressed_output.permute(2, 0, 1)
#
# out_lengths = lengths.new(new_lengths)
# padding = lengths_to_padding_mask(out_lengths)
#
# else:
# out = None
# logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
#
# return out, padding
class
TextEncoder
(
FairseqEncoder
):
class
TextEncoder
(
FairseqEncoder
):
def
__init__
(
self
,
args
,
dictionary
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
=
None
):
super
()
.
__init__
(
None
)
super
()
.
__init__
(
None
)
self
.
embed_tokens
=
None
embed_dim
=
args
.
encoder_embed_dim
embed_dim
=
args
.
encoder_embed_dim
layer_num
=
args
.
text_encoder_layers
self
.
layer_num
=
layer_num
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
args
.
no_scale_embedding
:
if
args
.
no_scale_embedding
:
self
.
embed_scale
=
1.0
self
.
embed_scale
=
1.0
...
@@ -259,13 +167,44 @@ class TextEncoder(FairseqEncoder):
...
@@ -259,13 +167,44 @@ class TextEncoder(FairseqEncoder):
)
)
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
TransformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
text_encoder_layers
)]
[
TransformerEncoderLayer
(
args
)
for
_
in
range
(
layer_num
)]
)
)
if
args
.
encoder_normalize_before
:
if
args
.
encoder_normalize_before
:
self
.
layer_norm
=
LayerNorm
(
args
.
encoder_embed_dim
)
self
.
layer_norm
=
LayerNorm
(
args
.
encoder_embed_dim
)
else
:
else
:
self
.
layer_norm
=
None
self
.
layer_norm
=
None
self
.
intermedia_ctc_layers
=
[]
self
.
target_ctc_layers
=
getattr
(
args
,
"target_ctc_layers"
,
None
)
if
self
.
target_ctc_layers
is
not
None
:
intermedia_ctc_layers
=
self
.
target_ctc_layers
.
split
(
","
)
for
layer_idx
in
intermedia_ctc_layers
:
layer_idx
=
int
(
layer_idx
)
assert
layer_idx
<=
layer_num
,
(
layer_idx
,
layer_num
)
if
layer_idx
<=
0
:
layer_idx
+=
layer_num
self
.
intermedia_ctc_layers
.
append
(
layer_idx
)
logger
.
info
(
"Intermedia CTC loss in layer
%
d"
%
layer_idx
)
self
.
ctc
=
CTC
(
embed_dim
,
dictionary_size
=
len
(
dictionary
),
dropout
=
args
.
dropout
)
if
embed_tokens
is
not
None
:
self
.
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
strategy
=
None
if
args
.
intermedia_adapter
==
"shrink"
:
strategy
=
getattr
(
args
,
"ctc_compress_strategy"
,
None
)
elif
args
.
intermedia_adapter
==
"league"
:
strategy
=
getattr
(
args
,
"intermedia_distribution_cutoff"
,
None
)
self
.
adapter
=
Adapter
(
embed_dim
,
args
.
intermedia_adapter
,
dictionary
,
strategy
=
strategy
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
def
forward
(
self
,
x
,
encoder_padding_mask
=
None
,
history
=
None
):
def
forward
(
self
,
x
,
encoder_padding_mask
=
None
,
history
=
None
):
x
=
self
.
embed_scale
*
x
x
=
self
.
embed_scale
*
x
...
@@ -273,10 +212,28 @@ class TextEncoder(FairseqEncoder):
...
@@ -273,10 +212,28 @@ class TextEncoder(FairseqEncoder):
x
=
positions
+
x
x
=
positions
+
x
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
target_ctc_logits
=
[]
layer_idx
=
0
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
layer_idx
+=
1
if
history
is
not
None
:
if
history
is
not
None
:
x
=
history
.
pop
()
x
=
history
.
pop
()
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
layer_idx
!=
self
.
layer_num
and
layer_idx
in
self
.
intermedia_ctc_layers
:
if
self
.
intermedia_drop_prob
>
0
:
p
=
torch
.
rand
(
1
)
.
uniform_
()
if
p
<
self
.
intermedia_drop_prob
:
break
norm_x
=
self
.
layer_norm
(
x
)
logit
=
self
.
ctc
(
norm_x
)
target_ctc_logits
.
append
(
logit
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
if
history
is
not
None
:
if
history
is
not
None
:
history
.
push
(
x
)
history
.
push
(
x
)
...
@@ -286,7 +243,11 @@ class TextEncoder(FairseqEncoder):
...
@@ -286,7 +243,11 @@ class TextEncoder(FairseqEncoder):
if
self
.
layer_norm
is
not
None
:
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
return
x
if
layer_idx
in
self
.
intermedia_ctc_layers
:
logit
=
self
.
ctc
(
x
)
target_ctc_logits
.
append
(
logit
)
return
x
,
target_ctc_logits
class
S2TSATEEncoder
(
FairseqEncoder
):
class
S2TSATEEncoder
(
FairseqEncoder
):
...
@@ -327,7 +288,7 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -327,7 +288,7 @@ class S2TSATEEncoder(FairseqEncoder):
args
.
encoder_attention_type
=
args
.
text_attention_type
args
.
encoder_attention_type
=
args
.
text_attention_type
# text encoder
# text encoder
self
.
text_encoder
=
TextEncoder
(
args
,
task
.
source_dictionary
)
self
.
text_encoder
=
TextEncoder
(
args
,
task
.
source_dictionary
,
embed_tokens
)
args
.
encoder_attention_type
=
acoustic_encoder_attention_type
args
.
encoder_attention_type
=
acoustic_encoder_attention_type
...
@@ -367,12 +328,13 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -367,12 +328,13 @@ class S2TSATEEncoder(FairseqEncoder):
self
.
history
.
push
(
x
)
self
.
history
.
push
(
x
)
x
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
self
.
history
)
x
,
target_ctc_logits
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
self
.
history
)
return
{
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_out"
:
[
x
],
# T x B x C
"ctc_logit"
:
[
ctc_logit
],
# T x B x C
"ctc_logit"
:
[
ctc_logit
],
# T x B x C
"intermedia_ctc_logits"
:
acoustic_encoder_out
.
get
(
"intermedia_ctc_logits"
,
[]),
# B x T x C
"intermedia_ctc_logits"
:
acoustic_encoder_out
.
get
(
"intermedia_ctc_logits"
,
[]),
# B x T x C
"target_ctc_logits"
:
target_ctc_logits
,
# B x T x C
"ctc_padding_mask"
:
[
ctc_padding_mask
],
# B x T
"ctc_padding_mask"
:
[
ctc_padding_mask
],
# B x T
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[],
# B x T x C
"encoder_embedding"
:
[],
# B x T x C
...
@@ -490,15 +452,23 @@ def base_architecture(args):
...
@@ -490,15 +452,23 @@ def base_architecture(args):
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_position_embed
=
getattr
(
args
,
"pds_position_embed"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_heads
=
getattr
(
args
,
"pds_attn_heads"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_ffn_ratios
=
getattr
(
args
,
"pds_ffn_ratios"
,
None
)
args
.
pds_attn_ds_ratios
=
getattr
(
args
,
"pds_attn_ds_ratios"
,
"1_1_1_1"
)
args
.
pds_conv_strides
=
getattr
(
args
,
"pds_conv_strides"
,
"1_1_1_1"
)
args
.
pds_attn_strides
=
getattr
(
args
,
"pds_attn_strides"
,
"1_1_1_1"
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
ctc_layer
=
getattr
(
args
,
"ctc_layer"
,
0
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
args
.
pds_dropout
=
getattr
(
args
,
"pds_dropout"
,
args
.
dropout
)
args
.
pds_fusion
=
getattr
(
args
,
"pds_fusion"
,
False
)
args
.
pds_fusion
=
getattr
(
args
,
"pds_fusion"
,
False
)
args
.
pds_fusion_method
=
getattr
(
args
,
"pds_fusion_method"
,
"all_conv"
)
args
.
pds_fusion_method
=
getattr
(
args
,
"pds_fusion_method"
,
"all_conv"
)
# intermedia CTC
args
.
pds_ctc
=
getattr
(
args
,
"pds_ctc"
,
"0_0_0_0"
)
args
.
intermedia_adapter
=
getattr
(
args
,
"intermedia_adapter"
,
"none"
)
args
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s"
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s"
)
def
s2t_sate_s
(
args
):
def
s2t_sate_s
(
args
):
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
67d8695f
...
@@ -395,6 +395,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -395,6 +395,12 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type
=
float
,
type
=
float
,
help
=
"probability of dropping the followed layers"
,
help
=
"probability of dropping the followed layers"
,
)
)
parser
.
add_argument
(
"--intermedia-temperature"
,
default
=
1
,
type
=
float
,
help
=
"temperature of the intermedia ctc probability"
,
)
pass
pass
@classmethod
@classmethod
...
@@ -585,6 +591,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -585,6 +591,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
adapter
=
Adapter
(
dim
,
args
.
intermedia_adapter
,
self
.
adapter
=
Adapter
(
dim
,
args
.
intermedia_adapter
,
task
.
source_dictionary
,
strategy
=
strategy
)
task
.
source_dictionary
,
strategy
=
strategy
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_drop_prob
=
getattr
(
args
,
"intermedia_drop_prob"
,
0
)
self
.
intermedia_temperature
=
getattr
(
args
,
"intermedia_temperature"
,
1
)
@staticmethod
@staticmethod
def
pooling_ratio
():
def
pooling_ratio
():
...
@@ -683,7 +690,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -683,7 +690,7 @@ class S2TTransformerEncoder(FairseqEncoder):
intermedia_ctc_logits
.
append
(
logit
)
intermedia_ctc_logits
.
append
(
logit
)
# prob = self.ctc.softmax(norm_x)
# prob = self.ctc.softmax(norm_x)
prob
=
utils
.
softmax
(
logit
,
dim
=-
1
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
# gather cosine similarity
# gather cosine similarity
...
...
fairseq/modules/__init__.py
查看文件 @
67d8695f
...
@@ -54,6 +54,7 @@ from .positional_encoding import (
...
@@ -54,6 +54,7 @@ from .positional_encoding import (
from
.espnet_multihead_attention
import
(
from
.espnet_multihead_attention
import
(
ESPNETMultiHeadedAttention
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
ReducedRelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
RotaryPositionMultiHeadedAttention
,
)
)
...
@@ -113,10 +114,11 @@ __all__ = [
...
@@ -113,10 +114,11 @@ __all__ = [
"unfold1d"
,
"unfold1d"
,
"ESPNETMultiHeadedAttention"
,
"ESPNETMultiHeadedAttention"
,
"PositionalEmbedding"
,
"PositionalEmbedding"
,
"RelPositionMultiHeadedAttention"
,
"PositionalEncoding"
,
"PositionalEncoding"
,
"LegacyRelPositionalEncoding"
,
"LegacyRelPositionalEncoding"
,
"RelPositionalEncoding"
,
"RelPositionalEncoding"
,
"RelPositionMultiHeadedAttention"
,
"ReducedRelPositionMultiHeadedAttention"
,
"LegacyRelPositionMultiHeadedAttention"
,
"LegacyRelPositionMultiHeadedAttention"
,
"RotaryPositionalEmbedding"
,
"RotaryPositionalEmbedding"
,
"RotaryPositionMultiHeadedAttention"
,
"RotaryPositionMultiHeadedAttention"
,
...
...
fairseq/modules/attention.py
查看文件 @
67d8695f
...
@@ -1347,7 +1347,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
...
@@ -1347,7 +1347,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
# Assert
# Assert
assert
not
(
group_size
>
1
and
kernel_size
is
not
None
),
"Local grouped attention not implemented"
assert
not
(
group_size
>
1
and
kernel_size
is
not
None
),
"Local grouped attention not implemented"
assert
not
(
group_size
>
1
and
stride
>
1
is
not
None
),
"Strided grouped attention not implemented"
assert
not
(
group_size
>
1
and
stride
>
1
),
"Strided grouped attention not implemented"
assert
not
(
linear_att
and
relative_pos_enc
),
"Linear attention requires absolute positional encodings"
assert
not
(
linear_att
and
relative_pos_enc
),
"Linear attention requires absolute positional encodings"
# Pre Norm
# Pre Norm
...
...
fairseq/modules/espnet_multihead_attention.py
查看文件 @
67d8695f
...
@@ -14,6 +14,7 @@ from fairseq.modules.rotary_positional_embedding import (
...
@@ -14,6 +14,7 @@ from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding
,
RotaryPositionalEmbedding
,
apply_rotary_pos_emb
,
apply_rotary_pos_emb
,
)
)
from
.layer_norm
import
LayerNorm
class
ESPNETMultiHeadedAttention
(
nn
.
Module
):
class
ESPNETMultiHeadedAttention
(
nn
.
Module
):
...
@@ -72,6 +73,7 @@ class ESPNETMultiHeadedAttention(nn.Module):
...
@@ -72,6 +73,7 @@ class ESPNETMultiHeadedAttention(nn.Module):
if
mask
is
not
None
:
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
scores
=
scores
.
masked_fill
(
mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
to
(
bool
),
mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
to
(
bool
),
# -1e8 if scores.dtype == torch.float32 else -1e4
float
(
"-inf"
),
# (batch, head, time1, time2)
float
(
"-inf"
),
# (batch, head, time1, time2)
)
)
self
.
attn
=
F
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
torch
.
float32
)
.
type_as
(
scores
)
# (batch, head, time1, time2)
self
.
attn
=
F
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
torch
.
float32
)
.
type_as
(
scores
)
# (batch, head, time1, time2)
...
@@ -195,6 +197,131 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
...
@@ -195,6 +197,131 @@ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
return
scores
,
None
return
scores
,
None
class
ReducedRelPositionMultiHeadedAttention
(
RelPositionMultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head: The number of heads.
n_feat: The number of features.
dropout: Dropout rate.
zero_triu: Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_feat
,
n_head
,
dropout
,
zero_triu
=
False
,
sample_ratio
=
1
,
reduced_method
=
"conv"
,
reduced_q
=
False
,
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
()
.
__init__
(
n_feat
,
n_head
,
dropout
)
self
.
zero_triu
=
zero_triu
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
self
.
pos_bias_v
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_u
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_v
)
super
()
.
__init__
(
n_feat
,
n_head
,
dropout
,
zero_triu
)
self
.
sample_ratio
=
sample_ratio
self
.
reduced_method
=
reduced_method
self
.
reduced_q
=
reduced_q
if
reduced_q
:
assert
self
.
reduced_method
==
'group'
,
"only support grouped method for query reduction"
if
self
.
sample_ratio
>
1
:
if
reduced_method
==
"conv"
:
self
.
sr
=
nn
.
Conv1d
(
n_feat
,
n_feat
,
kernel_size
=
sample_ratio
,
stride
=
sample_ratio
,
)
self
.
norm
=
LayerNorm
(
n_feat
)
elif
reduced_method
==
"pool"
:
self
.
linear
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
norm
=
LayerNorm
(
n_feat
)
self
.
act
=
nn
.
GELU
()
elif
reduced_method
==
"group"
:
pass
def
forward
(
self
,
query
,
key
,
value
,
pos_emb
,
key_padding_mask
=
None
,
**
kwargs
):
"""Compute scaled dot product attention.
Args:
query: Query tensor T X B X C
key: Key tensor T X B X C
value: Value tensor T X B X C
pos_emb: Positional embedding tensor 2T-1 X B(1) X C
key_padding_mask: Mask tensor T X B
Returns:
torch.Tensor: Output tensor T X B X C.
"""
# (bsz, seq_len, dim)
query
=
query
.
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
)
value
=
value
.
transpose
(
0
,
1
)
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
tgt_len
=
query
.
size
(
1
)
query_
=
query
if
self
.
sample_ratio
>
1
:
assert
tgt_len
%
self
.
sample_ratio
==
0
,
\
(
"sample ratio
%
d is mismatched with length
%
d"
%
(
self
.
sample_ratio
,
tgt_len
))
if
self
.
reduced_method
==
"conv"
:
query_
=
query
.
transpose
(
1
,
2
)
# bsz, dim, seq_len
query_
=
self
.
sr
(
query_
)
.
transpose
(
1
,
2
)
# bsz, seq_len, dim
query_
=
self
.
norm
(
query_
)
elif
self
.
reduced_method
==
"pool"
:
query_
=
query
.
transpose
(
1
,
2
)
# bsz, dim, seq_len
pool_length
=
int
(
tgt_len
/
self
.
sample_ratio
)
query_
=
nn
.
functional
.
adaptive_max_pool1d
(
query_
,
pool_length
)
.
transpose
(
1
,
2
)
query_
=
self
.
act
(
self
.
norm
(
query_
))
key
=
value
=
query_
if
key_padding_mask
is
not
None
:
key_padding_mask
=
key_padding_mask
[:,
::
self
.
sample_ratio
]
n_batch
=
query
.
size
(
0
)
q
=
self
.
linear_q
(
query
)
.
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
k
=
self
.
linear_k
(
key
)
.
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
v
=
self
.
linear_v
(
value
)
.
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
q
=
q
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
k
=
k
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
v
=
v
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
# q, k, v = self.forward_qkv(query, key, value)
q
=
q
.
transpose
(
1
,
2
)
# (batch, time1, head, d_k)
n_batch_pos
=
pos_emb
.
size
(
0
)
p
=
self
.
linear_pos
(
pos_emb
)
.
view
(
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
)
p
=
p
.
transpose
(
1
,
2
)
# (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u
=
(
q
+
self
.
pos_bias_u
)
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
q_with_bias_v
=
(
q
+
self
.
pos_bias_v
)
.
transpose
(
1
,
2
)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
(
-
2
,
-
1
))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
-
2
,
-
1
))
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
scores
=
self
.
forward_attention
(
v
,
scores
,
key_padding_mask
)
scores
=
scores
.
transpose
(
0
,
1
)
return
scores
,
None
class
LegacyRelPositionMultiHeadedAttention
(
RelPositionMultiHeadedAttention
):
class
LegacyRelPositionMultiHeadedAttention
(
RelPositionMultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
"""Multi-Head Attention layer with relative position encoding (old version).
...
...
fairseq/modules/pds_layer.py
查看文件 @
67d8695f
...
@@ -12,6 +12,7 @@ from fairseq.modules import (
...
@@ -12,6 +12,7 @@ from fairseq.modules import (
ConvolutionModule
,
ConvolutionModule
,
ESPNETMultiHeadedAttention
,
ESPNETMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
ReducedRelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
LegacyRelPositionMultiHeadedAttention
,
LocalMultiheadAttention
,
LocalMultiheadAttention
,
ReducedMultiheadAttention
,
ReducedMultiheadAttention
,
...
@@ -91,6 +92,7 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -91,6 +92,7 @@ class PDSTransformerEncoderLayer(nn.Module):
self
.
macaron_norm
=
None
self
.
macaron_norm
=
None
self
.
ffn_scale
=
1.0
self
.
ffn_scale
=
1.0
self
.
conv_stride
=
conv_stride
if
args
.
use_cnn_module
:
if
args
.
use_cnn_module
:
self
.
conv_norm
=
LayerNorm
(
embed_dim
)
self
.
conv_norm
=
LayerNorm
(
embed_dim
)
self
.
conv_module
=
ConvolutionModule
(
self
.
conv_module
=
ConvolutionModule
(
...
@@ -104,7 +106,6 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -104,7 +106,6 @@ class PDSTransformerEncoderLayer(nn.Module):
self
.
final_norm
=
LayerNorm
(
expand_embed_dim
)
self
.
final_norm
=
LayerNorm
(
expand_embed_dim
)
# Convolution Residual
# Convolution Residual
self
.
conv_stride
=
conv_stride
self
.
conv_res
=
nn
.
Sequential
(
self
.
conv_res
=
nn
.
Sequential
(
Permute3D
(
1
,
2
,
0
),
Permute3D
(
1
,
2
,
0
),
nn
.
Conv1d
(
embed_dim
,
expand_embed_dim
,
kernel_size
=
1
,
stride
=
conv_stride
),
nn
.
Conv1d
(
embed_dim
,
expand_embed_dim
,
kernel_size
=
1
,
stride
=
conv_stride
),
...
@@ -173,6 +174,15 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -173,6 +174,15 @@ class PDSTransformerEncoderLayer(nn.Module):
attention_heads
,
attention_heads
,
dropout
=
dropout
,
dropout
=
dropout
,
)
)
elif
self
.
attn_type
==
"reduced_rel_pos"
:
return
ReducedRelPositionMultiHeadedAttention
(
embed_dim
,
attention_heads
,
dropout
=
dropout
,
sample_ratio
=
sample_ratio
,
reduced_method
=
getattr
(
args
,
"attention_reduced_method"
,
"conv"
),
reduced_q
=
getattr
(
args
,
"attention_reduced_q"
,
False
)
)
elif
self
.
attn_type
==
"rel_pos_legacy"
:
elif
self
.
attn_type
==
"rel_pos_legacy"
:
return
LegacyRelPositionMultiHeadedAttention
(
return
LegacyRelPositionMultiHeadedAttention
(
embed_dim
,
embed_dim
,
...
@@ -284,7 +294,7 @@ class PDSTransformerEncoderLayer(nn.Module):
...
@@ -284,7 +294,7 @@ class PDSTransformerEncoderLayer(nn.Module):
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
]:
if
self
.
attn_type
in
[
"rel_pos"
,
"rel_pos_legacy"
,
"rel_selfattn"
,
"reduced_rel_pos"
]:
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
x
,
_
=
self
.
self_attn
(
x
,
_
=
self
.
self_attn
(
query
=
x
,
query
=
x
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论