obichimav's picture
Update app.py
f8cecaf verified
raw
history blame
21.2 kB
# # imports
# import os
# import json
# import base64
# from io import BytesIO
# from dotenv import load_dotenv
# from openai import OpenAI
# import gradio as gr
# import numpy as np
# from PIL import Image, ImageDraw
# import requests
# import torch
# from transformers import (
# AutoProcessor,
# Owlv2ForObjectDetection,
# AutoModelForZeroShotObjectDetection
# )
# # from transformers import AutoProcessor, Owlv2ForObjectDetection
# from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
# # Initialization
# load_dotenv()
# os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
# PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
# MODEL = "gpt-4o"
# openai = OpenAI()
# # Initialize models
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # Owlv2
# owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
# owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
# # DINO
# dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
# dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
# system_message = """You are an expert in object detection. When users mention:
# 1. "count [object(s)]" - Use detect_objects with proper format based on model
# 2. "detect [object(s)]" - Same as count
# 3. "show [object(s)]" - Same as count
# For DINO model: Format queries as "a [object]." (e.g., "a frog.")
# For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]]
# Always use object detection tool when counting/detecting is mentioned."""
# system_message += "Always be accurate. If you don't know the answer, say so."
# class State:
# def __init__(self):
# self.current_image = None
# self.last_prediction = None
# self.current_model = "owlv2" # Default model
# state = State()
# def get_preprocessed_image(pixel_values):
# pixel_values = pixel_values.squeeze().numpy()
# unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
# unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
# unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
# return unnormalized_image
# def encode_image_to_base64(image_array):
# if image_array is None:
# return None
# image = Image.fromarray(image_array)
# buffered = BytesIO()
# image.save(buffered, format="JPEG")
# return base64.b64encode(buffered.getvalue()).decode('utf-8')
# def format_query_for_model(text_input, model_type="owlv2"):
# """Format query based on model requirements"""
# # Extract objects (e.g., "detect a lion" -> "lion")
# text = text_input.lower()
# words = [w.strip('.,?!') for w in text.split()
# if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
# if model_type == "owlv2":
# # Return just the list of queries for Owlv2, not nested list
# queries = ["a photo of " + obj for obj in words]
# print("Owlv2 queries:", queries)
# return queries
# else: # DINO
# # DINO query format
# query = f"a {words[:]}."
# print("DINO query:", query)
# return query
# def detect_objects(query_text):
# if state.current_image is None:
# return {"count": 0, "message": "No image provided"}
# image = Image.fromarray(state.current_image)
# draw = ImageDraw.Draw(image)
# if state.current_model == "owlv2":
# # For Owlv2, pass the text queries directly
# inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
# with torch.no_grad():
# outputs = owlv2_model(**inputs)
# results = owlv2_processor.post_process_object_detection(
# outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
# )
# else: # DINO
# # For DINO, pass the single text query
# inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
# with torch.no_grad():
# outputs = dino_model(**inputs)
# results = dino_processor.post_process_grounded_object_detection(
# outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3,
# target_sizes=[image.size[::-1]]
# )
# # Draw detection boxes
# boxes = results[0]["boxes"]
# scores = results[0]["scores"]
# for box, score in zip(boxes, scores):
# box = [round(i) for i in box.tolist()]
# draw.rectangle(box, outline="red", width=3)
# draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red")
# state.last_prediction = np.array(image)
# return {
# "count": len(boxes),
# "confidence": scores.tolist(),
# "message": f"Detected {len(boxes)} objects"
# }
# def identify_plant():
# if state.current_image is None:
# return {"error": "No image provided"}
# image = Image.fromarray(state.current_image)
# img_byte_arr = BytesIO()
# image.save(img_byte_arr, format='JPEG')
# img_byte_arr = img_byte_arr.getvalue()
# api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}"
# files = [('images', ('image.jpg', img_byte_arr))]
# data = {'organs': ['leaf']}
# try:
# response = requests.post(api_endpoint, files=files, data=data)
# if response.status_code == 200:
# result = response.json()
# best_match = result['results'][0]
# return {
# "scientific_name": best_match['species']['scientificName'],
# "common_names": best_match['species'].get('commonNames', []),
# "family": best_match['species']['family']['scientificName'],
# "genus": best_match['species']['genus']['scientificName'],
# "confidence": f"{best_match['score']*100:.1f}%"
# }
# else:
# return {"error": f"API Error: {response.status_code}"}
# except Exception as e:
# return {"error": f"Error: {str(e)}"}
# # Tool definitions
# object_detection_function = {
# "name": "detect_objects",
# "description": "Use this function to detect and count objects in images based on text queries.",
# "parameters": {
# "type": "object",
# "properties": {
# "query_text": {
# "type": "array",
# "description": "List of text queries describing objects to detect",
# "items": {"type": "string"}
# }
# }
# }
# }
# plant_identification_function = {
# "name": "identify_plant",
# "description": "Use this when asked about plant species identification or botanical classification.",
# "parameters": {
# "type": "object",
# "properties": {},
# "required": []
# }
# }
# tools = [
# {"type": "function", "function": object_detection_function},
# {"type": "function", "function": plant_identification_function}
# ]
# def format_tool_response(tool_response_content):
# data = json.loads(tool_response_content)
# if "error" in data:
# return f"Error: {data['error']}"
# elif "scientific_name" in data:
# return f"""πŸ“‹ Plant Identification Results:
# 🌿 Scientific Name: {data['scientific_name']}
# πŸ‘₯ Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'}
# πŸ‘ͺ Family: {data['family']}
# 🎯 Confidence: {data['confidence']}"""
# else:
# return f"I detected {data['count']} objects in the image."
# def chat(message, image, history):
# if image is not None:
# state.current_image = image
# if state.current_image is None:
# return "Please upload an image first.", None
# base64_image = encode_image_to_base64(state.current_image)
# messages = [{"role": "system", "content": system_message}]
# for human, assistant in history:
# messages.append({"role": "user", "content": human})
# messages.append({"role": "assistant", "content": assistant})
# # Extract objects to detect from user message
# # This could be enhanced with better NLP
# objects_to_detect = message.lower()
# formatted_query = format_query_for_model(objects_to_detect, state.current_model)
# messages.append({
# "role": "user",
# "content": [
# {"type": "text", "text": message},
# {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
# ]
# })
# response = openai.chat.completions.create(
# model=MODEL,
# messages=messages,
# tools=tools,
# max_tokens=300
# )
# if response.choices[0].finish_reason == "tool_calls":
# message = response.choices[0].message
# messages.append(message)
# for tool_call in message.tool_calls:
# if tool_call.function.name == "detect_objects":
# results = detect_objects(formatted_query)
# else:
# results = identify_plant()
# tool_response = {
# "role": "tool",
# "content": json.dumps(results),
# "tool_call_id": tool_call.id
# }
# messages.append(tool_response)
# response = openai.chat.completions.create(
# model=MODEL,
# messages=messages,
# max_tokens=300
# )
# return response.choices[0].message.content, state.last_prediction
# def update_model(choice):
# print(f"Model switched to: {choice}")
# state.current_model = choice.lower()
# return f"Model switched to {choice}"
# # Create Gradio interface
# with gr.Blocks() as demo:
# gr.Markdown("# Object Detection and Plant Analysis System")
# with gr.Row():
# with gr.Column():
# model_choice = gr.Radio(
# choices=["Owlv2", "DINO"],
# value="Owlv2",
# label="Select Detection Model",
# interactive=True
# )
# image_input = gr.Image(type="numpy", label="Upload Image")
# text_input = gr.Textbox(
# label="Ask about the image",
# placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'"
# )
# with gr.Row():
# submit_btn = gr.Button("Analyze")
# reset_btn = gr.Button("Reset")
# with gr.Column():
# chatbot = gr.Chatbot()
# # output_image = gr.Image(label="Detected Objects")
# output_image = gr.Image(type="numpy", label="Detected Objects")
# def process_interaction(message, image, history):
# response, pred_image = chat(message, image, history)
# history.append((message, response))
# return "", pred_image, history
# def reset_interface():
# state.current_image = None
# state.last_prediction = None
# return None, None, None, []
# model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)])
# submit_btn.click(
# fn=process_interaction,
# inputs=[text_input, image_input, chatbot],
# outputs=[text_input, output_image, chatbot]
# )
# reset_btn.click(
# fn=reset_interface,
# inputs=[],
# outputs=[image_input, output_image, text_input, chatbot]
# )
# gr.Markdown("""## Instructions
# 1. Select the detection model (Owlv2 or DINO)
# 2. Upload an image
# 3. Ask specific questions about objects or plants
# 4. Click Analyze to get results""")
# demo.launch(share=True)
# imports
import os
import json
import base64
from io import BytesIO
from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import requests
import matplotlib.pyplot as plt
from vision_agent.agent import VisionAgentCoderV2
from vision_agent.models import AgentMessage
import vision_agent.tools as T
# Initialization
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-anthropic-key-here')
PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
MODEL = "gpt-4o"
openai = OpenAI()
# Initialize VisionAgent
agent = VisionAgentCoderV2(verbose=False)
system_message = """You are an expert in object detection. When users mention:
1. "count [object(s)]" - Use detect_objects to count them
2. "detect [object(s)]" - Same as count
3. "show [object(s)]" - Same as count
Always use object detection tool when counting/detecting is mentioned."""
system_message += "Always be accurate. If you don't know the answer, say so."
class State:
def __init__(self):
self.current_image = None
self.last_prediction = None
state = State()
def encode_image_to_base64(image_array):
if image_array is None:
return None
image = Image.fromarray(image_array)
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def save_temp_image(image_array):
"""Save the image to a temporary file for VisionAgent to process"""
temp_path = "temp_image.jpg"
image = Image.fromarray(image_array)
image.save(temp_path)
return temp_path
def detect_objects(query_text):
if state.current_image is None:
return {"count": 0, "message": "No image provided"}
# Save the current image to a temporary file
image_path = save_temp_image(state.current_image)
try:
# Use VisionAgent to detect objects
image = T.load_image(image_path)
# Clean query text to get the object name
object_name = query_text[0].replace("a photo of ", "").strip()
# Detect objects using CountGD
detections = T.countgd_object_detection(object_name, image)
# Visualize results
result_image = T.overlay_bounding_boxes(image, detections)
# Convert result back to numpy array for display
state.last_prediction = np.array(result_image)
return {
"count": len(detections),
"confidence": [det["score"] for det in detections],
"message": f"Detected {len(detections)} {object_name}(s)"
}
except Exception as e:
print(f"Error in detect_objects: {str(e)}")
return {"count": 0, "message": f"Error: {str(e)}"}
def identify_plant():
if state.current_image is None:
return {"error": "No image provided"}
image = Image.fromarray(state.current_image)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}"
files = [('images', ('image.jpg', img_byte_arr))]
data = {'organs': ['leaf']}
try:
response = requests.post(api_endpoint, files=files, data=data)
if response.status_code == 200:
result = response.json()
best_match = result['results'][0]
return {
"scientific_name": best_match['species']['scientificName'],
"common_names": best_match['species'].get('commonNames', []),
"family": best_match['species']['family']['scientificName'],
"genus": best_match['species']['genus']['scientificName'],
"confidence": f"{best_match['score']*100:.1f}%"
}
else:
return {"error": f"API Error: {response.status_code}"}
except Exception as e:
return {"error": f"Error: {str(e)}"}
# Tool definitions
object_detection_function = {
"name": "detect_objects",
"description": "Use this function to detect and count objects in images based on text queries.",
"parameters": {
"type": "object",
"properties": {
"query_text": {
"type": "array",
"description": "List of text queries describing objects to detect",
"items": {"type": "string"}
}
}
}
}
plant_identification_function = {
"name": "identify_plant",
"description": "Use this when asked about plant species identification or botanical classification.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
tools = [
{"type": "function", "function": object_detection_function},
{"type": "function", "function": plant_identification_function}
]
def format_tool_response(tool_response_content):
data = json.loads(tool_response_content)
if "error" in data:
return f"Error: {data['error']}"
elif "scientific_name" in data:
return f"""πŸ“‹ Plant Identification Results:
🌿 Scientific Name: {data['scientific_name']}
πŸ‘₯ Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'}
πŸ‘ͺ Family: {data['family']}
🎯 Confidence: {data['confidence']}"""
else:
return f"I detected {data['count']} objects in the image."
def chat(message, image, history):
if image is not None:
state.current_image = image
if state.current_image is None:
return "Please upload an image first.", None
base64_image = encode_image_to_base64(state.current_image)
messages = [{"role": "system", "content": system_message}]
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
# Extract objects to detect from user message
objects_to_detect = message.lower()
# Format query for object detection
query = ["a photo of " + objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()]
messages.append({
"role": "user",
"content": [
{"type": "text", "text": message},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
]
})
response = openai.chat.completions.create(
model=MODEL,
messages=messages,
tools=tools,
max_tokens=300
)
if response.choices[0].finish_reason == "tool_calls":
message = response.choices[0].message
messages.append(message)
for tool_call in message.tool_calls:
if tool_call.function.name == "detect_objects":
results = detect_objects(query)
else:
results = identify_plant()
tool_response = {
"role": "tool",
"content": json.dumps(results),
"tool_call_id": tool_call.id
}
messages.append(tool_response)
response = openai.chat.completions.create(
model=MODEL,
messages=messages,
max_tokens=300
)
return response.choices[0].message.content, state.last_prediction
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Object Detection and Plant Analysis System using VisionAgent")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Upload Image")
text_input = gr.Textbox(
label="Ask about the image",
placeholder="e.g., 'Count dogs in this image' or 'What species is this plant?'"
)
with gr.Row():
submit_btn = gr.Button("Analyze")
reset_btn = gr.Button("Reset")
with gr.Column():
chatbot = gr.Chatbot()
output_image = gr.Image(type="numpy", label="Detection Results")
def process_interaction(message, image, history):
response, pred_image = chat(message, image, history)
history.append((message, response))
return "", pred_image, history
def reset_interface():
state.current_image = None
state.last_prediction = None
return None, None, None, []
submit_btn.click(
fn=process_interaction,
inputs=[text_input, image_input, chatbot],
outputs=[text_input, output_image, chatbot]
)
reset_btn.click(
fn=reset_interface,
inputs=[],
outputs=[image_input, output_image, text_input, chatbot]
)
gr.Markdown("""## Instructions
1. Upload an image
2. Ask specific questions about objects or plants
3. Click Analyze to get results
Examples:
- "Count the number of people in this image"
- "Detect cats and dogs"
- "What species is this plant?"
""")
demo.launch(share=True)