Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
81caa4ca
Commit
81caa4ca
authored
4 years ago
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the speed perturb for the must-c dataset
parent
6a2f4065
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
128 行增加
和
49 行删除
+128
-49
egs/mustc/st/run.sh
+4
-0
examples/speech_to_text/prep_mustc_data.py
+124
-48
fairseq_cli/generate.py
+0
-1
没有找到文件。
egs/mustc/st/run.sh
查看文件 @
81caa4ca
...
...
@@ -41,6 +41,7 @@ share_dict=1
org_data_dir
=
/media/data/
${
dataset
}
data_dir
=
~/st/data/
${
dataset
}
/st
data_dir
=
~/st/data/
${
dataset
}
/st_perturb_2
test_subset
=(
tst-COMMON
)
# exp
...
...
@@ -104,6 +105,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if
[[
!
-e
${
data_dir
}
/
${
lang
}
]]
;
then
mkdir
-p
${
data_dir
}
/
${
lang
}
fi
source
audio/bin/activate
cmd
=
"python
${
root_dir
}
/examples/speech_to_text/prep_mustc_data.py
--data-root
${
org_data_dir
}
...
...
@@ -118,6 +120,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
cmd
=
"python
${
root_dir
}
/examples/speech_to_text/prep_mustc_data.py
--data-root
${
org_data_dir
}
--output-root
${
data_dir
}
--speed-perturb
--task st
--add-src
--cmvn-type utterance
...
...
@@ -133,6 +136,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo
-e
"
\0
33[34mRun command:
\n
${
cmd
}
\0
33[0m"
[[
$eval
-eq
1
]]
&&
eval
${
cmd
}
deactivate
fi
data_dir
=
${
data_dir
}
/
${
lang
}
...
...
This diff is collapsed.
Click to expand it.
examples/speech_to_text/prep_mustc_data.py
查看文件 @
81caa4ca
...
...
@@ -13,6 +13,7 @@ from itertools import groupby
from
tempfile
import
NamedTemporaryFile
from
typing
import
Tuple
import
string
import
pickle
import
numpy
as
np
import
pandas
as
pd
...
...
@@ -28,7 +29,6 @@ from examples.speech_to_text.data_utils import (
save_df_to_tsv
,
cal_gcmvn_stats
,
)
from
torch
import
Tensor
from
torch.utils.data
import
Dataset
from
tqdm
import
tqdm
...
...
@@ -46,14 +46,14 @@ class MUSTC(Dataset):
utterance_id
"""
SPLITS
=
[
"
train"
,
"dev"
,
"tst-COMMON"
,
"tst-HE
"
]
SPLITS
=
[
"
dev"
,
"tst-COMMON"
,
"tst-HE"
,
"train
"
]
LANGUAGES
=
[
"de"
,
"es"
,
"fr"
,
"it"
,
"nl"
,
"pt"
,
"ro"
,
"ru"
]
def
__init__
(
self
,
root
:
str
,
lang
:
str
,
split
:
str
)
->
None
:
def
__init__
(
self
,
root
:
str
,
lang
:
str
,
split
:
str
,
speed_perturb
:
bool
=
False
)
->
None
:
assert
split
in
self
.
SPLITS
and
lang
in
self
.
LANGUAGES
_root
=
Path
(
root
)
/
f
"en-{lang}"
/
"data"
/
split
wav_root
,
txt_root
=
_root
/
"wav"
,
_root
/
"txt"
assert
_root
.
is_dir
()
and
wav_root
.
is_dir
()
and
txt_root
.
is_dir
()
assert
_root
.
is_dir
()
and
wav_root
.
is_dir
()
and
txt_root
.
is_dir
()
,
(
_root
,
wav_root
,
txt_root
)
# Load audio segments
try
:
import
yaml
...
...
@@ -61,6 +61,8 @@ class MUSTC(Dataset):
print
(
"Please install PyYAML to load the MuST-C YAML files"
)
with
open
(
txt_root
/
f
"{split}.yaml"
)
as
f
:
segments
=
yaml
.
load
(
f
,
Loader
=
yaml
.
BaseLoader
)
self
.
speed_perturb
=
[
0.9
,
1.0
,
1.1
]
if
speed_perturb
and
split
.
startswith
(
"train"
)
else
None
# Load source and target utterances
for
_lang
in
[
"en"
,
lang
]:
with
open
(
txt_root
/
f
"{split}.{_lang}"
)
as
f
:
...
...
@@ -72,7 +74,8 @@ class MUSTC(Dataset):
self
.
data
=
[]
for
wav_filename
,
_seg_group
in
groupby
(
segments
,
lambda
x
:
x
[
"wav"
]):
wav_path
=
wav_root
/
wav_filename
sample_rate
=
torchaudio
.
info
(
wav_path
.
as_posix
())[
0
]
.
rate
# sample_rate = torchaudio.info(wav_path.as_posix())[0].rate
sample_rate
=
torchaudio
.
info
(
wav_path
.
as_posix
())
.
sample_rate
seg_group
=
sorted
(
_seg_group
,
key
=
lambda
x
:
x
[
"offset"
])
for
i
,
segment
in
enumerate
(
seg_group
):
offset
=
int
(
float
(
segment
[
"offset"
])
*
sample_rate
)
...
...
@@ -91,10 +94,52 @@ class MUSTC(Dataset):
)
)
def
__getitem__
(
self
,
n
:
int
)
->
Tuple
[
Tensor
,
int
,
str
,
str
,
str
,
str
]
:
def
__getitem__
(
self
,
n
:
int
):
wav_path
,
offset
,
n_frames
,
sr
,
src_utt
,
tgt_utt
,
spk_id
,
utt_id
=
self
.
data
[
n
]
waveform
,
_
=
torchaudio
.
load
(
wav_path
,
offset
=
offset
,
num_frames
=
n_frames
)
return
waveform
,
sr
,
src_utt
,
tgt_utt
,
spk_id
,
utt_id
items
=
[]
if
self
.
speed_perturb
is
None
:
waveform
,
_
=
torchaudio
.
load
(
wav_path
,
frame_offset
=
offset
,
num_frames
=
n_frames
)
items
.
append
([
waveform
,
sr
,
src_utt
,
tgt_utt
,
spk_id
,
utt_id
])
else
:
for
speed
in
self
.
speed_perturb
:
sp_utt_id
=
f
"sp{speed}_"
+
utt_id
if
speed
==
1.0
:
waveform
,
_
=
torchaudio
.
load
(
wav_path
,
frame_offset
=
offset
,
num_frames
=
n_frames
)
else
:
waveform
,
_
=
torchaudio
.
load
(
wav_path
,
frame_offset
=
offset
,
num_frames
=
n_frames
)
effects
=
[
[
"speed"
,
f
"{speed}"
],
[
"rate"
,
f
"{sr}"
]
]
waveform
,
_
=
torchaudio
.
sox_effects
.
apply_effects_tensor
(
waveform
,
sr
,
effects
)
items
.
append
([
waveform
,
sr
,
src_utt
,
tgt_utt
,
spk_id
,
sp_utt_id
])
return
items
def
get_fast
(
self
,
n
:
int
):
wav_path
,
offset
,
n_frames
,
sr
,
src_utt
,
tgt_utt
,
spk_id
,
utt_id
=
self
.
data
[
n
]
items
=
[]
if
self
.
speed_perturb
is
None
:
items
.
append
([
wav_path
,
sr
,
src_utt
,
tgt_utt
,
spk_id
,
utt_id
])
else
:
for
speed
in
self
.
speed_perturb
:
sp_utt_id
=
f
"sp{speed}_"
+
utt_id
items
.
append
([
wav_path
,
sr
,
src_utt
,
tgt_utt
,
spk_id
,
sp_utt_id
])
return
items
def
get_src_text
(
self
):
src_text
=
[]
for
item
in
self
.
data
:
src_text
.
append
(
item
[
4
])
return
src_text
def
get_tgt_text
(
self
):
tgt_text
=
[]
for
item
in
self
.
data
:
tgt_text
.
append
(
item
[
5
])
return
tgt_text
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
...
...
@@ -116,33 +161,77 @@ def process(args):
feature_root
=
output_root
/
"fbank80"
feature_root
.
mkdir
(
exist_ok
=
True
)
zip_path
=
output_root
/
"fbank80.zip"
if
args
.
overwrite
or
not
Path
.
exists
(
zip_path
):
manifest_dict
=
{}
train_text
=
[]
gen_feature_flag
=
False
if
not
Path
.
exists
(
zip_path
):
gen_feature_flag
=
True
for
split
in
MUSTC
.
SPLITS
:
if
not
Path
.
exists
(
output_root
/
f
"{split}_{args.task}.tsv"
):
gen_feature_flag
=
True
break
if
args
.
overwrite
or
gen_feature_flag
:
for
split
in
MUSTC
.
SPLITS
:
print
(
f
"Fetching split {split}..."
)
dataset
=
MUSTC
(
root
.
as_posix
(),
lang
,
split
)
dataset
=
MUSTC
(
root
.
as_posix
(),
lang
,
split
,
args
.
speed_perturb
)
is_train_split
=
split
.
startswith
(
"train"
)
print
(
"Extracting log mel filter bank features..."
)
if
split
==
'train'
and
args
.
cmvn_type
==
"global"
:
if
is_train_split
and
args
.
cmvn_type
==
"global"
:
print
(
"And estimating cepstral mean and variance stats..."
)
gcmvn_feature_list
=
[]
for
waveform
,
sample_rate
,
_
,
_
,
_
,
utt_id
in
tqdm
(
dataset
):
features
=
extract_fbank_features
(
waveform
,
sample_rate
)
manifest
=
{
c
:
[]
for
c
in
MANIFEST_COLUMNS
}
if
args
.
task
==
"st"
and
args
.
add_src
:
manifest
[
"src_text"
]
=
[]
np
.
save
(
(
feature_root
/
f
"{utt_id}.npy"
)
.
as_posix
(),
features
)
for
items
in
tqdm
(
dataset
):
for
item
in
items
:
# waveform, sample_rate, _, _, _, utt_id = item
waveform
,
sr
,
src_utt
,
tgt_utt
,
speaker_id
,
utt_id
=
item
features_path
=
(
feature_root
/
f
"{utt_id}.npy"
)
.
as_posix
()
features
=
extract_fbank_features
(
waveform
,
sr
,
Path
(
features_path
))
# np.save(
# (feature_root / f"{utt_id}.npy").as_posix(),
# features
# )
manifest
[
"id"
]
.
append
(
utt_id
)
duration_ms
=
int
(
waveform
.
size
(
1
)
/
sr
*
1000
)
# duration_ms = int(time_dict[utt_id] / sr * 1000)
manifest
[
"n_frames"
]
.
append
(
int
(
1
+
(
duration_ms
-
25
)
/
10
))
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
manifest
[
"tgt_text"
]
.
append
(
src_utt
if
args
.
task
==
"asr"
else
tgt_utt
)
if
args
.
task
==
"st"
and
args
.
add_src
:
manifest
[
"src_text"
]
.
append
(
src_utt
)
manifest
[
"speaker"
]
.
append
(
speaker_id
)
if
split
==
'train'
and
args
.
cmvn_type
==
"global"
and
not
utt_id
.
startswith
(
"sp"
):
if
len
(
gcmvn_feature_list
)
<
args
.
gcmvn_max_num
:
gcmvn_feature_list
.
append
(
features
)
if
split
==
'train'
and
args
.
cmvn_type
==
"global"
:
if
len
(
gcmvn_feature_list
)
<
args
.
gcmvn_max_num
:
gcmvn_feature_list
.
append
(
features
)
if
is_train_split
and
args
.
size
!=
-
1
and
len
(
manifest
[
"id"
])
>
args
.
size
:
break
if
split
==
'train'
and
args
.
cmvn_type
==
"global"
:
if
is_train_split
:
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
train_text
.
extend
(
list
(
set
(
tuple
(
manifest
[
"src_text"
]))))
train_text
.
extend
(
dataset
.
get_tgt_text
())
if
is_train_split
and
args
.
cmvn_type
==
"global"
:
# Estimate and save cmv
stats
=
cal_gcmvn_stats
(
gcmvn_feature_list
)
with
open
(
output_root
/
"gcmvn.npz"
,
"wb"
)
as
f
:
np
.
savez
(
f
,
mean
=
stats
[
"mean"
],
std
=
stats
[
"std"
])
manifest_dict
[
split
]
=
manifest
# Pack features into ZIP
print
(
"ZIPing features..."
)
create_zip
(
feature_root
,
zip_path
)
...
...
@@ -159,33 +248,13 @@ def process(args):
zip_manifest
=
get_zip_manifest
(
zip_path
)
# Generate TSV manifest
print
(
"Generating manifest..."
)
for
split
in
MUSTC
.
SPLITS
:
for
split
,
manifest
in
manifest_dict
.
items
():
is_train_split
=
split
.
startswith
(
"train"
)
manifest
=
{
c
:
[]
for
c
in
MANIFEST_COLUMNS
}
if
args
.
task
==
"st"
and
args
.
add_src
:
manifest
[
"src_text"
]
=
[]
dataset
=
MUSTC
(
args
.
data_root
,
lang
,
split
)
for
wav
,
sr
,
src_utt
,
tgt_utt
,
speaker_id
,
utt_id
in
tqdm
(
dataset
):
manifest
[
"id"
]
.
append
(
utt_id
)
for
utt_id
in
manifest
[
"id"
]:
manifest
[
"audio"
]
.
append
(
zip_manifest
[
utt_id
])
duration_ms
=
int
(
wav
.
size
(
1
)
/
sr
*
1000
)
manifest
[
"n_frames"
]
.
append
(
int
(
1
+
(
duration_ms
-
25
)
/
10
))
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
if
args
.
rm_punc_src
:
for
w
in
string
.
punctuation
:
src_utt
=
src_utt
.
replace
(
w
,
""
)
manifest
[
"tgt_text"
]
.
append
(
src_utt
if
args
.
task
==
"asr"
else
tgt_utt
)
if
args
.
task
==
"st"
and
args
.
add_src
:
manifest
[
"src_text"
]
.
append
(
src_utt
)
manifest
[
"speaker"
]
.
append
(
speaker_id
)
if
is_train_split
and
args
.
size
!=
-
1
and
len
(
manifest
[
"id"
])
>
args
.
size
:
break
if
is_train_split
:
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
train_text
.
extend
(
manifest
[
"src_text"
])
train_text
.
extend
(
manifest
[
"tgt_text"
])
df
=
pd
.
DataFrame
.
from_dict
(
manifest
)
df
=
filter_manifest_df
(
df
,
is_train_split
=
is_train_split
)
save_df_to_tsv
(
df
,
output_root
/
f
"{split}_{args.task}.tsv"
)
...
...
@@ -207,7 +276,9 @@ def process(args):
for
split
in
MUSTC
.
SPLITS
:
if
split
.
startswith
(
"train"
):
dataset
=
MUSTC
(
args
.
data_root
,
lang
,
split
)
for
wav
,
sr
,
src_utt
,
tgt_utt
,
speaker_id
,
utt_id
in
dataset
:
src_text
=
dataset
.
get_src_text
()
tgt_text
=
dataset
.
get_tgt_text
()
for
src_utt
,
tgt_utt
in
zip
(
src_text
,
tgt_text
):
if
args
.
task
==
"st"
and
args
.
add_src
and
args
.
share
:
if
args
.
lowercase_src
:
src_utt
=
src_utt
.
lower
()
...
...
@@ -215,6 +286,7 @@ def process(args):
src_utt
=
src_utt
.
translate
(
None
,
string
.
punctuation
)
train_text
.
append
(
src_utt
)
train_text
.
append
(
tgt_utt
)
with
NamedTemporaryFile
(
mode
=
"w"
)
as
f
:
for
t
in
train_text
:
f
.
write
(
t
+
"
\n
"
)
...
...
@@ -242,8 +314,9 @@ def process(args):
asr_spm_filename
=
asr_spm_filename
,
share_src_and_tgt
=
True
if
args
.
task
==
"asr"
else
False
)
# Clean up
shutil
.
rmtree
(
feature_root
)
#
shutil.rmtree(feature_root)
def
process_joint
(
args
):
...
...
@@ -305,8 +378,11 @@ def main():
parser
.
add_argument
(
"--vocab-size"
,
default
=
8000
,
type
=
int
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"asr"
,
"st"
])
parser
.
add_argument
(
"--size"
,
default
=-
1
,
type
=
int
)
parser
.
add_argument
(
"--speed-perturb"
,
action
=
"store_true"
,
default
=
False
,
help
=
"apply speed perturbation on wave file"
)
parser
.
add_argument
(
"--joint"
,
action
=
"store_true"
,
help
=
""
)
parser
.
add_argument
(
"--share"
,
action
=
"store_true"
,
help
=
"share the transcription and translation"
)
parser
.
add_argument
(
"--share"
,
action
=
"store_true"
,
help
=
"share the tokenizer and dictionary of the transcription and translation"
)
parser
.
add_argument
(
"--add-src"
,
action
=
"store_true"
,
help
=
"add the src text for st task"
)
parser
.
add_argument
(
"--asr-prefix"
,
type
=
str
,
help
=
"prefix of the asr dict"
)
parser
.
add_argument
(
"--lowercase-src"
,
action
=
"store_true"
,
help
=
"lowercase the source text"
)
...
...
This diff is collapsed.
Click to expand it.
fairseq_cli/generate.py
查看文件 @
81caa4ca
...
...
@@ -81,7 +81,6 @@ def _main(cfg: DictConfig, output_file):
# Load dataset splits
task
=
tasks
.
setup_task
(
cfg
.
task
)
# Set dictionaries
try
:
src_dict
=
getattr
(
task
,
"source_dictionary"
,
None
)
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论