Spaces:
Sleeping
Sleeping
File size: 7,539 Bytes
5f52293 ed2f5ce 69f2e98 ed2f5ce 2a36ff2 82043d5 1061b7a ed2f5ce 345a48b ab48671 ed2f5ce 2a200be ed2f5ce 2e5cfb3 d70fa2a df220f6 ed2f5ce 1197e50 ed2f5ce bd00948 1197e50 1061b7a 2e5a20c d5685b0 7f9822a ed2f5ce d5685b0 1061b7a 724aed2 d5685b0 1197e50 79549f2 2a200be 1197e50 ed2f5ce d5685b0 1197e50 ed2f5ce d70fa2a 1197e50 d70fa2a ed2f5ce 1197e50 72a27e8 1197e50 724aed2 1061b7a 724aed2 1197e50 df220f6 1197e50 df220f6 2a200be c1e9d2a 1e8badb 2a200be 1197e50 724aed2 64725d2 1197e50 df220f6 1e8badb 64725d2 1197e50 df220f6 d2df209 5db9d8c 03ff38e 2a200be 1197e50 2a200be 5db9d8c df220f6 9e943cf d70fa2a df220f6 2a200be 82043d5 1e8badb 2a200be df220f6 2a200be 1197e50 2a200be df220f6 1197e50 2a200be df220f6 2a200be d5685b0 1197e50 df220f6 a6b3f71 1197e50 724aed2 1197e50 df220f6 a6b3f71 df220f6 a6b3f71 1197e50 a6b3f71 1197e50 a6b3f71 a742a0d 2a200be df220f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import gradio as gr
import torch
import os
import numpy as np
from groq import Groq
import spaces
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusion3Pipeline
from parler_tts import ParlerTTSForConditionalGeneration
import soundfile as sf
from llama_index.core.agent import ReActAgent
from llama_index.core.tools import FunctionTool
from llama_index.llms.groq import Groq
from PIL import Image
from tavily import TavilyClient
import requests
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core import VectorStoreIndex
# Initialize models and clients
MODEL = 'llama3-groq-70b-8192-tool-use-preview'
client = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True)
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1")
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
# Updated Image generation model
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
# Tavily Client for web search
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API"))
# Function to play voice output
def play_voice_output(response):
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda')
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
return "output.wav"
# NumPy Code Calculator Tool
def numpy_code_calculator(query):
try:
# Assume query is a request for a numpy computation
local_dict = {"np": np}
exec(query, local_dict)
result = local_dict.get("result", "No result found")
return str(result)
except Exception as e:
return f"Error: {e}"
# Web Search Tool
def web_search(query):
answer = tavily_client.qna_search(query=query)
return answer
# Image Generation Tool
def image_generation(query):
image = pipe(
query,
negative_prompt="",
num_inference_steps=15,
guidance_scale=7.0,
).images[0]
image.save("output.jpg")
return "output.jpg"
# Document Question Answering Tool
def document_question_answering(query, docs):
index = VectorStoreIndex.from_documents(docs)
query_engine = index.as_query_engine(similarity_top_k=3)
response = query_engine.query(query)
return str(response)
# Function to handle different input types and choose the right tool
def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
if audio:
if isinstance(audio, str):
audio = open(audio, "rb")
transcription = client.audio.transcriptions.create(
file=(audio.name, audio.read()),
model="whisper-large-v3"
)
user_prompt = transcription.text
tools = [
FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy"),
FunctionTool.from_defaults(fn=image_generation, name="Image"),
]
# Add the web search tool only if websearch mode is enabled
if websearch:
tools.append(FunctionTool.from_defaults(fn=web_search, name="Web"))
# Add the document question answering tool only if a document is provided
if document:
docs = LlamaParse(result_type="text").load_data(document)
tools.append(FunctionTool.from_defaults(fn=document_question_answering, name="Document", docs=docs))
llm = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
agent = ReActAgent.from_tools(tools, llm=llm, verbose=True)
if image:
image = Image.open(image).convert('RGB')
messages = [{"role": "user", "content": [image, user_prompt]}]
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
else:
response = agent.chat(user_prompt)
# Extract the content from AgentChatResponse to return as a string
if isinstance(response, AgentChatResponse):
response = response.response
return response
# Gradio UI Setup
def create_ui():
with gr.Blocks() as demo:
gr.Markdown("# AI Assistant")
with gr.Row():
with gr.Column(scale=2):
user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1)
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
document_input = gr.File(type="file", label="Upload a document", elem_id="document-icon")
voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
with gr.Column(scale=1):
submit = gr.Button("Submit")
output_label = gr.Label(label="Output")
audio_output = gr.Audio(label="Audio Output", visible=False)
submit.click(
fn=main_interface,
inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode, document_input],
outputs=[output_label, audio_output]
)
voice_only_mode.change(
lambda x: gr.update(visible=not x),
inputs=voice_only_mode,
outputs=[user_prompt, image_input, websearch_mode, document_input, submit]
)
voice_only_mode.change(
lambda x: gr.update(visible=x),
inputs=voice_only_mode,
outputs=[audio_input]
)
return demo
# Main interface function
@spaces.GPU()
def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
print("Starting main_interface function")
vqa_model.to(device='cuda', dtype=torch.bfloat16)
tts_model.to("cuda")
pipe.to("cuda")
print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
try:
response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
print("handle_input function executed successfully")
except Exception as e:
print(f"Error in handle_input: {e}")
response = "Error occurred during processing."
if voice_only:
try:
audio_output = play_voice_output(response)
print("play_voice_output function executed successfully")
return "Response generated.", audio_output
except Exception as e:
print(f"Error in play_voice_output: {e}")
return "Error occurred during voice output.", None
else:
return response, None
# Launch the UI
demo = create_ui()
demo.launch() |