Daniel Cerda Escobar commited on
Commit
8f2e27c
·
1 Parent(s): f78ddd2

Define model

Browse files
Files changed (1) hide show
  1. utils.py +54 -0
utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import sahi.predict
3
+ import sahi.utils
4
+ from PIL import Image
5
+
6
+ TEMP_DIR = "temp"
7
+
8
+
9
+ def sahi_mmdet_inference(
10
+ image,
11
+ detection_model,
12
+ slice_height=512,
13
+ slice_width=512,
14
+ overlap_height_ratio=0.2,
15
+ overlap_width_ratio=0.2,
16
+ image_size=640,
17
+ postprocess_type="GREEDYNMM",
18
+ postprocess_match_metric="IOS",
19
+ postprocess_match_threshold=0.5,
20
+ postprocess_class_agnostic=False,
21
+ ):
22
+
23
+ # standard inference
24
+ detection_model.image_size = image_size
25
+ prediction_result_1 = sahi.predict.get_prediction(
26
+ image=image, detection_model=detection_model
27
+ )
28
+ visual_result_1 = sahi.utils.cv.visualize_object_predictions(
29
+ image=numpy.array(image),
30
+ object_prediction_list=prediction_result_1.object_prediction_list,
31
+ )
32
+ output_1 = Image.fromarray(visual_result_1["image"])
33
+
34
+ # sliced inference
35
+ prediction_result_2 = sahi.predict.get_sliced_prediction(
36
+ image=image,
37
+ detection_model=detection_model,
38
+ slice_height=slice_height,
39
+ slice_width=slice_width,
40
+ overlap_height_ratio=overlap_height_ratio,
41
+ overlap_width_ratio=overlap_width_ratio,
42
+ postprocess_type=postprocess_type,
43
+ postprocess_match_metric=postprocess_match_metric,
44
+ postprocess_match_threshold=postprocess_match_threshold,
45
+ postprocess_class_agnostic=postprocess_class_agnostic,
46
+ )
47
+ visual_result_2 = sahi.utils.cv.visualize_object_predictions(
48
+ image=numpy.array(image),
49
+ object_prediction_list=prediction_result_2.object_prediction_list,
50
+ )
51
+
52
+ output_2 = Image.fromarray(visual_result_2["image"])
53
+
54
+ return output_1, output_2