sidbhasin commited on
Commit
e571436
·
verified ·
1 Parent(s): 7e7a7aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -5,31 +5,38 @@ import numpy as np
5
  from PIL import Image
6
  import io
7
 
8
- # First, ensure all required dependencies are installed
9
- try:
10
- import torchvision
11
- import skimage
12
- except ImportError:
13
- print("Installing required packages...")
14
- import subprocess
15
- subprocess.check_call(["pip", "install", "torchvision", "scikit-image"])
16
- import torchvision
17
- import skimage
18
-
19
  def remove_background(input_image):
20
  try:
21
- # Initialize the pipeline with correct parameters and dependencies
 
 
 
 
22
  segmentor = pipeline(
23
- "image-segmentation",
24
  model="briaai/RMBG-1.4",
25
- trust_remote_code=True,
26
- device="cpu",
27
- framework="pt"
 
 
 
 
28
  )
29
 
30
- # Process the image
31
- result = segmentor(input_image)
32
- return result['output_image']
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
  raise gr.Error(f"Error processing image: {str(e)}")
35
 
 
5
  from PIL import Image
6
  import io
7
 
 
 
 
 
 
 
 
 
 
 
 
8
  def remove_background(input_image):
9
  try:
10
+ # Convert input to PIL Image if it's not already
11
+ if not isinstance(input_image, Image.Image):
12
+ input_image = Image.fromarray(input_image)
13
+
14
+ # Initialize the pipeline
15
  segmentor = pipeline(
16
+ task="image-segmentation",
17
  model="briaai/RMBG-1.4",
18
+ trust_remote_code=True
19
+ )
20
+
21
+ # Process the image and get mask
22
+ result = segmentor(
23
+ input_image,
24
+ return_mask=True
25
  )
26
 
27
+ # Create output image with transparent background
28
+ output_image = Image.new('RGBA', input_image.size, (0, 0, 0, 0))
29
+
30
+ # Convert input to RGBA if it's not already
31
+ if input_image.mode != 'RGBA':
32
+ input_image = input_image.convert('RGBA')
33
+
34
+ # Apply mask to create transparent background
35
+ mask = result['mask'] if isinstance(result, dict) else result
36
+ output_image.paste(input_image, mask=mask)
37
+
38
+ return output_image
39
+
40
  except Exception as e:
41
  raise gr.Error(f"Error processing image: {str(e)}")
42