# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict

import numpy as np

from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators


class MultiItr(object):
    def __init__(self, itr):
        self.itr = itr
        self._counts = [0 for x in itr]

    def __len__(self):
        return sum(len(itr) for itr in self.itr)

    def __iter__(self):
        return self

    def __next__(self):
        ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)]
        idx = ratios.index(min(ratios))
        self._counts[idx] += 1
        return next(self.itr[idx])


class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating):
    """A wrapper around multiple epoch batch iterators."""

    def __init__(
        self,
        dataset,
        batch_sampler,
        seed=1,
        num_shards=1,
        shard_id=0,
        num_workers=0,
        epoch=1,
    ):

        assert isinstance(dataset, OrderedDict)
        assert len(dataset)
        assert isinstance(dataset[next(iter(dataset))], FairseqDataset)

        self.iterators = []

        self.epoch = epoch
        for key, dt in dataset.items():
            epoch_iter = iterators.EpochBatchIterator(
                dataset=dt,
                collate_fn=dt.collater,
                batch_sampler=batch_sampler[key],
                seed=seed,
                num_shards=num_shards,
                shard_id=shard_id,
                num_workers=0,
                epoch=epoch,
            )
            self.iterators.append(epoch_iter)

    def __len__(self):
        return sum(len(itr) for itr in self.iterators)

    def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
        # `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s.
        return MultiItr(
            [
                itr.next_epoch_itr(
                    shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus
                )
                for itr in self.iterators
            ]
        )

    def end_of_epoch(self):
        return all(itr.end_of_epoch() for itr in self.iterators)

    @property
    def next_epoch_idx(self):
        """Return the epoch index after *next_epoch_itr* is called."""

        epochs = [itr.next_epoch_idx for itr in self.iterators]
        self.epoch = epochs[0]
        assert all(epoch == self.epoch for epoch in epochs)

        return self.epoch

    @property
    def iterations_in_epoch(self):
        return sum(itr.iterations_in_epoch for itr in self.iterators)

    def state_dict(self):
        return {
            "iterators": [it.state_dict() for it in self.iterators],
            "epoch": self.epoch,
        }

    def load_state_dict(self, state_dict):
        self.epoch = state_dict["epoch"]
        for it, d in zip(self.iterators, state_dict["iterators"]):
            it.load_state_dict(d)


class MultitaskDatasetWrapper(BaseWrapperDataset):
    """A wrapper for a multitask dataset."""

    def __init__(self, dataset, target_language_id, sample=1.0, name=""):
        super().__init__(dataset)
        self.target_language_id = target_language_id
        self.sample = sample
        self.name = name

    def collater(self, *args, **kwargs):
        ans = self.dataset.collater(*args, **kwargs)
        if "net_input" in ans:
            ans["net_input"]["target_language_id"] = self.target_language_id
            ans["net_input"]["dataset_name"] = self.name
        return ans

    def num_tokens(self, *args, **kwargs):
        return self.dataset.num_tokens(*args, **kwargs)

    def ordered_indices(self, *args, **kwargs):
        indices = self.dataset.ordered_indices(*args, **kwargs)
        # Hacky solution for sampling
        size = int(self.sample * indices.shape[0])

        return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size]))

    def size(self, index: int):
        return self.dataset.size(index)

    @property
    def supports_prefetch(self):
        """Whether this dataset supports prefetching."""
        return getattr(self.dataset, "supports_prefetch", False)

    def prefetch(self, indices):
        return self.dataset.prefetch(indices)