DrishtiSharma commited on
Commit
1b2c3c5
·
verified ·
1 Parent(s): 6c4605d

Create interim.py

Browse files
Files changed (1) hide show
  1. interim.py +157 -0
interim.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
6
+ from langchain_core.messages import BaseMessage, HumanMessage
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_experimental.tools import PythonREPLTool
9
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import Chroma
12
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.runnables import RunnablePassthrough
15
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
16
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
+ from langgraph.graph import StateGraph, END
18
+ from typing import Annotated, Sequence, TypedDict
19
+ from langchain_core.tools import tool
20
+ import functools
21
+ import operator
22
+
23
+ # Load environment variables
24
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
25
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
26
+
27
+ if not OPENAI_API_KEY or not TAVILY_API_KEY:
28
+ st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
29
+ st.stop()
30
+
31
+ # Initialize API keys and LLM
32
+ llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
33
+
34
+ # Utility Functions
35
+ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
36
+ prompt = ChatPromptTemplate.from_messages([
37
+ ("system", system_prompt),
38
+ MessagesPlaceholder(variable_name="messages"),
39
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
40
+ ])
41
+ agent = create_openai_tools_agent(llm, tools, prompt)
42
+ return AgentExecutor(agent=agent, tools=tools)
43
+
44
+ def agent_node(state, agent, name):
45
+ result = agent.invoke(state)
46
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
47
+
48
+ @tool
49
+ def RAG(state):
50
+ st.session_state.outputs.append('-> Calling RAG ->')
51
+ question = state
52
+ template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}"""
53
+ prompt = ChatPromptTemplate.from_template(template)
54
+ retrieval_chain = (
55
+ {"context": retriever, "question": RunnablePassthrough()} |
56
+ prompt |
57
+ llm |
58
+ StrOutputParser()
59
+ )
60
+ result = retrieval_chain.invoke(question)
61
+ return result
62
+
63
+ # Load Tools and Retriever
64
+ tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
65
+ python_repl_tool = PythonREPLTool()
66
+
67
+ # File Upload Section
68
+ st.title("Multi-Agent Workflow Demonstration")
69
+ uploaded_files = st.file_uploader("Upload your source files (TXT)", accept_multiple_files=True, type=['txt'])
70
+
71
+ if uploaded_files:
72
+ docs = []
73
+ for uploaded_file in uploaded_files:
74
+ content = uploaded_file.read().decode("utf-8")
75
+ docs.append(TextLoader(file_path=None, content=content).load()[0])
76
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
77
+ new_docs = text_splitter.split_documents(documents=docs)
78
+ embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
79
+ db = Chroma.from_documents(new_docs, embeddings)
80
+ retriever = db.as_retriever(search_kwargs={"k": 4})
81
+ else:
82
+ retriever = None
83
+ st.warning("Please upload at least one text file to proceed.")
84
+ st.stop()
85
+
86
+ # Create Agents
87
+ research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
88
+ code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
89
+ RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
90
+
91
+ research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
92
+ code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
93
+ rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
94
+
95
+ members = ["RAG", "Researcher", "Coder"]
96
+ system_prompt = (
97
+ "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
98
+ "Use RAG tool for Japan or Sports questions."
99
+ )
100
+ options = ["FINISH"] + members
101
+ function_def = {
102
+ "name": "route", "description": "Select the next role.",
103
+ "parameters": {
104
+ "title": "routeSchema", "type": "object",
105
+ "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]
106
+ }
107
+ }
108
+ prompt = ChatPromptTemplate.from_messages([
109
+ ("system", system_prompt),
110
+ MessagesPlaceholder(variable_name="messages"),
111
+ ("system", "Given the conversation above, who should act next? Select one of: {options}"),
112
+ ]).partial(options=str(options), members=", ".join(members))
113
+
114
+ supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
115
+
116
+ # Build Workflow
117
+ class AgentState(TypedDict):
118
+ messages: Annotated[Sequence[BaseMessage], operator.add]
119
+ next: str
120
+
121
+ workflow = StateGraph(AgentState)
122
+ workflow.add_node("Researcher", research_node)
123
+ workflow.add_node("Coder", code_node)
124
+ workflow.add_node("RAG", rag_node)
125
+ workflow.add_node("supervisor", supervisor_chain)
126
+
127
+ for member in members:
128
+ workflow.add_edge(member, "supervisor")
129
+ conditional_map = {k: k for k in members}
130
+ conditional_map["FINISH"] = END
131
+ workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
132
+ workflow.set_entry_point("supervisor")
133
+ graph = workflow.compile()
134
+
135
+ # Streamlit UI
136
+ if 'outputs' not in st.session_state:
137
+ st.session_state.outputs = []
138
+
139
+ user_input = st.text_area("Enter your task or question:")
140
+
141
+ def run_workflow(task):
142
+ st.session_state.outputs.clear()
143
+ st.session_state.outputs.append(f"User Input: {task}")
144
+ for state in graph.stream({"messages": [HumanMessage(content=task)]}):
145
+ if "__end__" not in state:
146
+ st.session_state.outputs.append(str(state))
147
+ st.session_state.outputs.append("----")
148
+
149
+ if st.button("Run Workflow"):
150
+ if user_input:
151
+ run_workflow(user_input)
152
+ else:
153
+ st.warning("Please enter a task or question.")
154
+
155
+ st.subheader("Workflow Output:")
156
+ for output in st.session_state.outputs:
157
+ st.text(output)