CRSArena / src /model /crs_model.py
Nolwenn
Initial commit
b599481
raw
history blame
1.86 kB
import sys
from typing import Any, Dict, List, Tuple
sys.path.append("..")
from src.model.BARCOR import BARCOR
from src.model.CHATGPT import CHATGPT
from src.model.CRB_CRS import CRBCRSModel
from src.model.KBRD import KBRD
from src.model.UNICRS import UNICRS
name2class = {
"kbrd": KBRD,
"barcor": BARCOR,
"unicrs": UNICRS,
"chatgpt": CHATGPT,
"crbcrs": CRBCRSModel,
}
class CRSModel:
def __init__(self, crs_model, *args, **kwargs) -> None:
model_class = name2class[crs_model]
self.crs_model = model_class(*args, **kwargs)
def get_rec(self, conv_dict: Dict[str, Any]):
"""Generates recommendations given a conversation context."""
return self.crs_model.get_rec(conv_dict)
def get_conv(self, conv_dict: Dict[str, Any]):
"""Generates utterance given a conversation context."""
return self.crs_model.get_conv(conv_dict)
def get_response(
self,
conv_dict: Dict[str, Any],
id2entity: Dict[int, str],
options: Tuple[str, Dict[str, str]],
state: List[float],
**kwargs
) -> Tuple[str, List[float]]:
"""Generates a response given a conversation context.
Args:
conv_dict: Conversation context.
id2entity: Mapping from entity id to entity name.
options: Prompt with options and dictionary of options.
state: State of the option choices.
Returns:
Generated response and updated state.
"""
return self.crs_model.get_response(
conv_dict, id2entity, options, state, **kwargs
)
def get_choice(self, gen_inputs, option, state, conv_dict=None):
"""Generates a choice between options given a conversation context."""
return self.crs_model.get_choice(gen_inputs, option, state, conv_dict)