aijack commited on
Commit
a9ee6d6
·
1 Parent(s): ae4e807

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +128 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sahi.utils
3
+ from sahi import AutoDetectionModel
4
+ import sahi.predict
5
+ import sahi.slicing
6
+ from PIL import Image
7
+ import numpy
8
+
9
+ IMAGE_SIZE = 640
10
+
11
+ # Images
12
+ sahi.utils.file.download_from_url(
13
+ "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
14
+ "apple_tree.jpg",
15
+ )
16
+
17
+
18
+
19
+ # Model
20
+ model = AutoDetectionModel.from_pretrained(
21
+ model_type="yolov5", model_path="yolov5s6.pt", device="cpu", confidence_threshold=0.5, image_size=IMAGE_SIZE
22
+ )
23
+
24
+
25
+ def sahi_yolo_inference(
26
+ image,
27
+ slice_height=512,
28
+ slice_width=512,
29
+ overlap_height_ratio=0.2,
30
+ overlap_width_ratio=0.2,
31
+ postprocess_type="NMS",
32
+ postprocess_match_metric="IOU",
33
+ postprocess_match_threshold=0.5,
34
+ postprocess_class_agnostic=False,
35
+ ):
36
+
37
+ image_width, image_height = image.size
38
+ sliced_bboxes = sahi.slicing.get_slice_bboxes(
39
+ image_height,
40
+ image_width,
41
+ slice_height,
42
+ slice_width,
43
+ False,
44
+ overlap_height_ratio,
45
+ overlap_width_ratio,
46
+ )
47
+ if len(sliced_bboxes) > 60:
48
+ raise ValueError(
49
+ f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size."
50
+ )
51
+
52
+ # standard inference
53
+ prediction_result_1 = sahi.predict.get_prediction(
54
+ image=image, detection_model=model
55
+ )
56
+ print(image)
57
+ visual_result_1 = sahi.utils.cv.visualize_object_predictions(
58
+ image=numpy.array(image),
59
+ object_prediction_list=prediction_result_1.object_prediction_list,
60
+ )
61
+ output_1 = Image.fromarray(visual_result_1["image"])
62
+
63
+ # sliced inference
64
+ prediction_result_2 = sahi.predict.get_sliced_prediction(
65
+ image=image,
66
+ detection_model=model,
67
+ slice_height=int(slice_height),
68
+ slice_width=int(slice_width),
69
+ overlap_height_ratio=overlap_height_ratio,
70
+ overlap_width_ratio=overlap_width_ratio,
71
+ postprocess_type=postprocess_type,
72
+ postprocess_match_metric=postprocess_match_metric,
73
+ postprocess_match_threshold=postprocess_match_threshold,
74
+ postprocess_class_agnostic=postprocess_class_agnostic,
75
+ )
76
+ visual_result_2 = sahi.utils.cv.visualize_object_predictions(
77
+ image=numpy.array(image),
78
+ object_prediction_list=prediction_result_2.object_prediction_list,
79
+ )
80
+
81
+ output_2 = Image.fromarray(visual_result_2["image"])
82
+
83
+ return output_1, output_2
84
+
85
+
86
+ inputs = [
87
+ gr.Image(type="pil", label="Original Image"),
88
+ gr.Number(default=512, label="slice_height"),
89
+ gr.Number(default=512, label="slice_width"),
90
+ gr.Number(default=0.2, label="overlap_height_ratio"),
91
+ gr.Number(default=0.2, label="overlap_width_ratio"),
92
+ gr.Dropdown(
93
+ ["NMS", "GREEDYNMM"],
94
+ type="value",
95
+ value="NMS",
96
+ label="postprocess_type",
97
+ ),
98
+ gr.Dropdown(
99
+ ["IOU", "IOS"], type="value", default="IOU", label="postprocess_type"
100
+ ),
101
+ gr.Number(default=0.5, label="postprocess_match_threshold"),
102
+ gr.Checkbox(default=True, label="postprocess_class_agnostic"),
103
+ ]
104
+
105
+ outputs = [
106
+ gr.Image(type="pil", label="YOLOv5s"),
107
+ gr.Image(type="pil", label="YOLOv5s + SAHI"),
108
+ ]
109
+
110
+ title = "Small Object Detection with SAHI + YOLOv5"
111
+ description = "SAHI + YOLOv5 demo for small object detection. Upload an image or click an example image to use."
112
+ article = "<p style='text-align: center'><a href='http://claireye.com.tw'>Claireye</a> | 2023</p>"
113
+ examples = [
114
+ ["apple_tree.jpg", 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True]
115
+
116
+ ]
117
+
118
+ gr.Interface(
119
+ sahi_yolo_inference,
120
+ inputs,
121
+ outputs,
122
+ title=title,
123
+ description=description,
124
+ article=article,
125
+ examples=examples,
126
+ theme="huggingface",
127
+ cache_examples=True,
128
+ ).launch(debug=True, enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.10.2+cpu
2
+ torchvision==0.11.3+cpu
3
+ -f https://download.pytorch.org/whl/torch_stable.html
4
+ yolov5==7.0.8
5
+ sahi==0.11.11