Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
4cffbd98
Commit
4cffbd98
authored
Apr 07, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
the fix version of sate
parent
8645e75b
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
429 行增加
和
15 行删除
+429
-15
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
+3
-0
fairseq/models/speech_to_text/__init__.py
+1
-0
fairseq/models/speech_to_text/s2t_conformer.py
+4
-4
fairseq/models/speech_to_text/s2t_sate.py
+404
-0
fairseq/models/speech_to_text/s2t_transformer.py
+17
-11
没有找到文件。
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
查看文件 @
4cffbd98
...
...
@@ -100,6 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def
compute_ctc_loss
(
self
,
model
,
sample
,
encoder_out
):
transcript
=
sample
[
"transcript"
]
if
"ctc_logit"
in
encoder_out
:
ctc_logit
=
encoder_out
[
"ctc_logit"
][
0
]
else
:
ctc_logit
=
model
.
encoder
.
compute_ctc_logit
(
encoder_out
)
lprobs
=
model
.
get_normalized_probs
(
[
ctc_logit
],
log_probs
=
True
...
...
fairseq/models/speech_to_text/__init__.py
查看文件 @
4cffbd98
...
...
@@ -7,3 +7,4 @@ from .berard import * # noqa
from
.convtransformer
import
*
# noqa
from
.s2t_transformer
import
*
# noqa
from
.s2t_conformer
import
*
# noqa
from
.s2t_sate
import
*
# noqa
fairseq/models/speech_to_text/s2t_conformer.py
查看文件 @
4cffbd98
...
...
@@ -92,12 +92,12 @@ class S2TConformerEncoder(S2TTransformerEncoder):
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
args
,
task
,
embed_tokens
)
self
.
conformer_layers
=
nn
.
ModuleList
(
del
self
.
layers
self
.
layers
=
nn
.
ModuleList
(
[
ConformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
encoder_layers
)]
)
del
self
.
transformer_layers
def
forward
(
self
,
src_tokens
,
src_lengths
):
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
x
=
self
.
embed_scale
*
x
...
...
@@ -109,7 +109,7 @@ class S2TConformerEncoder(S2TTransformerEncoder):
x
=
self
.
dropout_module
(
x
)
positions
=
self
.
dropout_module
(
positions
)
for
layer
in
self
.
conformer_
layers
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
self
.
layer_norm
is
not
None
:
...
...
fairseq/models/speech_to_text/s2t_sate.py
0 → 100644
查看文件 @
4cffbd98
#!/usr/bin/env python3
import
logging
import
math
import
torch
import
torch.nn
as
nn
from
fairseq
import
checkpoint_utils
from
fairseq.models
import
(
FairseqEncoder
,
register_model
,
register_model_architecture
,
)
from
fairseq.models.transformer
import
Embedding
,
TransformerDecoder
from
fairseq.models.speech_to_text
import
(
S2TTransformerModel
,
S2TTransformerEncoder
,
S2TConformerEncoder
,
S2TConformerModel
)
from
fairseq.models.speech_to_text.s2t_transformer
import
Conv1dSubsampler
from
fairseq.modules
import
(
FairseqDropout
,
LayerNorm
,
PositionalEmbedding
,
TransformerEncoderLayer
,
)
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"s2t_sate"
)
class
S2TSATEModel
(
S2TTransformerModel
):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def
__init__
(
self
,
encoder
,
decoder
):
super
()
.
__init__
(
encoder
,
decoder
)
@staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
S2TConformerModel
.
add_args
(
parser
)
parser
.
add_argument
(
"--text-encoder-layers"
,
default
=
6
,
type
=
int
,
help
=
"layers of the text encoder"
,
)
parser
.
add_argument
(
"--adapter"
,
default
=
"league"
,
type
=
str
,
help
=
"adapter type"
,
)
parser
.
add_argument
(
"--acoustic-encoder"
,
default
=
"transformer"
,
type
=
str
,
help
=
"the architecture of the acoustic encoder"
,
)
parser
.
add_argument
(
"--load-pretrained-acoustic-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take acoustic encoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--load-pretrained-text-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take text encoder weights from (for initialization)"
,
)
pass
@classmethod
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TSATEEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_acoustic_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained acoustic encoder from: "
f
"{args.load_pretrained_acoustic_encoder_from}"
)
encoder
.
acoustic_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
acoustic_encoder
,
checkpoint
=
args
.
load_pretrained_acoustic_encoder_from
,
strict
=
False
)
if
getattr
(
args
,
"load_pretrained_text_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained text encoder from: "
f
"{args.load_pretrained_text_encoder_from}"
)
encoder
.
text_encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
.
text_encoder
,
checkpoint
=
args
.
load_pretrained_text_encoder_from
,
strict
=
False
)
return
encoder
class
Adapter
(
nn
.
Module
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
super
()
.
__init__
()
attention_dim
=
args
.
encoder_embed_dim
self
.
embed_scale
=
math
.
sqrt
(
attention_dim
)
if
args
.
no_scale_embedding
:
self
.
embed_scale
=
1.0
self
.
padding_idx
=
dictionary
.
pad_index
self
.
dropout_module
=
FairseqDropout
(
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
adapter_type
=
getattr
(
args
,
"adapter"
,
"league"
)
self
.
adapter_type
=
adapter_type
if
adapter_type
in
[
"linear"
,
"league"
]:
self
.
linear_adapter
=
nn
.
Sequential
(
nn
.
Linear
(
attention_dim
,
attention_dim
),
LayerNorm
(
args
.
encoder_embed_dim
),
self
.
dropout_module
,
nn
.
ReLU
(),
)
elif
adapter_type
==
"linear2"
:
self
.
linear_adapter
=
nn
.
Sequential
(
nn
.
Linear
(
attention_dim
,
attention_dim
),
self
.
dropout_module
,
)
elif
adapter_type
==
"subsample"
:
self
.
subsample_adaptor
=
Conv1dSubsampler
(
attention_dim
,
args
.
conv_channels
,
attention_dim
,
[
int
(
k
)
for
k
in
args
.
conv_kernel_sizes
.
split
(
","
)],
)
if
adapter_type
in
[
"embed"
,
"context"
,
"league"
,
"gated_league"
]:
if
embed_tokens
is
None
:
num_embeddings
=
len
(
dictionary
)
self
.
embed_adapter
=
Embedding
(
num_embeddings
,
attention_dim
,
self
.
padding_idx
)
else
:
self
.
embed_adapter
=
embed_tokens
if
adapter_type
==
"gated_league"
:
self
.
gate_linear
=
nn
.
Linear
(
2
*
attention_dim
,
attention_dim
)
elif
adapter_type
==
"gated_league2"
:
self
.
gate_linear1
=
nn
.
Linear
(
attention_dim
,
attention_dim
)
self
.
gate_linear2
=
nn
.
Linear
(
attention_dim
,
attention_dim
)
attn_type
=
getattr
(
args
,
"text_encoder_attention_type"
,
"selfattn"
)
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
,
pos_emb_type
=
attn_type
)
def
forward
(
self
,
x
,
padding
):
representation
,
distribution
=
x
batch
,
seq_len
,
embed_dim
=
distribution
.
size
()
lengths
=
(
~
padding
)
.
long
()
.
sum
(
-
1
)
if
self
.
adapter_type
==
"linear"
:
out
=
self
.
linear_adapter
(
representation
)
elif
self
.
adapter_type
==
"context"
:
out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
elif
self
.
adapter_type
==
"subsample"
:
out
=
self
.
subsample_adaptor
(
x
,
lengths
)
elif
self
.
adapter_type
==
"league"
:
linear_out
=
self
.
linear_adapter
(
representation
)
soft_out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
out
=
linear_out
+
soft_out
elif
self
.
adapter_type
==
"gated_league"
:
linear_out
=
self
.
linear_adapter
(
representation
)
soft_out
=
self
.
embed_adapter
(
distribution
)
coef
=
(
self
.
gate_linear
(
torch
.
cat
([
linear_out
,
soft_out
],
dim
=-
1
)))
.
sigmoid
()
out
=
coef
*
linear_out
+
(
1
-
coef
)
*
soft_out
else
:
out
=
None
logging
.
error
(
"Unsupported adapter type: {}."
.
format
(
self
.
adapter_type
))
out
=
self
.
embed_scale
*
out
positions
=
self
.
embed_positions
(
padding
)
.
transpose
(
0
,
1
)
out
=
positions
+
out
out
=
self
.
dropout_module
(
out
)
return
out
,
positions
class
TextEncoder
(
FairseqEncoder
):
def
__init__
(
self
,
args
,
embed_tokens
=
None
):
super
()
.
__init__
(
None
)
self
.
embed_tokens
=
embed_tokens
self
.
layers
=
nn
.
ModuleList
(
[
TransformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
text_encoder_layers
)]
)
if
args
.
encoder_normalize_before
:
self
.
layer_norm
=
LayerNorm
(
args
.
encoder_embed_dim
)
else
:
self
.
layer_norm
=
None
def
forward
(
self
,
x
,
encoder_padding_mask
=
None
,
positions
=
None
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
return
x
class
S2TSATEEncoder
(
FairseqEncoder
):
"""Speech-to-text Conformer encoder that consists of input subsampler and
Transformer encoder."""
def
__init__
(
self
,
args
,
task
=
None
,
embed_tokens
=
None
):
super
()
.
__init__
(
None
)
# acoustic encoder
acoustic_encoder_type
=
getattr
(
args
,
"acoustic_encoder"
,
"transformer"
)
if
acoustic_encoder_type
==
"transformer"
:
self
.
acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
elif
acoustic_encoder_type
==
"conformer"
:
self
.
acoustic_encoder
=
S2TConformerEncoder
(
args
,
task
,
embed_tokens
)
else
:
logging
.
error
(
"Unsupported model arch {}!"
.
format
(
acoustic_encoder_type
))
# adapter
self
.
adapter
=
Adapter
(
args
,
task
.
source_dictionary
,
embed_tokens
)
# text encoder
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
encoder_out
=
acoustic_encoder_out
[
"encoder_out"
][
0
]
encoder_padding_mask
=
acoustic_encoder_out
[
"encoder_padding_mask"
][
0
]
ctc_logit
=
self
.
acoustic_encoder
.
compute_ctc_logit
(
encoder_out
)
ctc_prob
=
self
.
acoustic_encoder
.
compute_ctc_prob
(
encoder_out
)
x
=
(
encoder_out
,
ctc_prob
)
x
,
positions
=
self
.
adapter
(
x
,
encoder_padding_mask
)
x
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
positions
)
return
{
"ctc_logit"
:
[
ctc_logit
],
# T x B x C
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[],
# B x T x C
"encoder_states"
:
[],
# List[T x B x C]
"src_tokens"
:
[],
"src_lengths"
:
[],
}
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
new_encoder_out
=
(
[]
if
len
(
encoder_out
[
"encoder_out"
])
==
0
else
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"encoder_out"
]]
)
new_encoder_padding_mask
=
(
[]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
else
[
x
.
index_select
(
0
,
new_order
)
for
x
in
encoder_out
[
"encoder_padding_mask"
]]
)
new_encoder_embedding
=
(
[]
if
len
(
encoder_out
[
"encoder_embedding"
])
==
0
else
[
x
.
index_select
(
0
,
new_order
)
for
x
in
encoder_out
[
"encoder_embedding"
]]
)
encoder_states
=
encoder_out
[
"encoder_states"
]
if
len
(
encoder_states
)
>
0
:
for
idx
,
state
in
enumerate
(
encoder_states
):
encoder_states
[
idx
]
=
state
.
index_select
(
1
,
new_order
)
return
{
"encoder_out"
:
new_encoder_out
,
# T x B x C
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
[],
# B x T
"src_lengths"
:
[],
# B x 1
}
@register_model_architecture
(
model_name
=
"s2t_sate"
,
arch_name
=
"s2t_sate"
)
def
base_architecture
(
args
):
# Convolutional subsampler
args
.
conv_kernel_sizes
=
getattr
(
args
,
"conv_kernel_sizes"
,
"5,5"
)
args
.
conv_channels
=
getattr
(
args
,
"conv_channels"
,
1024
)
# Conformer
args
.
macaron_style
=
getattr
(
args
,
"macaron_style"
,
True
)
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
True
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
True
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"acoustic_encoder"
,
"transformer"
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"adapter"
,
"league"
)
args
.
text_encoder_layers
=
getattr
(
args
,
"text_encoder_layers"
,
6
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
True
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
args
.
dropout
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
args
.
dropout
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0.0
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s"
)
def
s2t_sate_s
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
8
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
4
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_xs"
)
def
s2t_sate_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
3
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
256
*
4
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.3
)
s2t_sate_s
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_sp"
)
def
s2t_sate_sp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_sate_s
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_m"
)
def
s2t_sate_m
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
512
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
8
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.15
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_mp"
)
def
s2t_sate_mp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_sate_m
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_l"
)
def
s2t_sate_l
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
1024
*
4
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_lp"
)
def
s2t_sate_lp
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
16
)
s2t_sate_l
(
args
)
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
4cffbd98
...
...
@@ -247,26 +247,28 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
return
encoder
@classmethod
def
build_decoder
(
cls
,
args
,
task
,
embed_tokens
):
decoder
=
TransformerDecoderScriptable
(
args
,
task
.
target_dictionary
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_decoder_from"
,
None
):
decoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
decoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
logger
.
info
(
f
"loaded pretrained decoder from: "
f
"{args.load_pretrained_decoder_from}"
)
decoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
decoder
,
checkpoint
=
args
.
load_pretrained_decoder_from
,
strict
=
False
)
return
decoder
@classmethod
...
...
@@ -346,7 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder):
args
.
max_source_positions
,
args
.
encoder_embed_dim
,
self
.
padding_idx
,
pos_emb_type
=
self
.
attn_type
)
self
.
transformer_
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
TransformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
encoder_layers
)]
)
if
args
.
encoder_normalize_before
:
...
...
@@ -372,6 +374,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
ctc_dropout_module
=
FairseqDropout
(
p
=
args
.
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
...
...
@@ -384,7 +387,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x
=
self
.
dropout_module
(
x
)
positions
=
self
.
dropout_module
(
positions
)
for
layer
in
self
.
transformer_
layers
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
self
.
layer_norm
is
not
None
:
...
...
@@ -404,17 +407,20 @@ class S2TTransformerEncoder(FairseqEncoder):
def
compute_ctc_logit
(
self
,
encoder_out
):
assert
self
.
use_ctc
,
"CTC is not available!"
if
isinstance
(
encoder_out
,
dict
)
and
"encoder_out"
in
encoder_out
:
encoder_state
=
encoder_out
[
"encoder_out"
][
0
]
else
:
encoder_state
=
encoder_out
ctc_logit
=
self
.
ctc_projection
(
self
.
ctc_dropout_module
(
encoder_state
))
return
ctc_logit
def
compute_ctc_prob
(
self
,
encoder_out
):
def
compute_ctc_prob
(
self
,
encoder_out
,
temperature
=
1.0
):
assert
self
.
use_ctc
,
"CTC is not available!"
ctc_logit
=
self
.
compute_ctc_logit
(
encoder_out
)
ctc_logit
=
self
.
compute_ctc_logit
(
encoder_out
)
/
temperature
return
ctc_logit
.
Softmax
(
dim
=-
1
)
return
self
.
softmax
(
ctc_logit
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
new_encoder_out
=
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论