Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
import requests | |
from openai import OpenAI | |
from transformers import (Owlv2Processor, Owlv2ForObjectDetection, | |
AutoProcessor, AutoModelForMaskGeneration) | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
import base64 | |
import io | |
import numpy as np | |
import gradio as gr | |
import json | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') | |
def encode_image_to_base64(image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
def analyze_image(image): | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
base64_image = encode_image_to_base64(image) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": """Your task is to determine if the image is surprising or not surprising. | |
if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'. | |
Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising. | |
Provide the response as a JSON with the following structure: | |
{ | |
"label": "[surprising OR not surprising]", | |
"element": "[element]", | |
"rating": [1-5] | |
}""" | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_image}" | |
} | |
} | |
] | |
} | |
] | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=messages, | |
max_tokens=100, | |
temperature=0.1, | |
response_format={ | |
"type": "json_object" | |
} | |
) | |
return response.choices[0].message.content | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([1.0, 0.0, 0.0, 0.5]) | |
if len(mask.shape) == 4: | |
mask = mask[0, 0] | |
mask_image = np.zeros((*mask.shape, 4), dtype=np.float32) | |
mask_image[mask > 0] = color | |
ax.imshow(mask_image) | |
def process_image_detection(image, target_label, surprise_rating): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14") | |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-large-patch14").to(device) | |
sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base") | |
sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device) | |
image_np = np.array(image) | |
inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = owlv2_model(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]).to(device) | |
results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0] | |
fig = plt.figure(figsize=(10, 10)) | |
plt.imshow(image) | |
ax = plt.gca() | |
scores = results["scores"] | |
if len(scores) > 0: | |
max_score_idx = scores.argmax().item() | |
max_score = scores[max_score_idx].item() | |
if max_score > 0.2: | |
box = results["boxes"][max_score_idx].cpu().numpy() | |
sam_inputs = sam_processor( | |
image, | |
input_boxes=[[[box[0], box[1], box[2], box[3]]]], | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
sam_outputs = sam_model(**sam_inputs) | |
masks = sam_processor.image_processor.post_process_masks( | |
sam_outputs.pred_masks.cpu(), | |
sam_inputs["original_sizes"].cpu(), | |
sam_inputs["reshaped_input_sizes"].cpu() | |
) | |
mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0] | |
show_mask(mask, ax=ax) | |
rect = patches.Rectangle( | |
(box[0], box[1]), | |
box[2] - box[0], | |
box[3] - box[1], | |
linewidth=2, | |
edgecolor='red', | |
facecolor='none' | |
) | |
ax.add_patch(rect) | |
plt.text( | |
box[0], box[1] - 5, | |
f'{max_score:.2f}', | |
color='red' | |
) | |
plt.text( | |
box[2] + 5, box[1], | |
f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}', | |
color='red', | |
fontsize=10, | |
verticalalignment='bottom' | |
) | |
plt.axis('off') | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
buf.seek(0) | |
plt.close() | |
return buf | |
def process_and_analyze(image): | |
if image is None: | |
return None, "Please upload an image first." | |
if OPENAI_API_KEY is None: | |
return None, "OpenAI API key not found in environment variables." | |
# Convert numpy array to PIL Image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
try: | |
# Analyze image with GPT-4 | |
gpt_response = analyze_image(image) | |
response_data = json.loads(gpt_response) | |
analysis_text = f"Label: {response_data['label']}\nElement: {response_data['element']}\nRating: {response_data['rating']}/5" | |
if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na": | |
# Process image with detection models | |
result_buf = process_image_detection(image, response_data["element"], response_data["rating"]) | |
result_image = Image.open(result_buf) | |
return result_image, analysis_text | |
else: | |
return image, f"{analysis_text}\nImage not surprising or no specific element found." | |
except Exception as e: | |
return None, f"Error processing image: {str(e)}" | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Surprise Analysis") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload Image") | |
analyze_btn = gr.Button("Analyze Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Processed Image") | |
output_text = gr.Textbox(label="Analysis Results") | |
analyze_btn.click( | |
fn=process_and_analyze, | |
inputs=[input_image], | |
outputs=[output_image, output_text] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |