# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os

from fairseq import checkpoint_utils, tasks, utils

from . import DEFAULT_EOS, GET, SEND
from .agent import Agent


class SimulTransAgent(Agent):
    def __init__(self, args):
        # Load Model
        self.load_model(args)

        # build word spliter
        self.build_word_splitter(args)

        self.max_len = args.max_len

        self.eos = DEFAULT_EOS

    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--model-path', type=str, required=True,
                            help='path to your pretrained model.')
        parser.add_argument("--data-bin", type=str, required=True,
                            help="Path of data binary")
        parser.add_argument("--user-dir", type=str, default="example/simultaneous_translation",
                            help="User directory for simultaneous translation")
        parser.add_argument("--src-splitter-type", type=str, default=None,
                            help="Subword splitter type for source text")
        parser.add_argument("--tgt-splitter-type", type=str, default=None,
                            help="Subword splitter type for target text")
        parser.add_argument("--src-splitter-path", type=str, default=None,
                            help="Subword splitter model path for source text")
        parser.add_argument("--tgt-splitter-path", type=str, default=None,
                            help="Subword splitter model path for target text")
        parser.add_argument("--max-len", type=int, default=150,
                            help="Maximum length difference between source and target prediction")
        parser.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
                            help='A dictionary used to override model args at generation '
                                 'that were used during model training')
        # fmt: on
        return parser

    def load_dictionary(self, task):
        raise NotImplementedError

    def load_model(self, args):
        args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
        utils.import_user_module(args)
        filename = args.model_path
        if not os.path.exists(filename):
            raise IOError("Model file not found: {}".format(filename))

        state = checkpoint_utils.load_checkpoint_to_cpu(
            filename, json.loads(args.model_overrides)
        )

        saved_args = state["args"]
        saved_args.data = args.data_bin

        task = tasks.setup_task(saved_args)

        # build model for ensemble
        self.model = task.build_model(saved_args)
        self.model.load_state_dict(state["model"], strict=True)

        # Set dictionary
        self.load_dictionary(task)

    def init_states(self):
        return {
            "indices": {"src": [], "tgt": []},
            "tokens": {"src": [], "tgt": []},
            "segments": {"src": [], "tgt": []},
            "steps": {"src": 0, "tgt": 0},
            "finished": False,
            "finish_read": False,
            "model_states": {},
        }

    def update_states(self, states, new_state):
        raise NotImplementedError

    def policy(self, states):
        # Read and Write policy
        action = None

        while action is None:
            if states["finished"]:
                # Finish the hypo by sending eos to server
                return self.finish_action()

            # Model make decision given current states
            decision = self.model.decision_from_states(states)

            if decision == 0 and not self.finish_read(states):
                # READ
                action = self.read_action(states)
            else:
                # WRITE
                action = self.write_action(states)

            # None means we make decision again but not sending server anything
            # This happened when read a bufffered token
            # Or predict a subword
        return action

    def finish_read(self, states):
        raise NotImplementedError

    def write_action(self, states):
        token, index = self.model.predict_from_states(states)

        if (
            index == self.dict["tgt"].eos()
            or len(states["tokens"]["tgt"]) > self.max_len
        ):
            # Finish this sentence is predict EOS
            states["finished"] = True
            end_idx_last_full_word = self._target_length(states)

        else:
            states["tokens"]["tgt"] += [token]
            end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
                states["tokens"]["tgt"]
            )
            self._append_indices(states, [index], "tgt")

        if end_idx_last_full_word > states["steps"]["tgt"]:
            # Only sent detokenized full words to the server
            word = self.word_splitter["tgt"].merge(
                states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
            )
            states["steps"]["tgt"] = end_idx_last_full_word
            states["segments"]["tgt"] += [word]

            return {"key": SEND, "value": word}
        else:
            return None

    def read_action(self, states):
        return {"key": GET, "value": None}

    def finish_action(self):
        return {"key": SEND, "value": DEFAULT_EOS}

    def reset(self):
        pass

    def finish_eval(self, states, new_state):
        if len(new_state) == 0 and len(states["indices"]["src"]) == 0:
            return True
        return False

    def _append_indices(self, states, new_indices, key):
        states["indices"][key] += new_indices

    def _target_length(self, states):
        return len(states["tokens"]["tgt"])