# imports import os import json import base64 from io import BytesIO from dotenv import load_dotenv from openai import OpenAI import gradio as gr import numpy as np from PIL import Image, ImageDraw import requests import torch from transformers import ( AutoProcessor, Owlv2ForObjectDetection, AutoModelForZeroShotObjectDetection ) # from transformers import AutoProcessor, Owlv2ForObjectDetection from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD # Initialization load_dotenv() os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') MODEL = "gpt-4o" openai = OpenAI() # Initialize models device = "cuda" if torch.cuda.is_available() else "cpu" # Owlv2 owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) # DINO dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) system_message = """You are an expert in object detection. When users mention: 1. "count [object(s)]" - Use detect_objects with proper format based on model 2. "detect [object(s)]" - Same as count 3. "show [object(s)]" - Same as count For DINO model: Format queries as "a [object]." (e.g., "a frog.") For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]] Always use object detection tool when counting/detecting is mentioned.""" system_message += "Always be accurate. If you don't know the answer, say so." class State: def __init__(self): self.current_image = None self.last_prediction = None self.current_model = "owlv2" # Default model state = State() def get_preprocessed_image(pixel_values): pixel_values = pixel_values.squeeze().numpy() unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None] unnormalized_image = (unnormalized_image * 255).astype(np.uint8) unnormalized_image = np.moveaxis(unnormalized_image, 0, -1) return unnormalized_image def encode_image_to_base64(image_array): if image_array is None: return None image = Image.fromarray(image_array) buffered = BytesIO() image.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode('utf-8') def format_query_for_model(text_input, model_type="owlv2"): """Format query based on model requirements""" # Extract objects (e.g., "count frogs and horses" -> ["frog", "horse"]) text = text_input.lower() words = [w.strip('.,?!') for w in text.split() if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']] if model_type == "owlv2": return [["a photo of " + obj for obj in words]] else: # DINO # DINO only works with single object queries with format "a object." return f"a {words[0]}." def detect_objects(query_text): if state.current_image is None: return {"count": 0, "message": "No image provided"} image = Image.fromarray(state.current_image) draw = ImageDraw.Draw(image) if state.current_model == "owlv2": inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = owlv2_model(**inputs) results = owlv2_processor.post_process_object_detection( outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]]) ) else: # DINO inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device) with torch.no_grad(): outputs = dino_model(**inputs) results = dino_processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3, target_sizes=[image.size[::-1]] ) # Draw detection boxes boxes = results[0]["boxes"] scores = results[0]["scores"] for box, score in zip(boxes, scores): box = [round(i) for i in box.tolist()] draw.rectangle(box, outline="red", width=3) draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red") state.last_prediction = np.array(image) return { "count": len(boxes), "confidence": scores.tolist(), "message": f"Detected {len(boxes)} objects" } def identify_plant(): if state.current_image is None: return {"error": "No image provided"} image = Image.fromarray(state.current_image) img_byte_arr = BytesIO() image.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" files = [('images', ('image.jpg', img_byte_arr))] data = {'organs': ['leaf']} try: response = requests.post(api_endpoint, files=files, data=data) if response.status_code == 200: result = response.json() best_match = result['results'][0] return { "scientific_name": best_match['species']['scientificName'], "common_names": best_match['species'].get('commonNames', []), "family": best_match['species']['family']['scientificName'], "genus": best_match['species']['genus']['scientificName'], "confidence": f"{best_match['score']*100:.1f}%" } else: return {"error": f"API Error: {response.status_code}"} except Exception as e: return {"error": f"Error: {str(e)}"} # Tool definitions object_detection_function = { "name": "detect_objects", "description": "Use this function to detect and count objects in images based on text queries.", "parameters": { "type": "object", "properties": { "query_text": { "type": "array", "description": "List of text queries describing objects to detect", "items": {"type": "string"} } } } } plant_identification_function = { "name": "identify_plant", "description": "Use this when asked about plant species identification or botanical classification.", "parameters": { "type": "object", "properties": {}, "required": [] } } tools = [ {"type": "function", "function": object_detection_function}, {"type": "function", "function": plant_identification_function} ] def format_tool_response(tool_response_content): data = json.loads(tool_response_content) if "error" in data: return f"Error: {data['error']}" elif "scientific_name" in data: return f"""📋 Plant Identification Results: 🌿 Scientific Name: {data['scientific_name']} 👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} 👪 Family: {data['family']} 🎯 Confidence: {data['confidence']}""" else: return f"I detected {data['count']} objects in the image." def chat(message, image, history): if image is not None: state.current_image = image if state.current_image is None: return "Please upload an image first.", None base64_image = encode_image_to_base64(state.current_image) messages = [{"role": "system", "content": system_message}] for human, assistant in history: messages.append({"role": "user", "content": human}) messages.append({"role": "assistant", "content": assistant}) # Extract objects to detect from user message # This could be enhanced with better NLP objects_to_detect = message.lower() formatted_query = format_query_for_model(objects_to_detect, state.current_model) messages.append({ "role": "user", "content": [ {"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} ] }) response = openai.chat.completions.create( model=MODEL, messages=messages, tools=tools, max_tokens=300 ) if response.choices[0].finish_reason == "tool_calls": message = response.choices[0].message messages.append(message) for tool_call in message.tool_calls: if tool_call.function.name == "detect_objects": results = detect_objects(formatted_query) else: results = identify_plant() tool_response = { "role": "tool", "content": json.dumps(results), "tool_call_id": tool_call.id } messages.append(tool_response) response = openai.chat.completions.create( model=MODEL, messages=messages, max_tokens=300 ) return response.choices[0].message.content, state.last_prediction def update_model(choice): print(f"Model switched to: {choice}") state.current_model = choice.lower() return f"Model switched to {choice}" # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Object Detection and Plant Analysis System") with gr.Row(): with gr.Column(): model_choice = gr.Radio( choices=["Owlv2", "DINO"], value="Owlv2", label="Select Detection Model", interactive=True ) image_input = gr.Image(type="numpy", label="Upload Image") text_input = gr.Textbox( label="Ask about the image", placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'" ) with gr.Row(): submit_btn = gr.Button("Analyze") reset_btn = gr.Button("Reset") with gr.Column(): chatbot = gr.Chatbot() # output_image = gr.Image(label="Detected Objects") output_image = gr.Image(type="numpy", label="Detected Objects") def process_interaction(message, image, history): response, pred_image = chat(message, image, history) history.append((message, response)) return "", pred_image, history def reset_interface(): state.current_image = None state.last_prediction = None return None, None, None, [] model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)]) submit_btn.click( fn=process_interaction, inputs=[text_input, image_input, chatbot], outputs=[text_input, output_image, chatbot] ) reset_btn.click( fn=reset_interface, inputs=[], outputs=[image_input, output_image, text_input, chatbot] ) gr.Markdown("""## Instructions 1. Select the detection model (Owlv2 or DINO) 2. Upload an image 3. Ask specific questions about objects or plants 4. Click Analyze to get results""") demo.launch(share=True)