muhammadsalmanalfaridzi commited on
Commit
98fb533
·
verified ·
1 Parent(s): 7d539c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -168
app.py CHANGED
@@ -3,9 +3,10 @@ from dotenv import load_dotenv
3
  from roboflow import Roboflow
4
  import tempfile
5
  import os
 
6
  import cv2
7
  import numpy as np
8
- import vision_agent.tools as T
9
 
10
  # ========== Konfigurasi ==========
11
  load_dotenv()
@@ -17,87 +18,95 @@ project_name = os.getenv("ROBOFLOW_PROJECT")
17
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
18
 
19
  # CountGD Config
20
- COUNTGD_PROMPT = "cans . bottle" # Customize sesuai kebutuhan
 
21
 
22
- # Inisialisasi Model
23
  rf = Roboflow(api_key=rf_api_key)
24
  project = rf.workspace(workspace).project(project_name)
25
  yolo_model = project.version(model_version).model
26
 
27
  # ========== Fungsi Deteksi Kombinasi ==========
28
  def detect_combined(image):
 
29
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
30
  image.save(temp_file, format="JPEG")
31
  temp_path = temp_file.name
32
 
33
  try:
34
- # ========== [1] YOLO: Deteksi Produk Nestlé ==========
35
- yolo_pred = yolo_model.predict(temp_path, confidence=60, overlap=80).json()
36
-
37
- # Hitung per class Nestlé
38
  nestle_class_count = {}
39
  nestle_boxes = []
40
- for pred in yolo_pred['predictions']:
41
  class_name = pred['class']
42
  nestle_class_count[class_name] = nestle_class_count.get(class_name, 0) + 1
43
  nestle_boxes.append((pred['x'], pred['y'], pred['width'], pred['height']))
44
 
45
  total_nestle = sum(nestle_class_count.values())
46
 
47
- # ========== [2] CountGD: Deteksi Kompetitor ==========
48
- img = cv2.imread(temp_path)
49
- prompts = [p.strip() for p in COUNTGD_PROMPT.split('.') if p.strip()]
50
-
51
- competitor_detections = []
52
- for prompt in prompts:
53
- dets = T.countgd_object_detection(prompt, img)
54
- competitor_detections.extend(dets)
55
-
56
- # Filter & Hitung Kompetitor
 
 
 
 
 
57
  competitor_class_count = {}
58
  competitor_boxes = []
59
- for det in competitor_detections:
60
- bbox = det['bbox']
61
- class_name = det['class_name']
62
-
63
- if not is_overlap(bbox, nestle_boxes):
 
64
  competitor_class_count[class_name] = competitor_class_count.get(class_name, 0) + 1
65
  competitor_boxes.append({
66
  "class": class_name,
67
- "box": bbox,
68
- "confidence": det['score']
69
  })
70
 
71
  total_competitor = sum(competitor_class_count.values())
72
 
73
  # ========== [3] Format Output ==========
74
- result_text = "Product Nestle\n\n"
75
  for class_name, count in nestle_class_count.items():
76
  result_text += f"{class_name}: {count}\n"
77
- result_text += f"\nTotal Product Nestle: {total_nestle}\n\n"
78
-
79
- result_text += "Competitor Products\n\n"
80
  if competitor_class_count:
81
- for class_name, count in competitor_class_count.items():
82
- result_text += f"{class_name}: {count}\n"
83
  else:
84
- result_text += "No competitors detected\n"
85
- result_text += f"\nTotal Competitor: {total_competitor}"
86
 
87
  # ========== [4] Visualisasi ==========
88
- # Gambar bounding box Nestlé (Hijau)
89
- for pred in yolo_pred['predictions']:
 
90
  x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
91
- cv2.rectangle(img, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (0,255,0), 2)
92
- cv2.putText(img, pred['class'], (int(x-w/2), int(y-h/2-10)),
93
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
94
 
95
- # Gambar bounding box Kompetitor (Merah)
96
  for comp in competitor_boxes:
97
- x1, y1, x2, y2 = map(int, comp['box'])
98
- cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 2)
99
- cv2.putText(img, f"{comp['class']} {comp['confidence']:.2f}",
100
- (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 2)
 
 
 
101
 
102
  output_path = "/tmp/combined_output.jpg"
103
  cv2.imwrite(output_path, img)
@@ -109,151 +118,121 @@ def detect_combined(image):
109
  finally:
110
  os.remove(temp_path)
111
 
112
- def is_overlap(countgd_bbox, yolo_boxes, iou_threshold=0.3):
113
  """
114
- Deteksi overlap menggunakan Intersection over Union (IoU)
115
- Format bbox:
116
- - CountGD: [x1, y1, x2, y2]
117
- - YOLO: (x_center, y_center, width, height)
118
  """
119
- # Convert YOLO boxes to [x1,y1,x2,y2]
120
- yolo_boxes_converted = []
121
- for yb in yolo_boxes:
122
- x_center, y_center, width, height = yb
123
- x1 = x_center - width/2
124
- y1 = y_center - height/2
125
- x2 = x_center + width/2
126
- y2 = y_center + height/2
127
- yolo_boxes_converted.append((x1, y1, x2, y2))
128
-
129
- # Convert CountGD bbox to [x1,y1,x2,y2]
130
- countgd_x1, countgd_y1, countgd_x2, countgd_y2 = countgd_bbox
131
-
132
- # Hitung IoU dengan semua YOLO boxes
133
- for yolo_bbox in yolo_boxes_converted:
134
- yolo_x1, yolo_y1, yolo_x2, yolo_y2 = yolo_bbox
135
-
136
- # Hitung area intersection
137
- x_left = max(countgd_x1, yolo_x1)
138
- y_top = max(countgd_y1, yolo_y1)
139
- x_right = min(countgd_x2, yolo_x2)
140
- y_bottom = min(countgd_y2, yolo_y2)
141
-
142
- if x_right < x_left or y_bottom < y_top:
143
- continue
144
-
145
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
146
-
147
- # Hitung area union
148
- countgd_area = (countgd_x2 - countgd_x1) * (countgd_y2 - countgd_y1)
149
- yolo_area = (yolo_x2 - yolo_x1) * (yolo_y2 - yolo_y1)
150
- union_area = countgd_area + yolo_area - intersection_area
151
-
152
- iou = intersection_area / union_area if union_area > 0 else 0
153
-
154
- if iou > iou_threshold:
155
- return True
156
  return False
157
 
158
  # ========== Fungsi untuk Deteksi Video ==========
 
 
 
 
 
 
 
159
  def detect_objects_in_video(video_path):
160
- temp_output = "/tmp/output_video.mp4"
161
- temp_dir = tempfile.mkdtemp()
162
-
 
 
163
  try:
164
- cap = cv2.VideoCapture(video_path)
165
- fps = int(cap.get(cv2.CAP_PROP_FPS))
166
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
168
-
 
 
 
 
 
 
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
- out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
171
-
172
- total_counts = {}
173
- frame_idx = 0
174
-
175
- while cap.isOpened():
176
- ret, frame = cap.read()
177
  if not ret:
178
  break
179
-
180
- frame_path = os.path.join(temp_dir, f"frame_{frame_idx}.jpg")
181
  cv2.imwrite(frame_path, frame)
182
-
183
- # Deteksi dengan YOLO
184
- predictions = yolo_model.predict(frame_path, confidence=60, overlap=80).json()
185
-
186
- # Update counts dan gambar bounding box
187
- class_count = {}
188
- for pred in predictions['predictions']:
189
- class_name = pred['class']
190
- class_count[class_name] = class_count.get(class_name, 0) + 1
191
- total_counts[class_name] = total_counts.get(class_name, 0) + 1
192
-
193
- x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
194
- cv2.rectangle(frame, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (0,255,0), 2)
195
- cv2.putText(frame, class_name, (int(x-w/2), int(y-h/2-10)),
196
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
197
-
198
- # Tampilkan counter
199
- y_pos = 30
200
- for cls, cnt in class_count.items():
201
- cv2.putText(frame, f"{cls}: {cnt}", (10, y_pos),
202
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2)
203
- y_pos += 30
204
-
205
- out.write(frame)
206
- frame_idx += 1
207
-
208
- cap.release()
209
- out.release()
210
-
211
- # Generate report
212
- result_text = "Final Counts\n\n"
213
- for cls, cnt in total_counts.items():
214
- result_text += f"{cls}: {cnt}\n"
215
- result_text += f"\nTotal: {sum(total_counts.values())}"
216
-
217
- return temp_output, result_text
218
-
 
 
 
219
  except Exception as e:
220
- return None, f"Error: {str(e)}"
221
- finally:
222
- for f in os.listdir(temp_dir):
223
- os.remove(os.path.join(temp_dir, f))
224
- os.rmdir(temp_dir)
225
 
226
  # ========== Gradio Interface ==========
227
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as iface:
228
- gr.Markdown("""
229
- <div style="text-align: center;">
230
- <h1>NESTLE - STOCK COUNTING</h1>
231
- </div>
232
- """)
233
  with gr.Row():
234
  with gr.Column():
235
  input_image = gr.Image(type="pil", label="Input Image")
236
  detect_image_button = gr.Button("Detect Image")
237
- output_image = gr.Image(label="Detection Result")
238
-
 
239
  with gr.Column():
240
  input_video = gr.Video(label="Input Video")
241
  detect_video_button = gr.Button("Detect Video")
242
- output_video = gr.Video(label="Video Result")
 
243
 
244
- with gr.Column():
245
- output_text = gr.Textbox(label="Counting Results")
246
-
247
- detect_image_button.click(
248
- fn=detect_combined,
249
- inputs=input_image,
250
- outputs=[output_image, output_text]
251
- )
252
-
253
- detect_video_button.click(
254
- fn=detect_objects_in_video,
255
- inputs=input_video,
256
- outputs=[output_video, output_text]
257
- )
258
-
259
- iface.launch()
 
3
  from roboflow import Roboflow
4
  import tempfile
5
  import os
6
+ import requests
7
  import cv2
8
  import numpy as np
9
+ import subprocess
10
 
11
  # ========== Konfigurasi ==========
12
  load_dotenv()
 
18
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
19
 
20
  # CountGD Config
21
+ COUNTGD_PROMPT = "beverage . bottle . cans . mixed box" # Sesuaikan prompt sesuai kebutuhan
22
+ COUNTGD_API_KEY = os.getenv("COUNTGD_API_KEY") # API key CountGD
23
 
24
+ # Inisialisasi Model YOLO dari Roboflow
25
  rf = Roboflow(api_key=rf_api_key)
26
  project = rf.workspace(workspace).project(project_name)
27
  yolo_model = project.version(model_version).model
28
 
29
  # ========== Fungsi Deteksi Kombinasi ==========
30
  def detect_combined(image):
31
+ # Simpan gambar ke file temporer
32
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
33
  image.save(temp_file, format="JPEG")
34
  temp_path = temp_file.name
35
 
36
  try:
37
+ # ========== [1] Deteksi Produk Nestlé dengan YOLO ==========
38
+ yolo_pred = yolo_model.predict(temp_path, confidence=50, overlap=80).json()
39
+
40
+ # Hitung per kelas dan simpan bounding box (format: (x_center, y_center, width, height))
41
  nestle_class_count = {}
42
  nestle_boxes = []
43
+ for pred in yolo_pred.get('predictions', []):
44
  class_name = pred['class']
45
  nestle_class_count[class_name] = nestle_class_count.get(class_name, 0) + 1
46
  nestle_boxes.append((pred['x'], pred['y'], pred['width'], pred['height']))
47
 
48
  total_nestle = sum(nestle_class_count.values())
49
 
50
+ # ========== [2] Deteksi Kompetitor dengan CountGD ==========
51
+ countgd_url = "https://api.landing.ai/v1/tools/text-to-object-detection"
52
+ with open(temp_path, "rb") as image_file:
53
+ files = {"image": image_file}
54
+ data = {
55
+ "prompts": [COUNTGD_PROMPT],
56
+ "model": "countgd"
57
+ }
58
+ headers = {
59
+ "Authorization": f"Basic {COUNTGD_API_KEY}",
60
+ "Content-Type": "multipart/form-data"
61
+ }
62
+ response = requests.post(countgd_url, files=files, data=data, headers=headers)
63
+ countgd_pred = response.json()
64
+
65
  competitor_class_count = {}
66
  competitor_boxes = []
67
+ # Asumsikan respons JSON mengandung key "predictions" berupa daftar objek
68
+ for obj in countgd_pred.get("predictions", []):
69
+ countgd_box = obj.get("bbox") # Format: [x1, y1, x2, y2]
70
+ # Lakukan filter untuk menghindari duplikasi dengan deteksi YOLO
71
+ if not is_overlap(countgd_box, nestle_boxes):
72
+ class_name = obj.get("class", "").strip().lower()
73
  competitor_class_count[class_name] = competitor_class_count.get(class_name, 0) + 1
74
  competitor_boxes.append({
75
  "class": class_name,
76
+ "box": countgd_box,
77
+ "confidence": obj.get("score", 0)
78
  })
79
 
80
  total_competitor = sum(competitor_class_count.values())
81
 
82
  # ========== [3] Format Output ==========
83
+ result_text = "Product Nestlé\n\n"
84
  for class_name, count in nestle_class_count.items():
85
  result_text += f"{class_name}: {count}\n"
86
+ result_text += f"\nTotal Products Nestlé: {total_nestle}\n\n"
 
 
87
  if competitor_class_count:
88
+ result_text += f"Total Unclassified Products: {total_competitor}\n"
 
89
  else:
90
+ result_text += "No Unclassified Products detected\n"
 
91
 
92
  # ========== [4] Visualisasi ==========
93
+ img = cv2.imread(temp_path)
94
+ # Tandai bounding box untuk produk Nestlé (warna hijau)
95
+ for pred in yolo_pred.get('predictions', []):
96
  x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
97
+ cv2.rectangle(img, (int(x - w/2), int(y - h/2)), (int(x + w/2), int(y + h/2)), (0, 255, 0), 2)
98
+ cv2.putText(img, pred['class'], (int(x - w/2), int(y - h/2 - 10)),
99
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 3)
100
 
101
+ # Tandai bounding box untuk kompetitor (warna merah)
102
  for comp in competitor_boxes:
103
+ x1, y1, x2, y2 = comp['box']
104
+ # Ubah nama kelas menjadi 'unclassified' jika sesuai dengan daftar target
105
+ unclassified_classes = ["beverage", "cans", "bottle", "mixed box"]
106
+ display_name = "unclassified" if any(uc in comp['class'] for uc in unclassified_classes) else comp['class']
107
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
108
+ cv2.putText(img, f"{display_name} {comp['confidence']:.2f}",
109
+ (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 3)
110
 
111
  output_path = "/tmp/combined_output.jpg"
112
  cv2.imwrite(output_path, img)
 
118
  finally:
119
  os.remove(temp_path)
120
 
121
+ def is_overlap(box1, boxes2, threshold=0.3):
122
  """
123
+ Fungsi untuk mendeteksi overlap antara bounding box dari CountGD (format: [x1, y1, x2, y2])
124
+ dan bounding box YOLO (format: (x_center, y_center, width, height)).
 
 
125
  """
126
+ x1_min, y1_min, x1_max, y1_max = box1
127
+ for b2 in boxes2:
128
+ x2, y2, w2, h2 = b2
129
+ x2_min = x2 - w2/2
130
+ x2_max = x2 + w2/2
131
+ y2_min = y2 - h2/2
132
+ y2_max = y2 + h2/2
133
+
134
+ dx = min(x1_max, x2_max) - max(x1_min, x2_min)
135
+ dy = min(y1_max, y2_max) - max(y1_min, y2_min)
136
+ if dx >= 0 and dy >= 0:
137
+ area_overlap = dx * dy
138
+ area_box1 = (x1_max - x1_min) * (y1_max - y1_min)
139
+ if area_overlap / area_box1 > threshold:
140
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return False
142
 
143
  # ========== Fungsi untuk Deteksi Video ==========
144
+ def convert_video_to_mp4(input_path, output_path):
145
+ try:
146
+ subprocess.run(['ffmpeg', '-i', input_path, '-vcodec', 'libx264', '-acodec', 'aac', output_path], check=True)
147
+ return output_path
148
+ except subprocess.CalledProcessError as e:
149
+ return None, f"Error converting video: {e}"
150
+
151
  def detect_objects_in_video(video_path):
152
+ temp_output_path = "/tmp/output_video.mp4"
153
+ temp_frames_dir = tempfile.mkdtemp()
154
+ frame_count = 0
155
+ previous_detections = {}
156
+
157
  try:
158
+ if not video_path.endswith(".mp4"):
159
+ video_path, err = convert_video_to_mp4(video_path, temp_output_path)
160
+ if not video_path:
161
+ return None, f"Video conversion error: {err}"
162
+
163
+ video = cv2.VideoCapture(video_path)
164
+ frame_rate = int(video.get(cv2.CAP_PROP_FPS))
165
+ frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
166
+ frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
167
+ frame_size = (frame_width, frame_height)
168
+
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
+ output_video = cv2.VideoWriter(temp_output_path, fourcc, frame_rate, frame_size)
171
+
172
+ while True:
173
+ ret, frame = video.read()
 
 
 
174
  if not ret:
175
  break
176
+
177
+ frame_path = os.path.join(temp_frames_dir, f"frame_{frame_count}.jpg")
178
  cv2.imwrite(frame_path, frame)
179
+
180
+ predictions = yolo_model.predict(frame_path, confidence=50, overlap=80).json()
181
+ current_detections = {}
182
+ for prediction in predictions.get('predictions', []):
183
+ class_name = prediction['class']
184
+ x, y, w, h = prediction['x'], prediction['y'], prediction['width'], prediction['height']
185
+ object_id = f"{class_name}_{x}_{y}_{w}_{h}"
186
+ if object_id not in current_detections:
187
+ current_detections[object_id] = class_name
188
+
189
+ cv2.rectangle(frame, (int(x - w/2), int(y - h/2)),
190
+ (int(x + w/2), int(y + h/2)), (0, 255, 0), 2)
191
+ cv2.putText(frame, class_name, (int(x - w/2), int(y - h/2 - 10)),
192
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
193
+
194
+ object_counts = {}
195
+ for detection_id, class_name in current_detections.items():
196
+ object_counts[class_name] = object_counts.get(class_name, 0) + 1
197
+
198
+ count_text = ""
199
+ total_product_count = 0
200
+ for class_name, count in object_counts.items():
201
+ count_text += f"{class_name}: {count}\n"
202
+ total_product_count += count
203
+ count_text += f"\nTotal Product: {total_product_count}"
204
+
205
+ y_offset = 20
206
+ for line in count_text.split("\n"):
207
+ cv2.putText(frame, line, (10, y_offset),
208
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
209
+ y_offset += 30
210
+
211
+ output_video.write(frame)
212
+ frame_count += 1
213
+ previous_detections = current_detections
214
+
215
+ video.release()
216
+ output_video.release()
217
+ return temp_output_path
218
+
219
  except Exception as e:
220
+ return None, f"An error occurred: {e}"
 
 
 
 
221
 
222
  # ========== Gradio Interface ==========
223
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as iface:
224
+ gr.Markdown("""<div style="text-align: center;"><h1>NESTLE - STOCK COUNTING</h1></div>""")
 
 
 
 
225
  with gr.Row():
226
  with gr.Column():
227
  input_image = gr.Image(type="pil", label="Input Image")
228
  detect_image_button = gr.Button("Detect Image")
229
+ output_image = gr.Image(label="Detect Object")
230
+ output_text = gr.Textbox(label="Counting Object")
231
+ detect_image_button.click(fn=detect_combined, inputs=input_image, outputs=[output_image, output_text])
232
  with gr.Column():
233
  input_video = gr.Video(label="Input Video")
234
  detect_video_button = gr.Button("Detect Video")
235
+ output_video = gr.Video(label="Output Video")
236
+ detect_video_button.click(fn=detect_objects_in_video, inputs=input_video, outputs=[output_video])
237
 
238
+ iface.launch()