Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
80e64569
Commit
80e64569
authored
3 years ago
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update the pyramid transformer about block fuse
parent
9fadf1f4
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
3 行增加
和
4 行删除
+3
-4
fairseq/models/speech_to_text/pys2t_transformer.py
+3
-4
没有找到文件。
fairseq/models/speech_to_text/pys2t_transformer.py
查看文件 @
80e64569
...
...
@@ -118,7 +118,7 @@ class BlockFuse(nn.Module):
super
()
.
__init__
()
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv1d
(
prev_embed_dim
,
embed_dim
,
kernel_size
=
1
,
bias
=
False
),
nn
.
Conv1d
(
prev_embed_dim
,
embed_dim
,
kernel_size
=
1
),
nn
.
ReLU
()
)
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
...
...
@@ -146,7 +146,6 @@ class BlockFuse(nn.Module):
# x = self.gate(x, state).view(seq_len, bsz, dim)
coef
=
(
self
.
gate_linear
(
torch
.
cat
([
x
,
state
],
dim
=-
1
)))
.
sigmoid
()
x
=
coef
*
x
+
(
1
-
coef
)
*
state
x
=
state
+
x
else
:
x
=
x
+
state
...
...
@@ -345,7 +344,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
ppm_pre_layer_norm
=
LayerNorm
(
embed_dim
)
ppm_post_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
ppm
=
nn
.
Sequential
(
nn
.
Conv1d
(
embed_dim
,
self
.
embed_dim
,
kernel_size
=
1
,
bias
=
False
),
nn
.
Conv1d
(
embed_dim
,
self
.
embed_dim
,
kernel_size
=
1
),
nn
.
BatchNorm1d
(
self
.
embed_dim
),
nn
.
ReLU
(),
)
...
...
@@ -361,7 +360,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
setattr
(
self
,
f
"block_fuse{i + 1}"
,
block_fuse
)
setattr
(
self
,
f
"ppm{i + 1}"
,
ppm
)
setattr
(
self
,
f
"ppm_pre_layer_norm{i + 1}"
,
ppm_pre_layer_norm
)
setattr
(
self
,
f
"ppm_
layer_norm2
{i + 1}"
,
ppm_post_layer_norm
)
setattr
(
self
,
f
"ppm_
post_layer_norm
{i + 1}"
,
ppm_post_layer_norm
)
if
args
.
encoder_normalize_before
:
self
.
layer_norm
=
LayerNorm
(
self
.
embed_dim
)
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论