Spaces:
fantos
/
Runtime error

arxivgpt kim commited on
Commit
a119d24
ยท
verified ยท
1 Parent(s): a6a92bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -4,8 +4,11 @@ import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
 
7
  from briarmbg import BriaRMBG
 
8
  from PIL import Image
 
9
 
10
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๋กœ๋“œ
11
  net = BriaRMBG()
@@ -24,7 +27,7 @@ def resize_image(image, model_input_size=(1024, 1024)):
24
 
25
  def process(image, background_image=None):
26
  # ์ด๋ฏธ์ง€ ์ค€๋น„
27
- orig_image = Image.fromarray(image).convert("RGBA")
28
  w, h = orig_im_size = orig_image.size
29
  image = resize_image(orig_image)
30
  im_np = np.array(image)
@@ -37,30 +40,33 @@ def process(image, background_image=None):
37
  with torch.no_grad():
38
  result = net(im_tensor)
39
 
40
- # ํ›„์ฒ˜๋ฆฌ
41
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
42
  result = torch.sigmoid(result)
43
- mask = (result * 255).byte().cpu().numpy() # ๋งˆ์Šคํฌ๋ฅผ 0~255 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ ๋ณ€ํ™˜
44
 
45
- # mask ๋ฐฐ์—ด์ด ์˜ˆ์ƒ๋Œ€๋กœ 2์ฐจ์›์ธ์ง€ ํ™•์ธํ•˜๊ณ , ์•„๋‹ˆ๋ผ๋ฉด ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
46
  if mask.ndim > 2:
47
- mask = mask.squeeze() # ์ฐจ์› ์ถ•์†Œ
48
 
49
- # mask ๋ฐฐ์—ด์„ ๋ช…ํ™•ํžˆ uint8๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
50
  mask = mask.astype(np.uint8)
51
 
52
- # mask๋ฅผ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
53
- mask_image = Image.fromarray(mask, 'L') # 'L' ๋ชจ๋“œ๋Š” ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ ์ด๋ฏธ์ง€๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
54
-
55
- final_image = Image.new("RGBA", orig_image.size)
56
- final_image.paste(orig_image, mask=mask_image)
57
 
58
  # ์„ ํƒ์  ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
59
- if background_image is not None:
60
- final_image = merge_images(background_image, final_image)
 
 
61
 
62
  return final_image
63
 
 
64
  def merge_images(background_image, foreground_image):
65
  """
66
  ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์— ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํˆฌ๋ช…ํ•˜๊ฒŒ ์‚ฝ์ž…ํ•ฉ๋‹ˆ๋‹ค.
 
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
+ from gradio_imageslider import ImageSlider
8
  from briarmbg import BriaRMBG
9
+ import PIL
10
  from PIL import Image
11
+ from typing import Tuple
12
 
13
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๋กœ๋“œ
14
  net = BriaRMBG()
 
27
 
28
  def process(image, background_image=None):
29
  # ์ด๋ฏธ์ง€ ์ค€๋น„
30
+ orig_image = Image.fromarray(image).convert("RGB")
31
  w, h = orig_im_size = orig_image.size
32
  image = resize_image(orig_image)
33
  im_np = np.array(image)
 
40
  with torch.no_grad():
41
  result = net(im_tensor)
42
 
43
+ # ํ›„์ฒ˜๋ฆฌ
44
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
45
  result = torch.sigmoid(result)
46
+ mask = (result * 255).byte().cpu().numpy()
47
 
48
+ # mask ๋ฐฐ์—ด์ด ์˜ˆ์ƒ๋Œ€๋กœ 2์ฐจ์›์ธ์ง€ ํ™•์ธํ•˜๊ณ , ์•„๋‹ˆ๋ผ๋ฉด ์กฐ์ •
49
  if mask.ndim > 2:
50
+ mask = mask.squeeze()
51
 
52
+ # mask ๋ฐฐ์—ด์„ ๋ช…ํ™•ํžˆ uint8๋กœ ๋ณ€ํ™˜
53
  mask = mask.astype(np.uint8)
54
 
55
+ # ๋งˆ์Šคํฌ๋ฅผ ์•ŒํŒŒ ์ฑ„๋„๋กœ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ์ข… ์ด๋ฏธ์ง€ ์ƒ์„ฑ
56
+ orig_image = orig_image.convert("RGBA")
57
+ final_image = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
58
+ mask_image = Image.fromarray(mask, mode='L')
59
+ foreground_image = Image.composite(orig_image, final_image, mask_image)
60
 
61
  # ์„ ํƒ์  ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
62
+ if background_image:
63
+ final_image = merge_images(background_image, foreground_image)
64
+ else:
65
+ final_image = foreground_image
66
 
67
  return final_image
68
 
69
+
70
  def merge_images(background_image, foreground_image):
71
  """
72
  ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์— ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํˆฌ๋ช…ํ•˜๊ฒŒ ์‚ฝ์ž…ํ•ฉ๋‹ˆ๋‹ค.