Spaces:
Running
Running
File size: 3,037 Bytes
9d615c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
"""This module contains functions for loading a ConversationalRetrievalChain"""
import logging
import wandb
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from prompts import load_chat_prompt
logger = logging.getLogger(__name__)
def load_vector_store(wandb_run: wandb.run, openai_api_key: str) -> Chroma:
"""Load a vector store from a Weights & Biases artifact
Args:
run (wandb.run): An active Weights & Biases run
openai_api_key (str): The OpenAI API key to use for embedding
Returns:
Chroma: A chroma vector store object
"""
# load vector store artifact
vector_store_artifact_dir = wandb_run.use_artifact(
wandb_run.config.vector_store_artifact, type="search_index"
).download()
embedding_fn = OpenAIEmbeddings(openai_api_key=openai_api_key)
# load vector store
vector_store = Chroma(
embedding_function=embedding_fn, persist_directory=vector_store_artifact_dir
)
return vector_store
def load_chain(wandb_run: wandb.run, vector_store: Chroma, openai_api_key: str):
"""Load a ConversationalQA chain from a config and a vector store
Args:
wandb_run (wandb.run): An active Weights & Biases run
vector_store (Chroma): A Chroma vector store object
openai_api_key (str): The OpenAI API key to use for embedding
Returns:
ConversationalRetrievalChain: A ConversationalRetrievalChain object
"""
retriever = vector_store.as_retriever()
llm = ChatOpenAI(
openai_api_key=openai_api_key,
model_name=wandb_run.config.model_name,
temperature=wandb_run.config.chat_temperature,
max_retries=wandb_run.config.max_fallback_retries,
)
chat_prompt_dir = wandb_run.use_artifact(
wandb_run.config.chat_prompt_artifact, type="prompt"
).download()
qa_prompt = load_chat_prompt(f"{chat_prompt_dir}/chat_prompt_massa.json")
print ( '\\n===================\\nqa_prompt = ', qa_prompt)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=retriever,
combine_docs_chain_kwargs={"prompt": qa_prompt},
return_source_documents=True,
)
return qa_chain
def get_answer(
chain: ConversationalRetrievalChain,
question: str,
chat_history: list[tuple[str, str]],
):
"""Get an answer from a ConversationalRetrievalChain
Args:
chain (ConversationalRetrievalChain): A ConversationalRetrievalChain object
question (str): The question to ask
chat_history (list[tuple[str, str]]): A list of tuples of (question, answer)
Returns:
str: The answer to the question
"""
result = chain(
inputs={"question": question, "chat_history": chat_history},
return_only_outputs=True,
)
response = f"Answer:\t{result['answer']}"
return response
|