VanguardAI commited on
Commit
c7c3138
·
verified ·
1 Parent(s): 8318c4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -63
app.py CHANGED
@@ -8,10 +8,7 @@ from transformers import AutoModel, AutoTokenizer
8
  from diffusers import StableDiffusion3Pipeline
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
- from langchain.agents import AgentExecutor, create_react_agent
12
- from langchain.tools import BaseTool
13
- from langchain_groq import ChatGroq
14
- from langchain.agents import AgentExecutor, initialize_agent, Tool
15
  from langchain.agents import AgentType
16
  from langchain_groq import ChatGroq
17
  from langchain.prompts import PromptTemplate
@@ -56,9 +53,9 @@ def play_voice_output(response):
56
  return "output.wav"
57
 
58
  # NumPy Code Calculator Tool
59
- class NumpyCodeCalculator(BaseTool):
60
  name = "Numpy"
61
- description = "Useful for performing numpy computations"
62
 
63
  def _run(self, query: str) -> str:
64
  try:
@@ -70,16 +67,16 @@ class NumpyCodeCalculator(BaseTool):
70
  return f"Error: {e}"
71
 
72
  # Web Search Tool
73
- class WebSearch(BaseTool):
74
  name = "Web"
75
- description = "Useful for searching the web for information"
76
 
77
  def _run(self, query: str) -> str:
78
  answer = tavily_client.qna_search(query=query)
79
  return answer
80
 
81
  # Image Generation Tool
82
- class ImageGeneration(BaseTool):
83
  name = "Image"
84
  description = "Useful for generating images based on text descriptions"
85
 
@@ -94,7 +91,7 @@ class ImageGeneration(BaseTool):
94
  return "output.jpg"
95
 
96
  # Document Question Answering Tool
97
- class DocumentQuestionAnswering(BaseTool):
98
  name = "Document"
99
  description = "Useful for answering questions about a specific document"
100
 
@@ -122,8 +119,8 @@ class DocumentQuestionAnswering(BaseTool):
122
  response = self.qa_chain.run(query)
123
  return str(response)
124
 
125
- class DuckDuckGoSearchRun(BaseTool):
126
- name = "DuckDuckGo"
127
  description = "Useful for searching the internet for general information"
128
 
129
  def _run(self, query: str) -> str:
@@ -136,75 +133,52 @@ class DuckDuckGoSearchRun(BaseTool):
136
  data = response.json()
137
  answer = data["Abstract"]
138
  return answer
139
-
140
- # Function to handle different input types and choose the right tool
141
  # Function to handle different input types and choose the right tool
142
- def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
143
- # Initialize the search tool
144
- search = DuckDuckGoSearchRun()
145
 
 
146
  tools = [
147
- Tool(
148
- name="Search",
149
- func=search.run,
150
- description="Useful for searching the internet for general information"
151
- ),
152
- Tool(
153
- name="Image",
154
- func=ImageGeneration()._run,
155
- description="Useful for generating images based on text descriptions"
156
- ),
157
  ]
158
 
159
- # Add the numpy tool, but with a more specific description
160
- tools.append(Tool(
161
- name="Numpy",
162
- func=NumpyCodeCalculator()._run,
163
- description="Useful only for performing numerical computations, not for general searches"
164
- ))
165
-
166
  # Add the web search tool only if websearch mode is enabled
167
  if websearch:
168
- tools.append(Tool(
169
- name="Web",
170
- func=WebSearch()._run,
171
- description="Useful for advanced web searching beyond general information"
172
- ))
173
 
174
  # Add the document question answering tool only if a document is provided
175
  if document:
176
- tools.append(Tool(
177
- name="Document",
178
- func=DocumentQuestionAnswering(document)._run,
179
- description="Useful for answering questions about a specific document"
180
- ))
181
 
182
- llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
 
 
 
 
 
 
 
 
 
 
183
 
184
  # Check if the input requires any tools
185
- requires_tool = False
186
- for tool in tools:
187
- if tool.name.lower() in user_prompt.lower():
188
- requires_tool = True
189
- break
190
-
191
- if image or audio or requires_tool:
192
- # Initialize the agent
193
  agent = initialize_agent(
194
  tools,
195
  llm,
196
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
197
  verbose=True
198
  )
199
-
200
- if image:
201
- image = Image.open(image).convert('RGB')
202
- messages = [{"role": "user", "content": [image, user_prompt]}]
203
- response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
204
- else:
205
- response = agent.run(user_prompt)
206
  else:
207
- # If no tools are required, use the LLM directly
208
  response = llm.call(query=user_prompt)
209
 
210
  return response
@@ -420,7 +394,6 @@ def create_ui():
420
 
421
  return demo
422
 
423
- # Main interface function
424
  @spaces.GPU(duration=180)
425
  def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
426
  print("Starting main_interface function")
@@ -431,7 +404,7 @@ def main_interface(user_prompt, image=None, audio=None, voice_only=False, websea
431
  print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
432
 
433
  try:
434
- response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
435
  print("handle_input function executed successfully")
436
  except Exception as e:
437
  print(f"Error in handle_input: {e}")
 
8
  from diffusers import StableDiffusion3Pipeline
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
+ from langchain.agents import AgentExecutor, create_react_agent, initialize_agent, Tool
 
 
 
12
  from langchain.agents import AgentType
13
  from langchain_groq import ChatGroq
14
  from langchain.prompts import PromptTemplate
 
53
  return "output.wav"
54
 
55
  # NumPy Code Calculator Tool
56
+ class NumpyCodeCalculator(Tool):
57
  name = "Numpy"
58
+ description = "Useful only for performing numerical computations, not for general searches"
59
 
60
  def _run(self, query: str) -> str:
61
  try:
 
67
  return f"Error: {e}"
68
 
69
  # Web Search Tool
70
+ class WebSearch(Tool):
71
  name = "Web"
72
+ description = "Useful for advanced web searching beyond general information"
73
 
74
  def _run(self, query: str) -> str:
75
  answer = tavily_client.qna_search(query=query)
76
  return answer
77
 
78
  # Image Generation Tool
79
+ class ImageGeneration(Tool):
80
  name = "Image"
81
  description = "Useful for generating images based on text descriptions"
82
 
 
91
  return "output.jpg"
92
 
93
  # Document Question Answering Tool
94
+ class DocumentQuestionAnswering(Tool):
95
  name = "Document"
96
  description = "Useful for answering questions about a specific document"
97
 
 
119
  response = self.qa_chain.run(query)
120
  return str(response)
121
 
122
+ class DuckDuckGoSearchRun(Tool):
123
+ name = "Search"
124
  description = "Useful for searching the internet for general information"
125
 
126
  def _run(self, query: str) -> str:
 
133
  data = response.json()
134
  answer = data["Abstract"]
135
  return answer
136
+
 
137
  # Function to handle different input types and choose the right tool
138
+ def handle_input(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
139
+ # Initialize the LLM
140
+ llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
141
 
142
+ # Initialize tools
143
  tools = [
144
+ DuckDuckGoSearchRun(),
145
+ ImageGeneration(),
146
+ NumpyCodeCalculator(),
 
 
 
 
 
 
 
147
  ]
148
 
 
 
 
 
 
 
 
149
  # Add the web search tool only if websearch mode is enabled
150
  if websearch:
151
+ tools.append(WebSearch())
 
 
 
 
152
 
153
  # Add the document question answering tool only if a document is provided
154
  if document:
155
+ tools.append(DocumentQuestionAnswering(document))
 
 
 
 
156
 
157
+ # Handle voice input
158
+ if voice_only and audio:
159
+ # TODO: Implement Whisper integration for voice-to-text
160
+ user_prompt = "Whisper transcription of audio" # Replace with actual transcription
161
+
162
+ # Handle image and text input
163
+ if image and user_prompt:
164
+ image = Image.open(image).convert('RGB')
165
+ messages = [{"role": "user", "content": [image, user_prompt]}]
166
+ response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
167
+ return response
168
 
169
  # Check if the input requires any tools
170
+ requires_tool = any(tool.name.lower() in user_prompt.lower() for tool in tools)
171
+
172
+ # Use agent if tools are required, otherwise use LLM directly
173
+ if requires_tool:
 
 
 
 
174
  agent = initialize_agent(
175
  tools,
176
  llm,
177
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
178
  verbose=True
179
  )
180
+ response = agent.run(user_prompt)
 
 
 
 
 
 
181
  else:
 
182
  response = llm.call(query=user_prompt)
183
 
184
  return response
 
394
 
395
  return demo
396
 
 
397
  @spaces.GPU(duration=180)
398
  def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
399
  print("Starting main_interface function")
 
404
  print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
405
 
406
  try:
407
+ response = handle_input(user_prompt, image=image, audio=audio, voice_only=voice_only, websearch=websearch, document=document)
408
  print("handle_input function executed successfully")
409
  except Exception as e:
410
  print(f"Error in handle_input: {e}")