import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer, BitsAndBytesConfig import gradio as gr from threading import Thread from PIL import Image import subprocess import spaces from parler_tts import ParlerTTSForConditionalGeneration import soundfile as sf # Install flash-attention subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # Constants TITLE = "

Phi 3.5 Multimodal (Text + Vision)

" DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)" # Model configurations TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct" VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" device = "cuda" if torch.cuda.is_available() else "cpu" # Quantization config for text model quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) # Load models and tokenizers text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID) text_model = AutoModelForCausalLM.from_pretrained( TEXT_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) vision_model = AutoModelForCausalLM.from_pretrained( VISION_MODEL_ID, trust_remote_code=True, torch_dtype="auto", attn_implementation="flash_attention_2" ).to(device).eval() vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True) # Helper functions @spaces.GPU def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20): conversation = [{"role": "system", "content": system_prompt}] for prompt, answer in history: conversation.extend([ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, ]) conversation.append({"role": "user", "content": message}) input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device) streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=temperature > 0, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=[128001, 128008, 128009], streamer=streamer, ) with torch.no_grad(): thread = Thread(target=text_model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield history + [[message, buffer]] @spaces.GPU # Add this decorator def process_vision_query(image, text_input): prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n" image = Image.fromarray(image).convert("RGB") inputs = vision_processor(prompt, image, return_tensors="pt").to(device) with torch.no_grad(): generate_ids = vision_model.generate( **inputs, max_new_tokens=1000, eos_token_id=vision_processor.tokenizer.eos_token_id ) generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return response # Load Parler-TTS model tts_device = "cuda:0" if torch.cuda.is_available() else "cpu" tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1").to(tts_device) tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") @spaces.GPU def generate_speech(prompt, description): input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to(tts_device) prompt_input_ids = tts_tokenizer(prompt, return_tensors="pt").input_ids.to(tts_device) generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) audio_arr = generation.cpu().numpy().squeeze() output_path = "output_audio.wav" sf.write(output_path, audio_arr, tts_model.config.sampling_rate) return output_path # Custom CSS custom_css = """ body { background-color: #0b0f19; color: #e2e8f0; font-family: 'Arial', sans-serif;} #custom-header { text-align: center; padding: 20px 0; background-color: #1a202c; margin-bottom: 20px; border-radius: 10px;} #custom-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem;} #custom-header h1 .blue { color: #60a5fa;} #custom-header h1 .pink { color: #f472b6;} #custom-header h2 { font-size: 1.5rem; color: #94a3b8;} .suggestions { display: flex; justify-content: center; flex-wrap: wrap; gap: 1rem; margin: 20px 0;} .suggestion { background-color: #1e293b; border-radius: 0.5rem; padding: 1rem; display: flex; align-items: center; transition: transform 0.3s ease; width: 200px;} .suggestion:hover { transform: translateY(-5px);} .suggestion-icon { font-size: 1.5rem; margin-right: 1rem; background-color: #2d3748; padding: 0.5rem; border-radius: 50%;} .gradio-container { max-width: 100% !important;} #component-0, #component-1, #component-2 { max-width: 100% !important;} footer { text-align: center; margin-top: 2rem; color: #64748b;} """ # Custom HTML for the header custom_header = """

Phi 3.5 Multimodal Assistant

Text and Vision AI at Your Service

""" # Custom HTML for suggestions custom_suggestions = """
💬

Chat with the Text Model

🖼️

Analyze Images with Vision Model

🔊

Generate Speech with Parler-TTS

🔍

Explore advanced options

""" # Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Base().set( body_background_fill="#0b0f19", body_text_color="#e2e8f0", button_primary_background_fill="#3b82f6", button_primary_background_fill_hover="#2563eb", button_primary_text_color="white", block_title_text_color="#94a3b8", block_label_text_color="#94a3b8", )) as demo: gr.HTML(custom_header) gr.HTML(custom_suggestions) with gr.Tab("Text Model (Phi-3.5-mini)"): # ... (previous text model code remains the same) with gr.Tab("Vision Model (Phi-3.5-vision)"): # ... (previous vision model code remains the same) with gr.Tab("Text-to-Speech (Parler-TTS)"): with gr.Row(): with gr.Column(scale=1): tts_prompt = gr.Textbox(label="Text to Speak", placeholder="Enter the text you want to convert to speech...") tts_description = gr.Textbox(label="Voice Description", value="A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up.", lines=3) tts_submit_btn = gr.Button("Generate Speech", variant="primary") with gr.Column(scale=1): tts_output_audio = gr.Audio(label="Generated Speech") tts_submit_btn.click(generate_speech, inputs=[tts_prompt, tts_description], outputs=[tts_output_audio]) gr.HTML("") if __name__ == "__main__": demo.launch()