Spaces:
Running
Running
File size: 3,719 Bytes
3255105 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from PIL import Image
import requests
from io import BytesIO
import torchvision.datasets as datasets
import numpy as np
# Load SigLIP for image embeddings
from model.siglip import SigLIPModel
def get_cifar_examples():
# Load CIFAR10 test set
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)
# CIFAR10 classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Get one example from each class
examples = []
used_classes = set()
for idx in range(len(cifar10_test)):
img, label = cifar10_test[idx]
if classes[label] not in used_classes:
# Save the image temporarily
img_path = f"examples/{classes[label]}_example.jpg"
img.save(img_path)
examples.append(img_path)
used_classes.add(classes[label])
if len(used_classes) == 10: # We have one example from each class
break
return examples
def load_models():
# Load SigLIP model
siglip = SigLIPModel()
# Load base Phi model
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float32
)
# Load our fine-tuned LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"jatingocodeo/phi-vlm", # Your uploaded model
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm")
return siglip, model, tokenizer
def generate_description(image, siglip, model, tokenizer):
# Convert image to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Resize image to match SigLIP's expected size
image = image.resize((32, 32))
# Get image embedding from SigLIP
image_embedding = siglip.encode_image(image)
# Prepare prompt
prompt = """Below is an image. Please describe it in detail.
Image: <image>
Description: """
# Tokenize input
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
).to(model.device)
# Generate description
with torch.no_grad():
outputs = model(
**inputs,
image_embeddings=image_embedding.unsqueeze(0),
max_new_tokens=100,
temperature=0.7,
do_sample=True,
top_p=0.9
)
# Decode and return the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text.split("Description: ")[-1].strip()
# Load models
print("Loading models...")
siglip, model, tokenizer = load_models()
# Create Gradio interface
def process_image(image):
description = generate_description(image, siglip, model, tokenizer)
return description
# Get CIFAR10 examples
examples = get_cifar_examples()
# Define interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Generated Description"),
title="Image Description Generator",
description="""Upload an image and get a detailed description generated by our fine-tuned VLM model.
Below are sample images from CIFAR10 dataset that you can try.""",
examples=[[ex] for ex in examples] # Format examples for Gradio
)
# Launch the interface
if __name__ == "__main__":
iface.launch() |