Last commit not found
import sys | |
from typing import Any, Dict, List, Tuple | |
sys.path.append("..") | |
from src.model.BARCOR import BARCOR | |
from src.model.CHATGPT import CHATGPT | |
from src.model.CRB_CRS import CRBCRSModel | |
from src.model.KBRD import KBRD | |
from src.model.UNICRS import UNICRS | |
name2class = { | |
"kbrd": KBRD, | |
"barcor": BARCOR, | |
"unicrs": UNICRS, | |
"chatgpt": CHATGPT, | |
"crbcrs": CRBCRSModel, | |
} | |
class CRSModel: | |
def __init__(self, crs_model, *args, **kwargs) -> None: | |
model_class = name2class[crs_model] | |
self.crs_model = model_class(*args, **kwargs) | |
def get_rec(self, conv_dict: Dict[str, Any]): | |
"""Generates recommendations given a conversation context.""" | |
return self.crs_model.get_rec(conv_dict) | |
def get_conv(self, conv_dict: Dict[str, Any]): | |
"""Generates utterance given a conversation context.""" | |
return self.crs_model.get_conv(conv_dict) | |
def get_response( | |
self, | |
conv_dict: Dict[str, Any], | |
id2entity: Dict[int, str], | |
options: Tuple[str, Dict[str, str]], | |
state: List[float], | |
**kwargs | |
) -> Tuple[str, List[float]]: | |
"""Generates a response given a conversation context. | |
Args: | |
conv_dict: Conversation context. | |
id2entity: Mapping from entity id to entity name. | |
options: Prompt with options and dictionary of options. | |
state: State of the option choices. | |
Returns: | |
Generated response and updated state. | |
""" | |
return self.crs_model.get_response( | |
conv_dict, id2entity, options, state, **kwargs | |
) | |
def get_choice(self, gen_inputs, option, state, conv_dict=None): | |
"""Generates a choice between options given a conversation context.""" | |
return self.crs_model.get_choice(gen_inputs, option, state, conv_dict) | |