danifei commited on
Commit
159cb3e
·
verified ·
1 Parent(s): 36db309

Upload 12 files

Browse files
README.md CHANGED
@@ -1,13 +1,9 @@
1
- ---
2
- title: Low Light Deblurring New
3
- emoji: 📉
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ title: DarkIR
2
+ emoji: 🌻
 
3
  colorFrom: red
4
  colorTo: gray
5
  sdk: gradio
6
+ sdk_version: 5.8.0
7
  app_file: app.py
8
  pinned: false
9
+ license: mit
 
 
 
app.py CHANGED
@@ -1,84 +1,89 @@
1
  import gradio as gr
2
  from PIL import Image
3
- import os
4
  import torch
5
- import torch.nn.functional as F
6
  import torchvision.transforms as transforms
7
- import torchvision
8
- import numpy as np
9
- import yaml
10
- from huggingface_hub import hf_hub_download
11
 
12
- from archs import Network
13
- from options.options import parse
14
 
15
- path_opt = './options/predict/LOLBlur.yml'
16
 
17
- opt = parse(path_opt)
18
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
  #define some auxiliary functions
20
  pil_to_tensor = transforms.ToTensor()
 
21
 
22
- # PATH_MODEL = opt['save']['best']
23
 
24
- model = Network(img_channel=opt['network']['img_channels'],
25
- width=opt['network']['width'],
26
- middle_blk_num_enc=opt['network']['middle_blk_num_enc'],
27
- middle_blk_num_dec=opt['network']['middle_blk_num_dec'],
28
- enc_blk_nums=opt['network']['enc_blk_nums'],
29
- dec_blk_nums=opt['network']['dec_blk_nums'],
30
- dilations=opt['network']['dilations'],
31
- extra_depth_wise = opt['network']['extra_depth_wise'])
32
 
33
- checkpoints = torch.load('Network_noFAC_LOLBlur.pt', map_location=device)
34
- # print(checkpoints)
35
- model.load_state_dict(checkpoints['model_state_dict'])
 
 
 
 
 
 
 
 
36
 
37
  model = model.to(device)
38
 
39
- def load_img (filename):
40
- img = Image.open(filename).convert("RGB")
41
- img_tensor = pil_to_tensor(img)
42
- return img_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def process_img(image):
45
- img = np.array(image)
46
- img = img / 255.
47
- img = img.astype(np.float32)
48
- y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
49
 
50
  with torch.no_grad():
51
- x_hat = model(y)
52
-
53
- restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
54
- restored_img = np.clip(restored_img, 0. , 1.)
55
 
56
- restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
57
- return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
 
58
 
59
- title = "Low-Light-Deblurring 🌚🌠🌝"
60
- description = ''' ## [Low Light Image deblurring enhancement](https://github.com/cidautai/Net-Low-light-Deblurring)
61
 
62
  [Daniel Feijoo](https://github.com/danifei)
63
 
64
  Fundación Cidaut
65
 
66
- This model enhances low light images into normal light conditions ones. It was trained using LOLv2-real, LOLv2-synth and LOLBlur.
67
- Due to the training on LOLBlur, this network is expected to also reconstruct blurred low light images.
68
 
69
  > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
70
- **This demo expects an image with some degradations.**
71
- Due to the CPU limitations, the model won't return results inmediately <br>.
72
- Except for the LOLv2-real, the model was trained using mostly synthetic data, thus it might not work great on real-world complex images.
73
 
74
  <br>
75
  '''
76
 
77
- examples = [['examples/inputs/0010.png'],
78
- ['examples/inputs/0060.png'],
79
- ['examples/inputs/0075.png'],
80
- ["examples/inputs/0087.png"],
81
- ["examples/inputs/0088.png"]]
82
 
83
  css = """
84
  .image-frame img, .image-container img {
 
1
  import gradio as gr
2
  from PIL import Image
 
3
  import torch
 
4
  import torchvision.transforms as transforms
5
+ import torch.nn.functional as F
6
+
7
+ from archs import DarkIR
 
8
 
 
 
9
 
 
10
 
 
11
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
12
  #define some auxiliary functions
13
  pil_to_tensor = transforms.ToTensor()
14
+ tensor_to_pil = transforms.ToPILImage()
15
 
16
+ network = 'DarkIR'
17
 
18
+ PATH_MODEL = './models/darkir_1k_allv2_251205.pt'
 
 
 
 
 
 
 
19
 
20
+ model = DarkIR(img_channel=3,
21
+ width=32,
22
+ middle_blk_num_enc=2,
23
+ middle_blk_num_dec=2,
24
+ enc_blk_nums=[1, 2, 3],
25
+ dec_blk_nums=[3, 1, 1],
26
+ dilations=[1, 4, 9],
27
+ extra_depth_wise=True)
28
+
29
+ checkpoints = torch.load(PATH_MODEL, map_location=device)
30
+ model.load_state_dict(checkpoints['params'])
31
 
32
  model = model.to(device)
33
 
34
+ def path_to_tensor(path):
35
+ img = Image.open(path).convert('RGB')
36
+ img = pil_to_tensor(img).unsqueeze(0)
37
+
38
+ return img
39
+ def normalize_tensor(tensor):
40
+
41
+ max_value = torch.max(tensor)
42
+ min_value = torch.min(tensor)
43
+ output = (tensor - min_value)/(max_value)
44
+ return output
45
+
46
+ def pad_tensor(tensor, multiple = 8):
47
+ '''pad the tensor to be multiple of some number'''
48
+ multiple = multiple
49
+ _, _, H, W = tensor.shape
50
+ pad_h = (multiple - H % multiple) % multiple
51
+ pad_w = (multiple - W % multiple) % multiple
52
+ tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0)
53
+
54
+ return tensor
55
 
56
  def process_img(image):
57
+ tensor = path_to_tensor(image).to(device)
58
+ _, _, H, W = tensor.shape
59
+
60
+ tensor = pad_tensor(tensor)
61
 
62
  with torch.no_grad():
63
+ output = model(tensor, side_loss=False)
 
 
 
64
 
65
+ output = torch.clamp(output, 0., 1.)
66
+ output = output[:,:, :H, :W].squeeze(0)
67
+ return tensor_to_pil(output)
68
 
69
+ title = "DarkIR ✏️🖼️ 🤗"
70
+ description = ''' ## [ DarkIR: Robust Low-Light Image Restoration](https://github.com/cidautai/DarkIR)
71
 
72
  [Daniel Feijoo](https://github.com/danifei)
73
 
74
  Fundación Cidaut
75
 
 
 
76
 
77
  > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
78
+ **This demo expects an image with some Low-Light degradations.**
 
 
79
 
80
  <br>
81
  '''
82
 
83
+ examples = [['examples/0010.png'],
84
+ ['examples/r13073518t_low.png'],
85
+ ['examples/low00733_low.png'],
86
+ ["examples/0087.png"]]
 
87
 
88
  css = """
89
  .image-frame img, .image-container img {
archs/DarkIR.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ try:
5
+ from arch_model import EBlock, DBlock
6
+ from arch_util import CustomSequential
7
+ except:
8
+ from archs.arch_model import EBlock, DBlock
9
+ from .arch_util import CustomSequential
10
+
11
+ class DarkIR(nn.Module):
12
+
13
+ def __init__(self, img_channel=3,
14
+ width=32,
15
+ middle_blk_num_enc=2,
16
+ middle_blk_num_dec=2,
17
+ enc_blk_nums=[1, 2, 3],
18
+ dec_blk_nums=[3, 1, 1],
19
+ dilations = [1, 4, 9],
20
+ extra_depth_wise = True):
21
+ super(DarkIR, self).__init__()
22
+
23
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
24
+ bias=True)
25
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
26
+ bias=True)
27
+
28
+ self.encoders = nn.ModuleList()
29
+ self.decoders = nn.ModuleList()
30
+ self.middle_blks = nn.ModuleList()
31
+ self.ups = nn.ModuleList()
32
+ self.downs = nn.ModuleList()
33
+
34
+ chan = width
35
+ for num in enc_blk_nums:
36
+ self.encoders.append(
37
+ CustomSequential(
38
+ *[EBlock(chan, extra_depth_wise=extra_depth_wise) for _ in range(num)]
39
+ )
40
+ )
41
+ self.downs.append(
42
+ nn.Conv2d(chan, 2*chan, 2, 2)
43
+ )
44
+ chan = chan * 2
45
+
46
+ self.middle_blks_enc = \
47
+ CustomSequential(
48
+ *[EBlock(chan, extra_depth_wise=extra_depth_wise) for _ in range(middle_blk_num_enc)]
49
+ )
50
+ self.middle_blks_dec = \
51
+ CustomSequential(
52
+ *[DBlock(chan, dilations=dilations, extra_depth_wise=extra_depth_wise) for _ in range(middle_blk_num_dec)]
53
+ )
54
+
55
+ for num in dec_blk_nums:
56
+ self.ups.append(
57
+ nn.Sequential(
58
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
59
+ nn.PixelShuffle(2)
60
+ )
61
+ )
62
+ chan = chan // 2
63
+ self.decoders.append(
64
+ CustomSequential(
65
+ *[DBlock(chan, dilations=dilations, extra_depth_wise=extra_depth_wise) for _ in range(num)]
66
+ )
67
+ )
68
+ self.padder_size = 2 ** len(self.encoders)
69
+
70
+ # this layer is needed for the computing of the middle loss. It isn't necessary for anything else
71
+ self.side_out = nn.Conv2d(in_channels = width * 2**len(self.encoders), out_channels = img_channel,
72
+ kernel_size = 3, stride=1, padding=1)
73
+
74
+ def forward(self, input, side_loss = False, use_adapter = None):
75
+
76
+ _, _, H, W = input.shape
77
+
78
+ input = self.check_image_size(input)
79
+ x = self.intro(input)
80
+
81
+ skips = []
82
+ for encoder, down in zip(self.encoders, self.downs):
83
+ x = encoder(x)
84
+ skips.append(x)
85
+ x = down(x)
86
+
87
+ # we apply the encoder transforms
88
+ x_light = self.middle_blks_enc(x)
89
+
90
+ if side_loss:
91
+ out_side = self.side_out(x_light)
92
+ # apply the decoder transforms
93
+ x = self.middle_blks_dec(x_light)
94
+ x = x + x_light
95
+
96
+ for decoder, up, skip in zip(self.decoders, self.ups, skips[::-1]):
97
+ x = up(x)
98
+ x = x + skip
99
+ x = decoder(x)
100
+
101
+ x = self.ending(x)
102
+ x = x + input
103
+ out = x[:, :, :H, :W] # we recover the original size of the image
104
+ if side_loss:
105
+ return out_side, out
106
+ else:
107
+ return out
108
+
109
+ def check_image_size(self, x):
110
+ _, _, h, w = x.size()
111
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
112
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
113
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
114
+ return x
115
+
116
+ if __name__ == '__main__':
117
+
118
+ img_channel = 3
119
+ width = 64
120
+
121
+ enc_blks = [1, 2, 3]
122
+ middle_blk_num_enc = 2
123
+ middle_blk_num_dec = 2
124
+ dec_blks = [3, 1, 1]
125
+ residual_layers = None
126
+ dilations = [1, 4, 9]
127
+ extra_depth_wise = True
128
+
129
+ net = DarkIR(img_channel=img_channel,
130
+ width=width,
131
+ middle_blk_num_enc=middle_blk_num_enc,
132
+ middle_blk_num_dec= middle_blk_num_dec,
133
+ enc_blk_nums=enc_blks,
134
+ dec_blk_nums=dec_blks,
135
+ dilations = dilations,
136
+ extra_depth_wise = extra_depth_wise)
137
+
138
+ new_state_dict = net.state_dict()
139
+
140
+ inp_shape = (3, 256, 256)
141
+
142
+ net.load_state_dict(new_state_dict)
143
+
144
+ from ptflops import get_model_complexity_info
145
+
146
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
147
+
148
+ print(macs, params)
149
+
150
+ weights = net.state_dict()
151
+ adapter_weights = {k: v for k, v in weights.items() if 'adapter' not in k}
152
+
153
+
154
+
archs/__init__.py CHANGED
@@ -1,4 +1,232 @@
1
- from .nafnet_utils.arch_model import NAFNet
2
- from .network import Network
 
 
3
 
4
- __all__ = ['NAFNet','Network']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim.lr_scheduler import CosineAnnealingLR
3
+ from torch.nn.parallel import DistributedDataParallel as DDP
4
+ from ptflops import get_model_complexity_info
5
 
6
+ from .DarkIR import DarkIR
7
+
8
+ def create_model(opt, rank, adapter = False):
9
+ '''
10
+ Creates the model.
11
+ opt: a dictionary from the yaml config key network
12
+ '''
13
+ name = opt['name']
14
+
15
+
16
+ model = DarkIR(img_channel=opt['img_channels'],
17
+ width=opt['width'],
18
+ middle_blk_num_enc=opt['middle_blk_num_enc'],
19
+ middle_blk_num_dec=opt['middle_blk_num_dec'],
20
+ enc_blk_nums=opt['enc_blk_nums'],
21
+ dec_blk_nums=opt['dec_blk_nums'],
22
+ dilations=opt['dilations'],
23
+ extra_depth_wise=opt['extra_depth_wise'])
24
+
25
+ if rank ==0:
26
+ print(f'Using {name} network')
27
+
28
+ input_size = (3, 256, 256)
29
+ macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False)
30
+ print(f'Computational complexity at {input_size}: {macs}')
31
+ print('Number of parameters: ', params)
32
+ else:
33
+ macs, params = None, None
34
+
35
+ model.to(rank)
36
+
37
+ model = DDP(model, device_ids=[rank], find_unused_parameters=adapter)
38
+
39
+ return model, macs, params
40
+
41
+ def create_optim_scheduler(opt, model):
42
+ '''
43
+ Returns the optim and its scheduler.
44
+ opt: a dictionary of the yaml config file with the train key
45
+ '''
46
+ optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) ,
47
+ lr = opt['lr_initial'],
48
+ weight_decay = opt['weight_decay'],
49
+ betas = opt['betas'])
50
+
51
+ if opt['lr_scheme'] == 'CosineAnnealing':
52
+ scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min'])
53
+ else:
54
+ raise NotImplementedError('scheduler not implemented')
55
+
56
+ return optim, scheduler
57
+
58
+ def load_weights(model, old_weights):
59
+ '''
60
+ Loads the weights of a pretrained model, picking only the weights that are
61
+ in the new model.
62
+ '''
63
+ new_weights = model.state_dict()
64
+ new_weights.update({k: v for k, v in old_weights.items() if k in new_weights})
65
+
66
+ model.load_state_dict(new_weights)
67
+ return model
68
+
69
+ def load_optim(optim, optim_weights):
70
+ '''
71
+ Loads the values of the optimizer picking only the weights that are in the new model.
72
+ '''
73
+ optim_new_weights = optim.state_dict()
74
+ # optim_new_weights.load_state_dict(optim_weights)
75
+ optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights})
76
+ return optim
77
+
78
+ def resume_model(model,
79
+ optim,
80
+ scheduler,
81
+ path_model,
82
+ rank,resume:str=None):
83
+ '''
84
+ Returns the loaded weights of model and optimizer if resume flag is True
85
+ '''
86
+ map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
87
+ if resume:
88
+ checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
89
+ weights = checkpoints['model_state_dict']
90
+ model = load_weights(model, old_weights=weights)
91
+ optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
92
+ scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
93
+ start_epochs = checkpoints['epoch']
94
+
95
+ if rank == 0: print('Loaded weights')
96
+ else:
97
+ start_epochs = 0
98
+ if rank==0: print('Starting from zero the training')
99
+
100
+ return model, optim, scheduler, start_epochs
101
+
102
+ def find_different_keys(dict1, dict2):
103
+
104
+ # Finding different keys
105
+ different_keys = set(dict1.keys()) ^ set(dict2.keys())
106
+
107
+ return different_keys
108
+
109
+ def number_common_keys(dict1, dict2):
110
+ # Finding common keys
111
+ common_keys = set(dict1.keys()) & set(dict2.keys())
112
+
113
+ # Counting the number of common keys
114
+ common_keys_count = len(common_keys)
115
+ return common_keys_count
116
+
117
+ # # Function to add 'modules_list' prefix after the first numeric index
118
+ # def add_middle_prefix(state_dict, middle_prefix, target_strings):
119
+ # new_state_dict = {}
120
+ # for key, value in state_dict.items():
121
+ # for target in target_strings:
122
+ # if target in key:
123
+ # parts = key.split('.')
124
+ # # Find the first numeric index after the target string
125
+ # for i, part in enumerate(parts):
126
+ # if part == target:
127
+ # # Insert the middle prefix after the first numeric index
128
+ # if i + 1 < len(parts) and parts[i + 1].isdigit():
129
+ # parts.insert(i + 2, middle_prefix)
130
+ # break
131
+ # new_key = '.'.join(parts)
132
+ # new_state_dict[new_key] = value
133
+ # break
134
+ # else:
135
+ # new_state_dict[key] = value
136
+ # return new_state_dict
137
+
138
+ # # Function to adjust keys for 'middle_blks.' prefix
139
+ # def adjust_middle_blks_keys(state_dict, target_prefix, middle_prefix):
140
+ # new_state_dict = {}
141
+ # for key, value in state_dict.items():
142
+ # if target_prefix in key:
143
+ # parts = key.split('.')
144
+ # # Find the target prefix and adjust the key
145
+ # for i, part in enumerate(parts):
146
+ # if part == target_prefix.rstrip('.'):
147
+ # if i + 1 < len(parts) and parts[i + 1].isdigit():
148
+ # # Swap the numerical part and the middle prefix
149
+ # new_key = '.'.join(parts[:i + 1] + [middle_prefix] + parts[i + 1:i + 2] + parts[i + 2:])
150
+ # new_state_dict[new_key] = value
151
+ # break
152
+ # else:
153
+ # new_state_dict[key] = value
154
+ # return new_state_dict
155
+
156
+ # def resume_nafnet(model,
157
+ # optim,
158
+ # scheduler,
159
+ # path_adapter,
160
+ # path_model,
161
+ # rank, resume:str=None):
162
+ # '''
163
+ # Returns the loaded weights of model and optimizer if resume flag is True
164
+ # '''
165
+ # map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
166
+ # #first load the model weights
167
+ # checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
168
+ # weights = checkpoints
169
+ # if rank==0:
170
+ # print(len(weights), len(model.state_dict().keys()))
171
+
172
+ # different_keys = find_different_keys(weights, model.state_dict())
173
+ # filtered_keys = {item for item in different_keys if 'adapter' not in item}
174
+ # print(filtered_keys)
175
+ # print(len(filtered_keys))
176
+ # model = load_weights(model, old_weights=weights)
177
+ # #now if needed load the adapter weights
178
+ # if resume:
179
+ # checkpoints = torch.load(path_adapter, map_location=map_location, weights_only=False)
180
+ # weights = checkpoints
181
+ # model = load_weights(model, old_weights=weights)
182
+ # # optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
183
+ # scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
184
+ # start_epochs = checkpoints['epoch']
185
+
186
+ # if rank == 0: print('Loaded weights')
187
+ # else:
188
+ # start_epochs = 0
189
+ # if rank == 0: print('Starting from zero the training')
190
+
191
+ # return model, optim, scheduler, start_epochs
192
+
193
+ def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None):
194
+
195
+ '''
196
+ Save the .pt of the model after each epoch.
197
+ '''
198
+ best_psnr = metrics_train['best_psnr']
199
+ if rank!=0:
200
+ return best_psnr
201
+
202
+ if type(next(iter(metrics_eval.values()))) != dict:
203
+ metrics_eval = {'metrics': metrics_eval}
204
+
205
+ weights = model.state_dict()
206
+
207
+ # Save the model after every epoch
208
+ model_to_save = {
209
+ 'epoch': metrics_train['epoch'],
210
+ 'model_state_dict': weights,
211
+ 'optimizer_state_dict': optim.state_dict(),
212
+ 'loss': metrics_train['train_loss'],
213
+ 'scheduler_state_dict': scheduler.state_dict()
214
+ }
215
+
216
+ try:
217
+ torch.save(model_to_save, paths['new'])
218
+
219
+ # Save best model if new valid_psnr is higher than the best one
220
+ if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']:
221
+ torch.save(model_to_save, paths['best'])
222
+ metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr'] # update best psnr
223
+ except Exception as e:
224
+ print(f"Error saving model: {e}")
225
+ return metrics_train['best_psnr']
226
+
227
+ __all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint',
228
+ 'load_optim', 'load_weights']
229
+
230
+
231
+
232
+
archs/arch_model.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import torch.nn.functional as F
5
+
6
+ try:
7
+ from .arch_util import LayerNorm2d
8
+ except:
9
+ from arch_util import LayerNorm2d
10
+
11
+
12
+ class SimpleGate(nn.Module):
13
+ def forward(self, x):
14
+ x1, x2 = x.chunk(2, dim=1)
15
+ return x1 * x2
16
+
17
+ class Adapter(nn.Module):
18
+
19
+ def __init__(self, c, ffn_channel = None):
20
+ super().__init__()
21
+ if ffn_channel:
22
+ ffn_channel = 2
23
+ else:
24
+ ffn_channel = c
25
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
26
+ self.conv2 = nn.Conv2d(in_channels=ffn_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
27
+ self.depthwise = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1)
28
+
29
+ def forward(self, input):
30
+
31
+ x = self.conv1(input) + self.depthwise(input)
32
+ x = self.conv2(x)
33
+
34
+ return x
35
+
36
+ class FreMLP(nn.Module):
37
+
38
+ def __init__(self, nc, expand = 2):
39
+ super(FreMLP, self).__init__()
40
+ self.process1 = nn.Sequential(
41
+ nn.Conv2d(nc, expand * nc, 1, 1, 0),
42
+ nn.LeakyReLU(0.1, inplace=True),
43
+ nn.Conv2d(expand * nc, nc, 1, 1, 0))
44
+
45
+ def forward(self, x):
46
+ _, _, H, W = x.shape
47
+ x_freq = torch.fft.rfft2(x, norm='backward')
48
+ mag = torch.abs(x_freq)
49
+ pha = torch.angle(x_freq)
50
+ mag = self.process1(mag)
51
+ real = mag * torch.cos(pha)
52
+ imag = mag * torch.sin(pha)
53
+ x_out = torch.complex(real, imag)
54
+ x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
55
+ return x_out
56
+
57
+ class Branch(nn.Module):
58
+ '''
59
+ Branch that lasts lonly the dilated convolutions
60
+ '''
61
+ def __init__(self, c, DW_Expand, dilation = 1):
62
+ super().__init__()
63
+ self.dw_channel = DW_Expand * c
64
+
65
+ self.branch = nn.Sequential(
66
+ nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel, kernel_size=3, padding=dilation, stride=1, groups=self.dw_channel,
67
+ bias=True, dilation = dilation) # the dconv
68
+ )
69
+ def forward(self, input):
70
+ return self.branch(input)
71
+
72
+ class DBlock(nn.Module):
73
+ '''
74
+ Change this block using Branch
75
+ '''
76
+
77
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, dilations = [1], extra_depth_wise = False):
78
+ super().__init__()
79
+ #we define the 2 branches
80
+ self.dw_channel = DW_Expand * c
81
+
82
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
83
+ self.extra_conv = nn.Conv2d(self.dw_channel, self.dw_channel, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity() #optional extra dw
84
+ self.branches = nn.ModuleList()
85
+ for dilation in dilations:
86
+ self.branches.append(Branch(self.dw_channel, DW_Expand = 1, dilation = dilation))
87
+
88
+ assert len(dilations) == len(self.branches)
89
+ self.dw_channel = DW_Expand * c
90
+ self.sca = nn.Sequential(
91
+ nn.AdaptiveAvgPool2d(1),
92
+ nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
93
+ groups=1, bias=True, dilation = 1),
94
+ )
95
+ self.sg1 = SimpleGate()
96
+ self.sg2 = SimpleGate()
97
+ self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
98
+ ffn_channel = FFN_Expand * c
99
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
100
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
101
+
102
+ self.norm1 = LayerNorm2d(c)
103
+ self.norm2 = LayerNorm2d(c)
104
+
105
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
106
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
107
+
108
+
109
+ # self.adapter = Adapter(c, ffn_channel=None)
110
+
111
+ # self.use_adapters = False
112
+
113
+ # def set_use_adapters(self, use_adapters):
114
+ # self.use_adapters = use_adapters
115
+
116
+ def forward(self, inp, adapter = None):
117
+
118
+ y = inp
119
+ x = self.norm1(inp)
120
+ # x = self.conv1(self.extra_conv(x))
121
+ x = self.extra_conv(self.conv1(x))
122
+ z = 0
123
+ for branch in self.branches:
124
+ z += branch(x)
125
+
126
+ z = self.sg1(z)
127
+ x = self.sca(z) * z
128
+ x = self.conv3(x)
129
+ y = inp + self.beta * x
130
+ #second step
131
+ x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
132
+ x = self.sg2(x) # size [B, C, H, W]
133
+ x = self.conv5(x) # size [B, C, H, W]
134
+ x = y + x * self.gamma
135
+
136
+ # if self.use_adapters:
137
+ # return self.adapter(x)
138
+ # else:
139
+ return x
140
+
141
+ class EBlock(nn.Module):
142
+ '''
143
+ Change this block using Branch
144
+ '''
145
+
146
+ def __init__(self, c, DW_Expand=2, dilations = [1], extra_depth_wise = False):
147
+ super().__init__()
148
+ #we define the 2 branches
149
+ self.dw_channel = DW_Expand * c
150
+ self.extra_conv = nn.Conv2d(c, c, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity() #optional extra dw
151
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
152
+
153
+ self.branches = nn.ModuleList()
154
+ for dilation in dilations:
155
+ self.branches.append(Branch(c, DW_Expand, dilation = dilation))
156
+
157
+ assert len(dilations) == len(self.branches)
158
+ self.dw_channel = DW_Expand * c
159
+ self.sca = nn.Sequential(
160
+ nn.AdaptiveAvgPool2d(1),
161
+ nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
162
+ groups=1, bias=True, dilation = 1),
163
+ )
164
+ self.sg1 = SimpleGate()
165
+ self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
166
+ # second step
167
+
168
+ self.norm1 = LayerNorm2d(c)
169
+ self.norm2 = LayerNorm2d(c)
170
+ self.freq = FreMLP(nc = c, expand=2)
171
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
172
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
173
+
174
+
175
+ # self.adapter = Adapter(c, ffn_channel=None)
176
+
177
+ # self.use_adapters = False
178
+
179
+ # def set_use_adapters(self, use_adapters):
180
+ # self.use_adapters = use_adapters
181
+
182
+ def forward(self, inp):
183
+ y = inp
184
+ x = self.norm1(inp)
185
+ x = self.conv1(self.extra_conv(x))
186
+ z = 0
187
+ for branch in self.branches:
188
+ z += branch(x)
189
+
190
+ z = self.sg1(z)
191
+ x = self.sca(z) * z
192
+ x = self.conv3(x)
193
+ y = inp + self.beta * x
194
+ #second step
195
+ x_step2 = self.norm2(y) # size [B, 2*C, H, W]
196
+ x_freq = self.freq(x_step2) # size [B, C, H, W]
197
+ x = y * x_freq
198
+ x = y + x * self.gamma
199
+
200
+ # if self.use_adapters:
201
+ # return self.adapter(x)
202
+ # else:
203
+ return x
204
+
205
+ #----------------------------------------------------------------------------------------------
206
+ if __name__ == '__main__':
207
+
208
+ img_channel = 3
209
+ width = 32
210
+
211
+ enc_blks = [1, 2, 3]
212
+ middle_blk_num = 3
213
+ dec_blks = [3, 1, 1]
214
+ dilations = [1, 4, 9]
215
+ extra_depth_wise = True
216
+
217
+ # net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
218
+ # enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
219
+ net = EBlock(c = img_channel,
220
+ dilations = dilations,
221
+ extra_depth_wise=extra_depth_wise)
222
+
223
+ inp_shape = (3, 256, 256)
224
+
225
+ from ptflops import get_model_complexity_info
226
+
227
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
228
+ output = net(torch.randn((4, 3, 256, 256)))
229
+ # print('Values of EBlock:')
230
+ print(macs, params)
231
+
232
+ channels = 128
233
+ resol = 32
234
+ ksize = 5
235
+
236
+ # net = FAC(channels=channels, ksize=ksize)
237
+ # inp_shape = (channels, resol, resol)
238
+ # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
239
+ # print('Values of FAC:')
240
+ # print(macs, params)
archs/arch_util.py CHANGED
@@ -1,118 +1,65 @@
1
  import torch
2
- import torch.nn as nn
3
- import torch.nn.init as init
4
- import torch.nn.functional as F
5
-
6
- try:
7
- from .nafnet_utils.arch_util import LayerNorm2d
8
- from .nafnet_utils.arch_model import SimpleGate
9
- except:
10
- from nafnet_utils.arch_util import LayerNorm2d
11
- from nafnet_utils.arch_model import SimpleGate
12
-
13
-
14
- class Branch(nn.Module):
15
- '''
16
- Branch that lasts lonly the dilated convolutions
17
- '''
18
- def __init__(self, c, DW_Expand, dilation = 1, extra_depth_wise = False):
19
- super().__init__()
20
- self.dw_channel = DW_Expand * c
21
- self.branch = nn.Sequential(
22
- nn.Conv2d(c, c, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity(), #optional extra dw
23
- nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1),
24
- nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel, kernel_size=3, padding=dilation, stride=1, groups=self.dw_channel,
25
- bias=True, dilation = dilation) # the dconv
26
- )
27
- def forward(self, input):
28
- return self.branch(input)
29
-
30
- class EBlock(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  '''
32
- Change this block using Branch
 
33
  '''
34
-
35
- def __init__(self, c, DW_Expand=2, FFN_Expand=2, dilations = [1], extra_depth_wise = False):
36
- super().__init__()
37
- #we define the 2 branches
38
-
39
- self.branches = nn.ModuleList()
40
- for dilation in dilations:
41
- self.branches.append(Branch(c, DW_Expand, dilation = dilation, extra_depth_wise=extra_depth_wise))
42
-
43
- assert len(dilations) == len(self.branches)
44
- self.dw_channel = DW_Expand * c
45
- self.sca = nn.Sequential(
46
- nn.AdaptiveAvgPool2d(1),
47
- nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
48
- groups=1, bias=True, dilation = 1),
49
- )
50
- self.sg1 = SimpleGate()
51
- self.sg2 = SimpleGate()
52
- self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
53
- ffn_channel = FFN_Expand * c
54
- self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
55
- self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
56
-
57
- self.norm1 = LayerNorm2d(c)
58
- self.norm2 = LayerNorm2d(c)
59
-
60
- self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
61
- self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
62
-
63
- def forward(self, inp):
64
 
65
- y = inp
66
- x = self.norm1(inp)
67
- z = 0
68
- for branch in self.branches:
69
- z += branch(x)
70
-
71
- z = self.sg1(z)
72
- x = self.sca(z) * z
73
- x = self.conv3(x)
74
- y = inp + self.beta * x
75
- #second step
76
- x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
77
- x = self.sg2(x) # size [B, C, H, W]
78
- x = self.conv5(x) # size [B, C, H, W]
79
 
80
- return y + x * self.gamma
81
-
82
- #----------------------------------------------------------------------------------------------
83
  if __name__ == '__main__':
84
 
85
- img_channel = 3
86
- width = 32
87
-
88
- enc_blks = [1, 2, 3]
89
- middle_blk_num = 3
90
- dec_blks = [3, 1, 1]
91
- dilations = [1, 4, 9]
92
- extra_depth_wise = False
93
-
94
- # net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
95
- # enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
96
- net = EBlock(c = img_channel,
97
- dilations = dilations,
98
- extra_depth_wise=extra_depth_wise)
99
-
100
- inp_shape = (3, 256, 256)
101
-
102
- from ptflops import get_model_complexity_info
103
-
104
- # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
105
-
106
- # print('Values of EBlock:')
107
- # print(macs, params)
108
-
109
- channels = 128
110
- resol = 32
111
- ksize = 5
112
-
113
- net = FAC(channels=channels, ksize=ksize)
114
- inp_shape = (channels, resol, resol)
115
- macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
116
- print('Values of FAC:')
117
- print(macs, params)
118
-
 
1
  import torch
2
+ import numpy as np
3
+ from torch import nn as nn
4
+ from torch.nn import init as init
5
+ import torch.distributed as dist
6
+ from collections import OrderedDict
7
+
8
+ class LayerNormFunction(torch.autograd.Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, weight, bias, eps):
12
+ ctx.eps = eps
13
+ N, C, H, W = x.size()
14
+ mu = x.mean(1, keepdim=True)
15
+ var = (x - mu).pow(2).mean(1, keepdim=True)
16
+ y = (x - mu) / (var + eps).sqrt()
17
+ ctx.save_for_backward(y, var, weight)
18
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
19
+ return y
20
+
21
+ @staticmethod
22
+ def backward(ctx, grad_output):
23
+ eps = ctx.eps
24
+
25
+ N, C, H, W = grad_output.size()
26
+ y, var, weight = ctx.saved_variables
27
+ g = grad_output * weight.view(1, C, 1, 1)
28
+ mean_g = g.mean(dim=1, keepdim=True)
29
+
30
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
31
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
32
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
33
+ dim=0), None
34
+
35
+ class LayerNorm2d(nn.Module):
36
+
37
+ def __init__(self, channels, eps=1e-6):
38
+ super(LayerNorm2d, self).__init__()
39
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
40
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
41
+ self.eps = eps
42
+
43
+ def forward(self, x):
44
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
45
+
46
+
47
+ class CustomSequential(nn.Module):
48
  '''
49
+ Similar to nn.Sequential, but it lets us introduce a second argument in the forward method
50
+ so adaptors can be considered in the inference.
51
  '''
52
+ def __init__(self, *args):
53
+ super(CustomSequential, self).__init__()
54
+ self.modules_list = nn.ModuleList(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ def forward(self, x, use_adapter=False):
57
+ for module in self.modules_list:
58
+ if hasattr(module, 'set_use_adapters'):
59
+ module.set_use_adapters(use_adapter)
60
+ x = module(x)
61
+ return x
 
 
 
 
 
 
 
 
62
 
 
 
 
63
  if __name__ == '__main__':
64
 
65
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
darkir_v2real+lsrw.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e506a3ef1e60d7223a3d9846ebf4851a86f34f02a927fbcbd9f97edef62a5a6
3
+ size 13421785
examples/0010.png ADDED
examples/0087.png ADDED
examples/low00733_low.png ADDED
examples/r13073518t_low.png ADDED
requirements.txt CHANGED
@@ -1,108 +1,19 @@
1
- aiofiles==23.2.1
2
- altair==5.3.0
3
- annotated-types==0.7.0
4
- anyio==4.4.0
5
- attrs==23.2.0
6
- certifi==2024.6.2
7
- charset-normalizer==3.3.2
8
- click==8.1.7
9
- contourpy==1.2.1
10
- cycler==0.12.1
11
- dnspython==2.6.1
12
- docker-pycreds==0.4.0
13
- email_validator==2.2.0
14
- exceptiongroup==1.2.1
15
- fastapi==0.111.0
16
- fastapi-cli==0.0.4
17
- ffmpy==0.3.2
18
- filelock==3.15.3
19
- fonttools==4.53.0
20
- fsspec==2024.6.0
21
- gitdb==4.0.11
22
- GitPython==3.1.43
23
- gradio==4.36.1
24
- gradio_client==1.0.1
25
- h11==0.14.0
26
- httpcore==1.0.5
27
- httptools==0.6.1
28
- httpx==0.27.0
29
- huggingface-hub==0.23.4
30
- idna==3.7
31
- importlib_resources==6.4.0
32
- Jinja2==3.1.4
33
- jsonschema==4.22.0
34
- jsonschema-specifications==2023.12.1
35
- kiwisolver==1.4.5
36
  kornia==0.7.2
37
- kornia_rs==0.1.3
38
  lpips==0.1.4
39
- markdown-it-py==3.0.0
40
- MarkupSafe==2.1.5
41
- matplotlib==3.9.0
42
- mdurl==0.1.2
43
- mpmath==1.3.0
44
- networkx==3.3
45
  numpy==2.0.0
46
- nvidia-cublas-cu12==12.1.3.1
47
- nvidia-cuda-cupti-cu12==12.1.105
48
- nvidia-cuda-nvrtc-cu12==12.1.105
49
- nvidia-cuda-runtime-cu12==12.1.105
50
- nvidia-cudnn-cu12==8.9.2.26
51
- nvidia-cufft-cu12==11.0.2.54
52
- nvidia-curand-cu12==10.3.2.106
53
- nvidia-cusolver-cu12==11.4.5.107
54
- nvidia-cusparse-cu12==12.1.0.106
55
- nvidia-nccl-cu12==2.20.5
56
- nvidia-nvjitlink-cu12==12.5.40
57
- nvidia-nvtx-cu12==12.1.105
58
  opencv-python==4.10.0.84
59
- orjson==3.10.5
60
- packaging==24.1
61
  pandas==2.2.2
62
  pillow==10.3.0
63
- platformdirs==4.2.2
64
- protobuf==5.27.1
65
- psutil==6.0.0
66
  ptflops==0.7.3
67
- pydantic==2.7.4
68
- pydantic_core==2.18.4
69
- pydub==0.25.1
70
- Pygments==2.18.0
71
- pyparsing==3.1.2
72
- python-dateutil==2.9.0.post0
73
- python-dotenv==1.0.1
74
- python-multipart==0.0.9
75
  pytorch-msssim==1.0.0
76
- pytz==2024.1
77
  PyYAML==6.0.1
78
- referencing==0.35.1
79
- requests==2.32.3
80
- rich==13.7.1
81
- rpds-py==0.18.1
82
- ruff==0.4.10
83
  scipy==1.13.1
84
- semantic-version==2.10.0
85
- sentry-sdk==2.6.0
86
- setproctitle==1.3.3
87
- shellingham==1.5.4
88
- six==1.16.0
89
- smmap==5.0.1
90
- sniffio==1.3.1
91
- starlette==0.37.2
92
- sympy==1.12.1
93
- tomlkit==0.12.0
94
- toolz==0.12.1
95
- torch==2.3.1
96
- torchvision==0.18.1
97
  tqdm==4.66.4
98
- triton==2.3.1
99
- typer==0.12.3
100
- typing_extensions==4.12.2
101
- tzdata==2024.1
102
- ujson==5.10.0
103
- urllib3==2.2.2
104
- uvicorn==0.30.1
105
- uvloop==0.19.0
106
- wandb==0.17.2
107
- watchfiles==0.22.0
108
- websockets==11.0.3
 
1
+ einops==0.8.0
2
+ gradio==5.8.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  kornia==0.7.2
 
4
  lpips==0.1.4
 
 
 
 
 
 
5
  numpy==2.0.0
 
 
 
 
 
 
 
 
 
 
 
 
6
  opencv-python==4.10.0.84
 
 
7
  pandas==2.2.2
8
  pillow==10.3.0
 
 
 
9
  ptflops==0.7.3
10
+ pyiqa==0.1.13
 
 
 
 
 
 
 
11
  pytorch-msssim==1.0.0
 
12
  PyYAML==6.0.1
13
+ scikit-image==0.24.0
 
 
 
 
14
  scipy==1.13.1
15
+ torch==2.5.1
16
+ torchaudio==2.5.1
17
+ torchvision==0.20.1
 
 
 
 
 
 
 
 
 
 
18
  tqdm==4.66.4
19
+ wandb==0.17.2