Spaces:
Running
Running
"""This module contains functions for loading a ConversationalRetrievalChain""" | |
## May 24 move to langchain_community because everything deprecated | |
import logging | |
import wandb | |
from langchain.chains import ConversationalRetrievalChain | |
##from langchain.chat_models import ChatOpenAI | |
from langchain_community.chat_models import ChatOpenAI | |
##from langchain.embeddings import OpenAIEmbeddings | |
from langchain_community.embeddings import OpenAIEmbeddings | |
## deprectated from langchain.vectorstores import Chroma | |
from langchain_community.vectorstores import Chroma | |
from prompts import load_chat_prompt | |
import pathlib | |
logger = logging.getLogger(__name__) | |
def load_vector_store(wandb_run: wandb.run, openai_api_key: str) -> Chroma: | |
"""Load a vector store from a Weights & Biases artifact | |
Args: | |
run (wandb.run): An active Weights & Biases run | |
openai_api_key (str): The OpenAI API key to use for embedding | |
Returns: | |
Chroma: A chroma vector store object | |
""" | |
# load vector store artifact | |
vector_store_artifact_dir = wandb_run.use_artifact( | |
wandb_run.config.vector_store_artifact, type="search_index" | |
).download() | |
embedding_fn = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
# load vector store | |
vector_store = Chroma( | |
embedding_function=embedding_fn, persist_directory=vector_store_artifact_dir | |
) | |
return vector_store | |
def load_chain(wandb_run: wandb.run, vector_store: Chroma, openai_api_key: str): | |
"""Load a ConversationalQA chain from a config and a vector store | |
Args: | |
wandb_run (wandb.run): An active Weights & Biases run | |
vector_store (Chroma): A Chroma vector store object | |
openai_api_key (str): The OpenAI API key to use for embedding | |
Returns: | |
ConversationalRetrievalChain: A ConversationalRetrievalChain object | |
""" | |
retriever = vector_store.as_retriever() | |
llm = ChatOpenAI( | |
openai_api_key=openai_api_key, | |
model_name=wandb_run.config.model_name, | |
temperature=wandb_run.config.chat_temperature, | |
max_retries=wandb_run.config.max_fallback_retries, | |
) | |
chat_prompt_dir = wandb_run.use_artifact( | |
wandb_run.config.chat_prompt_artifact, type="prompt" | |
).download() | |
qa_prompt = load_chat_prompt(f"{chat_prompt_dir}/chat_prompt_massa.json") | |
print ( '\\n===================\\nqa_prompt = ', qa_prompt) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
combine_docs_chain_kwargs={"prompt": qa_prompt}, | |
return_source_documents=True, | |
) | |
return qa_chain | |
def get_answer( | |
chain: ConversationalRetrievalChain, | |
question: str, | |
chat_history: list[tuple[str, str]], | |
wandb_run: wandb.run | |
): | |
"""Get an answer from a ConversationalRetrievalChain | |
Args: | |
chain (ConversationalRetrievalChain): A ConversationalRetrievalChain object | |
question (str): The question to ask | |
chat_history (list[tuple[str, str]]): A list of tuples of (question, answer) | |
Returns: | |
str: The answer to the question | |
""" | |
# Define logging configuration | |
logging.basicConfig(filename='user_input.log', level=logging.INFO, | |
format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
# Log user question | |
logging.info(f"User question: {question}") | |
wandb.log({"question": question }) | |
# Log training progress | |
result = chain( | |
inputs={"question": question, "chat_history": chat_history}, | |
return_only_outputs=True, | |
) | |
response = f"Answer:\t{result['answer']}" | |
print( "file name"+ wandb_run.config.log_file) | |
f_name = wandb_run.config.log_file | |
#if isinstance(f_name, str) and f_name: | |
# f_name = pathlib.Path(f_name) | |
# with open(f_name, "w") as file1: | |
# Writing data to a file | |
# file1.write("Hello \n") | |
#if f_name and f_name.is_file(): | |
# ret = f_name.write("r")) | |
# if f_name and f_name.is_file(): | |
##template = json.load(f_name.open("r")) | |
print("File writing complete."+"quest = "+question+" answer : "+ result['answer']) | |
return response | |