Asma commited on
Commit
6cb29f6
·
1 Parent(s): 8e3f20a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from sahi.prediction import ObjectPrediction
4
+ from sahi.utils.cv import visualize_object_predictions, read_image
5
+ from ultralyticsplus import YOLO
6
+
7
+
8
+ def yolov8_inference(
9
+ image: gr.inputs.Image = None,
10
+ model_path: gr.inputs.Dropdown = None,
11
+ image_size: gr.inputs.Slider = 640,
12
+ conf_threshold: gr.inputs.Slider = 0.25,
13
+ iou_threshold: gr.inputs.Slider = 0.25,
14
+ ):
15
+ """
16
+ YOLOv8 inference function
17
+ Args:
18
+ image: Input image
19
+ model_path: Path to the model
20
+ image_size: Image size
21
+ conf_threshold: Confidence threshold
22
+ iou_threshold: IOU threshold
23
+ Returns:
24
+ Rendered image
25
+ """
26
+ model = YOLO(model_path)
27
+ model.conf = conf_threshold
28
+ model.iou = iou_threshold
29
+ results = model.predict(image, imgsz=image_size, return_outputs=True)
30
+ object_prediction_list = []
31
+ for _, image_results in enumerate(results):
32
+ image_predictions_in_xyxy_format = image_results['det']
33
+ for pred in image_predictions_in_xyxy_format:
34
+ x1, y1, x2, y2 = (
35
+ int(pred[0]),
36
+ int(pred[1]),
37
+ int(pred[2]),
38
+ int(pred[3]),
39
+ )
40
+ bbox = [x1, y1, x2, y2]
41
+ score = pred[4]
42
+ category_name = model.model.names[int(pred[5])]
43
+ category_id = pred[5]
44
+ object_prediction = ObjectPrediction(
45
+ bbox=bbox,
46
+ category_id=int(category_id),
47
+ score=score,
48
+ category_name=category_name,
49
+ )
50
+ object_prediction_list.append(object_prediction)
51
+
52
+ image = read_image(image)
53
+ output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list)
54
+ return output_image['image']
55
+
56
+
57
+ inputs = [
58
+ gr.inputs.Image(type="filepath", label="Input Image"),
59
+ gr.inputs.Dropdown(["Asma/GreenHawk_test"],
60
+ default="Asma/GreenHawk_test", label="Model"),
61
+ gr.inputs.Slider(minimum=320, maximum=1280, default=640, step=32, label="Image Size"),
62
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25, step=0.05, label="Confidence Threshold"),
63
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25, step=0.05, label="IOU Threshold"),
64
+ ]
65
+
66
+ outputs = gr.outputs.Image(type="filepath", label="Output Image")
67
+ title = "GreenHawk - Visual Pollution Detection"
68
+
69
+ demo_app = gr.Interface(
70
+ fn=yolov8_inference,
71
+ inputs=inputs,
72
+ outputs=outputs,
73
+ title=title,
74
+ # examples=examples,
75
+ cache_examples=True,
76
+ theme='huggingface',
77
+ )
78
+ demo_app.launch(enable_queue=True,share=True)