# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import hashlib
import logging
import math

import numpy as np
from fairseq.data import SampledMultiDataset

from .sampled_multi_dataset import CollateFormat, default_virtual_size_func


logger = logging.getLogger(__name__)


class SampledMultiEpochDataset(SampledMultiDataset):
    """Samples from multiple sub-datasets according to sampling ratios
       using virtual epoch sizes to speed up dataloading.
    Args:
        datasets (
            List[~torch.utils.data.Dataset]
            or OrderedDict[str, ~torch.utils.data.Dataset]
        ): datasets
        sampling_ratios (List[float]): list of probability of each dataset to be sampled
            (default: None, which corresponds to concating all dataset together).
        seed (int): RNG seed to use (default: 2).
        epoch (int): starting epoch number (default: 1).
        eval_key (str, optional): a key used at evaluation time that causes
            this instance to pass-through batches from *datasets[eval_key]*.
        collate_format (CollateFormat):  collater output format, either CollateFormat.ordered_dict or
            CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
            the collater to output batches of data mixed from all sub-datasets,
            and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
            of sub-datasets.
            Note that not all sub-datasets will present in a single batch in both formats.
        virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
        split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
        virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by
            this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering
            can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded.
        shared_collater (bool): whether or not to all sub-datasets have the same collater.
        shard_epoch (int): the real epoch number for shard selection.
        shuffle (bool): whether or not to shuffle data (default: True).
    """

    def __init__(
        self,
        datasets,
        sampling_ratios=None,
        seed=2,
        epoch=1,
        eval_key=None,
        collate_format=CollateFormat.single,
        virtual_size=default_virtual_size_func,
        split="",
        virtual_epoch_size=None,
        shared_collater=False,
        shard_epoch=1,
        shuffle=True,
    ):
        self.virtual_epoch_size = virtual_epoch_size
        self._current_epoch_start_index = None
        self._random_global_indices = None
        self.shard_epoch = shard_epoch if shard_epoch is not None else 1
        self.load_next_shard = None
        self._epoch_sizes = None
        super().__init__(
            datasets=datasets,
            sampling_ratios=sampling_ratios,
            seed=seed,
            epoch=epoch,
            eval_key=eval_key,
            collate_format=collate_format,
            virtual_size=virtual_size,
            split=split,
            shared_collater=shared_collater,
            shuffle=shuffle,
        )

    def _setup(self, epoch):
        self.virtual_epoch_size = (
            self.virtual_epoch_size
            if self.virtual_epoch_size is not None
            else self.virtual_size
        )
        if self.virtual_epoch_size > self.virtual_size:
            logger.warning(
                f"virtual epoch size {self.virtual_epoch_size} "
                f"is greater than virtual dataset size {self.virtual_size}"
            )
            self.virtual_epoch_size = self.virtual_size
        self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size)
        self._current_epoch_start_index = self._get_epoch_start_index(epoch)
        logger.info(
            f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}"
        )

    def _map_epoch_index_to_global(self, index):
        index = self._current_epoch_start_index + index
        # add randomness
        return self._random_global_indices[index]

    @property
    def sizes(self):
        if self._epoch_sizes is not None:
            return self._epoch_sizes
        _sizes = super().sizes
        indices = self._random_global_indices[
            self._current_epoch_start_index : self._current_epoch_start_index
            + len(self)
        ]
        self._epoch_sizes = _sizes[indices]
        # del super()._sizes to save memory
        del self._sizes
        self._sizes = None
        return self._epoch_sizes

    def _get_dataset_and_index(self, index):
        i = self._map_epoch_index_to_global(index)
        return super()._get_dataset_and_index(i)

    def __len__(self):
        return (
            self.virtual_epoch_size
            if self._current_epoch_start_index + self.virtual_epoch_size
            < self.virtual_size
            else self.virtual_size - self._current_epoch_start_index
        )

    def set_epoch(self, epoch):
        if self._current_epoch_start_index is None:
            # initializing epoch idnices of a virtual dataset
            self._setup(epoch)
            self._next_virtual_epoch(epoch)
        else:
            # working on already intialized epoch indices
            if epoch == self._cur_epoch:
                # re-enter so return
                return
            self._next_virtual_epoch(epoch)

    def _get_epoch_start_index(self, epoch):
        assert epoch >= 1  # fairseq is using 1-based epoch everywhere
        return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size

    def _next_global_indices(self, epoch):
        rng = np.random.RandomState(
            [
                int(
                    hashlib.sha1(
                        str(self.__class__.__name__).encode("utf-8")
                    ).hexdigest(),
                    16,
                )
                % (2 ** 32),
                self.seed % (2 ** 32),  # global seed
                epoch,  # epoch index,
            ]
        )
        del self._random_global_indices
        self._random_global_indices = rng.choice(
            self.virtual_size, self.virtual_size, replace=False
        )
        if self.load_next_shard is None:
            self.load_next_shard = False
        else:
            # increase shard epoch for next loading
            self.shard_epoch += 1
            self.load_next_shard = True
            logger.info(
                "to load next epoch/shard in next load_dataset: "
                f"epoch={epoch}/shard_epoch={self.shard_epoch}"
            )

    def _next_virtual_epoch(self, epoch):
        index = self._get_epoch_start_index(epoch)
        if index == 0 or self._random_global_indices is None:
            # need to start from the beginning,
            # so call super().set_epoch(epoch) to establish the global virtual indices
            logger.info(
                "establishing a new set of global virtual indices for "
                f"epoch={epoch}/shard_epoch={self.shard_epoch}"
            )
            super().set_epoch(epoch)
            self._next_global_indices(epoch)
        else:
            self._cur_epoch = epoch

        # reset cache sizes and ordered_indices for the epoch after moving to a new epoch
        self._clean_if_not_none(
            [
                self._epoch_sizes,
            ]
        )
        self._epoch_sizes = None
        self._current_epoch_start_index = index