import logging from dataclasses import dataclass, field from functools import lru_cache import numpy as np import pandas as pd from openai.embeddings_utils import cosine_similarity, get_embedding from buster.completers import get_completer from buster.formatter import ( Response, ResponseFormatter, Source, response_formatter_factory, ) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @dataclass class BusterConfig: """Configuration object for a chatbot. documents_csv: Path to the csv file containing the documents and their embeddings. embedding_model: OpenAI model to use to get embeddings. top_k: Max number of documents to retrieve, ordered by cosine similarity thresh: threshold for cosine similarity to be considered max_words: maximum number of words the retrieved documents can be. Will truncate otherwise. completion_kwargs: kwargs for the OpenAI.Completion() method separator: the separator to use, can be either "\n" or

depending on rendering. response_format: the type of format to render links with, e.g. slack or markdown unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to. text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation. reponse_footnote: Generic response to add the the chatbot's reply. source: the source of the document to consider """ documents_file: str = "buster/data/document_embeddings.tar.gz" embedding_model: str = "text-embedding-ada-002" top_k: int = 3 thresh: float = 0.7 max_words: int = 3000 unknown_threshold: float = 0.9 # set to 0 to deactivate completer_cfg: dict = field( # TODO: Put all this in its own config with sane defaults? default_factory=lambda: { "name": "GPT3", "text_before_documents": "You are a chatbot answering questions.\n", "text_before_prompt": "Answer the following question:\n", "completion_kwargs": { "engine": "text-davinci-003", "max_tokens": 200, "temperature": None, "top_p": None, "frequency_penalty": 1, "presence_penalty": 1, }, } ) response_format: str = "slack" unknown_prompt: str = "I Don't know how to answer your question." response_footnote: str = "I'm a bot 🤖 and not always perfect." source: str = "" from buster.retriever import Retriever class Buster: def __init__(self, cfg: BusterConfig, retriever: Retriever): self._unk_embedding = None self.cfg = cfg self.update_cfg(cfg) self.retriever = retriever @property def unk_embedding(self): return self._unk_embedding @unk_embedding.setter def unk_embedding(self, embedding): logger.info("Setting new UNK embedding...") self._unk_embedding = embedding return self._unk_embedding def update_cfg(self, cfg: BusterConfig): """Every time we set a new config, we update the things that need to be updated.""" logger.info(f"Updating config to {cfg.source}:\n{cfg}") self.cfg = cfg self.completer = get_completer(cfg.completer_cfg) self.unk_embedding = self.get_embedding(self.cfg.unknown_prompt, engine=self.cfg.embedding_model) self.response_formatter = response_formatter_factory( format=self.cfg.response_format, response_footnote=self.cfg.response_footnote ) logger.info(f"Config Updated.") @lru_cache def get_embedding(self, query: str, engine: str): logger.info("generating embedding") return get_embedding(query, engine=engine) def rank_documents( self, query: str, top_k: float, thresh: float, engine: str, source: str, ) -> pd.DataFrame: """ Compare the question to the series of documents and return the best matching documents. """ query_embedding = self.get_embedding( query, engine=engine, ) matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source) # log matched_documents to the console logger.info(f"matched documents before thresh: {matched_documents}") # filter out matched_documents using a threshold if thresh: matched_documents = matched_documents[matched_documents.similarity > thresh] logger.info(f"matched documents after thresh: {matched_documents}") return matched_documents def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str: # gather the documents in one large plaintext variable documents_list = matched_documents.content.to_list() documents_str = "" for idx, doc in enumerate(documents_list): documents_str += f" {doc} <\DOCUMENT>" # truncate the documents to fit # TODO: increase to actual token count word_count = len(documents_str.split(" ")) if word_count > max_words: logger.info("truncating documents to fit...") documents_str = " ".join(documents_str.split(" ")[0:max_words]) logger.info(f"Documents after truncation: {documents_str}") return documents_str def add_sources( self, matched_documents: pd.DataFrame, ): sources = ( Source( source=dct["source"], title=dct["title"], url=dct["url"], question_similarity=dct["similarity"] * 100 ) for dct in matched_documents.to_dict(orient="records") ) return sources def check_response_relevance( self, completion: str, engine: str, unk_embedding: np.array, unk_threshold: float ) -> bool: """Check to see if a response is relevant to the chatbot's knowledge or not. We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant. Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding. set the unk_threshold to 0 to essentially turn off this feature. """ response_embedding = self.get_embedding( completion, engine=engine, ) score = cosine_similarity(response_embedding, unk_embedding) logger.info(f"UNK score: {score}") # Likely that the answer is meaningful, add the top sources return score < unk_threshold def process_input(self, user_input: str, formatter: ResponseFormatter = None) -> str: """ Main function to process the input question and generate a formatted output. """ logger.info(f"User Input:\n{user_input}") # We make sure there is always a newline at the end of the question to avoid completing the question. if not user_input.endswith("\n"): user_input += "\n" matched_documents = self.rank_documents( query=user_input, top_k=self.cfg.top_k, thresh=self.cfg.thresh, engine=self.cfg.embedding_model, source=self.cfg.source, ) if len(matched_documents) == 0: response = Response(self.cfg.unknown_prompt) sources = tuple() return self.response_formatter(response, sources) # generate a completion documents: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words) response: Response = self.completer.generate_response(user_input, documents) logger.info(f"GPT Response:\n{response.text}") sources = self.add_sources(matched_documents) # check for relevance relevant = self.check_response_relevance( completion=response.text, engine=self.cfg.embedding_model, unk_embedding=self.unk_embedding, unk_threshold=self.cfg.unknown_threshold, ) if not relevant: # answer generated was the chatbot saying it doesn't know how to answer # override completion with generic "I don't know" response = Response(text=self.cfg.unknown_prompt) sources = tuple() return self.response_formatter(response, sources)