ashfaq93 commited on
Commit
d4d3d97
·
verified ·
1 Parent(s): 77c8a00

Upload Rag_conversation.py

Browse files
Files changed (1) hide show
  1. Rag_conversation.py +119 -0
Rag_conversation.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain_core.messages import HumanMessage, SystemMessage
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
10
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
11
+
12
+
13
+ # Load environment variables from .env
14
+ load_dotenv()
15
+
16
+ # Define the persistent directory
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ persistent_directory = os.path.join(current_dir, "db", "chroma_db_with_metadata")
19
+
20
+ # Define the embedding model
21
+ embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
22
+
23
+ # Load the existing vector store with the embedding function
24
+ db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings)
25
+
26
+ # Create a retriever for querying the vector store
27
+ # `search_type` specifies the type of search (e.g., similarity)
28
+ # `search_kwargs` contains additional arguments for the search (e.g., number of results to return)
29
+ '''retriever = db.as_retriever(
30
+ search_type="similarity",
31
+ search_kwargs={"k": 4},
32
+ )'''
33
+
34
+ retriever = db.as_retriever(
35
+ search_type="mmr", # Maximal Marginal Relevance (MMR) for diversity
36
+ search_kwargs={"k": 4, "fetch_k": 10} # Fetch more results for better selection
37
+ )
38
+
39
+ # Create a ChatOpenAI model
40
+ llm = ChatOpenAI(model="gpt-4o",temperature=0.2)
41
+
42
+
43
+
44
+ # Contextualize question prompt
45
+ # This system prompt helps the AI understand that it should reformulate the question
46
+ # based on the chat history to make it a standalone question
47
+ contextualize_q_system_prompt = (
48
+ "Given a chat history and the latest user question "
49
+ "which might reference context in the chat history, "
50
+ "formulate a standalone question which can be understood "
51
+ "without the chat history. Do NOT answer the question, just "
52
+ "reformulate it if needed and otherwise return it as is."
53
+ )
54
+
55
+ # Create a prompt template for contextualizing questions
56
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
57
+ [
58
+ ("system", contextualize_q_system_prompt),
59
+ MessagesPlaceholder("chat_history"),
60
+ ("human", "{input}"),
61
+ ]
62
+ )
63
+
64
+ # Create a history-aware retriever
65
+ # This uses the LLM to help reformulate the question based on chat history
66
+ history_aware_retriever = create_history_aware_retriever(
67
+ llm, retriever, contextualize_q_prompt
68
+ )
69
+
70
+ # Answer question prompt
71
+ # This system prompt helps the AI understand that it should provide concise answers
72
+ # based on the retrieved context and indicates what to do if the answer is unknown
73
+
74
+
75
+ qa_system_prompt = (
76
+ "You are an assistant for answering questions at Binghamton University."
77
+ "Use the retrieved context to generate a structured response with bullet points where appropriate."
78
+ "\n\n{context}"
79
+ "\n\nIf you don't know the answer, simply state that fact."
80
+ )
81
+
82
+
83
+ # Create a prompt template for answering questions
84
+ qa_prompt = ChatPromptTemplate.from_messages(
85
+ [
86
+ ("system", qa_system_prompt),
87
+ MessagesPlaceholder("chat_history"),
88
+ ("human", "{input}"),
89
+ ]
90
+ )
91
+
92
+ # Create a chain to combine documents for question answering
93
+ # `create_stuff_documents_chain` feeds all retrieved context into the LLM
94
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
95
+
96
+ # Create a retrieval chain that combines the history-aware retriever and the question answering chain
97
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
98
+
99
+
100
+ # Function to simulate a continual chat
101
+ def continual_chat():
102
+ print("Start chatting with the AI! Type 'exit' to end the conversation.")
103
+ chat_history = [] # Collect chat history here (a sequence of messages)
104
+ while True:
105
+ query = input("You: ")
106
+ if query.lower() == "exit":
107
+ break
108
+ # Process the user's query through the retrieval chain
109
+ result = rag_chain.invoke({"input": query, "chat_history": chat_history})
110
+ # Display the AI's response
111
+ print(f"AI: {result['answer']}")
112
+ # Update the chat history
113
+ chat_history.append(HumanMessage(content=query))
114
+ chat_history.append(SystemMessage(content=result["answer"]))
115
+
116
+
117
+ # Main function to start the continual chat
118
+ if __name__ == "__main__":
119
+ continual_chat()