File size: 4,103 Bytes
b599481 6fe92ef b599481 1220eda 6fe92ef 1220eda 6fe92ef b599481 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""CRS Fighter.
This class represents a CRS fighter. A CRS fighter has a fighter id (i.e., 1
or 2), a name (i.e., model name), and a CRS. The CRS is loaded using the
model name and configuration file.
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from utils import get_crs_model
from src.model.utils import get_entity, get_options
if TYPE_CHECKING:
from battle_manager import Message
class CRSFighter:
def __init__(self, fighter_id: int, name: str, config_path: str) -> None:
"""Initializes CRS fighter.
Args:
fighter_id: Fighter id (1 or 2).
name: Model name.
config: Model configuration file.
Raises:
ValueError: If id is not 1 or 2.
"""
if fighter_id not in [1, 2]:
raise ValueError("Fighter id must be 1 or 2.")
self.fighter_id = fighter_id
self.name = name
self.config_path = config_path
self.model = get_crs_model(self.name, self.config_path)
# Load entity data
self._load_entity_data()
# Load options
self.options = get_options(self.model.crs_model.kg_dataset)
# Generation arguments
self.response_generation_args = {}
if self.name.split("_")[0].__contains__("unicrs"):
self.response_generation_args.update(
{
"movie_token": "<pad>",
}
)
def _load_entity_data(self):
"""Loads entity data."""
with open(
f"data/{self.model.crs_model.kg_dataset}/entity2id.json",
"r",
encoding="utf-8",
) as f:
self.entity2id = json.load(f)
self.id2entity = {int(v): k for k, v in self.entity2id.items()}
self.entity_list = list(self.entity2id.keys())
def _process_user_input(
self, input_message: str, history: List["Message"]
) -> Dict[str, Any]:
"""Processes user input.
The conversation dictionary contains the following keys: context,
entity, rec, and resp. Context is a list of the previous utterances,
entity is a list of entities mentioned in the conversation, rec is the
recommended items, resp is the response generated by the model, and
template is the context with masked entities.
Note that rec, resp, and template are empty as the model is used for
inference only, they are kept for compatibility with the models.
Args:
input_message: User input message.
history: Conversation history.
Returns:
Processed user input.
"""
context = [m["message"] for m in history] + [input_message]
entities = []
for utterance in context:
utterance_entities = get_entity(utterance, self.entity_list)
entities.extend(utterance_entities)
return {
"context": context,
"entity": entities,
"rec": [],
"resp": "",
"template": [],
}
def reply(
self,
input_message: str,
history: List["Message"],
options_state: Optional[List[float]],
) -> Tuple[str, List[float]]:
"""Generates a reply to the user input.
Args:
input_message: User input message.
history: Conversation history.
options_state: State of the options.
Returns:
Generated response and updated state.
"""
# Process conversation to create conversation dictionary
conversation_dict = self._process_user_input(input_message, history)
if options_state is None or len(options_state) != len(self.options[1]):
options_state = [0.0] * len(self.options[1])
# Get response
response, state = self.model.get_response(
conversation_dict,
self.id2entity,
self.options,
options_state,
**self.response_generation_args,
)
return response, state
|