Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
f0e3290f
Commit
f0e3290f
authored
Apr 10, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add the traditional relative multihead attention
parent
a3e0f4c2
全部展开
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
447 行增加
和
4 行删除
+447
-4
fairseq/models/speech_to_text/s2t_conformer.py
+7
-0
fairseq/models/speech_to_text/s2t_sate.py
+7
-0
fairseq/models/speech_to_text/s2t_transformer.py
+14
-0
fairseq/models/transformer.py
+73
-3
fairseq/modules/__init__.py
+4
-0
fairseq/modules/layer_history.py
+316
-0
fairseq/modules/relative_multihead_attention.py
+0
-0
fairseq/modules/transformer_layer.py
+26
-1
没有找到文件。
fairseq/models/speech_to_text/s2t_conformer.py
查看文件 @
f0e3290f
...
@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
...
@@ -182,6 +182,13 @@ def s2t_conformer_s(args):
base_architecture
(
args
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_s_relative"
)
def
s2t_conformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
s2t_conformer_s
(
args
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_xs"
)
@register_model_architecture
(
"s2t_conformer"
,
"s2t_conformer_xs"
)
def
s2t_conformer_xs
(
args
):
def
s2t_conformer_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
f0e3290f
...
@@ -357,6 +357,13 @@ def s2t_sate_s(args):
...
@@ -357,6 +357,13 @@ def s2t_sate_s(args):
base_architecture
(
args
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_s_relative"
)
def
s2t_sate_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
s2t_sate_s
(
args
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_xs"
)
@register_model_architecture
(
"s2t_sate"
,
"s2t_sate_xs"
)
def
s2t_sate_xs
(
args
):
def
s2t_sate_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
f0e3290f
...
@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -217,6 +217,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"if True, dont scale embeddings"
,
help
=
"if True, dont scale embeddings"
,
)
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--load-pretrained-encoder-from"
,
"--load-pretrained-encoder-from"
,
type
=
str
,
type
=
str
,
...
@@ -518,6 +522,9 @@ def base_architecture(args):
...
@@ -518,6 +522,9 @@ def base_architecture(args):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s"
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s"
)
def
s2t_transformer_s
(
args
):
def
s2t_transformer_s
(
args
):
...
@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
...
@@ -529,6 +536,13 @@ def s2t_transformer_s(args):
base_architecture
(
args
)
base_architecture
(
args
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_s_relative"
)
def
s2t_transformer_s_relative
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
s2t_transformer_s
(
args
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_xs"
)
@register_model_architecture
(
"s2t_transformer"
,
"s2t_transformer_xs"
)
def
s2t_transformer_xs
(
args
):
def
s2t_transformer_xs
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
...
...
fairseq/models/transformer.py
查看文件 @
f0e3290f
...
@@ -5,10 +5,11 @@
...
@@ -5,10 +5,11 @@
import
math
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
logging
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq
import
checkpoint_utils
,
utils
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models
import
(
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoder
,
...
@@ -35,6 +36,8 @@ from torch import Tensor
...
@@ -35,6 +36,8 @@ from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
logger
=
logging
.
getLogger
(
__name__
)
@register_model
(
"transformer"
)
@register_model
(
"transformer"
)
class
TransformerModel
(
FairseqEncoderDecoderModel
):
class
TransformerModel
(
FairseqEncoderDecoderModel
):
...
@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
...
@@ -191,6 +194,35 @@ class TransformerModel(FairseqEncoderDecoderModel):
help
=
'block size of quantization noise at training time'
)
help
=
'block size of quantization noise at training time'
)
parser
.
add_argument
(
'--quant-noise-scalar'
,
type
=
float
,
metavar
=
'D'
,
default
=
0
,
parser
.
add_argument
(
'--quant-noise-scalar'
,
type
=
float
,
metavar
=
'D'
,
default
=
0
,
help
=
'scalar quantization noise and scalar quantization at training time'
)
help
=
'scalar quantization noise and scalar quantization at training time'
)
parser
.
add_argument
(
'--max-relative-length'
,
type
=
int
,
default
=-
1
,
help
=
'the max relative length'
)
parser
.
add_argument
(
'--k-only'
,
default
=
False
,
action
=
'store_true'
,
help
=
'select the relative mode to map relative position information'
)
# args for loading pre-trained models
parser
.
add_argument
(
"--load-pretrained-encoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take encoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--load-pretrained-decoder-from"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"model to take decoder weights from (for initialization)"
,
)
parser
.
add_argument
(
"--encoder-freeze-module"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"freeze the module of the encoder"
,
)
parser
.
add_argument
(
"--decoder-freeze-module"
,
type
=
str
,
metavar
=
"STR"
,
help
=
"freeze the module of the decoder"
,
)
# fmt: on
# fmt: on
@classmethod
@classmethod
...
@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
...
@@ -240,7 +272,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
if
getattr
(
args
,
"offload_activations"
,
False
):
if
getattr
(
args
,
"offload_activations"
,
False
):
args
.
checkpoint_activations
=
True
# offloading implies checkpointing
args
.
checkpoint_activations
=
True
# offloading implies checkpointing
encoder
=
cls
.
build_encoder
(
args
,
src_dict
,
encoder_embed_tokens
)
encoder
=
cls
.
build_encoder
(
args
,
src_dict
,
encoder_embed_tokens
)
if
getattr
(
args
,
"encoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
encoder
,
args
.
encoder_freeze_module
)
logging
.
info
(
"freeze the encoder module: {}"
.
format
(
args
.
encoder_freeze_module
))
decoder
=
cls
.
build_decoder
(
args
,
tgt_dict
,
decoder_embed_tokens
)
decoder
=
cls
.
build_decoder
(
args
,
tgt_dict
,
decoder_embed_tokens
)
if
getattr
(
args
,
"decoder_freeze_module"
,
None
):
utils
.
freeze_parameters
(
decoder
,
args
.
decoder_freeze_module
)
logging
.
info
(
"freeze the decoder module: {}"
.
format
(
args
.
decoder_freeze_module
))
if
not
args
.
share_all_embeddings
:
if
not
args
.
share_all_embeddings
:
encoder
=
fsdp_wrap
(
encoder
,
min_num_params
=
1e8
)
encoder
=
fsdp_wrap
(
encoder
,
min_num_params
=
1e8
)
decoder
=
fsdp_wrap
(
decoder
,
min_num_params
=
1e8
)
decoder
=
fsdp_wrap
(
decoder
,
min_num_params
=
1e8
)
...
@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
...
@@ -260,17 +300,38 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod
@classmethod
def
build_encoder
(
cls
,
args
,
src_dict
,
embed_tokens
):
def
build_encoder
(
cls
,
args
,
src_dict
,
embed_tokens
):
return
TransformerEncoder
(
args
,
src_dict
,
embed_tokens
)
encoder
=
TransformerEncoder
(
args
,
src_dict
,
embed_tokens
)
if
getattr
(
args
,
"load_pretrained_encoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained encoder from: "
f
"{args.load_pretrained_encoder_from}"
)
encoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
encoder
,
checkpoint
=
args
.
load_pretrained_encoder_from
,
strict
=
False
)
return
encoder
@classmethod
@classmethod
def
build_decoder
(
cls
,
args
,
tgt_dict
,
embed_tokens
):
def
build_decoder
(
cls
,
args
,
tgt_dict
,
embed_tokens
):
return
TransformerDecoder
(
decoder
=
TransformerDecoder
(
args
,
args
,
tgt_dict
,
tgt_dict
,
embed_tokens
,
embed_tokens
,
no_encoder_attn
=
getattr
(
args
,
"no_cross_attention"
,
False
),
no_encoder_attn
=
getattr
(
args
,
"no_cross_attention"
,
False
),
)
)
if
getattr
(
args
,
"load_pretrained_decoder_from"
,
None
):
logger
.
info
(
f
"loaded pretrained decoder from: "
f
"{args.load_pretrained_decoder_from}"
)
decoder
=
checkpoint_utils
.
load_pretrained_component_from_model
(
component
=
decoder
,
checkpoint
=
args
.
load_pretrained_decoder_from
,
strict
=
False
)
return
decoder
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
# Current workaround is to add union of all arguments in child classes.
def
forward
(
def
forward
(
...
@@ -1073,6 +1134,15 @@ def base_architecture(args):
...
@@ -1073,6 +1134,15 @@ def base_architecture(args):
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0
)
args
.
quant_noise_pq_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
args
.
quant_noise_pq_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
args
.
quant_noise_scalar
=
getattr
(
args
,
"quant_noise_scalar"
,
0
)
args
.
quant_noise_scalar
=
getattr
(
args
,
"quant_noise_scalar"
,
0
)
args
.
max_relative_length
=
getattr
(
args
,
'max_relative_length'
,
-
1
)
args
.
k_only
=
getattr
(
args
,
'k_only'
,
True
)
@register_model_architecture
(
"transformer"
,
"transformer_relative"
)
def
transformer_rpr
(
args
):
args
.
max_relative_length
=
20
args
.
k_only
=
True
base_architecture
(
args
)
@register_model_architecture
(
"transformer"
,
"transformer_iwslt_de_en"
)
@register_model_architecture
(
"transformer"
,
"transformer_iwslt_de_en"
)
...
...
fairseq/modules/__init__.py
查看文件 @
f0e3290f
...
@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
...
@@ -21,6 +21,7 @@ from .grad_multiply import GradMultiply
from
.gumbel_vector_quantizer
import
GumbelVectorQuantizer
from
.gumbel_vector_quantizer
import
GumbelVectorQuantizer
from
.kmeans_vector_quantizer
import
KmeansVectorQuantizer
from
.kmeans_vector_quantizer
import
KmeansVectorQuantizer
from
.layer_drop
import
LayerDropModuleList
from
.layer_drop
import
LayerDropModuleList
from
.layer_history
import
CreateLayerHistory
from
.layer_norm
import
Fp32LayerNorm
,
LayerNorm
from
.layer_norm
import
Fp32LayerNorm
,
LayerNorm
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.lightweight_convolution
import
LightweightConv
,
LightweightConv1dTBC
from
.lightweight_convolution
import
LightweightConv
,
LightweightConv1dTBC
...
@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
...
@@ -28,6 +29,7 @@ from .linearized_convolution import LinearizedConvolution
from
.multihead_attention
import
MultiheadAttention
from
.multihead_attention
import
MultiheadAttention
from
.positional_embedding
import
PositionalEmbedding
from
.positional_embedding
import
PositionalEmbedding
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
from
.rel_position_multihead_attention
import
RelPositionMultiheadAttention
from
.relative_multihead_attention
import
RelativeMultiheadAttention
from
.same_pad
import
SamePad
from
.same_pad
import
SamePad
from
.scalar_bias
import
ScalarBias
from
.scalar_bias
import
ScalarBias
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
...
@@ -47,6 +49,7 @@ __all__ = [
...
@@ -47,6 +49,7 @@ __all__ = [
"ConformerEncoderLayer"
,
"ConformerEncoderLayer"
,
"ConvolutionModule"
,
"ConvolutionModule"
,
"ConvTBC"
,
"ConvTBC"
,
"CreateLayerHistory"
,
"cross_entropy"
,
"cross_entropy"
,
"DownsampledMultiHeadAttention"
,
"DownsampledMultiHeadAttention"
,
"DynamicConv1dTBC"
,
"DynamicConv1dTBC"
,
...
@@ -69,6 +72,7 @@ __all__ = [
...
@@ -69,6 +72,7 @@ __all__ = [
"MultiheadAttention"
,
"MultiheadAttention"
,
"PositionalEmbedding"
,
"PositionalEmbedding"
,
"RelPositionMultiheadAttention"
,
"RelPositionMultiheadAttention"
,
"RelativeMultiheadAttention"
,
"SamePad"
,
"SamePad"
,
"ScalarBias"
,
"ScalarBias"
,
"SinusoidalPositionalEmbedding"
,
"SinusoidalPositionalEmbedding"
,
...
...
fairseq/modules/layer_history.py
0 → 100644
查看文件 @
f0e3290f
差异被折叠。
点击展开。
fairseq/modules/relative_multihead_attention.py
0 → 100644
查看文件 @
f0e3290f
差异被折叠。
点击展开。
fairseq/modules/transformer_layer.py
查看文件 @
f0e3290f
...
@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
...
@@ -8,7 +8,12 @@ from typing import Dict, List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq
import
utils
from
fairseq.modules
import
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
from
fairseq.modules
import
(
LayerNorm
,
MultiheadAttention
,
RelPositionMultiheadAttention
,
RelativeMultiheadAttention
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
from
fairseq.modules.quant_noise
import
quant_noise
from
fairseq.modules.quant_noise
import
quant_noise
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -82,6 +87,16 @@ class TransformerEncoderLayer(nn.Module):
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
args
.
max_relative_length
,
)
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
...
@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -277,6 +292,16 @@ class TransformerDecoderLayer(nn.Module):
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
self
.
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
elif
self
.
attn_type
==
"relative"
or
getattr
(
args
,
"max_relative_length"
,
-
1
)
!=
-
1
:
return
RelativeMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
max_relative_length
=
args
.
max_relative_length
,
)
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论