File size: 15,263 Bytes
e571ea9 eec823a e571ea9 7444ebf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
This file is used for deploying hugging face demo:
import sys
import os
import cv2
import torch
import torch.nn.functional as F
import gradio as gr
import torchvision
from torchvision.transforms.functional import normalize
from ldm.util import instantiate_from_config
from torch import autocast
import PIL
import numpy as np
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from omegaconf import OmegaConf
from PIL import Image
import copy
from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
from scripts.util_image import ImageSpliterTh
from basicsr.utils.download_util import load_file_from_url
from einops import rearrange, repeat
# os.system("pip freeze")
pretrain_model_url = {
'stablesr_512': '',
'stablesr_768': '',
'CFW': '',
# download weights
if not os.path.exists('./stablesr_000117.ckpt'):
load_file_from_url(url=pretrain_model_url['stablesr_512'], model_dir='./', progress=True, file_name=None)
if not os.path.exists('./stablesr_768v_000139.ckpt'):
load_file_from_url(url=pretrain_model_url['stablesr_768'], model_dir='./', progress=True, file_name=None)
if not os.path.exists('./vqgan_cfw_00011.ckpt'):
load_file_from_url(url=pretrain_model_url['CFW'], model_dir='./', progress=True, file_name=None)
# download images
def load_img(path):
image ="RGB")
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
def space_timesteps(num_timesteps, section_counts):
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim"):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
section_counts = [int(x) for x in section_counts.split(",")] #[250,]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
if section_count <= 1:
frac_stride = 1
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
if len(u) > 0 and verbose:
print("unexpected keys:")
return model
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda")
vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
vq_model = load_model_from_config(vqgan_config, './vqgan_cfw_00011.ckpt')
vq_model =
os.makedirs('output', exist_ok=True)
def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type):
"""Run a single prediction on the model"""
precision_scope = autocast
vq_model.decoder.fusion_w = dec_w
if model_type == '512':
config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml")
model = load_model_from_config(config, "./stablesr_000117.ckpt")
min_size = 512
config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml")
model = load_model_from_config(config, "./stablesr_768v_000139.ckpt")
min_size = 768
model =
model.configs = config
model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
model.num_timesteps = 1000
sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
last_alpha_cumprod = 1.0
new_betas = []
timestep_map = []
for i, alpha_cumprod in enumerate(model.alphas_cumprod):
if i in use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
new_betas = [ for beta in new_betas]
model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
model.num_timesteps = 1000
model.ori_timesteps = list(use_timesteps)
model =
try: # global try
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
init_image = load_img(image)
init_image = F.interpolate(
if init_image.size(-1) < min_size or init_image.size(-2) < min_size:
ori_size = init_image.size()
rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1))
new_h = max(int(ori_size[-2]*rescale), min_size)
new_w = max(int(ori_size[-1]*rescale), min_size)
init_template = F.interpolate(
size=(new_h, new_w),
init_template = init_image
rescale = 1
init_template = init_template.clamp(-1, 1)
assert init_template.size(-1) >= min_size
assert init_template.size(-2) >= min_size
init_template = init_template.type(torch.float16).to(device)
if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280:
init_latent_generator, enc_fea_lq = vq_model.encode(init_template)
init_latent = model.get_first_stage_encoding(init_latent_generator)
text_init = ['']*init_template.size(0)
semantic_c = model.cond_stage_model(text_init)
noise = torch.randn_like(init_latent)
t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
t =
x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size:
samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True)
samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0))
x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
if colorfix_type == 'adain':
x_samples = adaptive_instance_normalization(x_samples, init_template)
elif colorfix_type == 'wavelet':
x_samples = wavelet_reconstruction(x_samples, init_template)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1)
for im_lq_pch, index_infos in im_spliter:
init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch)) # move to latent space
text_init = ['']*init_latent.size(0)
semantic_c = model.cond_stage_model(text_init)
noise = torch.randn_like(init_latent)
# If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0))
t =
x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
# x_T = noise
samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0))
_, enc_fea_lq = vq_model.encode(im_lq_pch)
x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
if colorfix_type == 'adain':
x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
elif colorfix_type == 'wavelet':
x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
im_spliter.update(x_samples, index_infos)
x_samples = im_spliter.gather()
x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)
if rescale > 1:
x_samples = F.interpolate(
x_samples = x_samples.clamp(0, 1)
x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c')
restored_img = x_sample.astype(np.uint8)
return restored_img, f'output/out.png'
except Exception as error:
print('Global exception', error)
return None, None
title = "Exploiting Diffusion Prior for Real-World Image Super-Resolution"
description = r"""<center><img src='' style='height:40px' alt='StableSR logo'></center>
<b>Official Gradio demo</b> for <a href='' target='_blank'><b>Exploiting Diffusion Prior for Real-World Image Super-Resolution</b></a>.<br>
π₯ StableSR is a general image super-resolution algorithm for real-world and AIGC images.<br>
article = r"""
If StableSR is helpful, please help to β the <a href='' target='_blank'>Github Repo</a>. Thanks!
[](
π **Citation**
If our work is useful for your research, please consider citing:
author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin C.K. and Loy, Chen Change},
title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
article = {International Journal of Computer Vision},
year = {2024}
π **License**
This project is licensed under <a rel="license" href="">S-Lab License 1.0</a>.
Redistribution and use for non-commercial purposes should follow this license.
π§ **Contact**
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
π€ Find Me:
<a href=""><img style="margin-top:0.5em; margin-bottom:0.5em" src="" alt="Twitter Follow"></a>
<a href=""><img style="margin-top:0.5em; margin-bottom:2em" src="" alt="Github Follow"></a>
<center><img src='' alt='visitors'></center>
demo = gr.Interface(
inference, [
gr.inputs.Image(type="filepath", label="Input"),
gr.inputs.Number(default=1, label="Rescaling_Factor (Large images require huge time)"),
gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity (0 for better quality, 1 for better identity)'),
gr.inputs.Number(default=42, label="Seeds"),
choices=["512", "768v"],
gr.Slider(10, 1000, value=200, step=1, label='Sampling timesteps for DDPM'),
choices=["none", "adain", "wavelet"],
], [
gr.outputs.Image(type="numpy", label="Output"),
gr.outputs.File(label="Download the output")
['./01.png', 4, 0.5, 42, "512", 200, "adain"],
['./02.png', 4, 0.5, 42, "512", 200, "adain"],
['./03.png', 4, 0.5, 42, "512", 200, "adain"],
['./04.png', 4, 0.5, 42, "512", 200, "adain"],
['./05.png', 4, 0.5, 42, "512", 200, "adain"]