import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer, BitsAndBytesConfig import gradio as gr from threading import Thread from PIL import Image import subprocess import spaces from parler_tts import ParlerTTSForConditionalGeneration import soundfile as sf # Install flash-attention subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # Constants TITLE = "

Phi 3.5 Multimodal (Text + Vision + Speech)

" DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision + Speech)" # Model configurations TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct" VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" device = "cuda" if torch.cuda.is_available() else "cpu" # Quantization config for text model quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) # Load models and tokenizers text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID) text_model = AutoModelForCausalLM.from_pretrained( TEXT_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) vision_model = AutoModelForCausalLM.from_pretrained( VISION_MODEL_ID, trust_remote_code=True, torch_dtype="auto", attn_implementation="flash_attention_2" ).to(device).eval() vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True) # Load Parler-TTS model tts_device = "cuda:0" if torch.cuda.is_available() else "cpu" tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1").to(tts_device) tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") # Helper functions @spaces.GPU def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20): conversation = [{"role": "system", "content": system_prompt}] for prompt, answer in history: conversation.extend([ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, ]) conversation.append({"role": "user", "content": message}) input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device) streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=temperature > 0, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=[128001, 128008, 128009], streamer=streamer, ) with torch.no_grad(): thread = Thread(target=text_model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield history + [[message, buffer]], None # Yield None for audio initially # Generate speech for the final response audio_path = generate_speech(buffer, "A clear and concise voice reads out the response.") yield history + [[message, buffer]], audio_path @spaces.GPU def process_vision_query(image, text_input): prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n" image = Image.fromarray(image).convert("RGB") inputs = vision_processor(prompt, image, return_tensors="pt").to(device) with torch.no_grad(): generate_ids = vision_model.generate( **inputs, max_new_tokens=1000, eos_token_id=vision_processor.tokenizer.eos_token_id ) generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return response @spaces.GPU def generate_speech(prompt, description): input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to(tts_device) prompt_input_ids = tts_tokenizer(prompt, return_tensors="pt").input_ids.to(tts_device) generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) audio_arr = generation.cpu().numpy().squeeze() output_path = f"output_audio_{hash(prompt)}.wav" sf.write(output_path, audio_arr, tts_model.config.sampling_rate) return output_path # Custom CSS custom_css = """ body { background-color: #0b0f19; color: #e2e8f0; font-family: 'Arial', sans-serif;} #custom-header { text-align: center; padding: 20px 0; background-color: #1a202c; margin-bottom: 20px; border-radius: 10px;} #custom-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem;} #custom-header h1 .blue { color: #60a5fa;} #custom-header h1 .pink { color: #f472b6;} #custom-header h2 { font-size: 1.5rem; color: #94a3b8;} .suggestions { display: flex; justify-content: center; flex-wrap: wrap; gap: 1rem; margin: 20px 0;} .suggestion { background-color: #1e293b; border-radius: 0.5rem; padding: 1rem; display: flex; align-items: center; transition: transform 0.3s ease; width: 200px;} .suggestion:hover { transform: translateY(-5px);} .suggestion-icon { font-size: 1.5rem; margin-right: 1rem; background-color: #2d3748; padding: 0.5rem; border-radius: 50%;} .gradio-container { max-width: 100% !important;} #component-0, #component-1, #component-2 { max-width: 100% !important;} footer { text-align: center; margin-top: 2rem; color: #64748b;} """ # Custom HTML for the header custom_header = """

Phi 3.5 Multimodal Assistant

Text, Vision, and Speech AI at Your Service

""" # Custom HTML for suggestions custom_suggestions = """
💬

Chat with the Text Model

🖼️

Analyze Images with Vision Model

🔊

Generate Speech with Parler-TTS

🔍

Explore advanced options

""" # Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Base().set( body_background_fill="#0b0f19", body_text_color="#e2e8f0", button_primary_background_fill="#3b82f6", button_primary_background_fill_hover="#2563eb", button_primary_text_color="white", block_title_text_color="#94a3b8", block_label_text_color="#94a3b8", )) as demo: gr.HTML(custom_header) gr.HTML(custom_suggestions) with gr.Tab("Text Model (Phi-3.5-mini)"): chatbot = gr.Chatbot(height=400) msg = gr.Textbox(label="Message", placeholder="Type your message here...") with gr.Accordion("Advanced Options", open=False): system_prompt = gr.Textbox(value="You are a helpful assistant", label="System Prompt") temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature") max_new_tokens = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens") top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p") top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k") submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear Chat", variant="secondary") audio_output = gr.Audio(label="AI Response Audio") submit_btn.click(stream_text_chat, inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], outputs=[chatbot, audio_output]) clear_btn.click(lambda: (None, None), None, [chatbot, audio_output], queue=False) with gr.Tab("Vision Model (Phi-3.5-vision)"): with gr.Row(): with gr.Column(scale=1): vision_input_img = gr.Image(label="Upload an Image", type="pil") vision_text_input = gr.Textbox(label="Ask a question about the image", placeholder="What do you see in this image?") vision_submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(scale=1): vision_output_text = gr.Textbox(label="AI Analysis", lines=10) vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text]) with gr.Tab("Text-to-Speech (Parler-TTS)"): with gr.Row(): with gr.Column(scale=1): tts_prompt = gr.Textbox(label="Text to Speak", placeholder="Enter the text you want to convert to speech...") tts_description = gr.Textbox(label="Voice Description", value="A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up.", lines=3) tts_submit_btn = gr.Button("Generate Speech", variant="primary") with gr.Column(scale=1): tts_output_audio = gr.Audio(label="Generated Speech") tts_submit_btn.click(generate_speech, inputs=[tts_prompt, tts_description], outputs=[tts_output_audio]) gr.HTML("") if __name__ == "__main__": demo.launch()