File size: 4,103 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fe92ef
b599481
1220eda
6fe92ef
 
1220eda
6fe92ef
 
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""CRS Fighter.

This class represents a CRS fighter. A CRS fighter has a fighter id (i.e., 1
or 2), a name (i.e., model name), and a CRS. The CRS is loaded using the
model name and configuration file.
"""

import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from utils import get_crs_model

from src.model.utils import get_entity, get_options

if TYPE_CHECKING:
    from battle_manager import Message


class CRSFighter:
    def __init__(self, fighter_id: int, name: str, config_path: str) -> None:
        """Initializes CRS fighter.

        Args:
            fighter_id: Fighter id (1 or 2).
            name: Model name.
            config: Model configuration file.

        Raises:
            ValueError: If id is not 1 or 2.
        """
        if fighter_id not in [1, 2]:
            raise ValueError("Fighter id must be 1 or 2.")

        self.fighter_id = fighter_id

        self.name = name
        self.config_path = config_path
        self.model = get_crs_model(self.name, self.config_path)

        # Load entity data
        self._load_entity_data()

        # Load options
        self.options = get_options(self.model.crs_model.kg_dataset)

        # Generation arguments
        self.response_generation_args = {}
        if self.name.split("_")[0].__contains__("unicrs"):
            self.response_generation_args.update(
                {
                    "movie_token": "<pad>",
                }
            )

    def _load_entity_data(self):
        """Loads entity data."""
        with open(
            f"data/{self.model.crs_model.kg_dataset}/entity2id.json",
            "r",
            encoding="utf-8",
        ) as f:
            self.entity2id = json.load(f)

        self.id2entity = {int(v): k for k, v in self.entity2id.items()}
        self.entity_list = list(self.entity2id.keys())

    def _process_user_input(
        self, input_message: str, history: List["Message"]
    ) -> Dict[str, Any]:
        """Processes user input.

        The conversation dictionary contains the following keys: context,
        entity, rec, and resp. Context is a list of the previous utterances,
        entity is a list of entities mentioned in the conversation, rec is the
        recommended items, resp is the response generated by the model, and
        template is the context with masked entities.
        Note that rec, resp, and template are empty as the model is used for
        inference only, they are kept for compatibility with the models.

        Args:
            input_message: User input message.
            history: Conversation history.

        Returns:
            Processed user input.
        """
        context = [m["message"] for m in history] + [input_message]
        entities = []
        for utterance in context:
            utterance_entities = get_entity(utterance, self.entity_list)
            entities.extend(utterance_entities)

        return {
            "context": context,
            "entity": entities,
            "rec": [],
            "resp": "",
            "template": [],
        }

    def reply(
        self,
        input_message: str,
        history: List["Message"],
        options_state: Optional[List[float]],
    ) -> Tuple[str, List[float]]:
        """Generates a reply to the user input.

        Args:
            input_message: User input message.
            history: Conversation history.
            options_state: State of the options.

        Returns:
            Generated response and updated state.
        """
        # Process conversation to create conversation dictionary
        conversation_dict = self._process_user_input(input_message, history)

        if options_state is None or len(options_state) != len(self.options[1]):
            options_state = [0.0] * len(self.options[1])

        # Get response
        response, state = self.model.get_response(
            conversation_dict,
            self.id2entity,
            self.options,
            options_state,
            **self.response_generation_args,
        )
        return response, state