phi-vlm / app.py
jatingocodeo's picture
Update app.py
d98b4df verified
raw
history blame
3.13 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from PIL import Image
import torchvision.datasets as datasets
def load_model():
# Load base Phi model
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float32
)
# Load our fine-tuned LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"jatingocodeo/phi-vlm", # Your uploaded model
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm")
return model, tokenizer
def generate_description(image, model, tokenizer):
# Convert image to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Resize image to match training size
image = image.resize((32, 32))
# Prepare prompt
prompt = """Below is an image. Please describe it in detail.
Image: <image>
Description: """
# Tokenize input
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
).to(model.device)
# Generate description
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
do_sample=True,
top_p=0.9
)
# Decode and return the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text.split("Description: ")[-1].strip()
# Load model
print("Loading model...")
model, tokenizer = load_model()
# Get CIFAR10 examples
def get_cifar_examples():
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
examples = []
used_classes = set()
for idx in range(len(cifar10_test)):
img, label = cifar10_test[idx]
if classes[label] not in used_classes:
img_path = f"examples/{classes[label]}_example.jpg"
img.save(img_path)
examples.append(img_path)
used_classes.add(classes[label])
if len(used_classes) == 10:
break
return examples
# Create Gradio interface
def process_image(image):
return generate_description(image, model, tokenizer)
# Get examples
examples = get_cifar_examples()
# Define interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Generated Description"),
title="Image Description Generator",
description="""Upload an image and get a detailed description generated by our fine-tuned VLM model.
Below are sample images from CIFAR10 dataset that you can try.""",
examples=[[ex] for ex in examples]
)
# Launch the interface
if __name__ == "__main__":
iface.launch()