reab5555 commited on
Commit
13d7bed
·
verified ·
1 Parent(s): 165cb43

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import requests
4
+ from openai import OpenAI
5
+ from transformers import (Owlv2Processor, Owlv2ForObjectDetection,
6
+ AutoProcessor, AutoModelForMaskGeneration)
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
+ import base64
10
+ import io
11
+ import numpy as np
12
+ import gradio as gr
13
+ import json
14
+ import os
15
+ from dotenv import load_dotenv
16
+
17
+ # Load environment variables
18
+ load_dotenv()
19
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
20
+
21
+
22
+ def encode_image_to_base64(image):
23
+ buffered = io.BytesIO()
24
+ image.save(buffered, format="PNG")
25
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
26
+
27
+
28
+ def analyze_image(image):
29
+ client = OpenAI(api_key=OPENAI_API_KEY)
30
+ base64_image = encode_image_to_base64(image)
31
+
32
+ messages = [
33
+ {
34
+ "role": "user",
35
+ "content": [
36
+ {
37
+ "type": "text",
38
+ "text": """Your task is to determine if the image is surprising or not surprising.
39
+ if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'.
40
+ Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
41
+ Provide the response as a JSON with the following structure:
42
+ {
43
+ "label": "[surprising OR not surprising]",
44
+ "element": "[element]",
45
+ "rating": [1-5]
46
+ }"""
47
+ },
48
+ {
49
+ "type": "image_url",
50
+ "image_url": {
51
+ "url": f"data:image/jpeg;base64,{base64_image}"
52
+ }
53
+ }
54
+ ]
55
+ }
56
+ ]
57
+
58
+ response = client.chat.completions.create(
59
+ model="gpt-4-vision-preview",
60
+ messages=messages,
61
+ max_tokens=100,
62
+ temperature=0.1,
63
+ response_format={
64
+ "type": "json_object"
65
+ }
66
+ )
67
+
68
+ return response.choices[0].message.content
69
+
70
+
71
+ def show_mask(mask, ax, random_color=False):
72
+ if random_color:
73
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
74
+ else:
75
+ color = np.array([1.0, 0.0, 0.0, 0.5])
76
+
77
+ if len(mask.shape) == 4:
78
+ mask = mask[0, 0]
79
+
80
+ mask_image = np.zeros((*mask.shape, 4), dtype=np.float32)
81
+ mask_image[mask > 0] = color
82
+
83
+ ax.imshow(mask_image)
84
+
85
+
86
+ def process_image_detection(image, target_label, surprise_rating):
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
+
89
+ owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14")
90
+ owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-large-patch14").to(device)
91
+
92
+ sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
93
+ sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
94
+
95
+ image_np = np.array(image)
96
+
97
+ inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
98
+ with torch.no_grad():
99
+ outputs = owlv2_model(**inputs)
100
+
101
+ target_sizes = torch.tensor([image.size[::-1]]).to(device)
102
+ results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
103
+
104
+ fig = plt.figure(figsize=(10, 10))
105
+ plt.imshow(image)
106
+ ax = plt.gca()
107
+
108
+ scores = results["scores"]
109
+ if len(scores) > 0:
110
+ max_score_idx = scores.argmax().item()
111
+ max_score = scores[max_score_idx].item()
112
+
113
+ if max_score > 0.2:
114
+ box = results["boxes"][max_score_idx].cpu().numpy()
115
+
116
+ sam_inputs = sam_processor(
117
+ image,
118
+ input_boxes=[[[box[0], box[1], box[2], box[3]]]],
119
+ return_tensors="pt"
120
+ ).to(device)
121
+
122
+ with torch.no_grad():
123
+ sam_outputs = sam_model(**sam_inputs)
124
+
125
+ masks = sam_processor.image_processor.post_process_masks(
126
+ sam_outputs.pred_masks.cpu(),
127
+ sam_inputs["original_sizes"].cpu(),
128
+ sam_inputs["reshaped_input_sizes"].cpu()
129
+ )
130
+
131
+ mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
132
+ show_mask(mask, ax=ax)
133
+
134
+ rect = patches.Rectangle(
135
+ (box[0], box[1]),
136
+ box[2] - box[0],
137
+ box[3] - box[1],
138
+ linewidth=2,
139
+ edgecolor='red',
140
+ facecolor='none'
141
+ )
142
+ ax.add_patch(rect)
143
+
144
+ plt.text(
145
+ box[0], box[1] - 5,
146
+ f'{max_score:.2f}',
147
+ color='red'
148
+ )
149
+
150
+ plt.text(
151
+ box[2] + 5, box[1],
152
+ f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
153
+ color='red',
154
+ fontsize=10,
155
+ verticalalignment='bottom'
156
+ )
157
+
158
+ plt.axis('off')
159
+
160
+ buf = io.BytesIO()
161
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
162
+ buf.seek(0)
163
+ plt.close()
164
+
165
+ return buf
166
+
167
+
168
+ def process_and_analyze(image):
169
+ if image is None:
170
+ return None, "Please upload an image first."
171
+
172
+ if OPENAI_API_KEY is None:
173
+ return None, "OpenAI API key not found in environment variables."
174
+
175
+ # Convert numpy array to PIL Image
176
+ if isinstance(image, np.ndarray):
177
+ image = Image.fromarray(image)
178
+
179
+ try:
180
+ # Analyze image with GPT-4
181
+ gpt_response = analyze_image(image)
182
+ response_data = json.loads(gpt_response)
183
+
184
+ analysis_text = f"Label: {response_data['label']}\nElement: {response_data['element']}\nRating: {response_data['rating']}/5"
185
+
186
+ if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
187
+ # Process image with detection models
188
+ result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
189
+ result_image = Image.open(result_buf)
190
+ return result_image, analysis_text
191
+ else:
192
+ return image, f"{analysis_text}\nImage not surprising or no specific element found."
193
+
194
+ except Exception as e:
195
+ return None, f"Error processing image: {str(e)}"
196
+
197
+
198
+ # Create Gradio interface
199
+ def create_interface():
200
+ with gr.Blocks() as demo:
201
+ gr.Markdown("# Image Surprise Analysis")
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ input_image = gr.Image(label="Upload Image")
206
+ analyze_btn = gr.Button("Analyze Image")
207
+
208
+ with gr.Column():
209
+ output_image = gr.Image(label="Processed Image")
210
+ output_text = gr.Textbox(label="Analysis Results")
211
+
212
+ analyze_btn.click(
213
+ fn=process_and_analyze,
214
+ inputs=[input_image],
215
+ outputs=[output_image, output_text]
216
+ )
217
+
218
+ return demo
219
+
220
+
221
+ if __name__ == "__main__":
222
+ demo = create_interface()
223
+ demo.launch()