user-agent commited on
Commit
3d3a8e1
·
verified ·
1 Parent(s): 4775a16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -5
app.py CHANGED
@@ -1,11 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import uuid
 
4
  from PIL import Image
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation
7
  from typing import Union, List
8
  from loadimg import load_img # Your helper to load from URL or file
 
9
 
10
  torch.set_float32_matmul_precision("high")
11
 
@@ -23,6 +122,17 @@ transform_image = transforms.Compose([
23
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
  ])
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def process(image: Image.Image) -> Image.Image:
27
  image_size = image.size
28
  input_tensor = transform_image(image).unsqueeze(0).to(device)
@@ -50,20 +160,20 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
50
  processed.save(filename)
51
  return filename
52
 
53
- # Single image from URL
54
  if image_url:
55
- im = load_img(image_url, output_type="pil").convert("RGB")
56
  processed = process(im)
57
  filename = f"output_{uuid.uuid4().hex[:8]}.png"
58
  processed.save(filename)
59
  return filename
60
 
61
- # Batch of URLs
62
  if batch_urls:
63
  urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
64
  for url in urls:
65
  try:
66
- im = load_img(url, output_type="pil").convert("RGB")
67
  processed = process(im)
68
  filename = f"output_{uuid.uuid4().hex[:8]}.png"
69
  processed.save(filename)
@@ -91,4 +201,4 @@ demo = gr.Interface(
91
  )
92
 
93
  if __name__ == "__main__":
94
- demo.launch(show_error=True, mcp_server=True)
 
1
+ # import gradio as gr
2
+ # import torch
3
+ # import uuid
4
+ # from PIL import Image
5
+ # from torchvision import transforms
6
+ # from transformers import AutoModelForImageSegmentation
7
+ # from typing import Union, List
8
+ # from loadimg import load_img # Your helper to load from URL or file
9
+
10
+ # torch.set_float32_matmul_precision("high")
11
+
12
+ # # Load BiRefNet model
13
+ # birefnet = AutoModelForImageSegmentation.from_pretrained(
14
+ # "ZhengPeng7/BiRefNet", trust_remote_code=True
15
+ # )
16
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ # birefnet.to(device)
18
+
19
+ # # Image transformation
20
+ # transform_image = transforms.Compose([
21
+ # transforms.Resize((1024, 1024)),
22
+ # transforms.ToTensor(),
23
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
+ # ])
25
+
26
+ # def process(image: Image.Image) -> Image.Image:
27
+ # image_size = image.size
28
+ # input_tensor = transform_image(image).unsqueeze(0).to(device)
29
+
30
+ # with torch.no_grad():
31
+ # preds = birefnet(input_tensor)[-1].sigmoid().cpu()
32
+
33
+ # pred = preds[0].squeeze()
34
+ # mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
35
+ # binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
36
+
37
+ # white_bg = Image.new("RGB", image_size, (255, 255, 255))
38
+ # result = Image.composite(image, white_bg, binary_mask)
39
+ # return result
40
+
41
+ # def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
42
+ # results = []
43
+
44
+ # try:
45
+ # # Single image upload
46
+ # if image is not None:
47
+ # image = image.convert("RGB")
48
+ # processed = process(image)
49
+ # filename = f"output_{uuid.uuid4().hex[:8]}.png"
50
+ # processed.save(filename)
51
+ # return filename
52
+
53
+ # # Single image from URL
54
+ # if image_url:
55
+ # im = load_img(image_url, output_type="pil").convert("RGB")
56
+ # processed = process(im)
57
+ # filename = f"output_{uuid.uuid4().hex[:8]}.png"
58
+ # processed.save(filename)
59
+ # return filename
60
+
61
+ # # Batch of URLs
62
+ # if batch_urls:
63
+ # urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
64
+ # for url in urls:
65
+ # try:
66
+ # im = load_img(url, output_type="pil").convert("RGB")
67
+ # processed = process(im)
68
+ # filename = f"output_{uuid.uuid4().hex[:8]}.png"
69
+ # processed.save(filename)
70
+ # results.append(filename)
71
+ # except Exception as e:
72
+ # print(f"Error with {url}: {e}")
73
+ # return results if results else None
74
+
75
+ # except Exception as e:
76
+ # print("General error:", e)
77
+
78
+ # return None
79
+
80
+ # # Interface
81
+ # demo = gr.Interface(
82
+ # fn=handler,
83
+ # inputs=[
84
+ # gr.Image(label="Upload Image", type="pil"),
85
+ # gr.Textbox(label="Paste Image URL"),
86
+ # gr.Textbox(label="Comma-separated Image URLs (Batch)"),
87
+ # ],
88
+ # outputs=gr.File(label="Output File(s)", file_count="multiple"),
89
+ # title="Background Remover (White Fill)",
90
+ # description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
91
+ # )
92
+
93
+ # if __name__ == "__main__":
94
+ # demo.launch(show_error=True, mcp_server=True)
95
+
96
+
97
+
98
  import gradio as gr
99
  import torch
100
  import uuid
101
+ import base64
102
  from PIL import Image
103
  from torchvision import transforms
104
  from transformers import AutoModelForImageSegmentation
105
  from typing import Union, List
106
  from loadimg import load_img # Your helper to load from URL or file
107
+ from io import BytesIO
108
 
109
  torch.set_float32_matmul_precision("high")
110
 
 
122
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
123
  ])
124
 
125
+ def load_image_from_data_url(data_url: str) -> Image.Image:
126
+ """Load image from base64 data URL"""
127
+ if data_url.startswith("data:image/"):
128
+ # Extract base64 data after the comma
129
+ header, encoded = data_url.split(",", 1)
130
+ image_data = base64.b64decode(encoded)
131
+ return Image.open(BytesIO(image_data))
132
+ else:
133
+ # Regular URL, use existing load_img function
134
+ return load_img(data_url, output_type="pil")
135
+
136
  def process(image: Image.Image) -> Image.Image:
137
  image_size = image.size
138
  input_tensor = transform_image(image).unsqueeze(0).to(device)
 
160
  processed.save(filename)
161
  return filename
162
 
163
+ # Single image from URL (supports both regular URLs and base64 data URLs)
164
  if image_url:
165
+ im = load_image_from_data_url(image_url).convert("RGB")
166
  processed = process(im)
167
  filename = f"output_{uuid.uuid4().hex[:8]}.png"
168
  processed.save(filename)
169
  return filename
170
 
171
+ # Batch of URLs (supports both regular URLs and base64 data URLs)
172
  if batch_urls:
173
  urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
174
  for url in urls:
175
  try:
176
+ im = load_image_from_data_url(url).convert("RGB")
177
  processed = process(im)
178
  filename = f"output_{uuid.uuid4().hex[:8]}.png"
179
  processed.save(filename)
 
201
  )
202
 
203
  if __name__ == "__main__":
204
+ demo.launch(show_error=True, mcp_server=True)