jason4000 commited on
Commit
358321e
·
1 Parent(s): 3417f0e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +2 -0
  3. Myloss.py +157 -0
  4. README.md +2 -8
  5. __pycache__/dataloader.cpython-311.pyc +0 -0
  6. __pycache__/model.cpython-311.pyc +0 -0
  7. app.py +81 -0
  8. data/test_data/DICM/01.jpg +0 -0
  9. data/test_data/DICM/02.jpg +0 -0
  10. data/test_data/DICM/03.jpg +0 -0
  11. data/test_data/DICM/04.jpg +0 -0
  12. data/test_data/DICM/05.jpg +0 -0
  13. data/test_data/DICM/06.jpg +0 -0
  14. data/test_data/DICM/07.jpg +0 -0
  15. data/test_data/DICM/08.jpg +0 -0
  16. data/test_data/DICM/09.jpg +0 -0
  17. data/test_data/DICM/10.jpg +0 -0
  18. data/test_data/DICM/11.jpg +0 -0
  19. data/test_data/DICM/12.jpg +0 -0
  20. data/test_data/DICM/13.jpg +0 -0
  21. data/test_data/DICM/14.jpg +0 -0
  22. data/test_data/DICM/15.jpg +0 -0
  23. data/test_data/DICM/16.jpg +0 -0
  24. data/test_data/DICM/17.jpg +0 -0
  25. data/test_data/DICM/18.jpg +0 -0
  26. data/test_data/DICM/19.jpg +0 -0
  27. data/test_data/DICM/20.jpg +0 -0
  28. data/test_data/DICM/21.jpg +0 -0
  29. data/test_data/DICM/22.jpg +0 -0
  30. data/test_data/DICM/25.jpg +0 -0
  31. data/test_data/DICM/26.jpg +0 -0
  32. data/test_data/DICM/27.jpg +0 -0
  33. data/test_data/DICM/28.jpg +0 -0
  34. data/test_data/DICM/29.jpg +0 -0
  35. data/test_data/DICM/30.jpg +0 -0
  36. data/test_data/DICM/31.jpg +0 -0
  37. data/test_data/DICM/32.jpg +0 -0
  38. data/test_data/DICM/33.jpg +0 -0
  39. data/test_data/DICM/34.jpg +0 -0
  40. data/test_data/DICM/35.jpg +0 -0
  41. data/test_data/DICM/36.jpg +0 -0
  42. data/test_data/DICM/37.jpg +0 -0
  43. data/test_data/DICM/38.jpg +0 -0
  44. data/test_data/DICM/39.jpg +0 -0
  45. data/test_data/DICM/40.jpg +0 -0
  46. data/test_data/DICM/41.jpg +0 -0
  47. data/test_data/DICM/42.jpg +0 -0
  48. data/test_data/DICM/43.jpg +0 -0
  49. data/test_data/DICM/44.jpg +0 -0
  50. data/test_data/DICM/45.jpg +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/test_data/LIME/1.bmp filter=lfs diff=lfs merge=lfs -text
37
+ data/test_data/LIME/10.bmp filter=lfs diff=lfs merge=lfs -text
38
+ data/test_data/LIME/5.bmp filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ data/
Myloss.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from torchvision.models.vgg import vgg16
6
+ import numpy as np
7
+
8
+
9
+ class L_color(nn.Module):
10
+
11
+ def __init__(self):
12
+ super(L_color, self).__init__()
13
+
14
+ def forward(self, x ):
15
+
16
+ b,c,h,w = x.shape
17
+
18
+ mean_rgb = torch.mean(x,[2,3],keepdim=True)
19
+ mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
20
+ Drg = torch.pow(mr-mg,2)
21
+ Drb = torch.pow(mr-mb,2)
22
+ Dgb = torch.pow(mb-mg,2)
23
+ k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
24
+
25
+
26
+ return k
27
+
28
+
29
+ class L_spa(nn.Module):
30
+
31
+ def __init__(self):
32
+ super(L_spa, self).__init__()
33
+ # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
34
+ kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
35
+ kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
36
+ kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
37
+ kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
38
+ self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
39
+ self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
40
+ self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
41
+ self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
42
+ self.pool = nn.AvgPool2d(4)
43
+ def forward(self, org , enhance ):
44
+ b,c,h,w = org.shape
45
+
46
+ org_mean = torch.mean(org,1,keepdim=True)
47
+ enhance_mean = torch.mean(enhance,1,keepdim=True)
48
+
49
+ org_pool = self.pool(org_mean)
50
+ enhance_pool = self.pool(enhance_mean)
51
+
52
+ weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
53
+ E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
54
+
55
+
56
+ D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
57
+ D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
58
+ D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
59
+ D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
60
+
61
+ D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
62
+ D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
63
+ D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
64
+ D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
65
+
66
+ D_left = torch.pow(D_org_letf - D_enhance_letf,2)
67
+ D_right = torch.pow(D_org_right - D_enhance_right,2)
68
+ D_up = torch.pow(D_org_up - D_enhance_up,2)
69
+ D_down = torch.pow(D_org_down - D_enhance_down,2)
70
+ E = (D_left + D_right + D_up +D_down)
71
+ # E = 25*(D_left + D_right + D_up +D_down)
72
+
73
+ return E
74
+ class L_exp(nn.Module):
75
+
76
+ def __init__(self,patch_size,mean_val):
77
+ super(L_exp, self).__init__()
78
+ # print(1)
79
+ self.pool = nn.AvgPool2d(patch_size)
80
+ self.mean_val = mean_val
81
+ def forward(self, x ):
82
+
83
+ b,c,h,w = x.shape
84
+ x = torch.mean(x,1,keepdim=True)
85
+ mean = self.pool(x)
86
+
87
+ d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
88
+ return d
89
+
90
+ class L_TV(nn.Module):
91
+ def __init__(self,TVLoss_weight=1):
92
+ super(L_TV,self).__init__()
93
+ self.TVLoss_weight = TVLoss_weight
94
+
95
+ def forward(self,x):
96
+ batch_size = x.size()[0]
97
+ h_x = x.size()[2]
98
+ w_x = x.size()[3]
99
+ count_h = (x.size()[2]-1) * x.size()[3]
100
+ count_w = x.size()[2] * (x.size()[3] - 1)
101
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
102
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
103
+ return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
104
+ class Sa_Loss(nn.Module):
105
+ def __init__(self):
106
+ super(Sa_Loss, self).__init__()
107
+ # print(1)
108
+ def forward(self, x ):
109
+ # self.grad = np.ones(x.shape,dtype=np.float32)
110
+ b,c,h,w = x.shape
111
+ # x_de = x.cpu().detach().numpy()
112
+ r,g,b = torch.split(x , 1, dim=1)
113
+ mean_rgb = torch.mean(x,[2,3],keepdim=True)
114
+ mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
115
+ Dr = r-mr
116
+ Dg = g-mg
117
+ Db = b-mb
118
+ k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
119
+ # print(k)
120
+
121
+
122
+ k = torch.mean(k)
123
+ return k
124
+
125
+ class perception_loss(nn.Module):
126
+ def __init__(self):
127
+ super(perception_loss, self).__init__()
128
+ features = vgg16(pretrained=True).features
129
+ self.to_relu_1_2 = nn.Sequential()
130
+ self.to_relu_2_2 = nn.Sequential()
131
+ self.to_relu_3_3 = nn.Sequential()
132
+ self.to_relu_4_3 = nn.Sequential()
133
+
134
+ for x in range(4):
135
+ self.to_relu_1_2.add_module(str(x), features[x])
136
+ for x in range(4, 9):
137
+ self.to_relu_2_2.add_module(str(x), features[x])
138
+ for x in range(9, 16):
139
+ self.to_relu_3_3.add_module(str(x), features[x])
140
+ for x in range(16, 23):
141
+ self.to_relu_4_3.add_module(str(x), features[x])
142
+
143
+ # don't need the gradients, just want the features
144
+ for param in self.parameters():
145
+ param.requires_grad = False
146
+
147
+ def forward(self, x):
148
+ h = self.to_relu_1_2(x)
149
+ h_relu_1_2 = h
150
+ h = self.to_relu_2_2(h)
151
+ h_relu_2_2 = h
152
+ h = self.to_relu_3_3(h)
153
+ h_relu_3_3 = h
154
+ h = self.to_relu_4_3(h)
155
+ h_relu_4_3 = h
156
+ # out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
157
+ return h_relu_4_3
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Zero-DCE Code
3
- emoji: 📈
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.35.2
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Zero-DCE_code
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 3.35.2
 
 
6
  ---
 
 
__pycache__/dataloader.cpython-311.pyc ADDED
Binary file (2.48 kB). View file
 
__pycache__/model.cpython-311.pyc ADDED
Binary file (4.33 kB). View file
 
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import torch.backends.cudnn as cudnn
5
+ import torch.optim
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import time
10
+ import dataloader
11
+ import model
12
+ import numpy as np
13
+ from torchvision import transforms
14
+ from PIL import Image
15
+ import glob
16
+ import time
17
+ import gradio as gr
18
+
19
+
20
+
21
+ def lowlight(image_path):
22
+ os.environ['CUDA_VISIBLE_DEVICES']='0'
23
+ data_lowlight = Image.open(image_path)
24
+
25
+
26
+
27
+ data_lowlight = (np.asarray(data_lowlight)/255.0)
28
+
29
+
30
+ data_lowlight = torch.from_numpy(data_lowlight).float()
31
+ data_lowlight = data_lowlight.permute(2,0,1)
32
+ data_lowlight = data_lowlight.cuda().unsqueeze(0)
33
+
34
+ DCE_net = model.enhance_net_nopool().cuda()
35
+ DCE_net.load_state_dict(torch.load('snapshots/Epoch99.pth'))
36
+ start = time.time()
37
+ _,enhanced_image,_ = DCE_net(data_lowlight)
38
+
39
+ end_time = (time.time() - start)
40
+ print(end_time)
41
+ image_path = image_path.replace('test_data','result')
42
+ result_path = image_path
43
+ if not os.path.exists(image_path.replace('/'+image_path.split("/")[-1],'')):
44
+ os.makedirs(image_path.replace('/'+image_path.split("/")[-1],''))
45
+
46
+ torchvision.utils.save_image(enhanced_image, result_path)
47
+
48
+ def predict(img):
49
+ data_lowlight = (np.asarray(img)/255.0)
50
+
51
+
52
+ data_lowlight = torch.from_numpy(data_lowlight).float()
53
+ data_lowlight = data_lowlight.permute(2,0,1)
54
+ data_lowlight = data_lowlight.cuda().unsqueeze(0)
55
+
56
+ DCE_net = model.enhance_net_nopool().cuda()
57
+ DCE_net.load_state_dict(torch.load('snapshots/Epoch99.pth'))
58
+ _,enhanced_image,_ = DCE_net(data_lowlight)
59
+
60
+ return enhanced_image
61
+
62
+
63
+ if __name__ == '__main__':
64
+ # test_images
65
+ with torch.no_grad():
66
+ # filePath = 'data/test_data/'
67
+
68
+ # file_list = os.listdir(filePath)
69
+
70
+ # for file_name in file_list:
71
+ # test_list = glob.glob(filePath+file_name+"/*")
72
+ # for image in test_list:
73
+ # # image = image
74
+ # print(image)
75
+ # lowlight(image)
76
+
77
+ interface = gr.Interface(fn=predict, inputs='image', outputs='image')
78
+ interface.launch()
79
+
80
+
81
+
data/test_data/DICM/01.jpg ADDED
data/test_data/DICM/02.jpg ADDED
data/test_data/DICM/03.jpg ADDED
data/test_data/DICM/04.jpg ADDED
data/test_data/DICM/05.jpg ADDED
data/test_data/DICM/06.jpg ADDED
data/test_data/DICM/07.jpg ADDED
data/test_data/DICM/08.jpg ADDED
data/test_data/DICM/09.jpg ADDED
data/test_data/DICM/10.jpg ADDED
data/test_data/DICM/11.jpg ADDED
data/test_data/DICM/12.jpg ADDED
data/test_data/DICM/13.jpg ADDED
data/test_data/DICM/14.jpg ADDED
data/test_data/DICM/15.jpg ADDED
data/test_data/DICM/16.jpg ADDED
data/test_data/DICM/17.jpg ADDED
data/test_data/DICM/18.jpg ADDED
data/test_data/DICM/19.jpg ADDED
data/test_data/DICM/20.jpg ADDED
data/test_data/DICM/21.jpg ADDED
data/test_data/DICM/22.jpg ADDED
data/test_data/DICM/25.jpg ADDED
data/test_data/DICM/26.jpg ADDED
data/test_data/DICM/27.jpg ADDED
data/test_data/DICM/28.jpg ADDED
data/test_data/DICM/29.jpg ADDED
data/test_data/DICM/30.jpg ADDED
data/test_data/DICM/31.jpg ADDED
data/test_data/DICM/32.jpg ADDED
data/test_data/DICM/33.jpg ADDED
data/test_data/DICM/34.jpg ADDED
data/test_data/DICM/35.jpg ADDED
data/test_data/DICM/36.jpg ADDED
data/test_data/DICM/37.jpg ADDED
data/test_data/DICM/38.jpg ADDED
data/test_data/DICM/39.jpg ADDED
data/test_data/DICM/40.jpg ADDED
data/test_data/DICM/41.jpg ADDED
data/test_data/DICM/42.jpg ADDED
data/test_data/DICM/43.jpg ADDED
data/test_data/DICM/44.jpg ADDED
data/test_data/DICM/45.jpg ADDED