Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel, PeftConfig | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import torchvision.datasets as datasets | |
import numpy as np | |
# Load SigLIP for image embeddings | |
from model.siglip import SigLIPModel | |
def get_cifar_examples(): | |
# Load CIFAR10 test set | |
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True) | |
# CIFAR10 classes | |
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck'] | |
# Get one example from each class | |
examples = [] | |
used_classes = set() | |
for idx in range(len(cifar10_test)): | |
img, label = cifar10_test[idx] | |
if classes[label] not in used_classes: | |
# Save the image temporarily | |
img_path = f"examples/{classes[label]}_example.jpg" | |
img.save(img_path) | |
examples.append(img_path) | |
used_classes.add(classes[label]) | |
if len(used_classes) == 10: # We have one example from each class | |
break | |
return examples | |
def load_models(): | |
# Load SigLIP model | |
siglip = SigLIPModel() | |
# Load base Phi model | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Phi-3-mini-4k-instruct", | |
trust_remote_code=True, | |
device_map="auto", | |
torch_dtype=torch.float32 | |
) | |
# Load our fine-tuned LoRA adapter | |
model = PeftModel.from_pretrained( | |
base_model, | |
"jatingocodeo/phi-vlm", # Your uploaded model | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm") | |
return siglip, model, tokenizer | |
def generate_description(image, siglip, model, tokenizer): | |
# Convert image to RGB if needed | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Resize image to match SigLIP's expected size | |
image = image.resize((32, 32)) | |
# Get image embedding from SigLIP | |
image_embedding = siglip.encode_image(image) | |
# Prepare prompt | |
prompt = """Below is an image. Please describe it in detail. | |
Image: <image> | |
Description: """ | |
# Tokenize input | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=128 | |
).to(model.device) | |
# Generate description | |
with torch.no_grad(): | |
outputs = model( | |
**inputs, | |
image_embeddings=image_embedding.unsqueeze(0), | |
max_new_tokens=100, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9 | |
) | |
# Decode and return the generated text | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text.split("Description: ")[-1].strip() | |
# Load models | |
print("Loading models...") | |
siglip, model, tokenizer = load_models() | |
# Create Gradio interface | |
def process_image(image): | |
description = generate_description(image, siglip, model, tokenizer) | |
return description | |
# Get CIFAR10 examples | |
examples = get_cifar_examples() | |
# Define interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Textbox(label="Generated Description"), | |
title="Image Description Generator", | |
description="""Upload an image and get a detailed description generated by our fine-tuned VLM model. | |
Below are sample images from CIFAR10 dataset that you can try.""", | |
examples=[[ex] for ex in examples] # Format examples for Gradio | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |