File size: 4,258 Bytes
ac831c4
5a61493
b34cc48
 
 
5a61493
f5c0946
c66a6f5
426695e
ac831c4
5a61493
 
 
 
 
 
 
 
 
 
 
ac831c4
7e0a954
ac831c4
5a61493
ac831c4
 
 
 
5a61493
93307f9
 
 
 
5a61493
93307f9
 
 
 
e49837a
93307f9
 
 
5a61493
93307f9
 
c66a6f5
93307f9
426695e
93307f9
426695e
 
93307f9
426695e
 
93307f9
426695e
 
93307f9
426695e
 
93307f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a61493
93307f9
5a61493
93307f9
5a61493
93307f9
5a61493
93307f9
5a61493
93307f9
ac831c4
 
5a61493
5f0c190
93307f9
5a61493
 
 
 
 
93307f9
5a61493
93307f9
5a61493
93307f9
 
5a61493
93307f9
5a61493
 
 
 
 
ac831c4
93307f9
5a61493
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from dotenv import load_dotenv
from roboflow import Roboflow
import tempfile
import os
import requests
import numpy as np  # Import numpy to handle image slices
from sahi.predict import get_sliced_prediction  # SAHI slicing inference
import supervision as sv  # For annotating images with results

# Muat variabel lingkungan dari file .env
load_dotenv()
api_key = os.getenv("ROBOFLOW_API_KEY")
workspace = os.getenv("ROBOFLOW_WORKSPACE")
project_name = os.getenv("ROBOFLOW_PROJECT")
model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))

# Inisialisasi Roboflow menggunakan data yang diambil dari secrets
rf = Roboflow(api_key=api_key)
project = rf.workspace(workspace).project(project_name)
model = project.version(model_version).model

# Fungsi untuk menangani input dan output gambar
def detect_objects(image):
    # Simpan gambar yang diupload sebagai file sementara
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
        image.save(temp_file, format="JPEG")
        temp_file_path = temp_file.name

    try:
        # Perform sliced inference with SAHI using InferenceSlicer
        def callback(image_slice: np.ndarray) -> sv.Detections:
            results = model.infer(image_slice)[0]  # Perform inference on each slice
            return sv.Detections.from_inference(results)

        # Configure the SAHI Slicer with specific slice dimensions and overlap
        slicer = sv.InferenceSlicer(
            callback=callback,
            slice_wh=(320, 320),  # Adjust slice dimensions as needed
            overlap_wh=(64, 64),  # Adjust overlap in pixels (DO NOT use overlap_ratio_wh here)
            overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION,  # Filter overlapping detections
            iou_threshold=0.5,  # Intersection over Union threshold for NMS
        )

        # Run slicing-based inference
        detections = slicer(image)

        # Annotate the results on the image
        box_annotator = sv.BoxAnnotator()
        label_annotator = sv.LabelAnnotator()

        annotated_image = box_annotator.annotate(
            scene=image.copy(), detections=detections)
        
        annotated_image = label_annotator.annotate(
            scene=annotated_image, detections=detections)

        # Save the annotated image
        output_image_path = "/tmp/prediction_visual.png"
        annotated_image.save(output_image_path)

        # Count the number of detected objects per class
        class_count = {}
        total_count = 0

        for prediction in detections:
            class_name = prediction.class_id  # or prediction.class_name if available
            class_count[class_name] = class_count.get(class_name, 0) + 1
            total_count += 1  # Increment the total object count

        # Create a result text with object counts
        result_text = "Detected Objects:\n\n"
        for class_name, count in class_count.items():
            result_text += f"{class_name}: {count}\n"
        result_text += f"\nTotal objects detected: {total_count}"

    except requests.exceptions.HTTPError as http_err:
        # Handle HTTP errors
        result_text = f"HTTP error occurred: {http_err}"
        output_image_path = temp_file_path  # Return the original image in case of error
    except Exception as err:
        # Handle other errors
        result_text = f"An error occurred: {err}"
        output_image_path = temp_file_path  # Return the original image in case of error

    # Clean up temporary files
    os.remove(temp_file_path)
    
    return output_image_path, result_text

# Create the Gradio interface
with gr.Blocks() as iface:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
        with gr.Column():
            output_image = gr.Image(label="Detected Objects")
        with gr.Column():
            output_text = gr.Textbox(label="Object Count")
    
    # Button to trigger object detection
    detect_button = gr.Button("Detect Objects")
    
    # Link the button to the detect_objects function
    detect_button.click(
        fn=detect_objects, 
        inputs=input_image, 
        outputs=[output_image, output_text]
    )

# Launch the interface
iface.launch()