ghostsInTheMachine commited on
Commit
6bafd2d
·
verified ·
1 Parent(s): 4446ce3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -28
app.py CHANGED
@@ -3,17 +3,55 @@ import cv2
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
- import spaces # Required for @spaces.GPU
7
-
 
8
  from PIL import Image, ImageOps
9
  from transformers import AutoModelForImageSegmentation
10
  from torchvision import transforms
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  torch.set_float32_matmul_precision('high')
13
  torch.jit.script = lambda f: f
14
-
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
17
  def refine_foreground(image, mask, r=90):
18
  if mask.size != image.size:
19
  mask = mask.resize(image.size)
@@ -33,14 +71,11 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
33
  if isinstance(image, Image.Image):
34
  image = np.array(image) / 255.0
35
  blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
36
-
37
  blurred_FA = cv2.blur(F * alpha, (r, r))
38
  blurred_F = blurred_FA / (blurred_alpha + 1e-5)
39
-
40
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
41
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
42
- F = blurred_F + alpha * \
43
- (image - alpha * blurred_F - (1 - alpha) * blurred_B)
44
  F = np.clip(F, 0, 1)
45
  return F, blurred_B
46
 
@@ -67,42 +102,33 @@ def remove_background_wrapper(image):
67
  if image is None:
68
  raise gr.Error("Please upload an image.")
69
  image_ori = Image.fromarray(image).convert('RGB')
70
- # Call the processing function
71
  foreground, background, pred_pil, reverse_mask = remove_background(image_ori)
72
  return foreground, background, pred_pil, reverse_mask
73
 
74
- @spaces.GPU # Decorate the processing function
75
  def remove_background(image_ori):
76
  original_size = image_ori.size
77
-
78
- # Preprocess the image
79
  image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
80
  image_proc = image_preprocessor.proc(image_ori)
81
  image_proc = image_proc.unsqueeze(0)
82
-
83
- # Prediction
84
  with torch.no_grad():
85
  preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
86
  pred = preds[0].squeeze()
87
-
88
- # Process Results
89
  pred_pil = transforms.ToPILImage()(pred)
90
- pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
91
-
92
- # Create reverse mask (background mask)
93
  reverse_mask = ImageOps.invert(pred_pil)
94
-
95
- # Create foreground image (object with transparent background)
96
  foreground = image_ori.copy()
97
  foreground.putalpha(pred_pil)
98
-
99
- # Create background image
100
  background = image_ori.copy()
101
  background.putalpha(reverse_mask)
102
-
103
  torch.cuda.empty_cache()
104
-
105
- # Return images in the specified order
106
  return foreground, background, pred_pil, reverse_mask
107
 
108
  # Custom CSS for button styling
@@ -123,11 +149,12 @@ custom_css = """
123
  animation: gradient-animation 15s ease infinite;
124
  border-radius: 12px;
125
  color: black;
 
126
  }
127
  """
128
 
129
- with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
130
- # Interface setup with input and output
131
  with gr.Row():
132
  with gr.Column():
133
  image_input = gr.Image(type="numpy", sources=['upload'], label="Upload Image")
@@ -138,7 +165,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
138
  output_foreground_mask = gr.Image(type="pil", label="Foreground Mask")
139
  output_background_mask = gr.Image(type="pil", label="Background Mask")
140
 
141
- # Link the button to the processing function
142
  btn.click(fn=remove_background_wrapper, inputs=image_input, outputs=[
143
  output_foreground, output_background, output_foreground_mask, output_background_mask])
144
 
 
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
+ import spaces
7
+ from gradio.themes.base import Base
8
+ from gradio.themes.utils import colors, fonts, sizes
9
  from PIL import Image, ImageOps
10
  from transformers import AutoModelForImageSegmentation
11
  from torchvision import transforms
12
 
13
+ # Custom White Theme with Inter font
14
+ class WhiteTheme(Base):
15
+ def __init__(
16
+ self,
17
+ *,
18
+ primary_hue: colors.Color | str = colors.orange,
19
+ font: fonts.Font | str = fonts.GoogleFont("Inter"),
20
+ font_mono: fonts.Font | str = fonts.GoogleFont("Inter")
21
+ ):
22
+ super().__init__(
23
+ primary_hue=primary_hue,
24
+ font=font,
25
+ font_mono=font_mono,
26
+ )
27
+
28
+ self.set(
29
+ body_background_fill="white",
30
+ block_background_fill="white",
31
+ panel_background_fill="white",
32
+ body_text_color="black",
33
+ block_label_text_color="black",
34
+ block_border_color="white",
35
+ panel_border_color="white",
36
+ input_border_color="lightgray",
37
+ button_primary_background_fill="*primary_500",
38
+ button_primary_background_fill_hover="*primary_600",
39
+ button_primary_text_color="white",
40
+ button_secondary_background_fill="white",
41
+ button_secondary_border_color="lightgray",
42
+ block_shadow="none",
43
+ button_shadow="none",
44
+ input_shadow="none",
45
+ slider_color="*primary_500",
46
+ slider_track_color="lightgray",
47
+ )
48
+
49
+ # Your existing setup code
50
  torch.set_float32_matmul_precision('high')
51
  torch.jit.script = lambda f: f
 
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
 
54
+ # Keep all your existing functions unchanged
55
  def refine_foreground(image, mask, r=90):
56
  if mask.size != image.size:
57
  mask = mask.resize(image.size)
 
71
  if isinstance(image, Image.Image):
72
  image = np.array(image) / 255.0
73
  blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
 
74
  blurred_FA = cv2.blur(F * alpha, (r, r))
75
  blurred_F = blurred_FA / (blurred_alpha + 1e-5)
 
76
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
77
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
78
+ F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
 
79
  F = np.clip(F, 0, 1)
80
  return F, blurred_B
81
 
 
102
  if image is None:
103
  raise gr.Error("Please upload an image.")
104
  image_ori = Image.fromarray(image).convert('RGB')
 
105
  foreground, background, pred_pil, reverse_mask = remove_background(image_ori)
106
  return foreground, background, pred_pil, reverse_mask
107
 
108
+ @spaces.GPU
109
  def remove_background(image_ori):
110
  original_size = image_ori.size
 
 
111
  image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
112
  image_proc = image_preprocessor.proc(image_ori)
113
  image_proc = image_proc.unsqueeze(0)
114
+
 
115
  with torch.no_grad():
116
  preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
117
  pred = preds[0].squeeze()
118
+
 
119
  pred_pil = transforms.ToPILImage()(pred)
120
+ pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
121
+
 
122
  reverse_mask = ImageOps.invert(pred_pil)
123
+
 
124
  foreground = image_ori.copy()
125
  foreground.putalpha(pred_pil)
126
+
 
127
  background = image_ori.copy()
128
  background.putalpha(reverse_mask)
129
+
130
  torch.cuda.empty_cache()
131
+
 
132
  return foreground, background, pred_pil, reverse_mask
133
 
134
  # Custom CSS for button styling
 
149
  animation: gradient-animation 15s ease infinite;
150
  border-radius: 12px;
151
  color: black;
152
+ font-family: 'Inter', sans-serif;
153
  }
154
  """
155
 
156
+ # Create the interface with the custom theme
157
+ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
158
  with gr.Row():
159
  with gr.Column():
160
  image_input = gr.Image(type="numpy", sources=['upload'], label="Upload Image")
 
165
  output_foreground_mask = gr.Image(type="pil", label="Foreground Mask")
166
  output_background_mask = gr.Image(type="pil", label="Background Mask")
167
 
 
168
  btn.click(fn=remove_background_wrapper, inputs=image_input, outputs=[
169
  output_foreground, output_background, output_foreground_mask, output_background_mask])
170