habulaj commited on
Commit
dcb31a8
·
verified ·
1 Parent(s): b6f9a9a

Delete briarmbg.py

Browse files
Files changed (1) hide show
  1. briarmbg.py +0 -454
briarmbg.py DELETED
@@ -1,454 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class REBNCONV(nn.Module):
6
- def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
7
- super(REBNCONV,self).__init__()
8
-
9
- self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
10
- self.bn_s1 = nn.BatchNorm2d(out_ch)
11
- self.relu_s1 = nn.ReLU(inplace=True)
12
-
13
- def forward(self,x):
14
-
15
- hx = x
16
- xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
-
18
- return xout
19
-
20
- ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
- def _upsample_like(src,tar):
22
-
23
- src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
24
-
25
- return src
26
-
27
-
28
- ### RSU-7 ###
29
- class RSU7(nn.Module):
30
-
31
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
32
- super(RSU7,self).__init__()
33
-
34
- self.in_ch = in_ch
35
- self.mid_ch = mid_ch
36
- self.out_ch = out_ch
37
-
38
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
39
-
40
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
41
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
42
-
43
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
44
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
45
-
46
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
47
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
48
-
49
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
50
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
51
-
52
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
53
- self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
54
-
55
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
56
-
57
- self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
58
-
59
- self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
65
-
66
- def forward(self,x):
67
- b, c, h, w = x.shape
68
-
69
- hx = x
70
- hxin = self.rebnconvin(hx)
71
-
72
- hx1 = self.rebnconv1(hxin)
73
- hx = self.pool1(hx1)
74
-
75
- hx2 = self.rebnconv2(hx)
76
- hx = self.pool2(hx2)
77
-
78
- hx3 = self.rebnconv3(hx)
79
- hx = self.pool3(hx3)
80
-
81
- hx4 = self.rebnconv4(hx)
82
- hx = self.pool4(hx4)
83
-
84
- hx5 = self.rebnconv5(hx)
85
- hx = self.pool5(hx5)
86
-
87
- hx6 = self.rebnconv6(hx)
88
-
89
- hx7 = self.rebnconv7(hx6)
90
-
91
- hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
92
- hx6dup = _upsample_like(hx6d,hx5)
93
-
94
- hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
95
- hx5dup = _upsample_like(hx5d,hx4)
96
-
97
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
98
- hx4dup = _upsample_like(hx4d,hx3)
99
-
100
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
101
- hx3dup = _upsample_like(hx3d,hx2)
102
-
103
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
104
- hx2dup = _upsample_like(hx2d,hx1)
105
-
106
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
107
-
108
- return hx1d + hxin
109
-
110
-
111
- ### RSU-6 ###
112
- class RSU6(nn.Module):
113
-
114
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
115
- super(RSU6,self).__init__()
116
-
117
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
118
-
119
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
120
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
-
122
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
-
125
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
127
-
128
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
129
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
130
-
131
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
132
-
133
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
134
-
135
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
136
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
140
-
141
- def forward(self,x):
142
-
143
- hx = x
144
-
145
- hxin = self.rebnconvin(hx)
146
-
147
- hx1 = self.rebnconv1(hxin)
148
- hx = self.pool1(hx1)
149
-
150
- hx2 = self.rebnconv2(hx)
151
- hx = self.pool2(hx2)
152
-
153
- hx3 = self.rebnconv3(hx)
154
- hx = self.pool3(hx3)
155
-
156
- hx4 = self.rebnconv4(hx)
157
- hx = self.pool4(hx4)
158
-
159
- hx5 = self.rebnconv5(hx)
160
-
161
- hx6 = self.rebnconv6(hx5)
162
-
163
-
164
- hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
165
- hx5dup = _upsample_like(hx5d,hx4)
166
-
167
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
168
- hx4dup = _upsample_like(hx4d,hx3)
169
-
170
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
171
- hx3dup = _upsample_like(hx3d,hx2)
172
-
173
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
174
- hx2dup = _upsample_like(hx2d,hx1)
175
-
176
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
177
-
178
- return hx1d + hxin
179
-
180
- ### RSU-5 ###
181
- class RSU5(nn.Module):
182
-
183
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
184
- super(RSU5,self).__init__()
185
-
186
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
187
-
188
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
189
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
-
191
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
193
-
194
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
195
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
196
-
197
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
198
-
199
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
200
-
201
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
202
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
205
-
206
- def forward(self,x):
207
-
208
- hx = x
209
-
210
- hxin = self.rebnconvin(hx)
211
-
212
- hx1 = self.rebnconv1(hxin)
213
- hx = self.pool1(hx1)
214
-
215
- hx2 = self.rebnconv2(hx)
216
- hx = self.pool2(hx2)
217
-
218
- hx3 = self.rebnconv3(hx)
219
- hx = self.pool3(hx3)
220
-
221
- hx4 = self.rebnconv4(hx)
222
-
223
- hx5 = self.rebnconv5(hx4)
224
-
225
- hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
226
- hx4dup = _upsample_like(hx4d,hx3)
227
-
228
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
229
- hx3dup = _upsample_like(hx3d,hx2)
230
-
231
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
232
- hx2dup = _upsample_like(hx2d,hx1)
233
-
234
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
235
-
236
- return hx1d + hxin
237
-
238
- ### RSU-4 ###
239
- class RSU4(nn.Module):
240
-
241
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
- super(RSU4,self).__init__()
243
-
244
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
245
-
246
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
247
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
248
-
249
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
250
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
251
-
252
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
253
-
254
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
255
-
256
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
257
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
259
-
260
- def forward(self,x):
261
-
262
- hx = x
263
-
264
- hxin = self.rebnconvin(hx)
265
-
266
- hx1 = self.rebnconv1(hxin)
267
- hx = self.pool1(hx1)
268
-
269
- hx2 = self.rebnconv2(hx)
270
- hx = self.pool2(hx2)
271
-
272
- hx3 = self.rebnconv3(hx)
273
-
274
- hx4 = self.rebnconv4(hx3)
275
-
276
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
277
- hx3dup = _upsample_like(hx3d,hx2)
278
-
279
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
280
- hx2dup = _upsample_like(hx2d,hx1)
281
-
282
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
283
-
284
- return hx1d + hxin
285
-
286
- ### RSU-4F ###
287
- class RSU4F(nn.Module):
288
-
289
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
290
- super(RSU4F,self).__init__()
291
-
292
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
293
-
294
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
295
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
296
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
297
-
298
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
299
-
300
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
301
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
302
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
303
-
304
- def forward(self,x):
305
-
306
- hx = x
307
-
308
- hxin = self.rebnconvin(hx)
309
-
310
- hx1 = self.rebnconv1(hxin)
311
- hx2 = self.rebnconv2(hx1)
312
- hx3 = self.rebnconv3(hx2)
313
-
314
- hx4 = self.rebnconv4(hx3)
315
-
316
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
317
- hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
318
- hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
319
-
320
- return hx1d + hxin
321
-
322
-
323
- class myrebnconv(nn.Module):
324
- def __init__(self, in_ch=3,
325
- out_ch=1,
326
- kernel_size=3,
327
- stride=1,
328
- padding=1,
329
- dilation=1,
330
- groups=1):
331
- super(myrebnconv,self).__init__()
332
-
333
- self.conv = nn.Conv2d(in_ch,
334
- out_ch,
335
- kernel_size=kernel_size,
336
- stride=stride,
337
- padding=padding,
338
- dilation=dilation,
339
- groups=groups)
340
- self.bn = nn.BatchNorm2d(out_ch)
341
- self.rl = nn.ReLU(inplace=True)
342
-
343
- def forward(self,x):
344
- return self.rl(self.bn(self.conv(x)))
345
-
346
-
347
- class BriaRMBG(nn.Module):
348
-
349
- def __init__(self,in_ch=3,out_ch=1):
350
- super(BriaRMBG,self).__init__()
351
-
352
- self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
353
- self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
354
-
355
- self.stage1 = RSU7(64,32,64)
356
- self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
357
-
358
- self.stage2 = RSU6(64,32,128)
359
- self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
360
-
361
- self.stage3 = RSU5(128,64,256)
362
- self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
-
364
- self.stage4 = RSU4(256,128,512)
365
- self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
-
367
- self.stage5 = RSU4F(512,256,512)
368
- self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
-
370
- self.stage6 = RSU4F(512,256,512)
371
-
372
- # decoder
373
- self.stage5d = RSU4F(1024,256,512)
374
- self.stage4d = RSU4(1024,128,256)
375
- self.stage3d = RSU5(512,64,128)
376
- self.stage2d = RSU6(256,32,64)
377
- self.stage1d = RSU7(128,16,64)
378
-
379
- self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
380
- self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
381
- self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
382
- self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
383
- self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
384
- self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
385
-
386
- # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
387
-
388
- def forward(self,x):
389
-
390
- hx = x
391
-
392
- hxin = self.conv_in(hx)
393
- #hx = self.pool_in(hxin)
394
-
395
- #stage 1
396
- hx1 = self.stage1(hxin)
397
- hx = self.pool12(hx1)
398
-
399
- #stage 2
400
- hx2 = self.stage2(hx)
401
- hx = self.pool23(hx2)
402
-
403
- #stage 3
404
- hx3 = self.stage3(hx)
405
- hx = self.pool34(hx3)
406
-
407
- #stage 4
408
- hx4 = self.stage4(hx)
409
- hx = self.pool45(hx4)
410
-
411
- #stage 5
412
- hx5 = self.stage5(hx)
413
- hx = self.pool56(hx5)
414
-
415
- #stage 6
416
- hx6 = self.stage6(hx)
417
- hx6up = _upsample_like(hx6,hx5)
418
-
419
- #-------------------- decoder --------------------
420
- hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
421
- hx5dup = _upsample_like(hx5d,hx4)
422
-
423
- hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
424
- hx4dup = _upsample_like(hx4d,hx3)
425
-
426
- hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
427
- hx3dup = _upsample_like(hx3d,hx2)
428
-
429
- hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
430
- hx2dup = _upsample_like(hx2d,hx1)
431
-
432
- hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
433
-
434
-
435
- #side output
436
- d1 = self.side1(hx1d)
437
- d1 = _upsample_like(d1,x)
438
-
439
- d2 = self.side2(hx2d)
440
- d2 = _upsample_like(d2,x)
441
-
442
- d3 = self.side3(hx3d)
443
- d3 = _upsample_like(d3,x)
444
-
445
- d4 = self.side4(hx4d)
446
- d4 = _upsample_like(d4,x)
447
-
448
- d5 = self.side5(hx5d)
449
- d5 = _upsample_like(d5,x)
450
-
451
- d6 = self.side6(hx6)
452
- d6 = _upsample_like(d6,x)
453
-
454
- return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]