Spaces:
fantos
/
Runtime error

arxivgpt kim commited on
Commit
fc9e8ce
·
verified ·
1 Parent(s): b8f4eff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -119
app.py CHANGED
@@ -4,158 +4,76 @@ import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider
8
  from briarmbg import BriaRMBG
9
- import PIL
10
  from PIL import Image
11
- from typing import Tuple
12
 
13
- net=BriaRMBG()
14
- # model_path = "./model1.pth"
15
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
- net=net.cuda()
19
  else:
20
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
- net.eval()
22
 
23
-
24
- def resize_image(image):
25
  image = image.convert('RGB')
26
- model_input_size = (1024, 1024)
27
  image = image.resize(model_input_size, Image.BILINEAR)
28
  return image
29
 
30
-
31
- def process(image):
32
-
33
- # prepare input
34
- orig_image = Image.fromarray(image)
35
- w,h = orig_im_size = orig_image.size
36
- image = resize_image(orig_image)
37
- im_np = np.array(image)
38
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
- im_tensor = torch.unsqueeze(im_tensor,0)
40
- im_tensor = torch.divide(im_tensor,255.0)
41
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
  if torch.cuda.is_available():
43
- im_tensor=im_tensor.cuda()
44
-
45
- #inference
46
- result=net(im_tensor)
47
- # post process
48
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
- ma = torch.max(result)
50
- mi = torch.min(result)
51
- result = (result-mi)/(ma-mi)
52
- # image to pil
53
- im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
- pil_im = Image.fromarray(np.squeeze(im_array))
55
- # paste the mask on the original image
56
- new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
- new_im.paste(orig_image, mask=pil_im)
58
- # new_orig_image = orig_image.convert('RGBA')
59
-
60
- return new_im
61
- # return [new_orig_image, new_im]
62
-
63
-
64
- # block = gr.Blocks().queue()
65
-
66
- # with block:
67
- # gr.Markdown("## BRIA RMBG 1.4")
68
- # gr.HTML('''
69
- # <p style="margin-bottom: 10px; font-size: 94%">
70
- # This is a demo for BRIA RMBG 1.4 that using
71
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
72
- # </p>
73
- # ''')
74
- # with gr.Row():
75
- # with gr.Column():
76
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
77
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
78
- # run_button = gr.Button(value="Run")
79
-
80
- # with gr.Column():
81
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
82
- # ips = [input_image]
83
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
84
-
85
- # block.launch(debug = True)
86
-
87
- # block = gr.Blocks().queue()
88
-
89
- gr.Markdown("## BRIA RMBG 1.4")
90
- gr.HTML('''
91
- <p style="margin-bottom: 10px; font-size: 94%">
92
- This is a demo for BRIA RMBG 1.4 that using
93
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
94
- </p>
95
- ''')
96
- title = "Background Removal"
97
- description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
98
- For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
99
- """
100
- examples = [['./input.jpg'],]
101
- # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
102
- # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
103
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
104
 
 
105
 
106
  def merge_images(background_path, foreground_image):
107
- """
108
- 배경 이미지에 배경이 제거된 이미지를 투명하게 삽입합니다.
109
- 배경이 제거된 이미지는 배경 이미지 중앙에 30% 크기로 축소되어 삽입됩니다.
110
- """
111
- # 배경 이미지 로드
112
  background = Image.open(background_path).convert("RGBA")
113
- # 전경(배경이 제거된 이미지) 이미지 로드
114
  foreground = foreground_image.convert("RGBA")
115
 
116
- # 전경 이미지를 배경 이미지의 30% 크기로 조정
117
  scale_factor = 0.3
118
  new_size = (int(background.width * scale_factor), int(foreground.height * background.width / foreground.width * scale_factor))
119
  foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
120
 
121
- # 전경 이미지를 배경 이미지의 가운데에 위치시키기 위한 좌표 계산
122
  x = (background.width - new_size[0]) // 2
123
  y = (background.height - new_size[1]) // 2
124
 
125
- # 배경 이미지 위에 전경 이미지를 붙임
126
  background.paste(foreground_resized, (x, y), foreground_resized)
127
-
128
  return background
129
 
130
- def process(image, background_image=None):
131
- """
132
- 이미지에서 배경을 제거합니다. 선택적으로, 사용자가 배경 이미지를 제공한 경우
133
- 배경 이미지 위에 배경이 제거된 이미지를 투명하게 삽입합니다.
134
- """
135
- # 배경 제거 로직을 여기에 구현합니다...
136
- # 예를 들어, 배경 제거 처리 후 이미지를 new_im에 할당합니다.
137
- # 이 예제에서는 단순화를 위해 image를 직접 사용합니다.
138
- new_im = image # 실제 구현에서는 여기에 배경 제거 로직의 결과를 할당해야 합니다.
139
-
140
- # 배경 이미지가 제공되었는지 확인
141
- if background_image is not None:
142
- # 배경 이미지에 배경이 제거된 이미지를 삽입하는 로직을 구현합니다.
143
- result_image = merge_images(background_image, new_im)
144
- else:
145
- result_image = new_im
146
-
147
- return result_image
148
-
149
-
150
  # Gradio 인터페이스 정의
151
  demo = gr.Interface(
152
  fn=process,
153
  inputs=[
154
- gr.Image(type="pil", label="Image to remove background"),
155
- gr.Image(type="pil", label="Optional: Background Image") # optional 키워드 인자 제거
156
  ],
157
  outputs="image",
158
- examples=examples,
159
- title=title,
160
- description=description
161
  )
 
 
 
 
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
 
7
  from briarmbg import BriaRMBG
 
8
  from PIL import Image
 
9
 
10
+ # 모델 초기화 및 로드
11
+ net = BriaRMBG()
12
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
13
  if torch.cuda.is_available():
14
  net.load_state_dict(torch.load(model_path))
15
+ net = net.cuda()
16
  else:
17
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
18
+ net.eval()
19
 
20
+ def resize_image(image, model_input_size=(1024, 1024)):
 
21
  image = image.convert('RGB')
 
22
  image = image.resize(model_input_size, Image.BILINEAR)
23
  return image
24
 
25
+ def process(image, background_image=None):
26
+ # 이미지 전처리
27
+ orig_image = Image.fromarray(image.astype(np.uint8))
28
+ resized_image = resize_image(orig_image)
29
+ im_tensor = torch.tensor(np.array(resized_image), dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
30
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
 
 
 
 
 
 
31
  if torch.cuda.is_available():
32
+ im_tensor = im_tensor.cuda()
33
+
34
+ # 배경 제거 모델 추론
35
+ with torch.no_grad():
36
+ output = net(im_tensor)
37
+ output = F.interpolate(output[0][0], size=orig_image.size[::-1], mode='bilinear', align_corners=False)
38
+ output = torch.sigmoid(output).cpu().numpy().squeeze()
39
+ mask = (output * 255).astype(np.uint8)
40
+ new_im = Image.fromarray(mask).convert("L")
41
+
42
+ result_image = Image.new("RGBA", orig_image.size)
43
+ orig_image_rgba = orig_image.convert("RGBA")
44
+ result_image.paste(orig_image_rgba, mask=new_im)
45
+
46
+ # 배경 이미지가 제공된 경우, 합성
47
+ if background_image is not None:
48
+ result_image = merge_images(background_image, result_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ return result_image
51
 
52
  def merge_images(background_path, foreground_image):
 
 
 
 
 
53
  background = Image.open(background_path).convert("RGBA")
 
54
  foreground = foreground_image.convert("RGBA")
55
 
 
56
  scale_factor = 0.3
57
  new_size = (int(background.width * scale_factor), int(foreground.height * background.width / foreground.width * scale_factor))
58
  foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
59
 
 
60
  x = (background.width - new_size[0]) // 2
61
  y = (background.height - new_size[1]) // 2
62
 
 
63
  background.paste(foreground_resized, (x, y), foreground_resized)
 
64
  return background
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Gradio 인터페이스 정의
67
  demo = gr.Interface(
68
  fn=process,
69
  inputs=[
70
+ gr.Image(type="numpy", label="Image to remove background"),
71
+ gr.Image(type="pil", label="Optional: Background Image")
72
  ],
73
  outputs="image",
74
+ title="Background Removal",
75
+ description="This is a demo for BRIA RMBG 1.4 that uses the BRIA RMBG-1.4 image matting model as backbone."
 
76
  )
77
+
78
+ if __name__ == "__main__":
79
+ demo.launch()