Iceclear commited on
Commit
1e02395
·
verified ·
1 Parent(s): f0aa253

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -94
app.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- This file is used for deploying hugging face demo:
3
- https://huggingface.co/spaces/
4
- """
5
-
6
  import sys
7
  sys.path.append('StableSR')
8
  import os
@@ -25,48 +20,41 @@ from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_
25
  from scripts.util_image import ImageSpliterTh
26
  from basicsr.utils.download_util import load_file_from_url
27
  from einops import rearrange, repeat
 
28
 
29
- # os.system("pip freeze")
30
-
31
  pretrain_model_url = {
32
- 'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
33
- 'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
34
- 'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
35
  }
36
- # download weights
37
- if not os.path.exists('./stablesr_000117.ckpt'):
38
- load_file_from_url(url=pretrain_model_url['stablesr_512'], model_dir='./', progress=True, file_name=None)
39
- if not os.path.exists('./stablesr_768v_000139.ckpt'):
40
- load_file_from_url(url=pretrain_model_url['stablesr_768'], model_dir='./', progress=True, file_name=None)
41
- if not os.path.exists('./vqgan_cfw_00011.ckpt'):
42
- load_file_from_url(url=pretrain_model_url['CFW'], model_dir='./', progress=True, file_name=None)
43
-
44
- # download images
45
- torch.hub.download_url_to_file(
46
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png',
47
- '01.png')
48
- torch.hub.download_url_to_file(
49
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png',
50
- '02.png')
51
- torch.hub.download_url_to_file(
52
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png',
53
- '03.png')
54
- torch.hub.download_url_to_file(
55
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png',
56
- '04.png')
57
- torch.hub.download_url_to_file(
58
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png',
59
- '05.png')
60
 
61
  def load_img(path):
62
- image = Image.open(path).convert("RGB")
63
- w, h = image.size
64
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
65
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
66
- image = np.array(image).astype(np.float32) / 255.0
67
- image = image[None].transpose(0, 3, 1, 2)
68
- image = torch.from_numpy(image)
69
- return 2.*image - 1.
70
 
71
  def space_timesteps(num_timesteps, section_counts):
72
  """
@@ -143,11 +131,13 @@ def load_model_from_config(config, ckpt, verbose=False):
143
  model.eval()
144
  return model
145
 
146
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
147
  device = torch.device("cuda")
148
  vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
149
- vq_model = load_model_from_config(vqgan_config, './vqgan_cfw_00011.ckpt')
150
- vq_model = vq_model.to(device)
 
 
151
 
152
  os.makedirs('output', exist_ok=True)
153
 
@@ -284,6 +274,7 @@ def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type
284
  print('Global exception', error)
285
  return None, None
286
 
 
287
  with gr.Blocks(title="Exploiting Diffusion Prior for Real-World Image Super-Resolution") as demo:
288
  gr.Markdown(
289
  """
@@ -298,7 +289,7 @@ with gr.Blocks(title="Exploiting Diffusion Prior for Real-World Image Super-Reso
298
  If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!
299
  [![GitHub Stars](https://img.shields.io/github/stars/IceClear/StableSR?style=social)](https://github.com/IceClear/StableSR)
300
  ---
301
-
302
  📝 **Citation**
303
  If our work is useful for your research, please consider citing:
304
 
@@ -319,59 +310,49 @@ with gr.Blocks(title="Exploiting Diffusion Prior for Real-World Image Super-Reso
319
  If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
320
 
321
  <div>
322
- 🤗 Find Me:
323
- <a href="https://twitter.com/Iceclearwjy"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow"></a>
324
- <a href="https://github.com/IceClear"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow"></a>
325
  </div>
326
 
327
  <center><img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'></center>
328
  """
329
  )
330
 
331
- with gr.Row():
332
- with gr.Column():
333
- image = gr.Image(type="filepath", label="Input")
334
- upscale = gr.Number(value=1, label="Rescaling_Factor (Large images require huge time)")
335
- dec_w = gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity (0 for better quality, 1 for better identity)')
336
- seed = gr.Number(value=42, label="Seeds")
337
- model_type = gr.Dropdown(
338
- choices=["512", "768v"],
339
- value="512",
340
- label="Model"
341
- )
342
- ddpm_steps = gr.Slider(10, 1000, value=200, step=1, label='Sampling timesteps for DDPM')
343
- colorfix_type = gr.Dropdown(
344
- choices=["none", "adain", "wavelet"],
345
- value="adain",
346
- label="Color_Correction"
347
- )
348
- run_btn = gr.Button("Run Inference")
349
-
350
- with gr.Column():
351
- output_image = gr.Image(type="numpy", label="Output")
352
- output_file = gr.File(label="Download the output")
353
-
354
- # Inference trigger
355
- run_btn.click(
356
- fn=inference,
357
- inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
358
- outputs=[output_image, output_file]
359
- )
360
-
361
- # Example section
362
- gr.Examples(
363
- examples=[
364
- ['./01.png', 4, 0.5, 42, "512", 200, "adain"],
365
- ['./02.png', 4, 0.5, 42, "512", 200, "adain"],
366
- ['./03.png', 4, 0.5, 42, "512", 200, "adain"],
367
- ['./04.png', 4, 0.5, 42, "512", 200, "adain"],
368
- ['./05.png', 4, 0.5, 42, "512", 200, "adain"]
369
- ],
370
- fn=inference,
371
- inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
372
- outputs=[output_image, output_file],
373
- cache_examples=True
374
- )
375
 
376
  demo.queue()
377
- demo.launch()
 
 
 
 
 
 
1
  import sys
2
  sys.path.append('StableSR')
3
  import os
 
20
  from scripts.util_image import ImageSpliterTh
21
  from basicsr.utils.download_util import load_file_from_url
22
  from einops import rearrange, repeat
23
+ from itertools import islice
24
 
25
+ # Download weights
 
26
  pretrain_model_url = {
27
+ 'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
28
+ 'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
29
+ 'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
30
  }
31
+
32
+ for k, url in pretrain_model_url.items():
33
+ filename = url.split("/")[-1]
34
+ if not os.path.exists(f'./{filename}'):
35
+ load_file_from_url(url=url, model_dir='./', progress=True, file_name=None)
36
+
37
+ # Download sample images
38
+ image_urls = [
39
+ ('01.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png'),
40
+ ('02.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png'),
41
+ ('03.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png'),
42
+ ('04.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png'),
43
+ ('05.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png'),
44
+ ]
45
+
46
+ for fname, url in image_urls:
47
+ torch.hub.download_url_to_file(url, fname)
 
 
 
 
 
 
 
48
 
49
  def load_img(path):
50
+ image = Image.open(path).convert("RGB")
51
+ w, h = image.size
52
+ w, h = map(lambda x: x - x % 32, (w, h))
53
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
54
+ image = np.array(image).astype(np.float32) / 255.0
55
+ image = image[None].transpose(0, 3, 1, 2)
56
+ image = torch.from_numpy(image)
57
+ return 2.*image - 1.
58
 
59
  def space_timesteps(num_timesteps, section_counts):
60
  """
 
131
  model.eval()
132
  return model
133
 
134
+ # Load VQGAN model
135
  device = torch.device("cuda")
136
  vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
137
+ vq_model = instantiate_from_config(vqgan_config.model)
138
+ vq_sd = torch.load('./vqgan_cfw_00011.ckpt', map_location='cpu')['state_dict']
139
+ vq_model.load_state_dict(vq_sd, strict=False)
140
+ vq_model.cuda().eval()
141
 
142
  os.makedirs('output', exist_ok=True)
143
 
 
274
  print('Global exception', error)
275
  return None, None
276
 
277
+ # Gradio UI
278
  with gr.Blocks(title="Exploiting Diffusion Prior for Real-World Image Super-Resolution") as demo:
279
  gr.Markdown(
280
  """
 
289
  If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!
290
  [![GitHub Stars](https://img.shields.io/github/stars/IceClear/StableSR?style=social)](https://github.com/IceClear/StableSR)
291
  ---
292
+
293
  📝 **Citation**
294
  If our work is useful for your research, please consider citing:
295
 
 
310
  If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
311
 
312
  <div>
313
+ 🤗 Find Me:
314
+ <a href="https://twitter.com/Iceclearwjy"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow"></a>
315
+ <a href="https://github.com/IceClear"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow"></a>
316
  </div>
317
 
318
  <center><img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'></center>
319
  """
320
  )
321
 
322
+ with gr.Row():
323
+ with gr.Column():
324
+ image = gr.Image(type="filepath", label="Input")
325
+ upscale = gr.Number(value=1, label="Rescaling_Factor")
326
+ dec_w = gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity')
327
+ seed = gr.Number(value=42, label="Seeds")
328
+ model_type = gr.Dropdown(choices=["512", "768v"], value="512", label="Model")
329
+ ddpm_steps = gr.Slider(10, 1000, value=200, step=1, label='DDPM Steps')
330
+ colorfix_type = gr.Dropdown(choices=["none", "adain", "wavelet"], value="adain", label="Color Correction")
331
+ run_btn = gr.Button("Run Inference")
332
+
333
+ with gr.Column():
334
+ output_image = gr.Image(type="numpy", label="Output")
335
+ output_file = gr.File(label="Download the output")
336
+
337
+ run_btn.click(
338
+ fn=inference,
339
+ inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
340
+ outputs=[output_image, output_file]
341
+ )
342
+
343
+ gr.Examples(
344
+ examples=[
345
+ ['01.png', 4, 0.5, 42, "512", 200, "adain"],
346
+ ['02.png', 4, 0.5, 42, "512", 200, "adain"],
347
+ ['03.png', 4, 0.5, 42, "512", 200, "adain"],
348
+ ['04.png', 4, 0.5, 42, "512", 200, "adain"],
349
+ ['05.png', 4, 0.5, 42, "512", 200, "adain"]
350
+ ],
351
+ fn=inference,
352
+ inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
353
+ outputs=[output_image, output_file],
354
+ cache_examples=True
355
+ )
 
 
 
 
 
 
 
 
 
 
356
 
357
  demo.queue()
358
+ demo.launch()