simul_trans_agent.py 5.96 KB
Newer Older
xuchen committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os

from fairseq import checkpoint_utils, utils, tasks

from . import DEFAULT_EOS, GET, SEND
from .agent import Agent


class SimulTransAgent(Agent):
    def __init__(self, args):
        # Load Model
        self.load_model(args)

        # build word spliter
        self.build_word_splitter(args)

        self.max_len = args.max_len

        self.eos = DEFAULT_EOS

    @staticmethod
    def add_args(parser):
        parser.add_argument(
            "--model-path",
            type=str,
            required=True,
            help="path to your pretrained model.",
        )
        parser.add_argument(
            "--data-bin", type=str, required=True, help="Path of data binary"
        )
        parser.add_argument(
            "--user-dir",
            type=str,
            default="example/simultaneous_translation",
            help="User directory for simultaneous translation",
        )
        parser.add_argument(
            "--src-splitter-type",
            type=str,
            default=None,
            help="Subword splitter type for source text",
        )
        parser.add_argument(
            "--tgt-splitter-type",
            type=str,
            default=None,
            help="Subword splitter type for target text",
        )
        parser.add_argument(
            "--src-splitter-path",
            type=str,
            default=None,
            help="Subword splitter model path for source text",
        )
        parser.add_argument(
            "--tgt-splitter-path",
            type=str,
            default=None,
            help="Subword splitter model path for target text",
        )
        parser.add_argument(
            "--max-len",
            type=int,
            default=150,
            help="Maximum length difference between source and target prediction",
        )
        parser.add_argument(
            "--model-overrides",
            default="{}",
            type=str,
            metavar="DICT",
            help="A dictionary used to override model args at generation "
            "that were used during model training",
        )
        # fmt: on
        return parser

    def load_dictionary(self, task):
        raise NotImplementedError

    def load_model(self, args):
        args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
        utils.import_user_module(args)
        filename = args.model_path
        if not os.path.exists(filename):
            raise IOError("Model file not found: {}".format(filename))

        state = checkpoint_utils.load_checkpoint_to_cpu(
            filename, json.loads(args.model_overrides)
        )

        saved_args = state["args"]
        saved_args.data = args.data_bin

        task = tasks.setup_task(saved_args)

        # build model for ensemble
        self.model = task.build_model(saved_args)
        self.model.load_state_dict(state["model"], strict=True)

        # Set dictionary
        self.load_dictionary(task)

    def init_states(self):
        return {
            "indices": {"src": [], "tgt": []},
            "tokens": {"src": [], "tgt": []},
            "segments": {"src": [], "tgt": []},
            "steps": {"src": 0, "tgt": 0},
            "finished": False,
            "finish_read": False,
            "model_states": {},
        }

    def update_states(self, states, new_state):
        raise NotImplementedError

    def policy(self, states):
        # Read and Write policy
        action = None

        while action is None:
            if states["finished"]:
                # Finish the hypo by sending eos to server
                return self.finish_action()

            # Model make decision given current states
            decision = self.model.decision_from_states(states)

            if decision == 0 and not self.finish_read(states):
                # READ
                action = self.read_action(states)
            else:
                # WRITE
                action = self.write_action(states)

            # None means we make decision again but not sending server anything
            # This happened when read a buffered token
            # Or predict a subword
        return action

    def finish_read(self, states):
        raise NotImplementedError

    def write_action(self, states):
        token, index = self.model.predict_from_states(states)

        if (
            index == self.dict["tgt"].eos()
            or len(states["tokens"]["tgt"]) > self.max_len
        ):
            # Finish this sentence is predict EOS
            states["finished"] = True
            end_idx_last_full_word = self._target_length(states)

        else:
            states["tokens"]["tgt"] += [token]
            end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
                states["tokens"]["tgt"]
            )
            self._append_indices(states, [index], "tgt")

        if end_idx_last_full_word > states["steps"]["tgt"]:
            # Only sent detokenized full words to the server
            word = self.word_splitter["tgt"].merge(
                states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
            )
            states["steps"]["tgt"] = end_idx_last_full_word
            states["segments"]["tgt"] += [word]

            return {"key": SEND, "value": word}
        else:
            return None

    def read_action(self, states):
        return {"key": GET, "value": None}

    def finish_action(self):
        return {"key": SEND, "value": DEFAULT_EOS}

    def reset(self):
        pass

    def finish_eval(self, states, new_state):
        if len(new_state) == 0 and len(states["indices"]["src"]) == 0:
            return True
        return False

    def _append_indices(self, states, new_indices, key):
        states["indices"][key] += new_indices

    def _target_length(self, states):
        return len(states["tokens"]["tgt"])