petergpt commited on
Commit
3472d22
·
verified ·
1 Parent(s): 07d7c0a

multiple upload

Browse files
Files changed (1) hide show
  1. app.py +26 -49
app.py CHANGED
@@ -12,7 +12,7 @@ import warnings
12
  import time
13
  warnings.filterwarnings("ignore")
14
 
15
- # Clone the DIS repo and move contents (make sure this only happens once per session)
16
  os.system("git clone https://github.com/xuebinqin/DIS")
17
  os.system("mv DIS/IS-Net/* .")
18
 
@@ -22,22 +22,21 @@ from models import *
22
 
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
- # Download official weights
26
  if not os.path.exists("saved_models"):
27
  os.mkdir("saved_models")
28
  os.system("mv isnet.pth saved_models/")
29
 
30
  class GOSNormalize(object):
31
- '''
32
- Normalize the Image using torch.transforms
33
- '''
34
- def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
35
  self.mean = mean
36
  self.std = std
37
 
38
- def __call__(self,image):
39
- image = normalize(image, self.mean, self.std)
40
- return image
41
 
42
  transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
43
 
@@ -50,14 +49,11 @@ def load_image(im_path, hypar):
50
 
51
  def build_model(hypar, device):
52
  net = hypar["model"]
53
-
54
- # convert to half precision if needed
55
- if(hypar["model_digit"]=="half"):
56
  net.half()
57
  for layer in net.modules():
58
- if isinstance(layer, nn.BatchNorm2d):
59
  layer.float()
60
-
61
  net.to(device)
62
  if hypar["restore_model"] != "":
63
  net.load_state_dict(torch.load(os.path.join(hypar["model_path"], hypar["restore_model"]), map_location=device))
@@ -67,24 +63,19 @@ def build_model(hypar, device):
67
 
68
  def predict(net, inputs_val, shapes_val, hypar, device):
69
  net.eval()
70
-
71
  if hypar["model_digit"] == "full":
72
  inputs_val = inputs_val.type(torch.FloatTensor)
73
  else:
74
  inputs_val = inputs_val.type(torch.HalfTensor)
75
-
76
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
77
  ds_val = net(inputs_val_v)[0]
78
  pred_val = ds_val[0][0, :, :, :]
79
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0),
80
- (shapes_val[0][0], shapes_val[0][1]),
81
- mode='bilinear'))
82
-
83
  ma = torch.max(pred_val)
84
  mi = torch.min(pred_val)
85
- # normalize to [0, 1], add a small epsilon to avoid division by zero
86
  pred_val = (pred_val - mi) / (ma - mi + 1e-8)
87
-
88
  if device == 'cuda':
89
  torch.cuda.empty_cache()
90
  return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
@@ -102,51 +93,39 @@ hypar = {
102
  "model": ISNetDIS()
103
  }
104
 
105
- # Build the model
106
  net = build_model(hypar, device)
107
 
108
- def inference(img1, img2, img3, logs):
109
  """
110
- Process up to 3 images in parallel (each can be None if not provided).
111
  """
112
  start_time = time.time()
113
- logs = logs or "" # initialize logs if None
114
-
115
- # Gather images into a list (filter out None)
116
- image_paths = [i for i in [img1, img2, img3] if i is not None]
117
- if not image_paths:
118
- # No images were uploaded
119
- logs += f"No images to process.\n"
120
  return [], logs, logs
121
 
 
 
122
  processed_pairs = []
123
  for path in image_paths:
124
  image_tensor, orig_size = load_image(path, hypar)
125
  mask = predict(net, image_tensor, orig_size, hypar, device)
126
-
127
  pil_mask = Image.fromarray(mask).convert('L')
128
  im_rgb = Image.open(path).convert("RGB")
129
  im_rgba = im_rgb.copy()
130
  im_rgba.putalpha(pil_mask)
131
  processed_pairs.append([im_rgba, pil_mask])
132
 
133
- end_time = time.time()
134
- elapsed = round(end_time - start_time, 2)
135
-
136
- # Flatten into final gallery list
137
- final_images = []
138
- for pair in processed_pairs:
139
- final_images.extend(pair)
140
-
141
  logs += f"Processed {len(processed_pairs)} image(s) in {elapsed} second(s).\n"
142
-
143
- # Return the flattened gallery, state, and logs text
144
  return final_images, logs, logs
145
 
146
  title = "Highly Accurate Dichotomous Image Segmentation"
147
  description = (
148
- "This is an unofficial demo for DIS, a model that can remove the background from up to 3 images. "
149
- "Simply upload 1 to 3 images, or use the example images. "
150
  "GitHub: https://github.com/xuebinqin/DIS<br>"
151
  "Telegram bot: https://t.me/restoration_photo_bot<br>"
152
  "[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)"
@@ -159,9 +138,7 @@ article = (
159
  interface = gr.Interface(
160
  fn=inference,
161
  inputs=[
162
- gr.Image(type='filepath', label='Image 1'),
163
- gr.Image(type='filepath', label='Image 2'),
164
- gr.Image(type='filepath', label='Image 3'),
165
  gr.State()
166
  ],
167
  outputs=[
@@ -170,8 +147,8 @@ interface = gr.Interface(
170
  gr.Textbox(label="Logs", lines=6)
171
  ],
172
  examples=[
173
- ["robot.png", None, None],
174
- ["robot.png", "ship.png", None],
175
  ],
176
  title=title,
177
  description=description,
 
12
  import time
13
  warnings.filterwarnings("ignore")
14
 
15
+ # Clone the DIS repo and move contents (ensure this runs once per session)
16
  os.system("git clone https://github.com/xuebinqin/DIS")
17
  os.system("mv DIS/IS-Net/* .")
18
 
 
22
 
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
+ # Download official weights if not already present
26
  if not os.path.exists("saved_models"):
27
  os.mkdir("saved_models")
28
  os.system("mv isnet.pth saved_models/")
29
 
30
  class GOSNormalize(object):
31
+ """
32
+ Normalize the Image using torch.transforms.
33
+ """
34
+ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
35
  self.mean = mean
36
  self.std = std
37
 
38
+ def __call__(self, image):
39
+ return normalize(image, self.mean, self.std)
 
40
 
41
  transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
42
 
 
49
 
50
  def build_model(hypar, device):
51
  net = hypar["model"]
52
+ if hypar["model_digit"] == "half":
 
 
53
  net.half()
54
  for layer in net.modules():
55
+ if isinstance(layer, torch.nn.BatchNorm2d):
56
  layer.float()
 
57
  net.to(device)
58
  if hypar["restore_model"] != "":
59
  net.load_state_dict(torch.load(os.path.join(hypar["model_path"], hypar["restore_model"]), map_location=device))
 
63
 
64
  def predict(net, inputs_val, shapes_val, hypar, device):
65
  net.eval()
 
66
  if hypar["model_digit"] == "full":
67
  inputs_val = inputs_val.type(torch.FloatTensor)
68
  else:
69
  inputs_val = inputs_val.type(torch.HalfTensor)
 
70
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
71
  ds_val = net(inputs_val_v)[0]
72
  pred_val = ds_val[0][0, :, :, :]
73
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0),
74
+ (shapes_val[0][0], shapes_val[0][1]),
75
+ mode='bilinear'))
 
76
  ma = torch.max(pred_val)
77
  mi = torch.min(pred_val)
 
78
  pred_val = (pred_val - mi) / (ma - mi + 1e-8)
 
79
  if device == 'cuda':
80
  torch.cuda.empty_cache()
81
  return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
 
93
  "model": ISNetDIS()
94
  }
95
 
 
96
  net = build_model(hypar, device)
97
 
98
+ def inference(file_paths, logs):
99
  """
100
+ Process up to 3 images uploaded via the file uploader.
101
  """
102
  start_time = time.time()
103
+ logs = logs or ""
104
+ if not file_paths:
105
+ logs += "No images to process.\n"
 
 
 
 
106
  return [], logs, logs
107
 
108
+ # Limit to a maximum of 3 images
109
+ image_paths = file_paths[:3]
110
  processed_pairs = []
111
  for path in image_paths:
112
  image_tensor, orig_size = load_image(path, hypar)
113
  mask = predict(net, image_tensor, orig_size, hypar, device)
 
114
  pil_mask = Image.fromarray(mask).convert('L')
115
  im_rgb = Image.open(path).convert("RGB")
116
  im_rgba = im_rgb.copy()
117
  im_rgba.putalpha(pil_mask)
118
  processed_pairs.append([im_rgba, pil_mask])
119
 
120
+ elapsed = round(time.time() - start_time, 2)
121
+ final_images = [img for pair in processed_pairs for img in pair]
 
 
 
 
 
 
122
  logs += f"Processed {len(processed_pairs)} image(s) in {elapsed} second(s).\n"
 
 
123
  return final_images, logs, logs
124
 
125
  title = "Highly Accurate Dichotomous Image Segmentation"
126
  description = (
127
+ "This is an unofficial demo for DIS, a model that removes the background from images. "
128
+ "Upload up to 3 images at once using the file uploader below. "
129
  "GitHub: https://github.com/xuebinqin/DIS<br>"
130
  "Telegram bot: https://t.me/restoration_photo_bot<br>"
131
  "[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)"
 
138
  interface = gr.Interface(
139
  fn=inference,
140
  inputs=[
141
+ gr.File(file_count="multiple", type="filepath", label="Upload Images (up to 3)"),
 
 
142
  gr.State()
143
  ],
144
  outputs=[
 
147
  gr.Textbox(label="Logs", lines=6)
148
  ],
149
  examples=[
150
+ [["robot.png"], None],
151
+ [["robot.png", "ship.png"], None],
152
  ],
153
  title=title,
154
  description=description,