|
import sys |
|
from typing import Any, Dict, List, Tuple |
|
|
|
sys.path.append("..") |
|
|
|
from src.model.BARCOR import BARCOR |
|
from src.model.CHATGPT import CHATGPT |
|
from src.model.CRB_CRS import CRBCRSModel |
|
from src.model.KBRD import KBRD |
|
from src.model.UNICRS import UNICRS |
|
|
|
name2class = { |
|
"kbrd": KBRD, |
|
"barcor": BARCOR, |
|
"unicrs": UNICRS, |
|
"chatgpt": CHATGPT, |
|
"crbcrs": CRBCRSModel, |
|
} |
|
|
|
|
|
class CRSModel: |
|
def __init__(self, crs_model, *args, **kwargs) -> None: |
|
model_class = name2class[crs_model] |
|
self.crs_model = model_class(*args, **kwargs) |
|
|
|
def get_rec(self, conv_dict: Dict[str, Any]): |
|
"""Generates recommendations given a conversation context.""" |
|
return self.crs_model.get_rec(conv_dict) |
|
|
|
def get_conv(self, conv_dict: Dict[str, Any]): |
|
"""Generates utterance given a conversation context.""" |
|
return self.crs_model.get_conv(conv_dict) |
|
|
|
def get_response( |
|
self, |
|
conv_dict: Dict[str, Any], |
|
id2entity: Dict[int, str], |
|
options: Tuple[str, Dict[str, str]], |
|
state: List[float], |
|
**kwargs |
|
) -> Tuple[str, List[float]]: |
|
"""Generates a response given a conversation context. |
|
|
|
Args: |
|
conv_dict: Conversation context. |
|
id2entity: Mapping from entity id to entity name. |
|
options: Prompt with options and dictionary of options. |
|
state: State of the option choices. |
|
|
|
Returns: |
|
Generated response and updated state. |
|
""" |
|
return self.crs_model.get_response( |
|
conv_dict, id2entity, options, state, **kwargs |
|
) |
|
|
|
def get_choice(self, gen_inputs, option, state, conv_dict=None): |
|
"""Generates a choice between options given a conversation context.""" |
|
return self.crs_model.get_choice(gen_inputs, option, state, conv_dict) |
|
|