gouravgujariya commited on
Commit
fd7a33e
·
verified ·
1 Parent(s): 5d14af0

Delete briarmbg.py

Browse files
Files changed (1) hide show
  1. briarmbg.py +0 -456
briarmbg.py DELETED
@@ -1,456 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from huggingface_hub import PyTorchModelHubMixin
5
-
6
- class REBNCONV(nn.Module):
7
- def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
8
- super(REBNCONV,self).__init__()
9
-
10
- self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
11
- self.bn_s1 = nn.BatchNorm2d(out_ch)
12
- self.relu_s1 = nn.ReLU(inplace=True)
13
-
14
- def forward(self,x):
15
-
16
- hx = x
17
- xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
18
-
19
- return xout
20
-
21
- ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
22
- def _upsample_like(src,tar):
23
-
24
- src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
25
-
26
- return src
27
-
28
-
29
- ### RSU-7 ###
30
- class RSU7(nn.Module):
31
-
32
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
33
- super(RSU7,self).__init__()
34
-
35
- self.in_ch = in_ch
36
- self.mid_ch = mid_ch
37
- self.out_ch = out_ch
38
-
39
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
40
-
41
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
42
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
43
-
44
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
45
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
46
-
47
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
48
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
49
-
50
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
51
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
52
-
53
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
54
- self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
55
-
56
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
57
-
58
- self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
59
-
60
- self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
66
-
67
- def forward(self,x):
68
- b, c, h, w = x.shape
69
-
70
- hx = x
71
- hxin = self.rebnconvin(hx)
72
-
73
- hx1 = self.rebnconv1(hxin)
74
- hx = self.pool1(hx1)
75
-
76
- hx2 = self.rebnconv2(hx)
77
- hx = self.pool2(hx2)
78
-
79
- hx3 = self.rebnconv3(hx)
80
- hx = self.pool3(hx3)
81
-
82
- hx4 = self.rebnconv4(hx)
83
- hx = self.pool4(hx4)
84
-
85
- hx5 = self.rebnconv5(hx)
86
- hx = self.pool5(hx5)
87
-
88
- hx6 = self.rebnconv6(hx)
89
-
90
- hx7 = self.rebnconv7(hx6)
91
-
92
- hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
93
- hx6dup = _upsample_like(hx6d,hx5)
94
-
95
- hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
96
- hx5dup = _upsample_like(hx5d,hx4)
97
-
98
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
99
- hx4dup = _upsample_like(hx4d,hx3)
100
-
101
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
102
- hx3dup = _upsample_like(hx3d,hx2)
103
-
104
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
105
- hx2dup = _upsample_like(hx2d,hx1)
106
-
107
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
108
-
109
- return hx1d + hxin
110
-
111
-
112
- ### RSU-6 ###
113
- class RSU6(nn.Module):
114
-
115
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
- super(RSU6,self).__init__()
117
-
118
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
119
-
120
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
121
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
122
-
123
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
124
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
125
-
126
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
127
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
128
-
129
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
130
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
131
-
132
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
133
-
134
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
135
-
136
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
141
-
142
- def forward(self,x):
143
-
144
- hx = x
145
-
146
- hxin = self.rebnconvin(hx)
147
-
148
- hx1 = self.rebnconv1(hxin)
149
- hx = self.pool1(hx1)
150
-
151
- hx2 = self.rebnconv2(hx)
152
- hx = self.pool2(hx2)
153
-
154
- hx3 = self.rebnconv3(hx)
155
- hx = self.pool3(hx3)
156
-
157
- hx4 = self.rebnconv4(hx)
158
- hx = self.pool4(hx4)
159
-
160
- hx5 = self.rebnconv5(hx)
161
-
162
- hx6 = self.rebnconv6(hx5)
163
-
164
-
165
- hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
166
- hx5dup = _upsample_like(hx5d,hx4)
167
-
168
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
169
- hx4dup = _upsample_like(hx4d,hx3)
170
-
171
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
172
- hx3dup = _upsample_like(hx3d,hx2)
173
-
174
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
175
- hx2dup = _upsample_like(hx2d,hx1)
176
-
177
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
178
-
179
- return hx1d + hxin
180
-
181
- ### RSU-5 ###
182
- class RSU5(nn.Module):
183
-
184
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
- super(RSU5,self).__init__()
186
-
187
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
188
-
189
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
190
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
191
-
192
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
193
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
194
-
195
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
196
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
197
-
198
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
199
-
200
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
201
-
202
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
206
-
207
- def forward(self,x):
208
-
209
- hx = x
210
-
211
- hxin = self.rebnconvin(hx)
212
-
213
- hx1 = self.rebnconv1(hxin)
214
- hx = self.pool1(hx1)
215
-
216
- hx2 = self.rebnconv2(hx)
217
- hx = self.pool2(hx2)
218
-
219
- hx3 = self.rebnconv3(hx)
220
- hx = self.pool3(hx3)
221
-
222
- hx4 = self.rebnconv4(hx)
223
-
224
- hx5 = self.rebnconv5(hx4)
225
-
226
- hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
227
- hx4dup = _upsample_like(hx4d,hx3)
228
-
229
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
230
- hx3dup = _upsample_like(hx3d,hx2)
231
-
232
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
233
- hx2dup = _upsample_like(hx2d,hx1)
234
-
235
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
236
-
237
- return hx1d + hxin
238
-
239
- ### RSU-4 ###
240
- class RSU4(nn.Module):
241
-
242
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243
- super(RSU4,self).__init__()
244
-
245
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
246
-
247
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
248
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
249
-
250
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
251
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
252
-
253
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
254
-
255
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
256
-
257
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
260
-
261
- def forward(self,x):
262
-
263
- hx = x
264
-
265
- hxin = self.rebnconvin(hx)
266
-
267
- hx1 = self.rebnconv1(hxin)
268
- hx = self.pool1(hx1)
269
-
270
- hx2 = self.rebnconv2(hx)
271
- hx = self.pool2(hx2)
272
-
273
- hx3 = self.rebnconv3(hx)
274
-
275
- hx4 = self.rebnconv4(hx3)
276
-
277
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
278
- hx3dup = _upsample_like(hx3d,hx2)
279
-
280
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
281
- hx2dup = _upsample_like(hx2d,hx1)
282
-
283
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
284
-
285
- return hx1d + hxin
286
-
287
- ### RSU-4F ###
288
- class RSU4F(nn.Module):
289
-
290
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
291
- super(RSU4F,self).__init__()
292
-
293
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
294
-
295
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
296
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
297
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
298
-
299
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
300
-
301
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
302
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
303
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
304
-
305
- def forward(self,x):
306
-
307
- hx = x
308
-
309
- hxin = self.rebnconvin(hx)
310
-
311
- hx1 = self.rebnconv1(hxin)
312
- hx2 = self.rebnconv2(hx1)
313
- hx3 = self.rebnconv3(hx2)
314
-
315
- hx4 = self.rebnconv4(hx3)
316
-
317
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
318
- hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
319
- hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
320
-
321
- return hx1d + hxin
322
-
323
-
324
- class myrebnconv(nn.Module):
325
- def __init__(self, in_ch=3,
326
- out_ch=1,
327
- kernel_size=3,
328
- stride=1,
329
- padding=1,
330
- dilation=1,
331
- groups=1):
332
- super(myrebnconv,self).__init__()
333
-
334
- self.conv = nn.Conv2d(in_ch,
335
- out_ch,
336
- kernel_size=kernel_size,
337
- stride=stride,
338
- padding=padding,
339
- dilation=dilation,
340
- groups=groups)
341
- self.bn = nn.BatchNorm2d(out_ch)
342
- self.rl = nn.ReLU(inplace=True)
343
-
344
- def forward(self,x):
345
- return self.rl(self.bn(self.conv(x)))
346
-
347
-
348
- class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
-
350
- def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
- super(BriaRMBG,self).__init__()
352
- in_ch=config["in_ch"]
353
- out_ch=config["out_ch"]
354
- self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
- self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
-
357
- self.stage1 = RSU7(64,32,64)
358
- self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
359
-
360
- self.stage2 = RSU6(64,32,128)
361
- self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
362
-
363
- self.stage3 = RSU5(128,64,256)
364
- self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
365
-
366
- self.stage4 = RSU4(256,128,512)
367
- self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
368
-
369
- self.stage5 = RSU4F(512,256,512)
370
- self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
371
-
372
- self.stage6 = RSU4F(512,256,512)
373
-
374
- # decoder
375
- self.stage5d = RSU4F(1024,256,512)
376
- self.stage4d = RSU4(1024,128,256)
377
- self.stage3d = RSU5(512,64,128)
378
- self.stage2d = RSU6(256,32,64)
379
- self.stage1d = RSU7(128,16,64)
380
-
381
- self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
382
- self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
383
- self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
384
- self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
385
- self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
386
- self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
387
-
388
- # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
389
-
390
- def forward(self,x):
391
-
392
- hx = x
393
-
394
- hxin = self.conv_in(hx)
395
- #hx = self.pool_in(hxin)
396
-
397
- #stage 1
398
- hx1 = self.stage1(hxin)
399
- hx = self.pool12(hx1)
400
-
401
- #stage 2
402
- hx2 = self.stage2(hx)
403
- hx = self.pool23(hx2)
404
-
405
- #stage 3
406
- hx3 = self.stage3(hx)
407
- hx = self.pool34(hx3)
408
-
409
- #stage 4
410
- hx4 = self.stage4(hx)
411
- hx = self.pool45(hx4)
412
-
413
- #stage 5
414
- hx5 = self.stage5(hx)
415
- hx = self.pool56(hx5)
416
-
417
- #stage 6
418
- hx6 = self.stage6(hx)
419
- hx6up = _upsample_like(hx6,hx5)
420
-
421
- #-------------------- decoder --------------------
422
- hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
423
- hx5dup = _upsample_like(hx5d,hx4)
424
-
425
- hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
426
- hx4dup = _upsample_like(hx4d,hx3)
427
-
428
- hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
429
- hx3dup = _upsample_like(hx3d,hx2)
430
-
431
- hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
432
- hx2dup = _upsample_like(hx2d,hx1)
433
-
434
- hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
435
-
436
-
437
- #side output
438
- d1 = self.side1(hx1d)
439
- d1 = _upsample_like(d1,x)
440
-
441
- d2 = self.side2(hx2d)
442
- d2 = _upsample_like(d2,x)
443
-
444
- d3 = self.side3(hx3d)
445
- d3 = _upsample_like(d3,x)
446
-
447
- d4 = self.side4(hx4d)
448
- d4 = _upsample_like(d4,x)
449
-
450
- d5 = self.side5(hx5d)
451
- d5 = _upsample_like(d5,x)
452
-
453
- d6 = self.side6(hx6)
454
- d6 = _upsample_like(d6,x)
455
-
456
- return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]