# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse from agents import build_agent from client import SimulSTEvaluationService, SimulSTLocalEvaluationService from fairseq.registry import REGISTRIES DEFAULT_HOSTNAME = "localhost" DEFAULT_PORT = 12321 def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname" ) parser.add_argument( "--port", type=int, default=DEFAULT_PORT, help="server port number" ) parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type") parser.add_argument("--scorer-type", default="text", help="Scorer type") parser.add_argument( "--start-idx", type=int, default=0, help="Start index of the sentence to evaluate", ) parser.add_argument( "--end-idx", type=int, default=float("inf"), help="End index of the sentence to evaluate", ) parser.add_argument( "--scores", action="store_true", help="Request scores from server" ) parser.add_argument("--reset-server", action="store_true", help="Reset the server") parser.add_argument( "--num-threads", type=int, default=10, help="Number of threads used by agent" ) parser.add_argument( "--local", action="store_true", default=False, help="Local evaluation" ) args, _ = parser.parse_known_args() for registry_name, REGISTRY in REGISTRIES.items(): choice = getattr(args, registry_name, None) if choice is not None: cls = REGISTRY["registry"][choice] if hasattr(cls, "add_args"): cls.add_args(parser) args = parser.parse_args() return args if __name__ == "__main__": args = get_args() if args.local: session = SimulSTLocalEvaluationService(args) else: session = SimulSTEvaluationService(args.hostname, args.port) if args.reset_server: session.new_session() if args.agent_type is not None: agent = build_agent(args) agent.decode(session, args.start_idx, args.end_idx, args.num_threads) if args.scores: session.get_scores() print(session.get_scores())