# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field

import torch
import torch.distributed as dist
from fairseq.dataclass.configs import FairseqBMUFConfig
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim.fairseq_optimizer import FairseqOptimizer


class FairseqBMUF(FairseqOptimizer):
    """
    Implements incremental block distributed data parallelism similar to
    https://ieeexplore.ieee.org/document/7472805

    Paper title: Scalable training of deep learning machines by incremental
    block training with intra-block parallel optimization and blockwise
    model-update filtering
    """

    def __init__(self, cfg: FairseqBMUFConfig, optimizer):
        super().__init__(cfg)
        self._optimizer = optimizer
        self._num_updates = 0
        self.sync_iter = cfg.global_sync_iter
        self.block_momentum = cfg.block_momentum
        self.block_lr = cfg.block_lr
        self._reset_local_data()
        self.warmup_iteration = cfg.warmup_iterations
        self.use_nbm = cfg.use_nbm
        self.initial_state = self._optimizer.state_dict()
        self.average_sync = self.cfg.average_sync
        self.world_size = self.cfg.distributed_world_size

    @staticmethod
    def add_args(parser):
        """Add optimizer-specific arguments to the parser."""
        gen_parser_from_dataclass(parser, FairseqBMUFConfig())

    @property
    def optimizer(self):
        return self._optimizer.optimizer

    @property
    def optimizer_config(self):
        return self._optimizer.optimizer_config

    def get_lr(self):
        return self._optimizer.get_lr()

    def set_lr(self, lr):
        self._optimizer.set_lr(lr)

    def state_dict(self):
        return self._optimizer.state_dict()

    def load_state_dict(self, state_dict, optimizer_overrides=None):
        self._optimizer.load_state_dict(state_dict, optimizer_overrides)
        self.initial_state = self._optimizer.state_dict()

    def multiply_grads(self, c):
        """Multiplies grads by a constant *c*."""
        self._optimizer.multiply_grads(c)

    def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
        """Clips gradient norm."""
        return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn)

    def average_params(self):
        self._optimizer.average_params()

    def _block_sync(self):
        if self.world_size <= 1:
            return
        # Update the global model using local models from all GPUs
        # (Step-1) Calculate grad between previously synced model and
        # currrent local model
        if self.block_momentum != 0:
            self._calc_grad()

        # (Step-2) Average gradient from all GPUs
        self._avg_grad_from_all_gpus()

        # (Step-3) Calculate global momentum and update the global model
        if self.block_momentum != 0:
            self._update_global_model()

        # (Step-4) Average local optimizer params
        if self.average_sync:
            self.average_params()

    def _is_warmup_end(self):
        # Check whether train iterations is equal to warmup iter
        if self.get_num_updates() == self.warmup_iteration:
            return True
        return False

    def _is_bmuf_iter(self):
        # Check whether train iterations is equal to bmuf sync iter
        if (self.get_num_updates() > self.warmup_iteration) and (
            self.get_num_updates() % self.sync_iter == 0
        ):
            return True
        return False

    def _warmup_sync(self, root_rank=0):
        if self.world_size <= 1:
            return
        # Broadcast the local model to all gpus
        for param in self.params:
            dist.broadcast(param.data, src=root_rank)

        # Update local optimizer state
        if self.average_sync:
            self._optimizer.average_params()
        else:
            self._optimizer.load_state_dict(self.initial_state)

        self._reset_local_data()

    def step(self, closure=None):
        """Performs a single optimization step."""
        self._optimizer.step(closure)
        self.set_num_updates(self.get_num_updates() + 1)
        if self._is_warmup_end():
            self._warmup_sync()
        elif self._is_bmuf_iter():
            self._block_sync()

    def zero_grad(self):
        """Clears the gradients of all optimized parameters."""
        self._optimizer.zero_grad()

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self._num_updates = num_updates

    @torch.no_grad()
    def _reset_local_data(self):
        # (Step-0) Initialize global momentum parameters and store global copy on each gpu
        self.global_params = [torch.zeros_like(p.data) for p in self.params]
        self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
        self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]

        # saving the global model locally for calculating gradient during bmuf sync
        for param, global_param in zip(self.params, self.global_params):
            global_param.copy_(param.data)

    @torch.no_grad()
    def _calc_grad(self):
        # global_params is basically the global copy from the previously finished
        # synchronisation. param.data is local parameter after block_sync_freq
        # for the local gpu. so grad is difference between previously synced
        # model and currrent local model.
        for index, (param, global_param) in enumerate(
            zip(self.params, self.global_params)
        ):
            self.grads[index] = global_param - param.data

    def _avg_grad_from_all_gpus(self):
        for index, param in enumerate(self.params):
            sync_para = param.data if self.block_momentum == 0 else self.grads[index]
            sync_para /= float(dist.get_world_size())
            dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)

    @torch.no_grad()
    def _update_global_model(self):
        for index, (param, global_param, smoothed_grad, grad) in enumerate(
            zip(
                self.params,
                self.global_params,
                self.smoothed_grads,
                # all gpus would share the same value of smoothed_grad, since it is
                # always computed on synchronized gradients.
                self.grads,
            )
        ):
            # global_param is basically last syncrhornized parameter. though
            # smoothed_grad is local, all processes will have same value of
            # smoothed_grad and hence param is globally synchronized copy.
            # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
            smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
            param.data.copy_(global_param - smoothed_grad)

            # A Nesterov momentum here is to do a partial weight update before
            # calculating the gradient
            if self.use_nbm:
                param.data.copy_(param.data - self.block_momentum * smoothed_grad)

            # backup for the next synchronization.
            self.smoothed_grads[index] = smoothed_grad
            global_param.copy_(param.data)