gowtham58 commited on
Commit
80c08a8
·
verified ·
1 Parent(s): 5fdbdde

Uploading Necessary files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +6 -6
  3. app.py +110 -1
  4. briarmbg.py +456 -0
  5. input.jpg +0 -0
  6. input.mp4 +3 -0
  7. requirements.txt +10 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ input.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Briaai RMBG 1.4
3
- emoji: 🏆
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: BRIA RMBG 1.4
3
+ emoji: 💻
4
+ colorFrom: red
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
+ license: other
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,3 +1,112 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- gr.load("models/briaai/RMBG-1.4").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ from skimage import io
5
+ import torch, os
6
+ from PIL import Image
7
+ from briarmbg import BriaRMBG
8
  import gradio as gr
9
+ import cv2
10
+ import numpy as np
11
+ import time
12
+ import random
13
+ from PIL import Image
14
 
15
+
16
+ bgrm = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ bgrm.to(device)
20
+
21
+
22
+ def resize_image(image):
23
+ image = image.convert('RGB')
24
+ model_input_size = (1024, 1024)
25
+ image = image.resize(model_input_size, Image.BILINEAR)
26
+ return image
27
+
28
+
29
+ def process(image):
30
+
31
+ # prepare input
32
+ orig_image = Image.fromarray(image)
33
+ w,h = orig_im_size = orig_image.size
34
+ image = resize_image(orig_image)
35
+ im_np = np.array(image)
36
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
37
+ im_tensor = torch.unsqueeze(im_tensor,0)
38
+ im_tensor = torch.divide(im_tensor,255.0)
39
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
40
+ if torch.cuda.is_available():
41
+ im_tensor=im_tensor.cuda()
42
+
43
+ #inference
44
+ result=bgrm(im_tensor)
45
+ # post process
46
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
47
+ ma = torch.max(result)
48
+ mi = torch.min(result)
49
+ result = (result-mi)/(ma-mi)
50
+ # image to pil
51
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
52
+ pil_im = Image.fromarray(np.squeeze(im_array))
53
+ # paste the mask on the original image
54
+ new_im = Image.new("RGBA", pil_im.size, (0,255,0,255))
55
+ new_im.paste(orig_image, mask=pil_im)
56
+ # new_orig_image = orig_image.convert('RGBA')
57
+ return new_im
58
+
59
+
60
+
61
+
62
+
63
+ def process_video(video, progress=gr.Progress()):
64
+
65
+ cap = cv2.VideoCapture(video)
66
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames
67
+ writer = None
68
+ tmpname ='output.mp4'
69
+ processed_frames = 0
70
+ start_time = time.time()
71
+ i=0
72
+ while cap.isOpened():
73
+ ret, frame = cap.read()
74
+
75
+ if ret is False:
76
+ break
77
+
78
+ if time.time() - start_time >= 20 * 60 - 5:
79
+ print("GPU Timeout is coming")
80
+ cap.release()
81
+ writer.release()
82
+ return tmpname
83
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
84
+ img = Image.fromarray(frame).convert('RGB')
85
+
86
+ if writer is None:
87
+ writer = cv2.VideoWriter(tmpname, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)
88
+
89
+ processed_frames += 1
90
+ print(f"Processing frame {processed_frames}")
91
+ progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
92
+ out = process(np.array(img))
93
+ writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))
94
+
95
+ cap.release()
96
+ writer.release()
97
+ return tmpname
98
+
99
+ title = "🎞️ Video Background Removal Tool 🎥"
100
+ description = """Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode."""
101
+
102
+ examples = [['./input.mp4']]
103
+
104
+ iface = gr.Interface(
105
+ fn=process_video,
106
+ inputs=["video"],
107
+ outputs="video",
108
+ examples=examples,
109
+ title=title,
110
+ description=description
111
+ )
112
+ iface.launch()
briarmbg.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
8
+ super(REBNCONV,self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
11
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
12
+ self.relu_s1 = nn.ReLU(inplace=True)
13
+
14
+ def forward(self,x):
15
+
16
+ hx = x
17
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
18
+
19
+ return xout
20
+
21
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
22
+ def _upsample_like(src,tar):
23
+
24
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
25
+
26
+ return src
27
+
28
+
29
+ ### RSU-7 ###
30
+ class RSU7(nn.Module):
31
+
32
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
33
+ super(RSU7,self).__init__()
34
+
35
+ self.in_ch = in_ch
36
+ self.mid_ch = mid_ch
37
+ self.out_ch = out_ch
38
+
39
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
40
+
41
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
42
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
43
+
44
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
45
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
46
+
47
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
48
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
49
+
50
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
51
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
52
+
53
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
54
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
55
+
56
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
57
+
58
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
59
+
60
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
66
+
67
+ def forward(self,x):
68
+ b, c, h, w = x.shape
69
+
70
+ hx = x
71
+ hxin = self.rebnconvin(hx)
72
+
73
+ hx1 = self.rebnconv1(hxin)
74
+ hx = self.pool1(hx1)
75
+
76
+ hx2 = self.rebnconv2(hx)
77
+ hx = self.pool2(hx2)
78
+
79
+ hx3 = self.rebnconv3(hx)
80
+ hx = self.pool3(hx3)
81
+
82
+ hx4 = self.rebnconv4(hx)
83
+ hx = self.pool4(hx4)
84
+
85
+ hx5 = self.rebnconv5(hx)
86
+ hx = self.pool5(hx5)
87
+
88
+ hx6 = self.rebnconv6(hx)
89
+
90
+ hx7 = self.rebnconv7(hx6)
91
+
92
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
93
+ hx6dup = _upsample_like(hx6d,hx5)
94
+
95
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
96
+ hx5dup = _upsample_like(hx5d,hx4)
97
+
98
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
99
+ hx4dup = _upsample_like(hx4d,hx3)
100
+
101
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
102
+ hx3dup = _upsample_like(hx3d,hx2)
103
+
104
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
105
+ hx2dup = _upsample_like(hx2d,hx1)
106
+
107
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
108
+
109
+ return hx1d + hxin
110
+
111
+
112
+ ### RSU-6 ###
113
+ class RSU6(nn.Module):
114
+
115
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
+ super(RSU6,self).__init__()
117
+
118
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
119
+
120
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
121
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
122
+
123
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
124
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
125
+
126
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
127
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
128
+
129
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
130
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
131
+
132
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
133
+
134
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
135
+
136
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
141
+
142
+ def forward(self,x):
143
+
144
+ hx = x
145
+
146
+ hxin = self.rebnconvin(hx)
147
+
148
+ hx1 = self.rebnconv1(hxin)
149
+ hx = self.pool1(hx1)
150
+
151
+ hx2 = self.rebnconv2(hx)
152
+ hx = self.pool2(hx2)
153
+
154
+ hx3 = self.rebnconv3(hx)
155
+ hx = self.pool3(hx3)
156
+
157
+ hx4 = self.rebnconv4(hx)
158
+ hx = self.pool4(hx4)
159
+
160
+ hx5 = self.rebnconv5(hx)
161
+
162
+ hx6 = self.rebnconv6(hx5)
163
+
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
166
+ hx5dup = _upsample_like(hx5d,hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
169
+ hx4dup = _upsample_like(hx4d,hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
172
+ hx3dup = _upsample_like(hx3d,hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
175
+ hx2dup = _upsample_like(hx2d,hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
178
+
179
+ return hx1d + hxin
180
+
181
+ ### RSU-5 ###
182
+ class RSU5(nn.Module):
183
+
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5,self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
206
+
207
+ def forward(self,x):
208
+
209
+ hx = x
210
+
211
+ hxin = self.rebnconvin(hx)
212
+
213
+ hx1 = self.rebnconv1(hxin)
214
+ hx = self.pool1(hx1)
215
+
216
+ hx2 = self.rebnconv2(hx)
217
+ hx = self.pool2(hx2)
218
+
219
+ hx3 = self.rebnconv3(hx)
220
+ hx = self.pool3(hx3)
221
+
222
+ hx4 = self.rebnconv4(hx)
223
+
224
+ hx5 = self.rebnconv5(hx4)
225
+
226
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
227
+ hx4dup = _upsample_like(hx4d,hx3)
228
+
229
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
230
+ hx3dup = _upsample_like(hx3d,hx2)
231
+
232
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
233
+ hx2dup = _upsample_like(hx2d,hx1)
234
+
235
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
236
+
237
+ return hx1d + hxin
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+
242
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243
+ super(RSU4,self).__init__()
244
+
245
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
246
+
247
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
248
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
249
+
250
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
251
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
252
+
253
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
254
+
255
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
256
+
257
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
260
+
261
+ def forward(self,x):
262
+
263
+ hx = x
264
+
265
+ hxin = self.rebnconvin(hx)
266
+
267
+ hx1 = self.rebnconv1(hxin)
268
+ hx = self.pool1(hx1)
269
+
270
+ hx2 = self.rebnconv2(hx)
271
+ hx = self.pool2(hx2)
272
+
273
+ hx3 = self.rebnconv3(hx)
274
+
275
+ hx4 = self.rebnconv4(hx3)
276
+
277
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
278
+ hx3dup = _upsample_like(hx3d,hx2)
279
+
280
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
281
+ hx2dup = _upsample_like(hx2d,hx1)
282
+
283
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
284
+
285
+ return hx1d + hxin
286
+
287
+ ### RSU-4F ###
288
+ class RSU4F(nn.Module):
289
+
290
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
291
+ super(RSU4F,self).__init__()
292
+
293
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
294
+
295
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
296
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
297
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
298
+
299
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
300
+
301
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
302
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
303
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
304
+
305
+ def forward(self,x):
306
+
307
+ hx = x
308
+
309
+ hxin = self.rebnconvin(hx)
310
+
311
+ hx1 = self.rebnconv1(hxin)
312
+ hx2 = self.rebnconv2(hx1)
313
+ hx3 = self.rebnconv3(hx2)
314
+
315
+ hx4 = self.rebnconv4(hx3)
316
+
317
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
318
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
319
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
320
+
321
+ return hx1d + hxin
322
+
323
+
324
+ class myrebnconv(nn.Module):
325
+ def __init__(self, in_ch=3,
326
+ out_ch=1,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1,
330
+ dilation=1,
331
+ groups=1):
332
+ super(myrebnconv,self).__init__()
333
+
334
+ self.conv = nn.Conv2d(in_ch,
335
+ out_ch,
336
+ kernel_size=kernel_size,
337
+ stride=stride,
338
+ padding=padding,
339
+ dilation=dilation,
340
+ groups=groups)
341
+ self.bn = nn.BatchNorm2d(out_ch)
342
+ self.rl = nn.ReLU(inplace=True)
343
+
344
+ def forward(self,x):
345
+ return self.rl(self.bn(self.conv(x)))
346
+
347
+
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
+
350
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
+ super(BriaRMBG,self).__init__()
352
+ in_ch=config["in_ch"]
353
+ out_ch=config["out_ch"]
354
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
+
357
+ self.stage1 = RSU7(64,32,64)
358
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
359
+
360
+ self.stage2 = RSU6(64,32,128)
361
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
362
+
363
+ self.stage3 = RSU5(128,64,256)
364
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
365
+
366
+ self.stage4 = RSU4(256,128,512)
367
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
368
+
369
+ self.stage5 = RSU4F(512,256,512)
370
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
371
+
372
+ self.stage6 = RSU4F(512,256,512)
373
+
374
+ # decoder
375
+ self.stage5d = RSU4F(1024,256,512)
376
+ self.stage4d = RSU4(1024,128,256)
377
+ self.stage3d = RSU5(512,64,128)
378
+ self.stage2d = RSU6(256,32,64)
379
+ self.stage1d = RSU7(128,16,64)
380
+
381
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
382
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
383
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
384
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
385
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
386
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
387
+
388
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
389
+
390
+ def forward(self,x):
391
+
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ #hx = self.pool_in(hxin)
396
+
397
+ #stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ #stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ #stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ #stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ #stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ #stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6,hx5)
420
+
421
+ #-------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
423
+ hx5dup = _upsample_like(hx5d,hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
426
+ hx4dup = _upsample_like(hx4d,hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
429
+ hx3dup = _upsample_like(hx3d,hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
432
+ hx2dup = _upsample_like(hx2d,hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
435
+
436
+
437
+ #side output
438
+ d1 = self.side1(hx1d)
439
+ d1 = _upsample_like(d1,x)
440
+
441
+ d2 = self.side2(hx2d)
442
+ d2 = _upsample_like(d2,x)
443
+
444
+ d3 = self.side3(hx3d)
445
+ d3 = _upsample_like(d3,x)
446
+
447
+ d4 = self.side4(hx4d)
448
+ d4 = _upsample_like(d4,x)
449
+
450
+ d5 = self.side5(hx5d)
451
+ d5 = _upsample_like(d5,x)
452
+
453
+ d6 = self.side6(hx6)
454
+ d6 = _upsample_like(d6,x)
455
+
456
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
input.jpg ADDED
input.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f67ef25ad2ef3e72e3b2926bebbb8cfe49ee4ee702e56bae804931e0fb165698
3
+ size 4536473
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ gradio_imageslider
3
+ torch
4
+ torchvision
5
+ pillow
6
+ numpy
7
+ typing
8
+ gitpython
9
+ huggingface_hub
10
+ opencv-python