File size: 3,719 Bytes
3255105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from PIL import Image
import requests
from io import BytesIO
import torchvision.datasets as datasets
import numpy as np

# Load SigLIP for image embeddings
from model.siglip import SigLIPModel

def get_cifar_examples():
    # Load CIFAR10 test set
    cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)
    
    # CIFAR10 classes
    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
              'dog', 'frog', 'horse', 'ship', 'truck']
    
    # Get one example from each class
    examples = []
    used_classes = set()
    
    for idx in range(len(cifar10_test)):
        img, label = cifar10_test[idx]
        if classes[label] not in used_classes:
            # Save the image temporarily
            img_path = f"examples/{classes[label]}_example.jpg"
            img.save(img_path)
            examples.append(img_path)
            used_classes.add(classes[label])
            
        if len(used_classes) == 10:  # We have one example from each class
            break
    
    return examples

def load_models():
    # Load SigLIP model
    siglip = SigLIPModel()
    
    # Load base Phi model
    base_model = AutoModelForCausalLM.from_pretrained(
        "microsoft/Phi-3-mini-4k-instruct",
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.float32
    )
    
    # Load our fine-tuned LoRA adapter
    model = PeftModel.from_pretrained(
        base_model,
        "jatingocodeo/phi-vlm",  # Your uploaded model
        device_map="auto"
    )
    
    tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm")
    
    return siglip, model, tokenizer

def generate_description(image, siglip, model, tokenizer):
    # Convert image to RGB if needed
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    # Resize image to match SigLIP's expected size
    image = image.resize((32, 32))
    
    # Get image embedding from SigLIP
    image_embedding = siglip.encode_image(image)
    
    # Prepare prompt
    prompt = """Below is an image. Please describe it in detail.

Image: <image>
Description: """
    
    # Tokenize input
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    ).to(model.device)
    
    # Generate description
    with torch.no_grad():
        outputs = model(
            **inputs,
            image_embeddings=image_embedding.unsqueeze(0),
            max_new_tokens=100,
            temperature=0.7,
            do_sample=True,
            top_p=0.9
        )
    
    # Decode and return the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text.split("Description: ")[-1].strip()

# Load models
print("Loading models...")
siglip, model, tokenizer = load_models()

# Create Gradio interface
def process_image(image):
    description = generate_description(image, siglip, model, tokenizer)
    return description

# Get CIFAR10 examples
examples = get_cifar_examples()

# Define interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(label="Generated Description"),
    title="Image Description Generator",
    description="""Upload an image and get a detailed description generated by our fine-tuned VLM model.
                  Below are sample images from CIFAR10 dataset that you can try.""",
    examples=[[ex] for ex in examples]  # Format examples for Gradio
)

# Launch the interface
if __name__ == "__main__":
    iface.launch()