# # imports | |
# import os | |
# import json | |
# import base64 | |
# from io import BytesIO | |
# from dotenv import load_dotenv | |
# from openai import OpenAI | |
# import gradio as gr | |
# import numpy as np | |
# from PIL import Image, ImageDraw | |
# import requests | |
# import torch | |
# from transformers import ( | |
# AutoProcessor, | |
# Owlv2ForObjectDetection, | |
# AutoModelForZeroShotObjectDetection | |
# ) | |
# # from transformers import AutoProcessor, Owlv2ForObjectDetection | |
# from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
# # Initialization | |
# load_dotenv() | |
# os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') | |
# PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') | |
# MODEL = "gpt-4o" | |
# openai = OpenAI() | |
# # Initialize models | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# # Owlv2 | |
# owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") | |
# owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) | |
# # DINO | |
# dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") | |
# dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) | |
# system_message = """You are an expert in object detection. When users mention: | |
# 1. "count [object(s)]" - Use detect_objects with proper format based on model | |
# 2. "detect [object(s)]" - Same as count | |
# 3. "show [object(s)]" - Same as count | |
# For DINO model: Format queries as "a [object]." (e.g., "a frog.") | |
# For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]] | |
# Always use object detection tool when counting/detecting is mentioned.""" | |
# system_message += "Always be accurate. If you don't know the answer, say so." | |
# class State: | |
# def __init__(self): | |
# self.current_image = None | |
# self.last_prediction = None | |
# self.current_model = "owlv2" # Default model | |
# state = State() | |
# def get_preprocessed_image(pixel_values): | |
# pixel_values = pixel_values.squeeze().numpy() | |
# unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None] | |
# unnormalized_image = (unnormalized_image * 255).astype(np.uint8) | |
# unnormalized_image = np.moveaxis(unnormalized_image, 0, -1) | |
# return unnormalized_image | |
# def encode_image_to_base64(image_array): | |
# if image_array is None: | |
# return None | |
# image = Image.fromarray(image_array) | |
# buffered = BytesIO() | |
# image.save(buffered, format="JPEG") | |
# return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
# def format_query_for_model(text_input, model_type="owlv2"): | |
# """Format query based on model requirements""" | |
# # Extract objects (e.g., "detect a lion" -> "lion") | |
# text = text_input.lower() | |
# words = [w.strip('.,?!') for w in text.split() | |
# if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']] | |
# if model_type == "owlv2": | |
# # Return just the list of queries for Owlv2, not nested list | |
# queries = ["a photo of " + obj for obj in words] | |
# print("Owlv2 queries:", queries) | |
# return queries | |
# else: # DINO | |
# # DINO query format | |
# query = f"a {words[:]}." | |
# print("DINO query:", query) | |
# return query | |
# def detect_objects(query_text): | |
# if state.current_image is None: | |
# return {"count": 0, "message": "No image provided"} | |
# image = Image.fromarray(state.current_image) | |
# draw = ImageDraw.Draw(image) | |
# if state.current_model == "owlv2": | |
# # For Owlv2, pass the text queries directly | |
# inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device) | |
# with torch.no_grad(): | |
# outputs = owlv2_model(**inputs) | |
# results = owlv2_processor.post_process_object_detection( | |
# outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]]) | |
# ) | |
# else: # DINO | |
# # For DINO, pass the single text query | |
# inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device) | |
# with torch.no_grad(): | |
# outputs = dino_model(**inputs) | |
# results = dino_processor.post_process_grounded_object_detection( | |
# outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3, | |
# target_sizes=[image.size[::-1]] | |
# ) | |
# # Draw detection boxes | |
# boxes = results[0]["boxes"] | |
# scores = results[0]["scores"] | |
# for box, score in zip(boxes, scores): | |
# box = [round(i) for i in box.tolist()] | |
# draw.rectangle(box, outline="red", width=3) | |
# draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red") | |
# state.last_prediction = np.array(image) | |
# return { | |
# "count": len(boxes), | |
# "confidence": scores.tolist(), | |
# "message": f"Detected {len(boxes)} objects" | |
# } | |
# def identify_plant(): | |
# if state.current_image is None: | |
# return {"error": "No image provided"} | |
# image = Image.fromarray(state.current_image) | |
# img_byte_arr = BytesIO() | |
# image.save(img_byte_arr, format='JPEG') | |
# img_byte_arr = img_byte_arr.getvalue() | |
# api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" | |
# files = [('images', ('image.jpg', img_byte_arr))] | |
# data = {'organs': ['leaf']} | |
# try: | |
# response = requests.post(api_endpoint, files=files, data=data) | |
# if response.status_code == 200: | |
# result = response.json() | |
# best_match = result['results'][0] | |
# return { | |
# "scientific_name": best_match['species']['scientificName'], | |
# "common_names": best_match['species'].get('commonNames', []), | |
# "family": best_match['species']['family']['scientificName'], | |
# "genus": best_match['species']['genus']['scientificName'], | |
# "confidence": f"{best_match['score']*100:.1f}%" | |
# } | |
# else: | |
# return {"error": f"API Error: {response.status_code}"} | |
# except Exception as e: | |
# return {"error": f"Error: {str(e)}"} | |
# # Tool definitions | |
# object_detection_function = { | |
# "name": "detect_objects", | |
# "description": "Use this function to detect and count objects in images based on text queries.", | |
# "parameters": { | |
# "type": "object", | |
# "properties": { | |
# "query_text": { | |
# "type": "array", | |
# "description": "List of text queries describing objects to detect", | |
# "items": {"type": "string"} | |
# } | |
# } | |
# } | |
# } | |
# plant_identification_function = { | |
# "name": "identify_plant", | |
# "description": "Use this when asked about plant species identification or botanical classification.", | |
# "parameters": { | |
# "type": "object", | |
# "properties": {}, | |
# "required": [] | |
# } | |
# } | |
# tools = [ | |
# {"type": "function", "function": object_detection_function}, | |
# {"type": "function", "function": plant_identification_function} | |
# ] | |
# def format_tool_response(tool_response_content): | |
# data = json.loads(tool_response_content) | |
# if "error" in data: | |
# return f"Error: {data['error']}" | |
# elif "scientific_name" in data: | |
# return f"""π Plant Identification Results: | |
# πΏ Scientific Name: {data['scientific_name']} | |
# π₯ Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} | |
# πͺ Family: {data['family']} | |
# π― Confidence: {data['confidence']}""" | |
# else: | |
# return f"I detected {data['count']} objects in the image." | |
# def chat(message, image, history): | |
# if image is not None: | |
# state.current_image = image | |
# if state.current_image is None: | |
# return "Please upload an image first.", None | |
# base64_image = encode_image_to_base64(state.current_image) | |
# messages = [{"role": "system", "content": system_message}] | |
# for human, assistant in history: | |
# messages.append({"role": "user", "content": human}) | |
# messages.append({"role": "assistant", "content": assistant}) | |
# # Extract objects to detect from user message | |
# # This could be enhanced with better NLP | |
# objects_to_detect = message.lower() | |
# formatted_query = format_query_for_model(objects_to_detect, state.current_model) | |
# messages.append({ | |
# "role": "user", | |
# "content": [ | |
# {"type": "text", "text": message}, | |
# {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
# ] | |
# }) | |
# response = openai.chat.completions.create( | |
# model=MODEL, | |
# messages=messages, | |
# tools=tools, | |
# max_tokens=300 | |
# ) | |
# if response.choices[0].finish_reason == "tool_calls": | |
# message = response.choices[0].message | |
# messages.append(message) | |
# for tool_call in message.tool_calls: | |
# if tool_call.function.name == "detect_objects": | |
# results = detect_objects(formatted_query) | |
# else: | |
# results = identify_plant() | |
# tool_response = { | |
# "role": "tool", | |
# "content": json.dumps(results), | |
# "tool_call_id": tool_call.id | |
# } | |
# messages.append(tool_response) | |
# response = openai.chat.completions.create( | |
# model=MODEL, | |
# messages=messages, | |
# max_tokens=300 | |
# ) | |
# return response.choices[0].message.content, state.last_prediction | |
# def update_model(choice): | |
# print(f"Model switched to: {choice}") | |
# state.current_model = choice.lower() | |
# return f"Model switched to {choice}" | |
# # Create Gradio interface | |
# with gr.Blocks() as demo: | |
# gr.Markdown("# Object Detection and Plant Analysis System") | |
# with gr.Row(): | |
# with gr.Column(): | |
# model_choice = gr.Radio( | |
# choices=["Owlv2", "DINO"], | |
# value="Owlv2", | |
# label="Select Detection Model", | |
# interactive=True | |
# ) | |
# image_input = gr.Image(type="numpy", label="Upload Image") | |
# text_input = gr.Textbox( | |
# label="Ask about the image", | |
# placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'" | |
# ) | |
# with gr.Row(): | |
# submit_btn = gr.Button("Analyze") | |
# reset_btn = gr.Button("Reset") | |
# with gr.Column(): | |
# chatbot = gr.Chatbot() | |
# # output_image = gr.Image(label="Detected Objects") | |
# output_image = gr.Image(type="numpy", label="Detected Objects") | |
# def process_interaction(message, image, history): | |
# response, pred_image = chat(message, image, history) | |
# history.append((message, response)) | |
# return "", pred_image, history | |
# def reset_interface(): | |
# state.current_image = None | |
# state.last_prediction = None | |
# return None, None, None, [] | |
# model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)]) | |
# submit_btn.click( | |
# fn=process_interaction, | |
# inputs=[text_input, image_input, chatbot], | |
# outputs=[text_input, output_image, chatbot] | |
# ) | |
# reset_btn.click( | |
# fn=reset_interface, | |
# inputs=[], | |
# outputs=[image_input, output_image, text_input, chatbot] | |
# ) | |
# gr.Markdown("""## Instructions | |
# 1. Select the detection model (Owlv2 or DINO) | |
# 2. Upload an image | |
# 3. Ask specific questions about objects or plants | |
# 4. Click Analyze to get results""") | |
# demo.launch(share=True) | |
# imports | |
import os | |
import json | |
import base64 | |
from io import BytesIO | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import requests | |
import matplotlib.pyplot as plt | |
from vision_agent.agent import VisionAgentCoderV2 | |
from vision_agent.models import AgentMessage | |
import vision_agent.tools as T | |
# Initialization | |
load_dotenv() | |
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') | |
os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-anthropic-key-here') | |
PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') | |
MODEL = "gpt-4o" | |
openai = OpenAI() | |
# Initialize VisionAgent | |
agent = VisionAgentCoderV2(verbose=False) | |
system_message = """You are an expert in object detection. When users mention: | |
1. "count [object(s)]" - Use detect_objects to count them | |
2. "detect [object(s)]" - Same as count | |
3. "show [object(s)]" - Same as count | |
Always use object detection tool when counting/detecting is mentioned.""" | |
system_message += "Always be accurate. If you don't know the answer, say so." | |
class State: | |
def __init__(self): | |
self.current_image = None | |
self.last_prediction = None | |
state = State() | |
def encode_image_to_base64(image_array): | |
if image_array is None: | |
return None | |
image = Image.fromarray(image_array) | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
def save_temp_image(image_array): | |
"""Save the image to a temporary file for VisionAgent to process""" | |
temp_path = "temp_image.jpg" | |
image = Image.fromarray(image_array) | |
image.save(temp_path) | |
return temp_path | |
def detect_objects(query_text): | |
if state.current_image is None: | |
return {"count": 0, "message": "No image provided"} | |
# Save the current image to a temporary file | |
image_path = save_temp_image(state.current_image) | |
try: | |
# Use VisionAgent to detect objects | |
image = T.load_image(image_path) | |
# Clean query text to get the object name | |
object_name = query_text[0].replace("a photo of ", "").strip() | |
# Detect objects using CountGD | |
detections = T.countgd_object_detection(object_name, image) | |
# Visualize results | |
result_image = T.overlay_bounding_boxes(image, detections) | |
# Convert result back to numpy array for display | |
state.last_prediction = np.array(result_image) | |
return { | |
"count": len(detections), | |
"confidence": [det["score"] for det in detections], | |
"message": f"Detected {len(detections)} {object_name}(s)" | |
} | |
except Exception as e: | |
print(f"Error in detect_objects: {str(e)}") | |
return {"count": 0, "message": f"Error: {str(e)}"} | |
def identify_plant(): | |
if state.current_image is None: | |
return {"error": "No image provided"} | |
image = Image.fromarray(state.current_image) | |
img_byte_arr = BytesIO() | |
image.save(img_byte_arr, format='JPEG') | |
img_byte_arr = img_byte_arr.getvalue() | |
api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" | |
files = [('images', ('image.jpg', img_byte_arr))] | |
data = {'organs': ['leaf']} | |
try: | |
response = requests.post(api_endpoint, files=files, data=data) | |
if response.status_code == 200: | |
result = response.json() | |
best_match = result['results'][0] | |
return { | |
"scientific_name": best_match['species']['scientificName'], | |
"common_names": best_match['species'].get('commonNames', []), | |
"family": best_match['species']['family']['scientificName'], | |
"genus": best_match['species']['genus']['scientificName'], | |
"confidence": f"{best_match['score']*100:.1f}%" | |
} | |
else: | |
return {"error": f"API Error: {response.status_code}"} | |
except Exception as e: | |
return {"error": f"Error: {str(e)}"} | |
# Tool definitions | |
object_detection_function = { | |
"name": "detect_objects", | |
"description": "Use this function to detect and count objects in images based on text queries.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query_text": { | |
"type": "array", | |
"description": "List of text queries describing objects to detect", | |
"items": {"type": "string"} | |
} | |
} | |
} | |
} | |
plant_identification_function = { | |
"name": "identify_plant", | |
"description": "Use this when asked about plant species identification or botanical classification.", | |
"parameters": { | |
"type": "object", | |
"properties": {}, | |
"required": [] | |
} | |
} | |
tools = [ | |
{"type": "function", "function": object_detection_function}, | |
{"type": "function", "function": plant_identification_function} | |
] | |
def format_tool_response(tool_response_content): | |
data = json.loads(tool_response_content) | |
if "error" in data: | |
return f"Error: {data['error']}" | |
elif "scientific_name" in data: | |
return f"""π Plant Identification Results: | |
πΏ Scientific Name: {data['scientific_name']} | |
π₯ Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} | |
πͺ Family: {data['family']} | |
π― Confidence: {data['confidence']}""" | |
else: | |
return f"I detected {data['count']} objects in the image." | |
def chat(message, image, history): | |
if image is not None: | |
state.current_image = image | |
if state.current_image is None: | |
return "Please upload an image first.", None | |
base64_image = encode_image_to_base64(state.current_image) | |
messages = [{"role": "system", "content": system_message}] | |
for human, assistant in history: | |
messages.append({"role": "user", "content": human}) | |
messages.append({"role": "assistant", "content": assistant}) | |
# Extract objects to detect from user message | |
objects_to_detect = message.lower() | |
# Format query for object detection | |
query = ["a photo of " + objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()] | |
messages.append({ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": message}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
] | |
}) | |
response = openai.chat.completions.create( | |
model=MODEL, | |
messages=messages, | |
tools=tools, | |
max_tokens=300 | |
) | |
if response.choices[0].finish_reason == "tool_calls": | |
message = response.choices[0].message | |
messages.append(message) | |
for tool_call in message.tool_calls: | |
if tool_call.function.name == "detect_objects": | |
results = detect_objects(query) | |
else: | |
results = identify_plant() | |
tool_response = { | |
"role": "tool", | |
"content": json.dumps(results), | |
"tool_call_id": tool_call.id | |
} | |
messages.append(tool_response) | |
response = openai.chat.completions.create( | |
model=MODEL, | |
messages=messages, | |
max_tokens=300 | |
) | |
return response.choices[0].message.content, state.last_prediction | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Object Detection and Plant Analysis System using VisionAgent") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="numpy", label="Upload Image") | |
text_input = gr.Textbox( | |
label="Ask about the image", | |
placeholder="e.g., 'Count dogs in this image' or 'What species is this plant?'" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Analyze") | |
reset_btn = gr.Button("Reset") | |
with gr.Column(): | |
chatbot = gr.Chatbot() | |
output_image = gr.Image(type="numpy", label="Detection Results") | |
def process_interaction(message, image, history): | |
response, pred_image = chat(message, image, history) | |
history.append((message, response)) | |
return "", pred_image, history | |
def reset_interface(): | |
state.current_image = None | |
state.last_prediction = None | |
return None, None, None, [] | |
submit_btn.click( | |
fn=process_interaction, | |
inputs=[text_input, image_input, chatbot], | |
outputs=[text_input, output_image, chatbot] | |
) | |
reset_btn.click( | |
fn=reset_interface, | |
inputs=[], | |
outputs=[image_input, output_image, text_input, chatbot] | |
) | |
gr.Markdown("""## Instructions | |
1. Upload an image | |
2. Ask specific questions about objects or plants | |
3. Click Analyze to get results | |
Examples: | |
- "Count the number of people in this image" | |
- "Detect cats and dogs" | |
- "What species is this plant?" | |
""") | |
demo.launch(share=True) |