# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import sys

from scorers import build_scorer
from tornado import ioloop, web


DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321


class ScorerHandler(web.RequestHandler):
    def initialize(self, scorer):
        self.scorer = scorer


class EvalSessionHandler(ScorerHandler):
    def post(self):
        self.scorer.reset()

    def get(self):
        r = json.dumps(self.scorer.get_info())
        self.write(r)


class ResultHandler(ScorerHandler):
    def get(self):
        r = json.dumps(self.scorer.score())
        self.write(r)


class SourceHandler(ScorerHandler):
    def get(self):
        sent_id = int(self.get_argument("sent_id"))
        segment_size = None
        if "segment_size" in self.request.arguments:
            string = self.get_argument("segment_size")
            if len(string) > 0:
                segment_size = int(string)

        r = json.dumps(self.scorer.send_src(int(sent_id), segment_size))

        self.write(r)


class HypothesisHandler(ScorerHandler):
    def put(self):
        sent_id = int(self.get_argument("sent_id"))
        list_of_tokens = self.request.body.decode("utf-8").strip().split()
        self.scorer.recv_hyp(sent_id, list_of_tokens)


def add_args():
    parser = argparse.ArgumentParser()
    # fmt: off
    parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
                        help='Server hostname')
    parser.add_argument('--port', type=int, default=DEFAULT_PORT,
                        help='Server port number')

    args, _ = parser.parse_known_args()
    # fmt: on
    return args


def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
    app = web.Application(
        [
            (r"/result", ResultHandler, dict(scorer=scorer)),
            (r"/src", SourceHandler, dict(scorer=scorer)),
            (r"/hypo", HypothesisHandler, dict(scorer=scorer)),
            (r"/", EvalSessionHandler, dict(scorer=scorer)),
        ],
        debug=debug,
    )
    app.listen(port, max_buffer_size=1024 ** 3)
    sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
    ioloop.IOLoop.current().start()


if __name__ == "__main__":
    args = add_args()
    scorer = build_scorer(args)
    start_server(scorer, args.hostname, args.port, args.debug)