Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
F
Fairseq-S2T
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
xuchen
Fairseq-S2T
Commits
dd402ec2
Commit
dd402ec2
authored
Mar 16, 2021
by
xuchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
modify the implementation of the relative position encoding
parent
306dd6fc
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
81 行增加
和
48 行删除
+81
-48
fairseq/models/speech_to_text/s2t_transformer.py
+5
-2
fairseq/models/transformer.py
+2
-0
fairseq/modules/rel_position_multihead_attention.py
+17
-10
fairseq/modules/transformer_layer.py
+57
-36
没有找到文件。
fairseq/models/speech_to_text/s2t_transformer.py
查看文件 @
dd402ec2
...
@@ -319,6 +319,7 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -319,6 +319,7 @@ class S2TTransformerEncoder(FairseqEncoder):
else
:
else
:
self
.
layer_norm
=
None
self
.
layer_norm
=
None
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
use_ctc
=
(
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
\
self
.
use_ctc
=
(
"ctc"
in
getattr
(
args
,
"criterion"
,
False
))
and
\
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
)
(
getattr
(
args
,
"ctc_weight"
,
False
)
>
0
)
if
self
.
use_ctc
:
if
self
.
use_ctc
:
...
@@ -344,11 +345,13 @@ class S2TTransformerEncoder(FairseqEncoder):
...
@@ -344,11 +345,13 @@ class S2TTransformerEncoder(FairseqEncoder):
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
encoder_padding_mask
=
lengths_to_padding_mask
(
input_lengths
)
positions
=
self
.
embed_positions
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
positions
=
self
.
embed_positions
(
encoder_padding_mask
)
.
transpose
(
0
,
1
)
x
+=
positions
if
self
.
attn_type
!=
"rel_selfattn"
:
x
+=
positions
# x += positions
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
for
layer
in
self
.
transformer_layers
:
for
layer
in
self
.
transformer_layers
:
x
=
layer
(
x
,
encoder_padding_mask
)
x
=
layer
(
x
,
encoder_padding_mask
,
pos_emb
=
positions
)
if
self
.
layer_norm
is
not
None
:
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
...
...
fairseq/models/transformer.py
查看文件 @
dd402ec2
...
@@ -685,6 +685,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -685,6 +685,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
]
]
)
)
self
.
num_layers
=
len
(
self
.
layers
)
self
.
num_layers
=
len
(
self
.
layers
)
self
.
attn_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
if
args
.
decoder_normalize_before
and
not
getattr
(
if
args
.
decoder_normalize_before
and
not
getattr
(
args
,
"no_decoder_final_norm"
,
False
args
,
"no_decoder_final_norm"
,
False
...
@@ -892,6 +893,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -892,6 +893,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self_attn_padding_mask
=
self_attn_padding_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
need_attn
=
bool
((
idx
==
alignment_layer
)),
need_attn
=
bool
((
idx
==
alignment_layer
)),
need_head_weights
=
bool
((
idx
==
alignment_layer
)),
need_head_weights
=
bool
((
idx
==
alignment_layer
)),
pos_emb
=
positions
)
)
inner_states
.
append
(
x
)
inner_states
.
append
(
x
)
if
layer_attn
is
not
None
and
idx
==
alignment_layer
:
if
layer_attn
is
not
None
and
idx
==
alignment_layer
:
...
...
fairseq/modules/rel_position_multihead_attention.py
查看文件 @
dd402ec2
...
@@ -55,7 +55,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -55,7 +55,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# linear transformation for positional encoding
# linear transformation for positional encoding
self
.
linear_pos
=
quant_noise
(
self
.
linear_pos
=
quant_noise
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
False
),
q_noise
,
qn_block_size
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
),
q_noise
,
qn_block_size
)
)
# these two learnable bias are used in matrix c and matrix d
# these two learnable bias are used in matrix c and matrix d
...
@@ -63,7 +63,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -63,7 +63,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
self
.
pos_bias_u
=
Parameter
(
torch
.
Tensor
(
self
.
num_heads
,
self
.
head_dim
))
self
.
pos_bias_u
=
Parameter
(
torch
.
Tensor
(
self
.
num_heads
,
self
.
head_dim
))
self
.
pos_bias_v
=
Parameter
(
torch
.
Tensor
(
self
.
num_heads
,
self
.
head_dim
))
self
.
pos_bias_v
=
Parameter
(
torch
.
Tensor
(
self
.
num_heads
,
self
.
head_dim
))
nn
.
init
.
xavier_uniform_
(
self
.
linear_pos
.
weight
)
#
nn.init.xavier_uniform_(self.linear_pos.weight)
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_u
)
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_u
)
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_v
)
nn
.
init
.
xavier_normal_
(
self
.
pos_bias_v
)
...
@@ -109,6 +109,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -109,6 +109,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
if
(
if
(
False
and
not
self
.
onnx_trace
not
self
.
onnx_trace
and
not
is_tpu
# don't use PyTorch version on TPUs
and
not
is_tpu
# don't use PyTorch version on TPUs
and
incremental_state
is
None
and
incremental_state
is
None
...
@@ -196,6 +197,8 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -196,6 +197,8 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# .view(tgt_len, bsz * self.num_heads, self.head_dim)
# .view(tgt_len, bsz * self.num_heads, self.head_dim)
# .transpose(0, 1)
# .transpose(0, 1)
# )
# )
# prepare q for RPE # (tgt_len, bsz, num_heads, head_dim)
q
=
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
self
.
num_heads
,
self
.
head_dim
)
if
k
is
not
None
:
if
k
is
not
None
:
k
=
(
k
=
(
k
.
contiguous
()
k
.
contiguous
()
...
@@ -279,18 +282,19 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -279,18 +282,19 @@ class RelPositionMultiheadAttention(MultiheadAttention):
dim
=
1
,
dim
=
1
,
)
)
pos_emb
=
pos_emb
.
transpose
(
0
,
1
)
p_rep
=
self
.
linear_pos
(
pos_emb
)
.
view
(
bsz
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
p_rep
=
self
.
linear_pos
(
pos_emb
)
.
view
(
bsz
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
p_rep
=
p_rep
.
contiguous
()
.
transpose
(
1
,
2
)
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
p_rep
=
p_rep
.
transpose
(
1
,
2
)
.
contiguous
(
)
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
# (batch * head, time1, d_k)
# (batch * head, time1, d_k)
q_with_bias_u
=
(
q_with_bias_u
=
(
(
q
+
self
.
pos_bias_u
)
.
contiguous
()
(
q
+
self
.
pos_bias_u
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
)
)
# (batch * head, time1, d_k)
# (batch * head, time1, d_k)
q_with_bias_v
=
(
q_with_bias_v
=
(
(
q
+
self
.
pos_bias_v
)
.
contiguous
()
(
q
+
self
.
pos_bias_v
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
)
)
...
@@ -298,15 +302,15 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -298,15 +302,15 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# compute attention score
# compute attention score
# first compute matrix a and matrix c
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch
,
head, time1, time2)
# (batch
*
head, time1, time2)
matrix_ac
=
torch
.
bmm
(
q_with_bias_u
,
k
.
transpose
(
1
,
2
))
matrix_ac
=
torch
.
bmm
(
q_with_bias_u
,
k
.
transpose
(
1
,
2
))
# compute matrix b and matrix d
# compute matrix b and matrix d
# (batch
,
head, time1, time2)
# (batch
*
head, time1, time2)
matrix_bd
=
torch
.
bmm
(
q_with_bias_v
,
p_rep
.
transpose
(
1
,
2
))
matrix_bd
=
torch
.
bmm
(
q_with_bias_v
,
p_rep
.
transpose
(
1
,
2
))
def
rel_shift
(
x
,
zero_triu
=
False
):
def
rel_shift
(
x
,
zero_triu
=
False
):
"""Compute relative positinal encoding.
"""Compute relative positi
o
nal encoding.
:param torch.Tensor x: (batch, time, size)
:param torch.Tensor x: (batch, time, size)
:param bool zero_triu: return the lower triangular part of the matrix
:param bool zero_triu: return the lower triangular part of the matrix
...
@@ -323,8 +327,11 @@ class RelPositionMultiheadAttention(MultiheadAttention):
...
@@ -323,8 +327,11 @@ class RelPositionMultiheadAttention(MultiheadAttention):
return
x
return
x
matrix_bd
=
rel_shift
(
matrix_bd
)
matrix_bd
=
matrix_bd
.
contiguous
()
.
view
(
bsz
,
self
.
num_heads
,
matrix_bd
.
size
(
-
2
),
matrix_bd
.
size
(
-
1
))
attn_weights
=
(
matrix_ac
+
matrix_bd
)
/
self
.
scaling
matrix_bd
=
rel_shift
(
matrix_bd
,
)
.
contiguous
()
.
view
(
bsz
*
self
.
num_heads
,
matrix_bd
.
size
(
-
2
),
matrix_bd
.
size
(
-
1
))
attn_weights
=
(
matrix_ac
+
matrix_bd
)
*
self
.
scaling
# attn_weights = torch.bmm(q, k.transpose(1, 2))
# attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
...
...
fairseq/modules/transformer_layer.py
查看文件 @
dd402ec2
...
@@ -35,6 +35,7 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -35,6 +35,7 @@ class TransformerEncoderLayer(nn.Module):
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise
=
getattr
(
args
,
'quant_noise_pq'
,
0
)
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
quant_noise_block_size
=
getattr
(
args
,
'quant_noise_pq_block_size'
,
8
)
or
8
self
.
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
args
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
args
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout_module
=
FairseqDropout
(
self
.
dropout_module
=
FairseqDropout
(
...
@@ -77,13 +78,12 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -77,13 +78,12 @@ class TransformerEncoderLayer(nn.Module):
)
)
def
build_self_attention
(
self
,
embed_dim
,
args
):
def
build_self_attention
(
self
,
embed_dim
,
args
):
attn_type
=
getattr
(
args
,
"encoder_attention_type"
,
"selfattn"
)
if
self
.
attn_type
==
"selfattn"
:
if
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
attn_type
)
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
return
attn_func
(
return
attn_func
(
...
@@ -112,7 +112,10 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -112,7 +112,10 @@ class TransformerEncoderLayer(nn.Module):
state_dict
[
"{}.{}.{}"
.
format
(
name
,
new
,
m
)]
=
state_dict
[
k
]
state_dict
[
"{}.{}.{}"
.
format
(
name
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
del
state_dict
[
k
]
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
attn_mask
:
Optional
[
Tensor
]
=
None
):
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_emb
:
Optional
[
Tensor
]
=
None
):
"""
"""
Args:
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
...
@@ -124,6 +127,7 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -124,6 +127,7 @@ class TransformerEncoderLayer(nn.Module):
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
useful for strided self-attention.
positions (Tensor): the position embedding for relative position encoding
Returns:
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
encoded output of shape `(seq_len, batch, embed_dim)`
...
@@ -139,14 +143,26 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -139,14 +143,26 @@ class TransformerEncoderLayer(nn.Module):
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
x
,
_
=
self
.
self_attn
(
if
self
.
attn_type
==
"rel_selfattn"
:
query
=
x
,
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
key
=
x
,
x
,
_
=
self
.
self_attn
(
value
=
x
,
query
=
x
,
key_padding_mask
=
encoder_padding_mask
,
key
=
x
,
need_weights
=
False
,
value
=
x
,
attn_mask
=
attn_mask
,
key_padding_mask
=
encoder_padding_mask
,
)
need_weights
=
False
,
attn_mask
=
attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
if
not
self
.
normalize_before
:
...
@@ -195,6 +211,7 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -195,6 +211,7 @@ class TransformerDecoderLayer(nn.Module):
self
.
cross_self_attention
=
getattr
(
args
,
"cross_self_attention"
,
False
)
self
.
cross_self_attention
=
getattr
(
args
,
"cross_self_attention"
,
False
)
self
.
attn_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
self
.
self_attn
=
self
.
build_self_attention
(
self
.
self_attn
=
self
.
build_self_attention
(
self
.
embed_dim
,
self
.
embed_dim
,
args
,
args
,
...
@@ -256,13 +273,12 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -256,13 +273,12 @@ class TransformerDecoderLayer(nn.Module):
def
build_self_attention
(
def
build_self_attention
(
self
,
embed_dim
,
args
,
add_bias_kv
=
False
,
add_zero_attn
=
False
self
,
embed_dim
,
args
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
):
attn_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
if
self
.
attn_type
==
"selfattn"
:
if
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
attn_func
=
MultiheadAttention
elif
attn_type
==
"rel_selfattn"
:
elif
self
.
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
attn_func
=
RelPositionMultiheadAttention
else
:
else
:
print
(
"The attention type
%
s is not supported!"
%
attn_type
)
print
(
"The attention type
%
s is not supported!"
%
self
.
attn_type
)
exit
(
1
)
exit
(
1
)
return
attn_func
(
return
attn_func
(
...
@@ -277,16 +293,7 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -277,16 +293,7 @@ class TransformerDecoderLayer(nn.Module):
)
)
def
build_encoder_attention
(
self
,
embed_dim
,
args
):
def
build_encoder_attention
(
self
,
embed_dim
,
args
):
attn_type
=
getattr
(
args
,
"decoder_attention_type"
,
"selfattn"
)
return
MultiheadAttention
(
if
attn_type
==
"selfattn"
:
attn_func
=
MultiheadAttention
elif
attn_type
==
"rel_selfattn"
:
attn_func
=
RelPositionMultiheadAttention
else
:
print
(
"The attention type
%
s is not supported!"
%
attn_type
)
exit
(
1
)
return
attn_func
(
embed_dim
,
embed_dim
,
args
.
decoder_attention_heads
,
args
.
decoder_attention_heads
,
kdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
kdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
...
@@ -315,6 +322,7 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -315,6 +322,7 @@ class TransformerDecoderLayer(nn.Module):
self_attn_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
self_attn_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_attn
:
bool
=
False
,
need_attn
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
pos_emb
:
Optional
[
Tensor
]
=
None
,
):
):
"""
"""
Args:
Args:
...
@@ -370,15 +378,28 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -370,15 +378,28 @@ class TransformerDecoderLayer(nn.Module):
else
:
else
:
y
=
x
y
=
x
x
,
attn
=
self
.
self_attn
(
if
self
.
attn_type
==
"rel_selfattn"
:
query
=
x
,
assert
pos_emb
is
not
None
,
"Positions is necessary for RPE!"
key
=
y
,
x
,
attn
=
self
.
self_attn
(
value
=
y
,
query
=
x
,
key_padding_mask
=
self_attn_padding_mask
,
key
=
y
,
incremental_state
=
incremental_state
,
value
=
y
,
need_weights
=
False
,
key_padding_mask
=
self_attn_padding_mask
,
attn_mask
=
self_attn_mask
,
incremental_state
=
incremental_state
,
)
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
pos_emb
=
pos_emb
)
else
:
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
if
not
self
.
normalize_before
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论