muhammadsalmanalfaridzi's picture
Update app.py
52b0cb8 verified
raw
history blame
4.31 kB
import gradio as gr
import tempfile
import os
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from PIL import Image
import numpy as np
# Inisialisasi model deteksi menggunakan SAHI
model_path = "best.pt" # Ganti dengan path model YOLO lokal Anda
confidence_threshold = 0.6 # Threshold kepercayaan
sahi_device = 'cuda' # Ganti dengan 'cpu' jika tidak menggunakan GPU
# Memuat model YOLO menggunakan SAHI
sahi_model = AutoDetectionModel.from_pretrained(
model_type="yolov11", # Tipe model YOLO, sesuaikan jika model YOLO yang digunakan berbeda
model_path=model_path,
confidence_threshold=confidence_threshold,
device=sahi_device
)
# Fungsi untuk deteksi objek menggunakan SAHI
def detect_objects(image):
# Simpan gambar yang diupload sebagai file sementara
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
image.save(temp_file, format="JPEG")
temp_file_path = temp_file.name
try:
# Lakukan prediksi pada gambar menggunakan SAHI
results = get_sliced_prediction(
image=image,
detection_model=sahi_model,
slice_height=512, # Ukuran potongan gambar (bisa disesuaikan)
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2
)
# Menghitung jumlah objek per kelas
class_count = {}
total_count = 0 # Menyimpan total jumlah objek
# Menggambar bounding boxes pada gambar
output_image = np.array(image) # Convert PIL Image to numpy array for OpenCV processing
for prediction in results.object_prediction_list:
bbox = prediction.bbox
class_name = prediction.category.name # Nama kelas objek
confidence = prediction.score.value # Skor prediksi
# Hanya gambar bounding box jika skor kepercayaan lebih besar dari threshold
if confidence >= confidence_threshold:
# Gambar bounding box
cv2.rectangle(output_image,
(int(bbox.minx), int(bbox.miny)),
(int(bbox.maxx), int(bbox.maxy)),
(0, 255, 0), 2) # Gambar kotak hijau
# Gambar label dan skor
cv2.putText(output_image,
f"{class_name} {confidence:.2f}",
(int(bbox.minx), int(bbox.miny) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9,
(0, 255, 0), 2)
# Hitung jumlah objek per kelas
class_count[class_name] = class_count.get(class_name, 0) + 1
total_count += 1 # Menambah jumlah objek
# Menyusun output berupa string hasil perhitungan
result_text = "Detected Objects:\n\n"
for class_name, count in class_count.items():
result_text += f"{class_name}: {count}\n"
result_text += f"\nTotal Objects: {total_count}"
# Convert output_image (numpy array) back to PIL Image to save
output_image_pil = Image.fromarray(output_image)
output_image_path = "/tmp/prediction.jpg"
output_image_pil.save(output_image_path) # Menyimpan gambar dengan prediksi
except Exception as err:
# Menangani kesalahan lain
result_text = f"An error occurred: {err}"
output_image_path = temp_file_path # Kembalikan gambar asli jika terjadi error
# Hapus file sementara setelah prediksi
os.remove(temp_file_path)
return output_image_path, result_text
# Membuat antarmuka Gradio dengan tata letak fleksibel
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
with gr.Column():
output_image = gr.Image(label="Detect Object")
with gr.Column():
output_text = gr.Textbox(label="Counting Object")
# Tombol untuk memproses input
detect_button = gr.Button("Detect")
# Hubungkan tombol dengan fungsi deteksi
detect_button.click(
fn=detect_objects,
inputs=input_image,
outputs=[output_image, output_text]
)
# Menjalankan antarmuka
iface.launch()