petergpt commited on
Commit
07d7c0a
·
verified ·
1 Parent(s): 66a61d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -47
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import cv2
3
  import gradio as gr
4
  import os
@@ -13,6 +12,7 @@ import warnings
13
  import time
14
  warnings.filterwarnings("ignore")
15
 
 
16
  os.system("git clone https://github.com/xuebinqin/DIS")
17
  os.system("mv DIS/IS-Net/* .")
18
 
@@ -36,10 +36,10 @@ class GOSNormalize(object):
36
  self.std = std
37
 
38
  def __call__(self,image):
39
- image = normalize(image,self.mean,self.std)
40
  return image
41
 
42
- transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
43
 
44
  def load_image(im_path, hypar):
45
  im = im_reader(im_path)
@@ -59,88 +59,94 @@ def build_model(hypar, device):
59
  layer.float()
60
 
61
  net.to(device)
62
-
63
- if(hypar["restore_model"]!=""):
64
- net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
65
  net.to(device)
66
- net.eval()
67
  return net
68
 
69
  def predict(net, inputs_val, shapes_val, hypar, device):
70
  net.eval()
71
 
72
- if(hypar["model_digit"]=="full"):
73
  inputs_val = inputs_val.type(torch.FloatTensor)
74
  else:
75
  inputs_val = inputs_val.type(torch.HalfTensor)
76
 
77
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
78
  ds_val = net(inputs_val_v)[0]
79
- pred_val = ds_val[0][0,:,:,:]
80
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0),
81
  (shapes_val[0][0], shapes_val[0][1]),
82
  mode='bilinear'))
83
 
84
  ma = torch.max(pred_val)
85
  mi = torch.min(pred_val)
86
- pred_val = (pred_val - mi) / (ma - mi + 1e-8) # normalize to 0~1, +1e-8 to avoid div by zero
 
87
 
88
- if device == 'cuda':
89
  torch.cuda.empty_cache()
90
  return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
91
 
92
  # Parameters
93
- hypar = {}
94
- hypar["model_path"] = "./saved_models"
95
- hypar["restore_model"] = "isnet.pth"
96
- hypar["interm_sup"] = False
97
- hypar["model_digit"] = "full"
98
- hypar["seed"] = 0
99
- hypar["cache_size"] = [1024, 1024]
100
- hypar["input_size"] = [1024, 1024]
101
- hypar["crop_size"] = [1024, 1024]
102
- hypar["model"] = ISNetDIS()
103
-
104
- # Build Model
 
105
  net = build_model(hypar, device)
106
 
107
- def inference(images, logs):
 
 
 
108
  start_time = time.time()
 
109
 
110
- # If user didn't upload images, just return empty
111
- if not images:
 
 
 
112
  return [], logs, logs
113
 
114
  processed_pairs = []
115
- for img_path in images:
116
- image_tensor, orig_size = load_image(img_path, hypar)
117
  mask = predict(net, image_tensor, orig_size, hypar, device)
118
 
119
  pil_mask = Image.fromarray(mask).convert('L')
120
- im_rgb = Image.open(img_path).convert("RGB")
121
  im_rgba = im_rgb.copy()
122
  im_rgba.putalpha(pil_mask)
123
  processed_pairs.append([im_rgba, pil_mask])
124
-
125
  end_time = time.time()
126
  elapsed = round(end_time - start_time, 2)
127
 
128
- # Flatten the list so that we can display all images in a single Gallery
129
  final_images = []
130
  for pair in processed_pairs:
131
  final_images.extend(pair)
132
 
133
- # Update logs
134
- logs = logs or ""
135
- logs += f"Processed {len(processed_pairs)} image(s) in {elapsed} seconds.\n"
136
 
 
137
  return final_images, logs, logs
138
 
139
  title = "Highly Accurate Dichotomous Image Segmentation"
140
  description = (
141
- "This is an unofficial demo for DIS, a model that can remove the background from a given image. "
142
- "To use it, simply upload up to 3 images, or click one of the examples to load them. "
143
- "Read more at the links below.<br>"
144
  "GitHub: https://github.com/xuebinqin/DIS<br>"
145
  "Telegram bot: https://t.me/restoration_photo_bot<br>"
146
  "[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)"
@@ -152,22 +158,24 @@ article = (
152
 
153
  interface = gr.Interface(
154
  fn=inference,
155
- inputs=[gr.Image(
156
- type='filepath',
157
- label='Images (up to 3)',
158
- multiple=True,
159
- max_count=3
160
- ),
161
- gr.State()],
162
  outputs=[
163
  gr.Gallery(label="Output (rgba + mask)"),
164
  gr.State(),
165
  gr.Textbox(label="Logs", lines=6)
166
  ],
167
- examples=[['robot.png'], ['ship.png']], # for multi-image examples, pass a list like ['robot.png','ship.png']
 
 
 
168
  title=title,
169
  description=description,
170
  article=article,
171
  flagging_mode="never",
172
- cache_mode="lazy",
173
  ).queue().launch(show_api=True, show_error=True)
 
 
1
  import cv2
2
  import gradio as gr
3
  import os
 
12
  import time
13
  warnings.filterwarnings("ignore")
14
 
15
+ # Clone the DIS repo and move contents (make sure this only happens once per session)
16
  os.system("git clone https://github.com/xuebinqin/DIS")
17
  os.system("mv DIS/IS-Net/* .")
18
 
 
36
  self.std = std
37
 
38
  def __call__(self,image):
39
+ image = normalize(image, self.mean, self.std)
40
  return image
41
 
42
+ transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
43
 
44
  def load_image(im_path, hypar):
45
  im = im_reader(im_path)
 
59
  layer.float()
60
 
61
  net.to(device)
62
+ if hypar["restore_model"] != "":
63
+ net.load_state_dict(torch.load(os.path.join(hypar["model_path"], hypar["restore_model"]), map_location=device))
 
64
  net.to(device)
65
+ net.eval()
66
  return net
67
 
68
  def predict(net, inputs_val, shapes_val, hypar, device):
69
  net.eval()
70
 
71
+ if hypar["model_digit"] == "full":
72
  inputs_val = inputs_val.type(torch.FloatTensor)
73
  else:
74
  inputs_val = inputs_val.type(torch.HalfTensor)
75
 
76
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
77
  ds_val = net(inputs_val_v)[0]
78
+ pred_val = ds_val[0][0, :, :, :]
79
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0),
80
  (shapes_val[0][0], shapes_val[0][1]),
81
  mode='bilinear'))
82
 
83
  ma = torch.max(pred_val)
84
  mi = torch.min(pred_val)
85
+ # normalize to [0, 1], add a small epsilon to avoid division by zero
86
+ pred_val = (pred_val - mi) / (ma - mi + 1e-8)
87
 
88
+ if device == 'cuda':
89
  torch.cuda.empty_cache()
90
  return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
91
 
92
  # Parameters
93
+ hypar = {
94
+ "model_path": "./saved_models",
95
+ "restore_model": "isnet.pth",
96
+ "interm_sup": False,
97
+ "model_digit": "full",
98
+ "seed": 0,
99
+ "cache_size": [1024, 1024],
100
+ "input_size": [1024, 1024],
101
+ "crop_size": [1024, 1024],
102
+ "model": ISNetDIS()
103
+ }
104
+
105
+ # Build the model
106
  net = build_model(hypar, device)
107
 
108
+ def inference(img1, img2, img3, logs):
109
+ """
110
+ Process up to 3 images in parallel (each can be None if not provided).
111
+ """
112
  start_time = time.time()
113
+ logs = logs or "" # initialize logs if None
114
 
115
+ # Gather images into a list (filter out None)
116
+ image_paths = [i for i in [img1, img2, img3] if i is not None]
117
+ if not image_paths:
118
+ # No images were uploaded
119
+ logs += f"No images to process.\n"
120
  return [], logs, logs
121
 
122
  processed_pairs = []
123
+ for path in image_paths:
124
+ image_tensor, orig_size = load_image(path, hypar)
125
  mask = predict(net, image_tensor, orig_size, hypar, device)
126
 
127
  pil_mask = Image.fromarray(mask).convert('L')
128
+ im_rgb = Image.open(path).convert("RGB")
129
  im_rgba = im_rgb.copy()
130
  im_rgba.putalpha(pil_mask)
131
  processed_pairs.append([im_rgba, pil_mask])
132
+
133
  end_time = time.time()
134
  elapsed = round(end_time - start_time, 2)
135
 
136
+ # Flatten into final gallery list
137
  final_images = []
138
  for pair in processed_pairs:
139
  final_images.extend(pair)
140
 
141
+ logs += f"Processed {len(processed_pairs)} image(s) in {elapsed} second(s).\n"
 
 
142
 
143
+ # Return the flattened gallery, state, and logs text
144
  return final_images, logs, logs
145
 
146
  title = "Highly Accurate Dichotomous Image Segmentation"
147
  description = (
148
+ "This is an unofficial demo for DIS, a model that can remove the background from up to 3 images. "
149
+ "Simply upload 1 to 3 images, or use the example images. "
 
150
  "GitHub: https://github.com/xuebinqin/DIS<br>"
151
  "Telegram bot: https://t.me/restoration_photo_bot<br>"
152
  "[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)"
 
158
 
159
  interface = gr.Interface(
160
  fn=inference,
161
+ inputs=[
162
+ gr.Image(type='filepath', label='Image 1'),
163
+ gr.Image(type='filepath', label='Image 2'),
164
+ gr.Image(type='filepath', label='Image 3'),
165
+ gr.State()
166
+ ],
 
167
  outputs=[
168
  gr.Gallery(label="Output (rgba + mask)"),
169
  gr.State(),
170
  gr.Textbox(label="Logs", lines=6)
171
  ],
172
+ examples=[
173
+ ["robot.png", None, None],
174
+ ["robot.png", "ship.png", None],
175
+ ],
176
  title=title,
177
  description=description,
178
  article=article,
179
  flagging_mode="never",
180
+ cache_mode="lazy"
181
  ).queue().launch(show_api=True, show_error=True)