Spaces:
Runtime error
Runtime error
Daniel Cerda Escobar
commited on
Commit
·
8f2e27c
1
Parent(s):
f78ddd2
Define model
Browse files
utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import sahi.predict
|
3 |
+
import sahi.utils
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
TEMP_DIR = "temp"
|
7 |
+
|
8 |
+
|
9 |
+
def sahi_mmdet_inference(
|
10 |
+
image,
|
11 |
+
detection_model,
|
12 |
+
slice_height=512,
|
13 |
+
slice_width=512,
|
14 |
+
overlap_height_ratio=0.2,
|
15 |
+
overlap_width_ratio=0.2,
|
16 |
+
image_size=640,
|
17 |
+
postprocess_type="GREEDYNMM",
|
18 |
+
postprocess_match_metric="IOS",
|
19 |
+
postprocess_match_threshold=0.5,
|
20 |
+
postprocess_class_agnostic=False,
|
21 |
+
):
|
22 |
+
|
23 |
+
# standard inference
|
24 |
+
detection_model.image_size = image_size
|
25 |
+
prediction_result_1 = sahi.predict.get_prediction(
|
26 |
+
image=image, detection_model=detection_model
|
27 |
+
)
|
28 |
+
visual_result_1 = sahi.utils.cv.visualize_object_predictions(
|
29 |
+
image=numpy.array(image),
|
30 |
+
object_prediction_list=prediction_result_1.object_prediction_list,
|
31 |
+
)
|
32 |
+
output_1 = Image.fromarray(visual_result_1["image"])
|
33 |
+
|
34 |
+
# sliced inference
|
35 |
+
prediction_result_2 = sahi.predict.get_sliced_prediction(
|
36 |
+
image=image,
|
37 |
+
detection_model=detection_model,
|
38 |
+
slice_height=slice_height,
|
39 |
+
slice_width=slice_width,
|
40 |
+
overlap_height_ratio=overlap_height_ratio,
|
41 |
+
overlap_width_ratio=overlap_width_ratio,
|
42 |
+
postprocess_type=postprocess_type,
|
43 |
+
postprocess_match_metric=postprocess_match_metric,
|
44 |
+
postprocess_match_threshold=postprocess_match_threshold,
|
45 |
+
postprocess_class_agnostic=postprocess_class_agnostic,
|
46 |
+
)
|
47 |
+
visual_result_2 = sahi.utils.cv.visualize_object_predictions(
|
48 |
+
image=numpy.array(image),
|
49 |
+
object_prediction_list=prediction_result_2.object_prediction_list,
|
50 |
+
)
|
51 |
+
|
52 |
+
output_2 = Image.fromarray(visual_result_2["image"])
|
53 |
+
|
54 |
+
return output_1, output_2
|