ghostsInTheMachine commited on
Commit
eeef7f4
·
verified ·
1 Parent(s): 3304489

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -26
app.py CHANGED
@@ -3,7 +3,7 @@ import cv2
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
- import spaces
7
 
8
  from PIL import Image, ImageOps
9
  from transformers import AutoModelForImageSegmentation
@@ -56,18 +56,19 @@ class ImagePreprocessor():
56
  image = self.transform_image(image)
57
  return image
58
 
59
- birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet-matting', trust_remote_code=True)
 
60
  birefnet.to(device)
61
  birefnet.eval()
62
 
63
- @spaces.GPU
64
  def remove_background(image):
65
  if image is None:
66
  raise gr.Error("Please upload an image.")
67
 
68
  image_ori = Image.fromarray(image).convert('RGB')
69
  original_size = image_ori.size
70
-
71
  # Preprocess the image
72
  image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
73
  image_proc = image_preprocessor.proc(image_ori)
@@ -81,10 +82,9 @@ def remove_background(image):
81
  # Process Results
82
  pred_pil = transforms.ToPILImage()(pred)
83
  pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
84
-
85
- # Create reverse mask
86
- reverse_mask = Image.new('L', original_size)
87
- reverse_mask.paste(ImageOps.invert(pred_pil))
88
 
89
  # Create foreground image (object with transparent background)
90
  foreground = image_ori.copy()
@@ -96,29 +96,17 @@ def remove_background(image):
96
 
97
  torch.cuda.empty_cache()
98
 
99
- # Save all images
100
- mask_path = "mask.png"
101
- pred_pil.save(mask_path)
102
-
103
- reverse_mask_path = "reverse_mask.png"
104
- reverse_mask.save(reverse_mask_path)
105
-
106
- foreground_path = "foreground.png"
107
- foreground.save(foreground_path)
108
-
109
- background_path = "background.png"
110
- background.save(background_path)
111
-
112
- return mask_path, reverse_mask_path, foreground_path, background_path
113
 
114
  iface = gr.Interface(
115
  fn=remove_background,
116
  inputs=gr.Image(type="numpy"),
117
  outputs=[
118
- gr.Image(type="filepath", label="Mask"),
119
- gr.Image(type="filepath", label="Reverse Mask"),
120
- gr.Image(type="filepath", label="Foreground"),
121
- gr.Image(type="filepath", label="Background")
122
  ],
123
  allow_flagging="never"
124
  )
 
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
+ import spaces # Added import for spaces
7
 
8
  from PIL import Image, ImageOps
9
  from transformers import AutoModelForImageSegmentation
 
56
  image = self.transform_image(image)
57
  return image
58
 
59
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
60
+ 'zhengpeng7/BiRefNet-matting', trust_remote_code=True)
61
  birefnet.to(device)
62
  birefnet.eval()
63
 
64
+ @spaces.GPU # Added the @spaces.GPU decorator
65
  def remove_background(image):
66
  if image is None:
67
  raise gr.Error("Please upload an image.")
68
 
69
  image_ori = Image.fromarray(image).convert('RGB')
70
  original_size = image_ori.size
71
+
72
  # Preprocess the image
73
  image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
74
  image_proc = image_preprocessor.proc(image_ori)
 
82
  # Process Results
83
  pred_pil = transforms.ToPILImage()(pred)
84
  pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
85
+
86
+ # Create reverse mask (background mask)
87
+ reverse_mask = ImageOps.invert(pred_pil)
 
88
 
89
  # Create foreground image (object with transparent background)
90
  foreground = image_ori.copy()
 
96
 
97
  torch.cuda.empty_cache()
98
 
99
+ # Return images in the specified order
100
+ return foreground, background, pred_pil, reverse_mask
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  iface = gr.Interface(
103
  fn=remove_background,
104
  inputs=gr.Image(type="numpy"),
105
  outputs=[
106
+ gr.Image(type="pil", label="Foreground"),
107
+ gr.Image(type="pil", label="Background"),
108
+ gr.Image(type="pil", label="Foreground Mask"),
109
+ gr.Image(type="pil", label="Background Mask")
110
  ],
111
  allow_flagging="never"
112
  )