Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
b970c7df
Commit
b970c7df
authored
Mar 13, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
up-sampling the representation for ctc calculation
parent
8b50c392
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
61 行增加
和
10 行删除
+61
-10
egs/iwslt2022/mt/conf/inter.yaml
+15
-0
egs/mustc/mt/conf/debug.yaml
+1
-1
egs/wmt20/mt/conf/inter.yaml
+15
-0
examples/speech_to_text/prep_audio_data.py
+1
-0
fairseq/models/speech_to_text/s2t_sate.py
+2
-0
fairseq/models/transformer_ctc.py
+24
-6
fairseq/modules/speech_to_text/adapter.py
+3
-3
没有找到文件。
egs/iwslt2022/mt/conf/inter.yaml
0 → 100644
查看文件 @
b970c7df
#ctc-weight: 0.2
intermedia-ctc-weight
:
0.3
intermedia-ctc-layers
:
2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter
:
league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process
:
sentencepiece
\ No newline at end of file
egs/mustc/mt/conf/debug.yaml
查看文件 @
b970c7df
arch
:
transformer
arch
:
transformer
_ctc
share-all-embeddings
:
True
share-all-embeddings
:
True
optimizer
:
adam
optimizer
:
adam
clip-norm
:
10.0
clip-norm
:
10.0
...
...
egs/wmt20/mt/conf/inter.yaml
0 → 100644
查看文件 @
b970c7df
#ctc-weight: 0.2
intermedia-ctc-weight
:
0.3
intermedia-ctc-layers
:
10,20
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter
:
league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process
:
sentencepiece
\ No newline at end of file
examples/speech_to_text/prep_audio_data.py
查看文件 @
b970c7df
...
@@ -107,6 +107,7 @@ class AudioDataset(Dataset):
...
@@ -107,6 +107,7 @@ class AudioDataset(Dataset):
for
idx
,
u
in
enumerate
(
utterances
):
for
idx
,
u
in
enumerate
(
utterances
):
segments
[
idx
][
_lang
]
=
u
segments
[
idx
][
_lang
]
=
u
# split = split.replace("_gen", "")
# Gather info
# Gather info
self
.
data
=
dict
()
self
.
data
=
dict
()
if
self
.
mode
==
"easy"
:
if
self
.
mode
==
"easy"
:
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
b970c7df
...
@@ -156,9 +156,11 @@ class TextEncoder(FairseqEncoder):
...
@@ -156,9 +156,11 @@ class TextEncoder(FairseqEncoder):
super
()
.
__init__
(
None
)
super
()
.
__init__
(
None
)
self
.
register_buffer
(
"version"
,
torch
.
Tensor
([
3
]))
# for consistent
embed_dim
=
args
.
encoder_embed_dim
embed_dim
=
args
.
encoder_embed_dim
layer_num
=
args
.
text_encoder_layers
layer_num
=
args
.
text_encoder_layers
self
.
layer_num
=
layer_num
self
.
layer_num
=
layer_num
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
args
.
no_scale_embedding
:
if
args
.
no_scale_embedding
:
...
...
fairseq/models/transformer_ctc.py
查看文件 @
b970c7df
...
@@ -672,6 +672,14 @@ class TransformerCTCEncoder(FairseqEncoder):
...
@@ -672,6 +672,14 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens
,
return_all_hiddens
,
token_embeddings
)
token_embeddings
)
def
upsample
(
self
,
x
,
ratio
=
2
):
if
ratio
<=
1
:
return
x
seq_len
,
bsz
,
dim
=
x
.
size
()
x
=
x
.
unsqueeze
(
0
)
.
expand
(
ratio
,
-
1
,
-
1
,
-
1
)
.
reshape
(
-
1
,
bsz
,
dim
)
return
x
# TorchScript doesn't support super() method so that the scriptable Subclass
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# Current workaround is to add a helper function with different name and
...
@@ -749,7 +757,7 @@ class TransformerCTCEncoder(FairseqEncoder):
...
@@ -749,7 +757,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
# CTC
if
self
.
use_ctc
and
self
.
inter_ctc
and
self
.
ctc_layer
==
layer_idx
:
if
self
.
use_ctc
and
self
.
inter_ctc
and
self
.
ctc_layer
==
layer_idx
:
ctc_logit
=
self
.
ctc
(
x
.
clone
(
))
ctc_logit
=
self
.
ctc
(
self
.
upsample
(
x
.
clone
()
))
# Intermedia CTC
# Intermedia CTC
if
layer_idx
in
self
.
intermedia_ctc_layers
:
if
layer_idx
in
self
.
intermedia_ctc_layers
:
...
@@ -759,11 +767,15 @@ class TransformerCTCEncoder(FairseqEncoder):
...
@@ -759,11 +767,15 @@ class TransformerCTCEncoder(FairseqEncoder):
break
break
norm_x
=
self
.
layer_norm
(
x
)
norm_x
=
self
.
layer_norm
(
x
)
logit
=
self
.
ctc
(
norm_x
)
up_x
=
self
.
upsample
(
norm_x
)
up_logit
=
self
.
ctc
(
up_x
)
intermedia_ctc_logits
.
append
(
logit
)
intermedia_ctc_logits
.
append
(
up_logit
)
prob
=
utils
.
softmax
(
logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
up_prob
=
utils
.
softmax
(
up_logit
/
self
.
intermedia_temperature
,
dim
=-
1
)
x
,
encoder_padding_mask
=
self
.
adapter
([
x
,
prob
],
encoder_padding_mask
)
up_prob
=
up_prob
.
permute
(
1
,
2
,
0
)
prob
=
nn
.
functional
.
max_pool1d
(
up_prob
,
kernel_size
=
2
,
stride
=
2
)
prob
=
prob
.
permute
(
2
,
0
,
1
)
x
,
_
=
self
.
adapter
([
x
,
prob
])
if
self
.
history
is
not
None
:
if
self
.
history
is
not
None
:
self
.
history
.
push
(
x
)
self
.
history
.
push
(
x
)
...
@@ -775,7 +787,12 @@ class TransformerCTCEncoder(FairseqEncoder):
...
@@ -775,7 +787,12 @@ class TransformerCTCEncoder(FairseqEncoder):
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
if
self
.
use_ctc
and
ctc_logit
is
None
:
if
self
.
use_ctc
and
ctc_logit
is
None
:
ctc_logit
=
self
.
ctc
(
x
)
ctc_logit
=
self
.
ctc
(
self
.
upsample
(
x
))
ctc_padding_mask
=
encoder_padding_mask
if
ctc_logit
is
not
None
or
len
(
intermedia_ctc_logits
)
!=
0
:
bsz
=
encoder_padding_mask
.
size
(
0
)
ctc_padding_mask
=
encoder_padding_mask
.
unsqueeze
(
-
1
)
.
expand
(
-
1
,
-
1
,
2
)
.
reshape
(
bsz
,
-
1
)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# `forward` so we use a dictionary instead.
...
@@ -784,6 +801,7 @@ class TransformerCTCEncoder(FairseqEncoder):
...
@@ -784,6 +801,7 @@ class TransformerCTCEncoder(FairseqEncoder):
return
{
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_out"
:
[
x
],
# T x B x C
"ctc_logit"
:
[]
if
ctc_logit
is
None
else
[
ctc_logit
],
# T x B x C
"ctc_logit"
:
[]
if
ctc_logit
is
None
else
[
ctc_logit
],
# T x B x C
"ctc_padding_mask"
:
[
ctc_padding_mask
],
"intermedia_ctc_logits"
:
intermedia_ctc_logits
,
# T x B x C
"intermedia_ctc_logits"
:
intermedia_ctc_logits
,
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[
encoder_embedding
],
# B x T x C
"encoder_embedding"
:
[
encoder_embedding
],
# B x T x C
...
...
fairseq/modules/speech_to_text/adapter.py
查看文件 @
b970c7df
...
@@ -95,14 +95,13 @@ class Adapter(nn.Module):
...
@@ -95,14 +95,13 @@ class Adapter(nn.Module):
if
self
.
distribution_cutoff
is
not
None
:
if
self
.
distribution_cutoff
is
not
None
:
logger
.
info
(
"Distribution cutoff:
%
d"
%
int
(
strategy
))
logger
.
info
(
"Distribution cutoff:
%
d"
%
int
(
strategy
))
def
forward
(
self
,
x
,
padding
):
def
forward
(
self
,
x
,
padding
=
None
):
representation
,
distribution
=
x
representation
,
distribution
=
x
distribution
=
distribution
.
type_as
(
representation
)
distribution
=
distribution
.
type_as
(
representation
)
seq_len
,
bsz
,
dim
=
representation
.
size
()
seq_len
,
bsz
,
dim
=
representation
.
size
()
org_distribution
=
distribution
org_distribution
=
distribution
distribution
=
distribution
.
view
(
-
1
,
distribution
.
size
(
-
1
))
distribution
=
distribution
.
contiguous
()
.
view
(
-
1
,
distribution
.
size
(
-
1
))
lengths
=
(
~
padding
)
.
long
()
.
sum
(
-
1
)
if
self
.
adapter_type
==
"linear"
:
if
self
.
adapter_type
==
"linear"
:
out
=
self
.
linear_adapter
(
representation
)
out
=
self
.
linear_adapter
(
representation
)
...
@@ -140,6 +139,7 @@ class Adapter(nn.Module):
...
@@ -140,6 +139,7 @@ class Adapter(nn.Module):
elif
self
.
adapter_type
==
"shrink"
:
elif
self
.
adapter_type
==
"shrink"
:
from
itertools
import
groupby
from
itertools
import
groupby
lengths
=
(
~
padding
)
.
long
()
.
sum
(
-
1
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
batch_predicted
=
[]
batch_predicted
=
[]
prob_ctc
=
org_distribution
.
transpose
(
0
,
1
)
# T x B x D -> B x T x D
prob_ctc
=
org_distribution
.
transpose
(
0
,
1
)
# T x B x D -> B x T x D
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论