File size: 10,261 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c817ff0
 
 
 
 
 
 
 
 
b599481
c817ff0
 
 
b599481
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""Start a Flask server to interact with the model.

Inspired by `script/ask.py`."""

import argparse
import json
import logging
import random
import uuid
from typing import Any, Dict, Tuple

import openai
from flask import Flask, request, session

from src.model.crs_model import CRSModel
from src.model.utils import get_entity, get_options

logging.basicConfig(
    format="[%(asctime)s] %(levelname)-12s %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


def parse_args() -> argparse.Namespace:
    """Parses command line arguments.

    Returns:
        Command line arguments.
    """
    parser = argparse.ArgumentParser(
        prog="serve_model.py",
        description="Start a Flask server to interact with the model.",
    )

    parser.add_argument(
        "--crs_model",
        type=str,
        choices=["kbrd", "barcor", "unicrs", "chatgpt"],
    )

    parser.add_argument(
        "--kg_dataset", type=str, choices=["redial", "opendialkg"]
    )

    # model_detailed
    parser.add_argument("--hidden_size", type=int)
    parser.add_argument("--entity_hidden_size", type=int)
    parser.add_argument("--num_bases", type=int, default=8)
    parser.add_argument("--context_max_length", type=int)
    parser.add_argument("--entity_max_length", type=int)

    # model
    parser.add_argument("--rec_model", type=str)
    parser.add_argument("--conv_model", type=str)

    # conv
    parser.add_argument("--tokenizer_path", type=str)
    parser.add_argument("--encoder_layers", type=int)
    parser.add_argument("--decoder_layers", type=int)
    parser.add_argument("--text_hidden_size", type=int)
    parser.add_argument("--attn_head", type=int)
    parser.add_argument("--resp_max_length", type=int)

    # prompt
    parser.add_argument("--api_key", type=str)
    parser.add_argument("--model", type=str)
    parser.add_argument("--text_tokenizer_path", type=str)
    parser.add_argument("--text_encoder", type=str)

    # server
    parser.add_argument("--host", type=str, default="127.0.0.1")
    parser.add_argument("--port", type=str, default="5005")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--debug", action="store_true")

    return parser.parse_args()


def get_model_args(
    model_name: str, args: argparse.Namespace
) -> Dict[str, Any]:
    """Returns model's arguments from command line arguments.

    Args:
        model_name: Model's name.
        args: Command line arguments.

    Raises:
        ValueError: If model is not supported.

    Returns:
        Model's arguments.
    """
    if model_name == "kbrd":
        return {
            "debug": args.debug,
            "kg_dataset": args.kg_dataset,
            "hidden_size": args.hidden_size,
            "entity_hidden_size": args.entity_hidden_size,
            "num_bases": args.num_bases,
            "rec_model": args.rec_model,
            "conv_model": args.conv_model,
            "context_max_length": args.context_max_length,
            "entity_max_length": args.entity_max_length,
            "tokenizer_path": args.tokenizer_path,
            "encoder_layers": args.encoder_layers,
            "decoder_layers": args.decoder_layers,
            "text_hidden_size": args.text_hidden_size,
            "attn_head": args.attn_head,
            "resp_max_length": args.resp_max_length,
            "seed": args.seed,
        }
    elif model_name == "barcor":
        return {
            "debug": args.debug,
            "kg_dataset": args.kg_dataset,
            "rec_model": args.rec_model,
            "conv_model": args.conv_model,
            "context_max_length": args.context_max_length,
            "resp_max_length": args.resp_max_length,
            "tokenizer_path": args.tokenizer_path,
            "seed": args.seed,
        }
    elif model_name == "unicrs":
        return {
            "debug": args.debug,
            "seed": args.seed,
            "kg_dataset": args.kg_dataset,
            "tokenizer_path": args.tokenizer_path,
            "context_max_length": args.context_max_length,
            "entity_max_length": args.entity_max_length,
            "resp_max_length": args.resp_max_length,
            "text_tokenizer_path": args.text_tokenizer_path,
            "rec_model": args.rec_model,
            "conv_model": args.conv_model,
            "model": args.model,
            "num_bases": args.num_bases,
            "text_encoder": args.text_encoder,
        }
    elif model_name == "chatgpt":
        openai.api_key = args.api_key
        return {
            "seed": args.seed,
            "debug": args.debug,
            "kg_dataset": args.kg_dataset,
        }

    raise ValueError(f"Model {model_name} is not supported.")


class CRSFlaskServer:
    def __init__(
        self,
        crs_model: CRSModel,
        kg_dataset: str,
        response_generation_args: Dict[str, Any] = {},
    ) -> None:
        """Initializes CRS Flask server.

        Args:
            crs_model: CRS model.
            kg_dataset: Name of knowledge graph dataset.
            response_generation_args: Arguments for response generation.
              Defaults to None.
        """
        self.crs_model = crs_model

        # Load entity data
        with open(
            f"data/{kg_dataset}/entity2id.json", "r", encoding="utf-8"
        ) as f:
            self.entity2id = json.load(f)

        self.id2entity = {int(v): k for k, v in self.entity2id.items()}
        self.entity_list = list(self.entity2id.keys())

        # Get options
        self.options = get_options(kg_dataset)

        # Response generation arguments
        self.response_generation_args = response_generation_args

        self.app = Flask(__name__)
        self.app.add_url_rule(
            "/",
            "receive_message",
            self.receive_message,
            methods=["GET", "POST"],
        )
        self.app.secret_key = str(uuid.uuid4().hex)

    def start(self, host: str = "127.0.0.1", port: str = "5005") -> None:
        """Starts the CRS Flask server.

        Args:
            host: Host address. Defaults to 127.0.0.1.
            port: Port number. Defaults to 5005.
        """
        self._host = host
        self._port = port
        self.app.run(host=host, port=port)

    def receive_message(self) -> Tuple[Dict[str, Any], int]:
        """Receives a message and returns a response.

        Returns:
            A response dictionary with the message and status code.
        """
        if request.method == "GET":
            return "Model is running.", 200
        else:
            sender_data = request.get_json()
            logger.debug(f"Received user request:\n{sender_data}")

            try:
                # Process conversation to create conversation dictionary
                conversation_dict = self._process_sender_data(sender_data)
                state = conversation_dict.pop("state")

                # Get response
                response, new_state = self.crs_model.get_response(
                    conversation_dict,
                    self.id2entity,
                    self.options,
                    state,
                    **self.response_generation_args,
                )
                logger.debug(f"Generated response: {response}")
                session["state"] = new_state
                return {"response": response}, 200
            except ValueError as e:
                logger.error(f"Error occurred: {e}")
                return (
                    "An error occurred, make sure you have provided the context"
                    " and message.",
                    400,
                )

    def _process_sender_data(
        self, sender_data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Processes sender data to create conversation dictionary.

        The conversation dictionary contains the following keys: context,
        entity, rec, resp, template, and state. Context is a list of the
        previous utterances, entity is a list of entities mentioned in the
        conversation, rec is the recommended items, resp is the response
        generated by the model, and state is the state of the options.
        Note that rec, resp, and template are empty as the model is used for
        inference only, they are kept for compatibility with the models.

        Args:
            sender_data: Data sent by the sender.

        Raises:
            ValueError: If context or message is not present in sender data.

        Returns:
            Conversation dictionary.
        """
        if any(key not in sender_data for key in ["context", "message"]):
            raise ValueError(
                "Invalid sender data. Missing context or message."
            )

        context = sender_data["context"] + [sender_data["message"]]

        state = session.pop("state", None)
        if state is None or len(state) != len(self.options[1]):
            state = [0.0] * len(self.options[1])

        entities = []
        for utterance in context:
            utterance_entities = get_entity(utterance, self.entity_list)
            entities.extend(utterance_entities)

        return {
            "context": context,
            "entity": entities,
            "rec": [],
            "resp": "",
            "template": [],
            "state": state,
        }


if __name__ == "__main__":
    args = parse_args()

    random.seed(args.seed)
    if args.debug:
        logger.setLevel(logging.DEBUG)

    model_args = get_model_args(args.crs_model, args)
    logger.info(f"Loaded arguments for {args.crs_model} model.")
    logger.debug(f"Model arguments:\n{model_args}")

    # Load model
    crs_model = CRSModel(crs_model=args.crs_model, **model_args)
    logger.info(f"Loaded {args.crs_model} model.")

    # Generation arguments
    response_generation_args = {}
    if args.crs_model == "unicrs":
        response_generation_args = {
            "movie_token": (
                "<movie>" if args.kg_dataset.startswith("redial") else "<mask>"
            ),
        }

    # Start CRS Flask server
    crs_server = CRSFlaskServer(
        crs_model, args.kg_dataset, response_generation_args
    )
    crs_server.start(args.host, args.port)