|
|
|
import os |
|
import json |
|
import base64 |
|
from io import BytesIO |
|
from dotenv import load_dotenv |
|
from openai import OpenAI |
|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw |
|
import requests |
|
import torch |
|
from transformers import ( |
|
AutoProcessor, |
|
Owlv2ForObjectDetection, |
|
AutoModelForZeroShotObjectDetection |
|
) |
|
|
|
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |
|
|
|
|
|
load_dotenv() |
|
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') |
|
PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') |
|
MODEL = "gpt-4o" |
|
openai = OpenAI() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") |
|
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) |
|
|
|
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") |
|
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) |
|
|
|
system_message = """You are an expert in object detection. When users mention: |
|
1. "count [object(s)]" - Use detect_objects with proper format based on model |
|
2. "detect [object(s)]" - Same as count |
|
3. "show [object(s)]" - Same as count |
|
|
|
For DINO model: Format queries as "a [object]." (e.g., "a frog.") |
|
For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]] |
|
|
|
Always use object detection tool when counting/detecting is mentioned.""" |
|
|
|
system_message += "Always be accurate. If you don't know the answer, say so." |
|
|
|
|
|
class State: |
|
def __init__(self): |
|
self.current_image = None |
|
self.last_prediction = None |
|
self.current_model = "owlv2" |
|
|
|
state = State() |
|
|
|
def get_preprocessed_image(pixel_values): |
|
pixel_values = pixel_values.squeeze().numpy() |
|
unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None] |
|
unnormalized_image = (unnormalized_image * 255).astype(np.uint8) |
|
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1) |
|
return unnormalized_image |
|
|
|
def encode_image_to_base64(image_array): |
|
if image_array is None: |
|
return None |
|
image = Image.fromarray(image_array) |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
return base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
|
|
|
def format_query_for_model(text_input, model_type="owlv2"): |
|
"""Format query based on model requirements""" |
|
|
|
text = text_input.lower() |
|
words = [w.strip('.,?!') for w in text.split() |
|
if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']] |
|
|
|
if model_type == "owlv2": |
|
return [["a photo of " + obj for obj in words]] |
|
else: |
|
|
|
return f"a {words[0]}." |
|
|
|
def detect_objects(query_text): |
|
if state.current_image is None: |
|
return {"count": 0, "message": "No image provided"} |
|
|
|
image = Image.fromarray(state.current_image) |
|
draw = ImageDraw.Draw(image) |
|
|
|
if state.current_model == "owlv2": |
|
inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = owlv2_model(**inputs) |
|
results = owlv2_processor.post_process_object_detection( |
|
outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]]) |
|
) |
|
else: |
|
inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = dino_model(**inputs) |
|
results = dino_processor.post_process_grounded_object_detection( |
|
outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3, |
|
target_sizes=[image.size[::-1]] |
|
) |
|
|
|
|
|
boxes = results[0]["boxes"] |
|
scores = results[0]["scores"] |
|
|
|
for box, score in zip(boxes, scores): |
|
box = [round(i) for i in box.tolist()] |
|
draw.rectangle(box, outline="red", width=3) |
|
draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red") |
|
|
|
state.last_prediction = np.array(image) |
|
return { |
|
"count": len(boxes), |
|
"confidence": scores.tolist(), |
|
"message": f"Detected {len(boxes)} objects" |
|
} |
|
|
|
|
|
def identify_plant(): |
|
if state.current_image is None: |
|
return {"error": "No image provided"} |
|
|
|
image = Image.fromarray(state.current_image) |
|
img_byte_arr = BytesIO() |
|
image.save(img_byte_arr, format='JPEG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
|
api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" |
|
files = [('images', ('image.jpg', img_byte_arr))] |
|
data = {'organs': ['leaf']} |
|
|
|
try: |
|
response = requests.post(api_endpoint, files=files, data=data) |
|
if response.status_code == 200: |
|
result = response.json() |
|
best_match = result['results'][0] |
|
return { |
|
"scientific_name": best_match['species']['scientificName'], |
|
"common_names": best_match['species'].get('commonNames', []), |
|
"family": best_match['species']['family']['scientificName'], |
|
"genus": best_match['species']['genus']['scientificName'], |
|
"confidence": f"{best_match['score']*100:.1f}%" |
|
} |
|
else: |
|
return {"error": f"API Error: {response.status_code}"} |
|
except Exception as e: |
|
return {"error": f"Error: {str(e)}"} |
|
|
|
|
|
object_detection_function = { |
|
"name": "detect_objects", |
|
"description": "Use this function to detect and count objects in images based on text queries.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query_text": { |
|
"type": "array", |
|
"description": "List of text queries describing objects to detect", |
|
"items": {"type": "string"} |
|
} |
|
} |
|
} |
|
} |
|
|
|
plant_identification_function = { |
|
"name": "identify_plant", |
|
"description": "Use this when asked about plant species identification or botanical classification.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": {}, |
|
"required": [] |
|
} |
|
} |
|
|
|
tools = [ |
|
{"type": "function", "function": object_detection_function}, |
|
{"type": "function", "function": plant_identification_function} |
|
] |
|
|
|
def format_tool_response(tool_response_content): |
|
data = json.loads(tool_response_content) |
|
if "error" in data: |
|
return f"Error: {data['error']}" |
|
elif "scientific_name" in data: |
|
return f"""π Plant Identification Results: |
|
|
|
πΏ Scientific Name: {data['scientific_name']} |
|
π₯ Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} |
|
πͺ Family: {data['family']} |
|
π― Confidence: {data['confidence']}""" |
|
else: |
|
return f"I detected {data['count']} objects in the image." |
|
|
|
def chat(message, image, history): |
|
if image is not None: |
|
state.current_image = image |
|
|
|
if state.current_image is None: |
|
return "Please upload an image first.", None |
|
|
|
base64_image = encode_image_to_base64(state.current_image) |
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
for human, assistant in history: |
|
messages.append({"role": "user", "content": human}) |
|
messages.append({"role": "assistant", "content": assistant}) |
|
|
|
|
|
|
|
objects_to_detect = message.lower() |
|
formatted_query = format_query_for_model(objects_to_detect, state.current_model) |
|
|
|
messages.append({ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": message}, |
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} |
|
] |
|
}) |
|
|
|
response = openai.chat.completions.create( |
|
model=MODEL, |
|
messages=messages, |
|
tools=tools, |
|
max_tokens=300 |
|
) |
|
|
|
if response.choices[0].finish_reason == "tool_calls": |
|
message = response.choices[0].message |
|
messages.append(message) |
|
|
|
for tool_call in message.tool_calls: |
|
if tool_call.function.name == "detect_objects": |
|
results = detect_objects(formatted_query) |
|
else: |
|
results = identify_plant() |
|
|
|
tool_response = { |
|
"role": "tool", |
|
"content": json.dumps(results), |
|
"tool_call_id": tool_call.id |
|
} |
|
messages.append(tool_response) |
|
|
|
response = openai.chat.completions.create( |
|
model=MODEL, |
|
messages=messages, |
|
max_tokens=300 |
|
) |
|
|
|
return response.choices[0].message.content, state.last_prediction |
|
|
|
def update_model(choice): |
|
print(f"Model switched to: {choice}") |
|
state.current_model = choice.lower() |
|
return f"Model switched to {choice}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Object Detection and Plant Analysis System") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_choice = gr.Radio( |
|
choices=["Owlv2", "DINO"], |
|
value="Owlv2", |
|
label="Select Detection Model", |
|
interactive=True |
|
) |
|
image_input = gr.Image(type="numpy", label="Upload Image") |
|
text_input = gr.Textbox( |
|
label="Ask about the image", |
|
placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'" |
|
) |
|
with gr.Row(): |
|
submit_btn = gr.Button("Analyze") |
|
reset_btn = gr.Button("Reset") |
|
|
|
with gr.Column(): |
|
chatbot = gr.Chatbot() |
|
|
|
output_image = gr.Image(type="numpy", label="Detected Objects") |
|
|
|
def process_interaction(message, image, history): |
|
response, pred_image = chat(message, image, history) |
|
history.append((message, response)) |
|
return "", pred_image, history |
|
|
|
def reset_interface(): |
|
state.current_image = None |
|
state.last_prediction = None |
|
return None, None, None, [] |
|
|
|
model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)]) |
|
|
|
submit_btn.click( |
|
fn=process_interaction, |
|
inputs=[text_input, image_input, chatbot], |
|
outputs=[text_input, output_image, chatbot] |
|
) |
|
|
|
reset_btn.click( |
|
fn=reset_interface, |
|
inputs=[], |
|
outputs=[image_input, output_image, text_input, chatbot] |
|
) |
|
|
|
gr.Markdown("""## Instructions |
|
1. Select the detection model (Owlv2 or DINO) |
|
2. Upload an image |
|
3. Ask specific questions about objects or plants |
|
4. Click Analyze to get results""") |
|
|
|
demo.launch(share=True) |