bruno16 commited on
Commit
9d615c0
1 Parent(s): bfda76e

Upload chain.py

Browse files
Files changed (1) hide show
  1. chain.py +88 -0
chain.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains functions for loading a ConversationalRetrievalChain"""
2
+
3
+ import logging
4
+
5
+ import wandb
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.embeddings import OpenAIEmbeddings
9
+ from langchain.vectorstores import Chroma
10
+ from prompts import load_chat_prompt
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def load_vector_store(wandb_run: wandb.run, openai_api_key: str) -> Chroma:
17
+ """Load a vector store from a Weights & Biases artifact
18
+ Args:
19
+ run (wandb.run): An active Weights & Biases run
20
+ openai_api_key (str): The OpenAI API key to use for embedding
21
+ Returns:
22
+ Chroma: A chroma vector store object
23
+ """
24
+ # load vector store artifact
25
+ vector_store_artifact_dir = wandb_run.use_artifact(
26
+ wandb_run.config.vector_store_artifact, type="search_index"
27
+ ).download()
28
+ embedding_fn = OpenAIEmbeddings(openai_api_key=openai_api_key)
29
+ # load vector store
30
+ vector_store = Chroma(
31
+ embedding_function=embedding_fn, persist_directory=vector_store_artifact_dir
32
+ )
33
+
34
+ return vector_store
35
+
36
+
37
+ def load_chain(wandb_run: wandb.run, vector_store: Chroma, openai_api_key: str):
38
+ """Load a ConversationalQA chain from a config and a vector store
39
+ Args:
40
+ wandb_run (wandb.run): An active Weights & Biases run
41
+ vector_store (Chroma): A Chroma vector store object
42
+ openai_api_key (str): The OpenAI API key to use for embedding
43
+ Returns:
44
+ ConversationalRetrievalChain: A ConversationalRetrievalChain object
45
+ """
46
+ retriever = vector_store.as_retriever()
47
+ llm = ChatOpenAI(
48
+ openai_api_key=openai_api_key,
49
+ model_name=wandb_run.config.model_name,
50
+ temperature=wandb_run.config.chat_temperature,
51
+ max_retries=wandb_run.config.max_fallback_retries,
52
+ )
53
+ chat_prompt_dir = wandb_run.use_artifact(
54
+ wandb_run.config.chat_prompt_artifact, type="prompt"
55
+ ).download()
56
+ qa_prompt = load_chat_prompt(f"{chat_prompt_dir}/chat_prompt_massa.json")
57
+
58
+ print ( '\\n===================\\nqa_prompt = ', qa_prompt)
59
+
60
+ qa_chain = ConversationalRetrievalChain.from_llm(
61
+ llm=llm,
62
+ chain_type="stuff",
63
+ retriever=retriever,
64
+ combine_docs_chain_kwargs={"prompt": qa_prompt},
65
+ return_source_documents=True,
66
+ )
67
+ return qa_chain
68
+
69
+
70
+ def get_answer(
71
+ chain: ConversationalRetrievalChain,
72
+ question: str,
73
+ chat_history: list[tuple[str, str]],
74
+ ):
75
+ """Get an answer from a ConversationalRetrievalChain
76
+ Args:
77
+ chain (ConversationalRetrievalChain): A ConversationalRetrievalChain object
78
+ question (str): The question to ask
79
+ chat_history (list[tuple[str, str]]): A list of tuples of (question, answer)
80
+ Returns:
81
+ str: The answer to the question
82
+ """
83
+ result = chain(
84
+ inputs={"question": question, "chat_history": chat_history},
85
+ return_only_outputs=True,
86
+ )
87
+ response = f"Answer:\t{result['answer']}"
88
+ return response