Commit ad1caa72 by xuchen

fix the implementation of the adaptor

parent 9e1dfcf2
......@@ -13,6 +13,7 @@ seed: 1
report-accuracy: True
arch: s2t_transformer_s
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......@@ -40,6 +41,8 @@ macaron-style: True
use-cnn-module: True
cnn-module-kernel: 31
adpater: subsample
#decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4
......
......@@ -122,7 +122,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi
source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
cmd="python ${root_dir}/examples/speech_to_text/prep_asr_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--task asr
......@@ -137,7 +137,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
asr_prefix=spm_${vocab_type}${asr_vocab_size}_asr
echo "stage 0: ST Data Preparation"
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py
cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
--data-root ${org_data_dir}
--output-root ${data_dir}
--task st
......
......@@ -136,7 +136,7 @@ class Adapter(nn.Module):
adapter_type = getattr(args, "adapter", "league")
self.adapter_type = adapter_type
if adapter_type in ["linear", "league"]:
if adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim),
LayerNorm(args.encoder_embed_dim),
......@@ -156,7 +156,7 @@ class Adapter(nn.Module):
[int(k) for k in args.conv_kernel_sizes.split(",")],
)
if adapter_type in ["embed", "context", "league", "gated_league"]:
if adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None:
num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, attention_dim, self.padding_idx)
......@@ -187,7 +187,9 @@ class Adapter(nn.Module):
out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "subsample":
out = self.subsample_adaptor(representation, lengths)
representation = representation.transpose(0, 1)
out, input_lengths = self.subsample_adaptor(representation, lengths)
padding = lengths_to_padding_mask(input_lengths)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
......@@ -195,7 +197,7 @@ class Adapter(nn.Module):
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
soft_out = self.embed_adapter(distribution)
soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "none":
......@@ -211,7 +213,7 @@ class Adapter(nn.Module):
out = self.dropout_module(out)
return out, positions
return out, positions, padding
class TextEncoder(FairseqEncoder):
......@@ -301,7 +303,7 @@ class S2TSATEEncoder(FairseqEncoder):
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature)
x = (encoder_out, ctc_prob)
x, positions = self.adapter(x, encoder_padding_mask)
x, positions, encoder_padding_mask = self.adapter(x, encoder_padding_mask)
if self.history is not None:
acoustic_history = self.acoustic_encoder.history
......
......@@ -387,6 +387,18 @@ class S2TTransformerEncoder(FairseqEncoder):
self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
# self.inter_subsample = []
# for i in range(args.encoder_layers // 4 - 1):
# self.inter_subsample.append(
# Conv1dSubsampler(
# args.encoder_embed_dim,
# args.encoder_ffn_embed_dim,
# args.encoder_embed_dim,
# [5],
# )
# )
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim)
else:
......@@ -437,10 +449,19 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None:
self.history.add(x)
# layer_index = 0
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions)
# layer_index += 1
# if layer_index % 4 == 0:
# index = layer_index // 4 - 1
# x = x.transpose(0, 1)
# x, input_lengths = self.inter_subsample[index](x, input_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.history is not None:
self.history.add(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论