denoising / evals.py
Yonuts's picture
Bugfix
33dc149
raw
history blame
27.6 kB
from typing import Any, List
import deepinv as dinv
import numpy as np
import torch
from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
from torchvision import transforms
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI
from utils import get_model
DEFAULT_MODEL_PARAMS = {
"in_channels": [1, 2, 3],
"grayscale": False,
"conv_type": "base",
"pool_type": "base",
"layer_scale_init_value": 1e-6,
"init_type": "ortho",
"gain_init_conv": 1.0,
"gain_init_linear": 1.0,
"drop_prob": 0.0,
"replk": False,
"mult_fact": 4,
"antialias": "gaussian",
"nc_base": 64,
"cond_type": "base",
"blind": False,
"pretrained_pth": None,
"N": 2,
"c_mult": 2,
"depth_encoding": 2,
"relu_in_encoding": False,
"skip_in_encoding": True
}
class PhysicsWithGenerator(torch.nn.Module):
"""Interface between Physics, Generator and Gradio."""
all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur",
"MRI", "CT"]
def __init__(self, physics_name: str, device_str: str = "cpu") -> None:
super().__init__()
self.name = physics_name
if self.name not in self.all_physics:
raise ValueError(f"{self.name} is unavailable.")
self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str)
if self.name == "MotionBlur_easy":
psf_size = 31
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid",
device=device_str)
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
self.saved_params = {"updatable_params": {"sigma": 0.05},
"updatable_params_converter": {"sigma": float},
"fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
"psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
elif self.name == "MotionBlur_medium":
psf_size = 31
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid",
device=device_str)
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
self.saved_params = {"updatable_params": {"sigma": 0.05},
"updatable_params_converter": {"sigma": float},
"fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
"psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
elif self.name == "MotionBlur_hard":
psf_size = 31
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid",
device=device_str)
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
self.saved_params = {"updatable_params": {"sigma": 0.05},
"updatable_params_converter": {"sigma": float},
"fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1,
"psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
elif self.name == "GaussianBlur":
psf_size = 31
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid",
device=device_str)
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size), num_channels=1,
device=device_str)
self.generator = self.physics_generator + self.sigma_generator
self.saved_params = {"updatable_params": {"sigma": 0.05},
"updatable_params_converter": {"sigma": float},
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
"psf_size": 31, "num_channels": 1}}
elif self.name == "MRI":
self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01),
device=device_str)
self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
self.generator = self.physics_generator # + self.sigma_generator
self.saved_params = {"updatable_params": {"sigma": 0.05},
"updatable_params_converter": {"sigma": float},
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
"acceleration_factor": 4}}
elif self.name == "CT":
acceleration_factor = 10
img_h = 480
angles = int(img_h / acceleration_factor)
# angles = torch.linspace(0, 180, steps=10)
self.physics = dinv.physics.Tomography(
img_width=img_h,
angles=angles,
circle=False,
normalize=True,
device=device_str,
noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
max_iter=10,
)
self.physics_generator = None
self.generator = self.sigma_generator
self.saved_params = {"updatable_params": {"sigma": 0.1},
"updatable_params_converter": {"sigma": float},
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.,
"angles": angles, "max_iter": 10}}
def display_saved_params(self) -> str:
"""Printable version of saved_params."""
updatable_params_str = "Updatable parameters:\n"
for param_name, param_value in self.saved_params["updatable_params"].items():
updatable_params_str += f"\t\t{param_name} = {param_value}" + "\n"
fixed_params_str = "Fixed parameters:\n"
for param_name, param_value in self.saved_params["fixed_params"].items():
fixed_params_str += f"\t\t{param_name} = {param_value}" + "\n"
return updatable_params_str + fixed_params_str
def _update_save_params(self, key: str, value: Any) -> None:
"""Update value of an existing key in save_params."""
if key in list(self.saved_params["updatable_params"].keys()):
if type(value) == str: # it may be only a str representation
# type: str -> ???
value = self.saved_params["updatable_params_converter"][key](value)
elif isinstance(value, torch.Tensor):
value = value.item() # type: torch.Tensor -> float
value = float(f"{value:.4f}") # keeps only 4 significant digits
self.saved_params["updatable_params"][key] = value
def update_and_display_params(self, key, value) -> str:
"""_update_save_params + update physics with saved_params + display_saved_params"""
self._update_save_params(key, value)
if self.name == "Denoising":
self.physics.noise_model.update_parameters(**self.saved_params["updatable_params"])
else:
self.physics.update_parameters(**self.saved_params["updatable_params"])
return self.display_saved_params()
def update_saved_params_and_physics(self, **kwargs) -> None:
"""Update save_params and update physics."""
for key, value in kwargs.items():
self._update_save_params(key, value)
self.physics.update(**kwargs)
def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"):
use_gen = True
elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
use_gen = True
if use_gen:
kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample
self.update_saved_params_and_physics(**kwargs)
return self.physics(x)
class EvalModel(torch.nn.Module):
"""Eval model.
Is there a difference with BaselineModel ?
-> BaselineModel should be models that are already trained and will have fixed weights.
-> Eval model will change depending on differents checkpoints.
"""
all_models = ["unext_emb_physics_config_C"]
def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
"""Load the model we want to evaluate."""
super().__init__()
self.base_name = model_name
self.ckpt_pth = ckpt_pth
self.name = self.base_name
if self.base_name not in self.all_models:
raise ValueError(f"{self.base_name} is unavailable.")
if self.base_name == "unext_emb_physics_config_C":
if self.ckpt_pth == "":
self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar"
self.model = get_model(model_name=self.base_name,
device='cpu',
**DEFAULT_MODEL_PARAMS)
# load model checkpoint
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[
'state_dict'] # load on cpu
self.model.load_state_dict(state_dict)
self.model.to(device_str)
self.model.eval()
# add epoch in the model name
epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch']
self.name = self.name + f"+{epoch}"
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
return self.model(y, physics=physics)
class BaselineModel(torch.nn.Module):
"""Baseline model.
Is there a difference with EvalModel ?
-> BaselineModel should be models that are already trained and will have fixed weights.
-> Eval model will change depending on differents checkpoints.
"""
all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR",
"DPIR_MRI", "DPIR_CT", "PDNET"]
def __init__(self, model_name: str, device_str: str = "cpu") -> None:
super().__init__()
self.base_name = model_name
self.ckpt_pth = ""
self.name = self.base_name
if self.name not in self.all_baselines:
raise ValueError(f"{self.name} is unavailable.")
elif self.name == "DRUNET":
n_channels = 3
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
self.model = dinv.models.DRUNet(in_channels=n_channels,
out_channels=n_channels,
device=device_str,
pretrained=ckpt_pth)
self.model.eval() # Set the model to evaluation mode
elif self.name == 'PDNET':
ckpt_pth = "ckpt/pdnet.pth.tar"
self.model = get_model(model_name='pdnet',
device=device_str)
self.model.eval()
self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict'])
elif self.name == "SWINIRx2":
n_channels = 3
scale = 2
ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth"
upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
pretrained=ckpt_pth)
self.model.to(device_str)
self.model.eval() # Set the model to evaluation mode
elif self.name == "SWINIRx4":
n_channels = 3
scale = 4
ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth"
upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
pretrained=ckpt_pth)
self.model.to(device_str)
self.model.eval() # Set the model to evaluation mode
elif self.name == "PnP-PGD-DRUNET":
n_channels = 3
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
drunet = dinv.models.DRUNet(in_channels=n_channels,
out_channels=n_channels,
device=device_str,
pretrained=ckpt_pth)
drunet.eval() # Set the model to evaluation mode
self.model = dinv.optim.optim_builder(iteration="PGD",
prior=dinv.optim.PnP(drunet).to(device_str),
data_fidelity=dinv.optim.L2(),
max_iter=20,
params_algo={'stepsize': 1., 'g_param': .05})
elif self.name == "DPIR":
n_channels = 3
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
drunet = dinv.models.DRUNet(in_channels=n_channels,
out_channels=n_channels,
device=device_str,
pretrained=ckpt_pth)
drunet.eval() # Set the model to evaluation mode
# Specify the denoising prior
self.prior = dinv.optim.prior.PnP(denoiser=drunet)
elif self.name == "DPIR_MRI":
class ComplexDenoiser(torch.nn.Module):
def __init__(self, denoiser):
super().__init__()
self.denoiser = denoiser
def forward(self, x, sigma):
noisy_batch = torch.cat((x[:, 0:1, ...], x[:, 1:2, ...]), 0)
input_min = noisy_batch.min()
denoised_batch = self.denoiser(noisy_batch - input_min, sigma)
denoised_batch = denoised_batch + input_min
denoised = torch.cat((denoised_batch[0:1, ...], denoised_batch[1:2, ...]), 1)
return denoised
# Load PnP denoiser backbone
n_channels = 1
ckpt_pth = "ckpt/drunet_gray.pth"
drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
pretrained=ckpt_pth)
complex_drunet = ComplexDenoiser(drunet)
complex_drunet.eval()
# Specify the denoising prior
self.prior = dinv.optim.prior.PnP(denoiser=complex_drunet)
elif self.name == "DPIR_CT":
class CTDenoiser(torch.nn.Module):
def __init__(self, denoiser):
super().__init__()
self.denoiser = denoiser
def forward(self, x, sigma):
x = x - x.min()
denoised = self.denoiser(x, sigma)
denoised = denoised + x.min()
return denoised
# Load PnP denoiser backbone
n_channels = 1
ckpt_pth = "ckpt/drunet_gray.pth"
drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
pretrained=ckpt_pth)
ct_drunet = CTDenoiser(drunet)
ct_drunet.eval()
# Specify the denoising prior
self.prior = dinv.optim.prior.PnP(denoiser=ct_drunet)
def circular_roll(self, tensor, p_h, p_w):
return tensor.roll(shifts=(p_h, p_w), dims=(-2, -1))
def get_DPIR_params(self, noise_level_img, max_iter=8):
r"""
Default parameters for the DPIR Plug-and-Play algorithm.
:param float noise_level_img: Noise level of the input image.
:return: tuple(list with denoiser noise level per iteration, list with stepsize per iteration, iterations).
"""
max_iter = 8
s1 = 49.0 / 255.0
s2 = max(noise_level_img, 0.01)
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
np.float32
)
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
lamb = 1 / 0.23
return list(sigma_denoiser), list(lamb * stepsize)
def get_DPIR_MRI_params(self, noise_level_img: float, max_iter: int = 8):
r"""
Default parameters for the DPIR Plug-and-Play algorithm.
:param float noise_level_img: Noise level of the input image.
"""
s1 = 49.0 / 255.0
s2 = noise_level_img
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
np.float32
)
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
lamb = 1.
return lamb, list(sigma_denoiser), list(stepsize), max_iter
def get_DPIR_CT_params(self, noise_level_img: float, max_iter: int = 8, lip_cons: float = 1.0):
r"""
Default parameters for the DPIR Plug-and-Play algorithm.
:param float noise_level_img: Noise level of the input image.
"""
s1 = 49.0 / 255.0 * lip_cons
s2 = noise_level_img
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
np.float32
)
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 #
lamb = 1.
return lamb, list(sigma_denoiser), list(stepsize), max_iter
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
if self.name == "DRUNET":
return self.model(y, sigma=physics.noise_model.sigma)
elif self.name == "PnP-PGD-DRUNET":
return self.model(y, physics=physics)
elif self.name == "DPIR":
# Set the DPIR algorithm parameters
sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
max_iter = 8
sigma_denoiser, stepsize = self.get_DPIR_params(sigma_float, max_iter=max_iter)
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser}
early_stop = False # Do not stop algorithm with convergence criteria
# instantiate DPIR
model = dinv.optim.optim_builder(
iteration="HQS",
prior=self.prior,
data_fidelity=dinv.optim.data_fidelity.L2(),
early_stop=early_stop,
max_iter=max_iter,
verbose=True,
params_algo=params_algo,
)
return model(y, physics=physics)
elif self.name == "DPIR_MRI":
sigma_float = max(physics.noise_model.sigma.item(), 0.015) # sigma should be a single value
lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_MRI_params(sigma_float, max_iter=16)
stepsize = [stepsize[0]] * max_iter
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
early_stop = False # Do not stop algorithm with convergence criteria
# Instantiate the algorithm class to solve the IP
model = dinv.optim.optim_builder(
iteration="HQS",
prior=self.prior,
data_fidelity=dinv.optim.data_fidelity.L2(),
early_stop=early_stop,
max_iter=max_iter,
verbose=True,
params_algo=params_algo,
)
return model(y, physics=physics)
elif self.name == "DPIR_CT":
# Set the DPIR algorithm parameters
sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
lip_const = physics.compute_norm(physics.A_adjoint(y))
lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8,
lip_cons=lip_const.item())
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
early_stop = False # Do not stop algorithm with convergence criteria
def custom_init(y, physic_op):
x_init = physic_op.prox_l2(physic_op.A_adjoint(y), y, gamma=1e4)
return {"est": (x_init, x_init)}
# Instantiate the algorithm class to solve the IP
algo = dinv.optim.optim_builder(
iteration="HQS",
prior=self.prior,
data_fidelity=dinv.optim.data_fidelity.L2(),
early_stop=early_stop,
max_iter=max_iter,
verbose=True,
params_algo=params_algo,
custom_init=custom_init
)
return algo(y, physics=physics)
elif self.name == 'SWINIRx4':
window_size = 8
scale = 4
_, _, h_old, w_old = y.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
output = self.model(img_lq)
output = output[..., :h_old * scale, :w_old * scale]
output = self.circular_roll(output, -2, -2)
# check shape of adjoint
x_adj = physics.A_adjoint(y)
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
return output
elif self.name == 'SWINIRx2':
window_size = 8
scale = 2
_, _, h_old, w_old = y.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
output = self.model(img_lq)
output = output[..., :h_old * scale, :w_old * scale]
output = self.circular_roll(output, -1, -1)
# check shape of adjoint
x_adj = physics.A_adjoint(y)
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
return output
else:
return self.model(y)
class EvalDataset(torch.utils.data.Dataset):
"""
We expect that images are 480x480.
"""
all_datasets = ["Natural", "MRI", "CT"]
def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
self.name = dataset_name
self.device_str = device_str
if self.name not in self.all_datasets:
raise ValueError(f"{self.name} is unavailable.")
if self.name == 'Natural':
self.root = 'img_samples/LSDIR_samples'
self.transform = transforms.Compose([transforms.ToTensor()])
self.dataset = dinv.datasets.LsdirHR(root=self.root,
download=False,
transform=self.transform)
elif self.name == 'MRI':
self.root = 'img_samples/FastMRI_samples'
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
self.dataset = Preprocessed_fastMRI(root=self.root,
transform=self.transform,
preprocess=False)
elif self.name == "CT":
self.root = 'img_samples/LIDC_IDRI_samples'
self.transform = None
self.dataset = Preprocessed_LIDCIDRI(root=self.root,
transform=self.transform)
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> torch.Tensor:
return self.dataset[idx].to(self.device_str)
class Metric():
"""Metrics and utilities."""
all_metrics = ["PSNR", "SSIM", "LPIPS"]
def __init__(self, metric_name: str, device_str: str = "cpu") -> None:
self.name = metric_name
if self.name not in self.all_metrics:
raise ValueError(f"{self.name} is unavailable.")
elif self.name == "PSNR":
self.metric = dinv.loss.metric.PSNR()
elif self.name == "SSIM":
self.metric = dinv.loss.metric.SSIM()
elif self.name == "LPIPS":
self.metric = dinv.loss.metric.LPIPS(device=device_str)
def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# it may happen that x_net and x do not have the same size, in which case we take the minimum size of both
if x_net.shape[-1] != x.shape[-1]:
min_size = min(x_net.shape[-1], x.shape[-1])
x_net_crop = x_net[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
x_crop = x[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
else:
x_net_crop = x_net
x_crop = x
return self.metric(x_net_crop, x_crop)
@classmethod
def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]:
l = []
for metric_name in metric_names:
l.append(cls(metric_name, device_str=device_str))
return l