Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
cbeb5521
Commit
cbeb5521
authored
Sep 15, 2022
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update the mixup implementation
parent
c7242ff4
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
62 行增加
和
63 行删除
+62
-63
egs/mustc/st/conf/mixup.yaml
+3
-3
fairseq/criterions/label_smoothed_cross_entropy.py
+3
-0
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
+13
-3
fairseq/models/speech_to_text/s2t_transformer.py
+10
-24
fairseq/models/transformer.py
+1
-0
fairseq/models/transformer_s2.py
+32
-33
没有找到文件。
egs/mustc/st/conf/mixup.yaml
查看文件 @
cbeb5521
...
@@ -3,6 +3,6 @@ inter-mixup-layer: -1
...
@@ -3,6 +3,6 @@ inter-mixup-layer: -1
inter-mixup-prob
:
1.0
inter-mixup-prob
:
1.0
inter-mixup-ratio
:
1.0
inter-mixup-ratio
:
1.0
inter-mixup-beta
:
0.5
inter-mixup-beta
:
0.5
inter-mixup-keep-org
:
Tru
e
inter-mixup-keep-org
:
Fals
e
ctc-mixup-consistent-weight
:
1
ctc-mixup-consistent-weight
:
0
mixup-consistent-weight
:
1
mixup-consistent-weight
:
0
fairseq/criterions/label_smoothed_cross_entropy.py
查看文件 @
cbeb5521
...
@@ -203,6 +203,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -203,6 +203,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
else
:
else
:
target
=
target
.
view
(
-
1
)
target
=
target
.
view
(
-
1
)
if
lprobs
.
size
(
0
)
==
0
:
return
torch
.
Tensor
([
0
]),
torch
.
Tensor
([
0
])
mask
=
target
.
ne
(
self
.
padding_idx
)
mask
=
target
.
ne
(
self
.
padding_idx
)
n_correct
=
torch
.
sum
(
n_correct
=
torch
.
sum
(
lprobs
.
argmax
(
1
)
.
masked_select
(
mask
)
.
eq
(
target
.
masked_select
(
mask
))
lprobs
.
argmax
(
1
)
.
masked_select
(
mask
)
.
eq
(
target
.
masked_select
(
mask
))
...
...
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
查看文件 @
cbeb5521
...
@@ -79,9 +79,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
...
@@ -79,9 +79,19 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
n_sentences
=
sample
[
"target"
]
.
size
(
0
)
n_sentences
=
sample
[
"target"
]
.
size
(
0
)
if
"mixup"
in
encoder_out
and
encoder_out
[
"mixup"
]
is
not
None
:
if
"mixup"
in
encoder_out
and
encoder_out
[
"mixup"
]
is
not
None
:
sample_size
//=
net_output
[
0
]
.
size
(
0
)
if
self
.
sentence_avg
else
encoder_out
[
"mixup"
][
"ratio"
]
mixup
=
encoder_out
[
"mixup"
]
n_tokens
//=
encoder_out
[
"mixup"
][
"ratio"
]
ratio
=
mixup
[
"ratio"
]
n_sentences
//=
net_output
[
0
]
.
size
(
0
)
if
mixup
[
"keep_org"
]:
n_tokens
=
int
(
sample_size
*
(
1
+
ratio
))
else
:
n_tokens
=
int
(
sample_size
*
ratio
)
if
self
.
sentence_avg
:
sample_size
=
net_output
[
0
]
.
size
(
0
)
else
:
sample_size
=
n_tokens
n_sentences
=
net_output
[
0
]
.
size
(
0
)
logging_output
=
{
logging_output
=
{
"trans_loss"
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
"trans_loss"
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
...
...
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
cbeb5521
...
@@ -825,27 +825,6 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -825,27 +825,6 @@ class S2TTransformerEncoder(FairseqEncoder):
batch
=
x
.
size
(
1
)
batch
=
x
.
size
(
1
)
org_indices
=
np
.
arange
(
batch
)
org_indices
=
np
.
arange
(
batch
)
# indices = np.random.permutation(batch)
# if self.mixup_ratio == 1:
# if len(indices) % 2 != 0:
# indices = np.append(indices, (indices[-1]))
# idx1 = indices[0::2]
# idx2 = indices[1::2]
#
# if self.mixup_keep_org:
# idx1 = np.append(org_indices, idx1)
# idx2 = np.append(org_indices, idx2)
#
# else:
# mix_size = int(max(2, batch * self.mixup_ratio // 2 * 2))
# mix_indices = indices[: mix_size]
# if self.mixup_keep_org:
# idx1 = np.append(org_indices, mix_indices[0::2])
# idx2 = np.append(org_indices, mix_indices[1::2])
# else:
# idx1 = np.append(mix_indices[0::2], (indices[mix_size:]))
# idx2 = np.append(mix_indices[1::2], (indices[mix_size:]))
mixup_size
=
int
(
batch
*
self
.
mixup_ratio
)
mixup_size
=
int
(
batch
*
self
.
mixup_ratio
)
if
mixup_size
<=
batch
:
if
mixup_size
<=
batch
:
mixup_index1
=
np
.
random
.
permutation
(
mixup_size
)
mixup_index1
=
np
.
random
.
permutation
(
mixup_size
)
...
@@ -853,6 +832,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -853,6 +832,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else
:
else
:
mixup_index1
=
np
.
random
.
randint
(
0
,
batch
,
mixup_size
)
mixup_index1
=
np
.
random
.
randint
(
0
,
batch
,
mixup_size
)
mixup_index2
=
np
.
random
.
randint
(
0
,
batch
,
mixup_size
)
mixup_index2
=
np
.
random
.
randint
(
0
,
batch
,
mixup_size
)
if
self
.
mixup_keep_org
:
if
self
.
mixup_keep_org
:
idx1
=
np
.
append
(
org_indices
,
mixup_index1
)
idx1
=
np
.
append
(
org_indices
,
mixup_index1
)
idx2
=
np
.
append
(
org_indices
,
mixup_index2
)
idx2
=
np
.
append
(
org_indices
,
mixup_index2
)
...
@@ -864,15 +844,16 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -864,15 +844,16 @@ class S2TTransformerEncoder(FairseqEncoder):
idx1
=
np
.
append
(
keep_indices
,
mixup_index1
)
idx1
=
np
.
append
(
keep_indices
,
mixup_index1
)
idx2
=
np
.
append
(
keep_indices
,
mixup_index2
)
idx2
=
np
.
append
(
keep_indices
,
mixup_index2
)
idx1
=
torch
.
from_numpy
(
idx1
)
.
to
(
x
.
device
)
idx1
=
torch
.
from_numpy
(
idx1
)
.
to
(
x
.
device
)
.
long
()
idx2
=
torch
.
from_numpy
(
idx2
)
.
to
(
x
.
device
)
idx2
=
torch
.
from_numpy
(
idx2
)
.
to
(
x
.
device
)
.
long
()
x1
=
x
[:,
idx1
]
x1
=
x
[:,
idx1
]
x2
=
x
[:,
idx2
]
x2
=
x
[:,
idx2
]
coef
=
self
.
beta
.
sample
([
len
(
idx1
)])
.
to
(
x
.
device
)
.
type_as
(
x
)
.
view
(
-
1
)
coef
=
self
.
beta
.
sample
([
len
(
idx1
)])
.
to
(
x
.
device
)
.
type_as
(
x
)
.
view
(
-
1
)
mixup_coef
=
coef
.
view
(
1
,
-
1
,
1
)
mixup_coef
=
coef
.
view
(
1
,
-
1
,
1
)
x
=
(
mixup_coef
*
x1
+
(
1
-
mixup_coef
)
*
x2
)
x
=
mixup_coef
*
x1
+
(
1
-
mixup_coef
)
*
x2
x
=
x
.
contiguous
()
pad1
=
encoder_padding_mask
[
idx1
]
pad1
=
encoder_padding_mask
[
idx1
]
pad2
=
encoder_padding_mask
[
idx2
]
pad2
=
encoder_padding_mask
[
idx2
]
...
@@ -881,6 +862,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -881,6 +862,7 @@ class S2TTransformerEncoder(FairseqEncoder):
mixup
=
{
mixup
=
{
"ratio"
:
self
.
mixup_ratio
,
"ratio"
:
self
.
mixup_ratio
,
"keep_org"
:
self
.
mixup_keep_org
,
"coef"
:
coef
,
"coef"
:
coef
,
"index1"
:
idx1
,
"index1"
:
idx1
,
"index2"
:
idx2
,
"index2"
:
idx2
,
...
@@ -1046,6 +1028,10 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -1046,6 +1028,10 @@ class S2TTransformerEncoder(FairseqEncoder):
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
self
.
show_debug
(
x
,
"x after encoding layer norm"
)
self
.
show_debug
(
x
,
"x after encoding layer norm"
)
if
self
.
training
and
self
.
mixup
and
layer_idx
==
mixup_layer
:
if
torch
.
rand
(
1
)
<
self
.
mixup_prob
:
x
,
encoder_padding_mask
,
input_lengths
,
mixup
=
self
.
apply_mixup
(
x
,
encoder_padding_mask
)
if
self
.
use_ctc
and
ctc_logit
is
None
:
if
self
.
use_ctc
and
ctc_logit
is
None
:
ctc_logit
=
self
.
ctc
(
x
,
encoder_padding_mask
,
"Source output"
,
is_top
=
True
)
ctc_logit
=
self
.
ctc
(
x
,
encoder_padding_mask
,
"Source output"
,
is_top
=
True
)
self
.
show_debug
(
x
,
"x after ctc"
)
self
.
show_debug
(
x
,
"x after ctc"
)
...
...
fairseq/models/transformer.py
查看文件 @
cbeb5521
...
@@ -1060,6 +1060,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -1060,6 +1060,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x2
=
x
[:,
idx2
]
x2
=
x
[:,
idx2
]
mixup_coef
=
coef
.
view
(
1
,
-
1
,
1
)
mixup_coef
=
coef
.
view
(
1
,
-
1
,
1
)
x
=
mixup_coef
*
x1
+
(
1
-
mixup_coef
)
*
x2
x
=
mixup_coef
*
x1
+
(
1
-
mixup_coef
)
*
x2
x
=
x
.
contiguous
()
if
self_attn_padding_mask
is
not
None
:
if
self_attn_padding_mask
is
not
None
:
pad1
=
self_attn_padding_mask
[
idx1
]
pad1
=
self_attn_padding_mask
[
idx1
]
...
...
fairseq/models/transformer_s2.py
查看文件 @
cbeb5521
...
@@ -36,7 +36,6 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
...
@@ -36,7 +36,6 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
from
torch
import
Tensor
from
torch
import
Tensor
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
...
@@ -66,13 +65,13 @@ class TransformerS2Encoder(TransformerEncoder):
...
@@ -66,13 +65,13 @@ class TransformerS2Encoder(TransformerEncoder):
return
layer
return
layer
def
forward
(
def
forward
(
self
,
self
,
src_tokens
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
x2
=
None
,
s2
=
None
,
x2_encoder_padding_mask
=
None
,
s2_encoder_padding_mask
=
None
,
return_all_hiddens
:
bool
=
False
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
"""
"""
Args:
Args:
...
@@ -99,8 +98,8 @@ class TransformerS2Encoder(TransformerEncoder):
...
@@ -99,8 +98,8 @@ class TransformerS2Encoder(TransformerEncoder):
"""
"""
return
self
.
forward_scriptable
(
src_tokens
,
return
self
.
forward_scriptable
(
src_tokens
,
src_lengths
,
src_lengths
,
x
2
,
s
2
,
x
2_encoder_padding_mask
,
s
2_encoder_padding_mask
,
return_all_hiddens
,
return_all_hiddens
,
token_embeddings
)
token_embeddings
)
...
@@ -109,13 +108,13 @@ class TransformerS2Encoder(TransformerEncoder):
...
@@ -109,13 +108,13 @@ class TransformerS2Encoder(TransformerEncoder):
# Current workaround is to add a helper function with different name and
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
# call the helper function from scriptable Subclass.
def
forward_scriptable
(
def
forward_scriptable
(
self
,
self
,
src_tokens
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
x
2
=
None
,
s
2
=
None
,
x
2_encoder_padding_mask
=
None
,
s
2_encoder_padding_mask
=
None
,
return_all_hiddens
:
bool
=
False
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
"""
"""
Args:
Args:
...
@@ -172,7 +171,7 @@ class TransformerS2Encoder(TransformerEncoder):
...
@@ -172,7 +171,7 @@ class TransformerS2Encoder(TransformerEncoder):
x
=
layer
(
x
=
layer
(
x
,
encoder_padding_mask
=
encoder_padding_mask
if
has_pads
else
None
,
x
,
encoder_padding_mask
=
encoder_padding_mask
if
has_pads
else
None
,
x2
=
x2
,
x2_encoder_padding_mask
=
x
2_encoder_padding_mask
,
s2
=
s2
,
s2_encoder_padding_mask
=
s
2_encoder_padding_mask
,
)
)
if
return_all_hiddens
:
if
return_all_hiddens
:
assert
encoder_states
is
not
None
assert
encoder_states
is
not
None
...
@@ -194,8 +193,8 @@ class TransformerS2Encoder(TransformerEncoder):
...
@@ -194,8 +193,8 @@ class TransformerS2Encoder(TransformerEncoder):
return
{
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"
encoder_out_s2"
:
[
x
2
],
# T x B x C
"
s2_encoder_out"
:
[
s
2
],
# T x B x C
"
encoder_padding_mask_s2"
:
[
x
2_encoder_padding_mask
],
# B x T
"
s2_encoder_padding_mask"
:
[
s
2_encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[
encoder_embedding
],
# B x T x C
"encoder_embedding"
:
[
encoder_embedding
],
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
[],
"src_tokens"
:
[],
...
@@ -229,13 +228,13 @@ class TransformerS2Decoder(TransformerDecoder):
...
@@ -229,13 +228,13 @@ class TransformerS2Decoder(TransformerDecoder):
return
layer
return
layer
def
extract_features_scriptable
(
def
extract_features_scriptable
(
self
,
self
,
prev_output_tokens
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
):
"""
"""
Similar to *forward* but only return features.
Similar to *forward* but only return features.
...
@@ -339,12 +338,13 @@ class TransformerS2Decoder(TransformerDecoder):
...
@@ -339,12 +338,13 @@ class TransformerS2Decoder(TransformerDecoder):
else
None
,
else
None
,
encoder_out
[
"encoder_padding_mask"
][
0
]
encoder_out
[
"encoder_padding_mask"
][
0
]
if
(
if
(
encoder_out
is
not
None
encoder_out
is
not
None
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
)
)
else
None
,
else
None
,
encoder_out_s2
=
encoder_out
[
"s2_encoder_out"
][
0
],
encoder_out_s2
=
encoder_out
[
"s2_encoder_out"
][
0
]
if
len
(
encoder_out
[
"s2_encoder_out"
])
>
0
else
None
,
encoder_padding_mask_s2
=
encoder_out
[
"s2_encoder_padding_mask"
][
0
],
encoder_padding_mask_s2
=
encoder_out
[
"s2_encoder_padding_mask"
][
0
]
if
len
(
encoder_out
[
"s2_encoder_padding_mask"
])
>
0
else
None
,
incremental_state
=
incremental_state
,
incremental_state
=
incremental_state
,
self_attn_mask
=
self_attn_mask
,
self_attn_mask
=
self_attn_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
...
@@ -411,4 +411,4 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
...
@@ -411,4 +411,4 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
return
m
\ No newline at end of file
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论