DrishtiSharma commited on
Commit
ddb5a15
·
verified ·
1 Parent(s): 95c2fc1

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +46 -26
interim.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import chromadb
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
  from langchain_openai import ChatOpenAI
@@ -7,7 +6,7 @@ from langchain.agents import AgentExecutor, create_openai_tools_agent
7
  from langchain_core.messages import BaseMessage, HumanMessage
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from langchain_experimental.tools import PythonREPLTool
10
- from langchain_community.document_loaders import DirectoryLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain_community.vectorstores import Chroma
13
  from langchain.embeddings import HuggingFaceBgeEmbeddings
@@ -21,12 +20,10 @@ from typing import Annotated, Sequence, TypedDict
21
  import functools
22
  import operator
23
  from langchain_core.tools import tool
24
-
25
-
26
- # Clear ChromaDB cache to fix tenant issue
27
- chromadb.api.client.SharedSystemClient.clear_system_cache()
28
 
29
  # Load environment variables
 
30
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
31
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
32
 
@@ -67,29 +64,51 @@ def RAG(state):
67
  result = retrieval_chain.invoke(question)
68
  return result
69
 
70
- # Load Tools and Retriever
71
  tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
72
  python_repl_tool = PythonREPLTool()
73
 
74
- # File Upload Section
75
  st.title("Multi-Agent Workflow Demonstration")
76
- uploaded_files = st.file_uploader("Upload your source files (TXT)", accept_multiple_files=True, type=['txt'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  if uploaded_files:
79
- docs = []
80
  for uploaded_file in uploaded_files:
81
  content = uploaded_file.read().decode("utf-8")
82
- docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
83
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
84
- new_docs = text_splitter.split_documents(documents=docs)
85
- embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
86
- db = Chroma.from_documents(new_docs, embeddings)
87
- retriever = db.as_retriever(search_kwargs={"k": 4})
88
- else:
89
- retriever = None
90
- st.warning("Please upload at least one text file to proceed.")
91
  st.stop()
92
 
 
 
 
 
 
 
 
 
93
  # Create Agents
94
  research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
95
  code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
@@ -107,10 +126,7 @@ system_prompt = (
107
  options = ["FINISH"] + members
108
  function_def = {
109
  "name": "route", "description": "Select the next role.",
110
- "parameters": {
111
- "title": "routeSchema", "type": "object",
112
- "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]
113
- }
114
  }
115
  prompt = ChatPromptTemplate.from_messages([
116
  ("system", system_prompt),
@@ -120,7 +136,7 @@ prompt = ChatPromptTemplate.from_messages([
120
 
121
  supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
122
 
123
- # Build Workflow
124
  class AgentState(TypedDict):
125
  messages: Annotated[Sequence[BaseMessage], operator.add]
126
  next: str
@@ -139,11 +155,11 @@ workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_ma
139
  workflow.set_entry_point("supervisor")
140
  graph = workflow.compile()
141
 
142
- # Streamlit UI
143
  if 'outputs' not in st.session_state:
144
  st.session_state.outputs = []
145
 
146
- user_input = st.text_area("Enter your task or question:")
147
 
148
  def run_workflow(task):
149
  st.session_state.outputs.clear()
@@ -159,6 +175,10 @@ if st.button("Run Workflow"):
159
  else:
160
  st.warning("Please enter a task or question.")
161
 
 
 
 
 
162
  st.subheader("Workflow Output:")
163
  for output in st.session_state.outputs:
164
  st.text(output)
 
1
  import os
 
2
  import streamlit as st
3
  from dotenv import load_dotenv
4
  from langchain_openai import ChatOpenAI
 
6
  from langchain_core.messages import BaseMessage, HumanMessage
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_experimental.tools import PythonREPLTool
9
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import Chroma
12
  from langchain.embeddings import HuggingFaceBgeEmbeddings
 
20
  import functools
21
  import operator
22
  from langchain_core.tools import tool
23
+ from glob import glob
 
 
 
24
 
25
  # Load environment variables
26
+
27
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
29
 
 
64
  result = retrieval_chain.invoke(question)
65
  return result
66
 
67
+ # Load Tools
68
  tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
69
  python_repl_tool = PythonREPLTool()
70
 
71
+ # Streamlit UI
72
  st.title("Multi-Agent Workflow Demonstration")
73
+
74
+ # Example questions for immediate testing
75
+ example_questions = [
76
+ "Code hello world and print it to the terminal",
77
+ "What is James McIlroy aiming for in sports?",
78
+ "Fetch India's GDP over the past 5 years and draw a line graph.",
79
+ "Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph."
80
+ ]
81
+
82
+ # File Selection Section
83
+ source_files = glob("source/*.txt")
84
+ selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2])
85
+
86
+ uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt'])
87
+
88
+ # Combine Files
89
+ all_docs = []
90
+ if selected_files:
91
+ for file_path in selected_files:
92
+ loader = TextLoader(file_path)
93
+ all_docs.extend(loader.load())
94
 
95
  if uploaded_files:
 
96
  for uploaded_file in uploaded_files:
97
  content = uploaded_file.read().decode("utf-8")
98
+ all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
99
+
100
+ if not all_docs:
101
+ st.warning("Please select files from the source directory or upload TXT files.")
 
 
 
 
 
102
  st.stop()
103
 
104
+ # Process Documents
105
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
106
+ split_docs = text_splitter.split_documents(all_docs)
107
+
108
+ embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
109
+ db = Chroma.from_documents(split_docs, embeddings)
110
+ retriever = db.as_retriever(search_kwargs={"k": 4})
111
+
112
  # Create Agents
113
  research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
114
  code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
 
126
  options = ["FINISH"] + members
127
  function_def = {
128
  "name": "route", "description": "Select the next role.",
129
+ "parameters": {"title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]}
 
 
 
130
  }
131
  prompt = ChatPromptTemplate.from_messages([
132
  ("system", system_prompt),
 
136
 
137
  supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
138
 
139
+ # Workflow
140
  class AgentState(TypedDict):
141
  messages: Annotated[Sequence[BaseMessage], operator.add]
142
  next: str
 
155
  workflow.set_entry_point("supervisor")
156
  graph = workflow.compile()
157
 
158
+ # Workflow Execution
159
  if 'outputs' not in st.session_state:
160
  st.session_state.outputs = []
161
 
162
+ user_input = st.text_area("Enter your task or question:", placeholder=example_questions[0])
163
 
164
  def run_workflow(task):
165
  st.session_state.outputs.clear()
 
175
  else:
176
  st.warning("Please enter a task or question.")
177
 
178
+ st.subheader("Example Questions:")
179
+ for example in example_questions:
180
+ st.text(f"- {example}")
181
+
182
  st.subheader("Workflow Output:")
183
  for output in st.session_state.outputs:
184
  st.text(output)