DrChamyoung commited on
Commit
4548a80
·
verified ·
1 Parent(s): 16bb0cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -1,17 +1,18 @@
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
4
- import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
8
 
9
- torch.set_float32_matmul_precision(["high", "highest"][0])
 
10
 
11
  birefnet = AutoModelForImageSegmentation.from_pretrained(
12
  "ZhengPeng7/BiRefNet", trust_remote_code=True
13
  )
14
- birefnet.to("cuda")
 
15
  transform_image = transforms.Compose(
16
  [
17
  transforms.Resize((1024, 1024)),
@@ -20,15 +21,14 @@ transform_image = transforms.Compose(
20
  ]
21
  )
22
 
23
-
24
- @spaces.GPU
25
  def fn(image):
26
  im = load_img(image, output_type="pil")
27
  im = im.convert("RGB")
28
  image_size = im.size
29
  origin = im.copy()
30
  image = load_img(im)
31
- input_images = transform_image(image).unsqueeze(0).to("cuda")
 
32
  # Prediction
33
  with torch.no_grad():
34
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -38,13 +38,11 @@ def fn(image):
38
  image.putalpha(mask)
39
  return (image, origin)
40
 
41
-
42
  slider1 = ImageSlider(label="birefnet", type="pil")
43
  slider2 = ImageSlider(label="birefnet", type="pil")
44
  image = gr.Image(label="Upload an image")
45
  text = gr.Textbox(label="Paste an image URL")
46
 
47
-
48
  chameleon = load_img("chameleon.jpg", output_type="pil")
49
 
50
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
@@ -54,10 +52,9 @@ tab1 = gr.Interface(
54
 
55
  tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")
56
 
57
-
58
  demo = gr.TabbedInterface(
59
  [tab1, tab2], ["image", "text"], title="birefnet for background removal"
60
  )
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
 
4
  from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
7
 
8
+ # Remove this line as it is GPU specific
9
+ # torch.set_float32_matmul_precision(["high", "highest"][0])
10
 
11
  birefnet = AutoModelForImageSegmentation.from_pretrained(
12
  "ZhengPeng7/BiRefNet", trust_remote_code=True
13
  )
14
+ birefnet.to("cpu") # Change to CPU
15
+
16
  transform_image = transforms.Compose(
17
  [
18
  transforms.Resize((1024, 1024)),
 
21
  ]
22
  )
23
 
 
 
24
  def fn(image):
25
  im = load_img(image, output_type="pil")
26
  im = im.convert("RGB")
27
  image_size = im.size
28
  origin = im.copy()
29
  image = load_img(im)
30
+ input_images = transform_image(image).unsqueeze(0).to("cpu") # Change to CPU
31
+
32
  # Prediction
33
  with torch.no_grad():
34
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
38
  image.putalpha(mask)
39
  return (image, origin)
40
 
 
41
  slider1 = ImageSlider(label="birefnet", type="pil")
42
  slider2 = ImageSlider(label="birefnet", type="pil")
43
  image = gr.Image(label="Upload an image")
44
  text = gr.Textbox(label="Paste an image URL")
45
 
 
46
  chameleon = load_img("chameleon.jpg", output_type="pil")
47
 
48
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
 
52
 
53
  tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")
54
 
 
55
  demo = gr.TabbedInterface(
56
  [tab1, tab2], ["image", "text"], title="birefnet for background removal"
57
  )
58
 
59
  if __name__ == "__main__":
60
+ demo.launch()