Spaces:
Build error
Build error
File size: 4,110 Bytes
cd97fb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class SRMConv2d_simple(nn.Module):
def __init__(self, inc=3, learnable=False):
super(SRMConv2d_simple, self).__init__()
self.truc = nn.Hardtanh(-3, 3)
kernel = self._build_kernel(inc) # (3,3,5,5)
self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
# self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
def forward(self, x):
'''
x: imgs (Batch, H, W, 3)
'''
out = F.conv2d(x, self.kernel, stride=1, padding=2)
out = self.truc(out)
return out
def _build_kernel(self, inc):
# filter1: KB
filter1 = [[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0]]
# filter2:KV
filter2 = [[-1, 2, -2, 2, -1],
[2, -6, 8, -6, 2],
[-2, 8, -12, 8, -2],
[2, -6, 8, -6, 2],
[-1, 2, -2, 2, -1]]
# filter3:hor 2rd
filter3 = [[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, -2, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
filter1 = np.asarray(filter1, dtype=float) / 4.
filter2 = np.asarray(filter2, dtype=float) / 12.
filter3 = np.asarray(filter3, dtype=float) / 2.
# statck the filters
filters = [[filter1],#, filter1, filter1],
[filter2],#, filter2, filter2],
[filter3]]#, filter3, filter3]] # (3,3,5,5)
filters = np.array(filters)
filters = np.repeat(filters, inc, axis=1)
filters = torch.FloatTensor(filters) # (3,3,5,5)
return filters
class SRMConv2d_Separate(nn.Module):
def __init__(self, inc, outc, learnable=False):
super(SRMConv2d_Separate, self).__init__()
self.inc = inc
self.truc = nn.Hardtanh(-3, 3)
kernel = self._build_kernel(inc) # (3,3,5,5)
self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
# self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
self.out_conv = nn.Sequential(
nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False),
nn.BatchNorm2d(outc),
nn.ReLU(inplace=True)
)
for ly in self.out_conv.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
def forward(self, x):
'''
x: imgs (Batch, H, W, 3)
'''
out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc)
out = self.truc(out)
out = self.out_conv(out)
return out
def _build_kernel(self, inc):
# filter1: KB
filter1 = [[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0]]
# filter2:KV
filter2 = [[-1, 2, -2, 2, -1],
[2, -6, 8, -6, 2],
[-2, 8, -12, 8, -2],
[2, -6, 8, -6, 2],
[-1, 2, -2, 2, -1]]
# # filter3:hor 2rd
filter3 = [[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, -2, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
filter1 = np.asarray(filter1, dtype=float) / 4.
filter2 = np.asarray(filter2, dtype=float) / 12.
filter3 = np.asarray(filter3, dtype=float) / 2.
# statck the filters
filters = [[filter1],#, filter1, filter1],
[filter2],#, filter2, filter2],
[filter3]]#, filter3, filter3]] # (3,3,5,5)
filters = np.array(filters)
# filters = np.repeat(filters, inc, axis=1)
filters = np.repeat(filters, inc, axis=0)
filters = torch.FloatTensor(filters) # (3,3,5,5)
# print(filters.size())
return filters
|