Spaces:
Running
Running
Upload 12 files
Browse files- README.md +4 -8
- app.py +53 -48
- archs/DarkIR.py +154 -0
- archs/__init__.py +231 -3
- archs/arch_model.py +240 -0
- archs/arch_util.py +58 -111
- darkir_v2real+lsrw.pt +3 -0
- examples/0010.png +0 -0
- examples/0087.png +0 -0
- examples/low00733_low.png +0 -0
- examples/r13073518t_low.png +0 -0
- requirements.txt +8 -97
README.md
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 📉
|
4 |
colorFrom: red
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
title: DarkIR
|
2 |
+
emoji: 🌻
|
|
|
3 |
colorFrom: red
|
4 |
colorTo: gray
|
5 |
sdk: gradio
|
6 |
+
sdk_version: 5.8.0
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
+
license: mit
|
|
|
|
|
|
app.py
CHANGED
@@ -1,84 +1,89 @@
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
-
import os
|
4 |
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
import torchvision.transforms as transforms
|
7 |
-
import
|
8 |
-
|
9 |
-
import
|
10 |
-
from huggingface_hub import hf_hub_download
|
11 |
|
12 |
-
from archs import Network
|
13 |
-
from options.options import parse
|
14 |
|
15 |
-
path_opt = './options/predict/LOLBlur.yml'
|
16 |
|
17 |
-
opt = parse(path_opt)
|
18 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
19 |
#define some auxiliary functions
|
20 |
pil_to_tensor = transforms.ToTensor()
|
|
|
21 |
|
22 |
-
|
23 |
|
24 |
-
|
25 |
-
width=opt['network']['width'],
|
26 |
-
middle_blk_num_enc=opt['network']['middle_blk_num_enc'],
|
27 |
-
middle_blk_num_dec=opt['network']['middle_blk_num_dec'],
|
28 |
-
enc_blk_nums=opt['network']['enc_blk_nums'],
|
29 |
-
dec_blk_nums=opt['network']['dec_blk_nums'],
|
30 |
-
dilations=opt['network']['dilations'],
|
31 |
-
extra_depth_wise = opt['network']['extra_depth_wise'])
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
model = model.to(device)
|
38 |
|
39 |
-
def
|
40 |
-
img = Image.open(
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def process_img(image):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
|
50 |
with torch.no_grad():
|
51 |
-
|
52 |
-
|
53 |
-
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
|
54 |
-
restored_img = np.clip(restored_img, 0. , 1.)
|
55 |
|
56 |
-
|
57 |
-
|
|
|
58 |
|
59 |
-
title = "
|
60 |
-
description = ''' ## [Low
|
61 |
|
62 |
[Daniel Feijoo](https://github.com/danifei)
|
63 |
|
64 |
Fundación Cidaut
|
65 |
|
66 |
-
This model enhances low light images into normal light conditions ones. It was trained using LOLv2-real, LOLv2-synth and LOLBlur.
|
67 |
-
Due to the training on LOLBlur, this network is expected to also reconstruct blurred low light images.
|
68 |
|
69 |
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
|
70 |
-
**This demo expects an image with some degradations.**
|
71 |
-
Due to the CPU limitations, the model won't return results inmediately <br>.
|
72 |
-
Except for the LOLv2-real, the model was trained using mostly synthetic data, thus it might not work great on real-world complex images.
|
73 |
|
74 |
<br>
|
75 |
'''
|
76 |
|
77 |
-
examples = [['examples/
|
78 |
-
['examples/
|
79 |
-
['examples/
|
80 |
-
["examples/
|
81 |
-
["examples/inputs/0088.png"]]
|
82 |
|
83 |
css = """
|
84 |
.image-frame img, .image-container img {
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
|
|
3 |
import torch
|
|
|
4 |
import torchvision.transforms as transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from archs import DarkIR
|
|
|
8 |
|
|
|
|
|
9 |
|
|
|
10 |
|
|
|
11 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
12 |
#define some auxiliary functions
|
13 |
pil_to_tensor = transforms.ToTensor()
|
14 |
+
tensor_to_pil = transforms.ToPILImage()
|
15 |
|
16 |
+
network = 'DarkIR'
|
17 |
|
18 |
+
PATH_MODEL = './models/darkir_1k_allv2_251205.pt'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
model = DarkIR(img_channel=3,
|
21 |
+
width=32,
|
22 |
+
middle_blk_num_enc=2,
|
23 |
+
middle_blk_num_dec=2,
|
24 |
+
enc_blk_nums=[1, 2, 3],
|
25 |
+
dec_blk_nums=[3, 1, 1],
|
26 |
+
dilations=[1, 4, 9],
|
27 |
+
extra_depth_wise=True)
|
28 |
+
|
29 |
+
checkpoints = torch.load(PATH_MODEL, map_location=device)
|
30 |
+
model.load_state_dict(checkpoints['params'])
|
31 |
|
32 |
model = model.to(device)
|
33 |
|
34 |
+
def path_to_tensor(path):
|
35 |
+
img = Image.open(path).convert('RGB')
|
36 |
+
img = pil_to_tensor(img).unsqueeze(0)
|
37 |
+
|
38 |
+
return img
|
39 |
+
def normalize_tensor(tensor):
|
40 |
+
|
41 |
+
max_value = torch.max(tensor)
|
42 |
+
min_value = torch.min(tensor)
|
43 |
+
output = (tensor - min_value)/(max_value)
|
44 |
+
return output
|
45 |
+
|
46 |
+
def pad_tensor(tensor, multiple = 8):
|
47 |
+
'''pad the tensor to be multiple of some number'''
|
48 |
+
multiple = multiple
|
49 |
+
_, _, H, W = tensor.shape
|
50 |
+
pad_h = (multiple - H % multiple) % multiple
|
51 |
+
pad_w = (multiple - W % multiple) % multiple
|
52 |
+
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0)
|
53 |
+
|
54 |
+
return tensor
|
55 |
|
56 |
def process_img(image):
|
57 |
+
tensor = path_to_tensor(image).to(device)
|
58 |
+
_, _, H, W = tensor.shape
|
59 |
+
|
60 |
+
tensor = pad_tensor(tensor)
|
61 |
|
62 |
with torch.no_grad():
|
63 |
+
output = model(tensor, side_loss=False)
|
|
|
|
|
|
|
64 |
|
65 |
+
output = torch.clamp(output, 0., 1.)
|
66 |
+
output = output[:,:, :H, :W].squeeze(0)
|
67 |
+
return tensor_to_pil(output)
|
68 |
|
69 |
+
title = "DarkIR ✏️🖼️ 🤗"
|
70 |
+
description = ''' ## [ DarkIR: Robust Low-Light Image Restoration](https://github.com/cidautai/DarkIR)
|
71 |
|
72 |
[Daniel Feijoo](https://github.com/danifei)
|
73 |
|
74 |
Fundación Cidaut
|
75 |
|
|
|
|
|
76 |
|
77 |
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
|
78 |
+
**This demo expects an image with some Low-Light degradations.**
|
|
|
|
|
79 |
|
80 |
<br>
|
81 |
'''
|
82 |
|
83 |
+
examples = [['examples/0010.png'],
|
84 |
+
['examples/r13073518t_low.png'],
|
85 |
+
['examples/low00733_low.png'],
|
86 |
+
["examples/0087.png"]]
|
|
|
87 |
|
88 |
css = """
|
89 |
.image-frame img, .image-container img {
|
archs/DarkIR.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
try:
|
5 |
+
from arch_model import EBlock, DBlock
|
6 |
+
from arch_util import CustomSequential
|
7 |
+
except:
|
8 |
+
from archs.arch_model import EBlock, DBlock
|
9 |
+
from .arch_util import CustomSequential
|
10 |
+
|
11 |
+
class DarkIR(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, img_channel=3,
|
14 |
+
width=32,
|
15 |
+
middle_blk_num_enc=2,
|
16 |
+
middle_blk_num_dec=2,
|
17 |
+
enc_blk_nums=[1, 2, 3],
|
18 |
+
dec_blk_nums=[3, 1, 1],
|
19 |
+
dilations = [1, 4, 9],
|
20 |
+
extra_depth_wise = True):
|
21 |
+
super(DarkIR, self).__init__()
|
22 |
+
|
23 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
24 |
+
bias=True)
|
25 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
26 |
+
bias=True)
|
27 |
+
|
28 |
+
self.encoders = nn.ModuleList()
|
29 |
+
self.decoders = nn.ModuleList()
|
30 |
+
self.middle_blks = nn.ModuleList()
|
31 |
+
self.ups = nn.ModuleList()
|
32 |
+
self.downs = nn.ModuleList()
|
33 |
+
|
34 |
+
chan = width
|
35 |
+
for num in enc_blk_nums:
|
36 |
+
self.encoders.append(
|
37 |
+
CustomSequential(
|
38 |
+
*[EBlock(chan, extra_depth_wise=extra_depth_wise) for _ in range(num)]
|
39 |
+
)
|
40 |
+
)
|
41 |
+
self.downs.append(
|
42 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
43 |
+
)
|
44 |
+
chan = chan * 2
|
45 |
+
|
46 |
+
self.middle_blks_enc = \
|
47 |
+
CustomSequential(
|
48 |
+
*[EBlock(chan, extra_depth_wise=extra_depth_wise) for _ in range(middle_blk_num_enc)]
|
49 |
+
)
|
50 |
+
self.middle_blks_dec = \
|
51 |
+
CustomSequential(
|
52 |
+
*[DBlock(chan, dilations=dilations, extra_depth_wise=extra_depth_wise) for _ in range(middle_blk_num_dec)]
|
53 |
+
)
|
54 |
+
|
55 |
+
for num in dec_blk_nums:
|
56 |
+
self.ups.append(
|
57 |
+
nn.Sequential(
|
58 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
59 |
+
nn.PixelShuffle(2)
|
60 |
+
)
|
61 |
+
)
|
62 |
+
chan = chan // 2
|
63 |
+
self.decoders.append(
|
64 |
+
CustomSequential(
|
65 |
+
*[DBlock(chan, dilations=dilations, extra_depth_wise=extra_depth_wise) for _ in range(num)]
|
66 |
+
)
|
67 |
+
)
|
68 |
+
self.padder_size = 2 ** len(self.encoders)
|
69 |
+
|
70 |
+
# this layer is needed for the computing of the middle loss. It isn't necessary for anything else
|
71 |
+
self.side_out = nn.Conv2d(in_channels = width * 2**len(self.encoders), out_channels = img_channel,
|
72 |
+
kernel_size = 3, stride=1, padding=1)
|
73 |
+
|
74 |
+
def forward(self, input, side_loss = False, use_adapter = None):
|
75 |
+
|
76 |
+
_, _, H, W = input.shape
|
77 |
+
|
78 |
+
input = self.check_image_size(input)
|
79 |
+
x = self.intro(input)
|
80 |
+
|
81 |
+
skips = []
|
82 |
+
for encoder, down in zip(self.encoders, self.downs):
|
83 |
+
x = encoder(x)
|
84 |
+
skips.append(x)
|
85 |
+
x = down(x)
|
86 |
+
|
87 |
+
# we apply the encoder transforms
|
88 |
+
x_light = self.middle_blks_enc(x)
|
89 |
+
|
90 |
+
if side_loss:
|
91 |
+
out_side = self.side_out(x_light)
|
92 |
+
# apply the decoder transforms
|
93 |
+
x = self.middle_blks_dec(x_light)
|
94 |
+
x = x + x_light
|
95 |
+
|
96 |
+
for decoder, up, skip in zip(self.decoders, self.ups, skips[::-1]):
|
97 |
+
x = up(x)
|
98 |
+
x = x + skip
|
99 |
+
x = decoder(x)
|
100 |
+
|
101 |
+
x = self.ending(x)
|
102 |
+
x = x + input
|
103 |
+
out = x[:, :, :H, :W] # we recover the original size of the image
|
104 |
+
if side_loss:
|
105 |
+
return out_side, out
|
106 |
+
else:
|
107 |
+
return out
|
108 |
+
|
109 |
+
def check_image_size(self, x):
|
110 |
+
_, _, h, w = x.size()
|
111 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
112 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
113 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
|
114 |
+
return x
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
|
118 |
+
img_channel = 3
|
119 |
+
width = 64
|
120 |
+
|
121 |
+
enc_blks = [1, 2, 3]
|
122 |
+
middle_blk_num_enc = 2
|
123 |
+
middle_blk_num_dec = 2
|
124 |
+
dec_blks = [3, 1, 1]
|
125 |
+
residual_layers = None
|
126 |
+
dilations = [1, 4, 9]
|
127 |
+
extra_depth_wise = True
|
128 |
+
|
129 |
+
net = DarkIR(img_channel=img_channel,
|
130 |
+
width=width,
|
131 |
+
middle_blk_num_enc=middle_blk_num_enc,
|
132 |
+
middle_blk_num_dec= middle_blk_num_dec,
|
133 |
+
enc_blk_nums=enc_blks,
|
134 |
+
dec_blk_nums=dec_blks,
|
135 |
+
dilations = dilations,
|
136 |
+
extra_depth_wise = extra_depth_wise)
|
137 |
+
|
138 |
+
new_state_dict = net.state_dict()
|
139 |
+
|
140 |
+
inp_shape = (3, 256, 256)
|
141 |
+
|
142 |
+
net.load_state_dict(new_state_dict)
|
143 |
+
|
144 |
+
from ptflops import get_model_complexity_info
|
145 |
+
|
146 |
+
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
147 |
+
|
148 |
+
print(macs, params)
|
149 |
+
|
150 |
+
weights = net.state_dict()
|
151 |
+
adapter_weights = {k: v for k, v in weights.items() if 'adapter' not in k}
|
152 |
+
|
153 |
+
|
154 |
+
|
archs/__init__.py
CHANGED
@@ -1,4 +1,232 @@
|
|
1 |
-
|
2 |
-
from .
|
|
|
|
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
3 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
4 |
+
from ptflops import get_model_complexity_info
|
5 |
|
6 |
+
from .DarkIR import DarkIR
|
7 |
+
|
8 |
+
def create_model(opt, rank, adapter = False):
|
9 |
+
'''
|
10 |
+
Creates the model.
|
11 |
+
opt: a dictionary from the yaml config key network
|
12 |
+
'''
|
13 |
+
name = opt['name']
|
14 |
+
|
15 |
+
|
16 |
+
model = DarkIR(img_channel=opt['img_channels'],
|
17 |
+
width=opt['width'],
|
18 |
+
middle_blk_num_enc=opt['middle_blk_num_enc'],
|
19 |
+
middle_blk_num_dec=opt['middle_blk_num_dec'],
|
20 |
+
enc_blk_nums=opt['enc_blk_nums'],
|
21 |
+
dec_blk_nums=opt['dec_blk_nums'],
|
22 |
+
dilations=opt['dilations'],
|
23 |
+
extra_depth_wise=opt['extra_depth_wise'])
|
24 |
+
|
25 |
+
if rank ==0:
|
26 |
+
print(f'Using {name} network')
|
27 |
+
|
28 |
+
input_size = (3, 256, 256)
|
29 |
+
macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False)
|
30 |
+
print(f'Computational complexity at {input_size}: {macs}')
|
31 |
+
print('Number of parameters: ', params)
|
32 |
+
else:
|
33 |
+
macs, params = None, None
|
34 |
+
|
35 |
+
model.to(rank)
|
36 |
+
|
37 |
+
model = DDP(model, device_ids=[rank], find_unused_parameters=adapter)
|
38 |
+
|
39 |
+
return model, macs, params
|
40 |
+
|
41 |
+
def create_optim_scheduler(opt, model):
|
42 |
+
'''
|
43 |
+
Returns the optim and its scheduler.
|
44 |
+
opt: a dictionary of the yaml config file with the train key
|
45 |
+
'''
|
46 |
+
optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) ,
|
47 |
+
lr = opt['lr_initial'],
|
48 |
+
weight_decay = opt['weight_decay'],
|
49 |
+
betas = opt['betas'])
|
50 |
+
|
51 |
+
if opt['lr_scheme'] == 'CosineAnnealing':
|
52 |
+
scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min'])
|
53 |
+
else:
|
54 |
+
raise NotImplementedError('scheduler not implemented')
|
55 |
+
|
56 |
+
return optim, scheduler
|
57 |
+
|
58 |
+
def load_weights(model, old_weights):
|
59 |
+
'''
|
60 |
+
Loads the weights of a pretrained model, picking only the weights that are
|
61 |
+
in the new model.
|
62 |
+
'''
|
63 |
+
new_weights = model.state_dict()
|
64 |
+
new_weights.update({k: v for k, v in old_weights.items() if k in new_weights})
|
65 |
+
|
66 |
+
model.load_state_dict(new_weights)
|
67 |
+
return model
|
68 |
+
|
69 |
+
def load_optim(optim, optim_weights):
|
70 |
+
'''
|
71 |
+
Loads the values of the optimizer picking only the weights that are in the new model.
|
72 |
+
'''
|
73 |
+
optim_new_weights = optim.state_dict()
|
74 |
+
# optim_new_weights.load_state_dict(optim_weights)
|
75 |
+
optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights})
|
76 |
+
return optim
|
77 |
+
|
78 |
+
def resume_model(model,
|
79 |
+
optim,
|
80 |
+
scheduler,
|
81 |
+
path_model,
|
82 |
+
rank,resume:str=None):
|
83 |
+
'''
|
84 |
+
Returns the loaded weights of model and optimizer if resume flag is True
|
85 |
+
'''
|
86 |
+
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
87 |
+
if resume:
|
88 |
+
checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
|
89 |
+
weights = checkpoints['model_state_dict']
|
90 |
+
model = load_weights(model, old_weights=weights)
|
91 |
+
optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
|
92 |
+
scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
|
93 |
+
start_epochs = checkpoints['epoch']
|
94 |
+
|
95 |
+
if rank == 0: print('Loaded weights')
|
96 |
+
else:
|
97 |
+
start_epochs = 0
|
98 |
+
if rank==0: print('Starting from zero the training')
|
99 |
+
|
100 |
+
return model, optim, scheduler, start_epochs
|
101 |
+
|
102 |
+
def find_different_keys(dict1, dict2):
|
103 |
+
|
104 |
+
# Finding different keys
|
105 |
+
different_keys = set(dict1.keys()) ^ set(dict2.keys())
|
106 |
+
|
107 |
+
return different_keys
|
108 |
+
|
109 |
+
def number_common_keys(dict1, dict2):
|
110 |
+
# Finding common keys
|
111 |
+
common_keys = set(dict1.keys()) & set(dict2.keys())
|
112 |
+
|
113 |
+
# Counting the number of common keys
|
114 |
+
common_keys_count = len(common_keys)
|
115 |
+
return common_keys_count
|
116 |
+
|
117 |
+
# # Function to add 'modules_list' prefix after the first numeric index
|
118 |
+
# def add_middle_prefix(state_dict, middle_prefix, target_strings):
|
119 |
+
# new_state_dict = {}
|
120 |
+
# for key, value in state_dict.items():
|
121 |
+
# for target in target_strings:
|
122 |
+
# if target in key:
|
123 |
+
# parts = key.split('.')
|
124 |
+
# # Find the first numeric index after the target string
|
125 |
+
# for i, part in enumerate(parts):
|
126 |
+
# if part == target:
|
127 |
+
# # Insert the middle prefix after the first numeric index
|
128 |
+
# if i + 1 < len(parts) and parts[i + 1].isdigit():
|
129 |
+
# parts.insert(i + 2, middle_prefix)
|
130 |
+
# break
|
131 |
+
# new_key = '.'.join(parts)
|
132 |
+
# new_state_dict[new_key] = value
|
133 |
+
# break
|
134 |
+
# else:
|
135 |
+
# new_state_dict[key] = value
|
136 |
+
# return new_state_dict
|
137 |
+
|
138 |
+
# # Function to adjust keys for 'middle_blks.' prefix
|
139 |
+
# def adjust_middle_blks_keys(state_dict, target_prefix, middle_prefix):
|
140 |
+
# new_state_dict = {}
|
141 |
+
# for key, value in state_dict.items():
|
142 |
+
# if target_prefix in key:
|
143 |
+
# parts = key.split('.')
|
144 |
+
# # Find the target prefix and adjust the key
|
145 |
+
# for i, part in enumerate(parts):
|
146 |
+
# if part == target_prefix.rstrip('.'):
|
147 |
+
# if i + 1 < len(parts) and parts[i + 1].isdigit():
|
148 |
+
# # Swap the numerical part and the middle prefix
|
149 |
+
# new_key = '.'.join(parts[:i + 1] + [middle_prefix] + parts[i + 1:i + 2] + parts[i + 2:])
|
150 |
+
# new_state_dict[new_key] = value
|
151 |
+
# break
|
152 |
+
# else:
|
153 |
+
# new_state_dict[key] = value
|
154 |
+
# return new_state_dict
|
155 |
+
|
156 |
+
# def resume_nafnet(model,
|
157 |
+
# optim,
|
158 |
+
# scheduler,
|
159 |
+
# path_adapter,
|
160 |
+
# path_model,
|
161 |
+
# rank, resume:str=None):
|
162 |
+
# '''
|
163 |
+
# Returns the loaded weights of model and optimizer if resume flag is True
|
164 |
+
# '''
|
165 |
+
# map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
166 |
+
# #first load the model weights
|
167 |
+
# checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
|
168 |
+
# weights = checkpoints
|
169 |
+
# if rank==0:
|
170 |
+
# print(len(weights), len(model.state_dict().keys()))
|
171 |
+
|
172 |
+
# different_keys = find_different_keys(weights, model.state_dict())
|
173 |
+
# filtered_keys = {item for item in different_keys if 'adapter' not in item}
|
174 |
+
# print(filtered_keys)
|
175 |
+
# print(len(filtered_keys))
|
176 |
+
# model = load_weights(model, old_weights=weights)
|
177 |
+
# #now if needed load the adapter weights
|
178 |
+
# if resume:
|
179 |
+
# checkpoints = torch.load(path_adapter, map_location=map_location, weights_only=False)
|
180 |
+
# weights = checkpoints
|
181 |
+
# model = load_weights(model, old_weights=weights)
|
182 |
+
# # optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
|
183 |
+
# scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
|
184 |
+
# start_epochs = checkpoints['epoch']
|
185 |
+
|
186 |
+
# if rank == 0: print('Loaded weights')
|
187 |
+
# else:
|
188 |
+
# start_epochs = 0
|
189 |
+
# if rank == 0: print('Starting from zero the training')
|
190 |
+
|
191 |
+
# return model, optim, scheduler, start_epochs
|
192 |
+
|
193 |
+
def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None):
|
194 |
+
|
195 |
+
'''
|
196 |
+
Save the .pt of the model after each epoch.
|
197 |
+
'''
|
198 |
+
best_psnr = metrics_train['best_psnr']
|
199 |
+
if rank!=0:
|
200 |
+
return best_psnr
|
201 |
+
|
202 |
+
if type(next(iter(metrics_eval.values()))) != dict:
|
203 |
+
metrics_eval = {'metrics': metrics_eval}
|
204 |
+
|
205 |
+
weights = model.state_dict()
|
206 |
+
|
207 |
+
# Save the model after every epoch
|
208 |
+
model_to_save = {
|
209 |
+
'epoch': metrics_train['epoch'],
|
210 |
+
'model_state_dict': weights,
|
211 |
+
'optimizer_state_dict': optim.state_dict(),
|
212 |
+
'loss': metrics_train['train_loss'],
|
213 |
+
'scheduler_state_dict': scheduler.state_dict()
|
214 |
+
}
|
215 |
+
|
216 |
+
try:
|
217 |
+
torch.save(model_to_save, paths['new'])
|
218 |
+
|
219 |
+
# Save best model if new valid_psnr is higher than the best one
|
220 |
+
if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']:
|
221 |
+
torch.save(model_to_save, paths['best'])
|
222 |
+
metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr'] # update best psnr
|
223 |
+
except Exception as e:
|
224 |
+
print(f"Error saving model: {e}")
|
225 |
+
return metrics_train['best_psnr']
|
226 |
+
|
227 |
+
__all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint',
|
228 |
+
'load_optim', 'load_weights']
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
archs/arch_model.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.init as init
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
try:
|
7 |
+
from .arch_util import LayerNorm2d
|
8 |
+
except:
|
9 |
+
from arch_util import LayerNorm2d
|
10 |
+
|
11 |
+
|
12 |
+
class SimpleGate(nn.Module):
|
13 |
+
def forward(self, x):
|
14 |
+
x1, x2 = x.chunk(2, dim=1)
|
15 |
+
return x1 * x2
|
16 |
+
|
17 |
+
class Adapter(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, c, ffn_channel = None):
|
20 |
+
super().__init__()
|
21 |
+
if ffn_channel:
|
22 |
+
ffn_channel = 2
|
23 |
+
else:
|
24 |
+
ffn_channel = c
|
25 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
26 |
+
self.conv2 = nn.Conv2d(in_channels=ffn_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
27 |
+
self.depthwise = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1)
|
28 |
+
|
29 |
+
def forward(self, input):
|
30 |
+
|
31 |
+
x = self.conv1(input) + self.depthwise(input)
|
32 |
+
x = self.conv2(x)
|
33 |
+
|
34 |
+
return x
|
35 |
+
|
36 |
+
class FreMLP(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self, nc, expand = 2):
|
39 |
+
super(FreMLP, self).__init__()
|
40 |
+
self.process1 = nn.Sequential(
|
41 |
+
nn.Conv2d(nc, expand * nc, 1, 1, 0),
|
42 |
+
nn.LeakyReLU(0.1, inplace=True),
|
43 |
+
nn.Conv2d(expand * nc, nc, 1, 1, 0))
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
_, _, H, W = x.shape
|
47 |
+
x_freq = torch.fft.rfft2(x, norm='backward')
|
48 |
+
mag = torch.abs(x_freq)
|
49 |
+
pha = torch.angle(x_freq)
|
50 |
+
mag = self.process1(mag)
|
51 |
+
real = mag * torch.cos(pha)
|
52 |
+
imag = mag * torch.sin(pha)
|
53 |
+
x_out = torch.complex(real, imag)
|
54 |
+
x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
|
55 |
+
return x_out
|
56 |
+
|
57 |
+
class Branch(nn.Module):
|
58 |
+
'''
|
59 |
+
Branch that lasts lonly the dilated convolutions
|
60 |
+
'''
|
61 |
+
def __init__(self, c, DW_Expand, dilation = 1):
|
62 |
+
super().__init__()
|
63 |
+
self.dw_channel = DW_Expand * c
|
64 |
+
|
65 |
+
self.branch = nn.Sequential(
|
66 |
+
nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel, kernel_size=3, padding=dilation, stride=1, groups=self.dw_channel,
|
67 |
+
bias=True, dilation = dilation) # the dconv
|
68 |
+
)
|
69 |
+
def forward(self, input):
|
70 |
+
return self.branch(input)
|
71 |
+
|
72 |
+
class DBlock(nn.Module):
|
73 |
+
'''
|
74 |
+
Change this block using Branch
|
75 |
+
'''
|
76 |
+
|
77 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, dilations = [1], extra_depth_wise = False):
|
78 |
+
super().__init__()
|
79 |
+
#we define the 2 branches
|
80 |
+
self.dw_channel = DW_Expand * c
|
81 |
+
|
82 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
|
83 |
+
self.extra_conv = nn.Conv2d(self.dw_channel, self.dw_channel, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity() #optional extra dw
|
84 |
+
self.branches = nn.ModuleList()
|
85 |
+
for dilation in dilations:
|
86 |
+
self.branches.append(Branch(self.dw_channel, DW_Expand = 1, dilation = dilation))
|
87 |
+
|
88 |
+
assert len(dilations) == len(self.branches)
|
89 |
+
self.dw_channel = DW_Expand * c
|
90 |
+
self.sca = nn.Sequential(
|
91 |
+
nn.AdaptiveAvgPool2d(1),
|
92 |
+
nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
93 |
+
groups=1, bias=True, dilation = 1),
|
94 |
+
)
|
95 |
+
self.sg1 = SimpleGate()
|
96 |
+
self.sg2 = SimpleGate()
|
97 |
+
self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
|
98 |
+
ffn_channel = FFN_Expand * c
|
99 |
+
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
100 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
101 |
+
|
102 |
+
self.norm1 = LayerNorm2d(c)
|
103 |
+
self.norm2 = LayerNorm2d(c)
|
104 |
+
|
105 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
106 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
107 |
+
|
108 |
+
|
109 |
+
# self.adapter = Adapter(c, ffn_channel=None)
|
110 |
+
|
111 |
+
# self.use_adapters = False
|
112 |
+
|
113 |
+
# def set_use_adapters(self, use_adapters):
|
114 |
+
# self.use_adapters = use_adapters
|
115 |
+
|
116 |
+
def forward(self, inp, adapter = None):
|
117 |
+
|
118 |
+
y = inp
|
119 |
+
x = self.norm1(inp)
|
120 |
+
# x = self.conv1(self.extra_conv(x))
|
121 |
+
x = self.extra_conv(self.conv1(x))
|
122 |
+
z = 0
|
123 |
+
for branch in self.branches:
|
124 |
+
z += branch(x)
|
125 |
+
|
126 |
+
z = self.sg1(z)
|
127 |
+
x = self.sca(z) * z
|
128 |
+
x = self.conv3(x)
|
129 |
+
y = inp + self.beta * x
|
130 |
+
#second step
|
131 |
+
x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
|
132 |
+
x = self.sg2(x) # size [B, C, H, W]
|
133 |
+
x = self.conv5(x) # size [B, C, H, W]
|
134 |
+
x = y + x * self.gamma
|
135 |
+
|
136 |
+
# if self.use_adapters:
|
137 |
+
# return self.adapter(x)
|
138 |
+
# else:
|
139 |
+
return x
|
140 |
+
|
141 |
+
class EBlock(nn.Module):
|
142 |
+
'''
|
143 |
+
Change this block using Branch
|
144 |
+
'''
|
145 |
+
|
146 |
+
def __init__(self, c, DW_Expand=2, dilations = [1], extra_depth_wise = False):
|
147 |
+
super().__init__()
|
148 |
+
#we define the 2 branches
|
149 |
+
self.dw_channel = DW_Expand * c
|
150 |
+
self.extra_conv = nn.Conv2d(c, c, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity() #optional extra dw
|
151 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
|
152 |
+
|
153 |
+
self.branches = nn.ModuleList()
|
154 |
+
for dilation in dilations:
|
155 |
+
self.branches.append(Branch(c, DW_Expand, dilation = dilation))
|
156 |
+
|
157 |
+
assert len(dilations) == len(self.branches)
|
158 |
+
self.dw_channel = DW_Expand * c
|
159 |
+
self.sca = nn.Sequential(
|
160 |
+
nn.AdaptiveAvgPool2d(1),
|
161 |
+
nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
162 |
+
groups=1, bias=True, dilation = 1),
|
163 |
+
)
|
164 |
+
self.sg1 = SimpleGate()
|
165 |
+
self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
|
166 |
+
# second step
|
167 |
+
|
168 |
+
self.norm1 = LayerNorm2d(c)
|
169 |
+
self.norm2 = LayerNorm2d(c)
|
170 |
+
self.freq = FreMLP(nc = c, expand=2)
|
171 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
172 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
173 |
+
|
174 |
+
|
175 |
+
# self.adapter = Adapter(c, ffn_channel=None)
|
176 |
+
|
177 |
+
# self.use_adapters = False
|
178 |
+
|
179 |
+
# def set_use_adapters(self, use_adapters):
|
180 |
+
# self.use_adapters = use_adapters
|
181 |
+
|
182 |
+
def forward(self, inp):
|
183 |
+
y = inp
|
184 |
+
x = self.norm1(inp)
|
185 |
+
x = self.conv1(self.extra_conv(x))
|
186 |
+
z = 0
|
187 |
+
for branch in self.branches:
|
188 |
+
z += branch(x)
|
189 |
+
|
190 |
+
z = self.sg1(z)
|
191 |
+
x = self.sca(z) * z
|
192 |
+
x = self.conv3(x)
|
193 |
+
y = inp + self.beta * x
|
194 |
+
#second step
|
195 |
+
x_step2 = self.norm2(y) # size [B, 2*C, H, W]
|
196 |
+
x_freq = self.freq(x_step2) # size [B, C, H, W]
|
197 |
+
x = y * x_freq
|
198 |
+
x = y + x * self.gamma
|
199 |
+
|
200 |
+
# if self.use_adapters:
|
201 |
+
# return self.adapter(x)
|
202 |
+
# else:
|
203 |
+
return x
|
204 |
+
|
205 |
+
#----------------------------------------------------------------------------------------------
|
206 |
+
if __name__ == '__main__':
|
207 |
+
|
208 |
+
img_channel = 3
|
209 |
+
width = 32
|
210 |
+
|
211 |
+
enc_blks = [1, 2, 3]
|
212 |
+
middle_blk_num = 3
|
213 |
+
dec_blks = [3, 1, 1]
|
214 |
+
dilations = [1, 4, 9]
|
215 |
+
extra_depth_wise = True
|
216 |
+
|
217 |
+
# net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
|
218 |
+
# enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
219 |
+
net = EBlock(c = img_channel,
|
220 |
+
dilations = dilations,
|
221 |
+
extra_depth_wise=extra_depth_wise)
|
222 |
+
|
223 |
+
inp_shape = (3, 256, 256)
|
224 |
+
|
225 |
+
from ptflops import get_model_complexity_info
|
226 |
+
|
227 |
+
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
228 |
+
output = net(torch.randn((4, 3, 256, 256)))
|
229 |
+
# print('Values of EBlock:')
|
230 |
+
print(macs, params)
|
231 |
+
|
232 |
+
channels = 128
|
233 |
+
resol = 32
|
234 |
+
ksize = 5
|
235 |
+
|
236 |
+
# net = FAC(channels=channels, ksize=ksize)
|
237 |
+
# inp_shape = (channels, resol, resol)
|
238 |
+
# macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
239 |
+
# print('Values of FAC:')
|
240 |
+
# print(macs, params)
|
archs/arch_util.py
CHANGED
@@ -1,118 +1,65 @@
|
|
1 |
import torch
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
'''
|
32 |
-
|
|
|
33 |
'''
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
#we define the 2 branches
|
38 |
-
|
39 |
-
self.branches = nn.ModuleList()
|
40 |
-
for dilation in dilations:
|
41 |
-
self.branches.append(Branch(c, DW_Expand, dilation = dilation, extra_depth_wise=extra_depth_wise))
|
42 |
-
|
43 |
-
assert len(dilations) == len(self.branches)
|
44 |
-
self.dw_channel = DW_Expand * c
|
45 |
-
self.sca = nn.Sequential(
|
46 |
-
nn.AdaptiveAvgPool2d(1),
|
47 |
-
nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
48 |
-
groups=1, bias=True, dilation = 1),
|
49 |
-
)
|
50 |
-
self.sg1 = SimpleGate()
|
51 |
-
self.sg2 = SimpleGate()
|
52 |
-
self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
|
53 |
-
ffn_channel = FFN_Expand * c
|
54 |
-
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
55 |
-
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
56 |
-
|
57 |
-
self.norm1 = LayerNorm2d(c)
|
58 |
-
self.norm2 = LayerNorm2d(c)
|
59 |
-
|
60 |
-
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
61 |
-
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
62 |
-
|
63 |
-
def forward(self, inp):
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
z = self.sg1(z)
|
72 |
-
x = self.sca(z) * z
|
73 |
-
x = self.conv3(x)
|
74 |
-
y = inp + self.beta * x
|
75 |
-
#second step
|
76 |
-
x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
|
77 |
-
x = self.sg2(x) # size [B, C, H, W]
|
78 |
-
x = self.conv5(x) # size [B, C, H, W]
|
79 |
|
80 |
-
return y + x * self.gamma
|
81 |
-
|
82 |
-
#----------------------------------------------------------------------------------------------
|
83 |
if __name__ == '__main__':
|
84 |
|
85 |
-
|
86 |
-
width = 32
|
87 |
-
|
88 |
-
enc_blks = [1, 2, 3]
|
89 |
-
middle_blk_num = 3
|
90 |
-
dec_blks = [3, 1, 1]
|
91 |
-
dilations = [1, 4, 9]
|
92 |
-
extra_depth_wise = False
|
93 |
-
|
94 |
-
# net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
|
95 |
-
# enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
96 |
-
net = EBlock(c = img_channel,
|
97 |
-
dilations = dilations,
|
98 |
-
extra_depth_wise=extra_depth_wise)
|
99 |
-
|
100 |
-
inp_shape = (3, 256, 256)
|
101 |
-
|
102 |
-
from ptflops import get_model_complexity_info
|
103 |
-
|
104 |
-
# macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
105 |
-
|
106 |
-
# print('Values of EBlock:')
|
107 |
-
# print(macs, params)
|
108 |
-
|
109 |
-
channels = 128
|
110 |
-
resol = 32
|
111 |
-
ksize = 5
|
112 |
-
|
113 |
-
net = FAC(channels=channels, ksize=ksize)
|
114 |
-
inp_shape = (channels, resol, resol)
|
115 |
-
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
116 |
-
print('Values of FAC:')
|
117 |
-
print(macs, params)
|
118 |
-
|
|
|
1 |
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import init as init
|
5 |
+
import torch.distributed as dist
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
class LayerNormFunction(torch.autograd.Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, x, weight, bias, eps):
|
12 |
+
ctx.eps = eps
|
13 |
+
N, C, H, W = x.size()
|
14 |
+
mu = x.mean(1, keepdim=True)
|
15 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
16 |
+
y = (x - mu) / (var + eps).sqrt()
|
17 |
+
ctx.save_for_backward(y, var, weight)
|
18 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
19 |
+
return y
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def backward(ctx, grad_output):
|
23 |
+
eps = ctx.eps
|
24 |
+
|
25 |
+
N, C, H, W = grad_output.size()
|
26 |
+
y, var, weight = ctx.saved_variables
|
27 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
28 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
29 |
+
|
30 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
31 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
32 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
33 |
+
dim=0), None
|
34 |
+
|
35 |
+
class LayerNorm2d(nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, channels, eps=1e-6):
|
38 |
+
super(LayerNorm2d, self).__init__()
|
39 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
40 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
41 |
+
self.eps = eps
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
45 |
+
|
46 |
+
|
47 |
+
class CustomSequential(nn.Module):
|
48 |
'''
|
49 |
+
Similar to nn.Sequential, but it lets us introduce a second argument in the forward method
|
50 |
+
so adaptors can be considered in the inference.
|
51 |
'''
|
52 |
+
def __init__(self, *args):
|
53 |
+
super(CustomSequential, self).__init__()
|
54 |
+
self.modules_list = nn.ModuleList(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
def forward(self, x, use_adapter=False):
|
57 |
+
for module in self.modules_list:
|
58 |
+
if hasattr(module, 'set_use_adapters'):
|
59 |
+
module.set_use_adapters(use_adapter)
|
60 |
+
x = module(x)
|
61 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
|
|
|
|
|
|
63 |
if __name__ == '__main__':
|
64 |
|
65 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
darkir_v2real+lsrw.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e506a3ef1e60d7223a3d9846ebf4851a86f34f02a927fbcbd9f97edef62a5a6
|
3 |
+
size 13421785
|
examples/0010.png
ADDED
examples/0087.png
ADDED
examples/low00733_low.png
ADDED
examples/r13073518t_low.png
ADDED
requirements.txt
CHANGED
@@ -1,108 +1,19 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
annotated-types==0.7.0
|
4 |
-
anyio==4.4.0
|
5 |
-
attrs==23.2.0
|
6 |
-
certifi==2024.6.2
|
7 |
-
charset-normalizer==3.3.2
|
8 |
-
click==8.1.7
|
9 |
-
contourpy==1.2.1
|
10 |
-
cycler==0.12.1
|
11 |
-
dnspython==2.6.1
|
12 |
-
docker-pycreds==0.4.0
|
13 |
-
email_validator==2.2.0
|
14 |
-
exceptiongroup==1.2.1
|
15 |
-
fastapi==0.111.0
|
16 |
-
fastapi-cli==0.0.4
|
17 |
-
ffmpy==0.3.2
|
18 |
-
filelock==3.15.3
|
19 |
-
fonttools==4.53.0
|
20 |
-
fsspec==2024.6.0
|
21 |
-
gitdb==4.0.11
|
22 |
-
GitPython==3.1.43
|
23 |
-
gradio==4.36.1
|
24 |
-
gradio_client==1.0.1
|
25 |
-
h11==0.14.0
|
26 |
-
httpcore==1.0.5
|
27 |
-
httptools==0.6.1
|
28 |
-
httpx==0.27.0
|
29 |
-
huggingface-hub==0.23.4
|
30 |
-
idna==3.7
|
31 |
-
importlib_resources==6.4.0
|
32 |
-
Jinja2==3.1.4
|
33 |
-
jsonschema==4.22.0
|
34 |
-
jsonschema-specifications==2023.12.1
|
35 |
-
kiwisolver==1.4.5
|
36 |
kornia==0.7.2
|
37 |
-
kornia_rs==0.1.3
|
38 |
lpips==0.1.4
|
39 |
-
markdown-it-py==3.0.0
|
40 |
-
MarkupSafe==2.1.5
|
41 |
-
matplotlib==3.9.0
|
42 |
-
mdurl==0.1.2
|
43 |
-
mpmath==1.3.0
|
44 |
-
networkx==3.3
|
45 |
numpy==2.0.0
|
46 |
-
nvidia-cublas-cu12==12.1.3.1
|
47 |
-
nvidia-cuda-cupti-cu12==12.1.105
|
48 |
-
nvidia-cuda-nvrtc-cu12==12.1.105
|
49 |
-
nvidia-cuda-runtime-cu12==12.1.105
|
50 |
-
nvidia-cudnn-cu12==8.9.2.26
|
51 |
-
nvidia-cufft-cu12==11.0.2.54
|
52 |
-
nvidia-curand-cu12==10.3.2.106
|
53 |
-
nvidia-cusolver-cu12==11.4.5.107
|
54 |
-
nvidia-cusparse-cu12==12.1.0.106
|
55 |
-
nvidia-nccl-cu12==2.20.5
|
56 |
-
nvidia-nvjitlink-cu12==12.5.40
|
57 |
-
nvidia-nvtx-cu12==12.1.105
|
58 |
opencv-python==4.10.0.84
|
59 |
-
orjson==3.10.5
|
60 |
-
packaging==24.1
|
61 |
pandas==2.2.2
|
62 |
pillow==10.3.0
|
63 |
-
platformdirs==4.2.2
|
64 |
-
protobuf==5.27.1
|
65 |
-
psutil==6.0.0
|
66 |
ptflops==0.7.3
|
67 |
-
|
68 |
-
pydantic_core==2.18.4
|
69 |
-
pydub==0.25.1
|
70 |
-
Pygments==2.18.0
|
71 |
-
pyparsing==3.1.2
|
72 |
-
python-dateutil==2.9.0.post0
|
73 |
-
python-dotenv==1.0.1
|
74 |
-
python-multipart==0.0.9
|
75 |
pytorch-msssim==1.0.0
|
76 |
-
pytz==2024.1
|
77 |
PyYAML==6.0.1
|
78 |
-
|
79 |
-
requests==2.32.3
|
80 |
-
rich==13.7.1
|
81 |
-
rpds-py==0.18.1
|
82 |
-
ruff==0.4.10
|
83 |
scipy==1.13.1
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
shellingham==1.5.4
|
88 |
-
six==1.16.0
|
89 |
-
smmap==5.0.1
|
90 |
-
sniffio==1.3.1
|
91 |
-
starlette==0.37.2
|
92 |
-
sympy==1.12.1
|
93 |
-
tomlkit==0.12.0
|
94 |
-
toolz==0.12.1
|
95 |
-
torch==2.3.1
|
96 |
-
torchvision==0.18.1
|
97 |
tqdm==4.66.4
|
98 |
-
|
99 |
-
typer==0.12.3
|
100 |
-
typing_extensions==4.12.2
|
101 |
-
tzdata==2024.1
|
102 |
-
ujson==5.10.0
|
103 |
-
urllib3==2.2.2
|
104 |
-
uvicorn==0.30.1
|
105 |
-
uvloop==0.19.0
|
106 |
-
wandb==0.17.2
|
107 |
-
watchfiles==0.22.0
|
108 |
-
websockets==11.0.3
|
|
|
1 |
+
einops==0.8.0
|
2 |
+
gradio==5.8.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
kornia==0.7.2
|
|
|
4 |
lpips==0.1.4
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
numpy==2.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
opencv-python==4.10.0.84
|
|
|
|
|
7 |
pandas==2.2.2
|
8 |
pillow==10.3.0
|
|
|
|
|
|
|
9 |
ptflops==0.7.3
|
10 |
+
pyiqa==0.1.13
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
pytorch-msssim==1.0.0
|
|
|
12 |
PyYAML==6.0.1
|
13 |
+
scikit-image==0.24.0
|
|
|
|
|
|
|
|
|
14 |
scipy==1.13.1
|
15 |
+
torch==2.5.1
|
16 |
+
torchaudio==2.5.1
|
17 |
+
torchvision==0.20.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
tqdm==4.66.4
|
19 |
+
wandb==0.17.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|