Spaces:
Running
on
T4
Running
on
T4
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Power by Zongsheng Yue 2022-07-13 16:59:27 | |
import os | |
import random | |
import numpy as np | |
from math import ceil | |
from pathlib import Path | |
from einops import rearrange | |
from omegaconf import OmegaConf | |
from skimage import img_as_ubyte | |
from ResizeRight.resize_right import resize | |
from utils import util_net | |
from utils import util_image | |
from utils import util_common | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from basicsr.utils import img2tensor | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from basicsr.utils.realesrgan_utils import RealESRGANer | |
from facelib.utils.face_restoration_helper import FaceRestoreHelper | |
class BaseSampler: | |
def __init__(self, configs): | |
''' | |
Input: | |
configs: config, see the yaml file in folder ./configs/sample/ | |
''' | |
self.configs = configs | |
self.display = configs.display | |
self.diffusion_cfg = configs.diffusion | |
self.setup_dist() # setup distributed training: self.num_gpus, self.rank | |
self.setup_seed() # setup seed | |
self.build_model() | |
def setup_seed(self, seed=None): | |
seed = self.configs.seed if seed is None else seed | |
seed += (self.rank+1) * 10000 | |
if self.rank == 0 and self.display: | |
print(f'Setting random seed {seed}') | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def setup_dist(self): | |
if torch.cuda.is_available(): | |
self.device = torch.device('cuda') | |
print(f'Runing on GPU...') | |
else: | |
self.device = torch.device('cpu') | |
print(f'Runing on CPU...') | |
self.rank = 0 | |
def build_model(self): | |
obj = util_common.get_obj_from_str(self.configs.diffusion.target) | |
self.diffusion = obj(**self.configs.diffusion.params) | |
obj = util_common.get_obj_from_str(self.configs.model.target) | |
model = obj(**self.configs.model.params).to(self.device) | |
if not self.configs.model.ckpt_path is None: | |
self.load_model(model, self.configs.model.ckpt_path) | |
self.model = model | |
self.model.eval() | |
def load_model(self, model, ckpt_path=None): | |
if not ckpt_path is None: | |
if self.rank == 0 and self.display: | |
print(f'Loading from {ckpt_path}...') | |
ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") | |
util_net.reload_model(model, ckpt) | |
if self.rank == 0 and self.display: | |
print('Loaded Done') | |
def reset_diffusion(self, diffusion_cfg): | |
self.diffusion = create_gaussian_diffusion(**diffusion_cfg) | |
class DifIRSampler(BaseSampler): | |
def build_model(self): | |
super().build_model() | |
if not self.configs.model_ir is None: | |
obj = util_common.get_obj_from_str(self.configs.model_ir.target) | |
model_ir = obj(**self.configs.model_ir.params).cuda() | |
if not self.configs.model_ir.ckpt_path is None: | |
self.load_model(model_ir, self.configs.model_ir.ckpt_path) | |
self.model_ir = model_ir | |
self.model_ir.eval() | |
if not self.configs.aligned: | |
# face dection model | |
self.face_helper = FaceRestoreHelper( | |
self.configs.detection.upscale, | |
face_size=self.configs.im_size, | |
crop_ratio=(1, 1), | |
det_model = self.configs.detection.det_model, | |
save_ext='png', | |
use_parse=True, | |
device=self.device, | |
) | |
# background super-resolution | |
if self.configs.background_enhance or self.configs.face_upsample: | |
bg_model = RRDBNet( | |
num_in_ch=3, | |
num_out_ch=3, | |
num_feat=64, | |
num_block=23, | |
num_grow_ch=32, | |
scale=2, | |
) | |
self.bg_model = RealESRGANer( | |
scale=2, | |
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', | |
model=bg_model, | |
tile=400, | |
tile_pad=10, | |
pre_pad=0, | |
half=True, | |
device=torch.device(f'cuda:{self.rank}'), | |
) # need to set False in CPU mode | |
def sample_func_ir_aligned( | |
self, | |
y0, | |
start_timesteps=None, | |
post_fun=None, | |
model_kwargs_ir=None, | |
need_restoration=True, | |
): | |
''' | |
Input: | |
y0: n x c x h x w torch tensor, low-quality image, [0, 1], RGB | |
or, h x w x c, numpy array, [0, 255], uint8, BGR | |
start_timesteps: integer, range [0, num_timesteps-1], | |
for accelerated sampling (e.g., 'ddim250'), range [0, 249] | |
post_fun: post-processing for the enhanced image | |
model_kwargs_ir: additional parameters for restoration model | |
Output: | |
sample: n x c x h x w, torch tensor, [0,1], RGB | |
''' | |
if not isinstance(y0, torch.Tensor): | |
y0 = img2tensor(y0, bgr2rgb=True, float32=True).unsqueeze(0) / 255. # 1 x c x h x w, [0,1] | |
if start_timesteps is None: | |
start_timesteps = self.diffusion.num_timesteps | |
if post_fun is None: | |
post_fun = lambda x: util_image.normalize_th( | |
im=x, | |
mean=0.5, | |
std=0.5, | |
reverse=False, | |
) | |
# basical image restoration | |
device = next(self.model.parameters()).device | |
y0 = y0.to(device=device, dtype=torch.float32) | |
h_old, w_old = y0.shape[2:4] | |
if not (h_old == self.configs.im_size and w_old == self.configs.im_size): | |
y0 = resize(y0, out_shape=(self.configs.im_size,) * 2).to(torch.float32) | |
if need_restoration: | |
with torch.no_grad(): | |
if model_kwargs_ir is None: | |
im_hq = self.model_ir(y0) | |
else: | |
im_hq = self.model_ir(y0, **model_kwargs_ir) | |
else: | |
im_hq = y0 | |
im_hq.clamp_(0.0, 1.0) | |
# diffuse for im_hq | |
yt = self.diffusion.q_sample( | |
x_start=post_fun(im_hq), | |
t=torch.tensor([start_timesteps,]*im_hq.shape[0], device=device), | |
) | |
assert yt.shape[-1] == self.configs.im_size and yt.shape[-2] == self.configs.im_size | |
if 'ddim' in self.configs.diffusion.params.timestep_respacing: | |
sample = self.diffusion.ddim_sample_loop( | |
self.model, | |
shape=yt.shape, | |
noise=yt, | |
start_timesteps=start_timesteps, | |
clip_denoised=True, | |
denoised_fn=None, | |
model_kwargs=None, | |
device=None, | |
progress=False, | |
eta=0.0, | |
) | |
else: | |
sample = self.diffusion.p_sample_loop( | |
self.model, | |
shape=yt.shape, | |
noise=yt, | |
start_timesteps=start_timesteps, | |
clip_denoised=True, | |
denoised_fn=None, | |
model_kwargs=None, | |
device=None, | |
progress=False, | |
) | |
sample = util_image.normalize_th(sample, reverse=True).clamp(0.0, 1.0) | |
if not (h_old == self.configs.im_size and w_old == self.configs.im_size): | |
sample = resize(sample, out_shape=(h_old, w_old)).clamp(0.0, 1.0) | |
return sample, im_hq | |
def sample_func_bfr_unaligned( | |
self, | |
y0, | |
bs=16, | |
start_timesteps=None, | |
post_fun=None, | |
model_kwargs_ir=None, | |
need_restoration=True, | |
only_center_face=False, | |
draw_box=False, | |
): | |
''' | |
Input: | |
y0: h x w x c numpy array, uint8, BGR | |
bs: batch size for face restoration | |
upscale: upsampling factor for the restorated image | |
start_timesteps: integer, range [0, num_timesteps-1], | |
for accelerated sampling (e.g., 'ddim250'), range [0, 249] | |
post_fun: post-processing for the enhanced image | |
model_kwargs_ir: additional parameters for restoration model | |
only_center_face: | |
draw_box: draw a box for each face | |
Output: | |
restored_img: h x w x c, numpy array, uint8, BGR | |
restored_faces: list, h x w x c, numpy array, uint8, BGR | |
cropped_faces: list, h x w x c, numpy array, uint8, BGR | |
''' | |
def _process_batch(cropped_faces_list): | |
length = len(cropped_faces_list) | |
cropped_face_t = np.stack( | |
img2tensor(cropped_faces_list, bgr2rgb=True, float32=True), | |
axis=0) / 255. | |
cropped_face_t = torch.from_numpy(cropped_face_t).to(torch.device(f"cuda:{self.rank}")) | |
restored_faces = self.sample_func_ir_aligned( | |
cropped_face_t, | |
start_timesteps=start_timesteps, | |
post_fun=post_fun, | |
model_kwargs_ir=model_kwargs_ir, | |
need_restoration=need_restoration, | |
)[0] # [0, 1], b x c x h x w | |
return restored_faces | |
assert not self.configs.aligned | |
self.face_helper.clean_all() | |
self.face_helper.read_image(y0) | |
num_det_faces = self.face_helper.get_face_landmarks_5( | |
only_center_face=only_center_face, | |
resize=640, | |
eye_dist_threshold=5, | |
) | |
# align and warp each face | |
self.face_helper.align_warp_face() | |
num_cropped_face = len(self.face_helper.cropped_faces) | |
if num_cropped_face > bs: | |
restored_faces = [] | |
for idx_start in range(0, num_cropped_face, bs): | |
idx_end = idx_start + bs if idx_start + bs < num_cropped_face else num_cropped_face | |
current_cropped_faces = self.face_helper.cropped_faces[idx_start:idx_end] | |
current_restored_faces = _process_batch(current_cropped_faces) | |
current_restored_faces = util_image.tensor2img( | |
list(current_restored_faces.split(1, dim=0)), | |
rgb2bgr=True, | |
min_max=(0, 1), | |
out_type=np.uint8, | |
) | |
restored_faces.extend(current_restored_faces) | |
else: | |
restored_faces = _process_batch(self.face_helper.cropped_faces) | |
restored_faces = util_image.tensor2img( | |
list(restored_faces.split(1, dim=0)), | |
rgb2bgr=True, | |
min_max=(0, 1), | |
out_type=np.uint8, | |
) | |
for xx in restored_faces: | |
self.face_helper.add_restored_face(xx) | |
# paste_back | |
if self.configs.background_enhance: | |
bg_img = self.bg_model.enhance(y0, outscale=self.configs.detection.upscale)[0] | |
else: | |
bg_img = None | |
self.face_helper.get_inverse_affine(None) | |
# paste each restored face to the input image | |
if self.configs.face_upsample: | |
restored_img = self.face_helper.paste_faces_to_input_image( | |
upsample_img=bg_img, | |
draw_box=draw_box, | |
face_upsampler=self.bg_model, | |
) | |
else: | |
restored_img = self.face_helper.paste_faces_to_input_image( | |
upsample_img=bg_img, | |
draw_box=draw_box, | |
) | |
cropped_faces = self.face_helper.cropped_faces | |
return restored_img, restored_faces, cropped_faces | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--save_dir", | |
type=str, | |
default="./save_dir", | |
help="Folder to save the checkpoints and training log", | |
) | |
parser.add_argument( | |
"--gpu_id", | |
type=str, | |
default='', | |
help="GPU Index, e.g., 025", | |
) | |
parser.add_argument( | |
"--cfg_path", | |
type=str, | |
default='./configs/sample/iddpm_ffhq256.yaml', | |
help="Path of config files", | |
) | |
parser.add_argument( | |
"--bs", | |
type=int, | |
default=32, | |
help="Batch size", | |
) | |
parser.add_argument( | |
"--num_images", | |
type=int, | |
default=3000, | |
help="Number of sampled images", | |
) | |
parser.add_argument( | |
"--timestep_respacing", | |
type=str, | |
default='1000', | |
help="Sampling steps for accelerate", | |
) | |
args = parser.parse_args() | |
configs = OmegaConf.load(args.cfg_path) | |
configs.gpu_id = args.gpu_id | |
configs.diffusion.params.timestep_respacing = args.timestep_respacing | |
sampler_dist = DiffusionSampler(configs) | |
sampler_dist.sample_func( | |
bs=args.bs, | |
num_images=args.num_images, | |
save_dir=args.save_dir, | |
) | |