Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
f190005c
Commit
f190005c
authored
Apr 21, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support the rpe for encoder and decoder respectively
parent
28d33ad8
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
95 行增加
和
21 行删除
+95
-21
fairseq/models/dlcl_transformer.py
+5
-2
fairseq/models/speech_to_text/s2t_conformer.py
+4
-2
fairseq/models/speech_to_text/s2t_sate.py
+31
-2
fairseq/models/speech_to_text/s2t_transformer.py
+7
-3
fairseq/models/transformer.py
+8
-4
fairseq/modules/conformer_layer.py
+22
-1
fairseq/modules/transformer_layer.py
+18
-7
没有找到文件。
fairseq/models/dlcl_transformer.py
查看文件 @
f190005c
...
...
@@ -532,13 +532,16 @@ def base_architecture(args):
args
.
encoder_integration_type
=
getattr
(
args
,
'encoder_integration_type'
,
'avg'
)
args
.
decoder_integration_type
=
getattr
(
args
,
'decoder_integration_type'
,
'avg'
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"dlcl_transformer"
,
"dlcl_transformer_relative"
)
def
dlcl_transformer_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
20
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
base_architecture
(
args
)
...
...
fairseq/models/speech_to_text/s2t_conformer.py
查看文件 @
f190005c
...
...
@@ -185,7 +185,8 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
...
...
@@ -201,7 +202,8 @@ def s2t_conformer_s(args):
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_s_relative"
)
def
s2t_conformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
100
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
s2t_conformer_s
(
args
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
f190005c
...
...
@@ -6,6 +6,7 @@ import math
import
torch
import
torch.nn
as
nn
from
fairseq
import
checkpoint_utils
from
fairseq.data.data_utils
import
lengths_to_padding_mask
from
fairseq.models
import
(
FairseqEncoder
,
register_model
,
...
...
@@ -82,6 +83,15 @@ class S2TSATEModel(S2TTransformerModel):
def
build_encoder
(
cls
,
args
,
task
=
None
,
embed_tokens
=
None
):
encoder
=
S2TSATEEncoder
(
args
,
task
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained acoustic encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
if
getattr
(
args
,
"load_pretrained_acoustic_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained acoustic encoder from: "
...
...
@@ -202,6 +212,7 @@ class TextEncoder(FairseqEncoder):
super
()
.
__init__
(
None
)
self
.
embed_tokens
=
embed_tokens
self
.
layers
=
nn
.
ModuleList
(
[
TransformerEncoderLayer
(
args
)
for
_
in
range
(
args
.
text_encoder_layers
)]
)
...
...
@@ -247,8 +258,19 @@ class S2TSATEEncoder(FairseqEncoder):
# adapter
self
.
adapter
=
Adapter
(
args
,
task
.
source_dictionary
,
embed_tokens
)
# self.length_adapter = Conv1dSubsampler(
# args.encoder_embed_dim,
# args.conv_channels,
# args.encoder_embed_dim,
# [int(k) for k in args.conv_kernel_sizes.split(",")],
# )
# acoustic_encoder_attention_type = args.encoder_attention_type
# args.encoder_attention_type = "selfattn"
# text encoder
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
# args.encoder_attention_type = acoustic_encoder_attention_type
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
normalize_before
=
args
.
encoder_normalize_before
...
...
@@ -283,6 +305,11 @@ class S2TSATEEncoder(FairseqEncoder):
self
.
history
.
add
(
x
)
# src_lengths = (~encoder_padding_mask).sum(1)
# x = x.transpose(0, 1)
# x, input_lengths = self.length_adapter(x, src_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
x
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
positions
,
self
.
history
)
return
{
...
...
@@ -375,7 +402,8 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
...
...
@@ -391,7 +419,8 @@ def s2t_sate_s(args):
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s_relative"
)
def
s2t_sate_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
100
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
s2t_sate_s
(
args
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
f190005c
...
...
@@ -220,7 +220,9 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action
=
"store_true"
,
help
=
"if True, dont scale embeddings"
,
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
parser
.
add_argument
(
'--max-encoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--max-decoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
...
...
@@ -567,7 +569,8 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
...
...
@@ -583,7 +586,8 @@ def s2t_transformer_s(args):
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s_relative"
)
def
s2t_transformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
20
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
s2t_transformer_s
(
args
)
...
...
fairseq/models/transformer.py
查看文件 @
f190005c
...
...
@@ -218,8 +218,10 @@ class TransformerModel(FairseqEncoderDecoderModel):
],
help
=
"transformer decoder self-attention layer type"
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--max-encoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max encoder relative length'
)
parser
.
add_argument
(
'--max-decoder-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max decoder relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
# args for loading pre-trained models
...
...
@@ -1182,13 +1184,15 @@ def base_architecture(args):
args
.
encoder_attention_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
args
.
decoder_attention_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
max_encoder_relative_length
=
getattr
(
args
,
'max_encoder_relative_length'
,
-
1
)
args
.
max_decoder_relative_length
=
getattr
(
args
,
'max_decoder_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"transformer"
,
"transformer_relative"
)
def
transformer_rpr
(
args
):
args
.
max_relative_length
=
20
args
.
max_encoder_relative_length
=
20
args
.
max_decoder_relative_length
=
20
args
.
k_only
=
True
base_architecture
(
args
)
...
...
fairseq/modules/conformer_layer.py
查看文件 @
f190005c
...
...
@@ -8,7 +8,13 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq.modules
import
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
ConvolutionModule
from
fairseq.modules
import
(
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
,
ConvolutionModule
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
...
...
@@ -112,6 +118,21 @@ class ConformerEncoderLayer(nn.Module):
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
)
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
attn_func
=
MultiheadAttention
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
...
...
fairseq/modules/transformer_layer.py
查看文件 @
f190005c
...
...
@@ -87,7 +87,10 @@ class TransformerEncoderLayer(nn.Module):
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
elif
self
.
attn_type
==
"relative"
:
# max_relative_length = getattr(args, "max_encoder_relative_length", -1)
max_relative_length
=
max
(
getattr
(
args
,
"max_encoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
...
...
@@ -95,10 +98,13 @@ class TransformerEncoderLayer(nn.Module):
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
args
.
max_relative_length
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The maximum encoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
print
(
"The encoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
...
...
@@ -292,18 +298,23 @@ class TransformerDecoderLayer(nn.Module):
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
elif
self
.
attn_type
==
"relative"
:
max_relative_length
=
max
(
getattr
(
args
,
"max_decoder_relative_length"
,
-
1
),
getattr
(
args
,
"max_relative_length"
,
-
1
))
if
max_relative_length
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
en
coder_attention_heads
,
args
.
de
coder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
args
.
max_relative_length
,
max_relative_length
=
max_relative_length
,
)
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The maximum decoder relative length
%
d can not be -1!"
%
max_relative_length
)
exit
(
1
)
else
:
print
(
"The decoder attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
return
attn_func
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论