Spaces:
fantos
/
Runtime error

arxivgpt kim commited on
Commit
c384edc
·
verified ·
1 Parent(s): 39590b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -93
app.py CHANGED
@@ -10,109 +10,111 @@ import PIL
10
  from PIL import Image
11
  from typing import Tuple
12
 
13
- # 모델 초기화 및 로드
14
- net = BriaRMBG()
15
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
- net = net.cuda()
19
  else:
20
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
21
- net.eval()
22
 
23
- def resize_image(image, model_input_size=(1024, 1024)):
 
24
  image = image.convert('RGB')
 
25
  image = image.resize(model_input_size, Image.BILINEAR)
26
  return image
27
 
28
 
29
- def process(image, background_image=None):
30
- # 이미지 준비
31
- orig_image = Image.fromarray(image).convert("RGB")
32
- w, h = orig_image.size
33
- resized_image = resize_image(orig_image)
34
- im_np = np.array(resized_image).astype(np.float32) / 255.0
35
- im_tensor = torch.tensor(im_np).permute(2, 0, 1).unsqueeze(0)
36
- im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
37
- if torch.cuda.is_available():
38
- im_tensor = im_tensor.cuda()
39
-
40
- # 추론
41
- with torch.no_grad():
42
- result = net(im_tensor)
43
-
44
- # 후처리
45
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
46
- result = torch.sigmoid(result)
47
- mask = (result * 255).byte().cpu().numpy()
48
-
49
- if mask.ndim > 2:
50
- mask = mask.squeeze()
51
-
52
- mask = mask.astype(np.uint8)
53
-
54
- # 마스크를 알파 채널로 사용하여 최종 이미지 생성
55
- final_image = Image.new("RGBA", orig_image.size)
56
- orig_image.putalpha(Image.fromarray(mask, 'L'))
57
-
58
- if background_image:
59
- # 배경 이미지가 제공된 경우, 배경 이미지 크기 조정
60
- background = background_image.convert("RGBA").resize(orig_image.size)
61
- # 배경과 전경(알파 적용된 원본 이미지) 합성
62
- final_image = Image.alpha_composite(background, orig_image)
63
- else:
64
- # 배경 이미지가 없는 경우, 투명도가 적용된 원본 이미지를 최종 이미지로 사용
65
- final_image = orig_image
66
-
67
- return final_image
68
-
69
-
70
 
71
- def merge_images(background_image, foreground_image):
72
- """
73
- 배경 이미지에 배경이 제거된 이미지를 투명하게 삽입합니다.
74
- 배경이 제거된 이미지는 배경 이미지 중앙에 30% 크기로 축소되어 삽입됩니다.
75
- """
76
- background = background_image.convert("RGBA")
77
- foreground = foreground_image.convert("RGBA")
78
-
79
- # 전경 이미지를 배경 이미지의 30% 크기로 조정
80
- scale_factor = 0.3
81
- foreground_width = int(background.width * scale_factor)
82
- foreground_height = int(foreground.height * foreground_width / foreground.width)
83
- new_size = (foreground_width, foreground_height)
84
- foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
85
-
86
- # 전경 이미지를 배경 이미지의 가운데에 위치시키기 위한 좌표 계산
87
- x = (background.width - foreground_width) // 2
88
- y = (background.height - foreground_height) // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # 배경 이미지 위에 전경 이미지를 붙임
91
- background.paste(foreground_resized, (x, y), foreground_resized)
92
 
93
- return background
94
-
95
-
96
- title = "Background Removal"
97
- description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."
98
-
99
-
100
- def add_blue_background(image):
101
- # 배경 제거된 이미지에 푸른색 배경을 추가하는 함수
102
- blue_background = Image.new("RGBA", image.size, "blue")
103
- final_image = Image.alpha_composite(blue_background, image.convert("RGBA"))
104
- return final_image
105
-
106
-
107
- inputs = "image" # 이전: gr.inputs.Image(type="pil") -> 변경: "image"
108
- outputs = ["image"] # 이전 방식에서 변경되었을 수 있는 출력 부분
109
-
110
- demo = gr.Interface(fn=process,
111
- inputs=inputs,
112
- outputs=outputs,
113
- title="Your Demo Title",
114
- description="A brief description of your app.")
115
-
116
- if __name__ == "__main__":
117
- demo.launch()
118
-
 
 
 
 
 
10
  from PIL import Image
11
  from typing import Tuple
12
 
13
+ net=BriaRMBG()
14
+ # model_path = "./model1.pth"
15
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
+ net=net.cuda()
19
  else:
20
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
+ net.eval()
22
 
23
+
24
+ def resize_image(image):
25
  image = image.convert('RGB')
26
+ model_input_size = (1024, 1024)
27
  image = image.resize(model_input_size, Image.BILINEAR)
28
  return image
29
 
30
 
31
+ def process(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # prepare input
34
+ orig_image = Image.fromarray(image)
35
+ w,h = orig_im_size = orig_image.size
36
+ image = resize_image(orig_image)
37
+ im_np = np.array(image)
38
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
+ im_tensor = torch.unsqueeze(im_tensor,0)
40
+ im_tensor = torch.divide(im_tensor,255.0)
41
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
+ if torch.cuda.is_available():
43
+ im_tensor=im_tensor.cuda()
44
+
45
+ #inference
46
+ result=net(im_tensor)
47
+ # post process
48
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
+ ma = torch.max(result)
50
+ mi = torch.min(result)
51
+ result = (result-mi)/(ma-mi)
52
+ # image to pil
53
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
+ pil_im = Image.fromarray(np.squeeze(im_array))
55
+ # paste the mask on the original image
56
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
+ new_im.paste(orig_image, mask=pil_im)
58
+ # new_orig_image = orig_image.convert('RGBA')
59
+
60
+ return new_im
61
+ # return [new_orig_image, new_im]
62
+
63
+ def calculate_position(org_size, add_size, position):
64
+ if position == "상단 좌측":
65
+ return (0, 0)
66
+ elif position == "상단 가운데":
67
+ return ((org_size[0] - add_size[0]) // 2, 0)
68
+ elif position == "상단 우측":
69
+ return (org_size[0] - add_size[0], 0)
70
+ elif position == "중앙 좌측":
71
+ return (0, (org_size[1] - add_size[1]) // 2)
72
+ elif position == "중앙 가운데":
73
+ return ((org_size[0] - add_size[0]) // 2, (org_size[1] - add_size[1]) // 2)
74
+ elif position == "중앙 우측":
75
+ return (org_size[0] - add_size[0], (org_size[1] - add_size[1]) // 2)
76
+ elif position == "하단 좌측":
77
+ return (0, org_size[1] - add_size[1])
78
+ elif position == "하단 가운데":
79
+ return ((org_size[0] - add_size[0]) // 2, org_size[1] - add_size[1])
80
+ elif position == "하단 우측":
81
+ return (org_size[0] - add_size[0], org_size[1] - add_size[1])
82
+
83
+ def merge(org_image, add_image, scale, position):
84
+ scale_percentage = scale / 100.0
85
+ new_size = (int(add_image.width * scale_percentage), int(add_image.height * scale_percentage))
86
+ add_image = add_image.resize(new_size, Image.Resampling.LANCZOS)
87
 
88
+ position = calculate_position(org_image.size, add_image.size, position)
89
+ org_image.paste(add_image, position, add_image)
90
 
91
+ return org_image
92
+
93
+
94
+
95
+ with gr.Blocks() as demo:
96
+ with gr.Tab("Background Removal"):
97
+ with gr.Column():
98
+ gr.Markdown("## BRIA RMBG 1.4")
99
+ gr.HTML('''
100
+ <p style="margin-bottom: 10px; font-size: 94%">
101
+ This is a demo for BRIA RMBG 1.4 that using
102
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
103
+ </p>
104
+ ''')
105
+ input_image = gr.Image(type="pil")
106
+ output_image = gr.Image()
107
+ process_button = gr.Button("Remove Background")
108
+ process_button.click(fn=process, inputs=input_image, outputs=output_image)
109
+
110
+ with gr.Tab("Merge"):
111
+ with gr.Column():
112
+ org_image = gr.Image(label="Background", type='pil', image_mode='RGBA', height="80vh")
113
+ add_image = gr.Image(label="Foreground", type='pil', image_mode='RGBA', height="80vh")
114
+ scale = gr.Slider(minimum=10, maximum=200, step=1, value=100, label="Scale of Foreground Image (%)")
115
+ position = gr.Radio(choices=["중앙 가운데", "상단 좌측", "상단 가운데", "상단 우측", "중앙 좌측", "중앙 우측", "하단 좌측", "하단 가운데", "하단 우측"], value="중앙 가운데", label="Position of Foreground Image")
116
+ merge_button = gr.Button("Merge Images")
117
+ result_merge = gr.Image(height="80vh")
118
+ merge_button.click(fn=merge, inputs=[org_image, add_image, scale, position], outputs=result_merge)
119
+
120
+ demo.launch()