Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
090afc4d
Commit
090afc4d
authored
Apr 12, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement the dlcl for s2t task
parent
db29e01d
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
154 行增加
和
38 行删除
+154
-38
fairseq/models/dlcl_transformer.py
+2
-0
fairseq/models/speech_to_text/s2t_conformer.py
+14
-0
fairseq/models/speech_to_text/s2t_sate.py
+32
-3
fairseq/models/speech_to_text/s2t_transformer.py
+43
-0
fairseq/models/transformer.py
+21
-0
fairseq/modules/__init__.py
+2
-1
fairseq/modules/layer_history.py
+40
-34
没有找到文件。
fairseq/models/dlcl_transformer.py
查看文件 @
090afc4d
...
@@ -55,8 +55,10 @@ class DLCLTransformerModel(TransformerModel):
...
@@ -55,8 +55,10 @@ class DLCLTransformerModel(TransformerModel):
# dense layer parameters
# dense layer parameters
parser
.
add_argument
(
'--encoder-history-type'
,
parser
.
add_argument
(
'--encoder-history-type'
,
default
=
"learnable_dense"
,
help
=
'encoder layer history type'
)
help
=
'encoder layer history type'
)
parser
.
add_argument
(
'--decoder-history-type'
,
parser
.
add_argument
(
'--decoder-history-type'
,
default
=
"learnable_dense"
,
help
=
'decoder layer history type'
)
help
=
'decoder layer history type'
)
parser
.
add_argument
(
'--encoder-integration-type'
,
choices
=
[
'avg'
,
'sum'
],
parser
.
add_argument
(
'--encoder-integration-type'
,
choices
=
[
'avg'
,
'sum'
],
help
=
'encoder layer integration type'
)
help
=
'encoder layer integration type'
)
...
...
fairseq/models/speech_to_text/s2t_conformer.py
查看文件 @
090afc4d
...
@@ -99,6 +99,9 @@ class S2TConformerEncoder(S2TTransformerEncoder):
...
@@ -99,6 +99,9 @@ class S2TConformerEncoder(S2TTransformerEncoder):
)
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
x
=
self
.
embed_scale
*
x
x
=
self
.
embed_scale
*
x
...
@@ -109,8 +112,19 @@ class S2TConformerEncoder(S2TTransformerEncoder):
...
@@ -109,8 +112,19 @@ class S2TConformerEncoder(S2TTransformerEncoder):
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
positions
=
self
.
dropout_module
(
positions
)
positions
=
self
.
dropout_module
(
positions
)
# add emb into history
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
if
self
.
layer_norm
is
not
None
:
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
...
...
fairseq/models/speech_to_text/s2t_sate.py
查看文件 @
090afc4d
...
@@ -16,13 +16,15 @@ from fairseq.models.speech_to_text import (
...
@@ -16,13 +16,15 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel
,
S2TTransformerModel
,
S2TTransformerEncoder
,
S2TTransformerEncoder
,
S2TConformerEncoder
,
S2TConformerEncoder
,
S2TConformerModel
)
S2TConformerModel
)
from
fairseq.models.speech_to_text.s2t_transformer
import
Conv1dSubsampler
from
fairseq.models.speech_to_text.s2t_transformer
import
Conv1dSubsampler
from
fairseq.modules
import
(
from
fairseq.modules
import
(
FairseqDropout
,
FairseqDropout
,
LayerNorm
,
LayerNorm
,
PositionalEmbedding
,
PositionalEmbedding
,
TransformerEncoderLayer
,
TransformerEncoderLayer
,
LearnableDenseLayerHistory
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -208,10 +210,17 @@ class TextEncoder(FairseqEncoder):
...
@@ -208,10 +210,17 @@ class TextEncoder(FairseqEncoder):
else
:
else
:
self
.
layer_norm
=
None
self
.
layer_norm
=
None
def
forward
(
self
,
x
,
encoder_padding_mask
=
None
,
positions
=
None
):
def
forward
(
self
,
x
,
encoder_padding_mask
=
None
,
positions
=
None
,
history
=
None
):
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
if
history
is
not
None
:
x
=
history
.
pop
()
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
history
is
not
None
:
history
.
add
(
x
)
if
history
is
not
None
:
x
=
history
.
pop
()
if
self
.
layer_norm
is
not
None
:
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
...
@@ -241,7 +250,16 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -241,7 +250,16 @@ class S2TSATEEncoder(FairseqEncoder):
# text encoder
# text encoder
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
self
.
text_encoder
=
TextEncoder
(
args
,
embed_tokens
)
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
normalize_before
=
args
.
encoder_normalize_before
layer_num
=
args
.
encoder_layers
+
args
.
text_encoder_layers
+
1
self
.
history
=
LearnableDenseLayerHistory
(
normalize_before
,
layer_num
,
args
.
encoder_embed_dim
,
True
)
else
:
self
.
history
=
None
def
forward
(
self
,
src_tokens
,
src_lengths
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
acoustic_encoder_out
=
self
.
acoustic_encoder
(
src_tokens
,
src_lengths
)
...
@@ -254,7 +272,18 @@ class S2TSATEEncoder(FairseqEncoder):
...
@@ -254,7 +272,18 @@ class S2TSATEEncoder(FairseqEncoder):
x
,
positions
=
self
.
adapter
(
x
,
encoder_padding_mask
)
x
,
positions
=
self
.
adapter
(
x
,
encoder_padding_mask
)
x
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
positions
)
if
self
.
history
is
not
None
:
acoustic_history
=
self
.
acoustic_encoder
.
history
layer_num
=
acoustic_history
.
layer_num
idx
=
torch
.
arange
(
layer_num
)
.
unsqueeze
(
0
)
.
T
.
repeat
(
1
,
layer_num
)
.
to
(
x
.
device
)
self
.
history
.
weight
.
scatter
(
0
,
idx
,
acoustic_history
.
weight
)
self
.
history
.
layers
.
extend
(
acoustic_history
.
layers
)
self
.
history
.
count
=
acoustic_history
.
count
self
.
history
.
sum
=
acoustic_history
.
sum
self
.
history
.
add
(
x
)
x
=
self
.
text_encoder
(
x
,
encoder_padding_mask
,
positions
,
self
.
history
)
return
{
return
{
"ctc_logit"
:
[
ctc_logit
],
# T x B x C
"ctc_logit"
:
[
ctc_logit
],
# T x B x C
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
090afc4d
...
@@ -19,6 +19,7 @@ from fairseq.modules import (
...
@@ -19,6 +19,7 @@ from fairseq.modules import (
LayerNorm
,
LayerNorm
,
PositionalEmbedding
,
PositionalEmbedding
,
TransformerEncoderLayer
,
TransformerEncoderLayer
,
CreateLayerHistory
,
)
)
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -247,6 +248,28 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
...
@@ -247,6 +248,28 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
metavar
=
"STR"
,
metavar
=
"STR"
,
help
=
"freeze the module of the decoder"
,
help
=
"freeze the module of the decoder"
,
)
)
parser
.
add_argument
(
"--use-enc-dlcl"
,
default
=
False
,
action
=
'store_true'
,
help
=
"use dlcl encoder"
,
)
parser
.
add_argument
(
"--use-dec-dlcl"
,
default
=
False
,
action
=
'store_true'
,
help
=
"use dlcl encoder"
,
)
parser
.
add_argument
(
'--encoder-history-type'
,
default
=
"learnable_dense"
,
help
=
'encoder layer history type'
)
parser
.
add_argument
(
'--decoder-history-type'
,
default
=
"learnable_dense"
,
help
=
'decoder layer history type'
)
pass
pass
@classmethod
@classmethod
...
@@ -362,6 +385,11 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -362,6 +385,11 @@ class S2TTransformerEncoder(FairseqEncoder):
else
:
else
:
self
.
layer_norm
=
None
self
.
layer_norm
=
None
if
getattr
(
args
,
"use_enc_dlcl"
,
False
):
self
.
history
=
CreateLayerHistory
(
args
,
is_encoder
=
True
)
else
:
self
.
history
=
None
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
self
.
use_ctc
=
"sate"
in
args
.
arch
or
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
\
((
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
\
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
))
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
))
...
@@ -384,6 +412,10 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -384,6 +412,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
x
,
input_lengths
=
self
.
subsample
(
src_tokens
,
src_lengths
)
x
=
self
.
embed_scale
*
x
x
=
self
.
embed_scale
*
x
...
@@ -394,8 +426,19 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -394,8 +426,19 @@ class S2TTransformerEncoder(FairseqEncoder):
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
positions
=
self
.
dropout_module
(
positions
)
positions
=
self
.
dropout_module
(
positions
)
# add emb into history
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
if
self
.
layer_norm
is
not
None
:
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
...
...
fairseq/models/transformer.py
查看文件 @
090afc4d
...
@@ -27,6 +27,7 @@ from fairseq.modules import (
...
@@ -27,6 +27,7 @@ from fairseq.modules import (
SinusoidalPositionalEmbedding
,
SinusoidalPositionalEmbedding
,
TransformerDecoderLayer
,
TransformerDecoderLayer
,
TransformerEncoderLayer
,
TransformerEncoderLayer
,
CreateLayerHistory
)
)
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
...
@@ -778,6 +779,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -778,6 +779,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else
:
else
:
self
.
layer_norm
=
None
self
.
layer_norm
=
None
if
getattr
(
args
,
"use_dec_dlcl"
,
False
):
self
.
history
=
CreateLayerHistory
(
args
,
is_encoder
=
False
)
else
:
self
.
history
=
None
self
.
project_out_dim
=
(
self
.
project_out_dim
=
(
Linear
(
embed_dim
,
self
.
output_embed_dim
,
bias
=
False
)
Linear
(
embed_dim
,
self
.
output_embed_dim
,
bias
=
False
)
if
embed_dim
!=
self
.
output_embed_dim
and
not
args
.
tie_adaptive_weights
if
embed_dim
!=
self
.
output_embed_dim
and
not
args
.
tie_adaptive_weights
...
@@ -913,6 +919,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -913,6 +919,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
- a dictionary with any model-specific outputs
"""
"""
if
self
.
history
is
not
None
:
self
.
history
.
clean
()
if
alignment_layer
is
None
:
if
alignment_layer
is
None
:
alignment_layer
=
self
.
num_layers
-
1
alignment_layer
=
self
.
num_layers
-
1
...
@@ -948,6 +957,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -948,6 +957,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# B x T x C -> T x B x C
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
x
=
x
.
transpose
(
0
,
1
)
# add emb into history
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
self_attn_padding_mask
:
Optional
[
Tensor
]
=
None
self_attn_padding_mask
:
Optional
[
Tensor
]
=
None
if
self
.
cross_self_attention
or
prev_output_tokens
.
eq
(
self
.
padding_idx
)
.
any
():
if
self
.
cross_self_attention
or
prev_output_tokens
.
eq
(
self
.
padding_idx
)
.
any
():
self_attn_padding_mask
=
prev_output_tokens
.
eq
(
self
.
padding_idx
)
self_attn_padding_mask
=
prev_output_tokens
.
eq
(
self
.
padding_idx
)
...
@@ -956,6 +969,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -956,6 +969,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
attn
:
Optional
[
Tensor
]
=
None
attn
:
Optional
[
Tensor
]
=
None
inner_states
:
List
[
Optional
[
Tensor
]]
=
[
x
]
inner_states
:
List
[
Optional
[
Tensor
]]
=
[
x
]
for
idx
,
layer
in
enumerate
(
self
.
layers
):
for
idx
,
layer
in
enumerate
(
self
.
layers
):
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
if
incremental_state
is
None
and
not
full_context_alignment
:
if
incremental_state
is
None
and
not
full_context_alignment
:
self_attn_mask
=
self
.
buffered_future_mask
(
x
)
self_attn_mask
=
self
.
buffered_future_mask
(
x
)
else
:
else
:
...
@@ -982,6 +998,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -982,6 +998,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
inner_states
.
append
(
x
)
inner_states
.
append
(
x
)
if
layer_attn
is
not
None
and
idx
==
alignment_layer
:
if
layer_attn
is
not
None
and
idx
==
alignment_layer
:
attn
=
layer_attn
.
float
()
.
to
(
x
)
attn
=
layer_attn
.
float
()
.
to
(
x
)
if
self
.
history
is
not
None
:
self
.
history
.
add
(
x
)
if
attn
is
not
None
:
if
attn
is
not
None
:
if
alignment_heads
is
not
None
:
if
alignment_heads
is
not
None
:
...
@@ -990,6 +1008,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -990,6 +1008,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# average probabilities over heads
# average probabilities over heads
attn
=
attn
.
mean
(
dim
=
0
)
attn
=
attn
.
mean
(
dim
=
0
)
if
self
.
history
is
not
None
:
x
=
self
.
history
.
pop
()
if
self
.
layer_norm
is
not
None
:
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
...
...
fairseq/modules/__init__.py
查看文件 @
090afc4d
...
@@ -21,7 +21,7 @@ from .grad_multiply import GradMultiply
...
@@ -21,7 +21,7 @@ from .grad_multiply import GradMultiply
from
.gumbel_vector_quantizer
import
GumbelVectorQuantizer
from
.gumbel_vector_quantizer
import
GumbelVectorQuantizer
from
.kmeans_vector_quantizer
import
KmeansVectorQuantizer
from
.kmeans_vector_quantizer
import
KmeansVectorQuantizer
from
.layer_drop
import
LayerDropModuleList
from
.layer_drop
import
LayerDropModuleList
from
.layer_history
import
CreateLayerHistory
from
.layer_history
import
CreateLayerHistory
,
LearnableDenseLayerHistory
from
.layer_norm
import
Fp32LayerNorm
,
LayerNorm
from
.layer_norm
import
Fp32LayerNorm
,
LayerNorm
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.lightweight_convolution
import
LightweightConv
,
LightweightConv1dTBC
from
.lightweight_convolution
import
LightweightConv
,
LightweightConv1dTBC
...
@@ -65,6 +65,7 @@ __all__ = [
...
@@ -65,6 +65,7 @@ __all__ = [
"KmeansVectorQuantizer"
,
"KmeansVectorQuantizer"
,
"LayerDropModuleList"
,
"LayerDropModuleList"
,
"LayerNorm"
,
"LayerNorm"
,
"LearnableDenseLayerHistory"
,
"LearnedPositionalEmbedding"
,
"LearnedPositionalEmbedding"
,
"LightweightConv1dTBC"
,
"LightweightConv1dTBC"
,
"LightweightConv"
,
"LightweightConv"
,
...
...
fairseq/modules/layer_history.py
查看文件 @
090afc4d
...
@@ -7,35 +7,42 @@ import numpy as np
...
@@ -7,35 +7,42 @@ import numpy as np
def
CreateLayerHistory
(
args
,
is_encoder
):
def
CreateLayerHistory
(
args
,
is_encoder
):
history_type
=
args
.
encoder_history_type
if
is_encoder
else
args
.
decoder_history_type
history_type
=
args
.
encoder_history_type
if
is_encoder
else
args
.
decoder_history_type
normalize_before
=
args
.
encoder_normalize_before
if
is_encoder
else
args
.
decoder_normalize_before
layer_num
=
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
dim
=
args
.
encoder_embed_dim
if
is_encoder
else
args
.
decoder_embed_dim
if
history_type
is
None
:
if
history_type
is
None
:
return
None
return
None
elif
history_type
==
"residual"
:
elif
history_type
==
"residual"
:
return
ResidualLayerHistory
(
args
,
is_encoder
)
return
ResidualLayerHistory
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
elif
history_type
==
"dense"
:
elif
history_type
==
"dense"
:
return
DenseLayerHistory
(
args
,
is_encoder
)
integration_type
=
getattr
(
args
,
'encoder_integration_type'
,
'avg'
)
if
is_encoder
else
\
getattr
(
args
,
'decoder_integration_type'
,
'avg'
)
windows_size
=
getattr
(
args
,
'encoder_windows_size'
,
-
1
)
if
is_encoder
else
\
getattr
(
args
,
'decoder_windows_size'
,
-
1
)
return
DenseLayerHistory
(
normalize_before
,
layer_num
,
dim
,
is_encoder
,
integration_type
,
windows_size
)
elif
history_type
==
"learnable_dense"
:
elif
history_type
==
"learnable_dense"
:
return
LearnableDenseLayerHistory
(
args
,
is_encoder
)
return
LearnableDenseLayerHistory
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
elif
history_type
==
"learnable_dense_mask"
:
elif
history_type
==
"learnable_dense_mask"
:
return
LearnableDenseMaskLayerHistory
(
args
,
is_encoder
)
return
LearnableDenseMaskLayerHistory
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
elif
history_type
==
"learnable_dense_nonorm"
:
elif
history_type
==
"learnable_dense_nonorm"
:
return
LearnableDenseNoNormLayerHistory
(
args
,
is_encoder
)
return
LearnableDenseNoNormLayerHistory
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
elif
history_type
==
"gru"
:
elif
history_type
==
"gru"
:
return
GruLayerHistory
(
args
,
is_encoder
)
return
GruLayerHistory
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
else
:
else
:
raise
ValueError
raise
ValueError
class
BaseLayerHistory
(
nn
.
Module
):
class
BaseLayerHistory
(
nn
.
Module
):
def
__init__
(
self
,
args
,
is_encoder
):
def
__init__
(
self
,
normalize_before
,
layer_num
,
dim
,
is_encoder
):
super
(
BaseLayerHistory
,
self
)
.
__init__
()
super
(
BaseLayerHistory
,
self
)
.
__init__
()
self
.
is_encoder
=
is_encoder
self
.
is_encoder
=
is_encoder
self
.
normalize_before
=
args
.
encoder_normalize_before
if
is_encoder
else
args
.
decoder_
normalize_before
self
.
normalize_before
=
normalize_before
# the first layer (aka. embedding layer) does not have layer normalization
# the first layer (aka. embedding layer) does not have layer normalization
layers
=
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
self
.
layer_norms
=
nn
.
ModuleList
(
LayerNorm
(
dim
)
for
_
in
range
(
layer_num
))
dim
=
args
.
encoder_embed_dim
if
is_encoder
else
args
.
decoder_embed_dim
self
.
layer_norms
=
nn
.
ModuleList
(
LayerNorm
(
dim
)
for
_
in
range
(
layers
))
def
add
(
self
,
layer
):
def
add
(
self
,
layer
):
raise
NotImplemented
raise
NotImplemented
...
@@ -52,8 +59,8 @@ class ResidualLayerHistory(BaseLayerHistory):
...
@@ -52,8 +59,8 @@ class ResidualLayerHistory(BaseLayerHistory):
x_n = x_{n-1} + y_{n-1}
x_n = x_{n-1} + y_{n-1}
"""
"""
def
__init__
(
self
,
args
,
is_encoder
):
def
__init__
(
self
,
normalize_before
,
layer_num
,
dim
,
is_encoder
):
super
(
ResidualLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
super
(
ResidualLayerHistory
,
self
)
.
__init__
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
self
.
count
=
0
self
.
count
=
0
self
.
x
=
None
self
.
x
=
None
self
.
y
=
None
self
.
y
=
None
...
@@ -90,19 +97,17 @@ class DenseLayerHistory(BaseLayerHistory):
...
@@ -90,19 +97,17 @@ class DenseLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
"""
def
__init__
(
self
,
args
,
is_encoder
):
def
__init__
(
self
,
normalize_before
,
layer_num
,
dim
,
is_encoder
,
integration_type
,
windows_size
):
super
(
DenseLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
super
(
DenseLayerHistory
,
self
)
.
__init__
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
self
.
sum
=
None
self
.
sum
=
None
self
.
count
=
0
self
.
count
=
0
self
.
individuals
=
None
# store past individual value, used for windows_size > 0
self
.
individuals
=
None
# store past individual value, used for windows_size > 0
self
.
integration_type
=
getattr
(
args
,
'encoder_integration_type'
,
'avg'
)
if
is_encoder
else
\
self
.
integration_type
=
integration_type
getattr
(
args
,
'decoder_integration_type'
,
'avg'
)
# windows = 1 means not use residual connection
# windows = 1 means not use residual connection
self
.
windows_size
=
getattr
(
args
,
'encoder_windows_size'
,
-
1
)
if
is_encoder
else
\
self
.
windows_size
=
windows_size
getattr
(
args
,
'decoder_windows_size'
,
-
1
)
if
self
.
windows_size
>
0
:
if
self
.
windows_size
>
0
:
assert
self
.
windows_size
<=
(
args
.
encoder_layers
+
1
)
if
is_encoder
else
(
args
.
decoder_layers
+
1
)
assert
self
.
windows_size
<=
1
+
layer_num
self
.
individuals
=
queue
.
Queue
(
self
.
windows_size
)
self
.
individuals
=
queue
.
Queue
(
self
.
windows_size
)
def
add
(
self
,
layer
):
def
add
(
self
,
layer
):
...
@@ -151,13 +156,14 @@ class LearnableDenseLayerHistory(BaseLayerHistory):
...
@@ -151,13 +156,14 @@ class LearnableDenseLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
"""
def
__init__
(
self
,
args
,
is_encoder
):
def
__init__
(
self
,
normalize_before
,
layer_num
,
dim
,
is_encoder
):
super
(
LearnableDenseLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
super
(
LearnableDenseLayerHistory
,
self
)
.
__init__
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
self
.
sum
=
None
self
.
sum
=
None
self
.
count
=
0
self
.
count
=
0
self
.
layer_num
=
1
+
(
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
)
self
.
layer_num
=
1
+
layer_num
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
layer_num
,
self
.
layer_num
)
.
fill_
(
1.0
)
.
tril
())
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
layer_num
,
self
.
layer_num
)
.
fill_
(
1.0
)
.
tril
())
self
.
weight
.
data
=
self
.
weight
.
data
/
self
.
weight
.
data
.
sum
(
1
,
keepdim
=
True
)
self
.
weight
.
data
=
self
.
weight
.
data
/
self
.
weight
.
data
.
sum
(
1
,
keepdim
=
True
)
self
.
layers
=
[]
def
extra_repr
(
self
):
def
extra_repr
(
self
):
return
'n_layers={layer_num}, '
.
format
(
**
self
.
__dict__
)
return
'n_layers={layer_num}, '
.
format
(
**
self
.
__dict__
)
...
@@ -198,11 +204,11 @@ class LearnableDenseMaskLayerHistory(BaseLayerHistory):
...
@@ -198,11 +204,11 @@ class LearnableDenseMaskLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
"""
def
__init__
(
self
,
args
,
is_encoder
):
def
__init__
(
self
,
normalize_before
,
layer_num
,
dim
,
is_encoder
):
super
(
LearnableDenseMaskLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
super
(
LearnableDenseMaskLayerHistory
,
self
)
.
__init__
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
self
.
sum
=
None
self
.
sum
=
None
self
.
count
=
0
self
.
count
=
0
self
.
layer_num
=
1
+
(
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
)
self
.
layer_num
=
1
+
layer_num
if
is_encoder
:
if
is_encoder
:
self
.
weight_mask
=
np
.
loadtxt
(
"encoder_mask.txt"
,
dtype
=
float
,
delimiter
=
' '
)
self
.
weight_mask
=
np
.
loadtxt
(
"encoder_mask.txt"
,
dtype
=
float
,
delimiter
=
' '
)
else
:
else
:
...
@@ -246,11 +252,11 @@ class LearnableDenseNoNormLayerHistory(BaseLayerHistory):
...
@@ -246,11 +252,11 @@ class LearnableDenseNoNormLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
"""
def
__init__
(
self
,
args
,
is_encoder
):
def
__init__
(
self
,
normalize_before
,
layer_num
,
dim
,
is_encoder
):
super
(
LearnableDenseNoNormLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
super
(
LearnableDenseNoNormLayerHistory
,
self
)
.
__init__
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
self
.
sum
=
None
self
.
sum
=
None
self
.
count
=
0
self
.
count
=
0
self
.
layer_num
=
1
+
(
args
.
encoder_layers
if
is_encoder
else
args
.
decoder_layers
)
self
.
layer_num
=
1
+
layer_num
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
layer_num
,
self
.
layer_num
)
.
fill_
(
1.0
)
.
tril
())
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
layer_num
,
self
.
layer_num
)
.
fill_
(
1.0
)
.
tril
())
self
.
weight
.
data
=
self
.
weight
.
data
/
self
.
weight
.
data
.
sum
(
1
,
keepdim
=
True
)
self
.
weight
.
data
=
self
.
weight
.
data
/
self
.
weight
.
data
.
sum
(
1
,
keepdim
=
True
)
self
.
layers
=
[]
self
.
layers
=
[]
...
@@ -286,13 +292,13 @@ class GruLayerHistory(BaseLayerHistory):
...
@@ -286,13 +292,13 @@ class GruLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
"""
def
__init__
(
self
,
args
,
is_encoder
):
def
__init__
(
self
,
normalize_before
,
layer_num
,
dim
,
is_encoder
):
super
(
GruLayerHistory
,
self
)
.
__init__
(
args
,
is_encoder
)
super
(
GruLayerHistory
,
self
)
.
__init__
(
normalize_before
,
layer_num
,
dim
,
is_encoder
)
self
.
count
=
0
self
.
count
=
0
self
.
gru
=
nn
.
GRUCell
(
args
.
encoder_embed_dim
,
args
.
encoder_embed_
dim
)
self
.
gru
=
nn
.
GRUCell
(
dim
)
self
.
gru_cells
=
[]
self
.
gru_cells
=
[]
self
.
layer_norms
=
nn
.
ModuleList
(
LayerNorm
(
args
.
encoder_embed_dim
)
for
_
in
range
(
args
.
decoder_layers
+
1
))
self
.
layer_norms
=
nn
.
ModuleList
(
LayerNorm
(
dim
)
for
_
in
range
(
layer_num
+
1
))
self
.
decoder_layers
=
args
.
decoder_layers
self
.
decoder_layers
=
layer_num
def
compute_gru
(
self
,
layer_output
):
def
compute_gru
(
self
,
layer_output
):
if
len
(
self
.
gru_cells
)
==
0
:
if
len
(
self
.
gru_cells
)
==
0
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论