sidbhasin commited on
Commit
ba2ffd5
1 Parent(s): 8fcbb28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -5,33 +5,36 @@ import numpy as np
5
  from PIL import Image
6
  import io
7
 
 
 
 
 
 
 
 
 
 
 
 
8
  def remove_background(input_image):
9
  try:
10
- # Initialize the pipeline with correct parameters
11
  segmentor = pipeline(
12
- task="image-segmentation",
13
  model="briaai/RMBG-1.4",
14
- trust_remote_code=True
 
 
15
  )
16
 
17
- # Convert input to PIL Image if it's not already
18
- if not isinstance(input_image, Image.Image):
19
- input_image = Image.fromarray(input_image)
20
-
21
  # Process the image
22
- result = segmentor(input_image, return_mask=True)
23
-
24
- # Create transparent background image
25
- output_image = Image.new('RGBA', input_image.size, (0, 0, 0, 0))
26
- output_image.paste(input_image, mask=result)
27
-
28
- return output_image
29
-
30
  except Exception as e:
31
  raise gr.Error(f"Error processing image: {str(e)}")
32
 
33
  # Create Gradio interface
34
- with gr.Blocks(theme=gr.themes.Default()) as demo:
35
  gr.HTML(
36
  """
37
  <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
 
5
  from PIL import Image
6
  import io
7
 
8
+ # First, ensure all required dependencies are installed
9
+ try:
10
+ import torchvision
11
+ import skimage
12
+ except ImportError:
13
+ print("Installing required packages...")
14
+ import subprocess
15
+ subprocess.check_call(["pip", "install", "torchvision", "scikit-image"])
16
+ import torchvision
17
+ import skimage
18
+
19
  def remove_background(input_image):
20
  try:
21
+ # Initialize the pipeline with correct parameters and dependencies
22
  segmentor = pipeline(
23
+ "image-segmentation",
24
  model="briaai/RMBG-1.4",
25
+ trust_remote_code=True,
26
+ device="cpu",
27
+ framework="pt"
28
  )
29
 
 
 
 
 
30
  # Process the image
31
+ result = segmentor(input_image)
32
+ return result['output_image']
 
 
 
 
 
 
33
  except Exception as e:
34
  raise gr.Error(f"Error processing image: {str(e)}")
35
 
36
  # Create Gradio interface
37
+ with gr.Blocks() as demo:
38
  gr.HTML(
39
  """
40
  <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">