import gradio as gr import torch import os import numpy as np from groq import Groq import spaces from transformers import AutoModel, AutoTokenizer from diffusers import StableDiffusion3Pipeline from parler_tts import ParlerTTSForConditionalGeneration import soundfile as sf from llama_index.core.agent import ReActAgent from llama_index.core.tools import FunctionTool from llama_index.llms.groq import Groq from PIL import Image from tavily import TavilyClient import requests from huggingface_hub import hf_hub_download from safetensors.torch import load_file from llama_index.core.chat_engine.types import AgentChatResponse # Initialize models and clients MODEL = 'llama3-groq-70b-8192-tool-use-preview' client = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")) vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True) tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1") tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") # Updated Image generation model pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) pipe = pipe.to("cuda") # Tavily Client for web search tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API")) # Function to play voice output def play_voice_output(response): description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda') prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda') generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) audio_arr = generation.cpu().numpy().squeeze() sf.write("output.wav", audio_arr, tts_model.config.sampling_rate) return "output.wav" # NumPy Code Calculator Tool def numpy_code_calculator(query): try: # Assume query is a request for a numpy computation local_dict = {"np": np} exec(query, local_dict) result = local_dict.get("result", "No result found") return str(result) except Exception as e: return f"Error: {e}" # Web Search Tool def web_search(query): answer = tavily_client.qna_search(query=query) return answer # Image Generation Tool def image_generation(query): image = pipe( query, negative_prompt="", num_inference_steps=15, guidance_scale=7.0, ).images[0] image.save("output.jpg") return "output.jpg" # Function to handle different input types and choose the right tool def handle_input(user_prompt, image=None, audio=None, websearch=False): if audio: if isinstance(audio, str): audio = open(audio, "rb") transcription = client.audio.transcriptions.create( file=(audio.name, audio.read()), model="whisper-large-v3" ) user_prompt = transcription.text tools = [ FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy"), FunctionTool.from_defaults(fn=image_generation, name="Image"), ] # Add the web search tool only if websearch mode is enabled if websearch: tools.append(FunctionTool.from_defaults(fn=web_search, name="Web")) llm = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")) agent = ReActAgent.from_tools(tools, llm=llm, verbose=True) if image: image = Image.open(image).convert('RGB') messages = [{"role": "user", "content": [image, user_prompt]}] response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer) else: response = agent.chat(user_prompt) # Extract the content from AgentChatResponse to return as a string if isinstance(response, AgentChatResponse): response = response.response return response # Gradio UI Setup def create_ui(): with gr.Blocks() as demo: gr.Markdown("# AI Assistant") with gr.Row(): with gr.Column(scale=2): user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1) with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon") audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon") voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode") websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode") with gr.Column(scale=1): submit = gr.Button("Submit") output_label = gr.Label(label="Output") audio_output = gr.Audio(label="Audio Output", visible=False) submit.click( fn=main_interface, inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode], outputs=[output_label, audio_output] ) voice_only_mode.change( lambda x: gr.update(visible=not x), inputs=voice_only_mode, outputs=[user_prompt, image_input, websearch_mode, submit] ) voice_only_mode.change( lambda x: gr.update(visible=x), inputs=voice_only_mode, outputs=[audio_input] ) return demo # Main interface function @spaces.GPU() def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False): print("Starting main_interface function") vqa_model.to(device='cuda', dtype=torch.bfloat16) tts_model.to("cuda") pipe.to("cuda") print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}") try: response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch) print("handle_input function executed successfully") except Exception as e: print(f"Error in handle_input: {e}") response = "Error occurred during processing." if voice_only: try: audio_output = play_voice_output(response) print("play_voice_output function executed successfully") return "Response generated.", audio_output except Exception as e: print(f"Error in play_voice_output: {e}") return "Error occurred during voice output.", None else: return response, None # Launch the UI demo = create_ui() demo.launch()