ghostsInTheMachine commited on
Commit
cb61e6f
·
verified ·
1 Parent(s): 4adc5f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -60,16 +60,16 @@ birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet-ma
60
  birefnet.to(device)
61
  birefnet.eval()
62
 
 
63
  def predict(image):
64
  if image is None:
65
  raise gr.Error("Please upload an image.")
66
 
67
- image_ori = Image.fromarray(image)
68
- image = image_ori.convert('RGB')
69
 
70
  # Preprocess the image
71
  image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
72
- image_proc = image_preprocessor.proc(image)
73
  image_proc = image_proc.unsqueeze(0)
74
 
75
  # Prediction
@@ -79,8 +79,8 @@ def predict(image):
79
 
80
  # Show Results
81
  pred_pil = transforms.ToPILImage()(pred)
82
- image_masked = refine_foreground(image, pred_pil)
83
- image_masked.putalpha(pred_pil.resize(image.size))
84
 
85
  torch.cuda.empty_cache()
86
 
@@ -94,8 +94,6 @@ iface = gr.Interface(
94
  fn=predict,
95
  inputs=gr.Image(type="numpy"),
96
  outputs=gr.Image(type="filepath"),
97
- title="BiRefNet Matting",
98
- description="Upload an image to perform matting using BiRefNet."
99
  )
100
 
101
  if __name__ == "__main__":
 
60
  birefnet.to(device)
61
  birefnet.eval()
62
 
63
+ @spaces.GPU
64
  def predict(image):
65
  if image is None:
66
  raise gr.Error("Please upload an image.")
67
 
68
+ image_ori = Image.fromarray(image).convert('RGB')
 
69
 
70
  # Preprocess the image
71
  image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
72
+ image_proc = image_preprocessor.proc(image_ori)
73
  image_proc = image_proc.unsqueeze(0)
74
 
75
  # Prediction
 
79
 
80
  # Show Results
81
  pred_pil = transforms.ToPILImage()(pred)
82
+ image_masked = refine_foreground(image_ori, pred_pil)
83
+ image_masked.putalpha(pred_pil.resize(image_ori.size))
84
 
85
  torch.cuda.empty_cache()
86
 
 
94
  fn=predict,
95
  inputs=gr.Image(type="numpy"),
96
  outputs=gr.Image(type="filepath"),
 
 
97
  )
98
 
99
  if __name__ == "__main__":