fairseq_lr_scheduler.py 1.98 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from argparse import Namespace

from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim import FairseqOptimizer


class FairseqLRScheduler(object):
    def __init__(self, cfg, optimizer):
        super().__init__()
        if optimizer is not None and not isinstance(optimizer, FairseqOptimizer):
            raise ValueError("optimizer must be an instance of FairseqOptimizer")
        self.cfg = cfg
        self.optimizer = optimizer
        self.best = None

    @classmethod
    def add_args(cls, parser):
        """Add arguments to the parser for this LR scheduler."""
        dc = getattr(cls, "__dataclass", None)
        if dc is not None:
            gen_parser_from_dataclass(parser, dc())

    def state_dict(self):
        """Return the LR scheduler state dict."""
        return {"best": self.best}

    def load_state_dict(self, state_dict):
        """Load an LR scheduler state dict."""
        self.best = state_dict["best"]

    def step_begin_epoch(self, epoch):
        """Update the learning rate at the beginning of the given epoch."""
        pass

    def step(self, epoch, val_loss=None):
        """Update the learning rate at the end of the given epoch."""
        if val_loss is not None:
            if self.best is None:
                self.best = val_loss
            else:
                self.best = min(self.best, val_loss)

    def step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.optimizer.get_lr()


class LegacyFairseqLRScheduler(FairseqLRScheduler):
    def __init__(self, args: Namespace, optimizer):
        if not isinstance(optimizer, FairseqOptimizer):
            raise ValueError("optimizer must be an instance of FairseqOptimizer")
        self.args = args
        self.optimizer = optimizer
        self.best = None