Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
31e7c426
Commit
31e7c426
authored
Jan 10, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix the bug of the intermedia ctc losses
parent
2215ade0
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
104 行增加
和
72 行删除
+104
-72
egs/libri_trans/asr/conf/debug.yaml
+1
-1
egs/wmt16/mt/local/wmt_en2de_multi_bleu.sh
+8
-7
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
+53
-14
fairseq/models/speech_to_text/__init__.py
+1
-1
fairseq/models/speech_to_text/ctc.py
+6
-9
fairseq/models/speech_to_text/inter_ctc_s2t_transformer.py
+34
-39
fairseq/models/speech_to_text/s2t_sate.py
+1
-1
没有找到文件。
egs/libri_trans/asr/conf/debug.yaml
查看文件 @
31e7c426
arch
:
multi_ctc_s2t_transformer_s
multi-ctc-layers
:
6,8,10,12
intermedia-ctc-layers
:
6,8,10
share-decoder-input-output-embed
:
True
optimizer
:
adam
clip-norm
:
10.0
...
...
egs/wmt16/mt/local/wmt_en2de_multi_bleu.sh
查看文件 @
31e7c426
...
...
@@ -29,21 +29,21 @@ cat $GEN | cut -f 3 > $REF
cat
$GEN
| cut
-f
4
>
$SYS
#detokenize the decodes file to format the manner to do tokenize
perl
$detokenizer
-l
de <
$SYS
>
$SYS
.dtk
perl
$detokenizer
-l
de <
$REF
>
$REF
.dtk
$detokenizer
-l
de <
$SYS
>
$SYS
.dtk
$detokenizer
-l
de <
$REF
>
$REF
.dtk
#replace unicode
perl
$replace_unicode_punctuation
-l
de <
$SYS
.dtk
>
$SYS
.dtk.punc
perl
$replace_unicode_punctuation
-l
de <
$REF
.dtk
>
$REF
.dtk.punc
$replace_unicode_punctuation
-l
de <
$SYS
.dtk
>
$SYS
.dtk.punc
$replace_unicode_punctuation
-l
de <
$REF
.dtk
>
$REF
.dtk.punc
#tokenize the decodes file by moses tokenizer.perl
perl
$tokenizer
-l
de <
$SYS
.dtk.punc
>
$SYS
.dtk.punc.tok
perl
$tokenizer
-l
de <
$REF
.dtk.punc
>
$REF
.dtk.punc.tok
$tokenizer
-l
de <
$SYS
.dtk.punc
>
$SYS
.dtk.punc.tok
$tokenizer
-l
de <
$REF
.dtk.punc
>
$REF
.dtk.punc.tok
#"rich-text format" --> rich ##AT##-##AT## text format.
perl
-ple
's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g'
<
$SYS
.dtk.punc.tok
>
$SYS
.dtk.punc.tok.atat
perl
-ple
's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g'
<
$REF
.dtk.punc.tok
>
$REF
.dtk.punc.tok.atat
perl
$multi_bleu
$REF
.dtk.punc.tok.atat <
$SYS
.dtk.punc.tok.atat
$multi_bleu
$REF
.dtk.punc.tok.atat <
$SYS
.dtk.punc.tok.atat
rm
-f
$SYS
.dtk
$SYS
.dtk.punc
$SYS
.dtk.punc.tok
$REF
.dtk
$REF
.dtk.punc
$REF
.dtk.punc.tok
\ No newline at end of file
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
查看文件 @
31e7c426
...
...
@@ -19,7 +19,8 @@ from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
class
LabelSmoothedCrossEntropyCriterionWithCTC
(
LabelSmoothedCrossEntropyCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
post_process
=
"letter"
,
ctc_weight
=
0.0
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
post_process
=
"letter"
,
ctc_weight
=
0.0
,
intermedia_ctc_weight
=
0.0
):
super
()
.
__init__
(
task
,
sentence_avg
,
label_smoothing
)
self
.
blank_idx
=
task
.
target_dictionary
.
index
(
task
.
blank_symbol
)
if
hasattr
(
task
,
'blank_symbol'
)
else
0
self
.
pad_idx
=
task
.
target_dictionary
.
pad
()
...
...
@@ -29,7 +30,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
assert
0
<=
ctc_weight
self
.
ctc_weight
=
ctc_weight
if
self
.
ctc_weight
>
0
:
self
.
intermedia_ctc_weight
=
intermedia_ctc_weight
if
self
.
ctc_weight
>
0
or
self
.
intermedia_ctc_weight
>
0
:
assert
getattr
(
task
,
"src_dict"
,
None
)
is
not
None
,
"CTC need a source dictionary."
self
.
post_process
=
post_process
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
blank
=
self
.
blank_idx
,
reduction
=
"sum"
,
zero_infinity
=
True
)
...
...
@@ -52,6 +54,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
help
=
"weight of CTC loss"
,
)
parser
.
add_argument
(
"--intermedia-ctc-weight"
,
default
=
0.0
,
type
=
float
,
metavar
=
"D"
,
help
=
"weight of intermedia CT loss"
,
)
parser
.
add_argument
(
"--post-process"
,
default
=
"letter"
,
type
=
str
,
...
...
@@ -91,10 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output
[
"n_correct"
]
=
utils
.
item
(
n_correct
.
data
)
logging_output
[
"total"
]
=
utils
.
item
(
total
.
data
)
if
self
.
ctc_weight
>
0
:
if
self
.
ctc_weight
>
0
or
self
.
intermedia_ctc_weight
>
0
:
ctc_loss
,
logging_output
=
self
.
compute_ctc_loss
(
model
,
sample
,
encoder_out
,
logging_output
)
logging_output
[
"ctc_loss"
]
=
utils
.
item
(
ctc_loss
.
data
)
loss
=
(
1
-
self
.
ctc_weight
)
*
loss
+
self
.
ctc_weight
*
ctc_loss
loss
=
(
1
-
self
.
ctc_weight
)
*
loss
+
ctc_loss
logging_output
[
"loss"
]
=
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
return
loss
,
sample_size
,
logging_output
...
...
@@ -114,10 +122,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
transcript_lengths
=
pad_mask
.
sum
(
-
1
)
ctc_loss
=
0
ctc_num
=
len
(
encoder_out
[
"ctc_logit"
])
assert
ctc_num
!=
0
,
"No ctc logit for loss!"
for
i
in
range
(
ctc_num
):
if
"ctc_logit"
in
encoder_out
and
len
(
encoder_out
[
"ctc_logit"
])
>
0
:
ctc_logit
=
encoder_out
[
"ctc_logit"
][
0
]
lprobs
=
model
.
get_normalized_probs
(
[
ctc_logit
],
log_probs
=
True
...
...
@@ -125,17 +130,41 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
lprobs
.
batch_first
=
False
with
torch
.
backends
.
cudnn
.
flags
(
enabled
=
False
):
loss
=
self
.
ctc_loss
(
ctc_
loss
=
self
.
ctc_loss
(
lprobs
,
targets_flat
,
input_lengths
,
transcript_lengths
,
)
ctc_loss
+=
loss
ctc_loss
/=
ctc_num
logging_output
[
"ctc_loss"
]
=
utils
.
item
(
ctc_loss
.
data
)
if
not
model
.
training
:
intermedia_ctc_num
=
0
intermedia_ctc_loss
=
0
if
"intermedia_ctc_logit"
in
encoder_out
:
intermedia_ctc_num
=
len
(
encoder_out
[
"intermedia_ctc_logit"
])
if
intermedia_ctc_num
>
0
:
for
i
in
range
(
intermedia_ctc_num
):
ctc_logit
=
encoder_out
[
"intermedia_ctc_logit"
][
i
]
inter_lprobs
=
model
.
get_normalized_probs
(
[
ctc_logit
],
log_probs
=
True
)
.
contiguous
()
# (T, B, C) from the encoder
inter_lprobs
.
batch_first
=
False
with
torch
.
backends
.
cudnn
.
flags
(
enabled
=
False
):
loss
=
self
.
ctc_loss
(
inter_lprobs
,
targets_flat
,
input_lengths
,
transcript_lengths
,
)
intermedia_ctc_loss
+=
loss
intermedia_ctc_loss
/=
intermedia_ctc_num
logging_output
[
"intermedia_ctc_loss"
]
=
utils
.
item
(
intermedia_ctc_loss
.
data
)
loss
=
self
.
ctc_weight
*
ctc_loss
+
self
.
intermedia_ctc_weight
*
intermedia_ctc_loss
if
not
model
.
training
and
self
.
ctc_weight
>
0
:
import
editdistance
with
torch
.
no_grad
():
...
...
@@ -189,7 +218,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
logging_output
[
"c_errors"
]
=
c_err
logging_output
[
"c_total"
]
=
c_len
return
ctc_
loss
,
logging_output
return
loss
,
logging_output
@staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
...
...
@@ -204,6 +233,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
inter_ctc_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"intermedia_ctc_loss"
,
0
)
for
log
in
logging_outputs
)
)
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
...
...
@@ -226,6 +258,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
sample_size
,
round
=
3
,
)
if
inter_ctc_loss_sum
>
0
:
metrics
.
log_scalar
(
"intermedia_ctc_loss"
,
inter_ctc_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
]
.
avg
)
)
...
...
fairseq/models/speech_to_text/__init__.py
查看文件 @
31e7c426
...
...
@@ -7,7 +7,7 @@ from .berard import * # noqa
from
.ctc
import
*
# noqa
from
.convtransformer
import
*
# noqa
from
.s2t_transformer
import
*
# noqa
from
.
multi
_ctc_s2t_transformer
import
*
# noqa
from
.
inter
_ctc_s2t_transformer
import
*
# noqa
from
.s2t_conformer
import
*
# noqa
from
.pdss2t_transformer
import
*
# noqa
from
.s2t_sate
import
*
# noqa
fairseq/models/speech_to_text/ctc.py
查看文件 @
31e7c426
...
...
@@ -39,17 +39,14 @@ class CTC(nn.Module):
x
=
self
.
ctc_projection
(
self
.
ctc_dropout_module
(
x
))
return
x
@staticmethod
def
softmax
(
ctc_logit
,
temperature
=
1.0
):
return
torch
.
nn
.
functional
.
softmax
(
ctc_logit
/
temperature
,
dim
=-
1
)
def
softmax
(
self
,
x
,
temperature
=
1.0
):
return
torch
.
nn
.
functional
.
softmax
(
self
.
ctc_projection
(
x
)
/
temperature
,
dim
=-
1
)
@staticmethod
def
log_softmax
(
ctc_logit
,
temperature
=
1.0
):
return
torch
.
nn
.
functional
.
log_softmax
(
ctc_logit
/
temperature
,
dim
=-
1
)
def
log_softmax
(
self
,
x
,
temperature
=
1.0
):
return
torch
.
nn
.
functional
.
log_softmax
(
self
.
ctc_projection
(
x
)
/
temperature
,
dim
=-
1
)
@staticmethod
def
argmax
(
ctc_logit
):
return
torch
.
argmax
(
ctc_logit
,
dim
=-
1
)
def
argmax
(
self
,
x
):
return
torch
.
argmax
(
self
.
ctc_projection
(
x
),
dim
=-
1
)
class
CTCCompressStrategy
:
...
...
fairseq/models/speech_to_text/
multi
_ctc_s2t_transformer.py
→
fairseq/models/speech_to_text/
inter
_ctc_s2t_transformer.py
查看文件 @
31e7c426
...
...
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class
Adapter
(
nn
.
Module
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
=
None
):
super
()
.
__init__
()
embed_dim
=
args
.
encoder_embed_dim
...
...
@@ -45,7 +45,7 @@ class Adapter(nn.Module):
if
self
.
adapter_type
in
[
"embed"
,
"context"
,
"league"
,
"gated_league"
,
"gated_league2"
]:
if
embed_tokens
is
None
:
num_embeddings
=
len
(
dictionary
)
self
.
embed_adapter
=
Embedding
(
num_embeddings
,
embed_dim
,
self
.
padding_idx
)
self
.
embed_adapter
=
Embedding
(
num_embeddings
,
embed_dim
,
dictionary
.
pad
()
)
else
:
self
.
embed_adapter
=
embed_tokens
...
...
@@ -115,9 +115,9 @@ class Adapter(nn.Module):
return
out
,
padding
@register_model
(
"
multi
_ctc_s2t_transformer"
)
class
Multi
CTCS2TTransformerModel
(
S2TTransformerModel
):
"""Speech-to-Text Transformer with
multiple
CTC Loss in different layers"""
@register_model
(
"
inter
_ctc_s2t_transformer"
)
class
Inter
CTCS2TTransformerModel
(
S2TTransformerModel
):
"""Speech-to-Text Transformer with
intermedia
CTC Loss in different layers"""
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
...
...
@@ -126,10 +126,10 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
def
add_args
(
parser
):
S2TTransformerModel
.
add_args
(
parser
)
parser
.
add_argument
(
"--
multi
-ctc-layers"
,
"--
intermedia
-ctc-layers"
,
default
=
None
,
type
=
str
,
help
=
"the position of the ctc loss, separated by "
,
help
=
"the position of the ctc loss, separated by
comma
"
,
)
parser
.
add_argument
(
"--adapter"
,
...
...
@@ -147,7 +147,7 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
@classmethod
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2T
Multi
CTCTransformerEncoder
(
args
,
task
,
embed_tokens
)
encoder
=
S2T
Inter
CTCTransformerEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
...
...
@@ -159,40 +159,34 @@ class MultiCTCS2TTransformerModel(S2TTransformerModel):
return
encoder
class
S2TMultiCTCTransformerEncoder
(
S2TTransformerEncoder
):
"""Speech-to-text Transformer encoder that consists of multiple input subsampler and
Conformer encoder."""
class
S2TInterCTCTransformerEncoder
(
S2TTransformerEncoder
):
"""Speech-to-text Transformer encoder that consists of intermedia ctc losses """
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
args
,
task
,
embed_tokens
)
if
self
.
use_ctc
:
del
self
.
ctc
self
.
multi_ctc_layers
=
[]
if
args
.
multi_ctc_layers
is
not
None
:
multi_ctc_layers
=
args
.
multi_ctc_layers
.
split
(
","
)
for
layer_idx
in
multi_ctc_layers
:
self
.
intermedia_ctc_layers
=
[]
if
args
.
intermedia_ctc_layers
is
not
None
:
intermedia_ctc_layers
=
args
.
intermedia_ctc_layers
.
split
(
","
)
for
layer_idx
in
intermedia_ctc_layers
:
layer_idx
=
int
(
layer_idx
)
if
layer_idx
<=
0
:
layer_idx
+=
args
.
encoder_layers
self
.
multi
_ctc_layers
.
append
(
layer_idx
)
self
.
intermedia
_ctc_layers
.
append
(
layer_idx
)
inter_ctc
=
True
if
layer_idx
!=
args
.
encoder_layers
else
False
if
inter_ctc
:
logger
.
info
(
"Intermedia CTC loss in layer
%
d"
%
layer_idx
)
ctc
=
CTC
(
args
.
encoder_embed_dim
,
dictionary_size
=
len
(
task
.
source_dictionary
),
dropout
=
args
.
dropout
,
need_layernorm
=
inter_ctc
)
need_layernorm
=
True
)
if
task
.
source_dictionary
==
task
.
target_dictionary
and
embed_tokens
is
not
None
:
ctc
.
ctc_projection
.
weight
=
embed_tokens
.
weight
ctc
.
ctc_projection
.
weight
=
self
.
ctc
.
ctc_projection
.
weight
ctc
.
LayerNorm
=
self
.
layer_norm
setattr
(
self
,
f
"ctc{layer_idx}"
,
ctc
)
if
inter_ctc
:
adapter
=
Adapter
(
args
,
task
.
source_dictionary
,
ctc
.
ctc_projection
)
adapter
=
Adapter
(
args
,
task
.
source_dictionary
)
#
adapter = Adapter(args, task.source_dictionary, ctc.ctc_projection)
setattr
(
self
,
f
"adapter{layer_idx}"
,
adapter
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
...
...
@@ -223,7 +217,8 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
self
.
history
.
push
(
x
)
layer_idx
=
0
ctc_logit
=
[]
ctc_logit
=
None
intermedia_ctc_logit
=
[]
for
layer
in
self
.
layers
:
layer_idx
+=
1
...
...
@@ -234,14 +229,14 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
# interleave CTC
if
self
.
use_ctc
and
layer_idx
in
self
.
multi_ctc_layers
and
layer_idx
!=
len
(
self
.
layers
)
:
if
layer_idx
in
self
.
intermedia_ctc_layers
:
ctc
=
getattr
(
self
,
f
"ctc{layer_idx}"
)
adapter
=
getattr
(
self
,
f
"adapter{layer_idx}"
)
logit
=
ctc
(
x
)
prob
=
ctc
.
softmax
(
logit
)
prob
=
ctc
.
softmax
(
x
)
x
,
encoder_padding_mask
=
adapter
([
x
,
prob
],
encoder_padding_mask
)
ctc_logit
.
append
(
ctc
(
x
)
)
intermedia_ctc_logit
.
append
(
logit
)
if
layer_idx
!=
len
(
self
.
layers
)
\
and
self
.
interleaved_dropout
is
not
None
\
...
...
@@ -257,13 +252,13 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
if
self
.
use_ctc
and
len
(
self
.
layers
)
in
self
.
multi_ctc_layers
:
ctc
=
getattr
(
self
,
f
"ctc{len(self.layers)}"
)
ctc_logit
.
append
(
ctc
(
x
))
if
self
.
use_ctc
:
ctc_logit
=
self
.
ctc
(
x
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"ctc_logit"
:
ctc_logit
,
# B x T x C
"ctc_logit"
:
[]
if
ctc_logit
is
None
else
[
ctc_logit
],
# B x T x C
"intermedia_ctc_logit"
:
intermedia_ctc_logit
,
# B x T x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[],
# B x T x C
"encoder_states"
:
[],
# List[T x B x C]
...
...
@@ -272,7 +267,7 @@ class S2TMultiCTCTransformerEncoder(S2TTransformerEncoder):
}
@register_model_architecture
(
model_name
=
"
multi_ctc_s2t_transformer"
,
arch_name
=
"multi
_ctc_s2t_transformer"
)
@register_model_architecture
(
model_name
=
"
inter_ctc_s2t_transformer"
,
arch_name
=
"inter
_ctc_s2t_transformer"
)
def
base_architecture
(
args
):
# Convolutional subsampler
args
.
conv_kernel_sizes
=
getattr
(
args
,
"conv_kernel_sizes"
,
"5,5"
)
...
...
@@ -321,7 +316,7 @@ def base_architecture(args):
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
# CTC
args
.
multi_ctc_layers
=
getattr
(
args
,
"multi
_ctc_layers"
,
0
)
args
.
intermedia_ctc_layers
=
getattr
(
args
,
"intermedia
_ctc_layers"
,
0
)
# Conformer
args
.
macaron_style
=
getattr
(
args
,
"macaron_style"
,
False
)
...
...
@@ -356,13 +351,13 @@ def base_architecture(args):
args
.
cl_dropout_strategy
=
getattr
(
args
,
"cl_dropout_strategy"
,
"linear"
)
@register_model_architecture
(
"
multi_ctc_s2t_transformer"
,
"multi
_ctc_s2t_transformer_s"
)
def
multi
_ctc_s2t_transformer_s
(
args
):
@register_model_architecture
(
"
inter_ctc_s2t_transformer"
,
"inter
_ctc_s2t_transformer_s"
)
def
inter
_ctc_s2t_transformer_s
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
8
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
multi_ctc_layers
=
getattr
(
args
,
"multi
_ctc_layers"
,
None
)
args
.
intermedia_ctc_layers
=
getattr
(
args
,
"intermedia
_ctc_layers"
,
None
)
base_architecture
(
args
)
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
31e7c426
...
...
@@ -170,7 +170,7 @@ class Adapter(nn.Module):
if
self
.
adapter_type
in
[
"embed"
,
"context"
,
"league"
,
"gated_league"
,
"gated_league2"
]:
if
embed_tokens
is
None
:
num_embeddings
=
len
(
dictionary
)
self
.
embed_adapter
=
Embedding
(
num_embeddings
,
embed_dim
,
self
.
padding_idx
)
self
.
embed_adapter
=
Embedding
(
num_embeddings
,
embed_dim
,
dictionary
.
pad
()
)
else
:
self
.
embed_adapter
=
embed_tokens
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论