Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
f190005c
Commit
f190005c
authored
Apr 21, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support the rpe for encoder and decoder respectively
parent
28d33ad8
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
110 行增加
和
36 行删除
+110
-36
fairseq/models/dlcl_transformer.py
+5
-2
fairseq/models/speech_to_text/s2t_conformer.py
+4
-2
fairseq/models/speech_to_text/s2t_sate.py
+31
-2
fairseq/models/speech_to_text/s2t_transformer.py
+7
-3
fairseq/models/transformer.py
+8
-4
fairseq/modules/conformer_layer.py
+22
-1
fairseq/modules/transformer_layer.py
+33
-22
没有找到文件。
fairseq/models/dlcl_transformer.py
查看文件 @
f190005c
...
@@ -532,13 +532,16 @@ def base_architecture(args):
...
@@ -532,13 +532,16 @@ def base_architecture(args):
args
.
encoder_integration_type
=
getattr
(
args
,
'encoder_integration_type'
,
'avg'
)
args
.
encoder_integration_type
=
getattr
(
args
,
'encoder_integration_type'
,
'avg'
)
args
.
decoder_integration_type
=
getattr
(
args
,
'decoder_integration_type'
,
'avg'
)
args
.
decoder_integration_type
=
getattr
(
args
,
'decoder_integration_type'
,
'avg'
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_relative"
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_relative"
)
def
dlcl_transformer_relative
(
args
):
def
dlcl_transformer_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
20
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
args
.
k_only
=
True
base_architecture
(
args
)
base_architecture
(
args
)
...
...
fairseq/models/speech_to_text/s2t_conformer.py
查看文件 @
f190005c
...
@@ -185,7 +185,8 @@ def base_architecture(args):
...
@@ -185,7 +185,8 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
...
@@ -201,7 +202,8 @@ def s2t_conformer_s(args):
...
@@ -201,7 +202,8 @@ def s2t_conformer_s(args):
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_s_relative"
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_s_relative"
)
def
s2t_conformer_s_relative
(
args
):
def
s2t_conformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
100
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
args
.
k_only
=
True
s2t_conformer_s
(
args
)
s2t_conformer_s
(
args
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
f190005c
...
@@ -6,6 +6,7 @@ import math
...
@@ -6,6 +6,7 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairseq
import
checkpoint_utils
from
fairseq
import
checkpoint_utils
from
fairseq.data.data_utils
import
lengths_to_padding_mask
from
fairseq.models
import
(
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoder
,
register_model
,
register_model
,
...
@@ -82,6 +83,15 @@ class S2TSATEModel(S2TTransformerModel):
...
@@ -82,6 +83,15 @@ class S2TSATEModel(S2TTransformerModel):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TSATEEncoder
(
args
,
task
,
embed_tokens
)
encoder
=
S2TSATEEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained acoustic encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
if
getattr
(
args
,
"load_pretrained_acoustic_encoder_from"
,
None
):
if
getattr
(
args
,
"load_pretrained_acoustic_encoder_from"
,
None
):
logger
.
info
(
logger
.
info
(
f
"loaded pretrained acoustic encoder from: "
f
"loaded pretrained acoustic encoder from: "
...
@@ -202,6 +212,7 @@ class TextEncoder(FairseqEncoder):
...
@@ -202,6 +212,7 @@ class TextEncoder(FairseqEncoder):
super
()
.
__init__
(
None
)
super
()
.
__init__
(
None
)
self
.
embed_tokens
=
embed_tokens
self
.
embed_tokens
=
embed_tokens
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
TransformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
text_encoder_layers
)]
[
TransformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
text_encoder_layers
)]
)
)
...
@@ -247,8 +258,19 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -247,8 +258,19 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter
# adapter
self
.
adapter
=
Adapter
(
args
,
task
.
source_dictionary
,
embed_tokens
)
self
.
adapter
=
Adapter
(
args
,
task
.
source_dictionary
,
embed_tokens
)
# self.length_adapter = Conv1dSubsampler(
# args.encoder_embed_dim,
# args.conv_channels,
# args.encoder_embed_dim,
# [int(k) for k in args.conv_kernel_sizes.split(",")],
# )
# acoustic_encoder_attention_type = args.encoder_attention_type
# args.encoder_attention_type = "selfattn"
# text encoder
# text encoder
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
# args.encoder_attention_type = acoustic_encoder_attention_type
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
normalize_before
=
args
.
encoder_normalize_before
normalize_before
=
args
.
encoder_normalize_before
...
@@ -283,6 +305,11 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -283,6 +305,11 @@ class S2TSATEEncoder(FairseqEncoder):
self
.
history
.
add
(
x
)
self
.
history
.
add
(
x
)
# src_lengths = (~encoder_padding_mask).sum(1)
# x = x.transpose(0, 1)
# x, input_lengths = self.length_adapter(x, src_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
positions
,
self
.
history
)
x
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
positions
,
self
.
history
)
return
{
return
{
...
@@ -375,7 +402,8 @@ def base_architecture(args):
...
@@ -375,7 +402,8 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
...
@@ -391,7 +419,8 @@ def s2t_sate_s(args):
...
@@ -391,7 +419,8 @@ def s2t_sate_s(args):
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s_relative"
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s_relative"
)
def
s2t_sate_s_relative
(
args
):
def
s2t_sate_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
100
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
args
.
k_only
=
True
s2t_sate_s
(
args
)
s2t_sate_s
(
args
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
f190005c
...
@@ -220,7 +220,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -220,7 +220,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"if True, dont scale embeddings"
,
help
=
"if True, dont scale embeddings"
,
)
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
parser
.
add_argument
(
'--max-encoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--max-decoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
help
=
'the max relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
help
=
'select the relative mode to map relative position information'
)
...
@@ -567,7 +569,8 @@ def base_architecture(args):
...
@@ -567,7 +569,8 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
...
@@ -583,7 +586,8 @@ def s2t_transformer_s(args):
...
@@ -583,7 +586,8 @@ def s2t_transformer_s(args):
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s_relative"
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s_relative"
)
def
s2t_transformer_s_relative
(
args
):
def
s2t_transformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
20
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
args
.
k_only
=
True
s2t_transformer_s
(
args
)
s2t_transformer_s
(
args
)
...
...
fairseq/models/transformer.py
查看文件 @
f190005c
...
@@ -218,8 +218,10 @@ class TransformerModel(FairseqEncoderDecoderModel):
...
@@ -218,8 +218,10 @@ class TransformerModel(FairseqEncoderDecoderModel):
],
],
help
=
"transformer decoder self-attention layer type"
help
=
"transformer decoder self-attention layer type"
)
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
parser
.
add_argument
(
'--max-encoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
help
=
'the max encoder relative length'
)
parser
.
add_argument
(
'--max-decoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max decoder relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
help
=
'select the relative mode to map relative position information'
)
# args for loading pre-trained models
# args for loading pre-trained models
...
@@ -1182,13 +1184,15 @@ def base_architecture(args):
...
@@ -1182,13 +1184,15 @@ def base_architecture(args):
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"transformer"
,
"transformer_relative"
)
@register_model_architecture
(
"transformer"
,
"transformer_relative"
)
def
transformer_rpr
(
args
):
def
transformer_rpr
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
20
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
args
.
k_only
=
True
base_architecture
(
args
)
base_architecture
(
args
)
...
...
fairseq/modules/conformer_layer.py
查看文件 @
f190005c
...
@@ -8,7 +8,13 @@ from typing import Optional
...
@@ -8,7 +8,13 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq
import
utils
from
fairseq.modules
import
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
ConvolutionModule
from
fairseq.modules
import
(
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
,
ConvolutionModule
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -112,6 +118,21 @@ class ConformerEncoderLayer(nn.Module):
...
@@ -112,6 +118,21 @@ class ConformerEncoderLayer(nn.Module):
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
)
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
else
:
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
...
...
fairseq/modules/transformer_layer.py
查看文件 @
f190005c
...
@@ -87,18 +87,24 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -87,18 +87,24 @@ class TransformerEncoderLayer(nn.Module):
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
elif
self
.
attn_type
==
"relative"
:
return
RelativeMultiheadAttention
(
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
embed_dim
,
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
args
.
encoder_attention_heads
,
if
max_relative_length
!=
-
1
:
dropout
=
args
.
attention_dropout
,
return
RelativeMultiheadAttention
(
self_attention
=
True
,
embed_dim
,
q_noise
=
self
.
quant_noise
,
args
.
encoder_attention_heads
,
qn_block_size
=
self
.
quant_noise_block_size
,
dropout
=
args
.
attention_dropout
,
max_relative_length
=
args
.
max_relative_length
,
self_attention
=
True
,
)
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The
encoder
attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
return
attn_func
(
return
attn_func
(
...
@@ -292,18 +298,23 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -292,18 +298,23 @@ class TransformerDecoderLayer(nn.Module):
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
elif
self
.
attn_type
==
"relative"
:
return
RelativeMultiheadAttention
(
max_relative_length
=
max
(
getattr
(
args
,
"max_decoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
embed_dim
,
if
max_relative_length
!=
-
1
:
args
.
encoder_attention_heads
,
return
RelativeMultiheadAttention
(
dropout
=
args
.
attention_dropout
,
embed_dim
,
self_attention
=
True
,
args
.
decoder_attention_heads
,
q_noise
=
self
.
quant_noise
,
dropout
=
args
.
attention_dropout
,
qn_block_size
=
self
.
quant_noise_block_size
,
self_attention
=
True
,
max_relative_length
=
args
.
max_relative_length
,
q_noise
=
self
.
quant_noise
,
)
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum decoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The
decoder
attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
return
attn_func
(
return
attn_func
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论