Commit d3bef363 by xuchen

adaptive softmax bug fix

parent d4a68f26
...@@ -95,7 +95,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -95,7 +95,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
return loss, sample_size, logging_output return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample): def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True, sample=sample)
target = model.get_targets(sample, net_output) target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0: if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False): if getattr(lprobs, "batch_first", False):
......
...@@ -220,6 +220,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -220,6 +220,11 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
action="store_true", action="store_true",
help="if True, dont scale embeddings", help="if True, dont scale embeddings",
) )
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--max-encoder-relative-length', type=int, default=-1, parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max relative length') help='the max relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1, parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
...@@ -526,6 +531,30 @@ class TransformerDecoderScriptable(TransformerDecoder): ...@@ -526,6 +531,30 @@ class TransformerDecoderScriptable(TransformerDecoder):
) )
return x, None return x, None
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
return out.exp_() if not log_probs else out
logits = net_output[0]
if log_probs:
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
else:
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") @register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args): def base_architecture(args):
...@@ -554,6 +583,10 @@ def base_architecture(args): ...@@ -554,6 +583,10 @@ def base_architecture(args):
args.activation_fn = getattr(args, "activation_fn", "relu") args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
...@@ -586,7 +619,7 @@ def s2t_transformer_s(args): ...@@ -586,7 +619,7 @@ def s2t_transformer_s(args):
@register_model_architecture("s2t_transformer", "s2t_transformer_s_relative") @register_model_architecture("s2t_transformer", "s2t_transformer_s_relative")
def s2t_transformer_s_relative(args): def s2t_transformer_s_relative(args):
args.max_encoder_relative_length = 20 args.max_encoder_relative_length = 100
args.max_decoder_relative_length = 20 args.max_decoder_relative_length = 20
args.k_only = True args.k_only = True
s2t_transformer_s(args) s2t_transformer_s(args)
......
...@@ -1150,6 +1150,10 @@ def base_architecture(args): ...@@ -1150,6 +1150,10 @@ def base_architecture(args):
args.dropout = getattr(args, "dropout", 0.1) args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论