# # imports # import os # import json # import base64 # from io import BytesIO # from dotenv import load_dotenv # from openai import OpenAI # import gradio as gr # import numpy as np # from PIL import Image, ImageDraw # import requests # import torch # from transformers import ( # AutoProcessor, # Owlv2ForObjectDetection, # AutoModelForZeroShotObjectDetection # ) # # from transformers import AutoProcessor, Owlv2ForObjectDetection # from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD # # Initialization # load_dotenv() # os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') # PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') # MODEL = "gpt-4o" # openai = OpenAI() # # Initialize models # device = "cuda" if torch.cuda.is_available() else "cpu" # # Owlv2 # owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") # owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) # # DINO # dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") # dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) # system_message = """You are an expert in object detection. When users mention: # 1. "count [object(s)]" - Use detect_objects with proper format based on model # 2. "detect [object(s)]" - Same as count # 3. "show [object(s)]" - Same as count # For DINO model: Format queries as "a [object]." (e.g., "a frog.") # For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]] # Always use object detection tool when counting/detecting is mentioned.""" # system_message += "Always be accurate. If you don't know the answer, say so." # class State: # def __init__(self): # self.current_image = None # self.last_prediction = None # self.current_model = "owlv2" # Default model # state = State() # def get_preprocessed_image(pixel_values): # pixel_values = pixel_values.squeeze().numpy() # unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None] # unnormalized_image = (unnormalized_image * 255).astype(np.uint8) # unnormalized_image = np.moveaxis(unnormalized_image, 0, -1) # return unnormalized_image # def encode_image_to_base64(image_array): # if image_array is None: # return None # image = Image.fromarray(image_array) # buffered = BytesIO() # image.save(buffered, format="JPEG") # return base64.b64encode(buffered.getvalue()).decode('utf-8') # def format_query_for_model(text_input, model_type="owlv2"): # """Format query based on model requirements""" # # Extract objects (e.g., "detect a lion" -> "lion") # text = text_input.lower() # words = [w.strip('.,?!') for w in text.split() # if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']] # if model_type == "owlv2": # # Return just the list of queries for Owlv2, not nested list # queries = ["a photo of " + obj for obj in words] # print("Owlv2 queries:", queries) # return queries # else: # DINO # # DINO query format # query = f"a {words[:]}." # print("DINO query:", query) # return query # def detect_objects(query_text): # if state.current_image is None: # return {"count": 0, "message": "No image provided"} # image = Image.fromarray(state.current_image) # draw = ImageDraw.Draw(image) # if state.current_model == "owlv2": # # For Owlv2, pass the text queries directly # inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device) # with torch.no_grad(): # outputs = owlv2_model(**inputs) # results = owlv2_processor.post_process_object_detection( # outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]]) # ) # else: # DINO # # For DINO, pass the single text query # inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device) # with torch.no_grad(): # outputs = dino_model(**inputs) # results = dino_processor.post_process_grounded_object_detection( # outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3, # target_sizes=[image.size[::-1]] # ) # # Draw detection boxes # boxes = results[0]["boxes"] # scores = results[0]["scores"] # for box, score in zip(boxes, scores): # box = [round(i) for i in box.tolist()] # draw.rectangle(box, outline="red", width=3) # draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red") # state.last_prediction = np.array(image) # return { # "count": len(boxes), # "confidence": scores.tolist(), # "message": f"Detected {len(boxes)} objects" # } # def identify_plant(): # if state.current_image is None: # return {"error": "No image provided"} # image = Image.fromarray(state.current_image) # img_byte_arr = BytesIO() # image.save(img_byte_arr, format='JPEG') # img_byte_arr = img_byte_arr.getvalue() # api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" # files = [('images', ('image.jpg', img_byte_arr))] # data = {'organs': ['leaf']} # try: # response = requests.post(api_endpoint, files=files, data=data) # if response.status_code == 200: # result = response.json() # best_match = result['results'][0] # return { # "scientific_name": best_match['species']['scientificName'], # "common_names": best_match['species'].get('commonNames', []), # "family": best_match['species']['family']['scientificName'], # "genus": best_match['species']['genus']['scientificName'], # "confidence": f"{best_match['score']*100:.1f}%" # } # else: # return {"error": f"API Error: {response.status_code}"} # except Exception as e: # return {"error": f"Error: {str(e)}"} # # Tool definitions # object_detection_function = { # "name": "detect_objects", # "description": "Use this function to detect and count objects in images based on text queries.", # "parameters": { # "type": "object", # "properties": { # "query_text": { # "type": "array", # "description": "List of text queries describing objects to detect", # "items": {"type": "string"} # } # } # } # } # plant_identification_function = { # "name": "identify_plant", # "description": "Use this when asked about plant species identification or botanical classification.", # "parameters": { # "type": "object", # "properties": {}, # "required": [] # } # } # tools = [ # {"type": "function", "function": object_detection_function}, # {"type": "function", "function": plant_identification_function} # ] # def format_tool_response(tool_response_content): # data = json.loads(tool_response_content) # if "error" in data: # return f"Error: {data['error']}" # elif "scientific_name" in data: # return f"""📋 Plant Identification Results: # 🌿 Scientific Name: {data['scientific_name']} # 👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} # 👪 Family: {data['family']} # 🎯 Confidence: {data['confidence']}""" # else: # return f"I detected {data['count']} objects in the image." # def chat(message, image, history): # if image is not None: # state.current_image = image # if state.current_image is None: # return "Please upload an image first.", None # base64_image = encode_image_to_base64(state.current_image) # messages = [{"role": "system", "content": system_message}] # for human, assistant in history: # messages.append({"role": "user", "content": human}) # messages.append({"role": "assistant", "content": assistant}) # # Extract objects to detect from user message # # This could be enhanced with better NLP # objects_to_detect = message.lower() # formatted_query = format_query_for_model(objects_to_detect, state.current_model) # messages.append({ # "role": "user", # "content": [ # {"type": "text", "text": message}, # {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} # ] # }) # response = openai.chat.completions.create( # model=MODEL, # messages=messages, # tools=tools, # max_tokens=300 # ) # if response.choices[0].finish_reason == "tool_calls": # message = response.choices[0].message # messages.append(message) # for tool_call in message.tool_calls: # if tool_call.function.name == "detect_objects": # results = detect_objects(formatted_query) # else: # results = identify_plant() # tool_response = { # "role": "tool", # "content": json.dumps(results), # "tool_call_id": tool_call.id # } # messages.append(tool_response) # response = openai.chat.completions.create( # model=MODEL, # messages=messages, # max_tokens=300 # ) # return response.choices[0].message.content, state.last_prediction # def update_model(choice): # print(f"Model switched to: {choice}") # state.current_model = choice.lower() # return f"Model switched to {choice}" # # Create Gradio interface # with gr.Blocks() as demo: # gr.Markdown("# Object Detection and Plant Analysis System") # with gr.Row(): # with gr.Column(): # model_choice = gr.Radio( # choices=["Owlv2", "DINO"], # value="Owlv2", # label="Select Detection Model", # interactive=True # ) # image_input = gr.Image(type="numpy", label="Upload Image") # text_input = gr.Textbox( # label="Ask about the image", # placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'" # ) # with gr.Row(): # submit_btn = gr.Button("Analyze") # reset_btn = gr.Button("Reset") # with gr.Column(): # chatbot = gr.Chatbot() # # output_image = gr.Image(label="Detected Objects") # output_image = gr.Image(type="numpy", label="Detected Objects") # def process_interaction(message, image, history): # response, pred_image = chat(message, image, history) # history.append((message, response)) # return "", pred_image, history # def reset_interface(): # state.current_image = None # state.last_prediction = None # return None, None, None, [] # model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)]) # submit_btn.click( # fn=process_interaction, # inputs=[text_input, image_input, chatbot], # outputs=[text_input, output_image, chatbot] # ) # reset_btn.click( # fn=reset_interface, # inputs=[], # outputs=[image_input, output_image, text_input, chatbot] # ) # gr.Markdown("""## Instructions # 1. Select the detection model (Owlv2 or DINO) # 2. Upload an image # 3. Ask specific questions about objects or plants # 4. Click Analyze to get results""") # demo.launch(share=True) # imports import os import json import base64 from io import BytesIO from dotenv import load_dotenv from openai import OpenAI import gradio as gr import numpy as np from PIL import Image, ImageDraw import requests import matplotlib.pyplot as plt from vision_agent.agent import VisionAgentCoderV2 from vision_agent.models import AgentMessage import vision_agent.tools as T # Initialization load_dotenv() os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-anthropic-key-here') PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') MODEL = "gpt-4o" openai = OpenAI() # Initialize VisionAgent agent = VisionAgentCoderV2(verbose=False) system_message = """You are an expert in object detection. When users mention: 1. "count [object(s)]" - Use detect_objects to count them 2. "detect [object(s)]" - Same as count 3. "show [object(s)]" - Same as count Always use object detection tool when counting/detecting is mentioned.""" system_message += "Always be accurate. If you don't know the answer, say so." class State: def __init__(self): self.current_image = None self.last_prediction = None state = State() def encode_image_to_base64(image_array): if image_array is None: return None image = Image.fromarray(image_array) buffered = BytesIO() image.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode('utf-8') def save_temp_image(image_array): """Save the image to a temporary file for VisionAgent to process""" temp_path = "temp_image.jpg" image = Image.fromarray(image_array) image.save(temp_path) return temp_path def detect_objects(query_text): if state.current_image is None: return {"count": 0, "message": "No image provided"} # Save the current image to a temporary file image_path = save_temp_image(state.current_image) try: # Clean query text to get the object name object_name = query_text[0].replace("a photo of ", "").strip() # Let VisionAgent handle the detection with its agent-based approach # Create agent message for object detection agent_message = [ AgentMessage( role="user", content=f"Count the number of {object_name} in this image. Only show detections with high confidence (>0.75).", media=[image_path] ) ] # Generate code using VisionAgent code_context = agent.generate_code(agent_message) # Load the image for visualization image = T.load_image(image_path) # Use multiple models for detection and get high confidence results # First try the specialized detector detections = T.countgd_object_detection(object_name, image, conf_threshold=0.75) # If no high-confidence detections, try the more general object detector if not detections: # Try a different model with the same high threshold try: detections = T.grounding_dino_detection(object_name, image, box_threshold=0.75) except: pass # Only keep high confidence detections high_conf_detections = [det for det in detections if det.get("score", 0) > 0.75] # Visualize only high confidence results with clear labeling result_image = T.overlay_bounding_boxes( image, high_conf_detections, labels=[f"{object_name}: {det['score']:.2f}" for det in high_conf_detections] ) # Convert result back to numpy array for display state.last_prediction = np.array(result_image) return { "count": len(high_conf_detections), "confidence": [det["score"] for det in high_conf_detections], "message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>0.75)" } except Exception as e: print(f"Error in detect_objects: {str(e)}") return {"count": 0, "message": f"Error: {str(e)}"} def identify_plant(): if state.current_image is None: return {"error": "No image provided"} image = Image.fromarray(state.current_image) img_byte_arr = BytesIO() image.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" files = [('images', ('image.jpg', img_byte_arr))] data = {'organs': ['leaf']} try: response = requests.post(api_endpoint, files=files, data=data) if response.status_code == 200: result = response.json() best_match = result['results'][0] return { "scientific_name": best_match['species']['scientificName'], "common_names": best_match['species'].get('commonNames', []), "family": best_match['species']['family']['scientificName'], "genus": best_match['species']['genus']['scientificName'], "confidence": f"{best_match['score']*100:.1f}%" } else: return {"error": f"API Error: {response.status_code}"} except Exception as e: return {"error": f"Error: {str(e)}"} # Tool definitions object_detection_function = { "name": "detect_objects", "description": "Use this function to detect and count objects in images based on text queries.", "parameters": { "type": "object", "properties": { "query_text": { "type": "array", "description": "List of text queries describing objects to detect", "items": {"type": "string"} } } } } plant_identification_function = { "name": "identify_plant", "description": "Use this when asked about plant species identification or botanical classification.", "parameters": { "type": "object", "properties": {}, "required": [] } } tools = [ {"type": "function", "function": object_detection_function}, {"type": "function", "function": plant_identification_function} ] def format_tool_response(tool_response_content): data = json.loads(tool_response_content) if "error" in data: return f"Error: {data['error']}" elif "scientific_name" in data: return f"""📋 Plant Identification Results: 🌿 Scientific Name: {data['scientific_name']} 👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} 👪 Family: {data['family']} 🎯 Confidence: {data['confidence']}""" else: return f"I detected {data['count']} objects in the image." def chat(message, image, history): if image is not None: state.current_image = image if state.current_image is None: return "Please upload an image first.", None base64_image = encode_image_to_base64(state.current_image) messages = [{"role": "system", "content": system_message}] for human, assistant in history: messages.append({"role": "user", "content": human}) messages.append({"role": "assistant", "content": assistant}) # Extract objects to detect from user message objects_to_detect = message.lower() # Format query for object detection - keep it simple and direct cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip() query = ["a photo of " + cleaned_query] messages.append({ "role": "user", "content": [ {"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} ] }) response = openai.chat.completions.create( model=MODEL, messages=messages, tools=tools, max_tokens=300 ) if response.choices[0].finish_reason == "tool_calls": message = response.choices[0].message messages.append(message) for tool_call in message.tool_calls: if tool_call.function.name == "detect_objects": results = detect_objects(query) else: results = identify_plant() tool_response = { "role": "tool", "content": json.dumps(results), "tool_call_id": tool_call.id } messages.append(tool_response) response = openai.chat.completions.create( model=MODEL, messages=messages, max_tokens=300 ) return response.choices[0].message.content, state.last_prediction # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Object Detection and Plant Analysis System using VisionAgent") with gr.Row(): with gr.Column(): image_input = gr.Image(type="numpy", label="Upload Image") text_input = gr.Textbox( label="Ask about the image", placeholder="e.g., 'Count dogs in this image' or 'What species is this plant?'" ) with gr.Row(): submit_btn = gr.Button("Analyze") reset_btn = gr.Button("Reset") with gr.Column(): chatbot = gr.Chatbot() output_image = gr.Image(type="numpy", label="Detection Results") def process_interaction(message, image, history): response, pred_image = chat(message, image, history) history.append((message, response)) return "", pred_image, history def reset_interface(): state.current_image = None state.last_prediction = None return None, None, None, [] submit_btn.click( fn=process_interaction, inputs=[text_input, image_input, chatbot], outputs=[text_input, output_image, chatbot] ) reset_btn.click( fn=reset_interface, inputs=[], outputs=[image_input, output_image, text_input, chatbot] ) gr.Markdown("""## Instructions 1. Upload an image 2. Ask specific questions about objects or plants 3. Click Analyze to get results Examples: - "Count the number of people in this image" - "Detect cats and dogs" - "What species is this plant?" """) demo.launch(share=True)