muhammadsalmanalfaridzi commited on
Commit
bd977a3
·
verified ·
1 Parent(s): 98fb533

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -111
app.py CHANGED
@@ -1,14 +1,14 @@
1
- import gradio as gr
2
- from dotenv import load_dotenv
3
- from roboflow import Roboflow
4
- import tempfile
5
  import os
 
 
6
  import requests
7
  import cv2
8
- import numpy as np
 
 
9
  import subprocess
10
 
11
- # ========== Konfigurasi ==========
12
  load_dotenv()
13
 
14
  # Roboflow Config
@@ -17,97 +17,96 @@ workspace = os.getenv("ROBOFLOW_WORKSPACE")
17
  project_name = os.getenv("ROBOFLOW_PROJECT")
18
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
19
 
20
- # CountGD Config
21
- COUNTGD_PROMPT = "beverage . bottle . cans . mixed box" # Sesuaikan prompt sesuai kebutuhan
22
- COUNTGD_API_KEY = os.getenv("COUNTGD_API_KEY") # API key CountGD
23
 
24
- # Inisialisasi Model YOLO dari Roboflow
25
  rf = Roboflow(api_key=rf_api_key)
26
  project = rf.workspace(workspace).project(project_name)
27
  yolo_model = project.version(model_version).model
28
 
29
- # ========== Fungsi Deteksi Kombinasi ==========
30
  def detect_combined(image):
31
- # Simpan gambar ke file temporer
32
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
33
  image.save(temp_file, format="JPEG")
34
  temp_path = temp_file.name
35
 
36
  try:
37
- # ========== [1] Deteksi Produk Nestlé dengan YOLO ==========
38
  yolo_pred = yolo_model.predict(temp_path, confidence=50, overlap=80).json()
39
 
40
- # Hitung per kelas dan simpan bounding box (format: (x_center, y_center, width, height))
41
  nestle_class_count = {}
42
  nestle_boxes = []
43
- for pred in yolo_pred.get('predictions', []):
44
  class_name = pred['class']
45
  nestle_class_count[class_name] = nestle_class_count.get(class_name, 0) + 1
46
  nestle_boxes.append((pred['x'], pred['y'], pred['width'], pred['height']))
47
 
48
  total_nestle = sum(nestle_class_count.values())
49
 
50
- # ========== [2] Deteksi Kompetitor dengan CountGD ==========
51
- countgd_url = "https://api.landing.ai/v1/tools/text-to-object-detection"
52
- with open(temp_path, "rb") as image_file:
53
- files = {"image": image_file}
54
- data = {
55
- "prompts": [COUNTGD_PROMPT],
56
- "model": "countgd"
57
- }
58
- headers = {
59
- "Authorization": f"Basic {COUNTGD_API_KEY}",
60
- "Content-Type": "multipart/form-data"
61
- }
62
- response = requests.post(countgd_url, files=files, data=data, headers=headers)
63
- countgd_pred = response.json()
 
64
 
 
65
  competitor_class_count = {}
66
  competitor_boxes = []
67
- # Asumsikan respons JSON mengandung key "predictions" berupa daftar objek
68
- for obj in countgd_pred.get("predictions", []):
69
- countgd_box = obj.get("bbox") # Format: [x1, y1, x2, y2]
70
- # Lakukan filter untuk menghindari duplikasi dengan deteksi YOLO
71
- if not is_overlap(countgd_box, nestle_boxes):
72
- class_name = obj.get("class", "").strip().lower()
73
  competitor_class_count[class_name] = competitor_class_count.get(class_name, 0) + 1
74
  competitor_boxes.append({
75
  "class": class_name,
76
- "box": countgd_box,
77
- "confidence": obj.get("score", 0)
78
  })
79
 
80
  total_competitor = sum(competitor_class_count.values())
81
 
82
- # ========== [3] Format Output ==========
83
- result_text = "Product Nestlé\n\n"
84
  for class_name, count in nestle_class_count.items():
85
  result_text += f"{class_name}: {count}\n"
86
- result_text += f"\nTotal Products Nestlé: {total_nestle}\n\n"
 
 
87
  if competitor_class_count:
88
  result_text += f"Total Unclassified Products: {total_competitor}\n"
89
  else:
90
  result_text += "No Unclassified Products detected\n"
91
 
92
- # ========== [4] Visualisasi ==========
93
  img = cv2.imread(temp_path)
94
- # Tandai bounding box untuk produk Nestlé (warna hijau)
95
- for pred in yolo_pred.get('predictions', []):
 
96
  x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
97
- cv2.rectangle(img, (int(x - w/2), int(y - h/2)), (int(x + w/2), int(y + h/2)), (0, 255, 0), 2)
98
- cv2.putText(img, pred['class'], (int(x - w/2), int(y - h/2 - 10)),
99
- cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 3)
100
 
101
- # Tandai bounding box untuk kompetitor (warna merah)
102
  for comp in competitor_boxes:
103
  x1, y1, x2, y2 = comp['box']
104
- # Ubah nama kelas menjadi 'unclassified' jika sesuai dengan daftar target
105
- unclassified_classes = ["beverage", "cans", "bottle", "mixed box"]
106
- display_name = "unclassified" if any(uc in comp['class'] for uc in unclassified_classes) else comp['class']
107
  cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
108
- cv2.putText(img, f"{display_name} {comp['confidence']:.2f}",
109
- (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 3)
110
-
111
  output_path = "/tmp/combined_output.jpg"
112
  cv2.imwrite(output_path, img)
113
 
@@ -118,29 +117,8 @@ def detect_combined(image):
118
  finally:
119
  os.remove(temp_path)
120
 
121
- def is_overlap(box1, boxes2, threshold=0.3):
122
- """
123
- Fungsi untuk mendeteksi overlap antara bounding box dari CountGD (format: [x1, y1, x2, y2])
124
- dan bounding box YOLO (format: (x_center, y_center, width, height)).
125
- """
126
- x1_min, y1_min, x1_max, y1_max = box1
127
- for b2 in boxes2:
128
- x2, y2, w2, h2 = b2
129
- x2_min = x2 - w2/2
130
- x2_max = x2 + w2/2
131
- y2_min = y2 - h2/2
132
- y2_max = y2 + h2/2
133
-
134
- dx = min(x1_max, x2_max) - max(x1_min, x2_min)
135
- dy = min(y1_max, y2_max) - max(y1_min, y2_min)
136
- if dx >= 0 and dy >= 0:
137
- area_overlap = dx * dy
138
- area_box1 = (x1_max - x1_min) * (y1_max - y1_min)
139
- if area_overlap / area_box1 > threshold:
140
- return True
141
- return False
142
-
143
- # ========== Fungsi untuk Deteksi Video ==========
144
  def convert_video_to_mp4(input_path, output_path):
145
  try:
146
  subprocess.run(['ffmpeg', '-i', input_path, '-vcodec', 'libx264', '-acodec', 'aac', output_path], check=True)
@@ -152,20 +130,22 @@ def detect_objects_in_video(video_path):
152
  temp_output_path = "/tmp/output_video.mp4"
153
  temp_frames_dir = tempfile.mkdtemp()
154
  frame_count = 0
155
- previous_detections = {}
156
 
157
  try:
 
158
  if not video_path.endswith(".mp4"):
159
  video_path, err = convert_video_to_mp4(video_path, temp_output_path)
160
  if not video_path:
161
  return None, f"Video conversion error: {err}"
162
 
 
163
  video = cv2.VideoCapture(video_path)
164
  frame_rate = int(video.get(cv2.CAP_PROP_FPS))
165
  frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
166
  frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
167
  frame_size = (frame_width, frame_height)
168
 
 
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
  output_video = cv2.VideoWriter(temp_output_path, fourcc, frame_rate, frame_size)
171
 
@@ -174,54 +154,47 @@ def detect_objects_in_video(video_path):
174
  if not ret:
175
  break
176
 
 
177
  frame_path = os.path.join(temp_frames_dir, f"frame_{frame_count}.jpg")
178
  cv2.imwrite(frame_path, frame)
179
 
180
- predictions = yolo_model.predict(frame_path, confidence=50, overlap=80).json()
181
- current_detections = {}
182
- for prediction in predictions.get('predictions', []):
183
- class_name = prediction['class']
184
- x, y, w, h = prediction['x'], prediction['y'], prediction['width'], prediction['height']
185
- object_id = f"{class_name}_{x}_{y}_{w}_{h}"
186
- if object_id not in current_detections:
187
- current_detections[object_id] = class_name
188
-
189
- cv2.rectangle(frame, (int(x - w/2), int(y - h/2)),
190
- (int(x + w/2), int(y + h/2)), (0, 255, 0), 2)
191
- cv2.putText(frame, class_name, (int(x - w/2), int(y - h/2 - 10)),
192
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
193
-
194
- object_counts = {}
195
- for detection_id, class_name in current_detections.items():
196
- object_counts[class_name] = object_counts.get(class_name, 0) + 1
197
-
198
- count_text = ""
199
- total_product_count = 0
200
- for class_name, count in object_counts.items():
201
- count_text += f"{class_name}: {count}\n"
202
- total_product_count += count
203
- count_text += f"\nTotal Product: {total_product_count}"
204
-
205
- y_offset = 20
206
- for line in count_text.split("\n"):
207
- cv2.putText(frame, line, (10, y_offset),
208
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
209
- y_offset += 30
210
-
211
  output_video.write(frame)
212
  frame_count += 1
213
- previous_detections = current_detections
214
 
215
  video.release()
216
  output_video.release()
 
217
  return temp_output_path
218
 
219
  except Exception as e:
220
  return None, f"An error occurred: {e}"
221
 
222
- # ========== Gradio Interface ==========
223
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as iface:
224
  gr.Markdown("""<div style="text-align: center;"><h1>NESTLE - STOCK COUNTING</h1></div>""")
 
225
  with gr.Row():
226
  with gr.Column():
227
  input_image = gr.Image(type="pil", label="Input Image")
@@ -229,10 +202,11 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", ne
229
  output_image = gr.Image(label="Detect Object")
230
  output_text = gr.Textbox(label="Counting Object")
231
  detect_image_button.click(fn=detect_combined, inputs=input_image, outputs=[output_image, output_text])
 
232
  with gr.Column():
233
  input_video = gr.Video(label="Input Video")
234
  detect_video_button = gr.Button("Detect Video")
235
  output_video = gr.Video(label="Output Video")
236
  detect_video_button.click(fn=detect_objects_in_video, inputs=input_video, outputs=[output_video])
237
-
238
  iface.launch()
 
 
 
 
 
1
  import os
2
+ import numpy as np
3
+ import tempfile
4
  import requests
5
  import cv2
6
+ import gradio as gr
7
+ from dotenv import load_dotenv
8
+ from roboflow import Roboflow
9
  import subprocess
10
 
11
+ # ========== Konfigurasi ==========
12
  load_dotenv()
13
 
14
  # Roboflow Config
 
17
  project_name = os.getenv("ROBOFLOW_PROJECT")
18
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
19
 
20
+ # countgd Model Configuration
21
+ COUNTGD_API_KEY = os.getenv("COUNTGD_API_KEY")
22
+ COUNTGD_MODEL_URL = "https://api.landing.ai/v1/tools/countgd-object-detection" # Replace with the correct API endpoint
23
 
24
+ # Inisialisasi Model
25
  rf = Roboflow(api_key=rf_api_key)
26
  project = rf.workspace(workspace).project(project_name)
27
  yolo_model = project.version(model_version).model
28
 
29
+ # ========== Fungsi Deteksi Kombinasi ==========
30
  def detect_combined(image):
 
31
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
32
  image.save(temp_file, format="JPEG")
33
  temp_path = temp_file.name
34
 
35
  try:
36
+ # ========== [1] YOLO: Deteksi Produk Nestlé (Per Class) ==========
37
  yolo_pred = yolo_model.predict(temp_path, confidence=50, overlap=80).json()
38
 
39
+ # Hitung per class Nestlé
40
  nestle_class_count = {}
41
  nestle_boxes = []
42
+ for pred in yolo_pred['predictions']:
43
  class_name = pred['class']
44
  nestle_class_count[class_name] = nestle_class_count.get(class_name, 0) + 1
45
  nestle_boxes.append((pred['x'], pred['y'], pred['width'], pred['height']))
46
 
47
  total_nestle = sum(nestle_class_count.values())
48
 
49
+ # ========== [2] countgd: Deteksi Produk dengan countgd Model ==========
50
+ # Make a request to the countgd model API (adjust parameters accordingly)
51
+ with open(temp_path, 'rb') as img_file:
52
+ response = requests.post(
53
+ COUNTGD_MODEL_URL,
54
+ headers={"Authorization": f"Bearer {COUNTGD_API_KEY}"},
55
+ files={"image": img_file},
56
+ data={"prompts": ["water bottle", "beverage can"]}
57
+ )
58
+
59
+ # Handle the response from the countgd model
60
+ if response.status_code == 200:
61
+ countgd_pred = response.json()['detections']
62
+ else:
63
+ return temp_path, f"Error calling countgd API: {response.text}"
64
 
65
+ # Filter & Hitung Kompetitor
66
  competitor_class_count = {}
67
  competitor_boxes = []
68
+ for obj in countgd_pred:
69
+ # Filter and process the detections
70
+ class_name = obj['label']
71
+ if class_name.lower() in ['water bottle', 'beverage can']: # Modify this as needed
 
 
72
  competitor_class_count[class_name] = competitor_class_count.get(class_name, 0) + 1
73
  competitor_boxes.append({
74
  "class": class_name,
75
+ "box": obj['bbox'],
76
+ "confidence": obj['score']
77
  })
78
 
79
  total_competitor = sum(competitor_class_count.values())
80
 
81
+ # ========== [3] Format Output ==========
82
+ result_text = "Product Nestle\n\n"
83
  for class_name, count in nestle_class_count.items():
84
  result_text += f"{class_name}: {count}\n"
85
+ result_text += f"\nTotal Products Nestle: {total_nestle}\n\n"
86
+
87
+ # Unclassified Products (from countgd model)
88
  if competitor_class_count:
89
  result_text += f"Total Unclassified Products: {total_competitor}\n"
90
  else:
91
  result_text += "No Unclassified Products detected\n"
92
 
93
+ # ========== [4] Visualisasi ==========
94
  img = cv2.imread(temp_path)
95
+
96
+ # Nestlé (Hijau)
97
+ for pred in yolo_pred['predictions']:
98
  x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
99
+ cv2.rectangle(img, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (0,255,0), 2)
100
+ cv2.putText(img, pred['class'], (int(x-w/2), int(y-h/2-10)),
101
+ cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0,255,0), 2)
102
 
103
+ # Kompetitor (Merah) with countgd detections
104
  for comp in competitor_boxes:
105
  x1, y1, x2, y2 = comp['box']
 
 
 
106
  cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
107
+ cv2.putText(img, f"{comp['class']} {comp['confidence']:.2f}",
108
+ (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 255), 2)
109
+
110
  output_path = "/tmp/combined_output.jpg"
111
  cv2.imwrite(output_path, img)
112
 
 
117
  finally:
118
  os.remove(temp_path)
119
 
120
+ # ========== Fungsi untuk Deteksi Video ==========
121
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def convert_video_to_mp4(input_path, output_path):
123
  try:
124
  subprocess.run(['ffmpeg', '-i', input_path, '-vcodec', 'libx264', '-acodec', 'aac', output_path], check=True)
 
130
  temp_output_path = "/tmp/output_video.mp4"
131
  temp_frames_dir = tempfile.mkdtemp()
132
  frame_count = 0
 
133
 
134
  try:
135
+ # Convert video to MP4 if necessary
136
  if not video_path.endswith(".mp4"):
137
  video_path, err = convert_video_to_mp4(video_path, temp_output_path)
138
  if not video_path:
139
  return None, f"Video conversion error: {err}"
140
 
141
+ # Read video and process frames
142
  video = cv2.VideoCapture(video_path)
143
  frame_rate = int(video.get(cv2.CAP_PROP_FPS))
144
  frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
145
  frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
146
  frame_size = (frame_width, frame_height)
147
 
148
+ # VideoWriter for output video
149
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
150
  output_video = cv2.VideoWriter(temp_output_path, fourcc, frame_rate, frame_size)
151
 
 
154
  if not ret:
155
  break
156
 
157
+ # Process predictions for the current frame using countgd model (same as in image detection)
158
  frame_path = os.path.join(temp_frames_dir, f"frame_{frame_count}.jpg")
159
  cv2.imwrite(frame_path, frame)
160
 
161
+ # Get predictions from countgd (adjust accordingly for video frames)
162
+ response = requests.post(
163
+ COUNTGD_MODEL_URL,
164
+ headers={"Authorization": f"Bearer {COUNTGD_API_KEY}"},
165
+ files={"image": open(frame_path, 'rb')},
166
+ data={"prompts": ["water bottle", "beverage can"]}
167
+ )
168
+
169
+ # Process the response (similarly to what was done for image detection)
170
+ if response.status_code == 200:
171
+ countgd_pred = response.json()['detections']
172
+ else:
173
+ continue
174
+
175
+ # Drawing detections on frames
176
+ for obj in countgd_pred:
177
+ x1, y1, x2, y2 = obj['bbox']
178
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
179
+ cv2.putText(frame, f"{obj['label']} {obj['score']:.2f}",
180
+ (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 255), 2)
181
+
182
+ # Write processed frame to output video
 
 
 
 
 
 
 
 
 
183
  output_video.write(frame)
184
  frame_count += 1
 
185
 
186
  video.release()
187
  output_video.release()
188
+
189
  return temp_output_path
190
 
191
  except Exception as e:
192
  return None, f"An error occurred: {e}"
193
 
194
+ # ========== Gradio Interface ==========
195
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as iface:
196
  gr.Markdown("""<div style="text-align: center;"><h1>NESTLE - STOCK COUNTING</h1></div>""")
197
+
198
  with gr.Row():
199
  with gr.Column():
200
  input_image = gr.Image(type="pil", label="Input Image")
 
202
  output_image = gr.Image(label="Detect Object")
203
  output_text = gr.Textbox(label="Counting Object")
204
  detect_image_button.click(fn=detect_combined, inputs=input_image, outputs=[output_image, output_text])
205
+
206
  with gr.Column():
207
  input_video = gr.Video(label="Input Video")
208
  detect_video_button = gr.Button("Detect Video")
209
  output_video = gr.Video(label="Output Video")
210
  detect_video_button.click(fn=detect_objects_in_video, inputs=input_video, outputs=[output_video])
211
+
212
  iface.launch()