obichimav's picture
Create app.py
70a1336 verified
raw
history blame
11.3 kB
# imports
import os
import json
import base64
from io import BytesIO
from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr
import cv2
import numpy as np
from PIL import Image, ImageDraw
import requests
import torch
from transformers import (
AutoProcessor,
Owlv2ForObjectDetection,
AutoModelForZeroShotObjectDetection
)
# from transformers import AutoProcessor, Owlv2ForObjectDetection
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
# Initialization
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
MODEL = "gpt-4o"
openai = OpenAI()
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
# Owlv2
owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
# DINO
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
system_message = """You are an expert in object detection. When users mention:
1. "count [object(s)]" - Use detect_objects with proper format based on model
2. "detect [object(s)]" - Same as count
3. "show [object(s)]" - Same as count
For DINO model: Format queries as "a [object]." (e.g., "a frog.")
For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]]
Always use object detection tool when counting/detecting is mentioned."""
system_message += "Always be accurate. If you don't know the answer, say so."
class State:
def __init__(self):
self.current_image = None
self.last_prediction = None
self.current_model = "owlv2" # Default model
state = State()
def get_preprocessed_image(pixel_values):
pixel_values = pixel_values.squeeze().numpy()
unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
return unnormalized_image
def encode_image_to_base64(image_array):
if image_array is None:
return None
image = Image.fromarray(image_array)
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def format_query_for_model(text_input, model_type="owlv2"):
"""Format query based on model requirements"""
# Extract objects (e.g., "count frogs and horses" -> ["frog", "horse"])
text = text_input.lower()
words = [w.strip('.,?!') for w in text.split()
if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
if model_type == "owlv2":
return [["a photo of " + obj for obj in words]]
else: # DINO
# DINO only works with single object queries with format "a object."
return f"a {words[0]}."
def detect_objects(query_text):
if state.current_image is None:
return {"count": 0, "message": "No image provided"}
image = Image.fromarray(state.current_image)
draw = ImageDraw.Draw(image)
if state.current_model == "owlv2":
inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = owlv2_model(**inputs)
results = owlv2_processor.post_process_object_detection(
outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
)
else: # DINO
inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = dino_model(**inputs)
results = dino_processor.post_process_grounded_object_detection(
outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3,
target_sizes=[image.size[::-1]]
)
# Draw detection boxes
boxes = results[0]["boxes"]
scores = results[0]["scores"]
for box, score in zip(boxes, scores):
box = [round(i) for i in box.tolist()]
draw.rectangle(box, outline="red", width=3)
draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red")
state.last_prediction = np.array(image)
return {
"count": len(boxes),
"confidence": scores.tolist(),
"message": f"Detected {len(boxes)} objects"
}
def identify_plant():
if state.current_image is None:
return {"error": "No image provided"}
image = Image.fromarray(state.current_image)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}"
files = [('images', ('image.jpg', img_byte_arr))]
data = {'organs': ['leaf']}
try:
response = requests.post(api_endpoint, files=files, data=data)
if response.status_code == 200:
result = response.json()
best_match = result['results'][0]
return {
"scientific_name": best_match['species']['scientificName'],
"common_names": best_match['species'].get('commonNames', []),
"family": best_match['species']['family']['scientificName'],
"genus": best_match['species']['genus']['scientificName'],
"confidence": f"{best_match['score']*100:.1f}%"
}
else:
return {"error": f"API Error: {response.status_code}"}
except Exception as e:
return {"error": f"Error: {str(e)}"}
# Tool definitions
object_detection_function = {
"name": "detect_objects",
"description": "Use this function to detect and count objects in images based on text queries.",
"parameters": {
"type": "object",
"properties": {
"query_text": {
"type": "array",
"description": "List of text queries describing objects to detect",
"items": {"type": "string"}
}
}
}
}
plant_identification_function = {
"name": "identify_plant",
"description": "Use this when asked about plant species identification or botanical classification.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
tools = [
{"type": "function", "function": object_detection_function},
{"type": "function", "function": plant_identification_function}
]
def format_tool_response(tool_response_content):
data = json.loads(tool_response_content)
if "error" in data:
return f"Error: {data['error']}"
elif "scientific_name" in data:
return f"""πŸ“‹ Plant Identification Results:
🌿 Scientific Name: {data['scientific_name']}
πŸ‘₯ Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'}
πŸ‘ͺ Family: {data['family']}
🎯 Confidence: {data['confidence']}"""
else:
return f"I detected {data['count']} objects in the image."
def chat(message, image, history):
if image is not None:
state.current_image = image
if state.current_image is None:
return "Please upload an image first.", None
base64_image = encode_image_to_base64(state.current_image)
messages = [{"role": "system", "content": system_message}]
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
# Extract objects to detect from user message
# This could be enhanced with better NLP
objects_to_detect = message.lower()
formatted_query = format_query_for_model(objects_to_detect, state.current_model)
messages.append({
"role": "user",
"content": [
{"type": "text", "text": message},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
]
})
response = openai.chat.completions.create(
model=MODEL,
messages=messages,
tools=tools,
max_tokens=300
)
if response.choices[0].finish_reason == "tool_calls":
message = response.choices[0].message
messages.append(message)
for tool_call in message.tool_calls:
if tool_call.function.name == "detect_objects":
results = detect_objects(formatted_query)
else:
results = identify_plant()
tool_response = {
"role": "tool",
"content": json.dumps(results),
"tool_call_id": tool_call.id
}
messages.append(tool_response)
response = openai.chat.completions.create(
model=MODEL,
messages=messages,
max_tokens=300
)
return response.choices[0].message.content, state.last_prediction
def update_model(choice):
print(f"Model switched to: {choice}")
state.current_model = choice.lower()
return f"Model switched to {choice}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Object Detection and Plant Analysis System")
with gr.Row():
with gr.Column():
model_choice = gr.Radio(
choices=["Owlv2", "DINO"],
value="Owlv2",
label="Select Detection Model",
interactive=True
)
image_input = gr.Image(type="numpy", label="Upload Image")
text_input = gr.Textbox(
label="Ask about the image",
placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'"
)
with gr.Row():
submit_btn = gr.Button("Analyze")
reset_btn = gr.Button("Reset")
with gr.Column():
chatbot = gr.Chatbot()
# output_image = gr.Image(label="Detected Objects")
output_image = gr.Image(type="numpy", label="Detected Objects")
def process_interaction(message, image, history):
response, pred_image = chat(message, image, history)
history.append((message, response))
return "", pred_image, history
def reset_interface():
state.current_image = None
state.last_prediction = None
return None, None, None, []
model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)])
submit_btn.click(
fn=process_interaction,
inputs=[text_input, image_input, chatbot],
outputs=[text_input, output_image, chatbot]
)
reset_btn.click(
fn=reset_interface,
inputs=[],
outputs=[image_input, output_image, text_input, chatbot]
)
gr.Markdown("""## Instructions
1. Select the detection model (Owlv2 or DINO)
2. Upload an image
3. Ask specific questions about objects or plants
4. Click Analyze to get results""")
demo.launch(share=True)