gouravgujariya commited on
Commit
482dea9
·
verified ·
1 Parent(s): 0cbc70a

Upload 5 files

Browse files
Files changed (5) hide show
  1. app (1).py +72 -0
  2. app.py +64 -0
  3. briarmbg.py +456 -0
  4. requirements (1).txt +10 -0
  5. requirements.txt +10 -0
app (1).py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ import gradio as gr
6
+ from gradio_imageslider import ImageSlider
7
+ from briarmbg import BriaRMBG
8
+ import PIL
9
+ from PIL import Image
10
+ from typing import Tuple
11
+
12
+
13
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ net.to(device)
16
+ net.eval()
17
+
18
+
19
+ def resize_image(image):
20
+ image = image.convert('RGB')
21
+ model_input_size = (1024, 1024)
22
+ image = image.resize(model_input_size, Image.BILINEAR)
23
+ return image
24
+
25
+
26
+ def process(image):
27
+
28
+ # prepare input
29
+ orig_image = Image.fromarray(image)
30
+ w,h = orig_im_size = orig_image.size
31
+ image = resize_image(orig_image)
32
+ im_np = np.array(image)
33
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
34
+ im_tensor = torch.unsqueeze(im_tensor,0)
35
+ im_tensor = torch.divide(im_tensor,255.0)
36
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
37
+ if torch.cuda.is_available():
38
+ im_tensor=im_tensor.cuda()
39
+
40
+ #inference
41
+ result=net(im_tensor)
42
+ # post process
43
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
44
+ ma = torch.max(result)
45
+ mi = torch.min(result)
46
+ result = (result-mi)/(ma-mi)
47
+ # image to pil
48
+ result_array = (result*255).cpu().data.numpy().astype(np.uint8)
49
+ pil_mask = Image.fromarray(np.squeeze(result_array))
50
+ # add the mask on the original image as alpha channel
51
+ new_im = orig_image.copy()
52
+ new_im.putalpha(pil_mask)
53
+ return new_im
54
+ # return [new_orig_image, new_im]
55
+
56
+
57
+ gr.Markdown("## BRIA RMBG 1.4")
58
+ gr.HTML('''
59
+ <p style="margin-bottom: 10px; font-size: 94%">
60
+ This is a demo for BRIA RMBG 1.4 that using
61
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
62
+ </p>
63
+ ''')
64
+ title = "Background Removal"
65
+ description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
66
+ For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>. To purchase a commercial license, simply click <a href='https://go.bria.ai/3ZCBTLH' target='_blank'><b>Here</b></a>. <br>
67
+ """
68
+ examples = [['./input.jpg'],]
69
+ demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
70
+
71
+ if __name__ == "__main__":
72
+ demo.launch(share=False)
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import threading
3
+ import os
4
+ import torch
5
+
6
+ os.environ["OMP_NUM_THREADS"] = str(os.cpu_count())
7
+ torch.set_num_threads(os.cpu_count())
8
+
9
+ model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
10
+ model2 = gr.load("models/Purz/face-projection")
11
+
12
+ stop_event = threading.Event()
13
+
14
+ def generate_images(text, selected_model):
15
+ stop_event.clear()
16
+
17
+ if selected_model == "Model 1 (Turbo Realism)":
18
+ model = model1
19
+ elif selected_model == "Model 2 (Face Projection)":
20
+ model = model2
21
+ else:
22
+ return ["Invalid model selection."] * 3
23
+
24
+ results = []
25
+ for i in range(3):
26
+ if stop_event.is_set():
27
+ return ["Image generation stopped by user."] * 3
28
+
29
+ modified_text = f"{text} variation {i+1}"
30
+ result = model(modified_text)
31
+ results.append(result)
32
+
33
+ return results
34
+
35
+ def stop_generation():
36
+ """Stops the ongoing image generation by setting the stop_event flag."""
37
+ stop_event.set()
38
+ return ["Generation stopped."] * 3
39
+
40
+ with gr.Blocks() as interface:#...
41
+ gr.Markdown(
42
+ "### ⚠ Sorry for the inconvenience. The Space is currently running on the CPU, which might affect performance. We appreciate your understanding."
43
+ )
44
+
45
+ text_input = gr.Textbox(label="Type here your imagination:", placeholder="Type your prompt...")
46
+ model_selector = gr.Radio(
47
+ ["Model 1 (Turbo Realism)", "Model 2 (Face Projection)"],
48
+ label="Select Model",
49
+ value="Model 1 (Turbo Realism)"
50
+ )
51
+
52
+ with gr.Row():
53
+ generate_button = gr.Button("Generate 3 Images 🎨")
54
+ stop_button = gr.Button("Stop Image Generation")
55
+
56
+ with gr.Row():
57
+ output1 = gr.Image(label="Generated Image 1")
58
+ output2 = gr.Image(label="Generated Image 2")
59
+ output3 = gr.Image(label="Generated Image 3")
60
+
61
+ generate_button.click(generate_images, inputs=[text_input, model_selector], outputs=[output1, output2, output3])
62
+ stop_button.click(stop_generation, inputs=[], outputs=[output1, output2, output3])
63
+
64
+ interface.launch()
briarmbg.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
8
+ super(REBNCONV,self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
11
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
12
+ self.relu_s1 = nn.ReLU(inplace=True)
13
+
14
+ def forward(self,x):
15
+
16
+ hx = x
17
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
18
+
19
+ return xout
20
+
21
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
22
+ def _upsample_like(src,tar):
23
+
24
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
25
+
26
+ return src
27
+
28
+
29
+ ### RSU-7 ###
30
+ class RSU7(nn.Module):
31
+
32
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
33
+ super(RSU7,self).__init__()
34
+
35
+ self.in_ch = in_ch
36
+ self.mid_ch = mid_ch
37
+ self.out_ch = out_ch
38
+
39
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
40
+
41
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
42
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
43
+
44
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
45
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
46
+
47
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
48
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
49
+
50
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
51
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
52
+
53
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
54
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
55
+
56
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
57
+
58
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
59
+
60
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
66
+
67
+ def forward(self,x):
68
+ b, c, h, w = x.shape
69
+
70
+ hx = x
71
+ hxin = self.rebnconvin(hx)
72
+
73
+ hx1 = self.rebnconv1(hxin)
74
+ hx = self.pool1(hx1)
75
+
76
+ hx2 = self.rebnconv2(hx)
77
+ hx = self.pool2(hx2)
78
+
79
+ hx3 = self.rebnconv3(hx)
80
+ hx = self.pool3(hx3)
81
+
82
+ hx4 = self.rebnconv4(hx)
83
+ hx = self.pool4(hx4)
84
+
85
+ hx5 = self.rebnconv5(hx)
86
+ hx = self.pool5(hx5)
87
+
88
+ hx6 = self.rebnconv6(hx)
89
+
90
+ hx7 = self.rebnconv7(hx6)
91
+
92
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
93
+ hx6dup = _upsample_like(hx6d,hx5)
94
+
95
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
96
+ hx5dup = _upsample_like(hx5d,hx4)
97
+
98
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
99
+ hx4dup = _upsample_like(hx4d,hx3)
100
+
101
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
102
+ hx3dup = _upsample_like(hx3d,hx2)
103
+
104
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
105
+ hx2dup = _upsample_like(hx2d,hx1)
106
+
107
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
108
+
109
+ return hx1d + hxin
110
+
111
+
112
+ ### RSU-6 ###
113
+ class RSU6(nn.Module):
114
+
115
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
+ super(RSU6,self).__init__()
117
+
118
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
119
+
120
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
121
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
122
+
123
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
124
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
125
+
126
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
127
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
128
+
129
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
130
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
131
+
132
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
133
+
134
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
135
+
136
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
141
+
142
+ def forward(self,x):
143
+
144
+ hx = x
145
+
146
+ hxin = self.rebnconvin(hx)
147
+
148
+ hx1 = self.rebnconv1(hxin)
149
+ hx = self.pool1(hx1)
150
+
151
+ hx2 = self.rebnconv2(hx)
152
+ hx = self.pool2(hx2)
153
+
154
+ hx3 = self.rebnconv3(hx)
155
+ hx = self.pool3(hx3)
156
+
157
+ hx4 = self.rebnconv4(hx)
158
+ hx = self.pool4(hx4)
159
+
160
+ hx5 = self.rebnconv5(hx)
161
+
162
+ hx6 = self.rebnconv6(hx5)
163
+
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
166
+ hx5dup = _upsample_like(hx5d,hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
169
+ hx4dup = _upsample_like(hx4d,hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
172
+ hx3dup = _upsample_like(hx3d,hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
175
+ hx2dup = _upsample_like(hx2d,hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
178
+
179
+ return hx1d + hxin
180
+
181
+ ### RSU-5 ###
182
+ class RSU5(nn.Module):
183
+
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5,self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
206
+
207
+ def forward(self,x):
208
+
209
+ hx = x
210
+
211
+ hxin = self.rebnconvin(hx)
212
+
213
+ hx1 = self.rebnconv1(hxin)
214
+ hx = self.pool1(hx1)
215
+
216
+ hx2 = self.rebnconv2(hx)
217
+ hx = self.pool2(hx2)
218
+
219
+ hx3 = self.rebnconv3(hx)
220
+ hx = self.pool3(hx3)
221
+
222
+ hx4 = self.rebnconv4(hx)
223
+
224
+ hx5 = self.rebnconv5(hx4)
225
+
226
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
227
+ hx4dup = _upsample_like(hx4d,hx3)
228
+
229
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
230
+ hx3dup = _upsample_like(hx3d,hx2)
231
+
232
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
233
+ hx2dup = _upsample_like(hx2d,hx1)
234
+
235
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
236
+
237
+ return hx1d + hxin
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+
242
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243
+ super(RSU4,self).__init__()
244
+
245
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
246
+
247
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
248
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
249
+
250
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
251
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
252
+
253
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
254
+
255
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
256
+
257
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
260
+
261
+ def forward(self,x):
262
+
263
+ hx = x
264
+
265
+ hxin = self.rebnconvin(hx)
266
+
267
+ hx1 = self.rebnconv1(hxin)
268
+ hx = self.pool1(hx1)
269
+
270
+ hx2 = self.rebnconv2(hx)
271
+ hx = self.pool2(hx2)
272
+
273
+ hx3 = self.rebnconv3(hx)
274
+
275
+ hx4 = self.rebnconv4(hx3)
276
+
277
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
278
+ hx3dup = _upsample_like(hx3d,hx2)
279
+
280
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
281
+ hx2dup = _upsample_like(hx2d,hx1)
282
+
283
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
284
+
285
+ return hx1d + hxin
286
+
287
+ ### RSU-4F ###
288
+ class RSU4F(nn.Module):
289
+
290
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
291
+ super(RSU4F,self).__init__()
292
+
293
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
294
+
295
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
296
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
297
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
298
+
299
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
300
+
301
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
302
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
303
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
304
+
305
+ def forward(self,x):
306
+
307
+ hx = x
308
+
309
+ hxin = self.rebnconvin(hx)
310
+
311
+ hx1 = self.rebnconv1(hxin)
312
+ hx2 = self.rebnconv2(hx1)
313
+ hx3 = self.rebnconv3(hx2)
314
+
315
+ hx4 = self.rebnconv4(hx3)
316
+
317
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
318
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
319
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
320
+
321
+ return hx1d + hxin
322
+
323
+
324
+ class myrebnconv(nn.Module):
325
+ def __init__(self, in_ch=3,
326
+ out_ch=1,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1,
330
+ dilation=1,
331
+ groups=1):
332
+ super(myrebnconv,self).__init__()
333
+
334
+ self.conv = nn.Conv2d(in_ch,
335
+ out_ch,
336
+ kernel_size=kernel_size,
337
+ stride=stride,
338
+ padding=padding,
339
+ dilation=dilation,
340
+ groups=groups)
341
+ self.bn = nn.BatchNorm2d(out_ch)
342
+ self.rl = nn.ReLU(inplace=True)
343
+
344
+ def forward(self,x):
345
+ return self.rl(self.bn(self.conv(x)))
346
+
347
+
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
+
350
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
+ super(BriaRMBG,self).__init__()
352
+ in_ch=config["in_ch"]
353
+ out_ch=config["out_ch"]
354
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
+
357
+ self.stage1 = RSU7(64,32,64)
358
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
359
+
360
+ self.stage2 = RSU6(64,32,128)
361
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
362
+
363
+ self.stage3 = RSU5(128,64,256)
364
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
365
+
366
+ self.stage4 = RSU4(256,128,512)
367
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
368
+
369
+ self.stage5 = RSU4F(512,256,512)
370
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
371
+
372
+ self.stage6 = RSU4F(512,256,512)
373
+
374
+ # decoder
375
+ self.stage5d = RSU4F(1024,256,512)
376
+ self.stage4d = RSU4(1024,128,256)
377
+ self.stage3d = RSU5(512,64,128)
378
+ self.stage2d = RSU6(256,32,64)
379
+ self.stage1d = RSU7(128,16,64)
380
+
381
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
382
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
383
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
384
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
385
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
386
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
387
+
388
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
389
+
390
+ def forward(self,x):
391
+
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ #hx = self.pool_in(hxin)
396
+
397
+ #stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ #stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ #stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ #stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ #stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ #stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6,hx5)
420
+
421
+ #-------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
423
+ hx5dup = _upsample_like(hx5d,hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
426
+ hx4dup = _upsample_like(hx4d,hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
429
+ hx3dup = _upsample_like(hx3d,hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
432
+ hx2dup = _upsample_like(hx2d,hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
435
+
436
+
437
+ #side output
438
+ d1 = self.side1(hx1d)
439
+ d1 = _upsample_like(d1,x)
440
+
441
+ d2 = self.side2(hx2d)
442
+ d2 = _upsample_like(d2,x)
443
+
444
+ d3 = self.side3(hx3d)
445
+ d3 = _upsample_like(d3,x)
446
+
447
+ d4 = self.side4(hx4d)
448
+ d4 = _upsample_like(d4,x)
449
+
450
+ d5 = self.side5(hx5d)
451
+ d5 = _upsample_like(d5,x)
452
+
453
+ d6 = self.side6(hx6)
454
+ d6 = _upsample_like(d6,x)
455
+
456
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
requirements (1).txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ gradio_imageslider
3
+ torch
4
+ torchvision
5
+ pillow
6
+ numpy
7
+ typing
8
+ gitpython
9
+ huggingface_hub
10
+ safetensors
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ pillow
5
+ numpy
6
+ typing
7
+ gitpython
8
+ huggingface_hub
9
+ safetensors
10
+ gradio_imageslider