Imadsarvm commited on
Commit
04ba376
·
verified ·
1 Parent(s): 5682b91
Files changed (1) hide show
  1. app.py +37 -27
app.py CHANGED
@@ -20,42 +20,52 @@ def resize_image(image):
20
  image = image.resize(model_input_size, Image.BILINEAR)
21
  return image
22
 
 
 
 
 
 
23
  def load_image(image_source):
24
  if isinstance(image_source, str): # Check if input is a URL
25
- response = requests.get(image_source)
26
- image = Image.open(BytesIO(response.content))
27
  else:
 
28
  image = Image.fromarray(image_source)
29
  return image
30
 
31
  def process(image_source):
32
- # Load and prepare input
33
- orig_image = load_image(image_source)
34
- w, h = orig_im_size = orig_image.size
35
- image = resize_image(orig_image)
36
- im_np = np.array(image)
37
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
38
- im_tensor = torch.unsqueeze(im_tensor, 0)
39
- im_tensor = torch.divide(im_tensor, 255.0)
40
- im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
41
- if torch.cuda.is_available():
42
- im_tensor = im_tensor.cuda()
 
43
 
44
- # Inference
45
- result = net(im_tensor)
46
- # Post-process
47
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
48
- ma = torch.max(result)
49
- mi = torch.min(result)
50
- result = (result - mi) / (ma - mi)
51
- # Image to PIL
52
- im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
53
- pil_im = Image.fromarray(np.squeeze(im_array))
54
- # Paste the mask on the original image
55
- new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
56
- new_im.paste(orig_image, mask=pil_im)
57
 
58
- return new_im
 
 
 
59
 
60
  title = "Background Removal"
61
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
 
20
  image = image.resize(model_input_size, Image.BILINEAR)
21
  return image
22
 
23
+ def get_url_image(url):
24
+ headers = {'User-Agent': 'gradio-app'}
25
+ response = requests.get(url, headers=headers)
26
+ return BytesIO(response.content)
27
+
28
  def load_image(image_source):
29
  if isinstance(image_source, str): # Check if input is a URL
30
+ print(f"Loading image from URL: {image_source}")
31
+ image = Image.open(get_url_image(image_source))
32
  else:
33
+ print("Loading image from file upload")
34
  image = Image.fromarray(image_source)
35
  return image
36
 
37
  def process(image_source):
38
+ try:
39
+ # Load and prepare input
40
+ orig_image = load_image(image_source)
41
+ w, h = orig_im_size = orig_image.size
42
+ image = resize_image(orig_image)
43
+ im_np = np.array(image)
44
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
45
+ im_tensor = torch.unsqueeze(im_tensor, 0)
46
+ im_tensor = torch.divide(im_tensor, 255.0)
47
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
48
+ if torch.cuda.is_available():
49
+ im_tensor = im_tensor.cuda()
50
 
51
+ # Inference
52
+ result = net(im_tensor)
53
+ # Post-process
54
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
55
+ ma = torch.max(result)
56
+ mi = torch.min(result)
57
+ result = (result - mi) / (ma - mi)
58
+ # Image to PIL
59
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
60
+ pil_im = Image.fromarray(np.squeeze(im_array))
61
+ # Paste the mask on the original image
62
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
63
+ new_im.paste(orig_image, mask=pil_im)
64
 
65
+ return new_im
66
+ except Exception as e:
67
+ print(f"Error during processing: {e}")
68
+ return None
69
 
70
  title = "Background Removal"
71
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>