Commit 9e1dfcf2 by xuchen

control the temperature of the adaptor

parent d3bef363
......@@ -60,6 +60,12 @@ class S2TSATEModel(S2TTransformerModel):
help="adapter type",
)
parser.add_argument(
"--temperature",
default=1.0,
type=float,
help="temperature of the CTC softmax",
)
parser.add_argument(
"--acoustic-encoder",
default="transformer",
type=str,
......@@ -181,7 +187,7 @@ class Adapter(nn.Module):
out = torch.mm(distribution.view(-1, embed_dim), self.embed_adapter.weight).view(batch, seq_len, -1)
elif self.adapter_type == "subsample":
out = self.subsample_adaptor(x, lengths)
out = self.subsample_adaptor(representation, lengths)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
......@@ -192,6 +198,8 @@ class Adapter(nn.Module):
soft_out = self.embed_adapter(distribution)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "none":
out = representation
else:
out = None
logging.error("Unsupported adapter type: {}.".format(self.adapter_type))
......@@ -256,6 +264,7 @@ class S2TSATEEncoder(FairseqEncoder):
logging.error("Unsupported model arch {}!".format(acoustic_encoder_type))
# adapter
self.temperature = getattr(args, "temperature", 1.0)
self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
# self.length_adapter = Conv1dSubsampler(
......@@ -289,7 +298,7 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
ctc_logit = self.acoustic_encoder.compute_ctc_logit(encoder_out)
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out)
ctc_prob = self.acoustic_encoder.compute_ctc_prob(encoder_out, self.temperature)
x = (encoder_out, ctc_prob)
x, positions = self.adapter(x, encoder_padding_mask)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论