user-agent commited on
Commit
3fcc660
·
verified ·
1 Parent(s): ff4155e

created app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import uuid
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+ from typing import Union, List
8
+ from loadimg import load_img # Your helper to load from URL or file
9
+
10
+ torch.set_float32_matmul_precision("high")
11
+
12
+ # Load BiRefNet model
13
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
14
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
15
+ )
16
+ birefnet.to("cuda")
17
+
18
+ # Image transformation
19
+ transform_image = transforms.Compose([
20
+ transforms.Resize((1024, 1024)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23
+ ])
24
+
25
+ def process(image: Image.Image) -> Image.Image:
26
+ image_size = image.size
27
+ input_tensor = transform_image(image).unsqueeze(0).to("cuda")
28
+
29
+ with torch.no_grad():
30
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
31
+
32
+ pred = preds[0].squeeze()
33
+ mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
34
+ binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
35
+
36
+ white_bg = Image.new("RGB", image_size, (255, 255, 255))
37
+ result = Image.composite(image, white_bg, binary_mask)
38
+ return result
39
+
40
+ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
41
+ results = []
42
+
43
+ try:
44
+ # Single image upload
45
+ if image is not None:
46
+ image = image.convert("RGB")
47
+ processed = process(image)
48
+ filename = f"output_{uuid.uuid4().hex[:8]}.png"
49
+ processed.save(filename)
50
+ return filename
51
+
52
+ # Single image from URL
53
+ if image_url:
54
+ im = load_img(image_url, output_type="pil").convert("RGB")
55
+ processed = process(im)
56
+ filename = f"output_{uuid.uuid4().hex[:8]}.png"
57
+ processed.save(filename)
58
+ return filename
59
+
60
+ # Batch of URLs
61
+ if batch_urls:
62
+ urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
63
+ for url in urls:
64
+ try:
65
+ im = load_img(url, output_type="pil").convert("RGB")
66
+ processed = process(im)
67
+ filename = f"output_{uuid.uuid4().hex[:8]}.png"
68
+ processed.save(filename)
69
+ results.append(filename)
70
+ except Exception as e:
71
+ print(f"Error with {url}: {e}")
72
+ return results if results else None
73
+
74
+ except Exception as e:
75
+ print("General error:", e)
76
+
77
+ return None
78
+
79
+ # Interface
80
+ demo = gr.Interface(
81
+ fn=handler,
82
+ inputs=[
83
+ gr.Image(label="Upload Image", type="pil", optional=True),
84
+ gr.Textbox(label="Paste Image URL", optional=True),
85
+ gr.Textbox(label="Comma-separated Image URLs (Batch)", optional=True),
86
+ ],
87
+ outputs=gr.File(label="Output File(s)", file_count="multiple"),
88
+ title="Background Remover (White Fill)",
89
+ description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch(show_error=True, mcp_server=True)