Imadsarvm commited on
Commit
4c3647c
·
verified ·
1 Parent(s): ebdf062
Files changed (1) hide show
  1. app.py +65 -68
app.py CHANGED
@@ -3,98 +3,95 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
6
- from gradio_imageslider import ImageSlider
7
  from briarmbg import BriaRMBG
8
  import PIL
9
  from PIL import Image
10
  from typing import Tuple
11
-
 
12
 
13
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.to(device)
16
 
17
-
18
  def resize_image(image):
19
  image = image.convert('RGB')
20
  model_input_size = (1024, 1024)
21
  image = image.resize(model_input_size, Image.BILINEAR)
22
  return image
23
 
 
 
 
 
 
 
24
 
25
- def process(image):
26
-
27
- # prepare input
28
- orig_image = Image.fromarray(image)
29
- w,h = orig_im_size = orig_image.size
30
- image = resize_image(orig_image)
31
- im_np = np.array(image)
32
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
33
- im_tensor = torch.unsqueeze(im_tensor,0)
34
- im_tensor = torch.divide(im_tensor,255.0)
35
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
36
- if torch.cuda.is_available():
37
- im_tensor=im_tensor.cuda()
38
-
39
- #inference
40
- result=net(im_tensor)
41
- # post process
42
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
43
- ma = torch.max(result)
44
- mi = torch.min(result)
45
- result = (result-mi)/(ma-mi)
46
- # image to pil
47
- im_array = (result*255).cpu().data.numpy().astype(np.uint8)
48
- pil_im = Image.fromarray(np.squeeze(im_array))
49
- # paste the mask on the original image
50
- new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
51
- new_im.paste(orig_image, mask=pil_im)
52
- # new_orig_image = orig_image.convert('RGBA')
53
-
54
- return new_im
55
- # return [new_orig_image, new_im]
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # block = gr.Blocks().queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # with block:
61
- # gr.Markdown("## BRIA RMBG 1.4")
62
- # gr.HTML('''
63
- # <p style="margin-bottom: 10px; font-size: 94%">
64
- # This is a demo for BRIA RMBG 1.4 that using
65
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
66
- # </p>
67
- # ''')
68
- # with gr.Row():
69
- # with gr.Column():
70
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
71
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
72
- # run_button = gr.Button(value="Run")
73
-
74
- # with gr.Column():
75
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
76
- # ips = [input_image]
77
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
78
-
79
- # block.launch(debug = True)
80
-
81
- # block = gr.Blocks().queue()
82
-
83
- gr.Markdown("## BRIA RMBG 1.4")
84
- gr.HTML('''
85
- <p style="margin-bottom: 10px; font-size: 94%">
86
- This is a demo for BRIA RMBG 1.4 that using
87
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
88
- </p>
89
- ''')
90
  title = "Background Removal"
91
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
92
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
93
  """
94
  examples = [['./input.jpg'],]
95
- # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
96
- # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
97
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
 
 
 
 
 
 
 
 
 
98
 
99
  if __name__ == "__main__":
100
  demo.launch(share=False)
 
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  import gradio as gr
 
6
  from briarmbg import BriaRMBG
7
  import PIL
8
  from PIL import Image
9
  from typing import Tuple
10
+ import requests
11
+ from io import BytesIO
12
 
13
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.to(device)
16
 
 
17
  def resize_image(image):
18
  image = image.convert('RGB')
19
  model_input_size = (1024, 1024)
20
  image = image.resize(model_input_size, Image.BILINEAR)
21
  return image
22
 
23
+ def get_url_image(url):
24
+ headers = {'User-Agent': 'gradio-app'}
25
+ response = requests.get(url, headers=headers)
26
+ print(f"Response status code: {response.status_code}")
27
+ response.raise_for_status() # Raise an error for bad status codes
28
+ return BytesIO(response.content)
29
 
30
+ def load_image(image_source):
31
+ try:
32
+ if isinstance(image_source, str): # Check if input is a URL
33
+ print(f"Loading image from URL: {image_source}")
34
+ image = Image.open(get_url_image(image_source))
35
+ else:
36
+ print("Loading image from file upload")
37
+ image = Image.fromarray(image_source)
38
+ print("Image loaded successfully")
39
+ return image
40
+ except Exception as e:
41
+ print(f"Error loading image: {e}")
42
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ def process(image_source):
45
+ try:
46
+ print("Processing image")
47
+ # Load and prepare input
48
+ orig_image = load_image(image_source)
49
+ w, h = orig_im_size = orig_image.size
50
+ image = resize_image(orig_image)
51
+ im_np = np.array(image)
52
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
53
+ im_tensor = torch.unsqueeze(im_tensor, 0)
54
+ im_tensor = torch.divide(im_tensor, 255.0)
55
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
56
+ if torch.cuda.is_available():
57
+ im_tensor = im_tensor.cuda()
58
 
59
+ # Inference
60
+ result = net(im_tensor)
61
+ # Post-process
62
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
63
+ ma = torch.max(result)
64
+ mi = torch.min(result)
65
+ result = (result - mi) / (ma - mi)
66
+ # Image to PIL
67
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
68
+ pil_im = Image.fromarray(np.squeeze(im_array))
69
+ # Paste the mask on the original image
70
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
71
+ new_im.paste(orig_image, mask=pil_im)
72
+ print("Image processed successfully")
73
+ return new_im
74
+ except Exception as e:
75
+ print(f"Error during processing: {e}")
76
+ return f"Error: {e}"
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  title = "Background Removal"
79
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
80
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
81
  """
82
  examples = [['./input.jpg'],]
83
+
84
+ demo = gr.Interface(
85
+ fn=process,
86
+ inputs=[
87
+ gr.Image(type="numpy", label="Upload Image"),
88
+ gr.Textbox(label="Image URL")
89
+ ],
90
+ outputs="image",
91
+ examples=examples,
92
+ title=title,
93
+ description=description
94
+ )
95
 
96
  if __name__ == "__main__":
97
  demo.launch(share=False)