Spaces:
Sleeping
Sleeping
from typing import Any, List | |
import deepinv as dinv | |
import numpy as np | |
import torch | |
from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator | |
from torchvision import transforms | |
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI | |
from utils import get_model | |
DEFAULT_MODEL_PARAMS = { | |
"in_channels": [1, 2, 3], | |
"grayscale": False, | |
"conv_type": "base", | |
"pool_type": "base", | |
"layer_scale_init_value": 1e-6, | |
"init_type": "ortho", | |
"gain_init_conv": 1.0, | |
"gain_init_linear": 1.0, | |
"drop_prob": 0.0, | |
"replk": False, | |
"mult_fact": 4, | |
"antialias": "gaussian", | |
"nc_base": 64, | |
"cond_type": "base", | |
"blind": False, | |
"pretrained_pth": None, | |
"N": 2, | |
"c_mult": 2, | |
"depth_encoding": 2, | |
"relu_in_encoding": False, | |
"skip_in_encoding": True | |
} | |
class PhysicsWithGenerator(torch.nn.Module): | |
"""Interface between Physics, Generator and Gradio.""" | |
all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur", | |
"MRI", "CT"] | |
def __init__(self, physics_name: str, device_str: str = "cpu") -> None: | |
super().__init__() | |
self.name = physics_name | |
if self.name not in self.all_physics: | |
raise ValueError(f"{self.name} is unavailable.") | |
self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str) | |
if self.name == "MotionBlur_easy": | |
psf_size = 31 | |
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid", | |
device=device_str) | |
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str) | |
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str) | |
self.saved_params = {"updatable_params": {"sigma": 0.05}, | |
"updatable_params_converter": {"sigma": float}, | |
"fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01, | |
"psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}} | |
elif self.name == "MotionBlur_medium": | |
psf_size = 31 | |
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid", | |
device=device_str) | |
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str) | |
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str) | |
self.saved_params = {"updatable_params": {"sigma": 0.05}, | |
"updatable_params_converter": {"sigma": float}, | |
"fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05, | |
"psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}} | |
elif self.name == "MotionBlur_hard": | |
psf_size = 31 | |
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid", | |
device=device_str) | |
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str) | |
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str) | |
self.saved_params = {"updatable_params": {"sigma": 0.05}, | |
"updatable_params_converter": {"sigma": float}, | |
"fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1, | |
"psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}} | |
elif self.name == "GaussianBlur": | |
psf_size = 31 | |
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid", | |
device=device_str) | |
self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size), num_channels=1, | |
device=device_str) | |
self.generator = self.physics_generator + self.sigma_generator | |
self.saved_params = {"updatable_params": {"sigma": 0.05}, | |
"updatable_params_converter": {"sigma": float}, | |
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2, | |
"psf_size": 31, "num_channels": 1}} | |
elif self.name == "MRI": | |
self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01), | |
device=device_str) | |
self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4) | |
self.generator = self.physics_generator # + self.sigma_generator | |
self.saved_params = {"updatable_params": {"sigma": 0.05}, | |
"updatable_params_converter": {"sigma": float}, | |
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2, | |
"acceleration_factor": 4}} | |
elif self.name == "CT": | |
acceleration_factor = 10 | |
img_h = 480 | |
angles = int(img_h / acceleration_factor) | |
# angles = torch.linspace(0, 180, steps=10) | |
self.physics = dinv.physics.Tomography( | |
img_width=img_h, | |
angles=angles, | |
circle=False, | |
normalize=True, | |
device=device_str, | |
noise_model=dinv.physics.GaussianNoise(sigma=1e-4), | |
max_iter=10, | |
) | |
self.physics_generator = None | |
self.generator = self.sigma_generator | |
self.saved_params = {"updatable_params": {"sigma": 0.1}, | |
"updatable_params_converter": {"sigma": float}, | |
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0., | |
"angles": angles, "max_iter": 10}} | |
def display_saved_params(self) -> str: | |
"""Printable version of saved_params.""" | |
updatable_params_str = "Updatable parameters:\n" | |
for param_name, param_value in self.saved_params["updatable_params"].items(): | |
updatable_params_str += f"\t\t{param_name} = {param_value}" + "\n" | |
fixed_params_str = "Fixed parameters:\n" | |
for param_name, param_value in self.saved_params["fixed_params"].items(): | |
fixed_params_str += f"\t\t{param_name} = {param_value}" + "\n" | |
return updatable_params_str + fixed_params_str | |
def _update_save_params(self, key: str, value: Any) -> None: | |
"""Update value of an existing key in save_params.""" | |
if key in list(self.saved_params["updatable_params"].keys()): | |
if type(value) == str: # it may be only a str representation | |
# type: str -> ??? | |
value = self.saved_params["updatable_params_converter"][key](value) | |
elif isinstance(value, torch.Tensor): | |
value = value.item() # type: torch.Tensor -> float | |
value = float(f"{value:.4f}") # keeps only 4 significant digits | |
self.saved_params["updatable_params"][key] = value | |
def update_and_display_params(self, key, value) -> str: | |
"""_update_save_params + update physics with saved_params + display_saved_params""" | |
self._update_save_params(key, value) | |
if self.name == "Denoising": | |
self.physics.noise_model.update_parameters(**self.saved_params["updatable_params"]) | |
else: | |
self.physics.update_parameters(**self.saved_params["updatable_params"]) | |
return self.display_saved_params() | |
def update_saved_params_and_physics(self, **kwargs) -> None: | |
"""Update save_params and update physics.""" | |
for key, value in kwargs.items(): | |
self._update_save_params(key, value) | |
self.physics.update(**kwargs) | |
def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor: | |
if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"): | |
use_gen = True | |
elif self.name in ["MRI"] and not hasattr(self.physics, "mask"): | |
use_gen = True | |
if use_gen: | |
kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample | |
self.update_saved_params_and_physics(**kwargs) | |
return self.physics(x) | |
class EvalModel(torch.nn.Module): | |
"""Eval model. | |
Is there a difference with BaselineModel ? | |
-> BaselineModel should be models that are already trained and will have fixed weights. | |
-> Eval model will change depending on differents checkpoints. | |
""" | |
all_models = ["unext_emb_physics_config_C"] | |
def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None: | |
"""Load the model we want to evaluate.""" | |
super().__init__() | |
self.base_name = model_name | |
self.ckpt_pth = ckpt_pth | |
self.name = self.base_name | |
if self.base_name not in self.all_models: | |
raise ValueError(f"{self.base_name} is unavailable.") | |
if self.base_name == "unext_emb_physics_config_C": | |
if self.ckpt_pth == "": | |
self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar" | |
self.model = get_model(model_name=self.base_name, | |
device='cpu', | |
**DEFAULT_MODEL_PARAMS) | |
# load model checkpoint | |
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[ | |
'state_dict'] # load on cpu | |
self.model.load_state_dict(state_dict) | |
self.model.to(device_str) | |
self.model.eval() | |
# add epoch in the model name | |
epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch'] | |
self.name = self.name + f"+{epoch}" | |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor: | |
return self.model(y, physics=physics) | |
class BaselineModel(torch.nn.Module): | |
"""Baseline model. | |
Is there a difference with EvalModel ? | |
-> BaselineModel should be models that are already trained and will have fixed weights. | |
-> Eval model will change depending on differents checkpoints. | |
""" | |
all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR", | |
"DPIR_MRI", "DPIR_CT", "PDNET"] | |
def __init__(self, model_name: str, device_str: str = "cpu") -> None: | |
super().__init__() | |
self.base_name = model_name | |
self.ckpt_pth = "" | |
self.name = self.base_name | |
if self.name not in self.all_baselines: | |
raise ValueError(f"{self.name} is unavailable.") | |
elif self.name == "DRUNET": | |
n_channels = 3 | |
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" | |
self.model = dinv.models.DRUNet(in_channels=n_channels, | |
out_channels=n_channels, | |
device=device_str, | |
pretrained=ckpt_pth) | |
self.model.eval() # Set the model to evaluation mode | |
elif self.name == 'PDNET': | |
ckpt_pth = "ckpt/pdnet.pth.tar" | |
self.model = get_model(model_name='pdnet', | |
device=device_str) | |
self.model.eval() | |
self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict']) | |
elif self.name == "SWINIRx2": | |
n_channels = 3 | |
scale = 2 | |
ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth" | |
upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle' | |
self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8, | |
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, | |
num_heads=[6, 6, 6, 6, 6, 6], | |
mlp_ratio=2, upsampler=upsampler, resi_connection='1conv', | |
pretrained=ckpt_pth) | |
self.model.to(device_str) | |
self.model.eval() # Set the model to evaluation mode | |
elif self.name == "SWINIRx4": | |
n_channels = 3 | |
scale = 4 | |
ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth" | |
upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle' | |
self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8, | |
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, | |
num_heads=[6, 6, 6, 6, 6, 6], | |
mlp_ratio=2, upsampler=upsampler, resi_connection='1conv', | |
pretrained=ckpt_pth) | |
self.model.to(device_str) | |
self.model.eval() # Set the model to evaluation mode | |
elif self.name == "PnP-PGD-DRUNET": | |
n_channels = 3 | |
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" | |
drunet = dinv.models.DRUNet(in_channels=n_channels, | |
out_channels=n_channels, | |
device=device_str, | |
pretrained=ckpt_pth) | |
drunet.eval() # Set the model to evaluation mode | |
self.model = dinv.optim.optim_builder(iteration="PGD", | |
prior=dinv.optim.PnP(drunet).to(device_str), | |
data_fidelity=dinv.optim.L2(), | |
max_iter=20, | |
params_algo={'stepsize': 1., 'g_param': .05}) | |
elif self.name == "DPIR": | |
n_channels = 3 | |
ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" | |
drunet = dinv.models.DRUNet(in_channels=n_channels, | |
out_channels=n_channels, | |
device=device_str, | |
pretrained=ckpt_pth) | |
drunet.eval() # Set the model to evaluation mode | |
# Specify the denoising prior | |
self.prior = dinv.optim.prior.PnP(denoiser=drunet) | |
elif self.name == "DPIR_MRI": | |
class ComplexDenoiser(torch.nn.Module): | |
def __init__(self, denoiser): | |
super().__init__() | |
self.denoiser = denoiser | |
def forward(self, x, sigma): | |
noisy_batch = torch.cat((x[:, 0:1, ...], x[:, 1:2, ...]), 0) | |
input_min = noisy_batch.min() | |
denoised_batch = self.denoiser(noisy_batch - input_min, sigma) | |
denoised_batch = denoised_batch + input_min | |
denoised = torch.cat((denoised_batch[0:1, ...], denoised_batch[1:2, ...]), 1) | |
return denoised | |
# Load PnP denoiser backbone | |
n_channels = 1 | |
ckpt_pth = "ckpt/drunet_gray.pth" | |
drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, | |
pretrained=ckpt_pth) | |
complex_drunet = ComplexDenoiser(drunet) | |
complex_drunet.eval() | |
# Specify the denoising prior | |
self.prior = dinv.optim.prior.PnP(denoiser=complex_drunet) | |
elif self.name == "DPIR_CT": | |
class CTDenoiser(torch.nn.Module): | |
def __init__(self, denoiser): | |
super().__init__() | |
self.denoiser = denoiser | |
def forward(self, x, sigma): | |
x = x - x.min() | |
denoised = self.denoiser(x, sigma) | |
denoised = denoised + x.min() | |
return denoised | |
# Load PnP denoiser backbone | |
n_channels = 1 | |
ckpt_pth = "ckpt/drunet_gray.pth" | |
drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, | |
pretrained=ckpt_pth) | |
ct_drunet = CTDenoiser(drunet) | |
ct_drunet.eval() | |
# Specify the denoising prior | |
self.prior = dinv.optim.prior.PnP(denoiser=ct_drunet) | |
def circular_roll(self, tensor, p_h, p_w): | |
return tensor.roll(shifts=(p_h, p_w), dims=(-2, -1)) | |
def get_DPIR_params(self, noise_level_img, max_iter=8): | |
r""" | |
Default parameters for the DPIR Plug-and-Play algorithm. | |
:param float noise_level_img: Noise level of the input image. | |
:return: tuple(list with denoiser noise level per iteration, list with stepsize per iteration, iterations). | |
""" | |
max_iter = 8 | |
s1 = 49.0 / 255.0 | |
s2 = max(noise_level_img, 0.01) | |
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( | |
np.float32 | |
) | |
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 | |
lamb = 1 / 0.23 | |
return list(sigma_denoiser), list(lamb * stepsize) | |
def get_DPIR_MRI_params(self, noise_level_img: float, max_iter: int = 8): | |
r""" | |
Default parameters for the DPIR Plug-and-Play algorithm. | |
:param float noise_level_img: Noise level of the input image. | |
""" | |
s1 = 49.0 / 255.0 | |
s2 = noise_level_img | |
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( | |
np.float32 | |
) | |
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 | |
lamb = 1. | |
return lamb, list(sigma_denoiser), list(stepsize), max_iter | |
def get_DPIR_CT_params(self, noise_level_img: float, max_iter: int = 8, lip_cons: float = 1.0): | |
r""" | |
Default parameters for the DPIR Plug-and-Play algorithm. | |
:param float noise_level_img: Noise level of the input image. | |
""" | |
s1 = 49.0 / 255.0 * lip_cons | |
s2 = noise_level_img | |
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( | |
np.float32 | |
) | |
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 # | |
lamb = 1. | |
return lamb, list(sigma_denoiser), list(stepsize), max_iter | |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor: | |
if self.name == "DRUNET": | |
return self.model(y, sigma=physics.noise_model.sigma) | |
elif self.name == "PnP-PGD-DRUNET": | |
return self.model(y, physics=physics) | |
elif self.name == "DPIR": | |
# Set the DPIR algorithm parameters | |
sigma_float = physics.noise_model.sigma.item() # sigma should be a single value | |
max_iter = 8 | |
sigma_denoiser, stepsize = self.get_DPIR_params(sigma_float, max_iter=max_iter) | |
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser} | |
early_stop = False # Do not stop algorithm with convergence criteria | |
# instantiate DPIR | |
model = dinv.optim.optim_builder( | |
iteration="HQS", | |
prior=self.prior, | |
data_fidelity=dinv.optim.data_fidelity.L2(), | |
early_stop=early_stop, | |
max_iter=max_iter, | |
verbose=True, | |
params_algo=params_algo, | |
) | |
return model(y, physics=physics) | |
elif self.name == "DPIR_MRI": | |
sigma_float = max(physics.noise_model.sigma.item(), 0.015) # sigma should be a single value | |
lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_MRI_params(sigma_float, max_iter=16) | |
stepsize = [stepsize[0]] * max_iter | |
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb} | |
early_stop = False # Do not stop algorithm with convergence criteria | |
# Instantiate the algorithm class to solve the IP | |
model = dinv.optim.optim_builder( | |
iteration="HQS", | |
prior=self.prior, | |
data_fidelity=dinv.optim.data_fidelity.L2(), | |
early_stop=early_stop, | |
max_iter=max_iter, | |
verbose=True, | |
params_algo=params_algo, | |
) | |
return model(y, physics=physics) | |
elif self.name == "DPIR_CT": | |
# Set the DPIR algorithm parameters | |
sigma_float = physics.noise_model.sigma.item() # sigma should be a single value | |
lip_const = physics.compute_norm(physics.A_adjoint(y)) | |
lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8, | |
lip_cons=lip_const.item()) | |
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb} | |
early_stop = False # Do not stop algorithm with convergence criteria | |
def custom_init(y, physic_op): | |
x_init = physic_op.prox_l2(physic_op.A_adjoint(y), y, gamma=1e4) | |
return {"est": (x_init, x_init)} | |
# Instantiate the algorithm class to solve the IP | |
algo = dinv.optim.optim_builder( | |
iteration="HQS", | |
prior=self.prior, | |
data_fidelity=dinv.optim.data_fidelity.L2(), | |
early_stop=early_stop, | |
max_iter=max_iter, | |
verbose=True, | |
params_algo=params_algo, | |
custom_init=custom_init | |
) | |
return algo(y, physics=physics) | |
elif self.name == 'SWINIRx4': | |
window_size = 8 | |
scale = 4 | |
_, _, h_old, w_old = y.size() | |
h_pad = (h_old // window_size + 1) * window_size - h_old | |
w_pad = (w_old // window_size + 1) * window_size - w_old | |
img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :] | |
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] | |
output = self.model(img_lq) | |
output = output[..., :h_old * scale, :w_old * scale] | |
output = self.circular_roll(output, -2, -2) | |
# check shape of adjoint | |
x_adj = physics.A_adjoint(y) | |
output = output[..., :x_adj.size(-2), :x_adj.size(-1)] | |
return output | |
elif self.name == 'SWINIRx2': | |
window_size = 8 | |
scale = 2 | |
_, _, h_old, w_old = y.size() | |
h_pad = (h_old // window_size + 1) * window_size - h_old | |
w_pad = (w_old // window_size + 1) * window_size - w_old | |
img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :] | |
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] | |
output = self.model(img_lq) | |
output = output[..., :h_old * scale, :w_old * scale] | |
output = self.circular_roll(output, -1, -1) | |
# check shape of adjoint | |
x_adj = physics.A_adjoint(y) | |
output = output[..., :x_adj.size(-2), :x_adj.size(-1)] | |
return output | |
else: | |
return self.model(y) | |
class EvalDataset(torch.utils.data.Dataset): | |
""" | |
We expect that images are 480x480. | |
""" | |
all_datasets = ["Natural", "MRI", "CT"] | |
def __init__(self, dataset_name: str, device_str: str = "cpu") -> None: | |
self.name = dataset_name | |
self.device_str = device_str | |
if self.name not in self.all_datasets: | |
raise ValueError(f"{self.name} is unavailable.") | |
if self.name == 'Natural': | |
self.root = 'img_samples/LSDIR_samples' | |
self.transform = transforms.Compose([transforms.ToTensor()]) | |
self.dataset = dinv.datasets.LsdirHR(root=self.root, | |
download=False, | |
transform=self.transform) | |
elif self.name == 'MRI': | |
self.root = 'img_samples/FastMRI_samples' | |
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True) | |
self.dataset = Preprocessed_fastMRI(root=self.root, | |
transform=self.transform, | |
preprocess=False) | |
elif self.name == "CT": | |
self.root = 'img_samples/LIDC_IDRI_samples' | |
self.transform = None | |
self.dataset = Preprocessed_LIDCIDRI(root=self.root, | |
transform=self.transform) | |
def __len__(self) -> int: | |
return len(self.dataset) | |
def __getitem__(self, idx: int) -> torch.Tensor: | |
return self.dataset[idx].to(self.device_str) | |
class Metric(): | |
"""Metrics and utilities.""" | |
all_metrics = ["PSNR", "SSIM", "LPIPS"] | |
def __init__(self, metric_name: str, device_str: str = "cpu") -> None: | |
self.name = metric_name | |
if self.name not in self.all_metrics: | |
raise ValueError(f"{self.name} is unavailable.") | |
elif self.name == "PSNR": | |
self.metric = dinv.loss.metric.PSNR() | |
elif self.name == "SSIM": | |
self.metric = dinv.loss.metric.SSIM() | |
elif self.name == "LPIPS": | |
self.metric = dinv.loss.metric.LPIPS(device=device_str) | |
def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |
# it may happen that x_net and x do not have the same size, in which case we take the minimum size of both | |
if x_net.shape[-1] != x.shape[-1]: | |
min_size = min(x_net.shape[-1], x.shape[-1]) | |
x_net_crop = x_net[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2, | |
x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2] | |
x_crop = x[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2, | |
x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2] | |
else: | |
x_net_crop = x_net | |
x_crop = x | |
return self.metric(x_net_crop, x_crop) | |
def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]: | |
l = [] | |
for metric_name in metric_names: | |
l.append(cls(metric_name, device_str=device_str)) | |
return l | |