ledetele commited on
Commit
86c2086
·
1 Parent(s): a624b9e

Create 05_image_Denoiser

Browse files
Files changed (1) hide show
  1. pages/05_image_Denoiser +250 -0
pages/05_image_Denoiser ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy
4
+ import os
5
+ import random
6
+ from basicsr.archs.rrdbnet_arch import RRDBNet
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from PIL import Image
9
+
10
+ from realesrgan import RealESRGANer
11
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
+
13
+
14
+ last_file = None
15
+ img_mode = "RGBA"
16
+
17
+
18
+ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
19
+ """Real-ESRGAN function to restore (and upscale) images.
20
+ """
21
+ if not img:
22
+ return
23
+
24
+ # Define model parameters
25
+ if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
26
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
27
+ netscale = 4
28
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
29
+ elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
30
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
31
+ netscale = 4
32
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
33
+ elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
34
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
35
+ netscale = 4
36
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
37
+ elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
38
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
39
+ netscale = 2
40
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
41
+ elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
42
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
43
+ netscale = 4
44
+ file_url = [
45
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
46
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
47
+ ]
48
+
49
+ # Determine model paths
50
+ model_path = os.path.join('weights', model_name + '.pth')
51
+ if not os.path.isfile(model_path):
52
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
53
+ for url in file_url:
54
+ # model_path will be updated
55
+ model_path = load_file_from_url(
56
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
57
+
58
+ # Use dni to control the denoise strength
59
+ dni_weight = None
60
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
61
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
62
+ model_path = [model_path, wdn_model_path]
63
+ dni_weight = [denoise_strength, 1 - denoise_strength]
64
+
65
+ # Restorer Class
66
+ upsampler = RealESRGANer(
67
+ scale=netscale,
68
+ model_path=model_path,
69
+ dni_weight=dni_weight,
70
+ model=model,
71
+ tile=0,
72
+ tile_pad=10,
73
+ pre_pad=10,
74
+ half=False,
75
+ gpu_id=None
76
+ )
77
+
78
+ # Use GFPGAN for face enhancement
79
+ if face_enhance:
80
+ from gfpgan import GFPGANer
81
+ face_enhancer = GFPGANer(
82
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
83
+ upscale=outscale,
84
+ arch='clean',
85
+ channel_multiplier=2,
86
+ bg_upsampler=upsampler)
87
+
88
+ # Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
89
+ #cv_img = numpy.array(img.get_value(), dtype = 'uint8')
90
+ cv_img = numpy.array(img)
91
+ #img = cv2.cvtColor(cv2.UMat(imgUMat), cv2.COLOR_RGB2GRAY)
92
+ img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
93
+
94
+ # Apply restoration
95
+ try:
96
+ if face_enhance:
97
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
98
+ else:
99
+ output, _ = upsampler.enhance(img, outscale=outscale)
100
+ except RuntimeError as error:
101
+ print('Error', error)
102
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
103
+ else:
104
+ # Save restored image and return it to the output Image component
105
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
106
+ extension = 'png'
107
+ else:
108
+ extension = 'jpg'
109
+
110
+ out_filename = f"output_{rnd_string(8)}.{extension}"
111
+ cv2.imwrite(out_filename, output)
112
+ global last_file
113
+ last_file = out_filename
114
+ return out_filename
115
+
116
+
117
+ def rnd_string(x):
118
+ """Returns a string of 'x' random characters
119
+ """
120
+ characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
121
+ result = "".join((random.choice(characters)) for i in range(x))
122
+ return result
123
+
124
+
125
+ def reset():
126
+ """Resets the Image components of the Gradio interface and deletes
127
+ the last processed image
128
+ """
129
+ global last_file
130
+ if last_file:
131
+ print(f"Deleting {last_file} ...")
132
+ os.remove(last_file)
133
+ last_file = None
134
+ return gr.update(value=None), gr.update(value=None)
135
+
136
+
137
+ def has_transparency(img):
138
+ """This function works by first checking to see if a "transparency" property is defined
139
+ in the image's info -- if so, we return "True". Then, if the image is using indexed colors
140
+ (such as in GIFs), it gets the index of the transparent color in the palette
141
+ (img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
142
+ (img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
143
+ it, but it double-checks by getting the minimum and maximum values of every color channel
144
+ (img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
145
+ https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
146
+ """
147
+ if img.info.get("transparency", None) is not None:
148
+ return True
149
+ if img.mode == "P":
150
+ transparent = img.info.get("transparency", -1)
151
+ for _, index in img.getcolors():
152
+ if index == transparent:
153
+ return True
154
+ elif img.mode == "RGBA":
155
+ extrema = img.getextrema()
156
+ if extrema[3][0] < 255:
157
+ return True
158
+ return False
159
+
160
+
161
+ def image_properties(img):
162
+ """Returns the dimensions (width and height) and color mode of the input image and
163
+ also sets the global img_mode variable to be used by the realesrgan function
164
+ """
165
+ global img_mode
166
+ if img:
167
+ if has_transparency(img):
168
+ img_mode = "RGBA"
169
+ else:
170
+ img_mode = "RGB"
171
+ properties = f"Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
172
+ return properties
173
+
174
+ def image_properties(image):
175
+ # Function to display image properties
176
+ properties = f"Image Size: {image.size}\nImage Mode: {image.mode}"
177
+ return properties
178
+
179
+ #----------
180
+
181
+ input_folder = '.'
182
+
183
+ @st.cache_resource
184
+ def load_image(image_file):
185
+ img = Image.open(image_file)
186
+ return img
187
+
188
+ def save_image(image_file):
189
+ if image_file is not None:
190
+ filename = image_file.name
191
+ img = load_image(image_file)
192
+ st.image(image=img, width=None)
193
+ with open(os.path.join(input_folder, filename), "wb") as f:
194
+ f.write(image_file.getbuffer())
195
+ st.success("Succesfully uploaded file for processing".format(filename))
196
+
197
+ #------------
198
+
199
+ st.title("Image Denoiser")
200
+ # Saving uploaded image in input folder for processing
201
+
202
+ #with st.expander("Options/Parameters"):
203
+
204
+ input_img = st.file_uploader(
205
+ "Upload Image", type=['png', 'jpeg', 'jpg', 'webp'])
206
+ #save_image(input_img)
207
+
208
+ model_name = "realesr-general-x4v3"
209
+
210
+ denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5)
211
+
212
+ outscale = 1
213
+
214
+ face_enhance = False
215
+
216
+ if input_img:
217
+ print(input_img)
218
+ input_img = Image.open(input_img)
219
+ # Display image properties
220
+ cols = st.columns(2)
221
+
222
+ cols[0].image(input_img, 'Source Image')
223
+
224
+ #input_properties = get_image_properties(input_img)
225
+ #cols[1].write(input_properties)
226
+
227
+ # Output placeholder
228
+ output_img = st.empty()
229
+
230
+ # Input and output placeholders
231
+ input_img = input_img
232
+ output_img = st.empty()
233
+
234
+ # Buttons
235
+ restore = st.button('Restore')
236
+ reset = st.button('Reset')
237
+
238
+ # Restore clicked
239
+ if restore:
240
+ if input_img is not None:
241
+ output = realesrgan(input_img, model_name, denoise_strength,
242
+ face_enhance, outscale)
243
+ output_img.image(output, 'Restored Image')
244
+ else:
245
+ st.warning('Upload a file', icon="⚠️")
246
+
247
+ # Reset clicked
248
+ if reset:
249
+ output_img.empty()
250
+