kedimestan commited on
Commit
b48c806
1 Parent(s): 42f5968

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -51
app.py CHANGED
@@ -3,7 +3,7 @@ import sahi
3
  import torch
4
  from ultralyticsplus import YOLO, render_model_output
5
 
6
- # Download sample images
7
  sahi.utils.file.download_from_url(
8
  "https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg",
9
  "highway.jpg",
@@ -17,7 +17,7 @@ sahi.utils.file.download_from_url(
17
  "zidane.jpg",
18
  )
19
 
20
- # List of YOLOv8 segmentation models
21
  model_names = [
22
  "yolov8n-seg.pt",
23
  "yolov8s-seg.pt",
@@ -26,18 +26,19 @@ model_names = [
26
  "yolov8x-seg.pt",
27
  ]
28
 
 
29
  current_model_name = "yolov8m-seg.pt"
30
  model = YOLO(current_model_name)
31
 
32
  def yolov8_inference(
33
  image: gr.Image = None,
34
- model_name: str = None,
35
- image_size: int = 640,
36
- conf_threshold: float = 0.25,
37
- iou_threshold: float = 0.45,
38
  ):
39
  """
40
- YOLOv8 inference function to return masks and label names for each detected object
41
  Args:
42
  image: Input image
43
  model_name: Name of the model
@@ -45,85 +46,66 @@ def yolov8_inference(
45
  conf_threshold: Confidence threshold
46
  iou_threshold: IOU threshold
47
  Returns:
48
- Object masks, coordinates, and label names
49
  """
50
  global model
51
  global current_model_name
52
-
53
- # Check if a new model is selected
54
  if model_name != current_model_name:
55
  model = YOLO(model_name)
56
  current_model_name = model_name
57
 
58
- # Set the confidence and IOU thresholds
59
  model.overrides["conf"] = conf_threshold
60
  model.overrides["iou"] = iou_threshold
61
 
62
- # Perform model prediction
63
  results = model.predict(image, imgsz=image_size, return_outputs=True)
64
-
65
- # Initialize an empty list to store the output
66
- output = []
67
 
68
- # Iterate over the results
69
  for result in results:
70
- # Check if segmentation masks are available
71
- if 'masks' in result and result['masks'] is not None:
72
- masks = result['masks']['data']
73
- for i, (mask, box) in enumerate(zip(masks, result['boxes'])):
74
- label = model.names[int(result['boxes']['cls'][i])]
75
- mask_coords = mask.tolist() # Convert mask coordinates to list format
76
- output.append({"label": label, "mask_coords": mask_coords})
77
- else:
78
- # If masks are not available, just extract bounding box information
79
- for i, box in enumerate(result['boxes']):
80
- label = model.names[int(result['boxes']['cls'][i])]
81
- bbox = box['xyxy'].tolist() # Bounding box coordinates
82
- output.append({"label": label, "bbox_coords": bbox})
83
 
84
- return output
 
 
 
 
 
85
 
86
- # Define Gradio interface inputs and outputs
 
 
 
87
  inputs = [
88
  gr.Image(type="filepath", label="Input Image"),
89
- gr.Dropdown(
90
- model_names,
91
- value=current_model_name,
92
- label="Model type",
93
- ),
94
  gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size"),
95
- gr.Slider(
96
- minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"
97
- ),
98
  gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
99
  ]
100
 
101
- # Output is a dictionary containing label names and coordinates of masks or boxes
102
- outputs = gr.JSON(label="Output Masks and Labels")
103
-
104
- title = "Ultralytics YOLOv8 Segmentation Demo"
105
 
106
- # Example images for the interface
107
  examples = [
108
  ["zidane.jpg", "yolov8m-seg.pt", 640, 0.6, 0.45],
109
  ["highway.jpg", "yolov8m-seg.pt", 640, 0.25, 0.45],
110
  ["small-vehicles1.jpeg", "yolov8m-seg.pt", 640, 0.25, 0.45],
111
  ]
112
 
113
- # Build the Gradio demo app with POST functionality
114
  demo_app = gr.Interface(
115
  fn=yolov8_inference,
116
  inputs=inputs,
117
  outputs=outputs,
118
- title=title,
119
  examples=examples,
120
- cache_examples=False, # Set to False to avoid caching issues
121
- theme="default",
122
  )
123
 
124
- # Launch the app with API-enabled functionality
125
- demo_app.queue().launch(
126
- enable_queue=True, # Allow for API-style interactions
127
  debug=True, # Show detailed errors in case of issues
128
  server_name="0.0.0.0", # Host on all IPs
129
  server_port=7860, # Custom port for accessing the app
 
3
  import torch
4
  from ultralyticsplus import YOLO, render_model_output
5
 
6
+ # Download images for the demo
7
  sahi.utils.file.download_from_url(
8
  "https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg",
9
  "highway.jpg",
 
17
  "zidane.jpg",
18
  )
19
 
20
+ # Define available YOLOv8 segmentation models
21
  model_names = [
22
  "yolov8n-seg.pt",
23
  "yolov8s-seg.pt",
 
26
  "yolov8x-seg.pt",
27
  ]
28
 
29
+ # Load the initial YOLOv8 model
30
  current_model_name = "yolov8m-seg.pt"
31
  model = YOLO(current_model_name)
32
 
33
  def yolov8_inference(
34
  image: gr.Image = None,
35
+ model_name: gr.Dropdown = None,
36
+ image_size: gr.Slider = 640,
37
+ conf_threshold: gr.Slider = 0.25,
38
+ iou_threshold: gr.Slider = 0.45,
39
  ):
40
  """
41
+ YOLOv8 inference function
42
  Args:
43
  image: Input image
44
  model_name: Name of the model
 
46
  conf_threshold: Confidence threshold
47
  iou_threshold: IOU threshold
48
  Returns:
49
+ Rendered image and mask coordinates with labels
50
  """
51
  global model
52
  global current_model_name
53
+ # Switch model if a different one is selected
 
54
  if model_name != current_model_name:
55
  model = YOLO(model_name)
56
  current_model_name = model_name
57
 
58
+ # Set model confidence and IOU thresholds
59
  model.overrides["conf"] = conf_threshold
60
  model.overrides["iou"] = iou_threshold
61
 
62
+ # Perform inference with the YOLO model
63
  results = model.predict(image, imgsz=image_size, return_outputs=True)
 
 
 
64
 
65
+ masks = []
66
  for result in results:
67
+ masks.append([result.masks, result.labels])
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ renders = []
70
+ for image_results in results:
71
+ render = render_model_output(
72
+ model=model, image=image, model_output=image_results
73
+ )
74
+ renders.append(render)
75
 
76
+ # Return mask coordinates and labels
77
+ return masks
78
+
79
+ # Gradio app inputs and outputs
80
  inputs = [
81
  gr.Image(type="filepath", label="Input Image"),
82
+ gr.Dropdown(model_names, value=current_model_name, label="Model type"),
 
 
 
 
83
  gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size"),
84
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"),
 
 
85
  gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
86
  ]
87
 
88
+ outputs = gr.Textbox(label="Mask Coordinates and Labels")
 
 
 
89
 
90
+ # Example inputs for the Gradio app
91
  examples = [
92
  ["zidane.jpg", "yolov8m-seg.pt", 640, 0.6, 0.45],
93
  ["highway.jpg", "yolov8m-seg.pt", 640, 0.25, 0.45],
94
  ["small-vehicles1.jpeg", "yolov8m-seg.pt", 640, 0.25, 0.45],
95
  ]
96
 
97
+ # Create the Gradio app interface
98
  demo_app = gr.Interface(
99
  fn=yolov8_inference,
100
  inputs=inputs,
101
  outputs=outputs,
102
+ title="Ultralytics YOLOv8 Segmentation Demo",
103
  examples=examples,
104
+ cache_examples=True,
 
105
  )
106
 
107
+ # Launch the Gradio app
108
+ demo_app.launch(
 
109
  debug=True, # Show detailed errors in case of issues
110
  server_name="0.0.0.0", # Host on all IPs
111
  server_port=7860, # Custom port for accessing the app