|
|
|
|
|
|
|
|
|
import os, sys, math, random |
|
|
|
import cv2 |
|
import numpy as np |
|
from pathlib import Path |
|
from loguru import logger |
|
from omegaconf import OmegaConf |
|
from contextlib import nullcontext |
|
|
|
from utils import util_net |
|
from utils import util_image |
|
from utils import util_common |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
|
|
from datapipe.datasets import create_dataset |
|
from utils.util_image import ImageSpliterTh |
|
|
|
class BaseSampler: |
|
def __init__( |
|
self, |
|
configs, |
|
sf=4, |
|
use_amp=True, |
|
chop_size=128, |
|
chop_stride=128, |
|
chop_bs=1, |
|
padding_offset=16, |
|
seed=10000, |
|
): |
|
''' |
|
Input: |
|
configs: config, see the yaml file in folder ./configs/ |
|
sf: int, super-resolution scale |
|
seed: int, random seed |
|
''' |
|
self.configs = configs |
|
self.sf = sf |
|
self.chop_size = chop_size |
|
self.chop_stride = chop_stride |
|
self.chop_bs = chop_bs |
|
self.seed = seed |
|
self.use_amp = use_amp |
|
self.padding_offset = padding_offset |
|
|
|
self.setup_dist() |
|
|
|
self.setup_seed() |
|
|
|
self.build_model() |
|
|
|
def setup_seed(self, seed=None): |
|
seed = self.seed if seed is None else seed |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
def setup_dist(self, gpu_id=None): |
|
num_gpus = torch.cuda.device_count() |
|
|
|
if num_gpus > 1: |
|
|
|
|
|
|
|
|
|
|
|
rank = 0 |
|
torch.cuda.set_device(rank) |
|
|
|
self.num_gpus = num_gpus |
|
print("내가 추가한 거 num_gpus: ", num_gpus) |
|
print("내가 추가한 거 self.num_gpus: ", self.num_gpus) |
|
|
|
self.rank = int(os.environ['LOCAL_RANK']) if num_gpus > 1 else 0 |
|
print("내가 추가한 거 self.rank: ", self.rank) |
|
|
|
|
|
|
|
def write_log(self, log_str): |
|
if self.rank == 0: |
|
print(log_str, flush=True) |
|
|
|
def build_model(self): |
|
|
|
log_str = f'Building the diffusion model with length: {self.configs.diffusion.params.steps}...' |
|
self.write_log(log_str) |
|
self.base_diffusion = util_common.instantiate_from_config(self.configs.diffusion) |
|
model = util_common.instantiate_from_config(self.configs.model).cuda() |
|
ckpt_path =self.configs.model.ckpt_path |
|
assert ckpt_path is not None |
|
self.write_log(f'Loading Diffusion model from {ckpt_path}...') |
|
self.load_model(model, ckpt_path) |
|
self.freeze_model(model) |
|
self.model = model.eval() |
|
|
|
|
|
if self.configs.autoencoder is not None: |
|
ckpt_path = self.configs.autoencoder.ckpt_path |
|
assert ckpt_path is not None |
|
self.write_log(f'Loading AutoEncoder model from {ckpt_path}...') |
|
autoencoder = util_common.instantiate_from_config(self.configs.autoencoder).cuda() |
|
self.load_model(autoencoder, ckpt_path) |
|
autoencoder.eval() |
|
self.autoencoder = autoencoder |
|
else: |
|
self.autoencoder = None |
|
|
|
def load_model(self, model, ckpt_path=None): |
|
state = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") |
|
if 'state_dict' in state: |
|
state = state['state_dict'] |
|
util_net.reload_model(model, state) |
|
|
|
def freeze_model(self, net): |
|
for params in net.parameters(): |
|
params.requires_grad = False |
|
|
|
class ResShiftSampler(BaseSampler): |
|
def sample_func(self, y0, noise_repeat=False, mask=False): |
|
''' |
|
Input: |
|
y0: n x c x h x w torch tensor, low-quality image, [-1, 1], RGB |
|
mask: image mask for inpainting |
|
Output: |
|
sample: n x c x h x w, torch tensor, [-1, 1], RGB |
|
''' |
|
if noise_repeat: |
|
self.setup_seed() |
|
|
|
offset = self.padding_offset |
|
ori_h, ori_w = y0.shape[2:] |
|
if not (ori_h % offset == 0 and ori_w % offset == 0): |
|
flag_pad = True |
|
pad_h = (math.ceil(ori_h / offset)) * offset - ori_h |
|
pad_w = (math.ceil(ori_w / offset)) * offset - ori_w |
|
y0 = F.pad(y0, pad=(0, pad_w, 0, pad_h), mode='reflect') |
|
else: |
|
flag_pad = False |
|
|
|
if self.configs.model.params.cond_lq and mask is not None: |
|
model_kwargs={ |
|
'lq':y0, |
|
'mask': mask, |
|
} |
|
elif self.configs.model.params.cond_lq: |
|
model_kwargs={'lq':y0,} |
|
else: |
|
model_kwargs = None |
|
|
|
results = self.base_diffusion.p_sample_loop( |
|
y=y0, |
|
model=self.model, |
|
first_stage_model=self.autoencoder, |
|
noise=None, |
|
noise_repeat=noise_repeat, |
|
clip_denoised=(self.autoencoder is None), |
|
denoised_fn=None, |
|
model_kwargs=model_kwargs, |
|
progress=False, |
|
) |
|
|
|
if flag_pad: |
|
results = results[:, :, :ori_h*self.sf, :ori_w*self.sf] |
|
|
|
return results.clamp_(-1.0, 1.0) |
|
|
|
def inference(self, in_path, out_path, mask_path=None, mask_back=True, bs=1, noise_repeat=False): |
|
''' |
|
Inference demo. |
|
Input: |
|
in_path: str, folder or image path for LQ image |
|
out_path: str, folder save the results |
|
bs: int, default bs=1, bs % num_gpus == 0 |
|
mask_path: image mask for inpainting |
|
''' |
|
def _process_per_image(im_lq_tensor, mask=None): |
|
''' |
|
Input: |
|
im_lq_tensor: b x c x h x w, torch tensor, [-1, 1], RGB |
|
mask: image mask for inpainting, [-1, 1], 1 for unknown area |
|
Output: |
|
im_sr: h x w x c, numpy array, [0,1], RGB |
|
''' |
|
|
|
context = torch.cuda.amp.autocast if self.use_amp else nullcontext |
|
if im_lq_tensor.shape[2] > self.chop_size or im_lq_tensor.shape[3] > self.chop_size: |
|
if mask is not None: |
|
im_lq_tensor = torch.cat([im_lq_tensor, mask], dim=1) |
|
im_spliter = ImageSpliterTh( |
|
im_lq_tensor, |
|
self.chop_size, |
|
stride=self.chop_stride, |
|
sf=self.sf, |
|
extra_bs=self.chop_bs, |
|
) |
|
for im_lq_pch, index_infos in im_spliter: |
|
if mask is not None: |
|
im_lq_pch, mask_pch = im_lq_pch[:, :-1], im_lq_pch[:, -1:,] |
|
else: |
|
mask_pch = None |
|
with context(): |
|
im_sr_pch = self.sample_func( |
|
im_lq_pch, |
|
noise_repeat=noise_repeat, |
|
mask=mask_pch, |
|
) |
|
im_spliter.update(im_sr_pch, index_infos) |
|
im_sr_tensor = im_spliter.gather() |
|
else: |
|
|
|
with context(): |
|
im_sr_tensor = self.sample_func( |
|
im_lq_tensor, |
|
noise_repeat=noise_repeat, |
|
mask=mask, |
|
) |
|
|
|
im_sr_tensor = im_sr_tensor * 0.5 + 0.5 |
|
if mask_back and mask is not None: |
|
mask = mask * 0.5 + 0.5 |
|
im_lq_tensor = im_lq_tensor * 0.5 + 0.5 |
|
im_sr_tensor = im_sr_tensor * mask + im_lq_tensor * (1 - mask) |
|
return im_sr_tensor |
|
|
|
in_path = Path(in_path) if not isinstance(in_path, Path) else in_path |
|
out_path = Path(out_path) if not isinstance(out_path, Path) else out_path |
|
|
|
if self.rank == 0: |
|
assert in_path.exists() |
|
if not out_path.exists(): |
|
out_path.mkdir(parents=True) |
|
|
|
if self.num_gpus > 1: |
|
dist.barrier() |
|
|
|
if in_path.is_dir(): |
|
if mask_path is None: |
|
data_config = {'type': 'base', |
|
'params': {'dir_path': str(in_path), |
|
'transform_type': 'default', |
|
'transform_kwargs': { |
|
'mean': 0.5, |
|
'std': 0.5, |
|
}, |
|
'need_path': True, |
|
'recursive': True, |
|
'length': None, |
|
} |
|
} |
|
else: |
|
data_config = {'type': 'inpainting_val', |
|
'params': {'lq_path': str(in_path), |
|
'mask_path': mask_path, |
|
'transform_type': 'default', |
|
'transform_kwargs': { |
|
'mean': 0.5, |
|
'std': 0.5, |
|
}, |
|
'need_path': True, |
|
'recursive': True, |
|
'im_exts': ['png', 'jpg', 'jpeg', 'JPEG', 'bmp', 'PNG'], |
|
'length': None, |
|
} |
|
} |
|
dataset = create_dataset(data_config) |
|
self.write_log(f'Find {len(dataset)} images in {in_path}') |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=bs, |
|
shuffle=False, |
|
drop_last=False, |
|
) |
|
for data in dataloader: |
|
micro_batchsize = math.ceil(bs / self.num_gpus) |
|
ind_start = self.rank * micro_batchsize |
|
ind_end = ind_start + micro_batchsize |
|
micro_data = {key:value[ind_start:ind_end] for key,value in data.items()} |
|
|
|
if micro_data['lq'].shape[0] > 0: |
|
results = _process_per_image( |
|
micro_data['lq'].cuda(), |
|
mask=micro_data['mask'].cuda() if 'mask' in micro_data else None, |
|
) |
|
|
|
for jj in range(results.shape[0]): |
|
im_sr = util_image.tensor2img(results[jj], rgb2bgr=True, min_max=(0.0, 1.0)) |
|
im_name = Path(micro_data['path'][jj]).stem |
|
im_path = out_path / f"{im_name}.png" |
|
util_image.imwrite(im_sr, im_path, chn='bgr', dtype_in='uint8') |
|
if self.num_gpus > 1: |
|
dist.barrier() |
|
else: |
|
im_lq = util_image.imread(in_path, chn='rgb', dtype='float32') |
|
im_lq_tensor = util_image.img2tensor(im_lq).cuda() |
|
if mask_path is not None: |
|
im_mask = util_image.imread(mask_path, chn='gray', dtype='float32')[:,:, None] |
|
im_mask_tensor = util_image.img2tensor(im_mask).cuda() |
|
|
|
im_sr_tensor = _process_per_image( |
|
(im_lq_tensor - 0.5) / 0.5, |
|
mask=(im_mask_tensor - 0.5) / 0.5 if mask_path is not None else None, |
|
) |
|
|
|
im_sr = util_image.tensor2img(im_sr_tensor, rgb2bgr=True, min_max=(0.0, 1.0)) |
|
im_path = out_path / f"{in_path.stem}.png" |
|
util_image.imwrite(im_sr, im_path, chn='bgr', dtype_in='uint8') |
|
|
|
self.write_log(f"Processing done, enjoy the results in {str(out_path)}") |
|
|
|
if __name__ == '__main__': |
|
pass |
|
|
|
|