kedimestan commited on
Commit
041db94
1 Parent(s): b48c806

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -34
app.py CHANGED
@@ -3,7 +3,7 @@ import sahi
3
  import torch
4
  from ultralyticsplus import YOLO, render_model_output
5
 
6
- # Download images for the demo
7
  sahi.utils.file.download_from_url(
8
  "https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg",
9
  "highway.jpg",
@@ -17,7 +17,7 @@ sahi.utils.file.download_from_url(
17
  "zidane.jpg",
18
  )
19
 
20
- # Define available YOLOv8 segmentation models
21
  model_names = [
22
  "yolov8n-seg.pt",
23
  "yolov8s-seg.pt",
@@ -26,7 +26,6 @@ model_names = [
26
  "yolov8x-seg.pt",
27
  ]
28
 
29
- # Load the initial YOLOv8 model
30
  current_model_name = "yolov8m-seg.pt"
31
  model = YOLO(current_model_name)
32
 
@@ -38,7 +37,7 @@ def yolov8_inference(
38
  iou_threshold: gr.Slider = 0.45,
39
  ):
40
  """
41
- YOLOv8 inference function
42
  Args:
43
  image: Input image
44
  model_name: Name of the model
@@ -46,68 +45,81 @@ def yolov8_inference(
46
  conf_threshold: Confidence threshold
47
  iou_threshold: IOU threshold
48
  Returns:
49
- Rendered image and mask coordinates with labels
50
  """
51
  global model
52
  global current_model_name
53
- # Switch model if a different one is selected
 
54
  if model_name != current_model_name:
55
  model = YOLO(model_name)
56
  current_model_name = model_name
57
 
58
- # Set model confidence and IOU thresholds
59
  model.overrides["conf"] = conf_threshold
60
  model.overrides["iou"] = iou_threshold
61
 
62
- # Perform inference with the YOLO model
63
  results = model.predict(image, imgsz=image_size, return_outputs=True)
 
 
 
64
 
65
- masks = []
66
  for result in results:
67
- masks.append([result.masks, result.labels])
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- renders = []
70
- for image_results in results:
71
- render = render_model_output(
72
- model=model, image=image, model_output=image_results
73
- )
74
- renders.append(render)
75
 
76
- # Return mask coordinates and labels
77
- return masks
78
-
79
- # Gradio app inputs and outputs
80
  inputs = [
81
  gr.Image(type="filepath", label="Input Image"),
82
- gr.Dropdown(model_names, value=current_model_name, label="Model type"),
 
 
 
 
83
  gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size"),
84
- gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"),
 
 
85
  gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
86
  ]
87
 
88
- outputs = gr.Textbox(label="Mask Coordinates and Labels")
 
 
 
89
 
90
- # Example inputs for the Gradio app
91
  examples = [
92
  ["zidane.jpg", "yolov8m-seg.pt", 640, 0.6, 0.45],
93
  ["highway.jpg", "yolov8m-seg.pt", 640, 0.25, 0.45],
94
  ["small-vehicles1.jpeg", "yolov8m-seg.pt", 640, 0.25, 0.45],
95
  ]
96
 
97
- # Create the Gradio app interface
98
  demo_app = gr.Interface(
99
  fn=yolov8_inference,
100
  inputs=inputs,
101
  outputs=outputs,
102
- title="Ultralytics YOLOv8 Segmentation Demo",
103
  examples=examples,
104
- cache_examples=True,
 
105
  )
106
 
107
- # Launch the Gradio app
108
- demo_app.launch(
109
- debug=True, # Show detailed errors in case of issues
110
- server_name="0.0.0.0", # Host on all IPs
111
- server_port=7860, # Custom port for accessing the app
112
- share=True # To make the app accessible from a URL
113
- )
 
3
  import torch
4
  from ultralyticsplus import YOLO, render_model_output
5
 
6
+ # Download sample images
7
  sahi.utils.file.download_from_url(
8
  "https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg",
9
  "highway.jpg",
 
17
  "zidane.jpg",
18
  )
19
 
20
+ # List of YOLOv8 segmentation models
21
  model_names = [
22
  "yolov8n-seg.pt",
23
  "yolov8s-seg.pt",
 
26
  "yolov8x-seg.pt",
27
  ]
28
 
 
29
  current_model_name = "yolov8m-seg.pt"
30
  model = YOLO(current_model_name)
31
 
 
37
  iou_threshold: gr.Slider = 0.45,
38
  ):
39
  """
40
+ YOLOv8 inference function to return masks and label names for each detected object
41
  Args:
42
  image: Input image
43
  model_name: Name of the model
 
45
  conf_threshold: Confidence threshold
46
  iou_threshold: IOU threshold
47
  Returns:
48
+ Object masks, coordinates, and label names
49
  """
50
  global model
51
  global current_model_name
52
+
53
+ # Check if a new model is selected
54
  if model_name != current_model_name:
55
  model = YOLO(model_name)
56
  current_model_name = model_name
57
 
58
+ # Set the confidence and IOU thresholds
59
  model.overrides["conf"] = conf_threshold
60
  model.overrides["iou"] = iou_threshold
61
 
62
+ # Perform model prediction
63
  results = model.predict(image, imgsz=image_size, return_outputs=True)
64
+
65
+ # Initialize an empty list to store the output
66
+ output = []
67
 
68
+ # Iterate over the results
69
  for result in results:
70
+ # Check if segmentation masks are available
71
+ if 'masks' in result and result['masks'] is not None:
72
+ masks = result['masks']['data']
73
+ for i, (mask, box) in enumerate(zip(masks, result['boxes'])):
74
+ label = model.names[int(result['boxes']['cls'][i])]
75
+ mask_coords = mask.tolist() # Convert mask coordinates to list format
76
+ output.append({"label": label, "mask_coords": mask_coords})
77
+ else:
78
+ # If masks are not available, just extract bounding box information
79
+ for i, box in enumerate(result['boxes']):
80
+ label = model.names[int(result['boxes']['cls'][i])]
81
+ bbox = box['xyxy'].tolist() # Bounding box coordinates
82
+ output.append({"label": label, "bbox_coords": bbox})
83
 
84
+ return output
 
 
 
 
 
85
 
86
+ # Define Gradio interface inputs and outputs
 
 
 
87
  inputs = [
88
  gr.Image(type="filepath", label="Input Image"),
89
+ gr.Dropdown(
90
+ model_names,
91
+ value=current_model_name,
92
+ label="Model type",
93
+ ),
94
  gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size"),
95
+ gr.Slider(
96
+ minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"
97
+ ),
98
  gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
99
  ]
100
 
101
+ # Output is a dictionary containing label names and coordinates of masks or boxes
102
+ outputs = gr.JSON(label="Output Masks and Labels")
103
+
104
+ title = "Ultralytics YOLOv8 Segmentation Demo"
105
 
106
+ # Example images for the interface
107
  examples = [
108
  ["zidane.jpg", "yolov8m-seg.pt", 640, 0.6, 0.45],
109
  ["highway.jpg", "yolov8m-seg.pt", 640, 0.25, 0.45],
110
  ["small-vehicles1.jpeg", "yolov8m-seg.pt", 640, 0.25, 0.45],
111
  ]
112
 
113
+ # Build the Gradio demo app
114
  demo_app = gr.Interface(
115
  fn=yolov8_inference,
116
  inputs=inputs,
117
  outputs=outputs,
118
+ title=title,
119
  examples=examples,
120
+ cache_examples=False, # Set to False to avoid caching issues
121
+ theme="default",
122
  )
123
 
124
+ # Launch the app
125
+ demo_app.queue().launch(debug=True)