VanguardAI's picture
Update app.py
c1e9d2a verified
raw
history blame
9.02 kB
import gradio as gr
import torch
import os
import numpy as np
from groq import Groq
import spaces # Import spaces
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from parler_tts import ParlerTTSForConditionalGeneration
import soundfile as sf
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA, LLMChain
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from langchain.llms import Groq as GroqLlm # Import GroqLlm
from PIL import Image
from decord import VideoReader, cpu
from tavily import TavilyClient
import requests
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Initialize models and clients
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
MODEL = 'llama3-groq-70b-8192-tool-use-preview'
llm = GroqLlm(client=client, model=MODEL) # Initialize GroqLlm
vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True)
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1")
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
# Image generation model
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")
image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing")
# Tavily Client for web search
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY")) # Corrected API key
# Function to play voice output
def play_voice_output(response):
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda')
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
return "output.wav"
# NumPy Code Calculator Tool
def numpy_code_calculator(query):
"""Generates and executes NumPy code for mathematical operations."""
try:
llm_response = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "user", "content": f"Write NumPy code to: {query}"}
]
)
code = llm_response.choices[0].message.content
print(f"Generated NumPy code:\n{code}")
# Execute the code in a safe environment
local_dict = {"np": np}
exec(code, local_dict)
result = local_dict.get("result", "No result found")
return str(result)
except Exception as e:
return f"Error: {e}"
# Web Search Tool
def web_search(query):
"""Performs a web search using Tavily."""
answer = tavily_client.qna_search(query=query)
return answer
# Image Generation Tool
def image_generation(query):
"""Generates an image based on the given prompt."""
image = image_pipe(prompt=query, num_inference_steps=20, guidance_scale=7.5).images[0]
image.save("output.jpg")
return "output.jpg"
# Document Question Answering Tool
def doc_question_answering(query, file_path):
"""Answers questions based on the content of a document."""
with open(file_path, 'r') as f:
file_content = f.read()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.create_documents([file_content])
embeddings = OpenAIEmbeddings()
db = Chroma.from_documents(docs, embeddings, persist_directory=".chroma_db")
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=db.as_retriever())
return qa.run(query)
# Function to handle different input types and choose the right tool
def handle_input(user_prompt, image=None, audio=None, doc=None, websearch=False):
# Voice input handling
if audio:
# Make sure 'audio' is a file object
if isinstance(audio, str):
audio = open(audio, "rb")
transcription = client.audio.transcriptions.create(
file=(audio.name, audio.read()),
model="whisper-large-v3"
)
user_prompt = transcription.text
# Initialize tools
tools = [
Tool(
name="Numpy Code Calculator",
func=numpy_code_calculator,
description="Useful for when you need to perform mathematical calculations using NumPy. Provide the calculation you want to perform.",
),
Tool(
name="Web Search",
func=web_search,
description="Useful for when you need to find information from the real world.",
),
Tool(
name="Image Generation",
func=image_generation,
description="Useful for when you need to generate an image based on a description.",
),
]
# Add document Q&A tool if a document is provided
if doc:
tools.append(
Tool(
name="Document Question Answering",
func=lambda query: doc_question_answering(query, doc.name),
description="Useful for when you need to answer questions about the uploaded document.",
)
)
# Initialize agent
agent = ZeroShotAgent(llm_chain=LLMChain(llm=llm, prompt=None), tools=tools, verbose=True)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
# If user uploaded an image and text, use MiniCPM model
if image:
image = Image.open(image).convert('RGB')
messages = [{"role": "user", "content": [image, user_prompt]}]
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
return response
# Use the agent to determine the best tool and get the response
if websearch:
response = agent_executor.run(f"{user_prompt} Use the Web Search tool if necessary.")
else:
response = agent_executor.run(user_prompt)
return response
# Gradio UI Setup
def create_ui():
with gr.Blocks() as demo:
gr.Markdown("# AI Assistant")
with gr.Row():
with gr.Column(scale=2):
user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1)
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
doc_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon")
voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
with gr.Column(scale=1):
submit = gr.Button("Submit")
output_label = gr.Label(label="Output")
audio_output = gr.Audio(label="Audio Output", visible=False)
submit.click(
fn=main_interface,
inputs=[user_prompt, image_input, audio_input, doc_input, voice_only_mode, websearch_mode],
outputs=[output_label, audio_output]
)
# Voice-only mode UI
voice_only_mode.change(
lambda x: gr.update(visible=not x),
inputs=voice_only_mode,
outputs=[user_prompt, image_input, doc_input, websearch_mode, submit]
)
voice_only_mode.change(
lambda x: gr.update(visible=x),
inputs=voice_only_mode,
outputs=[audio_input]
)
return demo
# Main interface function
@spaces.GPU()
def main_interface(user_prompt, image=None, audio=None, doc=None, voice_only=False, websearch=False):
vqa_model.to(device='cuda', dtype=torch.bfloat16)
tts_model.to("cuda")
unet.to("cuda")
image_pipe.to("cuda")
response = handle_input(user_prompt, image=image, audio=audio, doc=doc, websearch=websearch)
if voice_only:
audio_file = play_voice_output(response)
return response, audio_file
else:
return response, None
# Launch the app
demo = create_ui()
demo.launch(inline=False)