muhammadsalmanalfaridzi commited on
Commit
5f4e276
·
verified ·
1 Parent(s): 83eff9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -74
app.py CHANGED
@@ -1,76 +1,108 @@
1
  import gradio as gr
2
- import supervision as sv
3
- import numpy as np
4
- import cv2
5
- from inference import get_roboflow_model
6
-
7
- # Replace with your actual Roboflow model ID and API key
8
- model_id = "nescafe-4base/46" # Replace with your Roboflow model ID
9
- api_key = "Otg64Ra6wNOgDyjuhMYU" # Replace with your Roboflow API key
10
-
11
- # Load the Roboflow model using the get_roboflow_model function
12
- model = get_roboflow_model(model_id=model_id, api_key=api_key)
13
-
14
- # Define the callback function for the SAHI slicer
15
- def callback(image_slice: np.ndarray) -> sv.Detections:
16
- # Run inference on the image slice
17
- results = model.infer(image_slice)
18
-
19
- # Check if results are in the expected format and handle accordingly
20
- if isinstance(results, tuple):
21
- results = results[0] # Extract the detections from the tuple if necessary
22
-
23
- # If the results are a list (likely from Roboflow), access them correctly
24
- detections = []
25
- if isinstance(results, list):
26
- for result in results:
27
- # Ensure each result is processed into a Detections object
28
- detections.extend(sv.Detections.from_inference(result))
29
-
30
- # Return the list of detections
31
- return detections
32
-
33
- # Initialize the SAHI Inference Slicer
34
- slicer = sv.InferenceSlicer(callback=callback)
35
-
36
- # Function to handle image processing, inference, and annotation
37
- def process_image(image):
38
- # Convert the PIL image to OpenCV format (BGR)
39
- image = np.array(image)
40
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
41
-
42
- # Run inference using SAHI (splitting the image into slices)
43
- sliced_detections = slicer(image=image)
44
-
45
- # Annotate the detections with bounding boxes and labels
46
- label_annotator = sv.LabelAnnotator()
47
- box_annotator = sv.BoxAnnotator()
48
-
49
- annotated_image = box_annotator.annotate(scene=image.copy(), detections=sliced_detections)
50
- annotated_image = label_annotator.annotate(scene=annotated_image, detections=sliced_detections)
51
-
52
- # Convert the annotated image back to RGB for display in Gradio
53
- result_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
54
-
55
- # Count the number of objects detected
56
  class_count = {}
57
- for detection in sliced_detections:
58
- class_name = detection.class_name
59
- class_count[class_name] = class_count.get(class_name, 0) + 1
60
-
61
- total_count = sum(class_count.values())
62
-
63
- return result_image, class_count, total_count
64
-
65
- # Gradio interface
66
- iface = gr.Interface(
67
- fn=process_image,
68
- inputs=gr.Image(type="pil", label="Upload Image"),
69
- outputs=[gr.Image(type="pil", label="Annotated Image"),
70
- gr.JSON(label="Object Count"),
71
- gr.Number(label="Total Objects Detected")],
72
- live=True
73
- )
74
-
75
- # Launch the Gradio interface
76
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from dotenv import load_dotenv
3
+ from roboflow import Roboflow
4
+ import tempfile
5
+ import os
6
+ import requests
7
+ from PIL import Image
8
+
9
+ # Muat variabel lingkungan dari file .env
10
+ load_dotenv()
11
+ api_key = os.getenv("ROBOFLOW_API_KEY")
12
+ workspace = os.getenv("ROBOFLOW_WORKSPACE")
13
+ project_name = os.getenv("ROBOFLOW_PROJECT")
14
+ model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
15
+
16
+ # Inisialisasi Roboflow menggunakan data yang diambil dari secrets
17
+ rf = Roboflow(api_key=api_key)
18
+ project = rf.workspace(workspace).project(project_name)
19
+ model = project.version(model_version).model
20
+
21
+ # Fungsi untuk memotong gambar menjadi potongan-potongan kecil
22
+ def slice_image(image, slice_size=512, overlap=0):
23
+ width, height = image.size
24
+ slices = []
25
+
26
+ step = slice_size - overlap
27
+
28
+ for top in range(0, height, step):
29
+ for left in range(0, width, step):
30
+ bottom = min(top + slice_size, height)
31
+ right = min(left + slice_size, width)
32
+ slices.append((left, top, right, bottom))
33
+
34
+ return slices
35
+
36
+ # Fungsi untuk menangani input dan output gambar
37
+ def detect_objects(image):
38
+ slice_size = 512
39
+ overlap = 50
40
+
41
+ # Potong gambar menjadi bagian kecil
42
+ slices = slice_image(image, slice_size, overlap)
43
+ results = []
 
 
 
 
 
 
 
 
 
 
 
 
44
  class_count = {}
45
+ total_count = 0
46
+
47
+ for i, (left, top, right, bottom) in enumerate(slices):
48
+ sliced_image = image.crop((left, top, right, bottom))
49
+
50
+ # Simpan gambar slice sementara
51
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
52
+ sliced_image.save(temp_file, format="JPEG")
53
+ temp_file_path = temp_file.name
54
+
55
+ try:
56
+ # Lakukan prediksi pada setiap slice
57
+ predictions = model.predict(temp_file_path, confidence=60, overlap=80).json()
58
+
59
+ for prediction in predictions['predictions']:
60
+ prediction["left"] += left
61
+ prediction["top"] += top
62
+ prediction["right"] += left
63
+ prediction["bottom"] += top
64
+
65
+ results.append(prediction)
66
+
67
+ # Perbarui jumlah objek per kelas
68
+ class_name = prediction['class']
69
+ class_count[class_name] = class_count.get(class_name, 0) + 1
70
+ total_count += 1
71
+ except requests.exceptions.HTTPError as http_err:
72
+ return f"HTTP error occurred: {http_err}", None
73
+ except Exception as err:
74
+ return f"An error occurred: {err}", None
75
+ finally:
76
+ os.remove(temp_file_path)
77
+
78
+ # Gabungkan hasil deteksi
79
+ result_text = "Product Nestle\n\n"
80
+ for class_name, count in class_count.items():
81
+ result_text += f"{class_name}: {count}\n"
82
+ result_text += f"\nTotal Product Nestle: {total_count}"
83
+
84
+ # Kembalikan hasil
85
+ return image, result_text
86
+
87
+ # Membuat antarmuka Gradio dengan tata letak fleksibel
88
+ with gr.Blocks() as iface:
89
+ with gr.Row():
90
+ with gr.Column():
91
+ input_image = gr.Image(type="pil", label="Input Image")
92
+ with gr.Column():
93
+ output_image = gr.Image(label="Detect Object")
94
+ with gr.Column():
95
+ output_text = gr.Textbox(label="Counting Object")
96
+
97
+ # Tombol untuk memproses input
98
+ detect_button = gr.Button("Detect")
99
+
100
+ # Hubungkan tombol dengan fungsi deteksi
101
+ detect_button.click(
102
+ fn=detect_objects,
103
+ inputs=input_image,
104
+ outputs=[output_image, output_text]
105
+ )
106
+
107
+ # Menjalankan antarmuka
108
+ iface.launch()