VanguardAI commited on
Commit
2a36ff2
·
verified ·
1 Parent(s): 79549f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import os
4
  import numpy as np
5
  from groq import Groq
6
- import spaces # Import spaces
7
  from transformers import AutoModel, AutoTokenizer
8
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
9
  from parler_tts import ParlerTTSForConditionalGeneration
@@ -13,7 +13,6 @@ from langchain_community.vectorstores import Chroma
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.chains import RetrievalQA, LLMChain
15
  from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
16
- from langchain.llms import Groq as GroqLlm # Import GroqLlm
17
  from PIL import Image
18
  from decord import VideoReader, cpu
19
  from tavily import TavilyClient
@@ -24,7 +23,6 @@ from safetensors.torch import load_file
24
  # Initialize models and clients
25
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
26
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
27
- llm = GroqLlm(client=client, model=MODEL) # Initialize GroqLlm
28
 
29
  vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
30
  device_map="auto", torch_dtype=torch.bfloat16)
@@ -103,7 +101,7 @@ def doc_question_answering(query, file_path):
103
  return qa.run(query)
104
 
105
  # Function to handle different input types and choose the right tool
106
- def handle_input(user_prompt, image=None, audio=None, doc=None, websearch=False):
107
  # Voice input handling
108
  if audio:
109
  # Make sure 'audio' is a file object
@@ -144,8 +142,16 @@ def handle_input(user_prompt, image=None, audio=None, doc=None, websearch=False)
144
  )
145
  )
146
 
 
 
 
 
 
 
 
 
147
  # Initialize agent
148
- agent = ZeroShotAgent(llm_chain=LLMChain(llm=llm, prompt=None), tools=tools, verbose=True)
149
  agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
150
 
151
  # If user uploaded an image and text, use MiniCPM model
 
3
  import os
4
  import numpy as np
5
  from groq import Groq
6
+ import spaces
7
  from transformers import AutoModel, AutoTokenizer
8
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
9
  from parler_tts import ParlerTTSForConditionalGeneration
 
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.chains import RetrievalQA, LLMChain
15
  from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
 
16
  from PIL import Image
17
  from decord import VideoReader, cpu
18
  from tavily import TavilyClient
 
23
  # Initialize models and clients
24
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
25
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
 
26
 
27
  vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
28
  device_map="auto", torch_dtype=torch.bfloat16)
 
101
  return qa.run(query)
102
 
103
  # Function to handle different input types and choose the right tool
104
+ def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False):
105
  # Voice input handling
106
  if audio:
107
  # Make sure 'audio' is a file object
 
142
  )
143
  )
144
 
145
+ # Function for the agent's LLM
146
+ def llm_function(query):
147
+ response = client.chat.completions.create(
148
+ model=MODEL,
149
+ messages=[{"role": "user", "content": query}]
150
+ )
151
+ return response.choices[0].message.content
152
+
153
  # Initialize agent
154
+ agent = ZeroShotAgent(llm_chain=LLMChain(llm=llm_function, prompt=None), tools=tools, verbose=True)
155
  agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
156
 
157
  # If user uploaded an image and text, use MiniCPM model