|
"""Contextual Retrieval-based CRS. |
|
|
|
This CRS comprises two main components: a retriever and a recommender. The |
|
retriever is responsible for retrieving and ranking a set of responses from a |
|
pre-defined corpus. The recommender is responsible for recommending items |
|
whenever the retrieved responses contains an item placeholder. |
|
|
|
Adapted from original code: |
|
https://github.com/ahtsham58/CRB-CRS/tree/main |
|
""" |
|
|
|
import logging |
|
import os |
|
import re |
|
from typing import Any, Dict, List, Tuple |
|
|
|
from nltk.tokenize import word_tokenize |
|
import nltk |
|
|
|
nltk.download('punkt_tab') |
|
|
|
from src.model.crb_crs.recommender import * |
|
from src.model.crb_crs.retriever.mle_model import NGramMLE |
|
from src.model.crb_crs.retriever.retriever import Retriever |
|
|
|
|
|
class CRBCRSModel: |
|
def __init__( |
|
self, |
|
dataset: str, |
|
domain: str, |
|
corpus_folder: str, |
|
mle_model_path: str, |
|
recommender_path: str, |
|
) -> None: |
|
"""Initializes the CRB-CRS model. |
|
|
|
Args: |
|
dataset: Dataset name. |
|
domain: Domain of application. |
|
corpus_folder: Path to the folder containing the corpus. |
|
mle_model_path: Path to the MLE model. |
|
recommender_path: Path to the recommender model. |
|
|
|
Raises: |
|
FileNotFoundError: If MLE model path does not exist. |
|
""" |
|
if not os.path.exists(mle_model_path): |
|
raise FileNotFoundError( |
|
f"MLE model path {mle_model_path} does not exist." |
|
) |
|
mle_model = NGramMLE.load(mle_model_path) |
|
self.kg_dataset = dataset |
|
|
|
self.retriever = Retriever(corpus_folder, mle_model, dataset, domain) |
|
self.recommender = Recommender.load(recommender_path) |
|
|
|
def get_rec(self, conv_dict: Dict[str, Any]): |
|
"""Generates recommendations given a conversation context.""" |
|
pass |
|
|
|
def get_conv(self, conv_dict: Dict[str, Any]): |
|
"""Generates utterance given a conversation context.""" |
|
pass |
|
|
|
def get_response( |
|
self, |
|
conv_dict: Dict[str, Any], |
|
id2entity: Dict[int, str] = None, |
|
options: Tuple[str, Dict[str, str]] = None, |
|
state: List[float] = None, |
|
) -> Tuple[str, List[float]]: |
|
"""Generates a response given a conversation context. |
|
|
|
Args: |
|
conv_dict: Conversation context. |
|
id2entity (not used): Mapping from entity id to entity name. |
|
Defaults to None. |
|
options (not used): Prompt with options and dictionary of options. |
|
Defaults to None. |
|
state (not used): State of the option choices. Defaults to None. |
|
Returns: |
|
Generated response. |
|
""" |
|
|
|
context = conv_dict["context"] |
|
last_user_utterance = context[-1] |
|
last_user_utterance_tokens = word_tokenize(last_user_utterance) |
|
candidate_responses = [] |
|
|
|
|
|
candidate_responses.extend(self._get_candidates([last_user_utterance])) |
|
|
|
if len(context) > 1: |
|
|
|
|
|
candidate_responses.extend(self._get_candidates(context[-2:])) |
|
|
|
if len(context) > 2: |
|
|
|
|
|
candidate_responses.extend(self._get_candidates(context[-3:])) |
|
|
|
if len(context) > 3: |
|
|
|
candidate_responses.extend(self._get_candidates(context)) |
|
|
|
ranked_candidates = self.retriever.rank_candidates( |
|
last_user_utterance_tokens, candidate_responses |
|
) |
|
retrieved_response = ranked_candidates[0] |
|
|
|
|
|
recommended_items = [] |
|
original_item_ids = self.get_item_ids_from_retrieved_response( |
|
retrieved_response |
|
) |
|
if original_item_ids: |
|
recommended_items = self.recommender.get_recommendations(context) |
|
|
|
|
|
response = self.recommender.replace_item_ids_with_recommendations( |
|
retrieved_response, original_item_ids, recommended_items |
|
) |
|
|
|
|
|
try: |
|
response = self.recommender.integrate_domain_metadata( |
|
context, response |
|
) |
|
except Exception as e: |
|
logging.error(f"Error while integrating metadata: {e}") |
|
|
|
response = self.retriever.remove_utterance_prefix(response) |
|
return response, None |
|
|
|
def get_item_ids_from_retrieved_response(self, response: str) -> List[str]: |
|
"""Extracts item ids from a retrieved response. |
|
|
|
Args: |
|
response: Retrieved response. |
|
|
|
Returns: |
|
List of item ids. |
|
""" |
|
if self.kg_dataset == "redial": |
|
return [ |
|
re.sub(r"[@?]", "", id) for id in re.findall(r"@\S+", response) |
|
] |
|
return [] |
|
|
|
def _get_candidates(self, context: List[str]) -> List[str]: |
|
"""Gets candidate responses based on the conversation context.""" |
|
input_query = self.retriever.build_query(context) |
|
candidates = self.retriever.retrieve_candidates(context=input_query) |
|
return self.retriever.filter_outliers_from_candidates(candidates) |
|
|
|
def get_choice(self, gen_inputs, option, state, conv_dict=None): |
|
"""Generates a choice between options given a conversation context. |
|
|
|
This method is not implemented in this class because the recommendation |
|
stage is already conditioned on the conversation context.""" |
|
pass |
|
|