Commit ad1caa72 by xuchen

fix the implementation of the adaptor

parent 9e1dfcf2
...@@ -13,6 +13,7 @@ seed: 1 ...@@ -13,6 +13,7 @@ seed: 1
report-accuracy: True report-accuracy: True
arch: s2t_transformer_s arch: s2t_transformer_s
arch: s2t_sate
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -40,6 +41,8 @@ macaron-style: True ...@@ -40,6 +41,8 @@ macaron-style: True
use-cnn-module: True use-cnn-module: True
cnn-module-kernel: 31 cnn-module-kernel: 31
adpater: subsample
#decoder-embed-dim: 256 #decoder-embed-dim: 256
#decoder-ffn-embed-dim: 2048 #decoder-ffn-embed-dim: 2048
#decoder-attention-heads: 4 #decoder-attention-heads: 4
......
...@@ -122,7 +122,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -122,7 +122,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi fi
source ~/tools/audio/bin/activate source ~/tools/audio/bin/activate
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_asr_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
--output-root ${data_dir} --output-root ${data_dir}
--task asr --task asr
...@@ -137,7 +137,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -137,7 +137,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
asr_prefix=spm_${vocab_type}${asr_vocab_size}_asr asr_prefix=spm_${vocab_type}${asr_vocab_size}_asr
echo "stage 0: ST Data Preparation" echo "stage 0: ST Data Preparation"
cmd="python ${root_dir}/examples/speech_to_text/prep_mustc_data.py cmd="python ${root_dir}/examples/speech_to_text/prep_st_data.py
--data-root ${org_data_dir} --data-root ${org_data_dir}
--output-root ${data_dir} --output-root ${data_dir}
--task st --task st
......
...@@ -136,7 +136,7 @@ class Adapter(nn.Module): ...@@ -136,7 +136,7 @@ class Adapter(nn.Module):
adapter_type = getattr(args, "adapter", "league") adapter_type = getattr(args, "adapter", "league")
self.adapter_type = adapter_type self.adapter_type = adapter_type
if adapter_type in ["linear", "league"]: if adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.linear_adapter = nn.Sequential( self.linear_adapter = nn.Sequential(
nn.Linear(attention_dim, attention_dim), nn.Linear(attention_dim, attention_dim),
LayerNorm(args.encoder_embed_dim), LayerNorm(args.encoder_embed_dim),
...@@ -156,7 +156,7 @@ class Adapter(nn.Module): ...@@ -156,7 +156,7 @@ class Adapter(nn.Module):
[int(k) for k in args.conv_kernel_sizes.split(",")], [int(k) for k in args.conv_kernel_sizes.split(",")],
) )
if adapter_type in ["embed", "context", "league", "gated_league"]: if adapter_type in ["embed", "context", "league", "gated_league", "gated_league2"]:
if embed_tokens is None: if embed_tokens is None:
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.embed_adapter = Embedding(num_embeddings, attention_dim, self.padding_idx) self.embed_adapter = Embedding(num_embeddings, attention_dim, self.padding_idx)
...@@ -187,7 +187,9 @@ class Adapter(nn.Module): ...@@ -187,7 +187,9 @@ class Adapter(nn.Module):
out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1) out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "subsample": elif self.adapter_type == "subsample":
out = self.subsample_adaptor(representation, lengths) representation = representation.transpose(0, 1)
out, input_lengths = self.subsample_adaptor(representation, lengths)
padding = lengths_to_padding_mask(input_lengths)
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
...@@ -195,7 +197,7 @@ class Adapter(nn.Module): ...@@ -195,7 +197,7 @@ class Adapter(nn.Module):
out = linear_out + soft_out out = linear_out + soft_out
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
soft_out = self.embed_adapter(distribution) soft_out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "none": elif self.adapter_type == "none":
...@@ -211,7 +213,7 @@ class Adapter(nn.Module): ...@@ -211,7 +213,7 @@ class Adapter(nn.Module):
out = self.dropout_module(out) out = self.dropout_module(out)
return out, positions return out, positions, padding
class TextEncoder(FairseqEncoder): class TextEncoder(FairseqEncoder):
...@@ -301,7 +303,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -301,7 +303,7 @@ class S2TSATEEncoder(FairseqEncoder):
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature) ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature)
x = (encoder_out, ctc_prob) x = (encoder_out, ctc_prob)
x, positions = self.adapter(x, encoder_padding_mask) x, positions, encoder_padding_mask = self.adapter(x, encoder_padding_mask)
if self.history is not None: if self.history is not None:
acoustic_history = self.acoustic_encoder.history acoustic_history = self.acoustic_encoder.history
......
...@@ -387,6 +387,18 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -387,6 +387,18 @@ class S2TTransformerEncoder(FairseqEncoder):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
# self.inter_subsample = []
# for i in range(args.encoder_layers // 4 - 1):
# self.inter_subsample.append(
# Conv1dSubsampler(
# args.encoder_embed_dim,
# args.encoder_ffn_embed_dim,
# args.encoder_embed_dim,
# [5],
# )
# )
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim) self.layer_norm = LayerNorm(args.encoder_embed_dim)
else: else:
...@@ -437,10 +449,19 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -437,10 +449,19 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
self.history.add(x) self.history.add(x)
# layer_index = 0
for layer in self.layers: for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
# layer_index += 1
# if layer_index % 4 == 0:
# index = layer_index // 4 - 1
# x = x.transpose(0, 1)
# x, input_lengths = self.inter_subsample[index](x, input_lengths)
# encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.history is not None: if self.history is not None:
self.history.add(x) self.history.add(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论