kendrickfff commited on
Commit
2101247
·
verified ·
1 Parent(s): 603c4d7

use yolo to handle multiple detection

Browse files
Files changed (1) hide show
  1. app.py +35 -111
app.py CHANGED
@@ -1,148 +1,72 @@
1
  import os
2
  import gradio as gr
3
- from transformers import DetrImageProcessor, DetrForObjectDetection
4
- from langchain_google_genai.chat_models import ChatGoogleGenerativeAI # Import Gemini
5
- from PIL import Image
6
  import torch
7
- import json
8
- import requests
9
 
10
- # Load credentials (stringified JSON) from environment variable for Gemini
11
- credentials_string = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
12
- if not credentials_string:
13
- raise ValueError("GOOGLE_APPLICATION_CREDENTIALS is not set in the environment!")
14
 
15
- # Parse the stringified JSON back to a Python dictionary
16
- credentials = json.loads(credentials_string)
17
-
18
- # Save the credentials to a temporary JSON file (required by Google SDKs)
19
- with open("service_account.json", "w") as f:
20
- json.dump(credentials, f)
21
-
22
- # Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the temporary file
23
- os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json"
24
-
25
- # Initialize Gemini model (chatbot)
26
- llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
27
-
28
- # Initialize DETR model and processor for object detection
29
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
30
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
31
-
32
- # Load COCO class labels (from the official COCO dataset)
33
- COCO_CLASSES = [
34
- 'airplane', 'apple', 'backpack', 'banana', 'baseball hat', 'baseball glove', 'bear', 'bed', 'bench', 'bicycle',
35
- 'bird', 'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake', 'car', 'carrot', 'cat', 'cell phone', 'chair',
36
- 'clock', 'couch', 'cow', 'cup', 'dining table', 'dog', 'donut', 'elephant', 'fire hydrant', 'fork', 'frisbee',
37
- 'giraffe', 'hair drier', 'handbag', 'horse', 'hot dog', 'keyboard', 'kite', 'knife', 'laptop', 'microwave',
38
- 'motorcycle', 'mouse', 'orange', 'oven', 'parking meter', 'person', 'pizza', 'potted plant', 'refrigerator',
39
- 'remote', 'sandwich', 'scissors', 'sheep', 'sink', 'skateboard', 'skis', 'snowboard', 'spoon', 'sports ball',
40
- 'stop sign', 'suitcase', 'surfboard', 'teddy bear', 'tennis racket', 'tie', 'toaster', 'toilet', 'toothbrush',
41
- 'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase', 'wine glass'
42
- ]
43
-
44
- # Global chat history variable
45
  chat_history = []
46
 
47
- # Function for chatting with Gemini
48
- def chat_with_gemini(message):
49
  global chat_history
50
- bot_response = llm.predict(message) # This will interact with the Gemini model
51
  chat_history.append((message, bot_response))
52
  return chat_history
53
 
54
- # Function for analyzing the uploaded image
55
  def analyze_image(image_path):
56
  global chat_history
57
  try:
58
- # Open and preprocess the image
59
  image = Image.open(image_path).convert("RGB")
60
- inputs = processor(images=image, return_tensors="pt")
61
-
62
- # Perform inference
63
- with torch.no_grad():
64
- outputs = model(**inputs)
65
-
66
- # Set a target size for post-processing
67
- target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
68
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
69
-
70
- # Collect detected objects (with no minimum confidence filter)
71
  detected_objects = []
72
- for idx, label in enumerate(results["labels"]):
73
- # Get the object label based on label index
74
- object_name = COCO_CLASSES[label.item()] # Assuming COCO_CLASSES is available
75
- score = results["scores"][idx].item() # Confidence score for this detection
76
-
77
- # Store only objects with a score higher than a threshold (e.g., 0.1)
78
- if score > 0.1:
79
- detected_objects.append(f"{object_name} (score: {score:.2f})")
80
-
 
 
 
81
  if detected_objects:
82
  bot_response = f"Objects detected: {', '.join(detected_objects)}."
83
  else:
84
  bot_response = "No objects detected."
85
-
86
  chat_history.append(("Uploaded an image for analysis", bot_response))
87
- return chat_history
88
  except Exception as e:
89
  error_msg = f"Error processing the image: {str(e)}"
90
  chat_history.append(("Error during image analysis", error_msg))
91
- return chat_history
92
 
93
- # Build the Gradio interface
94
  with gr.Blocks() as demo:
95
  gr.Markdown("# Ken Chatbot")
96
  gr.Markdown("Ask me anything or upload an image for analysis!")
97
 
98
- # Chatbot display without "User" or "Bot" labels
99
  chatbot = gr.Chatbot(elem_id="chatbot")
100
-
101
- # User input components
102
  msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
103
  send_btn = gr.Button("Send")
104
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
 
105
 
106
- # Define interactions
107
- def handle_text_message(message):
108
- return chat_with_gemini(message)
109
-
110
- def handle_image_upload(image_path):
111
- return analyze_image(image_path)
112
-
113
- # Set up Gradio components with Enter key for sending
114
- msg.submit(handle_text_message, msg, chatbot)
115
- send_btn.click(handle_text_message, msg, chatbot)
116
  send_btn.click(lambda: "", None, msg) # Clear input field
117
- img_upload.change(handle_image_upload, img_upload, chatbot)
118
-
119
- # Custom CSS for styling without usernames
120
- gr.HTML("""
121
- <style>
122
- #chatbot .message-container {
123
- display: flex;
124
- flex-direction: column;
125
- margin-bottom: 10px;
126
- max-width: 70%;
127
- }
128
- #chatbot .message {
129
- border-radius: 15px;
130
- padding: 10px;
131
- margin: 5px 0;
132
- word-wrap: break-word;
133
- }
134
- #chatbot .message.user {
135
- background-color: #DCF8C6;
136
- margin-left: auto;
137
- text-align: right;
138
- }
139
- #chatbot .message.bot {
140
- background-color: #E1E1E1;
141
- margin-right: auto;
142
- text-align: left;
143
- }
144
- </style>
145
- """)
146
 
147
- # Launch the Gradio interface
148
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
+ from ultralytics import YOLO # Menggunakan YOLOv8 untuk deteksi objek
4
+ from PIL import Image, ImageDraw
 
5
  import torch
 
 
6
 
7
+ # Load model YOLOv8 (pastikan model ini telah di-download)
8
+ model = YOLO("yolov8n.pt") # Bisa diganti dengan model yang lebih besar jika diperlukan
 
 
9
 
10
+ # Global chat history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  chat_history = []
12
 
13
+ # Fungsi untuk chatting dengan chatbot
14
+ def chat_with_bot(message):
15
  global chat_history
16
+ bot_response = f"Bot: Saya menerima pesan Anda: '{message}'" # Placeholder response
17
  chat_history.append((message, bot_response))
18
  return chat_history
19
 
20
+ # Fungsi untuk menganalisis gambar
21
  def analyze_image(image_path):
22
  global chat_history
23
  try:
24
+ # Load gambar
25
  image = Image.open(image_path).convert("RGB")
26
+
27
+ # Prediksi objek dalam gambar
28
+ results = model(image)
29
+
30
+ # Ambil hasil deteksi
 
 
 
 
 
 
31
  detected_objects = []
32
+ image_draw = image.copy()
33
+ draw = ImageDraw.Draw(image_draw)
34
+
35
+ for result in results:
36
+ for box in result.boxes.data:
37
+ x1, y1, x2, y2, score, class_id = box.tolist()
38
+ if score > 0.5: # Hanya tampilkan objek dengan confidence score > 0.5
39
+ class_name = model.names[int(class_id)]
40
+ detected_objects.append(f"{class_name} (score: {score:.2f})")
41
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
42
+ draw.text((x1, y1), class_name, fill="red")
43
+
44
  if detected_objects:
45
  bot_response = f"Objects detected: {', '.join(detected_objects)}."
46
  else:
47
  bot_response = "No objects detected."
48
+
49
  chat_history.append(("Uploaded an image for analysis", bot_response))
50
+ return image_draw, chat_history
51
  except Exception as e:
52
  error_msg = f"Error processing the image: {str(e)}"
53
  chat_history.append(("Error during image analysis", error_msg))
54
+ return None, chat_history
55
 
56
+ # Bangun antarmuka Gradio
57
  with gr.Blocks() as demo:
58
  gr.Markdown("# Ken Chatbot")
59
  gr.Markdown("Ask me anything or upload an image for analysis!")
60
 
 
61
  chatbot = gr.Chatbot(elem_id="chatbot")
 
 
62
  msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
63
  send_btn = gr.Button("Send")
64
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
65
+ img_output = gr.Image(label="Detected Objects")
66
 
67
+ msg.submit(chat_with_bot, msg, chatbot)
68
+ send_btn.click(chat_with_bot, msg, chatbot)
 
 
 
 
 
 
 
 
69
  send_btn.click(lambda: "", None, msg) # Clear input field
70
+ img_upload.change(analyze_image, img_upload, [img_output, chatbot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
72
  demo.launch()