Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
408e2b95
Commit
408e2b95
authored
Apr 06, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bugs during prepare and CTC decoding
parent
8f084189
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
18 行增加
和
7 行删除
+18
-7
examples/speech_to_text/prep_audio_data.py
+2
-2
fairseq/criterions/ctc.py
+1
-1
fairseq/models/speech_to_text/pdss2t_transformer.py
+2
-1
fairseq/models/speech_to_text/s2t_ctc.py
+2
-1
fairseq/modules/espnet_multihead_attention.py
+8
-2
fairseq/modules/multihead_attention.py
+3
-0
没有找到文件。
examples/speech_to_text/prep_audio_data.py
查看文件 @
408e2b95
...
@@ -185,7 +185,7 @@ class AudioDataset(Dataset):
...
@@ -185,7 +185,7 @@ class AudioDataset(Dataset):
if
need_waveform
:
if
need_waveform
:
offset
=
item
.
get
(
'offset'
,
False
)
offset
=
item
.
get
(
'offset'
,
False
)
if
offset
:
if
offset
is
not
False
:
waveform
,
sample_rate
=
torchaudio
.
load
(
audio
,
waveform
,
sample_rate
=
torchaudio
.
load
(
audio
,
frame_offset
=
offset
,
frame_offset
=
offset
,
num_frames
=
item
[
"n_frames"
])
num_frames
=
item
[
"n_frames"
])
...
@@ -331,7 +331,7 @@ def process(args):
...
@@ -331,7 +331,7 @@ def process(args):
audio_path
=
item
[
"audio"
]
audio_path
=
item
[
"audio"
]
# add offset and frames info
# add offset and frames info
if
item
.
get
(
"offset"
,
False
):
if
item
.
get
(
"offset"
,
False
)
is
not
False
:
audio_path
=
f
"{audio_path}:{item['offset']}:{n_frames}"
audio_path
=
f
"{audio_path}:{item['offset']}:{n_frames}"
manifest
[
"audio"
]
.
append
(
audio_path
)
manifest
[
"audio"
]
.
append
(
audio_path
)
else
:
else
:
...
...
fairseq/criterions/ctc.py
查看文件 @
408e2b95
...
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
...
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
@dataclass
@dataclass
class
CtcCriterionConfig
(
FairseqDataclass
):
class
CtcCriterionConfig
(
FairseqDataclass
):
zero_infinity
:
bool
=
field
(
zero_infinity
:
bool
=
field
(
default
=
Fals
e
,
default
=
Tru
e
,
metadata
=
{
"help"
:
"zero inf loss when source length <= target length"
},
metadata
=
{
"help"
:
"zero inf loss when source length <= target length"
},
)
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
...
...
fairseq/models/speech_to_text/pdss2t_transformer.py
查看文件 @
408e2b95
...
@@ -882,7 +882,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
...
@@ -882,7 +882,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if
self
.
inter_ctc
:
if
self
.
inter_ctc
:
logger
.
info
(
"Intermedia CTC loss in layer
%
d"
%
self
.
ctc_layer
)
logger
.
info
(
"Intermedia CTC loss in layer
%
d"
%
self
.
ctc_layer
)
embed_dim
=
self
.
pds_embed_dims
[
-
1
]
# embed_dim = self.pds_embed_dims[-1]
embed_dim
=
self
.
embed_dim
if
self
.
inter_ctc
:
if
self
.
inter_ctc
:
ctc_layer
=
self
.
ctc_layer
ctc_layer
=
self
.
ctc_layer
for
i
in
range
(
self
.
pds_stages
):
for
i
in
range
(
self
.
pds_stages
):
...
...
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
408e2b95
...
@@ -638,8 +638,9 @@ class CTCDecoder(object):
...
@@ -638,8 +638,9 @@ class CTCDecoder(object):
src_lengths
=
src_lengths
)
src_lengths
=
src_lengths
)
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
ctc_logit
=
encoder_outs
[
"ctc_logit"
][
0
]
.
transpose
(
0
,
1
)
logit_length
=
(
~
encoder_outs
[
"encoder_padding_mask"
][
0
])
.
long
()
.
sum
(
-
1
)
beam_results
,
beam_scores
,
time_steps
,
out_lens
=
self
.
ctc_decoder
.
decode
(
beam_results
,
beam_scores
,
time_steps
,
out_lens
=
self
.
ctc_decoder
.
decode
(
utils
.
softmax
(
ctc_logit
,
-
1
),
src_lengths
utils
.
softmax
(
ctc_logit
,
-
1
),
logit_length
)
)
finalized
=
[]
finalized
=
[]
...
...
fairseq/modules/espnet_multihead_attention.py
查看文件 @
408e2b95
...
@@ -10,6 +10,7 @@ import math
...
@@ -10,6 +10,7 @@ import math
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
logging
from
fairseq.modules.rotary_positional_embedding
import
(
from
fairseq.modules.rotary_positional_embedding
import
(
RotaryPositionalEmbedding
,
RotaryPositionalEmbedding
,
apply_rotary_pos_emb
,
apply_rotary_pos_emb
,
...
@@ -76,12 +77,17 @@ class ESPNETMultiHeadedAttention(nn.Module):
...
@@ -76,12 +77,17 @@ class ESPNETMultiHeadedAttention(nn.Module):
-
1e8
if
scores
.
dtype
==
torch
.
float32
else
-
1e4
-
1e8
if
scores
.
dtype
==
torch
.
float32
else
-
1e4
# float("-inf"), # (batch, head, time1, time2)
# float("-inf"), # (batch, head, time1, time2)
)
)
# self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
scores
=
scores
.
clamp
(
min
=-
1e8
if
scores
.
dtype
==
torch
.
float32
else
-
1e4
,
scores
=
scores
.
clamp
(
min
=-
1e8
if
scores
.
dtype
==
torch
.
float32
else
-
1e4
,
max
=
1e8
if
scores
.
dtype
==
torch
.
float32
else
1e4
)
max
=
1e8
if
scores
.
dtype
==
torch
.
float32
else
1e4
)
self
.
attn
=
F
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
torch
.
float32
)
.
type_as
(
scores
)
# (batch, head, time1, time2)
self
.
attn
=
F
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
torch
.
float32
)
.
type_as
(
scores
)
# (batch, head, time1, time2)
if
torch
.
isnan
(
self
.
attn
)
.
any
():
if
torch
.
isnan
(
self
.
attn
)
.
any
():
import
logging
logging
.
warning
(
"Tensor attention scores has nan."
)
logging
.
error
(
"Tensor attention scores has nan."
)
# torch.save(scores, "scores.pt")
# torch.save(self.attn, "attn.pt")
# exit()
p_attn
=
self
.
dropout
(
self
.
attn
)
p_attn
=
self
.
dropout
(
self
.
attn
)
x
=
torch
.
matmul
(
p_attn
,
value
)
# (batch, head, time1, d_k)
x
=
torch
.
matmul
(
p_attn
,
value
)
# (batch, head, time1, d_k)
...
...
fairseq/modules/multihead_attention.py
查看文件 @
408e2b95
...
@@ -350,6 +350,9 @@ class MultiheadAttention(nn.Module):
...
@@ -350,6 +350,9 @@ class MultiheadAttention(nn.Module):
if
before_softmax
:
if
before_softmax
:
return
attn_weights
,
v
return
attn_weights
,
v
attn_weights
=
attn_weights
.
clamp
(
min
=-
1e8
if
attn_weights
.
dtype
==
torch
.
float32
else
-
1e4
,
max
=
1e8
if
attn_weights
.
dtype
==
torch
.
float32
else
1e4
)
attn_weights_float
=
F
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
)
attn_weights_float
=
F
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_probs
=
self
.
dropout_module
(
attn_weights
)
attn_probs
=
self
.
dropout_module
(
attn_weights
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论