user-agent commited on
Commit
c93abb9
·
verified ·
1 Parent(s): 1a1c89b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -32
app.py CHANGED
@@ -6,23 +6,22 @@ from PIL import Image
6
  from torchvision import transforms
7
  from transformers import AutoModelForImageSegmentation
8
  from typing import Union, List
9
- from loadimg import load_img
10
 
11
  torch.set_float32_matmul_precision("high")
12
 
13
- # Load RMBG v1.4 model
14
- model = AutoModelForImageSegmentation.from_pretrained(
15
- "briaai/RMBG-1.4",
16
- trust_remote_code=True
17
  )
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- model.to(device)
20
 
21
- # Transform for RMBG v1.4
22
  transform_image = transforms.Compose([
23
  transforms.Resize((1024, 1024)),
24
  transforms.ToTensor(),
25
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
26
  ])
27
 
28
  @spaces.GPU
@@ -31,28 +30,12 @@ def process(image: Image.Image) -> Image.Image:
31
  input_tensor = transform_image(image).unsqueeze(0).to(device)
32
 
33
  with torch.no_grad():
34
- preds = model(input_tensor)
35
-
36
- # Handle list output - extract the tensor from the list
37
- if isinstance(preds, list):
38
- # Usually the mask is the last or first element
39
- pred = preds[-1] if len(preds) > 0 else preds[0]
40
- elif isinstance(preds, tuple):
41
- pred = preds[0]
42
- else:
43
- pred = preds
44
-
45
- # Now apply sigmoid to the tensor
46
- mask = pred.sigmoid().cpu()
47
 
48
- # Process the mask
49
- mask_tensor = mask[0].squeeze()
50
- mask_pil = transforms.ToPILImage()(mask_tensor).resize(image_size).convert("L")
51
-
52
- # Create binary mask with threshold
53
- binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)
54
 
55
- # Apply mask with white background
56
  white_bg = Image.new("RGB", image_size, (255, 255, 255))
57
  result = Image.composite(image, white_bg, binary_mask)
58
  return result
@@ -62,6 +45,7 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
62
  results = []
63
 
64
  try:
 
65
  if image is not None:
66
  image = image.convert("RGB")
67
  processed = process(image)
@@ -69,6 +53,7 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
69
  processed.save(filename)
70
  return filename
71
 
 
72
  if image_url:
73
  im = load_img(image_url, output_type="pil").convert("RGB")
74
  processed = process(im)
@@ -76,6 +61,7 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
76
  processed.save(filename)
77
  return filename
78
 
 
79
  if batch_urls:
80
  urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
81
  for url in urls:
@@ -91,11 +77,10 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
91
 
92
  except Exception as e:
93
  print("General error:", e)
94
- import traceback
95
- traceback.print_exc()
96
 
97
  return None
98
 
 
99
  demo = gr.Interface(
100
  fn=handler,
101
  inputs=[
@@ -104,9 +89,9 @@ demo = gr.Interface(
104
  gr.Textbox(label="Comma-separated Image URLs (Batch)"),
105
  ],
106
  outputs=gr.File(label="Output File(s)", file_count="multiple"),
107
- title="Background Remover (RMBG v1.4)",
108
  description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
109
  )
110
 
111
  if __name__ == "__main__":
112
- demo.launch(show_error=True, mcp_server=True)
 
6
  from torchvision import transforms
7
  from transformers import AutoModelForImageSegmentation
8
  from typing import Union, List
9
+ from loadimg import load_img # Your helper to load from URL or file
10
 
11
  torch.set_float32_matmul_precision("high")
12
 
13
+ # Load BiRefNet model
14
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
15
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
 
16
  )
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ birefnet.to(device)
19
 
20
+ # Image transformation
21
  transform_image = transforms.Compose([
22
  transforms.Resize((1024, 1024)),
23
  transforms.ToTensor(),
24
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
25
  ])
26
 
27
  @spaces.GPU
 
30
  input_tensor = transform_image(image).unsqueeze(0).to(device)
31
 
32
  with torch.no_grad():
33
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ pred = preds[0].squeeze()
36
+ mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
37
+ binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
 
 
 
38
 
 
39
  white_bg = Image.new("RGB", image_size, (255, 255, 255))
40
  result = Image.composite(image, white_bg, binary_mask)
41
  return result
 
45
  results = []
46
 
47
  try:
48
+ # Single image upload
49
  if image is not None:
50
  image = image.convert("RGB")
51
  processed = process(image)
 
53
  processed.save(filename)
54
  return filename
55
 
56
+ # Single image from URL
57
  if image_url:
58
  im = load_img(image_url, output_type="pil").convert("RGB")
59
  processed = process(im)
 
61
  processed.save(filename)
62
  return filename
63
 
64
+ # Batch of URLs
65
  if batch_urls:
66
  urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
67
  for url in urls:
 
77
 
78
  except Exception as e:
79
  print("General error:", e)
 
 
80
 
81
  return None
82
 
83
+ # Interface
84
  demo = gr.Interface(
85
  fn=handler,
86
  inputs=[
 
89
  gr.Textbox(label="Comma-separated Image URLs (Batch)"),
90
  ],
91
  outputs=gr.File(label="Output File(s)", file_count="multiple"),
92
+ title="Background Remover (White Fill)",
93
  description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
94
  )
95
 
96
  if __name__ == "__main__":
97
+ demo.launch(show_error=True, mcp_server=True)