ghostsInTheMachine commited on
Commit
c36a9d3
·
verified ·
1 Parent(s): eeef7f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -3,7 +3,7 @@ import cv2
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
- import spaces # Added import for spaces
7
 
8
  from PIL import Image, ImageOps
9
  from transformers import AutoModelForImageSegmentation
@@ -49,24 +49,30 @@ class ImagePreprocessor():
49
  self.transform_image = transforms.Compose([
50
  transforms.Resize(resolution),
51
  transforms.ToTensor(),
52
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
 
53
  ])
54
 
55
  def proc(self, image: Image.Image) -> torch.Tensor:
56
  image = self.transform_image(image)
57
  return image
58
 
 
59
  birefnet = AutoModelForImageSegmentation.from_pretrained(
60
  'zhengpeng7/BiRefNet-matting', trust_remote_code=True)
61
  birefnet.to(device)
62
  birefnet.eval()
63
 
64
- @spaces.GPU # Added the @spaces.GPU decorator
65
- def remove_background(image):
66
  if image is None:
67
  raise gr.Error("Please upload an image.")
68
-
69
  image_ori = Image.fromarray(image).convert('RGB')
 
 
 
 
 
 
70
  original_size = image_ori.size
71
 
72
  # Preprocess the image
@@ -100,7 +106,7 @@ def remove_background(image):
100
  return foreground, background, pred_pil, reverse_mask
101
 
102
  iface = gr.Interface(
103
- fn=remove_background,
104
  inputs=gr.Image(type="numpy"),
105
  outputs=[
106
  gr.Image(type="pil", label="Foreground"),
 
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
+ import spaces # Required for @spaces.GPU
7
 
8
  from PIL import Image, ImageOps
9
  from transformers import AutoModelForImageSegmentation
 
49
  self.transform_image = transforms.Compose([
50
  transforms.Resize(resolution),
51
  transforms.ToTensor(),
52
+ transforms.Normalize([0.485, 0.456, 0.406],
53
+ [0.229, 0.224, 0.225]),
54
  ])
55
 
56
  def proc(self, image: Image.Image) -> torch.Tensor:
57
  image = self.transform_image(image)
58
  return image
59
 
60
+ # Load the model
61
  birefnet = AutoModelForImageSegmentation.from_pretrained(
62
  'zhengpeng7/BiRefNet-matting', trust_remote_code=True)
63
  birefnet.to(device)
64
  birefnet.eval()
65
 
66
+ def remove_background_wrapper(image):
 
67
  if image is None:
68
  raise gr.Error("Please upload an image.")
 
69
  image_ori = Image.fromarray(image).convert('RGB')
70
+ # Call the processing function
71
+ foreground, background, pred_pil, reverse_mask = remove_background(image_ori)
72
+ return foreground, background, pred_pil, reverse_mask
73
+
74
+ @spaces.GPU # Decorate the processing function
75
+ def remove_background(image_ori):
76
  original_size = image_ori.size
77
 
78
  # Preprocess the image
 
106
  return foreground, background, pred_pil, reverse_mask
107
 
108
  iface = gr.Interface(
109
+ fn=remove_background_wrapper,
110
  inputs=gr.Image(type="numpy"),
111
  outputs=[
112
  gr.Image(type="pil", label="Foreground"),