DrishtiSharma commited on
Commit
2685a62
·
verified ·
1 Parent(s): 8e66140

Create interim_v1.py

Browse files
Files changed (1) hide show
  1. interim_v1.py +164 -0
interim_v1.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import chromadb
3
+ import streamlit as st
4
+ from dotenv import load_dotenv
5
+ from langchain_openai import ChatOpenAI
6
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
7
+ from langchain_core.messages import BaseMessage, HumanMessage
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_experimental.tools import PythonREPLTool
10
+ from langchain_community.document_loaders import DirectoryLoader
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain_core.runnables import RunnablePassthrough
16
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
17
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
18
+ from langgraph.graph import StateGraph, END
19
+ from langchain_core.documents import Document
20
+ from typing import Annotated, Sequence, TypedDict
21
+ import functools
22
+ import operator
23
+ from langchain_core.tools import tool
24
+
25
+
26
+ # Clear ChromaDB cache to fix tenant issue
27
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
28
+
29
+ # Load environment variables
30
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
31
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
32
+
33
+ if not OPENAI_API_KEY or not TAVILY_API_KEY:
34
+ st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
35
+ st.stop()
36
+
37
+ # Initialize API keys and LLM
38
+ llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
39
+
40
+ # Utility Functions
41
+ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
42
+ prompt = ChatPromptTemplate.from_messages([
43
+ ("system", system_prompt),
44
+ MessagesPlaceholder(variable_name="messages"),
45
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
46
+ ])
47
+ agent = create_openai_tools_agent(llm, tools, prompt)
48
+ return AgentExecutor(agent=agent, tools=tools)
49
+
50
+ def agent_node(state, agent, name):
51
+ result = agent.invoke(state)
52
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
53
+
54
+ @tool
55
+ def RAG(state):
56
+ """Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results."""
57
+ st.session_state.outputs.append('-> Calling RAG ->')
58
+ question = state
59
+ template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}"""
60
+ prompt = ChatPromptTemplate.from_template(template)
61
+ retrieval_chain = (
62
+ {"context": retriever, "question": RunnablePassthrough()} |
63
+ prompt |
64
+ llm |
65
+ StrOutputParser()
66
+ )
67
+ result = retrieval_chain.invoke(question)
68
+ return result
69
+
70
+ # Load Tools and Retriever
71
+ tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
72
+ python_repl_tool = PythonREPLTool()
73
+
74
+ # File Upload Section
75
+ st.title("Multi-Agent Workflow Demonstration")
76
+ uploaded_files = st.file_uploader("Upload your source files (TXT)", accept_multiple_files=True, type=['txt'])
77
+
78
+ if uploaded_files:
79
+ docs = []
80
+ for uploaded_file in uploaded_files:
81
+ content = uploaded_file.read().decode("utf-8")
82
+ docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
83
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
84
+ new_docs = text_splitter.split_documents(documents=docs)
85
+ embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
86
+ db = Chroma.from_documents(new_docs, embeddings)
87
+ retriever = db.as_retriever(search_kwargs={"k": 4})
88
+ else:
89
+ retriever = None
90
+ st.warning("Please upload at least one text file to proceed.")
91
+ st.stop()
92
+
93
+ # Create Agents
94
+ research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
95
+ code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
96
+ RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
97
+
98
+ research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
99
+ code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
100
+ rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
101
+
102
+ members = ["RAG", "Researcher", "Coder"]
103
+ system_prompt = (
104
+ "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
105
+ "Use RAG tool for Japan or Sports questions."
106
+ )
107
+ options = ["FINISH"] + members
108
+ function_def = {
109
+ "name": "route", "description": "Select the next role.",
110
+ "parameters": {
111
+ "title": "routeSchema", "type": "object",
112
+ "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]
113
+ }
114
+ }
115
+ prompt = ChatPromptTemplate.from_messages([
116
+ ("system", system_prompt),
117
+ MessagesPlaceholder(variable_name="messages"),
118
+ ("system", "Given the conversation above, who should act next? Select one of: {options}"),
119
+ ]).partial(options=str(options), members=", ".join(members))
120
+
121
+ supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
122
+
123
+ # Build Workflow
124
+ class AgentState(TypedDict):
125
+ messages: Annotated[Sequence[BaseMessage], operator.add]
126
+ next: str
127
+
128
+ workflow = StateGraph(AgentState)
129
+ workflow.add_node("Researcher", research_node)
130
+ workflow.add_node("Coder", code_node)
131
+ workflow.add_node("RAG", rag_node)
132
+ workflow.add_node("supervisor", supervisor_chain)
133
+
134
+ for member in members:
135
+ workflow.add_edge(member, "supervisor")
136
+ conditional_map = {k: k for k in members}
137
+ conditional_map["FINISH"] = END
138
+ workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
139
+ workflow.set_entry_point("supervisor")
140
+ graph = workflow.compile()
141
+
142
+ # Streamlit UI
143
+ if 'outputs' not in st.session_state:
144
+ st.session_state.outputs = []
145
+
146
+ user_input = st.text_area("Enter your task or question:")
147
+
148
+ def run_workflow(task):
149
+ st.session_state.outputs.clear()
150
+ st.session_state.outputs.append(f"User Input: {task}")
151
+ for state in graph.stream({"messages": [HumanMessage(content=task)]}):
152
+ if "__end__" not in state:
153
+ st.session_state.outputs.append(str(state))
154
+ st.session_state.outputs.append("----")
155
+
156
+ if st.button("Run Workflow"):
157
+ if user_input:
158
+ run_workflow(user_input)
159
+ else:
160
+ st.warning("Please enter a task or question.")
161
+
162
+ st.subheader("Workflow Output:")
163
+ for output in st.session_state.outputs:
164
+ st.text(output)