muhammadsalmanalfaridzi commited on
Commit
7d539c2
·
verified ·
1 Parent(s): 6f61b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -144
app.py CHANGED
@@ -3,10 +3,9 @@ from dotenv import load_dotenv
3
  from roboflow import Roboflow
4
  import tempfile
5
  import os
6
- import requests
7
  import cv2
8
  import numpy as np
9
- import subprocess
10
 
11
  # ========== Konfigurasi ==========
12
  load_dotenv()
@@ -18,10 +17,9 @@ project_name = os.getenv("ROBOFLOW_PROJECT")
18
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
19
 
20
  # CountGD Config
21
- # Prompt yang digunakan untuk mendeteksi objek kompetitor
22
- COUNTGD_PROMPT = "beverage . bottle . cans . mixed box"
23
 
24
- # Inisialisasi Model YOLO dari Roboflow
25
  rf = Roboflow(api_key=rf_api_key)
26
  project = rf.workspace(workspace).project(project_name)
27
  yolo_model = project.version(model_version).model
@@ -33,10 +31,10 @@ def detect_combined(image):
33
  temp_path = temp_file.name
34
 
35
  try:
36
- # ========== [1] YOLO: Deteksi Produk Nestlé (Per Class) ==========
37
- yolo_pred = yolo_model.predict(temp_path, confidence=50, overlap=80).json()
38
-
39
- # Hitung per class Nestlé dan simpan bounding box (format: (x_center, y_center, width, height))
40
  nestle_class_count = {}
41
  nestle_boxes = []
42
  for pred in yolo_pred['predictions']:
@@ -46,32 +44,28 @@ def detect_combined(image):
46
 
47
  total_nestle = sum(nestle_class_count.values())
48
 
49
- # ========== [2] COUNTGD: Deteksi Kompetitor ==========
50
- # Mengirimkan request ke endpoint CountGD sesuai dokumentasi:
51
- # https://va.landing.ai/demo/api/Countgd%20Counting
52
- countgd_url = "https://api.landing.ai/v1/tools/text-to-object-detection"
53
- with open(temp_path, "rb") as image_file:
54
- files = {"image": image_file}
55
- data = {
56
- "prompts": [COUNTGD_PROMPT],
57
- "model": "countgd"
58
- }
59
- response = requests.post(countgd_url, files=files, data=data)
60
- # Asumsikan respons JSON mengandung key "predictions" dengan daftar objek
61
- countgd_pred = response.json()
62
-
63
  competitor_class_count = {}
64
  competitor_boxes = []
65
- for obj in countgd_pred.get("predictions", []):
66
- countgd_box = obj["bbox"] # Format: [x1, y1, x2, y2]
67
- # Filter objek yang sudah terdeteksi oleh YOLO (menghindari duplikasi deteksi)
68
- if not is_overlap(countgd_box, nestle_boxes):
69
- class_name = obj["class"].strip().lower()
70
  competitor_class_count[class_name] = competitor_class_count.get(class_name, 0) + 1
71
  competitor_boxes.append({
72
  "class": class_name,
73
- "box": countgd_box,
74
- "confidence": obj["score"]
75
  })
76
 
77
  total_competitor = sum(competitor_class_count.values())
@@ -80,31 +74,30 @@ def detect_combined(image):
80
  result_text = "Product Nestle\n\n"
81
  for class_name, count in nestle_class_count.items():
82
  result_text += f"{class_name}: {count}\n"
83
- result_text += f"\nTotal Products Nestle: {total_nestle}\n\n"
84
 
 
85
  if competitor_class_count:
86
- result_text += f"Total Unclassified Products: {total_competitor}\n"
 
87
  else:
88
- result_text += "No Unclassified Products detected\n"
 
89
 
90
  # ========== [4] Visualisasi ==========
91
- img = cv2.imread(temp_path)
92
-
93
- # Tandai deteksi produk Nestlé (Hijau)
94
  for pred in yolo_pred['predictions']:
95
  x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
96
- cv2.rectangle(img, (int(x - w/2), int(y - h/2)), (int(x + w/2), int(y + h/2)), (0, 255, 0), 2)
97
- cv2.putText(img, pred['class'], (int(x - w/2), int(y - h/2 - 10)),
98
- cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 3)
99
 
100
- # Tandai deteksi kompetitor (Merah), dengan pengecekan untuk merubah nama kelas menjadi 'unclassified'
101
  for comp in competitor_boxes:
102
- x1, y1, x2, y2 = comp['box']
103
- unclassified_classes = ["beverage", "cans", "bottle", "mixed box"]
104
- display_name = "unclassified" if any(c in comp['class'].lower() for c in unclassified_classes) else comp['class']
105
- cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
106
- cv2.putText(img, f"{display_name} {comp['confidence']:.2f}",
107
- (int(x1), int(y1 - 10)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 3)
108
 
109
  output_path = "/tmp/combined_output.jpg"
110
  cv2.imwrite(output_path, img)
@@ -116,121 +109,151 @@ def detect_combined(image):
116
  finally:
117
  os.remove(temp_path)
118
 
119
- def is_overlap(box1, boxes2, threshold=0.3):
120
- # Fungsi untuk mendeteksi overlap antara bounding box
121
- x1_min, y1_min, x1_max, y1_max = box1
122
- for b2 in boxes2:
123
- x2, y2, w2, h2 = b2
124
- x2_min = x2 - w2/2
125
- x2_max = x2 + w2/2
126
- y2_min = y2 - h2/2
127
- y2_max = y2 + h2/2
128
-
129
- dx = min(x1_max, x2_max) - max(x1_min, x2_min)
130
- dy = min(y1_max, y2_max) - max(y1_min, y2_min)
131
- if (dx >= 0) and (dy >= 0):
132
- area_overlap = dx * dy
133
- area_box1 = (x1_max - x1_min) * (y1_max - y1_min)
134
- if area_overlap / area_box1 > threshold:
135
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return False
137
 
138
- # ========== Fungsi Deteksi Video (tetap menggunakan YOLO) ==========
139
- def convert_video_to_mp4(input_path, output_path):
140
- try:
141
- subprocess.run(['ffmpeg', '-i', input_path, '-vcodec', 'libx264', '-acodec', 'aac', output_path], check=True)
142
- return output_path
143
- except subprocess.CalledProcessError as e:
144
- return None, f"Error converting video: {e}"
145
-
146
  def detect_objects_in_video(video_path):
147
- temp_output_path = "/tmp/output_video.mp4"
148
- temp_frames_dir = tempfile.mkdtemp()
149
- frame_count = 0
150
- previous_detections = {}
151
-
152
  try:
153
- if not video_path.endswith(".mp4"):
154
- video_path, err = convert_video_to_mp4(video_path, temp_output_path)
155
- if not video_path:
156
- return None, f"Video conversion error: {err}"
157
-
158
- video = cv2.VideoCapture(video_path)
159
- frame_rate = int(video.get(cv2.CAP_PROP_FPS))
160
- frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
161
- frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
162
- frame_size = (frame_width, frame_height)
163
-
164
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
165
- output_video = cv2.VideoWriter(temp_output_path, fourcc, frame_rate, frame_size)
166
-
167
- while True:
168
- ret, frame = video.read()
 
 
 
169
  if not ret:
170
  break
171
-
172
- frame_path = os.path.join(temp_frames_dir, f"frame_{frame_count}.jpg")
173
  cv2.imwrite(frame_path, frame)
174
-
175
- predictions = yolo_model.predict(frame_path, confidence=50, overlap=80).json()
176
-
177
- current_detections = {}
178
- for prediction in predictions['predictions']:
179
- class_name = prediction['class']
180
- x, y, w, h = prediction['x'], prediction['y'], prediction['width'], prediction['height']
181
- object_id = f"{class_name}_{x}_{y}_{w}_{h}"
182
- if object_id not in current_detections:
183
- current_detections[object_id] = class_name
184
-
185
- cv2.rectangle(frame, (int(x - w/2), int(y - h/2)), (int(x + w/2), int(y + h/2)), (0, 255, 0), 2)
186
- cv2.putText(frame, class_name, (int(x - w/2), int(y - h/2 - 10)),
187
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
188
-
189
- object_counts = {}
190
- for detection_id in current_detections.keys():
191
- class_name = current_detections[detection_id]
192
- object_counts[class_name] = object_counts.get(class_name, 0) + 1
193
-
194
- count_text = ""
195
- total_product_count = 0
196
- for class_name, count in object_counts.items():
197
- count_text += f"{class_name}: {count}\n"
198
- total_product_count += count
199
- count_text += f"\nTotal Product: {total_product_count}"
200
-
201
- y_offset = 20
202
- for line in count_text.split("\n"):
203
- cv2.putText(frame, line, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
204
- y_offset += 30
205
-
206
- output_video.write(frame)
207
- frame_count += 1
208
- previous_detections = current_detections
209
-
210
- video.release()
211
- output_video.release()
212
-
213
- return temp_output_path
214
-
215
  except Exception as e:
216
- return None, f"An error occurred: {e}"
 
 
 
 
217
 
218
  # ========== Gradio Interface ==========
219
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as iface:
220
- gr.Markdown("""<div style="text-align: center;"><h1>NESTLE - STOCK COUNTING</h1></div>""")
221
-
 
 
 
222
  with gr.Row():
223
  with gr.Column():
224
  input_image = gr.Image(type="pil", label="Input Image")
225
  detect_image_button = gr.Button("Detect Image")
226
- output_image = gr.Image(label="Detect Object")
227
- output_text = gr.Textbox(label="Counting Object")
228
- detect_image_button.click(fn=detect_combined, inputs=input_image, outputs=[output_image, output_text])
229
-
230
  with gr.Column():
231
  input_video = gr.Video(label="Input Video")
232
  detect_video_button = gr.Button("Detect Video")
233
- output_video = gr.Video(label="Output Video")
234
- detect_video_button.click(fn=detect_objects_in_video, inputs=input_video, outputs=[output_video])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- iface.launch()
 
3
  from roboflow import Roboflow
4
  import tempfile
5
  import os
 
6
  import cv2
7
  import numpy as np
8
+ import vision_agent.tools as T
9
 
10
  # ========== Konfigurasi ==========
11
  load_dotenv()
 
17
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
18
 
19
  # CountGD Config
20
+ COUNTGD_PROMPT = "cans . bottle" # Customize sesuai kebutuhan
 
21
 
22
+ # Inisialisasi Model
23
  rf = Roboflow(api_key=rf_api_key)
24
  project = rf.workspace(workspace).project(project_name)
25
  yolo_model = project.version(model_version).model
 
31
  temp_path = temp_file.name
32
 
33
  try:
34
+ # ========== [1] YOLO: Deteksi Produk Nestlé ==========
35
+ yolo_pred = yolo_model.predict(temp_path, confidence=60, overlap=80).json()
36
+
37
+ # Hitung per class Nestlé
38
  nestle_class_count = {}
39
  nestle_boxes = []
40
  for pred in yolo_pred['predictions']:
 
44
 
45
  total_nestle = sum(nestle_class_count.values())
46
 
47
+ # ========== [2] CountGD: Deteksi Kompetitor ==========
48
+ img = cv2.imread(temp_path)
49
+ prompts = [p.strip() for p in COUNTGD_PROMPT.split('.') if p.strip()]
50
+
51
+ competitor_detections = []
52
+ for prompt in prompts:
53
+ dets = T.countgd_object_detection(prompt, img)
54
+ competitor_detections.extend(dets)
55
+
56
+ # Filter & Hitung Kompetitor
 
 
 
 
57
  competitor_class_count = {}
58
  competitor_boxes = []
59
+ for det in competitor_detections:
60
+ bbox = det['bbox']
61
+ class_name = det['class_name']
62
+
63
+ if not is_overlap(bbox, nestle_boxes):
64
  competitor_class_count[class_name] = competitor_class_count.get(class_name, 0) + 1
65
  competitor_boxes.append({
66
  "class": class_name,
67
+ "box": bbox,
68
+ "confidence": det['score']
69
  })
70
 
71
  total_competitor = sum(competitor_class_count.values())
 
74
  result_text = "Product Nestle\n\n"
75
  for class_name, count in nestle_class_count.items():
76
  result_text += f"{class_name}: {count}\n"
77
+ result_text += f"\nTotal Product Nestle: {total_nestle}\n\n"
78
 
79
+ result_text += "Competitor Products\n\n"
80
  if competitor_class_count:
81
+ for class_name, count in competitor_class_count.items():
82
+ result_text += f"{class_name}: {count}\n"
83
  else:
84
+ result_text += "No competitors detected\n"
85
+ result_text += f"\nTotal Competitor: {total_competitor}"
86
 
87
  # ========== [4] Visualisasi ==========
88
+ # Gambar bounding box Nestlé (Hijau)
 
 
89
  for pred in yolo_pred['predictions']:
90
  x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
91
+ cv2.rectangle(img, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (0,255,0), 2)
92
+ cv2.putText(img, pred['class'], (int(x-w/2), int(y-h/2-10)),
93
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
94
 
95
+ # Gambar bounding box Kompetitor (Merah)
96
  for comp in competitor_boxes:
97
+ x1, y1, x2, y2 = map(int, comp['box'])
98
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 2)
99
+ cv2.putText(img, f"{comp['class']} {comp['confidence']:.2f}",
100
+ (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 2)
 
 
101
 
102
  output_path = "/tmp/combined_output.jpg"
103
  cv2.imwrite(output_path, img)
 
109
  finally:
110
  os.remove(temp_path)
111
 
112
+ def is_overlap(countgd_bbox, yolo_boxes, iou_threshold=0.3):
113
+ """
114
+ Deteksi overlap menggunakan Intersection over Union (IoU)
115
+ Format bbox:
116
+ - CountGD: [x1, y1, x2, y2]
117
+ - YOLO: (x_center, y_center, width, height)
118
+ """
119
+ # Convert YOLO boxes to [x1,y1,x2,y2]
120
+ yolo_boxes_converted = []
121
+ for yb in yolo_boxes:
122
+ x_center, y_center, width, height = yb
123
+ x1 = x_center - width/2
124
+ y1 = y_center - height/2
125
+ x2 = x_center + width/2
126
+ y2 = y_center + height/2
127
+ yolo_boxes_converted.append((x1, y1, x2, y2))
128
+
129
+ # Convert CountGD bbox to [x1,y1,x2,y2]
130
+ countgd_x1, countgd_y1, countgd_x2, countgd_y2 = countgd_bbox
131
+
132
+ # Hitung IoU dengan semua YOLO boxes
133
+ for yolo_bbox in yolo_boxes_converted:
134
+ yolo_x1, yolo_y1, yolo_x2, yolo_y2 = yolo_bbox
135
+
136
+ # Hitung area intersection
137
+ x_left = max(countgd_x1, yolo_x1)
138
+ y_top = max(countgd_y1, yolo_y1)
139
+ x_right = min(countgd_x2, yolo_x2)
140
+ y_bottom = min(countgd_y2, yolo_y2)
141
+
142
+ if x_right < x_left or y_bottom < y_top:
143
+ continue
144
+
145
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
146
+
147
+ # Hitung area union
148
+ countgd_area = (countgd_x2 - countgd_x1) * (countgd_y2 - countgd_y1)
149
+ yolo_area = (yolo_x2 - yolo_x1) * (yolo_y2 - yolo_y1)
150
+ union_area = countgd_area + yolo_area - intersection_area
151
+
152
+ iou = intersection_area / union_area if union_area > 0 else 0
153
+
154
+ if iou > iou_threshold:
155
+ return True
156
  return False
157
 
158
+ # ========== Fungsi untuk Deteksi Video ==========
 
 
 
 
 
 
 
159
  def detect_objects_in_video(video_path):
160
+ temp_output = "/tmp/output_video.mp4"
161
+ temp_dir = tempfile.mkdtemp()
162
+
 
 
163
  try:
164
+ cap = cv2.VideoCapture(video_path)
165
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
166
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
168
+
 
 
 
 
 
 
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
+ out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
171
+
172
+ total_counts = {}
173
+ frame_idx = 0
174
+
175
+ while cap.isOpened():
176
+ ret, frame = cap.read()
177
  if not ret:
178
  break
179
+
180
+ frame_path = os.path.join(temp_dir, f"frame_{frame_idx}.jpg")
181
  cv2.imwrite(frame_path, frame)
182
+
183
+ # Deteksi dengan YOLO
184
+ predictions = yolo_model.predict(frame_path, confidence=60, overlap=80).json()
185
+
186
+ # Update counts dan gambar bounding box
187
+ class_count = {}
188
+ for pred in predictions['predictions']:
189
+ class_name = pred['class']
190
+ class_count[class_name] = class_count.get(class_name, 0) + 1
191
+ total_counts[class_name] = total_counts.get(class_name, 0) + 1
192
+
193
+ x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
194
+ cv2.rectangle(frame, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (0,255,0), 2)
195
+ cv2.putText(frame, class_name, (int(x-w/2), int(y-h/2-10)),
196
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
197
+
198
+ # Tampilkan counter
199
+ y_pos = 30
200
+ for cls, cnt in class_count.items():
201
+ cv2.putText(frame, f"{cls}: {cnt}", (10, y_pos),
202
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2)
203
+ y_pos += 30
204
+
205
+ out.write(frame)
206
+ frame_idx += 1
207
+
208
+ cap.release()
209
+ out.release()
210
+
211
+ # Generate report
212
+ result_text = "Final Counts\n\n"
213
+ for cls, cnt in total_counts.items():
214
+ result_text += f"{cls}: {cnt}\n"
215
+ result_text += f"\nTotal: {sum(total_counts.values())}"
216
+
217
+ return temp_output, result_text
218
+
 
 
 
 
219
  except Exception as e:
220
+ return None, f"Error: {str(e)}"
221
+ finally:
222
+ for f in os.listdir(temp_dir):
223
+ os.remove(os.path.join(temp_dir, f))
224
+ os.rmdir(temp_dir)
225
 
226
  # ========== Gradio Interface ==========
227
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as iface:
228
+ gr.Markdown("""
229
+ <div style="text-align: center;">
230
+ <h1>NESTLE - STOCK COUNTING</h1>
231
+ </div>
232
+ """)
233
  with gr.Row():
234
  with gr.Column():
235
  input_image = gr.Image(type="pil", label="Input Image")
236
  detect_image_button = gr.Button("Detect Image")
237
+ output_image = gr.Image(label="Detection Result")
238
+
 
 
239
  with gr.Column():
240
  input_video = gr.Video(label="Input Video")
241
  detect_video_button = gr.Button("Detect Video")
242
+ output_video = gr.Video(label="Video Result")
243
+
244
+ with gr.Column():
245
+ output_text = gr.Textbox(label="Counting Results")
246
+
247
+ detect_image_button.click(
248
+ fn=detect_combined,
249
+ inputs=input_image,
250
+ outputs=[output_image, output_text]
251
+ )
252
+
253
+ detect_video_button.click(
254
+ fn=detect_objects_in_video,
255
+ inputs=input_video,
256
+ outputs=[output_video, output_text]
257
+ )
258
 
259
+ iface.launch()