AAAAAAyq commited on
Commit
30e0f74
1 Parent(s): 87c6f54

Update application file

Browse files
Files changed (2) hide show
  1. app.py +14 -16
  2. requirements.txt +12 -11
app.py CHANGED
@@ -1,11 +1,8 @@
1
  from ultralytics import YOLO
2
- from PIL import Image
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
- import io
7
  import torch
8
- # import cv2
9
 
10
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
11
 
@@ -48,26 +45,27 @@ def show_mask(annotation, ax, random_color=True, bbox=None, points=None):
48
  return mask_image
49
 
50
  def post_process(annotations, image, mask_random_color=True, bbox=None, points=None):
51
- plt.figure(figsize=(10, 10))
52
  plt.imshow(image)
53
  for i, mask in enumerate(annotations):
54
  show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
55
  plt.axis('off')
56
- # create a BytesIO object
57
- buf = io.BytesIO()
58
 
59
- # save plot to buf
60
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
61
 
62
- # use PIL to open the image
63
- img = Image.open(buf)
64
 
65
- # copy the image data
66
- img_copy = img.copy()
 
67
 
68
- # don't forget to close the buffer
69
- buf.close()
70
- return img_copy
71
 
72
 
73
  # def show_mask(annotation, ax, random_color=False):
@@ -107,7 +105,7 @@ def predict(inp):
107
 
108
  demo = gr.Interface(fn=predict,
109
  inputs=gr.inputs.Image(type='pil'),
110
- outputs=gr.outputs.Image(type='pil'),
111
  examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
112
  ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
113
  ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
 
1
  from ultralytics import YOLO
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import gradio as gr
 
5
  import torch
 
6
 
7
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
8
 
 
45
  return mask_image
46
 
47
  def post_process(annotations, image, mask_random_color=True, bbox=None, points=None):
48
+ fig = plt.figure(figsize=(10, 10))
49
  plt.imshow(image)
50
  for i, mask in enumerate(annotations):
51
  show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
52
  plt.axis('off')
53
+ # # create a BytesIO object
54
+ # buf = io.BytesIO()
55
 
56
+ # # save plot to buf
57
+ # plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
58
 
59
+ # # use PIL to open the image
60
+ # img = Image.open(buf)
61
 
62
+ # # copy the image data
63
+ # img_copy = img.copy()
64
+ plt.tight_layout()
65
 
66
+ # # don't forget to close the buffer
67
+ # buf.close()
68
+ return fig
69
 
70
 
71
  # def show_mask(annotation, ax, random_color=False):
 
105
 
106
  demo = gr.Interface(fn=predict,
107
  inputs=gr.inputs.Image(type='pil'),
108
+ outputs=['plot'],
109
  examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
110
  ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
111
  ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
requirements.txt CHANGED
@@ -1,16 +1,17 @@
1
  # Base-----------------------------------
2
- matplotlib>=3.2.2
3
- opencv-python>=4.6.0
4
- Pillow>=7.1.2
5
- PyYAML>=5.3.1
6
- requests>=2.23.0
7
- scipy>=1.4.1
8
- torch>=1.7.0
9
- torchvision>=0.8.1
10
- tqdm>=4.64.0
 
11
 
12
- pandas>=1.1.4
13
- seaborn>=0.11.0
14
 
15
  # Ultralytics-----------------------------------
16
  ultralytics
 
1
  # Base-----------------------------------
2
+ matplotlib==3.5.2
3
+ numpy==1.23.0
4
+ # opencv-python>=4.6.0
5
+ # Pillow>=7.1.2
6
+ # PyYAML>=5.3.1
7
+ # requests>=2.23.0
8
+ # scipy>=1.4.1
9
+ torch
10
+ torchvision
11
+ # tqdm>=4.64.0
12
 
13
+ # pandas>=1.1.4
14
+ # seaborn>=0.11.0
15
 
16
  # Ultralytics-----------------------------------
17
  ultralytics