from typing import Any, List import deepinv as dinv import numpy as np import torch from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator from torchvision import transforms from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset from utils import get_model DEFAULT_MODEL_PARAMS = { "in_channels": [1, 2, 3], "grayscale": False, "conv_type": "base", "pool_type": "base", "layer_scale_init_value": 1e-6, "init_type": "ortho", "gain_init_conv": 1.0, "gain_init_linear": 1.0, "drop_prob": 0.0, "replk": False, "mult_fact": 4, "antialias": "gaussian", "nc_base": 64, "cond_type": "base", "blind": False, "pretrained_pth": None, "N": 2, "c_mult": 2, "depth_encoding": 2, "relu_in_encoding": False, "skip_in_encoding": True } class PhysicsWithGenerator(torch.nn.Module): """Interface between Physics, Generator and Gradio.""" all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur", "MRI", "CT"] def __init__(self, physics_name: str, device_str: str = "cpu") -> None: super().__init__() self.name = physics_name if self.name not in self.all_physics: raise ValueError(f"{self.name} is unavailable.") self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str) if self.name == "MotionBlur_easy": psf_size = 31 self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid", device=device_str) self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str) self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str) self.saved_params = {"updatable_params": {"sigma": 0.05}, "updatable_params_converter": {"sigma": float}, "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01, "psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}} elif self.name == "MotionBlur_medium": psf_size = 31 self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid", device=device_str) self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str) self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str) self.saved_params = {"updatable_params": {"sigma": 0.05}, "updatable_params_converter": {"sigma": float}, "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05, "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}} elif self.name == "MotionBlur_hard": psf_size = 31 self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid", device=device_str) self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str) self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str) self.saved_params = {"updatable_params": {"sigma": 0.05}, "updatable_params_converter": {"sigma": float}, "fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1, "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}} elif self.name == "GaussianBlur": psf_size = 31 self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid", device=device_str) self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size), num_channels=1, device=device_str) self.generator = self.physics_generator + self.sigma_generator self.saved_params = {"updatable_params": {"sigma": 0.05}, "updatable_params_converter": {"sigma": float}, "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2, "psf_size": 31, "num_channels": 1}} elif self.name == "MRI": self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01), device=device_str) self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4) self.generator = self.physics_generator # + self.sigma_generator self.saved_params = {"updatable_params": {"sigma": 0.05}, "updatable_params_converter": {"sigma": float}, "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2, "acceleration_factor": 4}} elif self.name == "CT": acceleration_factor = 10 img_h = 480 angles = int(img_h / acceleration_factor) # angles = torch.linspace(0, 180, steps=10) self.physics = dinv.physics.Tomography( img_width=img_h, angles=angles, circle=False, normalize=True, device=device_str, noise_model=dinv.physics.GaussianNoise(sigma=1e-4), max_iter=10, ) self.physics_generator = None self.generator = self.sigma_generator self.saved_params = {"updatable_params": {"sigma": 0.1}, "updatable_params_converter": {"sigma": float}, "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0., "angles": angles, "max_iter": 10}} def display_saved_params(self) -> str: """Printable version of saved_params.""" updatable_params_str = "Updatable parameters:\n" for param_name, param_value in self.saved_params["updatable_params"].items(): updatable_params_str += f"\t\t{param_name} = {param_value}" + "\n" fixed_params_str = "Fixed parameters:\n" for param_name, param_value in self.saved_params["fixed_params"].items(): fixed_params_str += f"\t\t{param_name} = {param_value}" + "\n" return updatable_params_str + fixed_params_str def _update_save_params(self, key: str, value: Any) -> None: """Update value of an existing key in save_params.""" if key in list(self.saved_params["updatable_params"].keys()): if type(value) == str: # it may be only a str representation # type: str -> ??? value = self.saved_params["updatable_params_converter"][key](value) elif isinstance(value, torch.Tensor): value = value.item() # type: torch.Tensor -> float value = float(f"{value:.4f}") # keeps only 4 significant digits self.saved_params["updatable_params"][key] = value def update_and_display_params(self, key, value) -> str: """_update_save_params + update physics with saved_params + display_saved_params""" self._update_save_params(key, value) if self.name == "Denoising": self.physics.noise_model.update_parameters(**self.saved_params["updatable_params"]) else: self.physics.update_parameters(**self.saved_params["updatable_params"]) return self.display_saved_params() def update_saved_params_and_physics(self, **kwargs) -> None: """Update save_params and update physics.""" for key, value in kwargs.items(): self._update_save_params(key, value) self.physics.update(**kwargs) def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor: if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"): use_gen = True elif self.name in ["MRI"] and not hasattr(self.physics, "mask"): use_gen = True if use_gen: kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample self.update_saved_params_and_physics(**kwargs) return self.physics(x) class EvalModel(torch.nn.Module): """Eval model. Is there a difference with BaselineModel ? -> BaselineModel should be models that are already trained and will have fixed weights. -> Eval model will change depending on differents checkpoints. """ all_models = ["unext_emb_physics_config_C"] def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None: """Load the model we want to evaluate.""" super().__init__() self.base_name = model_name self.ckpt_pth = ckpt_pth self.name = self.base_name if self.base_name not in self.all_models: raise ValueError(f"{self.base_name} is unavailable.") if self.base_name == "unext_emb_physics_config_C": if self.ckpt_pth == "": self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar" self.model = get_model(model_name=self.base_name, device='cpu', **DEFAULT_MODEL_PARAMS) # load model checkpoint state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[ 'state_dict'] # load on cpu self.model.load_state_dict(state_dict) self.model.to(device_str) self.model.eval() # add epoch in the model name epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch'] self.name = self.name + f"+{epoch}" def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor: return self.model(y, physics=physics) class BaselineModel(torch.nn.Module): """Baseline model. Is there a difference with EvalModel ? -> BaselineModel should be models that are already trained and will have fixed weights. -> Eval model will change depending on differents checkpoints. """ all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR", "DPIR_MRI", "DPIR_CT", "PDNET"] def __init__(self, model_name: str, device_str: str = "cpu") -> None: super().__init__() self.base_name = model_name self.ckpt_pth = "" self.name = self.base_name if self.name not in self.all_baselines: raise ValueError(f"{self.name} is unavailable.") elif self.name == "DRUNET": n_channels = 3 ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" self.model = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, pretrained=ckpt_pth) self.model.eval() # Set the model to evaluation mode elif self.name == 'PDNET': ckpt_pth = "ckpt/pdnet.pth.tar" self.model = get_model(model_name='pdnet', device=device_str) self.model.eval() self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict']) elif self.name == "SWINIRx2": n_channels = 3 scale = 2 ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth" upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle' self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8, img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler=upsampler, resi_connection='1conv', pretrained=ckpt_pth) self.model.to(device_str) self.model.eval() # Set the model to evaluation mode elif self.name == "SWINIRx4": n_channels = 3 scale = 4 ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth" upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle' self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8, img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler=upsampler, resi_connection='1conv', pretrained=ckpt_pth) self.model.to(device_str) self.model.eval() # Set the model to evaluation mode elif self.name == "PnP-PGD-DRUNET": n_channels = 3 ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, pretrained=ckpt_pth) drunet.eval() # Set the model to evaluation mode self.model = dinv.optim.optim_builder(iteration="PGD", prior=dinv.optim.PnP(drunet).to(device_str), data_fidelity=dinv.optim.L2(), max_iter=20, params_algo={'stepsize': 1., 'g_param': .05}) elif self.name == "DPIR": n_channels = 3 ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, pretrained=ckpt_pth) drunet.eval() # Set the model to evaluation mode # Specify the denoising prior self.prior = dinv.optim.prior.PnP(denoiser=drunet) elif self.name == "DPIR_MRI": class ComplexDenoiser(torch.nn.Module): def __init__(self, denoiser): super().__init__() self.denoiser = denoiser def forward(self, x, sigma): noisy_batch = torch.cat((x[:, 0:1, ...], x[:, 1:2, ...]), 0) input_min = noisy_batch.min() denoised_batch = self.denoiser(noisy_batch - input_min, sigma) denoised_batch = denoised_batch + input_min denoised = torch.cat((denoised_batch[0:1, ...], denoised_batch[1:2, ...]), 1) return denoised # Load PnP denoiser backbone n_channels = 1 ckpt_pth = "ckpt/drunet_gray.pth" drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, pretrained=ckpt_pth) complex_drunet = ComplexDenoiser(drunet) complex_drunet.eval() # Specify the denoising prior self.prior = dinv.optim.prior.PnP(denoiser=complex_drunet) elif self.name == "DPIR_CT": class CTDenoiser(torch.nn.Module): def __init__(self, denoiser): super().__init__() self.denoiser = denoiser def forward(self, x, sigma): x = x - x.min() denoised = self.denoiser(x, sigma) denoised = denoised + x.min() return denoised # Load PnP denoiser backbone n_channels = 1 ckpt_pth = "ckpt/drunet_gray.pth" drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, pretrained=ckpt_pth) ct_drunet = CTDenoiser(drunet) ct_drunet.eval() # Specify the denoising prior self.prior = dinv.optim.prior.PnP(denoiser=ct_drunet) def circular_roll(self, tensor, p_h, p_w): return tensor.roll(shifts=(p_h, p_w), dims=(-2, -1)) def get_DPIR_params(self, noise_level_img, max_iter=8): r""" Default parameters for the DPIR Plug-and-Play algorithm. :param float noise_level_img: Noise level of the input image. :return: tuple(list with denoiser noise level per iteration, list with stepsize per iteration, iterations). """ max_iter = 8 s1 = 49.0 / 255.0 s2 = max(noise_level_img, 0.01) sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( np.float32 ) stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 lamb = 1 / 0.23 return list(sigma_denoiser), list(lamb * stepsize) def get_DPIR_MRI_params(self, noise_level_img: float, max_iter: int = 8): r""" Default parameters for the DPIR Plug-and-Play algorithm. :param float noise_level_img: Noise level of the input image. """ s1 = 49.0 / 255.0 s2 = noise_level_img sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( np.float32 ) stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 lamb = 1. return lamb, list(sigma_denoiser), list(stepsize), max_iter def get_DPIR_CT_params(self, noise_level_img: float, max_iter: int = 8, lip_cons: float = 1.0): r""" Default parameters for the DPIR Plug-and-Play algorithm. :param float noise_level_img: Noise level of the input image. """ s1 = 49.0 / 255.0 * lip_cons s2 = noise_level_img sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( np.float32 ) stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 # lamb = 1. return lamb, list(sigma_denoiser), list(stepsize), max_iter def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor: if self.name == "DRUNET": return self.model(y, sigma=physics.noise_model.sigma) elif self.name == "PnP-PGD-DRUNET": return self.model(y, physics=physics) elif self.name == "DPIR": # Set the DPIR algorithm parameters sigma_float = physics.noise_model.sigma.item() # sigma should be a single value max_iter = 8 sigma_denoiser, stepsize = self.get_DPIR_params(sigma_float, max_iter=max_iter) params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser} early_stop = False # Do not stop algorithm with convergence criteria # instantiate DPIR model = dinv.optim.optim_builder( iteration="HQS", prior=self.prior, data_fidelity=dinv.optim.data_fidelity.L2(), early_stop=early_stop, max_iter=max_iter, verbose=True, params_algo=params_algo, ) return model(y, physics=physics) elif self.name == "DPIR_MRI": sigma_float = max(physics.noise_model.sigma.item(), 0.015) # sigma should be a single value lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_MRI_params(sigma_float, max_iter=16) stepsize = [stepsize[0]] * max_iter params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb} early_stop = False # Do not stop algorithm with convergence criteria # Instantiate the algorithm class to solve the IP model = dinv.optim.optim_builder( iteration="HQS", prior=self.prior, data_fidelity=dinv.optim.data_fidelity.L2(), early_stop=early_stop, max_iter=max_iter, verbose=True, params_algo=params_algo, ) return model(y, physics=physics) elif self.name == "DPIR_CT": # Set the DPIR algorithm parameters sigma_float = physics.noise_model.sigma.item() # sigma should be a single value lip_const = physics.compute_norm(physics.A_adjoint(y)) lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8, lip_cons=lip_const.item()) params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb} early_stop = False # Do not stop algorithm with convergence criteria def custom_init(y, physic_op): x_init = physic_op.prox_l2(physic_op.A_adjoint(y), y, gamma=1e4) return {"est": (x_init, x_init)} # Instantiate the algorithm class to solve the IP algo = dinv.optim.optim_builder( iteration="HQS", prior=self.prior, data_fidelity=dinv.optim.data_fidelity.L2(), early_stop=early_stop, max_iter=max_iter, verbose=True, params_algo=params_algo, custom_init=custom_init ) return algo(y, physics=physics) elif self.name == 'SWINIRx4': window_size = 8 scale = 4 _, _, h_old, w_old = y.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :] img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] output = self.model(img_lq) output = output[..., :h_old * scale, :w_old * scale] output = self.circular_roll(output, -2, -2) # check shape of adjoint x_adj = physics.A_adjoint(y) output = output[..., :x_adj.size(-2), :x_adj.size(-1)] return output elif self.name == 'SWINIRx2': window_size = 8 scale = 2 _, _, h_old, w_old = y.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :] img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] output = self.model(img_lq) output = output[..., :h_old * scale, :w_old * scale] output = self.circular_roll(output, -1, -1) # check shape of adjoint x_adj = physics.A_adjoint(y) output = output[..., :x_adj.size(-2), :x_adj.size(-1)] return output else: return self.model(y) class EvalDataset(torch.utils.data.Dataset): """""" all_datasets = ["Natural", "MRI", "CT"] def __init__(self, dataset_name: str, device_str: str = "cpu") -> None: self.name = dataset_name self.device_str = device_str if self.name not in self.all_datasets: raise ValueError(f"{self.name} is unavailable.") if self.name == 'Natural': self.root = 'img_samples/LSDIR_samples' self.transform = transforms.Compose([transforms.ToTensor()]) self.dataset = LsdirMiniDataset(root=self.root, transform=self.transform) elif self.name == 'MRI': self.root = 'img_samples/FastMRI_samples' self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True) self.dataset = Preprocessed_fastMRI(root=self.root, transform=self.transform, preprocess=False) elif self.name == "CT": self.root = 'img_samples/LIDC_IDRI_samples' self.transform = None self.dataset = Preprocessed_LIDCIDRI(root=self.root, transform=self.transform) def __len__(self) -> int: return len(self.dataset) def __getitem__(self, idx: int) -> torch.Tensor: return self.dataset[idx].to(self.device_str) class Metric(): """Metrics and utilities.""" all_metrics = ["PSNR", "SSIM", "LPIPS"] def __init__(self, metric_name: str, device_str: str = "cpu") -> None: self.name = metric_name if self.name not in self.all_metrics: raise ValueError(f"{self.name} is unavailable.") elif self.name == "PSNR": self.metric = dinv.loss.metric.PSNR() elif self.name == "SSIM": self.metric = dinv.loss.metric.SSIM() elif self.name == "LPIPS": self.metric = dinv.loss.metric.LPIPS(device=device_str) def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # it may happen that x_net and x do not have the same size, in which case we take the minimum size of both if x_net.shape[-1] != x.shape[-1]: min_size = min(x_net.shape[-1], x.shape[-1]) x_net_crop = x_net[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2, x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2] x_crop = x[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2, x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2] else: x_net_crop = x_net x_crop = x return self.metric(x_net_crop, x_crop) @classmethod def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]: l = [] for metric_name in metric_names: l.append(cls(metric_name, device_str=device_str)) return l