Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
b23817e0
Commit
b23817e0
authored
Sep 03, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement the pyramid transformer
parent
7802e6f7
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
22 行增加
和
6 行删除
+22
-6
fairseq/models/speech_to_text/__init__.py
+1
-0
fairseq/models/speech_to_text/pys2t_transformer.py
+0
-0
fairseq/models/speech_to_text/s2t_sate.py
+15
-4
fairseq/models/speech_to_text/s2t_transformer.py
+2
-1
fairseq/modules/__init__.py
+4
-0
fairseq/modules/local_multihead_attention.py
+0
-1
fairseq/modules/pyramid_layer.py
+0
-0
fairseq/modules/reduced_multihead_attention.py
+0
-0
没有找到文件。
fairseq/models/speech_to_text/__init__.py
查看文件 @
b23817e0
...
@@ -7,4 +7,5 @@ from .berard import * # noqa
...
@@ -7,4 +7,5 @@ from .berard import * # noqa
from
.convtransformer
import
*
# noqa
from
.convtransformer
import
*
# noqa
from
.s2t_transformer
import
*
# noqa
from
.s2t_transformer
import
*
# noqa
from
.s2t_conformer
import
*
# noqa
from
.s2t_conformer
import
*
# noqa
from
.pys2t_transformer
import
*
# noqa
from
.s2t_sate
import
*
# noqa
from
.s2t_sate
import
*
# noqa
fairseq/models/speech_to_text/pys2t_transformer.py
0 → 100644
查看文件 @
b23817e0
差异被折叠。
点击展开。
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
b23817e0
...
@@ -17,7 +17,9 @@ from fairseq.models.speech_to_text import (
...
@@ -17,7 +17,9 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel
,
S2TTransformerModel
,
S2TTransformerEncoder
,
S2TTransformerEncoder
,
S2TConformerEncoder
,
S2TConformerEncoder
,
S2TConformerModel
S2TConformerModel
,
PYS2TTransformerModel
,
PyS2TTransformerEncoder
,
)
)
from
fairseq.models.speech_to_text.s2t_transformer
import
Conv1dSubsampler
from
fairseq.models.speech_to_text.s2t_transformer
import
Conv1dSubsampler
from
fairseq.modules
import
(
from
fairseq.modules
import
(
...
@@ -46,6 +48,7 @@ class S2TSATEModel(S2TTransformerModel):
...
@@ -46,6 +48,7 @@ class S2TSATEModel(S2TTransformerModel):
def
add_args
(
parser
):
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
"""Add model-specific arguments to the parser."""
S2TConformerModel
.
add_args
(
parser
)
S2TConformerModel
.
add_args
(
parser
)
PYS2TTransformerModel
.
add_args
(
parser
)
parser
.
add_argument
(
parser
.
add_argument
(
"--text-encoder-layers"
,
"--text-encoder-layers"
,
...
@@ -195,13 +198,16 @@ class Adapter(nn.Module):
...
@@ -195,13 +198,16 @@ class Adapter(nn.Module):
linear_out
=
self
.
linear_adapter
(
representation
)
linear_out
=
self
.
linear_adapter
(
representation
)
soft_out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
soft_out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
out
=
linear_out
+
soft_out
out
=
linear_out
+
soft_out
elif
self
.
adapter_type
==
"gated_league"
:
elif
self
.
adapter_type
==
"gated_league"
:
linear_out
=
self
.
linear_adapter
(
representation
)
linear_out
=
self
.
linear_adapter
(
representation
)
soft_out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
soft_out
=
torch
.
mm
(
distribution
.
view
(
-
1
,
embed_dim
),
self
.
embed_adapter
.
weight
)
.
view
(
batch
,
seq_len
,
-
1
)
coef
=
(
self
.
gate_linear
(
torch
.
cat
([
linear_out
,
soft_out
],
dim
=-
1
)))
.
sigmoid
()
coef
=
(
self
.
gate_linear
(
torch
.
cat
([
linear_out
,
soft_out
],
dim
=-
1
)))
.
sigmoid
()
out
=
coef
*
linear_out
+
(
1
-
coef
)
*
soft_out
out
=
coef
*
linear_out
+
(
1
-
coef
)
*
soft_out
elif
self
.
adapter_type
==
"none"
:
elif
self
.
adapter_type
==
"none"
:
out
=
representation
out
=
representation
else
:
else
:
out
=
None
out
=
None
logging
.
error
(
"Unsupported adapter type: {}."
.
format
(
self
.
adapter_type
))
logging
.
error
(
"Unsupported adapter type: {}."
.
format
(
self
.
adapter_type
))
...
@@ -262,6 +268,8 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -262,6 +268,8 @@ class S2TSATEEncoder(FairseqEncoder):
self
.
acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
self
.
acoustic_encoder
=
S2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
elif
acoustic_encoder_type
==
"conformer"
:
elif
acoustic_encoder_type
==
"conformer"
:
self
.
acoustic_encoder
=
S2TConformerEncoder
(
args
,
task
,
embed_tokens
)
self
.
acoustic_encoder
=
S2TConformerEncoder
(
args
,
task
,
embed_tokens
)
elif
acoustic_encoder_type
==
"pyramid"
:
self
.
acoustic_encoder
=
PyS2TTransformerEncoder
(
args
,
task
,
embed_tokens
)
else
:
else
:
logging
.
error
(
"Unsupported model arch {}!"
.
format
(
acoustic_encoder_type
))
logging
.
error
(
"Unsupported model arch {}!"
.
format
(
acoustic_encoder_type
))
...
@@ -277,9 +285,9 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -277,9 +285,9 @@ class S2TSATEEncoder(FairseqEncoder):
# )
# )
acoustic_encoder_attention_type
=
args
.
encoder_attention_type
acoustic_encoder_attention_type
=
args
.
encoder_attention_type
if
acoustic_encoder_attention_type
!=
"selfattn"
:
#
if acoustic_encoder_attention_type != "selfattn":
args
.
encoder_attention_type
=
"selfattn"
#
args.encoder_attention_type = "selfattn"
logger
.
info
(
"Force self attention for text encoder."
)
#
logger.info("Force self attention for text encoder.")
# text encoder
# text encoder
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
...
@@ -378,6 +386,9 @@ def base_architecture(args):
...
@@ -378,6 +386,9 @@ def base_architecture(args):
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
False
)
args
.
use_cnn_module
=
getattr
(
args
,
"use_cnn_module"
,
False
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
args
.
cnn_module_kernel
=
getattr
(
args
,
"cnn_module_kernel"
,
31
)
# Pyramid
args
.
pyramid_layers
=
getattr
(
args
,
"pyramid_layers"
,
None
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
b23817e0
...
@@ -147,10 +147,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -147,10 +147,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
type
=
str
,
type
=
str
,
default
=
"selfattn"
,
default
=
"selfattn"
,
choices
=
[
choices
=
[
"local"
,
"selfattn"
,
"selfattn"
,
"reduced"
,
"rel_selfattn"
,
"rel_selfattn"
,
"relative"
,
"relative"
,
"local"
,
],
],
help
=
"transformer encoder self-attention layer type"
help
=
"transformer encoder self-attention layer type"
)
)
...
...
fairseq/modules/__init__.py
查看文件 @
b23817e0
...
@@ -29,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
...
@@ -29,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from
.local_multihead_attention
import
LocalMultiheadAttention
from
.local_multihead_attention
import
LocalMultiheadAttention
from
.multihead_attention
import
MultiheadAttention
from
.multihead_attention
import
MultiheadAttention
from
.positional_embedding
import
PositionalEmbedding
from
.positional_embedding
import
PositionalEmbedding
from
.reduced_multihead_attention
import
ReducedMultiheadAttention
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
from
.relative_multihead_attention
import
RelativeMultiheadAttention
from
.relative_multihead_attention
import
RelativeMultiheadAttention
from
.same_pad
import
SamePad
from
.same_pad
import
SamePad
...
@@ -41,6 +42,7 @@ from .unfold import unfold1d
...
@@ -41,6 +42,7 @@ from .unfold import unfold1d
from
.transformer_layer
import
TransformerDecoderLayer
,
TransformerEncoderLayer
from
.transformer_layer
import
TransformerDecoderLayer
,
TransformerEncoderLayer
from
.vggblock
import
VGGBlock
from
.vggblock
import
VGGBlock
from
.conformer_layer
import
ConformerEncoderLayer
from
.conformer_layer
import
ConformerEncoderLayer
from
.pyramid_layer
import
PyramidTransformerEncoderLayer
__all__
=
[
__all__
=
[
"AdaptiveInput"
,
"AdaptiveInput"
,
...
@@ -74,6 +76,8 @@ __all__ = [
...
@@ -74,6 +76,8 @@ __all__ = [
"LocalMultiheadAttention"
,
"LocalMultiheadAttention"
,
"MultiheadAttention"
,
"MultiheadAttention"
,
"PositionalEmbedding"
,
"PositionalEmbedding"
,
"PyramidTransformerEncoderLayer"
,
"ReducedMultiheadAttention"
,
"RelPositionMultiheadAttention"
,
"RelPositionMultiheadAttention"
,
"RelativeMultiheadAttention"
,
"RelativeMultiheadAttention"
,
"SamePad"
,
"SamePad"
,
...
...
fairseq/modules/local_multihead_attention.py
查看文件 @
b23817e0
...
@@ -325,7 +325,6 @@ class LocalMultiheadAttention(nn.Module):
...
@@ -325,7 +325,6 @@ class LocalMultiheadAttention(nn.Module):
multihead_mask_weight
=
None
multihead_mask_weight
=
None
gauss_bias
=
None
gauss_bias
=
None
if
self
.
multihead_gauss_mask_sigma
is
not
None
:
if
self
.
multihead_gauss_mask_sigma
is
not
None
:
data_type
=
attn_weights
.
dtype
x1
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
-
1
,
1
)
.
to
(
attn_weights
.
device
)
x1
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
-
1
,
1
)
.
to
(
attn_weights
.
device
)
x2
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
1
,
-
1
)
.
to
(
attn_weights
.
device
)
x2
=
torch
.
arange
(
-
1
,
src_len
-
1
,
1
)
.
view
(
1
,
-
1
)
.
to
(
attn_weights
.
device
)
dis_square
=
-
(
x1
-
x2
)
**
2
/
2.0
dis_square
=
-
(
x1
-
x2
)
**
2
/
2.0
...
...
fairseq/modules/pyramid_layer.py
0 → 100644
查看文件 @
b23817e0
差异被折叠。
点击展开。
fairseq/modules/reduced_multihead_attention.py
0 → 100644
查看文件 @
b23817e0
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论