Spaces:
Sleeping
Sleeping
File size: 4,266 Bytes
9d615c0 352827d 9d615c0 aae060c 71b8ea3 9d615c0 95a84bb 9d615c0 d8cabdd 9d615c0 ce65deb befaba8 ea855b6 ad2034b ce65deb 9d615c0 ea855b6 7afd0f0 f15f01c 98a0d76 f15f01c 98a0d76 f15f01c 98a0d76 2afcd7a f15f01c 7afd0f0 ea855b6 9d615c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""This module contains functions for loading a ConversationalRetrievalChain"""
## May 24 move to langchain_community because everything deprecated
import logging
import wandb
from langchain.chains import ConversationalRetrievalChain
##from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
##from langchain.embeddings import OpenAIEmbeddings
from langchain_community.embeddings import OpenAIEmbeddings
## deprectated from langchain.vectorstores import Chroma
from langchain_community.vectorstores import Chroma
from prompts import load_chat_prompt
import pathlib
logger = logging.getLogger(__name__)
def load_vector_store(wandb_run: wandb.run, openai_api_key: str) -> Chroma:
"""Load a vector store from a Weights & Biases artifact
Args:
run (wandb.run): An active Weights & Biases run
openai_api_key (str): The OpenAI API key to use for embedding
Returns:
Chroma: A chroma vector store object
"""
# load vector store artifact
vector_store_artifact_dir = wandb_run.use_artifact(
wandb_run.config.vector_store_artifact, type="search_index"
).download()
embedding_fn = OpenAIEmbeddings(openai_api_key=openai_api_key)
# load vector store
vector_store = Chroma(
embedding_function=embedding_fn, persist_directory=vector_store_artifact_dir
)
return vector_store
def load_chain(wandb_run: wandb.run, vector_store: Chroma, openai_api_key: str):
"""Load a ConversationalQA chain from a config and a vector store
Args:
wandb_run (wandb.run): An active Weights & Biases run
vector_store (Chroma): A Chroma vector store object
openai_api_key (str): The OpenAI API key to use for embedding
Returns:
ConversationalRetrievalChain: A ConversationalRetrievalChain object
"""
retriever = vector_store.as_retriever()
llm = ChatOpenAI(
openai_api_key=openai_api_key,
model_name=wandb_run.config.model_name,
temperature=wandb_run.config.chat_temperature,
max_retries=wandb_run.config.max_fallback_retries,
)
chat_prompt_dir = wandb_run.use_artifact(
wandb_run.config.chat_prompt_artifact, type="prompt"
).download()
qa_prompt = load_chat_prompt(f"{chat_prompt_dir}/chat_prompt_massa.json")
print ( '\\n===================\\nqa_prompt = ', qa_prompt)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=retriever,
combine_docs_chain_kwargs={"prompt": qa_prompt},
return_source_documents=True,
)
return qa_chain
def get_answer(
chain: ConversationalRetrievalChain,
question: str,
chat_history: list[tuple[str, str]],
wandb_run: wandb.run
):
"""Get an answer from a ConversationalRetrievalChain
Args:
chain (ConversationalRetrievalChain): A ConversationalRetrievalChain object
question (str): The question to ask
chat_history (list[tuple[str, str]]): A list of tuples of (question, answer)
Returns:
str: The answer to the question
"""
# Define logging configuration
logging.basicConfig(filename='user_input.log', level=logging.INFO,
format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
# Log user question
logging.info(f"User question: {question}")
wandb.log({"question": question })
# Log training progress
result = chain(
inputs={"question": question, "chat_history": chat_history},
return_only_outputs=True,
)
response = f"Answer:\t{result['answer']}"
print( "file name"+ wandb_run.config.log_file)
f_name = wandb_run.config.log_file
#if isinstance(f_name, str) and f_name:
# f_name = pathlib.Path(f_name)
# with open(f_name, "w") as file1:
# Writing data to a file
# file1.write("Hello \n")
#if f_name and f_name.is_file():
# ret = f_name.write("r"))
# if f_name and f_name.is_file():
##template = json.load(f_name.open("r"))
print("File writing complete."+"quest = "+question+" answer : "+ result['answer'])
return response
|