Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
a201a883
Commit
a201a883
authored
2 years ago
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Try more settings of adapter
parent
5d84c743
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
159 行增加
和
69 行删除
+159
-69
fairseq/criterions/ctc.py
+14
-2
fairseq/models/speech_to_text/s2t_ctc.py
+2
-0
fairseq/models/speech_to_text/s2t_sate.py
+28
-11
fairseq/models/speech_to_text/s2t_transformer.py
+33
-12
fairseq/models/transformer_ctc.py
+42
-27
fairseq/modules/speech_to_text/adapter.py
+28
-13
fairseq/modules/speech_to_text/ctc.py
+11
-3
fairseq_cli/generate.py
+1
-1
没有找到文件。
fairseq/criterions/ctc.py
查看文件 @
a201a883
...
...
@@ -56,6 +56,11 @@ class CtcCriterionConfig(FairseqDataclass):
default
=
0.0
,
metadata
=
{
"help"
:
"weight of interleaved CTC loss"
},
)
aligned_target_ctc
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"calculate target ctc by aligned text"
},
)
target_ctc_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight of CTC loss for target sentence"
},
...
...
@@ -157,6 +162,7 @@ class CtcCriterion(FairseqCriterion):
self
.
cal_all_ctc
=
cfg
.
cal_all_ctc
self
.
ctc_weight
=
ctc_weight
self
.
interleaved_ctc_weight
=
cfg
.
interleaved_ctc_weight
self
.
aligned_target_ctc
=
cfg
.
aligned_target_ctc
self
.
target_ctc_weight
=
cfg
.
target_ctc_weight
self
.
target_interleaved_ctc_weight
=
cfg
.
target_interleaved_ctc_weight
...
...
@@ -314,6 +320,12 @@ class CtcCriterion(FairseqCriterion):
ctc_self_distill_num
+=
1
return
ctc_self_distill_num
,
ctc_self_distill_loss
def
get_target_text
(
self
,
sample
):
if
self
.
aligned_target_ctc
and
"aligned_target"
in
sample
:
return
sample
[
"aligned_target"
][
"tokens"
]
else
:
return
sample
[
"target"
]
def
compute_ctc_loss
(
self
,
model
,
sample
,
net_output
,
logging_output
):
if
"transcript"
in
sample
:
tokens
=
sample
[
"transcript"
][
"tokens"
]
...
...
@@ -405,7 +417,7 @@ class CtcCriterion(FairseqCriterion):
target_interleaved_ctc_loss
=
0
target_interleaved_ctc_num
=
0
if
self
.
use_target_ctc
:
target_tokens
=
s
ample
[
"target"
]
target_tokens
=
s
elf
.
get_target_text
(
sample
)
target_pad_mask
=
(
target_tokens
!=
self
.
pad_idx
)
&
(
target_tokens
!=
self
.
eos_idx
)
target_no_padding_mask
=
~
target_pad_mask
...
...
@@ -557,7 +569,7 @@ class CtcCriterion(FairseqCriterion):
if
target_lprobs
is
not
None
:
target_lprobs_t
=
target_lprobs
.
transpose
(
0
,
1
)
.
float
()
.
contiguous
()
.
cpu
()
target_tokens
=
s
ample
[
"target"
]
target_tokens
=
s
elf
.
get_target_text
(
sample
)
if
mixup
:
idx
=
mixup_idx1
if
mixup_coef
>
0.5
else
mixup_idx2
target_tokens
=
target_tokens
[
idx
]
...
...
This diff is collapsed.
Click to expand it.
fairseq/models/speech_to_text/s2t_ctc.py
查看文件 @
a201a883
...
...
@@ -283,6 +283,8 @@ def base_architecture(args):
args
.
sae_out_norm
=
getattr
(
args
,
"sae_out_norm"
,
False
)
args
.
sae_drop_prob
=
getattr
(
args
,
"sae_drop_prob"
,
0
)
args
.
sae_distribution_cutoff
=
getattr
(
args
,
"sae_distribution_cutoff"
,
None
)
args
.
sae_distribution_hard
=
getattr
(
args
,
"sae_distribution_hard"
,
False
)
args
.
sae_gumbel
=
getattr
(
args
,
"sae_gumbel"
,
False
)
# mixup
args
.
inter_mixup
=
getattr
(
args
,
"inter_mixup"
,
False
)
...
...
This diff is collapsed.
Click to expand it.
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
a201a883
import
logging
import
math
import
os
import
torch
import
torch.nn
as
nn
...
...
@@ -124,6 +125,11 @@ class S2TSATEModel(S2TTransformerModel):
)
# target CTC
parser
.
add_argument
(
"--target-sae-adapter"
,
type
=
str
,
help
=
"adapter type of target sae "
,
)
parser
.
add_argument
(
"--target-ctc-layer"
,
default
=
0
,
type
=
int
,
...
...
@@ -300,7 +306,6 @@ class TextualEncoder(FairseqEncoder):
self
.
ctc
.
ctc_projection
.
weight
.
size
()
==
embed_tokens
.
weight
.
size
():
self
.
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
self
.
interleaved_ctc_temperature
=
args
.
interleaved_ctc_temperature
self
.
interleaved_ctc_drop_prob
=
args
.
interleaved_ctc_drop_prob
self
.
interleaved_ctc_layers
=
[]
self
.
target_interleaved_ctc_layers
=
getattr
(
args
,
"target_interleaved_ctc_layers"
,
None
)
...
...
@@ -330,11 +335,14 @@ class TextualEncoder(FairseqEncoder):
"embed_norm"
:
getattr
(
args
,
"sae_embed_norm"
,
False
),
"out_norm"
:
getattr
(
args
,
"sae_out_norm"
,
False
),
"ctc_compress_strategy"
:
getattr
(
args
,
"ctc_compress_strategy"
,
None
),
"ctc_temperature"
:
getattr
(
args
,
"sae_ctc_temperature"
,
1.0
),
"distribution_cutoff"
:
getattr
(
args
,
"sae_distribution_cutoff"
,
None
),
"gumbel"
:
getattr
(
args
,
"sae_gumbel"
,
False
),
"distribution_hard"
:
getattr
(
args
,
"sae_distribution_hard"
,
None
),
"drop_prob"
:
getattr
(
args
,
"sae_drop_prob"
,
0
),
}
self
.
sae
=
Adapter
(
embed_dim
,
args
.
sae_adapter
,
self
.
sae
=
Adapter
(
embed_dim
,
args
.
target_
sae_adapter
,
len
(
dictionary
),
strategy
=
strategy
)
if
args
.
share_target_sae_and_ctc
and
hasattr
(
self
.
sae
,
"embed_adapter"
):
...
...
@@ -372,7 +380,6 @@ class TextualEncoder(FairseqEncoder):
norm_x
=
self
.
layer_norm
(
x
)
logit
=
self
.
ctc
(
norm_x
,
encoder_padding_mask
,
"Target Layer
%
d"
%
layer_idx
)
target_interleaved_ctc_logits
.
append
(
logit
)
prob
=
utils
.
softmax
(
logit
/
self
.
interleaved_ctc_temperature
,
dim
=-
1
)
# CTC alignment
oracle
=
None
...
...
@@ -386,7 +393,8 @@ class TextualEncoder(FairseqEncoder):
device
=
oracle
.
device
)
<
self
.
sae_ground_truth_ratio
)
.
bool
()
force_emit
=
best_aligns_pad
.
masked_fill
(
~
oracle_mask
,
-
1
)
x
,
encoder_padding_mask
=
self
.
sae
([
norm_x
,
prob
],
encoder_padding_mask
,
oracle
,
oracle_mask
)
if
self
.
sae
.
adapter_type
!=
"none"
:
x
,
encoder_padding_mask
=
self
.
sae
([
norm_x
,
logit
],
encoder_padding_mask
,
oracle
,
oracle_mask
)
if
history
is
not
None
:
history
.
push
(
x
)
...
...
@@ -398,7 +406,7 @@ class TextualEncoder(FairseqEncoder):
x
=
self
.
layer_norm
(
x
)
if
self
.
use_ctc
and
target_ctc_logit
is
None
:
target_ctc_logit
=
self
.
ctc
(
x
,
encoder_padding_mask
,
"Target output"
)
target_ctc_logit
=
self
.
ctc
(
x
,
encoder_padding_mask
,
"Target output"
,
is_top
=
True
)
return
x
,
target_ctc_logit
,
target_interleaved_ctc_logits
...
...
@@ -460,13 +468,19 @@ class S2TSATEEncoder(FairseqEncoder):
else
:
self
.
history
=
None
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
):
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
):
if
hasattr
(
self
.
acoustic_encoder
,
"ctc"
):
assert
src_dict
is
not
None
self
.
acoustic_encoder
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
src_dict
)
logger
.
info
(
"Acoustic Encoder CTC Inference"
)
self
.
acoustic_encoder
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
src_dict
,
path
=
path
+
".src_ctc"
if
path
is
not
None
else
None
)
# path=os.path.join(path, "src_ctc") if path is not None else None)
if
hasattr
(
self
.
textual_encoder
,
"ctc"
):
assert
tgt_dict
is
not
None
self
.
textual_encoder
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
tgt_dict
)
logger
.
info
(
"Textual Encoder CTC Inference"
)
self
.
textual_encoder
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
tgt_dict
,
path
=
path
+
".tgt_ctc"
if
path
is
not
None
else
None
)
# path=os.path.join(path, "tgt_ctc") if path is not None else None)
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
if
lang
==
"source"
:
...
...
@@ -500,11 +514,11 @@ class S2TSATEEncoder(FairseqEncoder):
if
"ctc_logit"
in
acoustic_encoder_out
and
len
(
acoustic_encoder_out
[
"ctc_logit"
])
>
0
:
ctc_logit
=
acoustic_encoder_out
[
"ctc_logit"
][
0
]
ctc_prob
=
F
.
softmax
(
ctc_logit
/
self
.
adapter_temperature
,
dim
=-
1
,
dtype
=
torch
.
float32
)
#
ctc_prob = F.softmax(ctc_logit / self.adapter_temperature, dim=-1, dtype=torch.float32)
else
:
ctc_logit
=
None
ctc_prob
=
None
x
=
(
encoder_out
,
ctc_
prob
)
#
ctc_prob = None
x
=
(
encoder_out
,
ctc_
logit
)
x
,
encoder_padding_mask
=
self
.
adapter
(
x
,
encoder_padding_mask
)
...
...
@@ -677,10 +691,13 @@ def base_architecture(args):
# Semantics-augmented Encoding (sae)
args
.
sae_adapter
=
getattr
(
args
,
"sae_adapter"
,
"none"
)
args
.
target_sae_adapter
=
getattr
(
args
,
"target_sae_adapter"
,
args
.
sae_adapter
)
args
.
share_sae_and_ctc
=
getattr
(
args
,
"share_sae_and_ctc"
,
False
)
args
.
share_target_sae_and_ctc
=
getattr
(
args
,
"share_target_sae_and_ctc"
,
False
)
args
.
sae_drop_prob
=
getattr
(
args
,
"sae_drop_prob"
,
0
)
args
.
sae_distribution_cutoff
=
getattr
(
args
,
"sae_distribution_cutoff"
,
None
)
args
.
sae_distribution_hard
=
getattr
(
args
,
"sae_distribution_hard"
,
False
)
args
.
sae_gumbel
=
getattr
(
args
,
"sae_gumbel"
,
False
)
# mixup
args
.
inter_mixup
=
getattr
(
args
,
"inter_mixup"
,
False
)
...
...
This diff is collapsed.
Click to expand it.
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
a201a883
...
...
@@ -415,7 +415,7 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help
=
"the position of interleaved ctc layers, separated by comma "
,
)
parser
.
add_argument
(
"--
interleaved
-ctc-temperature"
,
"--
sae
-ctc-temperature"
,
default
=
1
,
type
=
float
,
help
=
"temperature of the CTC probability in sae"
,
...
...
@@ -447,6 +447,16 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help
=
"cutoff of the distribution in sae"
,
)
parser
.
add_argument
(
"--sae-gumbel"
,
action
=
"store_true"
,
help
=
"use gumbel softmax in sae"
,
)
parser
.
add_argument
(
"--sae-distribution-hard"
,
action
=
"store_true"
,
help
=
"use hard distribution in sae"
,
)
parser
.
add_argument
(
"--sae-ground-truth-ratio"
,
default
=
0
,
type
=
float
,
...
...
@@ -643,7 +653,8 @@ class S2TTransformerEncoder(FairseqEncoder):
else
:
self
.
history
=
None
self
.
use_ctc
=
"sate"
in
args
.
arch
or
getattr
(
args
,
"ctc_weight"
,
0
)
>
0
# self.use_ctc = "sate" in args.arch or getattr(args, "ctc_weight", 0) > 0
self
.
use_ctc
=
getattr
(
args
,
"ctc_weight"
,
0
)
>
0
if
self
.
use_ctc
:
self
.
ctc_layer
=
args
.
ctc_layer
self
.
inter_ctc
=
True
if
self
.
ctc_layer
!=
0
and
self
.
ctc_layer
!=
args
.
encoder_layers
else
False
...
...
@@ -659,11 +670,12 @@ class S2TTransformerEncoder(FairseqEncoder):
embed_tokens
is
not
None
and
dim
==
embed_tokens
.
embedding_dim
:
self
.
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
self
.
interleaved_ctc_temperature
=
args
.
interleaved_ctc_temperature
self
.
interleaved_ctc_drop_prob
=
args
.
interleaved_ctc_drop_prob
self
.
sae_ground_truth_ratio
=
getattr
(
args
,
"sae_ground_truth_ratio"
,
0
)
self
.
interleaved_ctc_layers
=
[]
self
.
use_inter_ctc
=
False
if
args
.
interleaved_ctc_layers
is
not
None
:
self
.
use_inter_ctc
=
True
interleaved_ctc_layers
=
args
.
interleaved_ctc_layers
.
split
(
","
)
for
layer_idx
in
interleaved_ctc_layers
:
layer_idx
=
int
(
layer_idx
)
...
...
@@ -687,7 +699,10 @@ class S2TTransformerEncoder(FairseqEncoder):
"embed_norm"
:
getattr
(
args
,
"sae_embed_norm"
,
False
),
"out_norm"
:
getattr
(
args
,
"sae_out_norm"
,
False
),
"ctc_compress_strategy"
:
getattr
(
args
,
"ctc_compress_strategy"
,
None
),
"ctc_temperature"
:
getattr
(
args
,
"sae_ctc_temperature"
,
1.0
),
"distribution_cutoff"
:
getattr
(
args
,
"sae_distribution_cutoff"
,
None
),
"gumbel"
:
getattr
(
args
,
"sae_gumbel"
,
False
),
"distribution_hard"
:
getattr
(
args
,
"sae_distribution_hard"
,
None
),
"gt_ratio"
:
self
.
sae_ground_truth_ratio
,
"drop_prob"
:
getattr
(
args
,
"sae_drop_prob"
,
0
),
}
...
...
@@ -720,10 +735,18 @@ class S2TTransformerEncoder(FairseqEncoder):
# debug the variance
self
.
debug_var
=
False
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
):
self
.
update_num
=
0
self
.
curr_temp
=
0
def
set_num_updates
(
self
,
num_updates
):
super
()
.
set_num_updates
(
num_updates
)
self
.
update_num
=
num_updates
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
,
path
=
None
):
if
hasattr
(
self
,
"ctc"
):
assert
src_dict
is
not
None
self
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
src_dict
)
self
.
ctc
.
set_infer
(
ctc_infer
,
post_process
,
src_dict
,
path
=
path
+
".ctc"
if
path
is
not
None
else
None
)
def
ctc_valid
(
self
,
lprobs
,
targets
,
input_lengths
,
dictionary
,
lang
=
"source"
):
...
...
@@ -906,13 +929,8 @@ class S2TTransformerEncoder(FairseqEncoder):
norm_x
=
self
.
layer_norm
(
x
)
logit
=
self
.
ctc
(
norm_x
,
encoder_padding_mask
,
"Source Layer
%
d"
%
layer_idx
)
interleaved_ctc_logits
.
append
(
logit
)
logit
=
logit
.
clamp
(
min
=-
1e8
if
logit
.
dtype
==
torch
.
float32
else
-
1e4
,
max
=
1e8
if
logit
.
dtype
==
torch
.
float32
else
1e4
)
prob
=
utils
.
softmax
(
logit
/
self
.
interleaved_ctc_temperature
,
dim
=-
1
)
# CTC alignment
oracle
=
None
oracle_mask
=
None
...
...
@@ -925,7 +943,8 @@ class S2TTransformerEncoder(FairseqEncoder):
device
=
oracle
.
device
)
<
self
.
sae_ground_truth_ratio
)
.
bool
()
force_emit
=
best_aligns_pad
.
masked_fill
(
~
oracle_mask
,
-
1
)
x
,
encoder_padding_mask
=
self
.
sae
([
norm_x
,
prob
],
encoder_padding_mask
,
oracle
,
oracle_mask
)
if
self
.
sae
.
adapter_type
!=
"none"
:
x
,
encoder_padding_mask
=
self
.
sae
([
norm_x
,
logit
],
encoder_padding_mask
,
oracle
,
oracle_mask
)
self
.
show_debug
(
x
,
"x after sae"
)
# gather cosine similarity
...
...
@@ -945,7 +964,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
show_debug
(
x
,
"x after encoding layer norm"
)
if
self
.
use_ctc
and
ctc_logit
is
None
:
ctc_logit
=
self
.
ctc
(
x
,
encoder_padding_mask
,
"Source output"
)
ctc_logit
=
self
.
ctc
(
x
,
encoder_padding_mask
,
"Source output"
,
is_top
=
True
)
self
.
show_debug
(
x
,
"x after ctc"
)
return
{
...
...
@@ -1145,6 +1164,8 @@ def base_architecture(args):
args
.
sae_out_norm
=
getattr
(
args
,
"sae_out_norm"
,
False
)
args
.
sae_drop_prob
=
getattr
(
args
,
"sae_drop_prob"
,
0
)
args
.
sae_distribution_cutoff
=
getattr
(
args
,
"sae_distribution_cutoff"
,
None
)
args
.
sae_distribution_hard
=
getattr
(
args
,
"sae_distribution_hard"
,
False
)
args
.
sae_gumbel
=
getattr
(
args
,
"sae_gumbel"
,
False
)
# mixup
args
.
inter_mixup
=
getattr
(
args
,
"inter_mixup"
,
False
)
...
...
This diff is collapsed.
Click to expand it.
fairseq/models/transformer_ctc.py
查看文件 @
a201a883
...
...
@@ -319,7 +319,7 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help
=
"upsampling ratio of the representation for CTC calculation"
,
)
parser
.
add_argument
(
"--
interleaved
-ctc-temperature"
,
"--
sae
-ctc-temperature"
,
default
=
1
,
type
=
float
,
help
=
"temperature of the CTC probability in sae"
,
...
...
@@ -351,6 +351,16 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
help
=
"cutoff of the distribution in sae"
,
)
parser
.
add_argument
(
"--sae-gumbel"
,
action
=
"store_true"
,
help
=
"use gumbel softmax in sae"
,
)
parser
.
add_argument
(
"--sae-distribution-hard"
,
action
=
"store_true"
,
help
=
"use hard distribution in sae"
,
)
parser
.
add_argument
(
"--share-ctc-and-sae"
,
action
=
"store_true"
,
help
=
"share the weight of ctc and sae"
,
...
...
@@ -629,7 +639,6 @@ class TransformerCTCEncoder(FairseqEncoder):
self
.
ctc
.
ctc_projection
.
weight
=
decoder_embed_tokens
.
weight
self
.
interleaved_ctc_temperature
=
args
.
interleaved_ctc_temperature
self
.
interleaved_ctc_drop_prob
=
args
.
interleaved_ctc_drop_prob
self
.
interleaved_ctc_upsampling_ratio
=
int
(
args
.
interleaved_ctc_upsampling_ratio
)
self
.
interleaved_ctc_layers
=
[]
...
...
@@ -661,7 +670,10 @@ class TransformerCTCEncoder(FairseqEncoder):
"embed_norm"
:
getattr
(
args
,
"sae_embed_norm"
,
False
),
"out_norm"
:
getattr
(
args
,
"sae_out_norm"
,
False
),
"ctc_compress_strategy"
:
getattr
(
args
,
"ctc_compress_strategy"
,
None
),
"ctc_temperature"
:
getattr
(
args
,
"sae_ctc_temperature"
,
1.0
),
"distribution_cutoff"
:
getattr
(
args
,
"sae_distribution_cutoff"
,
None
),
"gumbel"
:
getattr
(
args
,
"sae_gumbel"
,
False
),
"distribution_hard"
:
getattr
(
args
,
"sae_distribution_hard"
,
None
),
"drop_prob"
:
getattr
(
args
,
"sae_drop_prob"
,
0
),
"gt_ratio"
:
self
.
sae_ground_truth_ratio
,
}
...
...
@@ -743,9 +755,6 @@ class TransformerCTCEncoder(FairseqEncoder):
return
x
if
len
(
x
.
size
())
==
3
:
# bsz, seq_len, dim = x.size()
# up_x = x.unsqueeze(2).expand(-1, -1, ratio, -1).reshape(bsz, -1, dim)
seq_len
,
bsz
,
dim
=
x
.
size
()
x
=
x
.
permute
(
1
,
2
,
0
)
up_x
=
self
.
un_sample
(
x
)
...
...
@@ -755,20 +764,25 @@ class TransformerCTCEncoder(FairseqEncoder):
up_x
=
x
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
ratio
)
.
reshape
(
bsz
,
-
1
)
up_padding
=
padding
.
unsqueeze
(
-
1
)
.
expand
(
-
1
,
-
1
,
int
(
ratio
))
.
reshape
(
bsz
,
-
1
)
# output_length = int(seq_len * ratio * 2/3)
# select_matrix = torch.rand(bsz, ratio * seq_len).to(up_x.device)
# select_matrix[:, 1::ratio] = 1
# mask = select_matrix.sort(dim=-1, descending=True)[1][:, :output_length]
# mask = mask.sort(dim=-1)[0]
#
# if len(x.size()) == 3:
# out_x = torch.gather(up_x, dim=1, index=mask.unsqueeze(-1).expand(-1, -1, dim)).contiguous()
# else:
# out_x = torch.gather(up_x, dim=1, index=mask).contiguous()
# out_padding = torch.gather(up_padding, dim=1, index=mask).contiguous()
out_x
=
up_x
out_padding
=
up_padding
perturb
=
False
if
perturb
:
output_length
=
int
(
seq_len
*
ratio
*
2
/
3
)
select_matrix
=
torch
.
rand
(
bsz
,
ratio
*
seq_len
)
.
to
(
up_x
.
device
)
select_matrix
[:,
1
::
ratio
]
=
1
mask
=
select_matrix
.
sort
(
dim
=-
1
,
descending
=
True
)[
1
][:,
:
output_length
]
mask
=
mask
.
sort
(
dim
=-
1
)[
0
]
if
len
(
x
.
size
())
==
3
:
up_x
=
up_x
.
transpose
(
0
,
1
)
out_x
=
torch
.
gather
(
up_x
,
dim
=
1
,
index
=
mask
.
unsqueeze
(
-
1
)
.
expand
(
-
1
,
-
1
,
dim
))
.
contiguous
()
out_x
=
out_x
.
transpose
(
0
,
1
)
else
:
out_x
=
torch
.
gather
(
up_x
,
dim
=
1
,
index
=
mask
)
.
contiguous
()
out_padding
=
torch
.
gather
(
up_padding
,
dim
=
1
,
index
=
mask
)
.
contiguous
()
else
:
out_x
=
up_x
.
contiguous
()
out_padding
=
up_padding
.
contiguous
()
return
out_x
,
out_padding
def
set_ctc_infer
(
self
,
ctc_infer
,
post_process
,
src_dict
=
None
,
tgt_dict
=
None
):
...
...
@@ -869,8 +883,8 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if
self
.
use_ctc
and
self
.
inter_ctc
and
self
.
ctc_layer
==
layer_idx
:
x
,
ctc_padding_mask
=
self
.
upsampling
(
x
,
encoder_padding_mask
)
ctc_logit
=
self
.
ctc
(
x
.
clone
()
,
ctc_padding_mask
)
up_
x
,
ctc_padding_mask
=
self
.
upsampling
(
x
,
encoder_padding_mask
)
ctc_logit
=
self
.
ctc
(
up_x
,
ctc_padding_mask
)
# Interleaved CTC
if
layer_idx
in
self
.
interleaved_ctc_layers
:
...
...
@@ -879,12 +893,10 @@ class TransformerCTCEncoder(FairseqEncoder):
if
p
<
self
.
interleaved_ctc_drop_prob
:
break
x
,
ctc_padding_mask
=
self
.
upsampling
(
x
,
encoder_padding_mask
)
norm_x
=
self
.
layer_norm
(
x
)
up_
x
,
ctc_padding_mask
=
self
.
upsampling
(
x
,
encoder_padding_mask
)
norm_x
=
self
.
layer_norm
(
up_
x
)
logit
=
self
.
ctc
(
norm_x
,
ctc_padding_mask
)
interleaved_ctc_logits
.
append
(
logit
)
prob
=
utils
.
softmax
(
logit
/
self
.
interleaved_ctc_temperature
,
dim
=-
1
)
# CTC alignment
oracle
=
None
...
...
@@ -898,7 +910,7 @@ class TransformerCTCEncoder(FairseqEncoder):
device
=
oracle
.
device
)
<
self
.
sae_ground_truth_ratio
)
.
bool
()
force_emit
=
best_aligns_pad
.
masked_fill
(
~
oracle_mask
,
-
1
)
x
,
_
=
self
.
sae
([
norm_x
,
prob
],
ctc_padding_mask
,
oracle
,
oracle_mask
)
x
,
_
=
self
.
sae
([
norm_x
,
logit
],
ctc_padding_mask
,
oracle
,
oracle_mask
)
x
=
x
.
permute
(
1
,
2
,
0
)
# x = nn.functional.interpolate(x, scale_factor=1/self.interleaved_ctc_upsampling_ratio, mode="linear")
...
...
@@ -915,7 +927,8 @@ class TransformerCTCEncoder(FairseqEncoder):
x
=
self
.
layer_norm
(
x
)
if
self
.
use_ctc
and
ctc_logit
is
None
:
ctc_logit
=
self
.
ctc
(
x
,
ctc_padding_mask
)
up_x
,
ctc_padding_mask
=
self
.
upsampling
(
x
,
encoder_padding_mask
)
ctc_logit
=
self
.
ctc
(
up_x
,
ctc_padding_mask
)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
...
...
@@ -1592,6 +1605,8 @@ def base_architecture(args):
args
.
share_ctc_and_sae
=
getattr
(
args
,
"share_ctc_and_sae"
,
False
)
args
.
sae_drop_prob
=
getattr
(
args
,
"sae_drop_prob"
,
0
)
args
.
sae_distribution_cutoff
=
getattr
(
args
,
"sae_distribution_cutoff"
,
None
)
args
.
sae_distribution_hard
=
getattr
(
args
,
"sae_distribution_hard"
,
False
)
args
.
sae_gumbel
=
getattr
(
args
,
"sae_gumbel"
,
False
)
@register_model_architecture
(
"transformer_ctc"
,
"transformer_ctc_relative"
)
...
...
This diff is collapsed.
Click to expand it.
fairseq/modules/speech_to_text/adapter.py
查看文件 @
a201a883
...
...
@@ -100,34 +100,45 @@ class Adapter(nn.Module):
if
self
.
cal_context
:
self
.
distribution_cutoff
=
strategy
.
get
(
"distribution_cutoff"
,
None
)
self
.
distribution_temperature
=
strategy
.
get
(
"ctc_temperature"
,
1.0
)
self
.
gumbel
=
strategy
.
get
(
"gumbel"
,
False
)
self
.
distribution_hard
=
strategy
.
get
(
"distribution_hard"
,
False
)
self
.
ground_truth_ratio
=
strategy
.
get
(
"gt_ratio"
,
0
)
self
.
drop_prob
=
strategy
.
get
(
"drop_prob"
,
0
)
if
self
.
distribution_cutoff
is
not
None
:
self
.
distribution_cutoff
=
int
(
self
.
distribution_cutoff
)
logger
.
info
(
"Distribution cutoff:
%
d"
%
self
.
distribution_cutoff
)
self
.
drop_prob
=
strategy
.
get
(
"drop_prob"
,
0
)
if
self
.
distribution_temperature
!=
1.0
:
logger
.
info
(
"Temperature:
%
f"
%
self
.
distribution_temperature
)
if
self
.
gumbel
:
logger
.
info
(
"Gumbel softmax."
)
if
self
.
distribution_hard
:
logger
.
info
(
"Hard distribution."
)
if
self
.
drop_prob
!=
0
:
logger
.
info
(
"
Adapter d
rop probability:
%
f"
%
self
.
drop_prob
)
logger
.
info
(
"
D
rop probability:
%
f"
%
self
.
drop_prob
)
self
.
ground_truth_ratio
=
strategy
.
get
(
"gt_ratio"
,
0
)
self
.
out_norm
=
strategy
.
get
(
"out_norm"
,
False
)
if
self
.
out_norm
:
self
.
out_ln
=
LayerNorm
(
dim
)
def
forward
(
self
,
x
,
padding
=
None
,
oracle
=
None
,
oracle_mask
=
None
):
representation
,
distribution
=
x
distribution
=
distribution
.
type_as
(
representation
)
representation
,
logit
=
x
seq_len
,
bsz
,
dim
=
representation
.
size
()
org_distribution
=
distribution
vocab_size
=
distribution
.
size
(
-
1
)
distribution
=
distribution
.
contiguous
()
.
view
(
-
1
,
vocab_size
)
linear_out
=
None
soft_out
=
None
if
self
.
cal_linear
:
linear_out
=
self
.
linear_adapter
(
representation
)
if
self
.
cal_context
:
if
self
.
training
and
self
.
gumbel
:
distribution
=
F
.
gumbel_softmax
(
logit
,
tau
=
self
.
distribution_temperature
,
hard
=
self
.
distribution_hard
)
else
:
distribution
=
F
.
softmax
(
logit
/
self
.
distribution_temperature
,
dim
=-
1
)
vocab_size
=
distribution
.
size
(
-
1
)
distribution
=
distribution
.
contiguous
()
.
view
(
-
1
,
vocab_size
)
org_distribution
=
distribution
if
self
.
distribution_cutoff
is
not
None
:
cutoff
=
min
(
int
(
self
.
distribution_cutoff
),
vocab_size
-
1
)
...
...
@@ -184,11 +195,15 @@ class Adapter(nn.Module):
out
=
representation
elif
self
.
adapter_type
==
"shrink"
:
if
self
.
training
and
self
.
gumbel
:
distribution
=
F
.
gumbel_softmax
(
logit
,
tau
=
self
.
distribution_temperature
,
hard
=
self
.
distribution_hard
)
else
:
distribution
=
F
.
softmax
(
logit
/
self
.
distribution_temperature
,
dim
=-
1
)
lengths
=
(
~
padding
)
.
long
()
.
sum
(
-
1
)
with
torch
.
no_grad
():
batch_predicted
=
[]
prob_ctc
=
org_
distribution
.
transpose
(
0
,
1
)
# T x B x D -> B x T x D
prob_ctc
=
distribution
.
transpose
(
0
,
1
)
# T x B x D -> B x T x D
for
b
in
range
(
prob_ctc
.
shape
[
0
]):
predicted
=
prob_ctc
[
b
][:
lengths
[
b
]]
.
argmax
(
-
1
)
.
tolist
()
batch_predicted
.
append
([(
p
[
0
],
len
(
list
(
p
[
1
])))
for
p
in
groupby
(
predicted
)])
...
...
This diff is collapsed.
Click to expand it.
fairseq/modules/speech_to_text/ctc.py
查看文件 @
a201a883
...
...
@@ -39,18 +39,23 @@ class CTC(nn.Module):
self
.
post_process
=
"sentencepiece"
self
.
blank_idx
=
0
def
set_infer
(
self
,
is_infer
,
text_post_process
,
dictionary
):
def
set_infer
(
self
,
is_infer
,
text_post_process
,
dictionary
,
path
):
self
.
infer_decoding
=
is_infer
self
.
post_process
=
text_post_process
self
.
dictionary
=
dictionary
self
.
path
=
path
if
self
.
path
is
not
None
:
self
.
save_stream
=
open
(
self
.
path
,
"a"
)
else
:
self
.
save_stream
=
None
def
forward
(
self
,
x
,
padding
=
None
,
tag
=
None
):
def
forward
(
self
,
x
,
padding
=
None
,
tag
=
None
,
is_top
=
False
):
if
self
.
need_layernorm
:
x
=
self
.
LayerNorm
(
x
)
x
=
self
.
ctc_projection
(
self
.
ctc_dropout_module
(
x
))
if
not
self
.
training
and
self
.
infer_decoding
:
if
not
self
.
training
and
self
.
infer_decoding
and
is_top
:
assert
self
.
dictionary
is
not
None
input_lengths
=
(
~
padding
)
.
sum
(
-
1
)
self
.
infer
(
x
.
transpose
(
0
,
1
)
.
float
()
.
contiguous
()
.
cpu
(),
input_lengths
,
tag
)
...
...
@@ -79,6 +84,9 @@ class CTC(nn.Module):
pred_units
=
self
.
dictionary
.
string
(
pred_units_arr
)
pred_words_raw
=
post_process
(
pred_units
,
self
.
post_process
)
.
split
()
if
self
.
save_stream
is
not
None
:
self
.
save_stream
.
write
(
" "
.
join
(
pred_words_raw
)
+
"
\n
"
)
if
tag
is
not
None
:
logger
.
info
(
"
%
s CTC prediction:
%
s"
%
(
tag
,
" "
.
join
(
pred_words_raw
)))
else
:
...
...
This diff is collapsed.
Click to expand it.
fairseq_cli/generate.py
查看文件 @
a201a883
...
...
@@ -108,7 +108,7 @@ def _main(cfg: DictConfig, output_file, translation_path=None):
for
model
in
models
:
if
hasattr
(
model
,
"encoder"
)
and
hasattr
(
model
.
encoder
,
"set_ctc_infer"
):
model
.
encoder
.
set_ctc_infer
(
cfg
.
generation
.
ctc_infer
,
"sentencepiece"
,
src_dict
,
tgt_dict
)
src_dict
,
tgt_dict
,
translation_path
)
# os.path.dirname(translation_path)
)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task
.
load_dataset
(
cfg
.
dataset
.
gen_subset
,
task_cfg
=
saved_cfg
.
task
)
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论