import os, sys, math, random |
import cv2 |
import numpy as np |
from pathlib import Path |
from loguru import logger |
from omegaconf import OmegaConf |
from contextlib import nullcontext |
from utils import util_net |
from utils import util_image |
from utils import util_common |
import torch |
import torch.nn.functional as F |
import torch.distributed as dist |
import torch.multiprocessing as mp |
from datapipe.datasets import create_dataset |
from utils.util_image import ImageSpliterTh |
class BaseSampler: |
def __init__( |
self, |
configs, |
sf=4, |
use_amp=True, |
chop_size=128, |
chop_stride=128, |
chop_bs=1, |
padding_offset=16, |
seed=10000, |
): |
''' |
Input: |
configs: config, see the yaml file in folder ./configs/ |
sf: int, super-resolution scale |
seed: int, random seed |
''' |
self.configs = configs |
self.sf = sf |
self.chop_size = chop_size |
self.chop_stride = chop_stride |
self.chop_bs = chop_bs |
self.seed = seed |
self.use_amp = use_amp |
self.padding_offset = padding_offset |
self.setup_dist() |
self.setup_seed() |
self.build_model() |
def setup_seed(self, seed=None): |
seed = self.seed if seed is None else seed |
random.seed(seed) |
np.random.seed(seed) |
torch.manual_seed(seed) |
torch.cuda.manual_seed_all(seed) |
def setup_dist(self, gpu_id=None): |
num_gpus = torch.cuda.device_count() |
if num_gpus > 1: |
rank = 0 |
torch.cuda.set_device(rank) |
self.num_gpus = num_gpus |
print("내가 추가한 거 num_gpus: ", num_gpus) |
print("내가 추가한 거 self.num_gpus: ", self.num_gpus) |
self.rank = int(os.environ['LOCAL_RANK']) if num_gpus > 1 else 0 |
print("내가 추가한 거 self.rank: ", self.rank) |
def write_log(self, log_str): |
if self.rank == 0: |
print(log_str, flush=True) |
def build_model(self): |
log_str = f'Building the diffusion model with length: {self.configs.diffusion.params.steps}...' |
self.write_log(log_str) |
self.base_diffusion = util_common.instantiate_from_config(self.configs.diffusion) |
model = util_common.instantiate_from_config(self.configs.model).cuda() |
ckpt_path =self.configs.model.ckpt_path |
assert ckpt_path is not None |
self.write_log(f'Loading Diffusion model from {ckpt_path}...') |
self.load_model(model, ckpt_path) |
self.freeze_model(model) |
self.model = model.eval() |
if self.configs.autoencoder is not None: |
ckpt_path = self.configs.autoencoder.ckpt_path |
assert ckpt_path is not None |
self.write_log(f'Loading AutoEncoder model from {ckpt_path}...') |
autoencoder = util_common.instantiate_from_config(self.configs.autoencoder).cuda() |
self.load_model(autoencoder, ckpt_path) |
autoencoder.eval() |
self.autoencoder = autoencoder |
else: |
self.autoencoder = None |
def load_model(self, model, ckpt_path=None): |
state = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") |
if 'state_dict' in state: |
state = state['state_dict'] |
util_net.reload_model(model, state) |
def freeze_model(self, net): |
for params in net.parameters(): |
params.requires_grad = False |
class ResShiftSampler(BaseSampler): |
def sample_func(self, y0, noise_repeat=False, mask=False): |
''' |
Input: |
y0: n x c x h x w torch tensor, low-quality image, [-1, 1], RGB |
mask: image mask for inpainting |
Output: |
sample: n x c x h x w, torch tensor, [-1, 1], RGB |
''' |
if noise_repeat: |
self.setup_seed() |
offset = self.padding_offset |
ori_h, ori_w = y0.shape[2:] |
if not (ori_h % offset == 0 and ori_w % offset == 0): |
flag_pad = True |
pad_h = (math.ceil(ori_h / offset)) * offset - ori_h |
pad_w = (math.ceil(ori_w / offset)) * offset - ori_w |
y0 = F.pad(y0, pad=(0, pad_w, 0, pad_h), mode='reflect') |
else: |
flag_pad = False |
if self.configs.model.params.cond_lq and mask is not None: |
model_kwargs={ |
'lq':y0, |
'mask': mask, |
} |
elif self.configs.model.params.cond_lq: |
model_kwargs={'lq':y0,} |
else: |
model_kwargs = None |
results = self.base_diffusion.p_sample_loop( |
y=y0, |
model=self.model, |
first_stage_model=self.autoencoder, |
noise=None, |
noise_repeat=noise_repeat, |
clip_denoised=(self.autoencoder is None), |
denoised_fn=None, |
model_kwargs=model_kwargs, |
progress=False, |
) |
if flag_pad: |
results = results[:, :, :ori_h*self.sf, :ori_w*self.sf] |
return results.clamp_(-1.0, 1.0) |
def inference(self, in_path, out_path, mask_path=None, mask_back=True, bs=1, noise_repeat=False): |
''' |
Inference demo. |
Input: |
in_path: str, folder or image path for LQ image |
out_path: str, folder save the results |
bs: int, default bs=1, bs % num_gpus == 0 |
mask_path: image mask for inpainting |
''' |
def _process_per_image(im_lq_tensor, mask=None): |
''' |
Input: |
im_lq_tensor: b x c x h x w, torch tensor, [-1, 1], RGB |
mask: image mask for inpainting, [-1, 1], 1 for unknown area |
Output: |
im_sr: h x w x c, numpy array, [0,1], RGB |
''' |
context = torch.cuda.amp.autocast if self.use_amp else nullcontext |
if im_lq_tensor.shape[2] > self.chop_size or im_lq_tensor.shape[3] > self.chop_size: |
if mask is not None: |
im_lq_tensor = torch.cat([im_lq_tensor, mask], dim=1) |
im_spliter = ImageSpliterTh( |
im_lq_tensor, |
self.chop_size, |
stride=self.chop_stride, |
sf=self.sf, |
extra_bs=self.chop_bs, |
) |
for im_lq_pch, index_infos in im_spliter: |
if mask is not None: |
im_lq_pch, mask_pch = im_lq_pch[:, :-1], im_lq_pch[:, -1:,] |
else: |
mask_pch = None |
with context(): |
im_sr_pch = self.sample_func( |
im_lq_pch, |
noise_repeat=noise_repeat, |
mask=mask_pch, |
) |
im_spliter.update(im_sr_pch, index_infos) |
im_sr_tensor = im_spliter.gather() |
else: |
with context(): |
im_sr_tensor = self.sample_func( |
im_lq_tensor, |
noise_repeat=noise_repeat, |
mask=mask, |
) |
im_sr_tensor = im_sr_tensor * 0.5 + 0.5 |
if mask_back and mask is not None: |
mask = mask * 0.5 + 0.5 |
im_lq_tensor = im_lq_tensor * 0.5 + 0.5 |
im_sr_tensor = im_sr_tensor * mask + im_lq_tensor * (1 - mask) |
return im_sr_tensor |
in_path = Path(in_path) if not isinstance(in_path, Path) else in_path |
out_path = Path(out_path) if not isinstance(out_path, Path) else out_path |
if self.rank == 0: |
assert in_path.exists() |
if not out_path.exists(): |
out_path.mkdir(parents=True) |
if self.num_gpus > 1: |
dist.barrier() |
if in_path.is_dir(): |
if mask_path is None: |
data_config = {'type': 'base', |
'params': {'dir_path': str(in_path), |
'transform_type': 'default', |
'transform_kwargs': { |
'mean': 0.5, |
'std': 0.5, |
}, |
'need_path': True, |
'recursive': True, |
'length': None, |
} |
} |
else: |
data_config = {'type': 'inpainting_val', |
'params': {'lq_path': str(in_path), |
'mask_path': mask_path, |
'transform_type': 'default', |
'transform_kwargs': { |
'mean': 0.5, |
'std': 0.5, |
}, |
'need_path': True, |
'recursive': True, |
'im_exts': ['png', 'jpg', 'jpeg', 'JPEG', 'bmp', 'PNG'], |
'length': None, |
} |
} |
dataset = create_dataset(data_config) |
self.write_log(f'Find {len(dataset)} images in {in_path}') |
dataloader = torch.utils.data.DataLoader( |
dataset, |
batch_size=bs, |
shuffle=False, |
drop_last=False, |
) |
for data in dataloader: |
micro_batchsize = math.ceil(bs / self.num_gpus) |
ind_start = self.rank * micro_batchsize |
ind_end = ind_start + micro_batchsize |
micro_data = {key:value[ind_start:ind_end] for key,value in data.items()} |
if micro_data['lq'].shape[0] > 0: |
results = _process_per_image( |
micro_data['lq'].cuda(), |
mask=micro_data['mask'].cuda() if 'mask' in micro_data else None, |
) |
for jj in range(results.shape[0]): |
im_sr = util_image.tensor2img(results[jj], rgb2bgr=True, min_max=(0.0, 1.0)) |
im_name = Path(micro_data['path'][jj]).stem |
im_path = out_path / f"{im_name}.png" |
util_image.imwrite(im_sr, im_path, chn='bgr', dtype_in='uint8') |
if self.num_gpus > 1: |
dist.barrier() |
else: |
im_lq = util_image.imread(in_path, chn='rgb', dtype='float32') |
im_lq_tensor = util_image.img2tensor(im_lq).cuda() |
if mask_path is not None: |
im_mask = util_image.imread(mask_path, chn='gray', dtype='float32')[:,:, None] |
im_mask_tensor = util_image.img2tensor(im_mask).cuda() |
im_sr_tensor = _process_per_image( |
(im_lq_tensor - 0.5) / 0.5, |
mask=(im_mask_tensor - 0.5) / 0.5 if mask_path is not None else None, |
) |
im_sr = util_image.tensor2img(im_sr_tensor, rgb2bgr=True, min_max=(0.0, 1.0)) |
im_path = out_path / f"{in_path.stem}.png" |
util_image.imwrite(im_sr, im_path, chn='bgr', dtype_in='uint8') |
self.write_log(f"Processing done, enjoy the results in {str(out_path)}") |
if __name__ == '__main__': |
pass |