jatingocodeo commited on
Commit
3255105
·
verified ·
1 Parent(s): 0d387a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel, PeftConfig
5
+ from PIL import Image
6
+ import requests
7
+ from io import BytesIO
8
+ import torchvision.datasets as datasets
9
+ import numpy as np
10
+
11
+ # Load SigLIP for image embeddings
12
+ from model.siglip import SigLIPModel
13
+
14
+ def get_cifar_examples():
15
+ # Load CIFAR10 test set
16
+ cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)
17
+
18
+ # CIFAR10 classes
19
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
20
+ 'dog', 'frog', 'horse', 'ship', 'truck']
21
+
22
+ # Get one example from each class
23
+ examples = []
24
+ used_classes = set()
25
+
26
+ for idx in range(len(cifar10_test)):
27
+ img, label = cifar10_test[idx]
28
+ if classes[label] not in used_classes:
29
+ # Save the image temporarily
30
+ img_path = f"examples/{classes[label]}_example.jpg"
31
+ img.save(img_path)
32
+ examples.append(img_path)
33
+ used_classes.add(classes[label])
34
+
35
+ if len(used_classes) == 10: # We have one example from each class
36
+ break
37
+
38
+ return examples
39
+
40
+ def load_models():
41
+ # Load SigLIP model
42
+ siglip = SigLIPModel()
43
+
44
+ # Load base Phi model
45
+ base_model = AutoModelForCausalLM.from_pretrained(
46
+ "microsoft/Phi-3-mini-4k-instruct",
47
+ trust_remote_code=True,
48
+ device_map="auto",
49
+ torch_dtype=torch.float32
50
+ )
51
+
52
+ # Load our fine-tuned LoRA adapter
53
+ model = PeftModel.from_pretrained(
54
+ base_model,
55
+ "jatingocodeo/phi-vlm", # Your uploaded model
56
+ device_map="auto"
57
+ )
58
+
59
+ tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm")
60
+
61
+ return siglip, model, tokenizer
62
+
63
+ def generate_description(image, siglip, model, tokenizer):
64
+ # Convert image to RGB if needed
65
+ if image.mode != "RGB":
66
+ image = image.convert("RGB")
67
+
68
+ # Resize image to match SigLIP's expected size
69
+ image = image.resize((32, 32))
70
+
71
+ # Get image embedding from SigLIP
72
+ image_embedding = siglip.encode_image(image)
73
+
74
+ # Prepare prompt
75
+ prompt = """Below is an image. Please describe it in detail.
76
+
77
+ Image: <image>
78
+ Description: """
79
+
80
+ # Tokenize input
81
+ inputs = tokenizer(
82
+ prompt,
83
+ return_tensors="pt",
84
+ padding=True,
85
+ truncation=True,
86
+ max_length=128
87
+ ).to(model.device)
88
+
89
+ # Generate description
90
+ with torch.no_grad():
91
+ outputs = model(
92
+ **inputs,
93
+ image_embeddings=image_embedding.unsqueeze(0),
94
+ max_new_tokens=100,
95
+ temperature=0.7,
96
+ do_sample=True,
97
+ top_p=0.9
98
+ )
99
+
100
+ # Decode and return the generated text
101
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ return generated_text.split("Description: ")[-1].strip()
103
+
104
+ # Load models
105
+ print("Loading models...")
106
+ siglip, model, tokenizer = load_models()
107
+
108
+ # Create Gradio interface
109
+ def process_image(image):
110
+ description = generate_description(image, siglip, model, tokenizer)
111
+ return description
112
+
113
+ # Get CIFAR10 examples
114
+ examples = get_cifar_examples()
115
+
116
+ # Define interface
117
+ iface = gr.Interface(
118
+ fn=process_image,
119
+ inputs=gr.Image(type="pil"),
120
+ outputs=gr.Textbox(label="Generated Description"),
121
+ title="Image Description Generator",
122
+ description="""Upload an image and get a detailed description generated by our fine-tuned VLM model.
123
+ Below are sample images from CIFAR10 dataset that you can try.""",
124
+ examples=[[ex] for ex in examples] # Format examples for Gradio
125
+ )
126
+
127
+ # Launch the interface
128
+ if __name__ == "__main__":
129
+ iface.launch()