Multimodal_App / app.py
sagar007's picture
Update app.py
1b8f6f0 verified
raw
history blame
7.84 kB
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread
from PIL import Image
import subprocess
import spaces
from parler_tts import ParlerTTSForConditionalGeneration
import soundfile as sf
# Install flash-attention
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Constants
TITLE = "<h1><center>Phi 3.5 Multimodal (Text + Vision)</center></h1>"
DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)"
# Model configurations
TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Quantization config for text model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Load models and tokenizers
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
text_model = AutoModelForCausalLM.from_pretrained(
TEXT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config
)
vision_model = AutoModelForCausalLM.from_pretrained(
VISION_MODEL_ID,
trust_remote_code=True,
torch_dtype="auto",
attn_implementation="flash_attention_2"
).to(device).eval()
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
# Helper functions
@spaces.GPU
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
conversation = [{"role": "system", "content": system_prompt}]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001, 128008, 128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield history + [[message, buffer]]
@spaces.GPU # Add this decorator
def process_vision_query(image, text_input):
prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
image = Image.fromarray(image).convert("RGB")
inputs = vision_processor(prompt, image, return_tensors="pt").to(device)
with torch.no_grad():
generate_ids = vision_model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=vision_processor.tokenizer.eos_token_id
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response
# Load Parler-TTS model
tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1").to(tts_device)
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
@spaces.GPU
def generate_speech(prompt, description):
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to(tts_device)
prompt_input_ids = tts_tokenizer(prompt, return_tensors="pt").input_ids.to(tts_device)
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
output_path = "output_audio.wav"
sf.write(output_path, audio_arr, tts_model.config.sampling_rate)
return output_path
# Custom CSS
custom_css = """
body { background-color: #0b0f19; color: #e2e8f0; font-family: 'Arial', sans-serif;}
#custom-header { text-align: center; padding: 20px 0; background-color: #1a202c; margin-bottom: 20px; border-radius: 10px;}
#custom-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem;}
#custom-header h1 .blue { color: #60a5fa;}
#custom-header h1 .pink { color: #f472b6;}
#custom-header h2 { font-size: 1.5rem; color: #94a3b8;}
.suggestions { display: flex; justify-content: center; flex-wrap: wrap; gap: 1rem; margin: 20px 0;}
.suggestion { background-color: #1e293b; border-radius: 0.5rem; padding: 1rem; display: flex; align-items: center; transition: transform 0.3s ease; width: 200px;}
.suggestion:hover { transform: translateY(-5px);}
.suggestion-icon { font-size: 1.5rem; margin-right: 1rem; background-color: #2d3748; padding: 0.5rem; border-radius: 50%;}
.gradio-container { max-width: 100% !important;}
#component-0, #component-1, #component-2 { max-width: 100% !important;}
footer { text-align: center; margin-top: 2rem; color: #64748b;}
"""
# Custom HTML for the header
custom_header = """
<div id="custom-header">
<h1><span class="blue">Phi 3.5</span> <span class="pink">Multimodal Assistant</span></h1>
<h2>Text and Vision AI at Your Service</h2>
</div>
"""
# Custom HTML for suggestions
custom_suggestions = """
<div class="suggestions">
<div class="suggestion">
<span class="suggestion-icon">πŸ’¬</span>
<p>Chat with the Text Model</p>
</div>
<div class="suggestion">
<span class="suggestion-icon">πŸ–ΌοΈ</span>
<p>Analyze Images with Vision Model</p>
</div>
<div class="suggestion">
<span class="suggestion-icon">πŸ”Š</span>
<p>Generate Speech with Parler-TTS</p>
</div>
<div class="suggestion">
<span class="suggestion-icon">πŸ”</span>
<p>Explore advanced options</p>
</div>
</div>
"""
# Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
body_background_fill="#0b0f19",
body_text_color="#e2e8f0",
button_primary_background_fill="#3b82f6",
button_primary_background_fill_hover="#2563eb",
button_primary_text_color="white",
block_title_text_color="#94a3b8",
block_label_text_color="#94a3b8",
)) as demo:
gr.HTML(custom_header)
gr.HTML(custom_suggestions)
with gr.Tab("Text Model (Phi-3.5-mini)"):
# ... (previous text model code remains the same)
with gr.Tab("Vision Model (Phi-3.5-vision)"):
# ... (previous vision model code remains the same)
with gr.Tab("Text-to-Speech (Parler-TTS)"):
with gr.Row():
with gr.Column(scale=1):
tts_prompt = gr.Textbox(label="Text to Speak", placeholder="Enter the text you want to convert to speech...")
tts_description = gr.Textbox(label="Voice Description", value="A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up.", lines=3)
tts_submit_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column(scale=1):
tts_output_audio = gr.Audio(label="Generated Speech")
tts_submit_btn.click(generate_speech, inputs=[tts_prompt, tts_description], outputs=[tts_output_audio])
gr.HTML("<footer>Powered by Phi 3.5 Multimodal AI and Parler-TTS</footer>")
if __name__ == "__main__":
demo.launch()