salma-remyx's picture
initial commit
a9f74e5
import spaces
import torch
import time
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from typing import List
MODEL_ID = "remyxai/SpaceQwen2.5-VL-3B-Instruct"
@spaces.GPU
def load_model():
print("Loading model and processor...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(device)
processor = AutoProcessor.from_pretrained(MODEL_ID)
return model, processor
model, processor = load_model()
def process_image(image_path_or_obj):
"""Loads, resizes, and preprocesses an image path or Pillow Image."""
if isinstance(image_path_or_obj, str):
# Path on disk or from history
image = Image.open(image_path_or_obj).convert("RGB")
elif isinstance(image_path_or_obj, Image.Image):
image = image_path_or_obj.convert("RGB")
else:
raise ValueError("process_image expects a file path (str) or PIL.Image")
max_width = 512
if image.width > max_width:
aspect_ratio = image.height / image.width
new_height = int(max_width * aspect_ratio)
image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
print(f"Resized image to: {max_width}x{new_height}")
return image
def get_latest_image(history):
"""
Look from the end to find the last user-uploaded image (stored as (file_path,) ).
Return None if not found.
"""
for user_msg, _assistant_msg in reversed(history):
if isinstance(user_msg, tuple) and len(user_msg) > 0:
return user_msg[0]
return None
def only_assistant_text(full_text: str) -> str:
"""
Helper to strip out any lines containing 'system', 'user', etc.,
and return only the final assistant answer.
Adjust this parsing if your model's output format differs.
"""
# Example output might look like:
# system
# ...
# user
# ...
# assistant
# The final answer
#
# We'll just split on 'assistant' and return everything after it.
if "assistant" in full_text:
parts = full_text.split("assistant", 1)
result = parts[-1].strip()
# Remove any leading punctuation (like a colon)
result = result.lstrip(":").strip()
return result
return full_text.strip()
def run_inference(image, prompt):
"""Runs Qwen2.5-VL inference on a single image and text prompt."""
system_msg = (
"You are a Vision Language Model specialized in interpreting visual data from images. "
"Your task is to analyze the provided image and respond to queries with concise answers."
)
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": system_msg}],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
},
]
text_input = processor.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
inputs = processor(text=[text_input], images=[image], return_tensors="pt").to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1024)
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# Parse out only the final assistant text
return only_assistant_text(output_text)
def add_message(history, user_input):
"""
Step 1 (triggered by user's 'Submit' or 'Send'):
- Save new text or images into `history`.
- The Chatbot display uses pairs: [user_text_or_image, assistant_reply].
"""
if not isinstance(history, list):
history = []
files = user_input.get("files", [])
text = user_input.get("text", "")
# Store images
for f in files:
# Each image is stored as `[(file_path,), None]`
history.append([(f,), None])
# Store text
if text:
history.append([text, None])
return history, gr.MultimodalTextbox(value=None)
def inference_interface(history):
"""
Step 2: Use the most recent text + the most recent image to run Qwen2.5-VL.
Instead of adding another entry, we fill the assistant's answer into
the last user text entry.
"""
if not history:
return history, gr.MultimodalTextbox(value=None)
# 1) Get the user's most recent text
user_text = ""
# We'll search from the end for the first str we find
for idx in range(len(history) - 1, -1, -1):
user_msg, assistant_msg = history[idx]
if isinstance(user_msg, str):
user_text = user_msg
# We'll also keep track of this index so we can fill in the assistant reply
user_idx = idx
break
else:
# No user text found
print("No user text found in history. Skipping inference.")
return history, gr.MultimodalTextbox(value=None)
# 2) Get the latest image from the entire conversation
latest_image = get_latest_image(history)
if not latest_image:
# No image found => can't run the model
print("No image found in history. Skipping inference.")
return history, gr.MultimodalTextbox(value=None)
# 3) Process the image
pil_image = process_image(latest_image)
# 4) Run inference
assistant_reply = run_inference(pil_image, user_text)
# 5) Fill that assistant reply back into the last user text entry
history[user_idx][1] = assistant_reply
return history, gr.MultimodalTextbox(value=None)
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# SpaceQwen2.5-VL Image Prompt Chatbot")
chatbot = gr.Chatbot([], line_breaks=True)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter text or upload an image (or both).",
show_label=True
)
# When the user presses Enter in the MultimodalTextbox:
submit_event = chat_input.submit(
fn=add_message, # Step 1: store user data
inputs=[chatbot, chat_input],
outputs=[chatbot, chat_input]
)
# After storing, run inference
submit_event.then(
fn=inference_interface, # Step 2: run Qwen2.5-VL
inputs=[chatbot],
outputs=[chatbot, chat_input]
)
# Same logic for a "Send" button
with gr.Row():
send_button = gr.Button("Send")
clear_button = gr.ClearButton([chatbot, chat_input])
send_click = send_button.click(
fn=add_message,
inputs=[chatbot, chat_input],
outputs=[chatbot, chat_input]
)
send_click.then(
fn=inference_interface,
inputs=[chatbot],
outputs=[chatbot, chat_input]
)
# Example
gr.Examples(
examples=[
{
"text": "Give me the height of the man in the red hat in feet.",
"files": ["./examples/warehouse_rgb.jpg"]
}
],
inputs=[chat_input],
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch(share=True)