VlaTal commited on
Commit
2c9a482
β€’
1 Parent(s): edb09e9

Added model and threshold choosing

Browse files
app.py CHANGED
@@ -1,30 +1,57 @@
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  import numpy as np
4
- import os
5
-
6
- # Load YOLO model
7
- model = YOLO('./best.pt')
8
 
 
 
 
9
  example_list = [["examples/" + example] for example in os.listdir("examples")]
10
 
11
- def process_image(input_image):
12
- if input_image is not None:
13
- results = model(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- for r in results:
16
- im_array = r.plot()
17
- im_array = im_array.astype(np.uint8)
18
- return im_array
19
 
20
- # Create Gradio Interface
21
  iface = gr.Interface(
22
  fn=process_image,
23
- inputs=gr.Image(),
24
- outputs=gr.Image(), # Specify output as Gradio Image
 
 
 
 
25
  title="YOLOv8-obb aerial detection",
26
- description="YOLOv8-obb trained on DOTAv1.5",
27
- examples=example_list)
 
28
 
29
- # Launch the Gradio interface
30
- iface.launch()
 
1
+ import os
2
  import gradio as gr
3
  from ultralytics import YOLO
4
  import numpy as np
 
 
 
 
5
 
6
+ model_options = ["yolo-8n-dota.pt", "yolo-8s-dota.pt", "yolo-8m-dota.pt"]
7
+ model_names = ["Nano", "Small", "Medium"]
8
+ models = [YOLO(option) for option in model_options]
9
  example_list = [["examples/" + example] for example in os.listdir("examples")]
10
 
11
+ def process_image(input_image, model_name, conf):
12
+ if input_image is None:
13
+ return None, "No image"
14
+
15
+ if model_name is None:
16
+ model_name = model_names[0]
17
+
18
+ if conf is None:
19
+ conf = 0.6
20
+
21
+ model_index = model_names.index(model_name)
22
+ model = models[model_index]
23
+
24
+ results = model.predict(input_image, conf=conf)
25
+ class_counts = {}
26
+ class_counts_str = "Class Counts:\n"
27
+
28
+ for r in results:
29
+ im_array = r.plot()
30
+ im_array = im_array.astype(np.uint8)
31
+
32
+ if len(r.obb.cls) == 0: # If no objects are detected
33
+ return None, "No objects detected."
34
+
35
+ for cls in r.obb.cls:
36
+ class_name = r.names[cls.item()]
37
+ class_counts[class_name] = class_counts.get(class_name, 0) + 1
38
+
39
+ for cls, count in class_counts.items():
40
+ class_counts_str += f"\n{cls}: {count}"
41
 
42
+ return im_array, class_counts_str
 
 
 
43
 
 
44
  iface = gr.Interface(
45
  fn=process_image,
46
+ inputs=[
47
+ gr.Image(),
48
+ gr.Radio(model_names, label="Choose model", value=model_names[0]),
49
+ gr.Slider(minimum=0.2, maximum=1.0, step=0.1, label="Confidence Threshold", value=0.6)
50
+ ],
51
+ outputs=["image", gr.Textbox(label="More info")],
52
  title="YOLOv8-obb aerial detection",
53
+ description='''YOLOv8-obb trained on DOTAv1.5''',
54
+ examples=example_list
55
+ )
56
 
57
+ iface.launch()
 
flagged/log.csv DELETED
@@ -1,2 +0,0 @@
1
- name,intensity,output,flag,username,timestamp
2
- aasd,99,"Hello, aasd!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!",,,2024-02-20 13:01:37.633520
 
 
 
best.pt β†’ yolo-8m-dota.pt RENAMED
File without changes
yolo-8n-dota.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ee44108137e10e13a377e2a75175af1476c355e04217cc38e3e8e2f4cb6fd7c
3
+ size 6465538
yolo-8s-dota.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4395070f75af8464e3c4d7e7d83eda61d19d6cd5ea6c62ca1898844a8e0ad54c
3
+ size 23169282