DrishtiSharma commited on
Commit
a7d581f
·
verified ·
1 Parent(s): 2685a62

Delete interim.py

Browse files
Files changed (1) hide show
  1. interim.py +0 -184
interim.py DELETED
@@ -1,184 +0,0 @@
1
- import os
2
- import streamlit as st
3
- from dotenv import load_dotenv
4
- from langchain_openai import ChatOpenAI
5
- from langchain.agents import AgentExecutor, create_openai_tools_agent
6
- from langchain_core.messages import BaseMessage, HumanMessage
7
- from langchain_community.tools.tavily_search import TavilySearchResults
8
- from langchain_experimental.tools import PythonREPLTool
9
- from langchain_community.document_loaders import DirectoryLoader, TextLoader
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
- from langchain_community.vectorstores import Chroma
12
- from langchain.embeddings import HuggingFaceBgeEmbeddings
13
- from langchain_core.output_parsers import StrOutputParser
14
- from langchain_core.runnables import RunnablePassthrough
15
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
16
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
- from langgraph.graph import StateGraph, END
18
- from langchain_core.documents import Document
19
- from typing import Annotated, Sequence, TypedDict
20
- import functools
21
- import operator
22
- from langchain_core.tools import tool
23
- from glob import glob
24
-
25
- # Load environment variables
26
-
27
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
- TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
29
-
30
- if not OPENAI_API_KEY or not TAVILY_API_KEY:
31
- st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
32
- st.stop()
33
-
34
- # Initialize API keys and LLM
35
- llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
36
-
37
- # Utility Functions
38
- def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
39
- prompt = ChatPromptTemplate.from_messages([
40
- ("system", system_prompt),
41
- MessagesPlaceholder(variable_name="messages"),
42
- MessagesPlaceholder(variable_name="agent_scratchpad"),
43
- ])
44
- agent = create_openai_tools_agent(llm, tools, prompt)
45
- return AgentExecutor(agent=agent, tools=tools)
46
-
47
- def agent_node(state, agent, name):
48
- result = agent.invoke(state)
49
- return {"messages": [HumanMessage(content=result["output"], name=name)]}
50
-
51
- @tool
52
- def RAG(state):
53
- """Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results."""
54
- st.session_state.outputs.append('-> Calling RAG ->')
55
- question = state
56
- template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}"""
57
- prompt = ChatPromptTemplate.from_template(template)
58
- retrieval_chain = (
59
- {"context": retriever, "question": RunnablePassthrough()} |
60
- prompt |
61
- llm |
62
- StrOutputParser()
63
- )
64
- result = retrieval_chain.invoke(question)
65
- return result
66
-
67
- # Load Tools
68
- tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
69
- python_repl_tool = PythonREPLTool()
70
-
71
- # Streamlit UI
72
- st.title("Multi-Agent w Supervisor")
73
-
74
- # Example questions for immediate testing
75
- example_questions = [
76
- "Code hello world and print it to the terminal",
77
- "What is James McIlroy aiming for in sports?",
78
- "Fetch India's GDP over the past 5 years and draw a line graph.",
79
- "Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph."
80
- ]
81
-
82
- # File Selection Section
83
- source_files = glob("source/*.txt")
84
- selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2])
85
-
86
- uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt'])
87
-
88
- # Combine Files
89
- all_docs = []
90
- if selected_files:
91
- for file_path in selected_files:
92
- loader = TextLoader(file_path)
93
- all_docs.extend(loader.load())
94
-
95
- if uploaded_files:
96
- for uploaded_file in uploaded_files:
97
- content = uploaded_file.read().decode("utf-8")
98
- all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
99
-
100
- if not all_docs:
101
- st.warning("Please select files from the source directory or upload TXT files.")
102
- st.stop()
103
-
104
- # Process Documents
105
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
106
- split_docs = text_splitter.split_documents(all_docs)
107
-
108
- embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
109
- db = Chroma.from_documents(split_docs, embeddings)
110
- retriever = db.as_retriever(search_kwargs={"k": 4})
111
-
112
- # Create Agents
113
- research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
114
- code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
115
- RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
116
-
117
- research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
118
- code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
119
- rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
120
-
121
- members = ["RAG", "Researcher", "Coder"]
122
- system_prompt = (
123
- "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
124
- "Use RAG tool for Japan or Sports questions."
125
- )
126
- options = ["FINISH"] + members
127
- function_def = {
128
- "name": "route", "description": "Select the next role.",
129
- "parameters": {"title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]}
130
- }
131
- prompt = ChatPromptTemplate.from_messages([
132
- ("system", system_prompt),
133
- MessagesPlaceholder(variable_name="messages"),
134
- ("system", "Given the conversation above, who should act next? Select one of: {options}"),
135
- ]).partial(options=str(options), members=", ".join(members))
136
-
137
- supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
138
-
139
- # Workflow
140
- class AgentState(TypedDict):
141
- messages: Annotated[Sequence[BaseMessage], operator.add]
142
- next: str
143
-
144
- workflow = StateGraph(AgentState)
145
- workflow.add_node("Researcher", research_node)
146
- workflow.add_node("Coder", code_node)
147
- workflow.add_node("RAG", rag_node)
148
- workflow.add_node("supervisor", supervisor_chain)
149
-
150
- for member in members:
151
- workflow.add_edge(member, "supervisor")
152
- conditional_map = {k: k for k in members}
153
- conditional_map["FINISH"] = END
154
- workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
155
- workflow.set_entry_point("supervisor")
156
- graph = workflow.compile()
157
-
158
- # Workflow Execution
159
- if 'outputs' not in st.session_state:
160
- st.session_state.outputs = []
161
-
162
- user_input = st.text_area("Enter your task or question:", placeholder=example_questions[0])
163
-
164
- def run_workflow(task):
165
- st.session_state.outputs.clear()
166
- st.session_state.outputs.append(f"User Input: {task}")
167
- for state in graph.stream({"messages": [HumanMessage(content=task)]}):
168
- if "__end__" not in state:
169
- st.session_state.outputs.append(str(state))
170
- st.session_state.outputs.append("----")
171
-
172
- if st.button("Run Workflow"):
173
- if user_input:
174
- run_workflow(user_input)
175
- else:
176
- st.warning("Please enter a task or question.")
177
-
178
- st.subheader("Example Questions:")
179
- for example in example_questions:
180
- st.text(f"- {example}")
181
-
182
- st.subheader("Workflow Output:")
183
- for output in st.session_state.outputs:
184
- st.text(output)