diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..948944e4ff4ddbbf2089a6032d70f6a291a79efe 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.pth.tar filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0015de3e1947955b9b9cc253c712252b63af4ad5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.ipynb +__pycache__ diff --git a/app.py b/app.py index 7c810fe441750ba0b7c10e563b4e5ea645603a45..a5f8f49cba35ba2ef623bf8beb9f32d15ca651b5 100644 --- a/app.py +++ b/app.py @@ -1,78 +1,269 @@ -import gradio as gr +import json +import os +import random +from functools import partial +from pathlib import Path +from typing import List + import deepinv as dinv +import gradio as gr import torch -import numpy as np -import PIL.Image +from PIL import Image +from torchvision import transforms + +from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric + + +### Gradio Utils +def generate_imgs(dataset: EvalDataset, idx: int, + model: EvalModel, baseline: BaselineModel, + physics: PhysicsWithGenerator, use_gen: bool, + metrics: List[Metric]): + ### Load 1 image + x = dataset[idx] # shape : (3, 256, 256) + x = x.unsqueeze(0) # shape : (1, 3, 256, 256) + with torch.no_grad(): + ### Compute y + y = physics(x, use_gen) # possible reduction in img shape due to Blurring -def pil_to_torch(image, ref_size=512): - image = np.array(image) - image = image.transpose((2, 0, 1)) - image = torch.tensor(image).float() / 255 - image = image.unsqueeze(0) + ### Compute x_hat + out = model(y=y, physics=physics.physics) + out_baseline = baseline(y=y, physics=physics.physics) - if ref_size == 256: - size = (ref_size, ref_size) - elif image.shape[2] > image.shape[3]: - size = (ref_size, ref_size * image.shape[3]//image.shape[2]) + ### Process tensors before metric computation + if "Blur" in physics.name: + w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2 + h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2 + + x = x[..., w_1:w_2, h_1:h_2] + out = out[..., w_1:w_2, h_1:h_2] + if out_baseline.shape != out.shape: + out_baseline = out_baseline[..., w_1:w_2, h_1:h_2] + + ### Metrics + metrics_y = "" + metrics_out = "" + metrics_out_baseline = "" + for metric in metrics: + if y.shape == x.shape: + metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n" + metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n" + metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n" + + ### Process y when y shape is different from x shape + if physics.name == "MRI" or "SR" in physics.name: + y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4) else: - size = (ref_size * image.shape[2]//image.shape[3], ref_size) - - image = torch.nn.functional.interpolate(image, size=size, mode='bilinear') - return image - - -def torch_to_pil(image): - image = image.squeeze(0).cpu().detach().numpy() - image = image.transpose((1, 2, 0)) - image = (np.clip(image, 0, 1) * 255).astype(np.uint8) - image = PIL.Image.fromarray(image) - return image - - -def image_mod(image, noise_level, denoiser): - image = pil_to_torch(image, ref_size=256 if denoiser == 'DiffUNet' else 512) - if denoiser == 'DnCNN': - den = dinv.models.DnCNN() - sigma0 = 2/255 - denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0 - elif denoiser == 'MedianFilter': - denoiser = dinv.models.MedianFilter(kernel_size=5) - elif denoiser == 'BM3D': - denoiser = dinv.models.BM3D() - elif denoiser == 'TV': - denoiser = dinv.models.TVDenoiser() - elif denoiser == 'TGV': - denoiser = dinv.models.TGVDenoiser() - elif denoiser == 'Wavelets': - denoiser = dinv.models.WaveletPrior() - elif denoiser == 'DiffUNet': - denoiser = dinv.models.DiffUNet() - elif denoiser == 'DRUNet': - denoiser = dinv.models.DRUNet() + y_plot = y.clone() + + ### Processing images for plotting : + # - clip value outside of [0,1] + # - shape (1, C, H, W) -> (C, H, W) + # - torch.Tensor object -> Pil object + process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip") + to_pil = transforms.ToPILImage() + x = to_pil(process_img(x)[0].to('cpu')) + y = to_pil(process_img(y_plot)[0].to('cpu')) + out = to_pil(process_img(out)[0].to('cpu')) + out_baseline = to_pil(process_img(out_baseline)[0].to('cpu')) + + + return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline + +def update_random_idx_and_generate_imgs(dataset: EvalDataset, + model: EvalModel, + baseline: BaselineModel, + physics: PhysicsWithGenerator, + use_gen: bool, + metrics: List[Metric]): + idx = random.randint(0, len(dataset)) + x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset, + idx, + model, + baseline, + physics, + use_gen, + metrics) + return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline + +def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator, + model_a: EvalModel | BaselineModel, model_b: EvalModel | BaselineModel, + x: Image.Image, y: Image.Image, + out_a: Image.Image, out_b: Image.Image, + y_metrics_str: str, + out_a_metric_str : str, out_b_metric_str: str) -> None: + + ### PROCESSES STR + physics_params_str = "" + for param_name, param_value in physics.saved_params["updatable_params"].items(): + physics_params_str += f"{param_name}_{param_value}-" + physics_params_str = physics_params_str[:-1] if physics_params_str.endswith("-") else physics_params_str + y_metrics_str = y_metrics_str.replace(" = ", "_").replace("\n", "-") + y_metrics_str = y_metrics_str[:-1] if y_metrics_str.endswith("-") else y_metrics_str + out_a_metric_str = out_a_metric_str.replace(" = ", "_").replace("\n", "-") + out_a_metric_str = out_a_metric_str[:-1] if out_a_metric_str.endswith("-") else out_a_metric_str + out_b_metric_str = out_b_metric_str.replace(" = ", "_").replace("\n", "-") + out_b_metric_str = out_b_metric_str[:-1] if out_b_metric_str.endswith("-") else out_b_metric_str + + save_path = SAVE_IMG_DIR / f"{dataset.name}+{idx}+{physics.name}+{physics_params_str}+{y_metrics_str}+{model_a.name}+{out_a_metric_str}+{model_b.name}+{out_b_metric_str}.png" + titles = [f"{dataset.name}[{idx}]", + f"y = {physics.name}(x)", + f"{model_a.name}", + f"{model_b.name}"] + + # Pil object -> torch.Tensor + to_tensor = transforms.ToTensor() + x = to_tensor(x) + y = to_tensor(y) + out_a = to_tensor(out_a) + out_b = to_tensor(out_b) + + dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path) + +get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR) +get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR) +get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR) +get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR) +get_physics_generator_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR) + +def get_model(model_name, ckpt_pth): + if model_name in BaselineModel.all_baselines: + return get_baseline_model_on_DEVICE_STR(model_name) else: - raise ValueError("Invalid denoiser") - noisy = image + torch.randn_like(image) * noise_level - estimated = denoiser(noisy, noise_level) - return torch_to_pil(noisy), torch_to_pil(estimated) + return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth) + + +### Gradio Blocks interface + +# Define custom CSS +custom_css = """ +.fixed-textbox textarea { + height: 90px !important; /* Adjust height to fit exactly 4 lines */ + overflow: scroll; /* Add a scroll bar if necessary */ + resize: none; /* User can resize vertically the textbox */ +} +""" +title = "Inverse problem playground" # displayed on gradio tab and in the gradio page +with gr.Blocks(title=title, css=custom_css) as interface: + gr.Markdown("## " + title) -input_image = gr.Image(label='Input Image') -output_images = gr.Image(label='Denoised Image') -noise_image = gr.Image(label='Noisy Image') -input_image_output = gr.Image(label='Input Image') + # Loading things + model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State + model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State + dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("DIV2K_valid_HR")) + physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("Denoising")) # lambda expression to instanciate a callable in a gr.State + metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"])) -noise_levels = gr.Dropdown(choices=[0.05, 0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level') + @gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder]) + def dynamic_layout(model_a, model_b, dataset, physics, metrics): + ### LAYOUT + model_a_name = model_a.base_name + model_a_full_name = model_a.name + model_b_name = model_b.base_name + model_b_full_name = model_b.name + dataset_name = dataset.name + physics_name = physics.name + metric_names = [metric.name for metric in metrics] -denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DRUNet', label='Denoiser') + # Components: Inputs/Outputs + Load EvalDataset/PhysicsWithGenerator/EvalModel/BaselineModel + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=False) + physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params()) + with gr.Column(): + y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False) + y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],) + with gr.Row(): + choose_dataset = gr.Radio(choices=EvalDataset.all_datasets, + label="List of EvalDataset", + value=dataset_name, + scale=2) + idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1) + + choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics, + label="List of PhysicsWithGenerator", + value=physics_name) + with gr.Row(): + key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()), + label="Updatable Parameter Key", + scale=2) + value_text = gr.Textbox(label="Update Value", scale=2) + with gr.Column(scale=1): + update_button = gr.Button("Update Param") + use_generator_button = gr.Checkbox(label="Use param generator") + + with gr.Column(): + with gr.Row(): + with gr.Column(): + model_a_out = gr.Image(label=f"{model_a_full_name} OUTPUT", interactive=False) + out_a_metric = gr.Textbox(label="Metrics(model_a(y), x)", elem_classes=["fixed-textbox"]) + load_model_a = gr.Button("Load model A...", scale=1) + with gr.Column(): + model_b_out = gr.Image(label=f"{model_b_full_name} OUTPUT", interactive=False) + out_b_metric = gr.Textbox(label="Metrics(model_b(y), x)", elem_classes=["fixed-textbox"]) + load_model_b = gr.Button("Load model B...", scale=1) + with gr.Row(): + choose_model_a = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines, + label="List of Model A", + value=model_a_name, + scale=2) + path_a_str = gr.Textbox(value=model_a.ckpt_pth, label="Checkpoint path", scale=3) + with gr.Row(): + choose_model_b = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines, + label="List of Model B", + value=model_b_name, + scale=2) + path_b_str = gr.Textbox(value=model_b.ckpt_pth, label="Checkpoint path", scale=3) + + # Components: Load Metric + Load/Save Buttons + with gr.Row(): + with gr.Column(): + choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics, + value=metric_names, + label="Choose metrics you are interested") + with gr.Column(): + load_button = gr.Button("Load images...") + load_random_button = gr.Button("Load randomly...") + save_button = gr.Button("Save images...") -demo = gr.Interface( - image_mod, - inputs=[input_image, noise_levels, denoiser], - examples=[['https://upload.wikimedia.org/wikipedia/commons/b/b4/Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg', 0.1, 'DRUNet']], - outputs=[noise_image, output_images], - title="Image Denoising with DeepInverse", - description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 512 pixels to reduce the computation time. For more advanced models, please run the code locally.", -) + ### Event listeners + choose_dataset.change(fn=get_dataset_on_DEVICE_STR, + inputs=choose_dataset, + outputs=dataset_placeholder) + choose_physics.change(fn=get_physics_generator_on_DEVICE_STR, + inputs=choose_physics, + outputs=physics_placeholder) + update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params) + load_model_a.click(fn=get_model, + inputs=[choose_model_a, path_a_str], + outputs=model_a_placeholder) + load_model_b.click(fn=get_model, + inputs=[choose_model_b, path_b_str], + outputs=model_b_placeholder) + choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR, + inputs=choose_metrics, + outputs=metrics_placeholder) + load_button.click(fn=generate_imgs, + inputs=[dataset_placeholder, + idx_slider, + model_a_placeholder, + model_b_placeholder, + physics_placeholder, + use_generator_button, + metrics_placeholder], + outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric]) + load_random_button.click(fn=update_random_idx_and_generate_imgs, + inputs=[dataset_placeholder, + model_a_placeholder, + model_b_placeholder, + physics_placeholder, + use_generator_button, + metrics_placeholder], + outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric]) -demo.launch() \ No newline at end of file +if __name__ == "__main__": + interface.launch() diff --git a/ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth b/ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth new file mode 100644 index 0000000000000000000000000000000000000000..9d90c6a3282aa86dec28aa65291157aaa3d5f705 --- /dev/null +++ b/ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2032ebf8f401dd3ce2fae5f3852117cb72101ec6ed8358faa64c2a3fa09ed4ac +size 67277475 diff --git a/ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth b/ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth new file mode 100755 index 0000000000000000000000000000000000000000..3747ac5da2cd6af974235d7c6c8b94e53a16a0ff --- /dev/null +++ b/ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e78e33f22c1aa8a773db0cf4a7381bae97c2362c717f155439ebc690cbd9215 +size 67869037 diff --git a/ckpt/drunet_deepinv_color_finetune_22k.pth b/ckpt/drunet_deepinv_color_finetune_22k.pth new file mode 100755 index 0000000000000000000000000000000000000000..39a79fd9cf1bed0dd7fc9a4ea1ce7874a87bdfbb --- /dev/null +++ b/ckpt/drunet_deepinv_color_finetune_22k.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20296845d272d3d786b89ea3c1208d5f2ceb57658a499d4dd28073cbb73508aa +size 130585443 diff --git a/ckpt/drunet_gray.pth b/ckpt/drunet_gray.pth new file mode 100755 index 0000000000000000000000000000000000000000..ef5a7eb2fa84348203608e7e7c0e040287db9dc9 --- /dev/null +++ b/ckpt/drunet_gray.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e27fb29456732c604c8ee3ac3e92ecadd2a7a4e36a8be675e92ebbf93a240de4 +size 130569961 diff --git a/ckpt/pdnet.pth.tar b/ckpt/pdnet.pth.tar new file mode 100755 index 0000000000000000000000000000000000000000..3eb9d9139a778e18251d6fdd4b7628881cca1d81 --- /dev/null +++ b/ckpt/pdnet.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9b898d9d7deac148c3ec722c956231259d13d5329e08d42ba15f212e1967fe0 +size 5756509 diff --git a/ckpt/ram_ckp_10.pth.tar b/ckpt/ram_ckp_10.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..5dbcf3686dba10cb1b7840ea510bb59ae82c627d --- /dev/null +++ b/ckpt/ram_ckp_10.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b07c87641c0045112461a11723c4f32b1aa1d351ac070f76b40c0c84c86d2a01 +size 427825566 diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..0be91f749824e77f9b23659a81aad6cfa074f28b --- /dev/null +++ b/datasets.py @@ -0,0 +1,84 @@ +from pathlib import Path +from typing import Callable, Optional +import os + +import torch +from torch.utils.data import Dataset + + +class Preprocessed_fastMRI(torch.utils.data.Dataset): + """FastMRI from preprocessed data for faster lading.""" + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + preprocess: bool = False, + ) -> None: + self.root = root + self.transform = transform + self.preprocess = preprocess + + # should contain all the information to load a data sample from the storage + self.sample_identifiers = [] + + # append all filenames in self.root ending with .pt + for root, _, files in os.walk(self.root): + for file in files: + if file.endswith(".pt"): + self.sample_identifiers.append(file) + + def __len__(self) -> int: + return len(self.sample_identifiers) + + def __getitem__(self, idx: int): + fname = self.sample_identifiers[idx] + + tensor = torch.load(os.path.join(self.root, fname), weights_only=True) + img = tensor['data'].float() + + if self.transform is not None: + img = self.transform(img) + + if not self.preprocess: + return img + + else: + # remove extension and prefix from filename + fname = Path(fname).stem + return img, fname + + +class Preprocessed_LIDCIDRI(torch.utils.data.Dataset): + """FastMRI from preprocessed data for faster lading.""" + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + ) -> None: + self.root = root + self.transform = transform + + # should contain all the information to load a data sample from the storage + self.sample_identifiers = [] + + # append all filenames in self.root ending with .pt + for root, _, files in os.walk(self.root): + for file in files: + if file.endswith(".pt"): + self.sample_identifiers.append(file) + + def __len__(self) -> int: + return len(self.sample_identifiers) + + def __getitem__(self, idx: int): + fname = self.sample_identifiers[idx] + + tensor = torch.load(os.path.join(self.root, fname), weights_only=True) + img = tensor['data'].float() + + if self.transform is not None: + img = self.transform(img) + + img = img.unsqueeze(0) # add channel dim diff --git a/evals.py b/evals.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d3baeb2cf549c3d2288877106ffaa65de2f752 --- /dev/null +++ b/evals.py @@ -0,0 +1,564 @@ +from typing import Any, List + +import deepinv as dinv +import numpy as np +import torch +from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator +from torchvision import transforms + +from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI +from utils import get_model + +DEFAULT_MODEL_PARAMS = { + "in_channels": [1, 2, 3], + "grayscale": False, + "conv_type": "base", + "pool_type": "base", + "layer_scale_init_value": 1e-6, + "init_type": "ortho", + "gain_init_conv": 1.0, + "gain_init_linear": 1.0, + "drop_prob": 0.0, + "replk": False, + "mult_fact": 4, + "antialias": "gaussian", + "nc_base": 64, + "cond_type": "base", + "blind": False, + "pretrained_pth": None, + "N": 2, + "c_mult": 2, + "depth_encoding": 2, + "relu_in_encoding": False, + "skip_in_encoding": True +} + + +class PhysicsWithGenerator(torch.nn.Module): + """Interface between Physics, Generator and Gradio.""" + all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur", + "MRI", "CT"] + + def __init__(self, physics_name: str, device_str: str = "cpu") -> None: + super().__init__() + + self.name = physics_name + if self.name not in self.all_physics: + raise ValueError(f"{self.name} is unavailable.") + + self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str) + if self.name == "MotionBlur_easy": + psf_size = 31 + self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid", + device=device_str) + self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str) + self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str) + self.saved_params = {"updatable_params": {"sigma": 0.05}, + "updatable_params_converter": {"sigma": float}, + "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01, + "psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}} + elif self.name == "MotionBlur_medium": + psf_size = 31 + self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid", + device=device_str) + self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str) + self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str) + self.saved_params = {"updatable_params": {"sigma": 0.05}, + "updatable_params_converter": {"sigma": float}, + "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05, + "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}} + elif self.name == "MotionBlur_hard": + psf_size = 31 + self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid", + device=device_str) + self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str) + self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str) + self.saved_params = {"updatable_params": {"sigma": 0.05}, + "updatable_params_converter": {"sigma": float}, + "fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1, + "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}} + elif self.name == "GaussianBlur": + psf_size = 31 + self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid", + device=device_str) + self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size), num_channels=1, + device=device_str) + self.generator = self.physics_generator + self.sigma_generator + self.saved_params = {"updatable_params": {"sigma": 0.05}, + "updatable_params_converter": {"sigma": float}, + "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2, + "psf_size": 31, "num_channels": 1}} + elif self.name == "MRI": + self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01), + device=device_str) + self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4) + self.generator = self.physics_generator # + self.sigma_generator + self.saved_params = {"updatable_params": {"sigma": 0.05}, + "updatable_params_converter": {"sigma": float}, + "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2, + "acceleration_factor": 4}} + elif self.name == "CT": + acceleration_factor = 10 + img_h = 480 + angles = int(img_h / acceleration_factor) + # angles = torch.linspace(0, 180, steps=10) + self.physics = dinv.physics.Tomography( + img_width=img_h, + angles=angles, + circle=False, + normalize=True, + device=device_str, + noise_model=dinv.physics.GaussianNoise(sigma=1e-4), + max_iter=10, + ) + self.physics_generator = None + self.generator = self.sigma_generator + self.saved_params = {"updatable_params": {"sigma": 0.1}, + "updatable_params_converter": {"sigma": float}, + "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0., + "angles": angles, "max_iter": 10}} + + def display_saved_params(self) -> str: + """Printable version of saved_params.""" + updatable_params_str = "Updatable parameters:\n" + for param_name, param_value in self.saved_params["updatable_params"].items(): + updatable_params_str += f"\t\t{param_name} = {param_value}" + "\n" + + fixed_params_str = "Fixed parameters:\n" + for param_name, param_value in self.saved_params["fixed_params"].items(): + fixed_params_str += f"\t\t{param_name} = {param_value}" + "\n" + + return updatable_params_str + fixed_params_str + + def _update_save_params(self, key: str, value: Any) -> None: + """Update value of an existing key in save_params.""" + if key in list(self.saved_params["updatable_params"].keys()): + if type(value) == str: # it may be only a str representation + # type: str -> ??? + value = self.saved_params["updatable_params_converter"][key](value) + elif isinstance(value, torch.Tensor): + value = value.item() # type: torch.Tensor -> float + value = float(f"{value:.4f}") # keeps only 4 significant digits + self.saved_params["updatable_params"][key] = value + + def update_and_display_params(self, key, value) -> str: + """_update_save_params + update physics with saved_params + display_saved_params""" + self._update_save_params(key, value) + + if self.name == "Denoising": + self.physics.noise_model.update_parameters(**self.saved_params["updatable_params"]) + else: + self.physics.update_parameters(**self.saved_params["updatable_params"]) + + return self.display_saved_params() + + def update_saved_params_and_physics(self, **kwargs) -> None: + """Update save_params and update physics.""" + for key, value in kwargs.items(): + self._update_save_params(key, value) + + self.physics.update(**kwargs) + + def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor: + if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"): + use_gen = True + elif self.name in ["MRI"] and not hasattr(self.physics, "mask"): + use_gen = True + + if use_gen: + kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample + self.update_saved_params_and_physics(**kwargs) + + return self.physics(x) + + +class EvalModel(torch.nn.Module): + """Eval model. + + Is there a difference with BaselineModel ? + -> BaselineModel should be models that are already trained and will have fixed weights. + -> Eval model will change depending on differents checkpoints. + """ + all_models = ["unext_emb_physics_config_C"] + + def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None: + """Load the model we want to evaluate.""" + super().__init__() + self.base_name = model_name + self.ckpt_pth = ckpt_pth + self.name = self.base_name + if self.base_name not in self.all_models: + raise ValueError(f"{self.base_name} is unavailable.") + if self.base_name == "unext_emb_physics_config_C": + if self.ckpt_pth == "": + self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar" + self.model = get_model(model_name=self.base_name, + device='cpu', + **DEFAULT_MODEL_PARAMS) + + # load model checkpoint + state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[ + 'state_dict'] # load on cpu + self.model.load_state_dict(state_dict) + self.model.to(device_str) + self.model.eval() + + # add epoch in the model name + epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch'] + self.name = self.name + f"+{epoch}" + + def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor: + return self.model(y, physics=physics) + + +class BaselineModel(torch.nn.Module): + """Baseline model. + + Is there a difference with EvalModel ? + -> BaselineModel should be models that are already trained and will have fixed weights. + -> Eval model will change depending on differents checkpoints. + """ + all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR", + "DPIR_MRI", "DPIR_CT", "PDNET"] + + def __init__(self, model_name: str, device_str: str = "cpu") -> None: + super().__init__() + self.base_name = model_name + self.ckpt_pth = "" + self.name = self.base_name + if self.name not in self.all_baselines: + raise ValueError(f"{self.name} is unavailable.") + elif self.name == "DRUNET": + n_channels = 3 + ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" + self.model = dinv.models.DRUNet(in_channels=n_channels, + out_channels=n_channels, + device=device_str, + pretrained=ckpt_pth) + self.model.eval() # Set the model to evaluation mode + elif self.name == 'PDNET': + ckpt_pth = "ckpt/pdnet.pth.tar" + self.model = get_model(model_name='pdnet', + device=device_str) + self.model.eval() + self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict']) + elif self.name == "SWINIRx2": + n_channels = 3 + scale = 2 + ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth" + upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle' + self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8, + img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, upsampler=upsampler, resi_connection='1conv', + pretrained=ckpt_pth) + self.model.to(device_str) + self.model.eval() # Set the model to evaluation mode + elif self.name == "SWINIRx4": + n_channels = 3 + scale = 4 + ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth" + upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle' + self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8, + img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, upsampler=upsampler, resi_connection='1conv', + pretrained=ckpt_pth) + self.model.to(device_str) + self.model.eval() # Set the model to evaluation mode + + elif self.name == "PnP-PGD-DRUNET": + n_channels = 3 + ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" + drunet = dinv.models.DRUNet(in_channels=n_channels, + out_channels=n_channels, + device=device_str, + pretrained=ckpt_pth) + drunet.eval() # Set the model to evaluation mode + self.model = dinv.optim.optim_builder(iteration="PGD", + prior=dinv.optim.PnP(drunet).to(device_str), + data_fidelity=dinv.optim.L2(), + max_iter=20, + params_algo={'stepsize': 1., 'g_param': .05}) + elif self.name == "DPIR": + n_channels = 3 + ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth" + drunet = dinv.models.DRUNet(in_channels=n_channels, + out_channels=n_channels, + device=device_str, + pretrained=ckpt_pth) + drunet.eval() # Set the model to evaluation mode + + # Specify the denoising prior + self.prior = dinv.optim.prior.PnP(denoiser=drunet) + elif self.name == "DPIR_MRI": + class ComplexDenoiser(torch.nn.Module): + def __init__(self, denoiser): + super().__init__() + self.denoiser = denoiser + + def forward(self, x, sigma): + noisy_batch = torch.cat((x[:, 0:1, ...], x[:, 1:2, ...]), 0) + input_min = noisy_batch.min() + denoised_batch = self.denoiser(noisy_batch - input_min, sigma) + denoised_batch = denoised_batch + input_min + denoised = torch.cat((denoised_batch[0:1, ...], denoised_batch[1:2, ...]), 1) + return denoised + + # Load PnP denoiser backbone + n_channels = 1 + ckpt_pth = "ckpt/drunet_gray.pth" + drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, + pretrained=ckpt_pth) + complex_drunet = ComplexDenoiser(drunet) + complex_drunet.eval() + + # Specify the denoising prior + self.prior = dinv.optim.prior.PnP(denoiser=complex_drunet) + elif self.name == "DPIR_CT": + class CTDenoiser(torch.nn.Module): + def __init__(self, denoiser): + super().__init__() + self.denoiser = denoiser + + def forward(self, x, sigma): + x = x - x.min() + denoised = self.denoiser(x, sigma) + denoised = denoised + x.min() + return denoised + + # Load PnP denoiser backbone + n_channels = 1 + ckpt_pth = "ckpt/drunet_gray.pth" + drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str, + pretrained=ckpt_pth) + ct_drunet = CTDenoiser(drunet) + ct_drunet.eval() + + # Specify the denoising prior + self.prior = dinv.optim.prior.PnP(denoiser=ct_drunet) + + def circular_roll(self, tensor, p_h, p_w): + return tensor.roll(shifts=(p_h, p_w), dims=(-2, -1)) + + def get_DPIR_params(self, noise_level_img, max_iter=8): + r""" + Default parameters for the DPIR Plug-and-Play algorithm. + + :param float noise_level_img: Noise level of the input image. + :return: tuple(list with denoiser noise level per iteration, list with stepsize per iteration, iterations). + """ + max_iter = 8 + s1 = 49.0 / 255.0 + s2 = max(noise_level_img, 0.01) + sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( + np.float32 + ) + stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 + lamb = 1 / 0.23 + return list(sigma_denoiser), list(lamb * stepsize) + + def get_DPIR_MRI_params(self, noise_level_img: float, max_iter: int = 8): + r""" + Default parameters for the DPIR Plug-and-Play algorithm. + + :param float noise_level_img: Noise level of the input image. + """ + s1 = 49.0 / 255.0 + s2 = noise_level_img + sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( + np.float32 + ) + stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 + lamb = 1. + return lamb, list(sigma_denoiser), list(stepsize), max_iter + + def get_DPIR_CT_params(self, noise_level_img: float, max_iter: int = 8, lip_cons: float = 1.0): + r""" + Default parameters for the DPIR Plug-and-Play algorithm. + + :param float noise_level_img: Noise level of the input image. + """ + s1 = 49.0 / 255.0 * lip_cons + s2 = noise_level_img + sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( + np.float32 + ) + stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 # + lamb = 1. + return lamb, list(sigma_denoiser), list(stepsize), max_iter + + def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor: + if self.name == "DRUNET": + return self.model(y, sigma=physics.noise_model.sigma) + elif self.name == "PnP-PGD-DRUNET": + return self.model(y, physics=physics) + elif self.name == "DPIR": + # Set the DPIR algorithm parameters + sigma_float = physics.noise_model.sigma.item() # sigma should be a single value + max_iter = 8 + + sigma_denoiser, stepsize = self.get_DPIR_params(sigma_float, max_iter=max_iter) + params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser} + early_stop = False # Do not stop algorithm with convergence criteria + + # instantiate DPIR + model = dinv.optim.optim_builder( + iteration="HQS", + prior=self.prior, + data_fidelity=dinv.optim.data_fidelity.L2(), + early_stop=early_stop, + max_iter=max_iter, + verbose=True, + params_algo=params_algo, + ) + return model(y, physics=physics) + elif self.name == "DPIR_MRI": + sigma_float = max(physics.noise_model.sigma.item(), 0.015) # sigma should be a single value + lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_MRI_params(sigma_float, max_iter=16) + stepsize = [stepsize[0]] * max_iter + params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb} + early_stop = False # Do not stop algorithm with convergence criteria + + # Instantiate the algorithm class to solve the IP + model = dinv.optim.optim_builder( + iteration="HQS", + prior=self.prior, + data_fidelity=dinv.optim.data_fidelity.L2(), + early_stop=early_stop, + max_iter=max_iter, + verbose=True, + params_algo=params_algo, + ) + return model(y, physics=physics) + elif self.name == "DPIR_CT": + # Set the DPIR algorithm parameters + sigma_float = physics.noise_model.sigma.item() # sigma should be a single value + lip_const = physics.compute_norm(physics.A_adjoint(y)) + lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8, + lip_cons=lip_const.item()) + params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb} + early_stop = False # Do not stop algorithm with convergence criteria + + def custom_init(y, physic_op): + x_init = physic_op.prox_l2(physic_op.A_adjoint(y), y, gamma=1e4) + return {"est": (x_init, x_init)} + + # Instantiate the algorithm class to solve the IP + algo = dinv.optim.optim_builder( + iteration="HQS", + prior=self.prior, + data_fidelity=dinv.optim.data_fidelity.L2(), + early_stop=early_stop, + max_iter=max_iter, + verbose=True, + params_algo=params_algo, + custom_init=custom_init + ) + return algo(y, physics=physics) + elif self.name == 'SWINIRx4': + window_size = 8 + scale = 4 + _, _, h_old, w_old = y.size() + h_pad = (h_old // window_size + 1) * window_size - h_old + w_pad = (w_old // window_size + 1) * window_size - w_old + img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :] + img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] + output = self.model(img_lq) + output = output[..., :h_old * scale, :w_old * scale] + output = self.circular_roll(output, -2, -2) + # check shape of adjoint + x_adj = physics.A_adjoint(y) + output = output[..., :x_adj.size(-2), :x_adj.size(-1)] + return output + elif self.name == 'SWINIRx2': + window_size = 8 + scale = 2 + _, _, h_old, w_old = y.size() + h_pad = (h_old // window_size + 1) * window_size - h_old + w_pad = (w_old // window_size + 1) * window_size - w_old + img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :] + img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] + output = self.model(img_lq) + output = output[..., :h_old * scale, :w_old * scale] + output = self.circular_roll(output, -1, -1) + # check shape of adjoint + x_adj = physics.A_adjoint(y) + output = output[..., :x_adj.size(-2), :x_adj.size(-1)] + return output + elif 'UNROLLED_DPIR' in self.name: + return self.model(y, physics=physics) + else: + return self.model(y) + + +class EvalDataset(torch.utils.data.Dataset): + """ + We expect that images are 480x480. + """ + all_datasets = ["Natural", "MRI", "CT"] + + def __init__(self, dataset_name: str, device_str: str = "cpu") -> None: + self.name = dataset_name + self.device_str = device_str + if self.name not in self.all_datasets: + raise ValueError(f"{self.name} is unavailable.") + if self.name == 'Natural': + self.root = 'datasets/LSDIR_samples' + self.transform = transforms.Compose([transforms.ToTensor()]) + self.dataset = dinv.datasets.LsdirHR(root=self.root, + download=False, + transform=self.transform) + elif self.name == 'MRI': + self.root = 'datasets/FastMRI_samples' + self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True) + self.dataset = Preprocessed_fastMRI(root=self.root, + transform=self.transform, + preprocess=False) + elif self.name == "CT": + self.root = 'datasets/LIDC_IDRI_samples' + self.transform = None + self.dataset = Preprocessed_LIDCIDRI(root=self.root, + transform=self.transform) + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> torch.Tensor: + return self.dataset[idx].to(self.device_str) + + +class Metric(): + """Metrics and utilities.""" + all_metrics = ["PSNR", "SSIM", "LPIPS"] + + def __init__(self, metric_name: str, device_str: str = "cpu") -> None: + self.name = metric_name + if self.name not in self.all_metrics: + raise ValueError(f"{self.name} is unavailable.") + elif self.name == "PSNR": + self.metric = dinv.loss.metric.PSNR() + elif self.name == "SSIM": + self.metric = dinv.loss.metric.SSIM() + elif self.name == "LPIPS": + self.metric = dinv.loss.metric.LPIPS(device=device_str) + + def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + # it may happen that x_net and x do not have the same size, in which case we take the minimum size of both + if x_net.shape[-1] != x.shape[-1]: + min_size = min(x_net.shape[-1], x.shape[-1]) + x_net_crop = x_net[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2, + x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2] + x_crop = x[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2, + x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2] + else: + x_net_crop = x_net + x_crop = x + return self.metric(x_net_crop, x_crop) + + @classmethod + def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]: + l = [] + for metric_name in metric_names: + l.append(cls(metric_name, device_str=device_str)) + return l diff --git a/img_samples/FastMRI_samples/file_brain_AXT1POST_209_6001231_11.pt b/img_samples/FastMRI_samples/file_brain_AXT1POST_209_6001231_11.pt new file mode 100644 index 0000000000000000000000000000000000000000..b4fcd0608c0ab7fb3e961e5b89f68e3e88392a4a --- /dev/null +++ b/img_samples/FastMRI_samples/file_brain_AXT1POST_209_6001231_11.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3caf7165619d7c5f1e30c6ecca6f5239e318aeb3e070daacc9f8b7343d803fee +size 1639843 diff --git a/img_samples/FastMRI_samples/file_brain_AXT2_205_2050122_7.pt b/img_samples/FastMRI_samples/file_brain_AXT2_205_2050122_7.pt new file mode 100644 index 0000000000000000000000000000000000000000..43362e37808629b148e2c97a5a96277557140555 --- /dev/null +++ b/img_samples/FastMRI_samples/file_brain_AXT2_205_2050122_7.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8dfa2385d380ccb22c8cd4a7045ee1262e57ad8ee64aa9a763d0f9fe9404116 +size 1639818 diff --git a/img_samples/FastMRI_samples/file_brain_AXT2_205_2050160_10.pt b/img_samples/FastMRI_samples/file_brain_AXT2_205_2050160_10.pt new file mode 100644 index 0000000000000000000000000000000000000000..3634bc0fd3004a335af85e5409e37d7bb48f6dec --- /dev/null +++ b/img_samples/FastMRI_samples/file_brain_AXT2_205_2050160_10.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f10dacc30dade778bc255c9c61859887282bf36b4e967d0bb2828baf2bb2a914 +size 1639823 diff --git a/img_samples/FastMRI_samples/file_brain_AXT2_210_6001888_6.pt b/img_samples/FastMRI_samples/file_brain_AXT2_210_6001888_6.pt new file mode 100644 index 0000000000000000000000000000000000000000..e6245defaa0903e0c7340a4143b0ee3c369f5cfe --- /dev/null +++ b/img_samples/FastMRI_samples/file_brain_AXT2_210_6001888_6.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cd2d841a101fbd9cc6b37e2c7dea35d868be01977cb08165446e8feb05c22ac +size 2434442 diff --git a/img_samples/FastMRI_samples/file_brain_AXT2_210_6001947_5.pt b/img_samples/FastMRI_samples/file_brain_AXT2_210_6001947_5.pt new file mode 100644 index 0000000000000000000000000000000000000000..17b47426a3c9ef39b79845705387cdc03668c549 --- /dev/null +++ b/img_samples/FastMRI_samples/file_brain_AXT2_210_6001947_5.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e85534e07a07e07107cec4a256c3112f1ca6283f39280a603a559c182e534abb +size 2434442 diff --git a/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0032_01-01-2000-NA-NA-53482_3000537.000000-NA-91689_1-236.pt b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0032_01-01-2000-NA-NA-53482_3000537.000000-NA-91689_1-236.pt new file mode 100755 index 0000000000000000000000000000000000000000..a359c79d38c34594b3ca00ee942a6ceb679a2b25 --- /dev/null +++ b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0032_01-01-2000-NA-NA-53482_3000537.000000-NA-91689_1-236.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:177a0102e65a7fbf97161c431f77722dc842a202dfde0e8a4e9c75dcb9f4ab9a +size 2098888 diff --git a/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0083_01-01-2000-NA-NA-22049_3000646.000000-NA-60532_1-027.pt b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0083_01-01-2000-NA-NA-22049_3000646.000000-NA-60532_1-027.pt new file mode 100755 index 0000000000000000000000000000000000000000..f6e668f5f00dab131bfcf576bd355c70bae260ee --- /dev/null +++ b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0083_01-01-2000-NA-NA-22049_3000646.000000-NA-60532_1-027.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:002d727e599d79615f36be14a6136e6c89c6ee20a0c4bd943aff097bb1447270 +size 2098888 diff --git a/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0144_01-01-2000-NA-NA-61308_3000703.000000-NA-75826_1-079.pt b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0144_01-01-2000-NA-NA-61308_3000703.000000-NA-75826_1-079.pt new file mode 100755 index 0000000000000000000000000000000000000000..8b103b177ada6a652dab6b9f1f9d981917588b6a --- /dev/null +++ b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0144_01-01-2000-NA-NA-61308_3000703.000000-NA-75826_1-079.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a66e185bbc1e5bbd57c12b0c4c55ee3b6572154d968a8511b00564159ef0837 +size 2098888 diff --git a/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0152_01-01-2000-NA-NA-78489_3000696.000000-NA-27171_1-083.pt b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0152_01-01-2000-NA-NA-78489_3000696.000000-NA-27171_1-083.pt new file mode 100755 index 0000000000000000000000000000000000000000..b34263066ef70b2cce09d6a9d13c8ab34c103444 --- /dev/null +++ b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0152_01-01-2000-NA-NA-78489_3000696.000000-NA-27171_1-083.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0b731e35e6debaf7241ee6ae7a89915163851178b9e4d54c4104cb8a4426076 +size 2098888 diff --git a/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0298_01-01-2000-NA-NA-11572_3000663.000000-NA-48288_1-004.pt b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0298_01-01-2000-NA-NA-11572_3000663.000000-NA-48288_1-004.pt new file mode 100755 index 0000000000000000000000000000000000000000..1fef5595abaa11a758f975e6ebb9f86ba909c11b --- /dev/null +++ b/img_samples/LIDC-IDRI_samples/LIDC-IDRI-0298_01-01-2000-NA-NA-11572_3000663.000000-NA-48288_1-004.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d7b5a7b1bfaf79f203fb492e735434964c71a9408caa9c6df407579f0df6000 +size 2098888 diff --git a/img_samples/LSDIR_samples/0001000/0000007_s005.png b/img_samples/LSDIR_samples/0001000/0000007_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..5c6d14b6c5be71c332fa9cee9bd0a919198c0492 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000007_s005.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000030_s003.png b/img_samples/LSDIR_samples/0001000/0000030_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..260a1eba797f9c4d580807e6f25d5b70a5751523 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000030_s003.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000067_s005.png b/img_samples/LSDIR_samples/0001000/0000067_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..caf89ff88d97f97c7a300a698dc0b11d642326a3 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000067_s005.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000082_s003.png b/img_samples/LSDIR_samples/0001000/0000082_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..cf3d5ebcf95640e80bac2d81bfbea9a610c52729 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000082_s003.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000110_s002.png b/img_samples/LSDIR_samples/0001000/0000110_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..ccfcfd92af347c7bf779974d0571575acb5b15cf Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000110_s002.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000125_s003.png b/img_samples/LSDIR_samples/0001000/0000125_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..5235ea98bb399c8b426ec79442d4f165fe20c2d7 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000125_s003.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000154_s007.png b/img_samples/LSDIR_samples/0001000/0000154_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..944f7ac0e656134f427d8ecf90406ba77d6626ba Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000154_s007.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000247_s007.png b/img_samples/LSDIR_samples/0001000/0000247_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..367cabc86e27804a1a859669fff71e7316c5de3d Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000247_s007.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000259_s003.png b/img_samples/LSDIR_samples/0001000/0000259_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..08e2b7fd7b81ea00e2d1ed0867e1a2068198e19a Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000259_s003.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000405_s008.png b/img_samples/LSDIR_samples/0001000/0000405_s008.png new file mode 100755 index 0000000000000000000000000000000000000000..3e2a689df15e742e5bc33a17be5351ab40e81adb Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000405_s008.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000578_s002.png b/img_samples/LSDIR_samples/0001000/0000578_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..afb1a261118337404c9ed7aef47249bd3d2202c9 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000578_s002.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000669_s010.png b/img_samples/LSDIR_samples/0001000/0000669_s010.png new file mode 100755 index 0000000000000000000000000000000000000000..ed50886758c0077cd36b32450573e44c9212e468 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000669_s010.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000689_s006.png b/img_samples/LSDIR_samples/0001000/0000689_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..35cd0a18755e891c0273314a33fa83e2166471b5 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000689_s006.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000715_s011.png b/img_samples/LSDIR_samples/0001000/0000715_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..66c22abc6f3de01e9ba5769214181cc55030cfcf Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000715_s011.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000752_s010.png b/img_samples/LSDIR_samples/0001000/0000752_s010.png new file mode 100755 index 0000000000000000000000000000000000000000..58d7029296f9b1e47859c4ce6a0d065f01ef94a5 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000752_s010.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000803_s012.png b/img_samples/LSDIR_samples/0001000/0000803_s012.png new file mode 100755 index 0000000000000000000000000000000000000000..c039539558bcac0ed678eba3a11a68332e54a268 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000803_s012.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000825_s012.png b/img_samples/LSDIR_samples/0001000/0000825_s012.png new file mode 100755 index 0000000000000000000000000000000000000000..6dd933f0ae63570a1e9418708a3a7e50e5784f58 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000825_s012.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000921_s012.png b/img_samples/LSDIR_samples/0001000/0000921_s012.png new file mode 100755 index 0000000000000000000000000000000000000000..165ec7582cb10159edee116e9415d144419d0e47 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000921_s012.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000958_s004.png b/img_samples/LSDIR_samples/0001000/0000958_s004.png new file mode 100755 index 0000000000000000000000000000000000000000..92d3a25a4b020a1ad585cf53ac332e493d22f2c3 Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000958_s004.png differ diff --git a/img_samples/LSDIR_samples/0001000/0000994_s021.png b/img_samples/LSDIR_samples/0001000/0000994_s021.png new file mode 100755 index 0000000000000000000000000000000000000000..18c077d955c23ed855ee16b35a94e13c741993bb Binary files /dev/null and b/img_samples/LSDIR_samples/0001000/0000994_s021.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008033_s006.png b/img_samples/LSDIR_samples/0009000/0008033_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..2089c9d25201476291096e437b105b4b03521c5a Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008033_s006.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008068_s005.png b/img_samples/LSDIR_samples/0009000/0008068_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..ebcd34a13ff26de4fe5da3cf7ef88b0feb998d9c Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008068_s005.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008115_s004.png b/img_samples/LSDIR_samples/0009000/0008115_s004.png new file mode 100755 index 0000000000000000000000000000000000000000..001691c3f2a14544fc2a39198ababbb9af9ec46e Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008115_s004.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008217_s002.png b/img_samples/LSDIR_samples/0009000/0008217_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..99b8c659d74848d4475dad0aafa914d2e484ec3c Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008217_s002.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008294_s010.png b/img_samples/LSDIR_samples/0009000/0008294_s010.png new file mode 100755 index 0000000000000000000000000000000000000000..55f894728e6bb18278c35e32828ecb9ad471b1ad Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008294_s010.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008315_s053.png b/img_samples/LSDIR_samples/0009000/0008315_s053.png new file mode 100755 index 0000000000000000000000000000000000000000..646c86b0046b3593c7f14f144b0bdb07fa175481 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008315_s053.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008340_s015.png b/img_samples/LSDIR_samples/0009000/0008340_s015.png new file mode 100755 index 0000000000000000000000000000000000000000..37a4e1568ee70eb78bdefa3fb0a59f014703347a Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008340_s015.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008361_s009.png b/img_samples/LSDIR_samples/0009000/0008361_s009.png new file mode 100755 index 0000000000000000000000000000000000000000..be04cbac03630bf8445051a199d85dae5ef46a1a Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008361_s009.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008386_s007.png b/img_samples/LSDIR_samples/0009000/0008386_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..5473e039a13d4e05789a9f0071346922e796a409 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008386_s007.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008491_s006.png b/img_samples/LSDIR_samples/0009000/0008491_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..2b4b4c97f338cfe273e3e1c3d0d18f3309e42220 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008491_s006.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008528_s007.png b/img_samples/LSDIR_samples/0009000/0008528_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..e3aaf64920f0461ff4fdd9f3c9ae1afef922a12f Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008528_s007.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008571_s007.png b/img_samples/LSDIR_samples/0009000/0008571_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..97cc1cfee20674a2cc4e1b5496c7fb325721b249 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008571_s007.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008573_s012.png b/img_samples/LSDIR_samples/0009000/0008573_s012.png new file mode 100755 index 0000000000000000000000000000000000000000..e4be80c7885e1292b5b2967989ce217e1c1540f8 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008573_s012.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008605_s007.png b/img_samples/LSDIR_samples/0009000/0008605_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..ec723ceddcb315d531fe0dcb3796908089cacc6f Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008605_s007.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008611_s002.png b/img_samples/LSDIR_samples/0009000/0008611_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..2e807a55ad6b291a086b0d6393e64aac7363ee0c Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008611_s002.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008631_s005.png b/img_samples/LSDIR_samples/0009000/0008631_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..89595fe5c55c7558d1de4e1e618f75a5cb7ae82f Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008631_s005.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008681_s008.png b/img_samples/LSDIR_samples/0009000/0008681_s008.png new file mode 100755 index 0000000000000000000000000000000000000000..b33e4eb4a70ac58e5a28dd1b4fac9735c9c76cfc Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008681_s008.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008703_s013.png b/img_samples/LSDIR_samples/0009000/0008703_s013.png new file mode 100755 index 0000000000000000000000000000000000000000..8f5fa573a6afc818c774e204db336ee33fe0b9a8 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008703_s013.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008714_s010.png b/img_samples/LSDIR_samples/0009000/0008714_s010.png new file mode 100755 index 0000000000000000000000000000000000000000..f9cb1067d9959cbc048f8c07306a01983575eb50 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008714_s010.png differ diff --git a/img_samples/LSDIR_samples/0009000/0008774_s004.png b/img_samples/LSDIR_samples/0009000/0008774_s004.png new file mode 100755 index 0000000000000000000000000000000000000000..086ee33dc1457b49f883fc9ffcbb77145cfdf7a5 Binary files /dev/null and b/img_samples/LSDIR_samples/0009000/0008774_s004.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022020_s005.png b/img_samples/LSDIR_samples/0023000/0022020_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..93a8b7954bc27f3ea6b3e4a007277d2dd3c625fb Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022020_s005.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022037_s011.png b/img_samples/LSDIR_samples/0023000/0022037_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..d6e1b2306fa11f9fc3298b5ea7d17c38bafd4cd8 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022037_s011.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022059_s008.png b/img_samples/LSDIR_samples/0023000/0022059_s008.png new file mode 100755 index 0000000000000000000000000000000000000000..7700e7ed90eb438743e8aa09a7ca0682f7fe0290 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022059_s008.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022135_s002.png b/img_samples/LSDIR_samples/0023000/0022135_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..37882f0b4741e1bbb0937ccc43006ffa3e58f298 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022135_s002.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022240_s003.png b/img_samples/LSDIR_samples/0023000/0022240_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..60fc51a647993fdae2892c9e62d9a67b61649628 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022240_s003.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022351_s005.png b/img_samples/LSDIR_samples/0023000/0022351_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..e6c43d382a1ce55b7120142526a21b296cb9e20e Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022351_s005.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022372_s013.png b/img_samples/LSDIR_samples/0023000/0022372_s013.png new file mode 100755 index 0000000000000000000000000000000000000000..c94a6d4ca4fa086dd8f4b03ed9ecb1e936968208 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022372_s013.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022429_s006.png b/img_samples/LSDIR_samples/0023000/0022429_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..2cf74d17ff8b614f174477c40fbc73399de51956 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022429_s006.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022452_s020.png b/img_samples/LSDIR_samples/0023000/0022452_s020.png new file mode 100755 index 0000000000000000000000000000000000000000..95de3b4ee5685e1a4e84698dc291fd541207eb4f Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022452_s020.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022478_s006.png b/img_samples/LSDIR_samples/0023000/0022478_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..958fce3a9eb5dcf0fa0e21b1a981be0a8d21f236 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022478_s006.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022480_s007.png b/img_samples/LSDIR_samples/0023000/0022480_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..8026d8511fb18cee82de225f61e3390180f7b50f Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022480_s007.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022492_s001.png b/img_samples/LSDIR_samples/0023000/0022492_s001.png new file mode 100755 index 0000000000000000000000000000000000000000..052d8080f26f77f1c060a08661dd245bd80d89af Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022492_s001.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022531_s004.png b/img_samples/LSDIR_samples/0023000/0022531_s004.png new file mode 100755 index 0000000000000000000000000000000000000000..37b4a36e8b055cf28bdb86512ada0efdc342a016 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022531_s004.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022666_s012.png b/img_samples/LSDIR_samples/0023000/0022666_s012.png new file mode 100755 index 0000000000000000000000000000000000000000..efaa846ce619d8e1efe41ebdb36d115ca3f37f75 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022666_s012.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022710_s015.png b/img_samples/LSDIR_samples/0023000/0022710_s015.png new file mode 100755 index 0000000000000000000000000000000000000000..d9f0826fe21cb06b696a7406e50a8fd15792c83e Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022710_s015.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022765_s011.png b/img_samples/LSDIR_samples/0023000/0022765_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..e9723d85288073986ed0fdce5adcfbebb2f73733 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022765_s011.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022781_s005.png b/img_samples/LSDIR_samples/0023000/0022781_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..5454a82e5374b8d044e1b84d128de8c8e4a9bd19 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022781_s005.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022815_s004.png b/img_samples/LSDIR_samples/0023000/0022815_s004.png new file mode 100755 index 0000000000000000000000000000000000000000..a6f1f3780a8dcaae9feed9ab282525cf8395a22a Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022815_s004.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022908_s010.png b/img_samples/LSDIR_samples/0023000/0022908_s010.png new file mode 100755 index 0000000000000000000000000000000000000000..b11dfcb9d2d49b778e507473ccbec986d45dfb15 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022908_s010.png differ diff --git a/img_samples/LSDIR_samples/0023000/0022932_s021.png b/img_samples/LSDIR_samples/0023000/0022932_s021.png new file mode 100755 index 0000000000000000000000000000000000000000..f85ac42d7b285dee58ca9fe63ecfdce73ade7b12 Binary files /dev/null and b/img_samples/LSDIR_samples/0023000/0022932_s021.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064028_s004.png b/img_samples/LSDIR_samples/0065000/0064028_s004.png new file mode 100755 index 0000000000000000000000000000000000000000..2afb3e0da7ad02a7c983528cb2187899b1c520ca Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064028_s004.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064049_s008.png b/img_samples/LSDIR_samples/0065000/0064049_s008.png new file mode 100755 index 0000000000000000000000000000000000000000..302d270192c264e0c8a5bb98364a47c3792c9d95 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064049_s008.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064051_s003.png b/img_samples/LSDIR_samples/0065000/0064051_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..404af71fbf985ba5a0b908684aa7f9fd459c7808 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064051_s003.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064278_s009.png b/img_samples/LSDIR_samples/0065000/0064278_s009.png new file mode 100755 index 0000000000000000000000000000000000000000..d67506d26c83ef874f5c172e1d65cc6454a9f047 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064278_s009.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064346_s035.png b/img_samples/LSDIR_samples/0065000/0064346_s035.png new file mode 100755 index 0000000000000000000000000000000000000000..c0f7b912998956112f26278b6a8e448a85ea538b Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064346_s035.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064479_s012.png b/img_samples/LSDIR_samples/0065000/0064479_s012.png new file mode 100755 index 0000000000000000000000000000000000000000..ec903c6bd9d5408e585868d2edd1402fe850bcd6 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064479_s012.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064507_s011.png b/img_samples/LSDIR_samples/0065000/0064507_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..daba875300509bb52e6bc243ff9bfb920a49adeb Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064507_s011.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064559_s007.png b/img_samples/LSDIR_samples/0065000/0064559_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..c63830897ab08d4272a86db39dc0cab1014443d6 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064559_s007.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064652_s007.png b/img_samples/LSDIR_samples/0065000/0064652_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..33d2cb6af3d8425e766d4f71571e5e8048a5c77a Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064652_s007.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064674_s008.png b/img_samples/LSDIR_samples/0065000/0064674_s008.png new file mode 100755 index 0000000000000000000000000000000000000000..9019d794eeff8cdc59ca1df59e2def64047718b6 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064674_s008.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064777_s014.png b/img_samples/LSDIR_samples/0065000/0064777_s014.png new file mode 100755 index 0000000000000000000000000000000000000000..cc0fcc26f9df1f1d7d56abde8db1c3d45f2af027 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064777_s014.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064784_s007.png b/img_samples/LSDIR_samples/0065000/0064784_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..b92ef05b1d636a9883fcc70b8ef43f462f560e40 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064784_s007.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064795_s003.png b/img_samples/LSDIR_samples/0065000/0064795_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..b0c01cbea430a22054e2885290a1ad6312598ae6 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064795_s003.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064812_s008.png b/img_samples/LSDIR_samples/0065000/0064812_s008.png new file mode 100755 index 0000000000000000000000000000000000000000..08cafecdc00f44c9f336a3a4a0f5d10c2d92380c Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064812_s008.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064849_s003.png b/img_samples/LSDIR_samples/0065000/0064849_s003.png new file mode 100755 index 0000000000000000000000000000000000000000..31efe8609adad580ec63a1c6bf0f590225b081f0 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064849_s003.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064863_s002.png b/img_samples/LSDIR_samples/0065000/0064863_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..949efc64105234b0ae434bac59db568181b8c8e1 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064863_s002.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064877_s002.png b/img_samples/LSDIR_samples/0065000/0064877_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..54d732804f5a02e8fb535a9c9b799ce3884c614c Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064877_s002.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064928_s011.png b/img_samples/LSDIR_samples/0065000/0064928_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..bf479c05bca254c518842c39ce741e990ea2092c Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064928_s011.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064980_s019.png b/img_samples/LSDIR_samples/0065000/0064980_s019.png new file mode 100755 index 0000000000000000000000000000000000000000..513839e90ffab7fc538d2213ce819090c479a30e Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064980_s019.png differ diff --git a/img_samples/LSDIR_samples/0065000/0064987_s009.png b/img_samples/LSDIR_samples/0065000/0064987_s009.png new file mode 100755 index 0000000000000000000000000000000000000000..31dda1f532fc230cbc3d7c7f5b2ad3a9551e8fd1 Binary files /dev/null and b/img_samples/LSDIR_samples/0065000/0064987_s009.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084061_s010.png b/img_samples/LSDIR_samples/0085000/0084061_s010.png new file mode 100755 index 0000000000000000000000000000000000000000..e307bf3a4e9ef148125d88165cbc3adf50f6f858 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084061_s010.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084136_s006.png b/img_samples/LSDIR_samples/0085000/0084136_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..69e99343201d9193bd671cd7e868e260613989e2 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084136_s006.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084160_s002.png b/img_samples/LSDIR_samples/0085000/0084160_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..8abeea4b775861b20f4797db453003d4a8b425f6 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084160_s002.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084169_s009.png b/img_samples/LSDIR_samples/0085000/0084169_s009.png new file mode 100755 index 0000000000000000000000000000000000000000..5085d0c9783f28c8414e3de02e85a990943a3eac Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084169_s009.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084185_s011.png b/img_samples/LSDIR_samples/0085000/0084185_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..99afd553c2e83751056eb8b005fbf9a5bd19ee14 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084185_s011.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084205_s001.png b/img_samples/LSDIR_samples/0085000/0084205_s001.png new file mode 100755 index 0000000000000000000000000000000000000000..fdd66a46f44366439504d867941c01b6dd49916a Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084205_s001.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084338_s011.png b/img_samples/LSDIR_samples/0085000/0084338_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..0b066c7fa37d607142cc7723cb46ef5d4095e20a Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084338_s011.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084346_s005.png b/img_samples/LSDIR_samples/0085000/0084346_s005.png new file mode 100755 index 0000000000000000000000000000000000000000..6c6c2772abd84dd2406148204b0d446cd2c3cdd3 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084346_s005.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084442_s002.png b/img_samples/LSDIR_samples/0085000/0084442_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..61447fd95b6616a56ae19a1a65ccbd92d4b420f7 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084442_s002.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084496_s001.png b/img_samples/LSDIR_samples/0085000/0084496_s001.png new file mode 100755 index 0000000000000000000000000000000000000000..2c9a1162298ad52eab9d7499f7586ccda828ec9e Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084496_s001.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084540_s002.png b/img_samples/LSDIR_samples/0085000/0084540_s002.png new file mode 100755 index 0000000000000000000000000000000000000000..91f10d2203f8638d3bbf72e835326053b7a020b5 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084540_s002.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084553_s007.png b/img_samples/LSDIR_samples/0085000/0084553_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..2fcc9b129d4868ec9ad5efce2ccd03ad90116eae Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084553_s007.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084593_s011.png b/img_samples/LSDIR_samples/0085000/0084593_s011.png new file mode 100755 index 0000000000000000000000000000000000000000..0a6bac4e804252f69b452ea54f61638d068e0d16 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084593_s011.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084622_s015.png b/img_samples/LSDIR_samples/0085000/0084622_s015.png new file mode 100755 index 0000000000000000000000000000000000000000..937ffe1281ca119d09481a3f73182ef30b8cc635 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084622_s015.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084792_s006.png b/img_samples/LSDIR_samples/0085000/0084792_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..75fa3b4047b5ae4e744db0904db9c1366127208a Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084792_s006.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084794_s012.png b/img_samples/LSDIR_samples/0085000/0084794_s012.png new file mode 100755 index 0000000000000000000000000000000000000000..4651c3c2c46011e13126449cf7a6ec8aa2880ae1 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084794_s012.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084840_s009.png b/img_samples/LSDIR_samples/0085000/0084840_s009.png new file mode 100755 index 0000000000000000000000000000000000000000..b9e65d5c07d33ac8245cb720a12521dc193c8753 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084840_s009.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084846_s006.png b/img_samples/LSDIR_samples/0085000/0084846_s006.png new file mode 100755 index 0000000000000000000000000000000000000000..84e2a9f0e709ecb46dea44b449b09224d623c607 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084846_s006.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084972_s007.png b/img_samples/LSDIR_samples/0085000/0084972_s007.png new file mode 100755 index 0000000000000000000000000000000000000000..216c59fa32ae02257908ac49336923ece5ddffc0 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084972_s007.png differ diff --git a/img_samples/LSDIR_samples/0085000/0084978_s010.png b/img_samples/LSDIR_samples/0085000/0084978_s010.png new file mode 100755 index 0000000000000000000000000000000000000000..d4e48c412ade226fd00556f176db69480e2500d3 Binary files /dev/null and b/img_samples/LSDIR_samples/0085000/0084978_s010.png differ diff --git a/models/PDNet.py b/models/PDNet.py new file mode 100644 index 0000000000000000000000000000000000000000..1506570fd1812df95b1a1f0b1f1b308c8f61103e --- /dev/null +++ b/models/PDNet.py @@ -0,0 +1,322 @@ +from pathlib import Path + +import torch +from torch.func import vmap +from torch.utils.data import DataLoader +import deepinv as dinv +from deepinv.unfolded import unfolded_builder +from deepinv.utils.phantoms import RandomPhantomDataset, SheppLoganDataset +from deepinv.optim.optim_iterators import CPIteration, fStep, gStep +from deepinv.optim import Prior, DataFidelity +from deepinv.utils.tensorlist import TensorList + +from physics.multiscale import MultiScaleLinearPhysics +from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads + + +def get_PDNet_architecture(in_channels=[1, 2, 3], out_channels=[1, 2, 3], n_primal=3, n_dual=3, device='cuda'): + class PDNetIteration(CPIteration): + r"""Single iteration of learned primal dual. + We only redefine the fStep and gStep classes. + The forward method is inherited from the CPIteration class. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.g_step = gStepPDNet(**kwargs) + self.f_step = fStepPDNet(**kwargs) + + def forward( + self, X, cur_data_fidelity, cur_prior, cur_params, y, physics, *args, **kwargs + ): + r""" + Single iteration of the Chambolle-Pock algorithm. + + :param dict X: Dictionary containing the current iterate and the estimated cost. + :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity. + :param deepinv.optim.Prior cur_prior: Instance of the Prior class defining the current prior. + :param dict cur_params: dictionary containing the current parameters of the algorithm. + :param torch.Tensor y: Input data. + :param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term. + :return: Dictionary `{"est": (x, ), "cost": F}` containing the updated current iterate and the estimated current cost. + """ + x_prev, z_prev, u_prev = X["est"] # x : primal, z : relaxed primal, u : dual + BS, C_primal, H_primal, W_primal = x_prev.shape + _, C_dual, H_dual, W_dual = u_prev.shape + n_channels = C_primal // n_primal + K = lambda x: torch.cat( + [physics.A(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_primal)], dim=1) + K_adjoint = lambda x: torch.cat( + [physics.A_adjoint(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_dual)], dim=1) + u = self.f_step(u_prev, K(z_prev), cur_data_fidelity, y, physics, n_channels, + cur_params) # dual update (data_fid) + x = self.g_step(x_prev, K_adjoint(u), cur_prior, n_channels, cur_params) # primal update (prior) + z = x + cur_params["beta"] * (x - x_prev) + F = ( + self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics) + if self.has_cost + else None + ) + return {"est": (x, z, u), "cost": F} + + class fStepPDNet(fStep): + r""" + Dual update of the PDNet algorithm. + We write it as a proximal operator of the data fidelity term. + This proximal mapping is to be replaced by a trainable model. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, w, cur_data_fidelity, y, physics, n_channels, *args): + r""" + :param torch.Tensor x: Current first variable :math:`u`. + :param torch.Tensor w: Current second variable :math:`A z`. + :param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term. + :param torch.Tensor y: Input data. + """ + return cur_data_fidelity.prox(x, w, y, n_channels) + + class gStepPDNet(gStep): + r""" + Primal update of the PDNet algorithm. + We write it as a proximal operator of the prior term. + This proximal mapping is to be replaced by a trainable model. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, w, cur_prior, n_channels, *args): + r""" + :param torch.Tensor x: Current first variable :math:`x`. + :param torch.Tensor w: Current second variable :math:`A^\top u`. + :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior. + """ + return cur_prior.prox(x, w, n_channels) + + # %% + # Define the trainable prior and data fidelity terms. + # --------------------------------------------------- + # Prior and data-fidelity are respectively defined as subclass of :class:`deepinv.optim.Prior` and :class:`deepinv.optim.DataFidelity`. + # Their proximal operators are replaced by trainable models. + + class PDNetPrior(Prior): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = model + + def prox(self, x, w, n_channels): + # give to the model : full primal + premier de dual + dual_cond = w[:, 0:n_channels, :, :] + return self.model(x, dual_cond) + + class PDNetDataFid(DataFidelity): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = model + + def prox(self, x, w, y, n_channels): + # give to the model : full dual + deuxieme de primal + y = n_channel*n_dual + n_channel + n_channel + if n_primal > 1: + primal_cond = w[:, n_channels:(2 * n_channels), :, :] + else: + primal_cond = w[:, 0:n_channels, :, :] + return self.model(x, primal_cond, y) + + # Unrolled optimization algorithm parameters + max_iter = 10 + + # Set up the data fidelity term. Each layer has its own data fidelity module. + in_channels_dual = [in_channel * n_dual + in_channel + in_channel for in_channel in in_channels] + out_channels_dual = [in_channel * n_dual for in_channel in in_channels] + in_channels_primal = [in_channel * n_primal + in_channel for in_channel in in_channels] + out_channels_primal = [in_channel * n_primal for in_channel in in_channels] + + data_fidelity = [ + PDNetDataFid(model=PDNet_DualBlock(in_channels=in_channels_dual, out_channels=out_channels_dual).to(device)) for + i in range(max_iter) + ] + + # Set up the trainable prior. Each layer has its own prior module. + prior = [ + PDNetPrior(model=PDNet_PrimalBlock(in_channels=in_channels_primal, out_channels=out_channels_primal).to(device)) + for i in range(max_iter)] + + # %% + # Define the model. + # ------------------------------- + + def custom_init(y, physics): + x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1) + u0 = (0 * y).repeat(1, n_dual, 1, 1) + return {"est": (x0, x0, u0)} + + def custom_output(X): + x = X["est"][0] + n_channels = x.shape[1] // n_primal + if n_primal > 1: + return X["est"][0][:, n_channels:(2 * n_channels), :, :] + else: + return X["est"][0][:, 0:n_channels, :, :] + + # %% + # Define the unfolded trainable model. + # ------------------------------------- + # The original paper of the learned primal dual algorithm the authors used the adjoint operator + # in the primal update. However, the same authors (among others) find in the paper + # + # A. Hauptmann, J. Adler, S. Arridge, O. Öktem, + # Multi-scale learned iterative reconstruction, + # IEEE Transactions on Computational Imaging 6, 843-856, 2020. + # + # that using a filtered gradient can improve both the training speed and reconstruction quality significantly. + # Following this approach, we use the filtered backprojection instead of the adjoint operator in the primal step. + + model = unfolded_builder( + iteration=PDNetIteration(), + params_algo={"beta": 0.0}, + data_fidelity=data_fidelity, + prior=prior, + max_iter=max_iter, + custom_init=custom_init, + get_output=custom_output, + ) + + return model.to(device) + + +def init_weights(m): + if isinstance(m, torch.nn.Linear): + torch.torch.nn.init.xavier_uniform(m.weight) + m.bias.data.fill_(0.0) + + +class PDNet_PrimalBlock(torch.nn.Module): + r""" + Primal block for the Primal-Dual unfolding model. + + From https://arxiv.org/abs/1707.06474. + + Primal variables are images of shape (batch_size, in_channels, height, width). The input of each + primal block is the concatenation of the current primal variable and the backprojected dual variable along + the channel dimension. The output of each primal block is the current primal variable. + + :param int in_channels: number of input channels. Default: 6. + :param int out_channels: number of output channels. Default: 5. + :param int depth: number of convolutional layers in the block. Default: 3. + :param bool bias: whether to use bias in convolutional layers. Default: True. + :param int nf: number of features in the convolutional layers. Default: 32. + """ + + def __init__(self, in_channels=[1, 2, 3], out_channels=[1, 2, 3], depth=3, bias=True, nf=32): + super(PDNet_PrimalBlock, self).__init__() + + self.separate_head = isinstance(in_channels, list) + self.depth = depth + + self.in_conv = InHead(in_channels, nf, bias=bias) + # self.m_head.apply(init_weights) + + # self.in_conv = torch.nn.Conv2d( + # in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias + # ) + + self.in_conv.apply(init_weights) + self.conv_list = torch.nn.ModuleList( + [ + torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias) + for _ in range(self.depth - 2) + ] + ) + self.conv_list.apply(init_weights) + # self.out_conv = torch.nn.Conv2d( + # nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias + # ) + self.out_conv = OutTail(nf, out_channels, bias=bias) + self.out_conv.apply(init_weights) + + self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)]) + + def forward(self, x, Atu): + r""" + Forward pass of the primal block. + + :param torch.Tensor x: current primal variable. + :param torch.Tensor Atu: backprojected dual variable. + :return: (:class:`torch.Tensor`) the current primal variable. + """ + primal_channels = x.shape[1] + x_in = torch.cat((x, Atu), dim=1) + + x_ = self.in_conv(x_in) + x_ = self.nl_list[0](x_) + + for i in range(self.depth - 2): + x_l = self.conv_list[i](x_) + x_ = self.nl_list[i + 1](x_l) + + return self.out_conv(x_, primal_channels) + x + + +class PDNet_DualBlock(torch.nn.Module): + r""" + Dual block for the Primal-Dual unfolding model. + + From https://arxiv.org/abs/1707.06474. + + Dual variables are images of shape (batch_size, in_channels, height, width). The input of each + primal block is the concatenation of the current dual variable with the projected primal variable and + the measurements. The output of each dual block is the current primal variable. + + :param int in_channels: number of input channels. Default: 7. + :param int out_channels: number of output channels. Default: 5. + :param int depth: number of convolutional layers in the block. Default: 3. + :param bool bias: whether to use bias in convolutional layers. Default: True. + :param int nf: number of features in the convolutional layers. Default: 32. + """ + + def __init__(self, in_channels=[1, 2, 3], out_channels=[6, 2, 3], depth=3, bias=True, nf=32): + super(PDNet_DualBlock, self).__init__() + + self.depth = depth + self.in_conv = InHead(in_channels, nf, bias=bias) + # self.in_conv = torch.nn.Conv2d( + # in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias + # ) + self.in_conv.apply(init_weights) + self.conv_list = torch.nn.ModuleList( + [ + torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias) + for _ in range(self.depth - 2) + ] + ) + self.conv_list.apply(init_weights) + self.out_conv = OutTail(nf, out_channels, bias=bias) + # self.out_conv = torch.nn.Conv2d( + # nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias + # ) + self.out_conv.apply(init_weights) + + self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)]) + + def forward(self, u, Ax_cur, y): + r""" + Forward pass of the dual block. + + :param torch.Tensor u: current dual variable. + :param torch.Tensor Ax_cur: projection of the primal variable. + :param torch.Tensor y: measurements. + """ + dual_channels = u.shape[1] + x_in = torch.cat((u, Ax_cur, y), dim=1) + + x_ = self.in_conv(x_in) + x_ = self.nl_list[0](x_) + + for i in range(self.depth - 2): + x_l = self.conv_list[i](x_) + x_ = self.nl_list[i + 1](x_l) + + return self.out_conv(x_, dual_channels) + u \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/blocks.py b/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9312e3518711889d69e4526eadc9c652e9f5d20b --- /dev/null +++ b/models/blocks.py @@ -0,0 +1,924 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from deepinv.models.unet import BFBatchNorm2d +from deepinv.physics.blur import gaussian_blur +from deepinv.physics.functional import conv2d +from deepinv.utils.tensorlist import TensorList + +from timm.models.layers import trunc_normal_, DropPath + + +def normalize(x, dim=None, eps=1e-4): + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half) / half + ).to(t.device) + args = t[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( + dtype=next(self.parameters()).dtype + ) + t_emb = self.mlp(t_freq) + return t_emb + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x, gain=1): + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) # forced weight normalization + w = normalize(w) # traditional weight normalization + w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + assert w.ndim == 4 + return F.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +# -------------------------------------------------------------------------------------- +def mp_silu(x): + return torch.nn.functional.silu(x) / 0.596 + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels, bandwidth=1, device="cpu"): + super().__init__() + self.register_buffer( + "freqs", 2 * np.pi * torch.rand(num_channels, device=device) * bandwidth + ) + self.register_buffer( + "phases", 2 * np.pi * torch.rand(num_channels, device=device) + ) + + def forward(self, x): + y = x.to(torch.float32) + y = y.ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * np.sqrt(2) + return y.to(x.dtype) + + +class NoiseEmbedding(torch.nn.Module): + def __init__(self, num_channels=1, emb_channels=512, device="cpu", biasfree=True): + super().__init__() + self.emb_fourier = MPFourier(num_channels, device=device) + self.emb_noise = MPConv(num_channels, emb_channels, kernel=[]) + self.biasfree = biasfree + + def forward(self, y, physics, factor): + if hasattr(physics, "noise_model") and not callable(physics.noise_model): + sigma = getattr(physics.noise_model, "sigma", 0.0) + else: + sigma = 0.0 + + if isinstance(y, TensorList): + sigma = sigma / (y[0].abs().reshape(y[0].size(0),-1).mean(1) + 1e-8) / factor + else: + sigma = sigma / (y.abs().reshape(y.size(0),-1).mean(1) + 1e-8) / factor + emb_four = self.emb_fourier(sigma) + emb = self.emb_noise(emb_four) + if self.biasfree: + emb = F.relu(emb) + else: + emb = mp_silu(emb) + return emb.unsqueeze(-1).unsqueeze(-1) + + +# -------------------------------------------------------------------------------------- +class AffineConv2d(nn.Conv2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + mode="affine", + bias=False, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="circular", + blind=True, + ): + if mode == "affine": # f(a*x + 1) = a*f(x) + 1 + bias = False + super().__init__( + in_channels, + out_channels, + kernel_size, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + ) + self.blind = blind + self.mode = mode + + def affine(self, w): + """returns new kernels that encode affine combinations""" + return ( + w.view(self.out_channels, -1).roll(1, 1).view(w.size()) + - w + + 1 / w[0, ...].numel() + ) + + def forward(self, x): + if self.mode != "affine": + return super().forward(x) + else: + kernel = ( + self.affine(self.weight) + if self.blind + else torch.cat( + (self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]), + dim=1, + ) + ) + padding = tuple( + elt for elt in reversed(self.padding) for _ in range(2) + ) # used to translate padding arg used by Conv module to the ones used by F.pad + padding_mode = ( + self.padding_mode if self.padding_mode != "zeros" else "constant" + ) # used to translate padding_mode arg used by Conv module to the ones used by F.pad + return F.conv2d( + F.pad(x, padding, mode=padding_mode), + kernel, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + ) + + +# -------------------------------------------------------------------------------------- +def kaiser_window(beta, length): + """Return the Kaiser window of length `length` and shape parameter `beta`.""" + if beta < 0: + raise ValueError("beta must be greater than 0") + if length < 1: + raise ValueError("length must be greater than 0") + if length == 1: + return torch.tensor([1.0]) + half = (length - 1) / 2 + n = torch.arange(length) + beta = torch.tensor(beta) + return torch.i0(beta * torch.sqrt(1 - ((n - half) / half) ** 2)) / torch.i0(beta) + + +def sinc_filter(factor=2, length=11, windowed=True): + r""" + Anti-aliasing sinc filter multiplied by a Kaiser window. + + :param float factor: Downsampling factor. + :param int length: Length of the filter. + """ + deltaf = 1 / factor + + n = torch.arange(length) - (length - 1) / 2 + filter = torch.sinc(n / factor) + + if windowed: + A = 2.285 * (length - 1) * 3.14 * deltaf + 7.95 + if A <= 21: + beta = 0 + elif A <= 50: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21) + else: + beta = 0.1102 * (A - 8.7) + + filter = filter * kaiser_window(beta, length) + + filter = filter.unsqueeze(0) + filter = filter * filter.T + filter = filter.unsqueeze(0).unsqueeze(0) + filter = filter / filter.sum() + return filter + + +class EquivMaxPool(nn.Module): + r""" + Max pooling layer that is equivariant to translations. + + :param int kernel_size: size of the pooling window. + :param int stride: stride of the pooling operation. + :param int padding: padding to apply before pooling. + :param bool circular_padding: circular padding for the convolutional layers. + """ + + def __init__( + self, + antialias="gaussian", + factor=2, + device="cuda", + in_channels=64, + out_channels=64, + bias=False, + padding_mode="circular", + ): + super(EquivMaxPool, self).__init__() + self.antialias = antialias + if antialias == "gaussian": + self.antialias_kernel = gaussian_blur(factor / 3.14).to(device) + elif antialias == "sinc": + self.antialias_kernel = sinc_filter( + factor=factor, length=11, windowed=True + ).to(device) + + self.conv_down = AffineConv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + padding_mode=padding_mode, + groups=1, + ) + + self.conv_up = AffineConv2d( + out_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + padding_mode=padding_mode, + groups=1, + ) + + def forward(self, x): + return self.downscale(x) + + def downscale(self, x): + r""" + Apply the equivariant pooling. + + :param torch.Tensor x: input tensor. + """ + B, C, H, W = x.shape + + x = self.conv_down(x) + + if self.antialias == "gaussian" or self.antialias == "sinc": + x = conv2d(x, self.antialias_kernel, padding="circular") + + x1 = x[:, :, ::2, ::2].unsqueeze(0) + x2 = x[:, :, ::2, 1::2].unsqueeze(0) + x3 = x[:, :, 1::2, ::2].unsqueeze(0) + x4 = x[:, :, 1::2, 1::2].unsqueeze(0) + out = torch.cat([x1, x2, x3, x4], dim=0) # (4, B, C, H/2, W/2) + ind = torch.norm(out, dim=(2, 3, 4), p=2) # (4, B) + ind = torch.argmax(ind, dim=0) # (B) + out = out[ind, torch.arange(B), ...] # (B, C, H/2, W/2) + self.ind = ind + + return out + + def upscale(self, x): + B, C, H, W = x.shape + + out = torch.zeros((B, C, H * 2, W * 2), device=x.device) + out[:, :, ::2, ::2] = x + ind = self.ind + filter = torch.zeros((B, 1, 2, 2), device=x.device) + filter[ind == 0, :, 0, 0] = 1 + filter[ind == 1, :, 0, 1] = 1 + filter[ind == 2, :, 1, 0] = 1 + filter[ind == 3, :, 1, 1] = 1 + out = conv2d(out, filter, padding="constant") + + if self.antialias == "gaussian" or self.antialias == "sinc": + out = conv2d(out, self.antialias_kernel, padding="circular") + + out = self.conv_up(out) + return out + + +# -------------------------------------------------------------------------------------- +class ConvNextBaseBlock(nn.Module): + r""" + ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv) + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + mode (str): Mode for the AffineConv2d (if needed, else ignored). + bias (bool): Whether to use bias in convolutions. Default: False. + ksize (int): Kernel size for the convolutions. Default: 7. + padding_mode (str): Padding mode for convolutions. Default: 'circular'. + mult_fact (int): Multiplier factor for expanding the number of channels. + residual (bool): Whether to use a residual connection. Default: False. + """ + + def __init__( + self, + in_channels, + out_channels, + mode="", + bias=False, + ksize=7, + padding_mode="circular", + mult_fact=1, + residual=False, + ): + super().__init__() + + ### DEPTHWISE SEPARABLE CONVOLUTION: (N,C,H,W) -> (N,4*C,H,W) + # depthwise conv with big kernel + self.dwconv_a = AffineConv2d( + in_channels, + in_channels, + kernel_size=ksize, + padding=ksize // 2, + groups=in_channels, + padding_mode=padding_mode, + bias=bias, + mode=mode, + ) + # depthwise conv with small kernel + self.dwconv_a_small = AffineConv2d( + in_channels, + in_channels, + kernel_size=3, + padding=3 // 2, + groups=in_channels, + padding_mode=padding_mode, + bias=bias, + mode=mode, + ) + # pointwise conv to change number of channels + self.pwconv_a1 = AffineConv2d( + in_channels, + mult_fact * in_channels, + kernel_size=1, + stride=1, + padding=0, + mode=mode, + bias=bias, + padding_mode=padding_mode, + groups=1, + ) + + ### ACTIVATION + self.act_a = nn.ReLU() + + ### POINTWISE CONVOLUTION: (N,4*C,H,W) -> (N,O,H,W) + self.pwconv_a2 = AffineConv2d( + mult_fact * in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + padding_mode=padding_mode, + groups=1, + ) + + ### Needed to match the number of channels : (N,C,H,W) -> (C,O,H,W) + self.residual = residual + if self.residual: + self.residual_conv = AffineConv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + padding_mode=padding_mode, + bias=bias, + mode=mode, + ) + + def forward(self, x_in, stream1=None, stream2=None): + """Forward with GPU parallelization using multiple cuda streams.""" + + if stream1 is not None and stream2 is not None: + # Use the streams + with torch.cuda.stream(stream1): + output_a = self.dwconv_a(x_in) # Run the first convolution in stream1 + + with torch.cuda.stream(stream2): + output_a_small = self.dwconv_a_small( + x_in + ) # Run the second convolution in stream2 + + # Ensure the streams are synchronized before adding the results + torch.cuda.synchronize() + x = self.pwconv_a(output_a + output_a_small) + + else: + x = self.dwconv_a(x_in) + self.dwconv_a_small(x_in) # replk 7x7 with 3x3 + x = self.pwconv_a1(x) + + x = self.act_a(x) + x = self.pwconv_a2(x) # (N,O,H,W) + + if self.residual: + x = self.residual_conv(x_in) + x + + return x + + +class ConvNextBlock2(nn.Module): + r""" + ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv) + + Args: + ??? + """ + + def __init__( + self, + in_channels, + out_channels, + mode="affine", + bias=False, + ksize=7, + padding_mode="circular", + mult_fact=4, + s1=None, + s2=None, + ): + super().__init__() + self.block_0 = ConvNextBaseBlock( + in_channels, + out_channels, + mode=mode, + bias=bias, + ksize=ksize, + padding_mode=padding_mode, + mult_fact=mult_fact, + ) + self.block_1 = ConvNextBaseBlock( + in_channels, + out_channels, + mode=mode, + bias=bias, + ksize=ksize, + padding_mode=padding_mode, + mult_fact=mult_fact, + ) + # self.relu = nn.ReLU(inplace=True) # issue with the network when working in FP16 ??? + self.relu = nn.ReLU() + + # cuda stream to parallelize execution of ConvNextBaseBlock + self.s1 = s1 + self.s2 = s2 + + def forward(self, input, emb_sigma=None): + if self.s1 is not None and self.s2 is not None: + x = self.block_0(input, self.s1, self.s2) + else: + x = self.block_0(input) + + x = self.relu(x) + + if self.s1 is not None and self.s2 is not None: + x = self.block_1(x, self.s1, self.s2) + else: + x = self.block_1(x) + return x + input + + +class CondResBlock(nn.Module): + def __init__( + self, + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=False, + emb_channels=512, + ): + super(CondResBlock, self).__init__() + + assert in_channels == out_channels, "Only support in_channels==out_channels." + + self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=[3, 3]) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size, stride, padding, bias=bias + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size, stride, padding, bias=bias + ) + + def forward(self, x, emb_sigma): + # u = self.conv1(mp_silu(x)) + u = self.conv1(F.relu((x))) + c = self.emb_linear(emb_sigma, gain=self.gain) + 1 + # y = mp_silu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype)) + y = F.relu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype)) + y = self.conv2(y) + return x + y + + +""" +Functional blocks below +""" +from collections import OrderedDict +import torch +import torch.nn as nn + + +""" +# -------------------------------------------- +# Advanced nn.Sequential +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +""" + + +def sequential(*args): + """Advanced nn.Sequential. + Args: + nn.Sequential, nn.Module + Returns: + nn.Sequential + """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError("sequential does not support OrderedDict input.") + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +""" +# -------------------------------------------- +# Useful blocks +# https://github.com/xinntao/BasicSR +# -------------------------------- +# conv + normaliation + relu (conv) +# (PixelUnShuffle) +# (ConditionalBatchNorm2d) +# concat (ConcatBlock) +# sum (ShortcutBlock) +# resblock (ResBlock) +# Channel Attention (CA) Layer (CALayer) +# Residual Channel Attention Block (RCABlock) +# Residual Channel Attention Group (RCAGroup) +# Residual Dense Block (ResidualDenseBlock_5C) +# Residual in Residual Dense Block (RRDB) +# -------------------------------------------- +""" + + +# -------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# -------------------------------------------- +def conv( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="CBR", + negative_slope=0.2, +): + L = [] + for t in mode: + if t == "C": + L.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + ) + elif t == "T": + L.append( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + ) + elif t == "B": + L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) + elif t == "I": + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == "R": + L.append(nn.ReLU(inplace=True)) + elif t == "r": + L.append(nn.ReLU(inplace=False)) + elif t == "L": + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) + elif t == "l": + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) + elif t == "E": + L.append(nn.ELU(inplace=False)) + elif t == "s": + L.append(nn.Softplus()) + elif t == "2": + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == "3": + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == "4": + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == "U": + L.append(nn.Upsample(scale_factor=2, mode="nearest")) + elif t == "u": + L.append(nn.Upsample(scale_factor=3, mode="nearest")) + elif t == "v": + L.append(nn.Upsample(scale_factor=4, mode="nearest")) + elif t == "M": + L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + elif t == "A": + L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + else: + raise NotImplementedError("Undefined type: ".format(t)) + return sequential(*L) + + +""" +# -------------------------------------------- +# Upsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# upsample_pixelshuffle +# upsample_upconv +# upsample_convtranspose +# -------------------------------------------- +""" + + +# -------------------------------------------- +# conv + subp (+ relu) +# -------------------------------------------- +def upsample_pixelshuffle( + in_channels=64, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in [ + "2", + "3", + "4", + ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." + up1 = conv( + in_channels, + out_channels * (int(mode[0]) ** 2), + kernel_size, + stride, + padding, + bias, + mode="C" + mode, + negative_slope=negative_slope, + ) + return up1 + + +# -------------------------------------------- +# nearest_upsample + conv (+ R) +# -------------------------------------------- +def upsample_upconv( + in_channels=64, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in [ + "2", + "3", + "4", + ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR" + if mode[0] == "2": + uc = "UC" + elif mode[0] == "3": + uc = "uC" + elif mode[0] == "4": + uc = "vC" + mode = mode.replace(mode[0], uc) + up1 = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode=mode, + negative_slope=negative_slope, + ) + return up1 + + +# -------------------------------------------- +# convTranspose (+ relu) +# -------------------------------------------- +def upsample_convtranspose( + in_channels=64, + out_channels=3, + kernel_size=2, + stride=2, + padding=0, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in [ + "2", + "3", + "4", + "8", + ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], "T") + up1 = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode, + negative_slope, + ) + return up1 + + +""" +# -------------------------------------------- +# Downsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# downsample_strideconv +# downsample_maxpool +# downsample_avgpool +# -------------------------------------------- +""" + + +# -------------------------------------------- +# strideconv (+ relu) +# -------------------------------------------- +def downsample_strideconv( + in_channels=64, + out_channels=64, + kernel_size=2, + stride=2, + padding=0, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in [ + "2", + "3", + "4", + "8", + ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], "C") + down1 = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode, + negative_slope, + ) + return down1 + + +# -------------------------------------------- +# maxpooling + conv (+ relu) +# -------------------------------------------- +def downsample_maxpool( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=0, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in [ + "2", + "3", + ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], "MC") + pool = conv( + kernel_size=kernel_size_pool, + stride=stride_pool, + mode=mode[0], + negative_slope=negative_slope, + ) + pool_tail = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode=mode[1:], + negative_slope=negative_slope, + ) + return sequential(pool, pool_tail) + + +# -------------------------------------------- +# averagepooling + conv (+ relu) +# -------------------------------------------- +def downsample_avgpool( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in [ + "2", + "3", + ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], "AC") + pool = conv( + kernel_size=kernel_size_pool, + stride=stride_pool, + mode=mode[0], + negative_slope=negative_slope, + ) + pool_tail = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode=mode[1:], + negative_slope=negative_slope, + ) + return sequential(pool, pool_tail) \ No newline at end of file diff --git a/models/heads.py b/models/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..b888541fb9465096cc59f76b54707eee1eb9dd43 --- /dev/null +++ b/models/heads.py @@ -0,0 +1,270 @@ +import torch +from models.blocks import AffineConv2d, downsample_strideconv, upsample_convtranspose + +class InHead(torch.nn.Module): + def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False): + super(InHead, self).__init__() + self.in_channels_list = in_channels_list + self.input_layer = input_layer + for i, in_channels in enumerate(in_channels_list): + conv = AffineConv2d( + in_channels=in_channels, + out_channels=out_channels, + bias=bias, + mode=mode, + kernel_size=3, + stride=1, + padding=1, + padding_mode="zeros", + ) + setattr(self, f"conv{i}", conv) + + def forward(self, x): + in_channels = x.size(1) - 1 if self.input_layer else x.size(1) + + # find index + i = self.in_channels_list.index(in_channels) + x = getattr(self, f"conv{i}")(x) + + return x + +class OutTail(torch.nn.Module): + def __init__(self, in_channels, out_channels_list, mode="", bias=False): + super(OutTail, self).__init__() + self.in_channels = in_channels + self.out_channels_list = out_channels_list + for i, out_channels in enumerate(out_channels_list): + conv = AffineConv2d( + in_channels=in_channels, + out_channels=out_channels, + bias=bias, + mode=mode, + kernel_size=3, + stride=1, + padding=1, + padding_mode="zeros", + ) + setattr(self, f"conv{i}", conv) + + def forward(self, x, out_channels): + i = self.out_channels_list.index(out_channels) + x = getattr(self, f"conv{i}")(x) + + return x + +# TODO: check that the heads are compatible with the old implementation +class Heads(torch.nn.Module): + def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, relu_in=False, skip_in=False): + super(Heads, self).__init__() + self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list] + self.scale = scale + self.mode = mode + for i, in_channels in enumerate(self.in_channels_list): + setattr(self, f"head{i}", HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) + + if self.mode == "": + self.nl = torch.nn.ReLU(inplace=False) + if self.scale != 1: + for i, in_channels in enumerate(in_channels_list): + setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) + + def forward(self, x): + in_channels = x.size(1) + i = self.in_channels_list.index(in_channels) + + if self.scale != 1: + if self.mode == "bilinear": + x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False) + else: + x = getattr(self, f"down{i}")(x) + x = self.nl(x) + + # find index + x = getattr(self, f"head{i}")(x) + + return x + +class Tails(torch.nn.Module): + def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, relu_in=False, skip_in=False): + super(Tails, self).__init__() + self.out_channels_list = out_channels_list + self.scale = scale + for i, out_channels in enumerate(out_channels_list): + setattr(self, f"tail{i}", HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) + + self.mode = mode + if self.mode == "": + self.nl = torch.nn.ReLU(inplace=False) + if self.scale != 1: + # self.up = upsample_convtranspose(out_channels, out_channels, bias=True, mode=str(self.scale)) + for i, out_channels in enumerate(out_channels_list): + setattr(self, f"up{i}", upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, mode=str(self.scale))) + + def forward(self, x, out_channels): + i = self.out_channels_list.index(out_channels) + x = getattr(self, f"tail{i}")(x) + # find index + if self.scale != 1: + if self.mode == "bilinear": + x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) + else: + x = getattr(self, f"up{i}")(x) + + return x + +class ConvChannels(torch.nn.Module): + """ + TODO: remplace this with convconv + A method that only performs convolutional operations on the appropriate channels dim. + """ + def __init__(self, channels_list, depth=2, bias=False, residual=False): + super(ConvChannels, self).__init__() + self.channels_list = channels_list + self.residual = residual + for i, channels in enumerate(channels_list): + setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1)) + setattr(self, f"nl{i}", torch.nn.ReLU()) + setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1)) + + def forward(self, x): + i = self.channels_list.index(x.shape[1]) + u = getattr(self, f"conv{i}_1")(x) + u = getattr(self, f"nl{i}")(u) + u = getattr(self, f"conv{i}_2")(u) + if self.residual: + u = x + u + return u + +class HeadBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False): + super(HeadBlock, self).__init__() + + padding = kernel_size // 2 + + c = out_channels if depth < 2 else in_channels + + self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias) + self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False) + self.depth = depth + self.nl_1 = torch.nn.ReLU(inplace=False) + self.nl_2 = torch.nn.ReLU(inplace=False) + self.relu_in = relu_in + self.skip_in = skip_in + + for i in range(depth-1): + if i < depth - 2: + c_in, c = in_channels, in_channels + else: + c_in, c = in_channels, out_channels + + setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias)) + setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias)) + setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False)) + + + def forward(self, x): + + if self.skip_in and self.relu_in: + x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x) + elif self.skip_in and not self.relu_in: + x = self.convin(x) + self.zero_conv_skip(x) + else: + x = self.convin(x) + + for i in range(self.depth-1): + aux = getattr(self, f"conv1{i}")(x) + aux = self.nl_2(aux) + aux_0 = getattr(self, f"conv2{i}")(aux) + aux_1 = getattr(self, f"skipconv{i}")(x) + x = aux_0 + aux_1 + + return x + + +class SNRModule(torch.nn.Module): + """ + A method that only performs convolutional operations on the appropriate channels dim. + """ + def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64): + super(SNRModule, self).__init__() + self.channels_list = channels_list + self.residual = residual + for i, channels in enumerate(channels_list): + setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels + 1, features, 3, bias=bias, padding=1)) + setattr(self, f"nl{i}", torch.nn.ReLU()) + setattr(self, f"conv{i}_2", torch.nn.Conv2d(features, out_channels, 3, bias=bias, padding=1)) + + def forward(self, x0, sigma): + i = self.channels_list.index(x0.shape[1]) + + noise_level_map = (torch.ones((x0.size(0), 1, x0.size(2), x0.size(3)), device=x0.device) * sigma) + x = torch.cat((x0, noise_level_map), 1) + + u = getattr(self, f"conv{i}_1")(x) + u = getattr(self, f"nl{i}")(u) + u = getattr(self, f"conv{i}_2")(u) + + den = u.pow(2).mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True).sqrt() + u = u.abs() / (den + 1e-8) + + return u.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + + +class EquivConvModule(torch.nn.Module): + """ + A method that only performs convolutional operations on the appropriate channels dim. + """ + def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64, N=1): + super(EquivConvModule, self).__init__() + self.channels_list = [c * N for c in channels_list] + self.residual = residual + for i, channels in enumerate(channels_list): + setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels * N, channels * N, 3, bias=bias, padding=1)) + setattr(self, f"nl{i}", torch.nn.ReLU()) + setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels * N, out_channels, 3, bias=bias, padding=1)) + + def forward(self, x): + + i = self.channels_list.index(x.shape[1]) + + u = getattr(self, f"conv{i}_1")(x) + u = getattr(self, f"nl{i}")(u) + u = getattr(self, f"conv{i}_2")(u) + + return u + + +class EquivHeads(torch.nn.Module): + def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear"): + super(EquivHeads, self).__init__() + self.in_channels_list = in_channels_list + self.scale = scale + self.mode = mode + for i, in_channels in enumerate(in_channels_list): + setattr(self, f"head{i}", HeadBlock(in_channels + 1, out_channels, depth=depth, bias=bias)) + + if self.mode == "": + self.nl = torch.nn.ReLU(inplace=False) + if self.scale != 1: + for i, in_channels in enumerate(in_channels_list): + setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) + + def forward(self, x, sigma): + in_channels = x.size(1) + i = self.in_channels_list.index(in_channels) + + if self.scale != 1: + if self.mode == "bilinear": + x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False) + else: + x = getattr(self, f"down{i}")(x) + x = self.nl(x) + + # concat noise level map + noise_level_map = (torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) * sigma) + x = torch.cat((x, noise_level_map), 1) + + # find index + x = getattr(self, f"head{i}")(x) + + return x diff --git a/models/unext_wip.py b/models/unext_wip.py new file mode 100644 index 0000000000000000000000000000000000000000..cc29614a8ffd9eabd4dc897eb3d97aca54c212e3 --- /dev/null +++ b/models/unext_wip.py @@ -0,0 +1,1238 @@ +# Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models +import re +import math +import functools + +import deepinv as dinv +from deepinv.utils import plot, TensorList + +import torch +from torch.func import vmap +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from deepinv.optim.utils import conjugate_gradient + +from physics.multiscale import MultiScaleLinearPhysics, Pad +from models.blocks import EquivMaxPool, AffineConv2d, ConvNextBlock2, NoiseEmbedding, MPConv, TimestepEmbedding, conv, downsample_strideconv, upsample_convtranspose +from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads + +cuda = True if torch.cuda.is_available() else False +Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor + + +### --------------- MODEL --------------- +class BaseEncBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + bias=False, + mode="CRC", + nb=2, + embedding=False, + emb_channels=None, + emb_physics=False, + img_channels=None, + decode_upscale=None, + config='A', + N=4, + c_mult=1, + depth_encoding=1, + relu_in_encoding=False, + skip_in_encoding=True, + ): + super(BaseEncBlock, self).__init__() + self.config = config + self.enc = nn.ModuleList( + [ + ResBlock( + in_channels, + out_channels, + bias=bias, + mode=mode, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=img_channels, + decode_upscale=decode_upscale, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + for _ in range(nb) + ] + ) + + def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): + for i in range(len(self.enc)): + x = self.enc[i](x, emb_sigma=emb_sigma, physics=physics, t=t, y=y, img_channels=img_channels, scale=scale) + return x + + +class NextEncBlock(nn.Module): + def __init__( + self, in_channels, out_channels, bias=False, mode="", mult_fact=4, nb=2 + ): + super(NextEncBlock, self).__init__() + self.enc = nn.ModuleList( + [ + ConvNextBlock2( + in_channels=in_channels, + out_channels=out_channels, + bias=bias, + mode=mode, + mult_fact=mult_fact, + ) + for _ in range(nb) + ] + ) + + def forward(self, x, emb_sigma=None): + for i in range(len(self.enc)): + x = self.enc[i](x, emb_sigma) + return x + + +class UNeXt(nn.Module): + r""" + DRUNet denoiser network. + + The network architecture is based on the paper + `Learning deep CNN denoiser prior for image restoration `_, + and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts. + + The network takes into account the noise level of the input image, which is encoded as an additional input channel. + + A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3) + can be downloaded via setting ``pretrained='download'``. + + :param int in_channels: number of channels of the input. + :param int out_channels: number of channels of the output. + :param list nc: number of convolutional layers. + :param int nb: number of convolutional blocks per layer. + :param int nf: number of channels per convolutional layer. + :param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus. + :param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and + "strideconv" for convolution with stride 2. + :param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel + shuffling, and "upconv" for nearest neighbour upsampling with additional convolution. + :param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random + using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an + online repository (only available for the default architecture with 3 or 1 input/output channels). + Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights. + See :ref:`pretrained-weights ` for more details. + :param bool train: training or testing mode. + :param str device: gpu or cpu. + + """ + + def __init__( + self, + in_channels=[1, 2, 3], + out_channels=[1, 2, 3], + nc=[64, 128, 256, 512], + nb=4, # 4 in DRUNet but out of memory + conv_type="next", # should be 'base' or 'next' + pool_type="next", # should be 'base' or 'next' + cond_type="base", # conditioning, should be 'base' or 'edm' + device=None, + bias=False, + mode="", + residual=False, + act_mode="R", + layer_scale_init_value=1e-6, + init_type="ortho", + gain_init_conv=1.0, + gain_init_linear=1.0, + drop_prob=0.0, + replk=False, + mult_fact=4, + antialias="gaussian", + emb_physics=False, + config='A', + pretrained_pth=None, + N=4, + c_mult=1, + depth_encoding=1, + relu_in_encoding=False, + skip_in_encoding=True, + ): + super(UNeXt, self).__init__() + + self.residual = residual + self.conv_type = conv_type + self.pool_type = pool_type + self.emb_physics = emb_physics + self.config = config + self.in_channels = in_channels + self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device)) + + self.separate_head = isinstance(in_channels, list) + + assert cond_type in ["base", "edm"], "cond_type should be 'base' or 'edm'" + self.cond_type = cond_type + + if self.cond_type == "base": + if self.config != 'E': + if isinstance(in_channels, list): + in_channels_first = [] + for i in range(len(in_channels)): + in_channels_first.append(in_channels[i] + 2) + else: # old head + in_channels_first = in_channels + 1 + else: + in_channels_first = in_channels + else: + in_channels_first = in_channels + self.noise_embedding = NoiseEmbedding( + num_channels=in_channels, emb_channels=max(nc), device=device + ) + + self.timestep_embedding = lambda x: x + + # check if in_channels is a list + self.m_head = InHead(in_channels_first, nc[0]) + + if conv_type == "next": + self.m_down1 = NextEncBlock( + nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb + ) + self.m_down2 = NextEncBlock( + nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb + ) + self.m_down3 = NextEncBlock( + nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb + ) + self.m_body = NextEncBlock( + nc[3], nc[3], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb + ) + self.m_up3 = NextEncBlock( + nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb + ) + self.m_up2 = NextEncBlock( + nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb + ) + self.m_up1 = NextEncBlock( + nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb + ) + + elif conv_type == "base": + embedding = ( + False if cond_type == "base" else True + ) + emb_channels = max(nc) + self.m_down1 = BaseEncBlock( + nc[0], + nc[0], + bias=False, + mode="CRC", + nb=nb, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=in_channels, + decode_upscale=1, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + self.m_down2 = BaseEncBlock( + nc[1], + nc[1], + bias=False, + mode="CRC", + nb=nb, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=in_channels, + decode_upscale=2, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + self.m_down3 = BaseEncBlock( + nc[2], + nc[2], + bias=False, + mode="CRC", + nb=nb, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=in_channels, + decode_upscale=4, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + self.m_body = BaseEncBlock( + nc[3], + nc[3], + bias=False, + mode="CRC", + nb=nb, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=in_channels, + decode_upscale=8, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + self.m_up3 = BaseEncBlock( + nc[2], + nc[2], + bias=False, + mode="CRC", + nb=nb, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=in_channels, + decode_upscale=4, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + self.m_up2 = BaseEncBlock( + nc[1], + nc[1], + bias=False, + mode="CRC", + nb=nb, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=in_channels, + decode_upscale=2, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + self.m_up1 = BaseEncBlock( + nc[0], + nc[0], + bias=False, + mode="CRC", + nb=nb, + embedding=embedding, + emb_channels=emb_channels, + emb_physics=emb_physics, + img_channels=in_channels, + decode_upscale=1, + config=config, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ) + + else: + raise NotImplementedError("conv_type should be 'base' or 'next'") + + if pool_type == "next_max": + self.pool1 = EquivMaxPool( + antialias=antialias, + in_channels=nc[0], + out_channels=nc[1], + device=device, + ) + self.pool2 = EquivMaxPool( + antialias=antialias, + in_channels=nc[1], + out_channels=nc[2], + device=device, + ) + self.pool3 = EquivMaxPool( + antialias=antialias, + in_channels=nc[2], + out_channels=nc[3], + device=device, + ) + elif pool_type == "base": + self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2") + self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2") + self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2") + self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2") + self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2") + self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2") + else: + raise NotImplementedError("pool_type should be 'base' or 'next'") + + self.m_tail = OutTail(nc[0], in_channels) + + if conv_type == "base": + init_func = functools.partial( + weights_init_unext, init_type="ortho", gain_conv=0.2 + ) + self.apply(init_func) + else: + init_func = functools.partial( + weights_init_unext, + init_type=init_type, + gain_conv=gain_init_conv, + gain_linear=gain_init_linear, + ) + self.apply(init_func) + + if pretrained_pth=='jz': + pth = '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth' + self.load_drunet_weights(pth) + elif pretrained_pth is not None: + self.load_drunet_weights(pretrained_pth) + + if self.config == 'D': + # deactivate grad for layers that do not contain the string "PhysicsBlock" or "gain" or "fact_realign" + for name, param in self.named_parameters(): + if 'PhysicsBlock' not in name and 'gain' not in name and 'fact_realign' not in name and "m_head" not in name and "m_tail" not in name: + param.requires_grad = False + + if device is not None: + self.to(device) + + def load_drunet_weights(self, ckpt_pth): + state_dict = torch.load(ckpt_pth, map_location=lambda storage, loc: storage) + + new_state_dict = {} + matched_keys = [] # List to store successfully matched keys + unmatched_keys = [] # List to store keys that were not matched or excluded + excluded_keys = [] # List to store excluded keys + + # Define patterns to exclude + exclude_patterns = ["head", "tail"] + + # Dealing with regular keys + for old_key, value in state_dict.items(): + # Skip keys containing any of the excluded patterns + if any(excluded in old_key for excluded in exclude_patterns): + excluded_keys.append(old_key) + continue # Skip further processing for this key + + new_key = old2new(old_key) + + if new_key is not None: + matched_keys.append((old_key, new_key)) # Record the matched keys + new_state_dict[new_key] = value + else: + unmatched_keys.append(old_key) # Record unmatched keys + + # TODO: clean this + for excluded_key in excluded_keys: + if isinstance(self.in_channels, list): + for i, in_channel in enumerate(self.in_channels): + # print('Dealing with conv ', i) + new_key = f"m_head.conv{i}.weight" + if 'head' in excluded_key: + new_key = f"m_head.conv{i}.weight" + # new_key = f"m_head.head.conv{i}.weight" + if 'tail' in excluded_key: + new_key = f"m_tail.conv{i}.weight" + # DEBUG print all keys of state dict: + # print(state_dict.keys()) + # print(self.state_dict().keys()) + conditioning = 'base' + # if self.config == 'E': + # conditioning = False + new_kv = update_keyvals_headtail(excluded_key, + state_dict[excluded_key], + init_value=self.state_dict()[new_key], + new_key_name=new_key, + conditioning=conditioning) + new_state_dict.update(new_kv) + # print(new_kv.keys()) + else: + new_kv = update_keyvals_headtail(excluded_key, state_dict[excluded_key]) + new_state_dict.update(new_kv) + + # Display matched keys + print("Matched keys:") + for old_key, new_key in matched_keys: + print(f"{old_key} -> {new_key}") + + # Load updated state dict into the model + self.load_state_dict(new_state_dict, strict=False) + + # Display unmatched keys + print("\nUnmatched keys:") + for unmatched_key in unmatched_keys: + print(unmatched_key) + + print("Weights loaded from ", ckpt_pth) + + def constant2map(self, value, x): + if isinstance(value, torch.Tensor): + if value.ndim > 0: + value_map = value.view(x.size(0), 1, 1, 1) + value_map = value_map.expand(-1, 1, x.size(2), x.size(3)) + else: + value_map = torch.ones( + (x.size(0), 1, x.size(2), x.size(3)), device=x.device + ) * value[None, None, None, None].to(x.device) + else: + value_map = ( + torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) + * value + ) + return value_map + + def base_conditioning(self, x, sigma, gamma): + noise_level_map = self.constant2map(sigma, x) + gamma_map = self.constant2map(gamma, x) + return torch.cat((x, noise_level_map, gamma_map), 1) + + def realign_input(self, x, physics, y): + + if hasattr(physics, "factor"): + f = physics.factor + elif hasattr(physics, "base") and hasattr(physics.base, "factor"): + f = physics.base.factor + elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"): + f = physics.base.base.factor + else: + f = 1.0 + + sigma = 1e-6 # default value + if hasattr(physics.noise_model, 'sigma'): + sigma = physics.noise_model.sigma + if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'): + sigma = physics.base.noise_model.sigma + if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr(physics.base.base.noise_model, 'sigma'): + sigma = physics.base.base.noise_model.sigma + + if isinstance(y, TensorList): + num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1)) + else: + num = (y.reshape(y.shape[0], -1).abs().mean(1)) + + snr = num / (sigma + 1e-4) # SNR equivariant + gamma = 1 / (1e-4 + 1 / (snr * f **2 )) # TODO: check square-root / mean / check if we need to add a factor in front ? + gamma = gamma[(...,) + (None,) * (x.dim() - 1)] + model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign) + + return model_input + + def forward_unet(self, x0, sigma=None, gamma=None, physics=None, t=None, y=None, img_channels=None): + + # list_values = [] + + if self.cond_type == "base": + # if self.config != 'E': + x0 = self.base_conditioning(x0, sigma, gamma) + emb_sigma = None + else: + emb_sigma = self.noise_embedding( + sigma + ) # This only if the embedding is the non-basic one from drunet + + emb_timestep = self.timestep_embedding(t) + + x1 = self.m_head(x0) # old + # x1 = self.m_head(x0, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) + # list_values.append(x1.abs().mean()) + + if self.config == 'G': + x1_, emb1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) + else: + x1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) + x2 = self.pool1(x1_) + # list_values.append(x2.abs().mean()) + + if self.config == 'G': + x3_, emb3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) + else: + x3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) + x3 = self.pool2(x3_) + + # list_values.append(x3.abs().mean()) + if self.config == 'G': + x4_, emb4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) + else: + x4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) + x4 = self.pool3(x4_) + + # issue: https://github.com/matthieutrs/ram_project/issues/1 + # solution 1: using .contiguous() below + # solution 2: using a print statement that magically solves the issue + ###print(x4.is_contiguous()) + + # list_values.append(x4.abs().mean()) + if self.config == 'G': + x, _ = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) + else: + x = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=3) + + # list_values.append(x.abs().mean()) + if self.pool_type == "next" or self.pool_type == "next_max": + x = self.pool3.upscale(x + x4) + else: + x = self.up3(x + x4) + + if self.config == 'G': + x, _ = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb4_, img_channels=img_channels) + else: + x = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) + + # list_values.append(x.abs().mean()) + if self.pool_type == "next" or self.pool_type == "next_max": + x = self.pool2.upscale(x + x3) + else: + x = self.up2(x + x3) + + if self.config == 'G': + x, _ = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb3_, img_channels=img_channels) + else: + x = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) + + # list_values.append(x.abs().mean()) + if self.pool_type == "next" or self.pool_type == "next_max": + x = self.pool1.upscale(x + x2) + else: + x = self.up1(x + x2) + + if self.config == 'G': + x, _ = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb1_, img_channels=img_channels) + else: + x = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) + + # list_values.append(x.abs().mean()) + if self.separate_head: + x = self.m_tail(x + x1, img_channels) + else: + x = self.m_tail(x + x1) + + return x + + def forward(self, x, sigma=None, gamma=None, physics=None, t=None, y=None): + r""" + Run the denoiser on image with noise level :math:`\sigma`. + + :param torch.Tensor x: noisy image + :param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch. + If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``. + """ + img_channels = x.shape[1] # x_n_chan = x.shape[1] + if self.emb_physics: + physics = MultiScaleLinearPhysics(physics, x.shape[-3:], device=x.device) + + if self.separate_head and img_channels not in self.in_channels: + raise ValueError(f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.") + + if y is not None: + x = self.realign_input(x, physics, y) + + x = self.forward_unet(x, sigma=sigma, gamma=gamma, physics=physics, t=t, y=y, img_channels=img_channels) + + return x + + +def krylov_embeddings_old(y, p, factor, v=None, N=4, feat_size=1, x_init=None, img_channels=3): + + if x_init is None: + x = p.A_adjoint(y) + else: + x = x_init[:, :img_channels, ...] + + if feat_size > 1: + _, C, _, _ = x.shape + if v is None: + v = torch.zeros_like(x).repeat(1, N-1, 1, 1) + out = x - v[:, :C, ...] + norm = factor ** 2 + A = lambda u: p.A_adjoint(p.A(u)) * norm + for i in range(N-1): + x = A(x) - v[:, (i+1) * C:(i+2) * C, ...] + out = torch.cat([out, x], dim=1) + else: + if v is None: + v = torch.zeros_like(x) + out = x - v + norm = factor ** 2 + A = lambda u: p.A_adjoint(p.A(u)) * norm + for i in range(N-1): + x = A(x) - v + out = torch.cat([out, x], dim=1) + return out + +def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None, img_channels=3): + """ + Efficient Krylov subspace embedding computation with parallel processing. + + Args: + y (torch.Tensor): The input tensor. + p: An object with A and A_adjoint methods (linear operator). + factor (float): Scaling factor. + v (torch.Tensor, optional): Precomputed values to subtract from Krylov sequence. Defaults to None. + N (int, optional): Number of Krylov iterations. Defaults to 4. + feat_size (int, optional): Feature expansion size. Defaults to 1. + x_init (torch.Tensor, optional): Initial guess. Defaults to None. + img_channels (int, optional): Number of image channels. Defaults to 3. + + Returns: + torch.Tensor: The Krylov embeddings. + """ + + if x_init is None: + x = p.A_adjoint(y) + else: + x = x_init.clone() # Extract the first img_channels + + norm = factor ** 2 # Precompute normalization factor + AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator + + v = v if v is not None else torch.zeros_like(x) + + out = x.clone() + # Compute Krylov basis + x_k = x.clone() + for i in range(N-1): + x_k = AtA(x_k) - v + out = torch.cat([out, x_k], dim=1) + + return out + + +def grad_embeddings(y, p, factor, v=None, N=4, feat_size=1): + Aty = p.A_adjoint(y) + if feat_size > 1: + _, C, _, _ = Aty.shape + if v is None: + v = torch.zeros_like(Aty).repeat(1, N-1, 1, 1) + out = v[:, :C, ...] - Aty + norm = factor ** 2 + A = lambda u: p.A_adjoint(p.A(u)) * norm + for i in range(N-1): + x = A(v[:, (i+1) * C:(i+2) * C, ...]) - Aty + out = torch.cat([out, x], dim=1) + else: + if v is None: + v = torch.zeros_like(Aty) + out = v - Aty + norm = factor ** 2 + A = lambda u: p.A_adjoint(p.A(u)) * norm + for i in range(N-1): + x = A(v) - Aty + out = torch.cat([out, x], dim=1) + return out + + +def prox_embeddings(y, p, factor, v=None, N=4): + x = p.A_adjoint(y) + B, C, H, W = x.shape + + if v is None: + v = torch.zeros_like(x) + + v = v.repeat(1, N - 1, 1, 1) + + gamma = torch.logspace(-4, -1, N-1, device=x.device).repeat_interleave(C).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + norm = factor ** 2 + A_sub = lambda u: torch.cat([p.A_adjoint(p.A(u[:, i * C:(i+1) * C, ...])) * norm for i in range(N-1)], dim=1) + A = lambda u: A_sub(u) + (u - v) * gamma + + u_hat = conjugate_gradient(A, x.repeat(1, N-1, 1, 1), max_iter=3, tol=1e-3) + u_hat = torch.cat([u_hat, x], dim=1) + + return u_hat + +# -------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# -------------------------------------------- +class MeasCondBlock(nn.Module): + def __init__( + self, + out_channels=64, + img_channels=None, + decode_upscale=None, + config = 'A', + N=4, + depth_encoding=1, + relu_in_encoding=False, + skip_in_encoding=True, + c_mult=1, + ): + super(MeasCondBlock, self).__init__() + + self.separate_head = isinstance(img_channels, list) + self.config = config + + assert img_channels is not None, "decode_dimensions should be provided" + assert decode_upscale is not None, "decode_upscale should be provided" + + # if self.separate_head: + if self.config == 'A': + self.relu_encoding = nn.ReLU(inplace=False) + self.N = N + self.c_mult = c_mult + self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) + if self.config == 'B': + self.N = N + self.c_mult = c_mult + self.relu_encoding = nn.ReLU(inplace=False) + self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) + self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) + if self.config == 'C': + self.N = N + self.c_mult = c_mult + self.relu_encoding = nn.ReLU(inplace=False) + self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) + self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) + elif self.config == 'D': + self.N = N + self.c_mult = c_mult + self.relu_encoding = nn.ReLU(inplace=False) + self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) + self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) + + self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) + self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) + self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) + self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) + self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) + + def forward(self, x, y, physics, t, emb_in=None, img_channels=None, scale=1): + if self.config == 'A': + return self.measurement_conditioning_config_A(x, y, physics, img_channels=img_channels, scale=scale) + elif self.config == 'F': + return self.measurement_conditioning_config_F(x, y, physics, img_channels=img_channels, scale=scale) + elif self.config == 'B': + return self.measurement_conditioning_config_B(x, y, physics, img_channels=img_channels, scale=scale) + elif self.config == 'C': + return self.measurement_conditioning_config_C(x, y, physics, img_channels=img_channels, scale=scale) + elif self.config == 'D': + return self.measurement_conditioning_config_D(x, y, physics, img_channels=img_channels, scale=scale) + elif self.config == 'E': + return self.measurement_conditioning_config_E(x, y, physics, img_channels=img_channels, scale=scale) + else: + raise NotImplementedError('Config not implemented') + + def measurement_conditioning_config_A(self, x, y, physics, img_channels, scale=0): + physics.set_scale(scale) + factor = 2**(scale) + meas = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) + cond = self.encoding_conv(meas) + emb = self.relu_encoding(cond) + return emb + + def measurement_conditioning_config_B(self, x, y, physics, img_channels, scale=0): + physics.set_scale(scale) + dec = self.decoding_conv(x, img_channels) + factor = 2**(scale) + meas = krylov_embeddings(y, physics, factor, v=dec, N=self.N, img_channels=img_channels) + cond = self.encoding_conv(meas) + emb = self.relu_encoding(cond) + return emb # * sigma_emb + + def measurement_conditioning_config_C(self, x, y, physics, img_channels, scale=0): + physics.set_scale(scale) + dec = self.decoding_conv(x, img_channels) + factor = 2**(scale) + meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) + meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) + for c in range(1, self.c_mult): + meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], + img_channels=img_channels) + meas_dec = torch.cat([meas_dec, meas_cur], dim=1) + meas = torch.cat([meas_y, meas_dec], dim=1) + cond = self.encoding_conv(meas) + emb = self.relu_encoding(cond) + return emb + + def measurement_conditioning_config_D(self, x, y, physics, img_channels, scale=0): + physics.set_scale(scale) + dec = self.decoding_conv(x, img_channels) + factor = 2**(scale) + meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) + meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) + for c in range(1, self.c_mult): + meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], + img_channels=img_channels) + meas_dec = torch.cat([meas_dec, meas_cur], dim=1) + meas = torch.cat([meas_y, meas_dec], dim=1) + cond = self.encoding_conv(meas) + emb = self.relu_encoding(cond) + return cond + + def measurement_conditioning_config_F(self, x, y, physics, img_channels): + dec_large = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) + dec = self.relu_decoding(dec_large) + + Adec = physics.A(dec) + + grad = physics.A_adjoint(self.gain_gradx ** 2 * Adec - self.gain_grady ** 2 * y) # TODO: check if we need to have L2 (depending on noise nature, can be automated) + + if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): + pinv = physics.prox_l2(dec, self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y, gamma=1e9) + else: + pinv = physics.A_dagger(self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess + + # Mix grad and pinv + emb = grad - pinv # will be 0 in the case of denoising, but also inpainting + im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too + grad_large = emb + im_emb + + emb_grad = self.encoding_conv(grad_large) + emb_grad = self.relu_encoding(emb_grad) + return emb_grad + + def measurement_conditioning_config_E(self, x, y, physics, img_channels, scale=1): + dec = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) + + physics.set_scale(scale) + + # TODO: check things are batched + f = physics.factor if hasattr(physics, "factor") else 1.0 + err = (physics.A_adjoint(physics.A(dec) - y)) + # snr = self.snr_module(err) + snr = dec.reshape(dec.shape[0], -1).abs().mean(dim=1) / (err.reshape(err.shape[0], -1).abs().mean(dim=1) + 1e-4) + + gamma = 1 / (1e-4 + 1 / (snr * f ** 2 + 1)) # TODO: check square-root / mean / check if we need to add a factor in front + gamma_est = gamma[(...,) + (None,) * (dec.dim() - 1)] + + prox = physics.prox_l2(dec, y, gamma=gamma_est * self.fact_prox) + emb = self.fact_prox_skip_1 * prox + self.fact_prox_skip_2 * dec + + emb_grad = self.encoding_conv(emb) + emb_grad = self.relu_encoding(emb_grad) + return emb_grad + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="CRC", + negative_slope=0.2, + embedding=False, + emb_channels=None, + emb_physics=False, + img_channels=None, + decode_upscale=None, + config = 'A', + head=False, + tail=False, + N=4, + c_mult=1, + depth_encoding=1, + relu_in_encoding=False, + skip_in_encoding=True, + ): + super(ResBlock, self).__init__() + + if not head and not tail: + assert in_channels == out_channels, "Only support in_channels==out_channels." + self.separate_head = isinstance(img_channels, list) + self.config = config + self.is_head = head + self.is_tail = tail + + if self.is_head: + self.head = InHead(img_channels, out_channels, input_layer=True) + + # if self.is_tail: + # self.tail = OutTail(in_channels, out_channels) + + if not self.is_head and not self.is_tail: + self.conv1 = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + "C", + negative_slope, + ) + self.nl = nn.ReLU(inplace=True) + self.conv2 = conv( + out_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + "C", + negative_slope, + ) + + if embedding: + self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=[]) + + self.emb_physics = emb_physics + + if self.emb_physics: + self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) + self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, config=config, c_mult=c_mult, + img_channels=img_channels, decode_upscale=decode_upscale, + N=N, depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding) + + def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): + u = self.conv1(x) + u = self.nl(u) + u_2 = self.conv2(u) # Should we sum this with below? + if self.emb_physics: # TODO: add a factor (1+gain) to the emb_meas? that depends on the input snr + emb_grad = self.PhysicsBlock(u, y, physics, t, img_channels=img_channels, scale=scale) + u_1 = self.gain * emb_grad # x - grad (sign does not matter) + else: + u_1 = 0 + return x + u_2 + u_1 + + + + +def calculate_fan_in_and_fan_out(tensor, pytorch_style: bool = True): + """ + from https://github.com/megvii-research/basecls/blob/main/basecls/layers/wrapper.py#L77 + """ + if len(tensor.shape) not in (2, 4, 5): + raise ValueError( + "fan_in and fan_out can only be computed for tensor with 2/4/5 " + "dimensions" + ) + if len(tensor.shape) == 5: + # `GOIKK` to `OIKK` + tensor = tensor.reshape(-1, *tensor.shape[2:]) if pytorch_style else tensor[0] + + num_input_fmaps = tensor.shape[1] + num_output_fmaps = tensor.shape[0] + receptive_field_size = 1 + if len(tensor.shape) > 2: + receptive_field_size = functools.reduce(lambda x, y: x * y, tensor.shape[2:], 1) + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + return fan_in, fan_out + + +def weights_init_unext(m, gain_conv=1.0, gain_linear=1.0, init_type="ortho"): + if hasattr(m, "modules"): + for submodule in m.modules(): + if not 'skip' in str(submodule): + if isinstance(submodule, nn.Conv2d) or isinstance( + submodule, nn.ConvTranspose2d + ): + # nn.init.orthogonal_(submodule.weight.data, gain=1.0) + k_shape = submodule.weight.data.shape[-1] + if k_shape < 4: + nn.init.orthogonal_(submodule.weight.data, gain=0.2) + else: + _, fan_out = calculate_fan_in_and_fan_out(submodule.weight) + std = math.sqrt(2 / fan_out) + nn.init.normal_(submodule.weight, 0, std) + # if init_type == 'ortho': + # nn.init.orthogonal_(submodule.weight.data, gain=gain_conv) + # elif init_type == 'kaiming': + # nn.init.kaiming_normal_(submodule.weight.data, a=0, mode='fan_in') + # elif init_type == 'xavier': + # nn.init.xavier_normal_(submodule.weight.data, gain=gain_conv) + elif isinstance(submodule, nn.Linear): + nn.init.normal_(submodule.weight.data, std=0.01) + elif 'skip' in str(submodule): + if isinstance(submodule, nn.Conv2d) or isinstance( + submodule, nn.ConvTranspose2d + ): + nn.init.ones_(submodule.weight.data) + # else: + # classname = submodule.__class__.__name__ + # # print('WARNING: no init for ', classname) + +def old2new(old_key): + """ + Converting old DRUNet keys to new UNExt style keys. + + PATTERNS TO MATCH: + 1. Case of downsampling blocks: + - for residual blocks (non-downsampling): + m_down3.2.res.0.weight -> m_down3.enc.2.conv1.weight + - for downsampling blocks: + m_down3.4.weight -> m_down3.downsample_strideconv.weight + 2. Case of upsampling blocks: + - for upsampling: + m_up3.0.weight -> m_up3.upsample_convtranspose.weight + - for residual blocks: + m_up3.2.res.0.weight -> m_up3.enc.2.conv1.weight + 3. Case for body blocks: + m_body.0.res.2.weight -> m_body.enc.0.conv2.weight + + Args: + old_key (str): The old key from the state dictionary. + + Returns: + str or None: The new key if matched, otherwise None. + """ + # Match keys with the pattern for residual blocks (downsampling) + match_residual = re.search(r"(m_down\d+)\.(\d+)\.res\.(\d+)", old_key) + if match_residual: + prefix = match_residual.group(1) # e.g., "m_down2" + index = match_residual.group(2) # e.g., "3" + conv_index = int(match_residual.group(3)) # e.g., "0" + + # Determine the new conv index: 0 -> 1, 2 -> 2 + new_conv_index = 1 if conv_index == 0 else 2 + # Construct the new key + new_key = f"{prefix}.enc.{index}.conv{new_conv_index}.weight" + return new_key + + match_residual = re.search(r"(m_up\d+)\.(\d+)\.res\.(\d+)", old_key) + if match_residual: + prefix = match_residual.group(1) # e.g., "m_down2" + index = int(match_residual.group(2)) # e.g., "3" + conv_index = int(match_residual.group(3)) # e.g., "0" + + # Determine the new conv index: 0 -> 1, 2 -> 2 + new_conv_index = 1 if conv_index == 0 else 2 + # Construct the new key + new_key = f"{prefix}.enc.{index-1}.conv{new_conv_index}.weight" + return new_key + + match_pool_downsample = re.search(r"m_down(\d+)\.4\.weight", old_key) + if match_pool_downsample: + index = match_pool_downsample.group(1) # e.g., "1" or "2" + # Construct the new key + new_key = f"pool{index}.weight" + return new_key + + # Match keys for upsampling blocks + match_upsample = re.search(r"m_up(\d+)\.0\.weight", old_key) + if match_upsample: + index = match_upsample.group(1) # e.g., "1" or "2" + # Construct the new key + new_key = f"up{index}.weight" + return new_key + + # Match keys for body blocks + match_body = re.search(r"(m_body)\.(\d+)\.res\.(\d+)\.weight", old_key) + if match_body: + prefix = match_body.group(1) # e.g., "m_body" + index = match_body.group(2) # e.g., "0" + conv_index = int(match_body.group(3)) # e.g., "2" + + new_convindex = 1 if conv_index == 0 else 2 + + # Construct the new key + new_key = f"{prefix}.enc.{index}.conv{new_convindex}.weight" + return new_key + + # If no patterns match, return None + return None + +def update_keyvals_headtail(old_key, old_value, init_value=None, new_key_name='m_head.conv0.weight', conditioning='base'): + """ + Converting old DRUNet keys to new UNExt style keys. + + KEYS do not change but weight need to be 0 padded. + + Args: + old_key (str): The old key from the state dictionary. + """ + if 'head' in old_key: + if conditioning == 'base': + c_in = init_value.shape[1] + c_in_old = old_value.shape[1] + # if c_in == c_in_old: + # new_value = old_value.detach() + # elif c_in < c_in_old: + # new_value = torch.zeros_like(init_value.detach()) + # new_value[:, -1:, ...] = old_value[:, -1:, ...] + # new_value[:, :c_in-1, ...] = old_value[:, :c_in-1, ...] + # if c_in == c_in_old: + # new_value = old_value.detach() + # elif c_in < c_in_old: + new_value = torch.zeros_like(init_value.detach()) + new_value[:, -2:-1, ...] = old_value[:, -1:, ...] + new_value[:, -1:, ...] = old_value[:, -1:, ...] + new_value[:, :c_in-2, ...] = old_value[:, :c_in-2, ...] + return {new_key_name: new_value} + else: + c_in = init_value.shape[1] + c_in_old = old_value.shape[1] + # if c_in == c_in_old - 1: + # new_value = old_value[:, :-1, ...].detach() + # elif c_in < c_in_old - 1: + # new_value = torch.zeros_like(init_value.detach()) + # new_value[:, -1:, ...] = old_value[:, -1:, ...] + # new_value[:, ...] = old_value[:, :c_in, ...] + new_value = torch.zeros_like(init_value.detach()) + new_value[:, -1:-2, ...] = old_value[:, -1:, ...] + new_value[:, -1:, ...] = old_value[:, -1:, ...] + new_value[:, ...] = old_value[:, :c_in, ...] + return {new_key_name: new_value} + elif 'tail' in old_key: + c_in = init_value.shape[0] + c_in_old = old_value.shape[0] + new_value = torch.zeros_like(init_value.detach()) + if c_in == c_in_old: + new_value = old_value.detach() + elif c_in < c_in_old: + new_value = torch.zeros_like(init_value.detach()) + new_value[:, ...] = old_value[:c_in, ...] + return {new_key_name: new_value} + else: + print(f"Key {old_key} does not contain 'head' or 'tail'.") + + + +# test the network +if __name__ == "__main__": + net = UNeXt() + x = torch.randn(1, 3, 128, 128) + y = net(x, 0.1) + # print(y.shape) + # print(y) + + +# Case for diagonal physics +# IDEA 1: kills signal in the image of A +# im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too +# IDEA 2: compute norm of signal in ker of A +# normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) +# im_emb = normker * physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too +# IDEA 3: same as above but add the pinv as well +# normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) +# grad_term = physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) +# # pinv_term = physics.A_dagger(self.gain_diagpinv_x * physics.A(dec) - self.gain_diagpinv_y * y) +# if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): +# pinv_term = physics.prox_l2(dec, self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y, gamma=1e9) +# else: +# pinv_term = physics.A_dagger(self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess +# im_emb = normker * (grad_term + pinv_term) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too + +# # Mix it +# if hasattr(physics.noise_model, 'sigma'): +# sigma = physics.noise_model.sigma # SNR ? x /= sigma ** 2 +# snr = (y.abs().mean()) / (sigma + 1e-4) # SNR equivariant # TODO: add epsilon +# snr = snr[(...,) + (None,) * (im_emb.dim() - 1)] +# else: +# snr = 1e4 +# +# grad_large = emb + self.gain_diag * (1 + self.gain_noise / snr) * im_emb \ No newline at end of file diff --git a/models/unrolled_dpir.py b/models/unrolled_dpir.py new file mode 100644 index 0000000000000000000000000000000000000000..1adb982f7553e59e30ba2f4b9842b9d2abfe9eb1 --- /dev/null +++ b/models/unrolled_dpir.py @@ -0,0 +1,304 @@ +import numpy as np +import deepinv +import torch +import deepinv as dinv +from deepinv.optim.data_fidelity import L2 +from deepinv.optim.prior import PnP +from deepinv.unfolded import unfolded_builder +import copy +import deepinv.optim.utils + +class PoissonGaussianDistance(dinv.optim.Distance): + r""" + Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm + + .. math:: + f(x) = (x-y)^{T}\Sigma_y(x-y) + + with :math:`\Sigma_y=\text{diag}(gamma y + \sigma^2)` + + :param float sigma: Gaussian noise parameter. Default: 1. + :param float gain: Poisson noise parameter. Default 0. + """ + + def __init__(self, sigma=1.0, gain=0.): + super().__init__() + self.sigma = sigma + self.gain = gain + + def fn(self, x, y, *args, **kwargs): + r""" + Computes the distance :math:`\distance{x}{y}` i.e. + + .. math:: + + \distance{x}{y} = \frac{1}{2}\|x-y\|^2 + + + :param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed. + :param torch.Tensor y: Data :math:`y`. + :return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch. + """ + norm = 1.0 / (self.sigma**2 + y * self.gain) + z = (x - y) * norm + d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2 + return d + + def grad(self, x, y, *args, **kwargs): + r""" + Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{x}\distance{x}{y}`, i.e. + + .. math:: + + \nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y + + + :param torch.Tensor x: Variable :math:`x` at which the gradient is computed. + :param torch.Tensor y: Observation :math:`y`. + :return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`. + """ + norm = 1.0 / (self.sigma**2 + y * self.gain) + return (x - y) * norm + + def prox(self, x, y, *args, gamma=1.0, **kwargs): + r""" + Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`. + + Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e. + + .. math:: + + \operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2 + + + :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. + :param torch.Tensor y: Data :math:`y`. + :param float gamma: thresholding parameter. + :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`. + """ + norm = 1.0 / (self.sigma**2 + y * self.gain) + return (x + norm * gamma * y) / (1 + gamma * norm) + + +class PoissonGaussianDataFidelity(dinv.optim.DataFidelity): + r""" + Implementation of the data-fidelity as the normalized :math:`\ell_2` norm + + .. math:: + + f(x) = \|\forw{x}-y\|^2_{\text{diag}(\sigma^2 + y \gamma)} + + It can be used to define a log-likelihood function associated with Poisson Gaussian noise + by setting an appropriate noise level :math:`\sigma`. + + :param float sigma: Standard deviation of the noise to be used as a normalisation factor. + :param float gain: Gain factor of the data-fidelity term. + """ + + def __init__(self, sigma=1.0, gain=0.): + super().__init__() + self.d = PoissonGaussianDistance(sigma=sigma, gain=gain) + self.gain = gain + self.sigma = sigma + + def prox(self, x, y, physics, gamma=1.0, *args, **kwargs): + r""" + Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`. + + Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e. + + .. math:: + + \operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2 + + + :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. + :param torch.Tensor y: Data :math:`y`. + :param deepinv.physics.Physics physics: physics model. + :param float gamma: stepsize of the proximity operator. + :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`. + """ + assert isinstance(physics, dinv.physics.LinearPhysics), "not implemented for non-linear physics" + if isinstance(physics, dinv.physics.StackedPhysics): + device=y[0].device + noise_model = physics[-1].noise_model + else: + device=y.device + noise_model = physics.noise_model + if hasattr(noise_model, "gain"): + self.gain = noise_model.gain.detach().to(device) + if hasattr(noise_model, "sigma"): + self.sigma = noise_model.sigma.detach().to(device) + # Ensure sigma is a tensor and reshape if necessary + if isinstance(self.sigma, float): + self.sigma = torch.tensor([self.sigma], device=device) + if self.sigma.ndim == 0 : + self.sigma = self.sigma.unsqueeze(0).to(device) + # Ensure gain is a tensor and reshape if necessary + if isinstance(self.gain, float): + self.gain = torch.tensor([self.gain], device=device) + if self.gain.ndim == 0 : + self.gain = self.gain.unsqueeze(0).to(device) + if self.gain[0] > 0 : + norm = gamma / (self.sigma[:, None, None, None]**2 + y * self.gain[:, None, None, None]) + else : + norm = gamma / (self.sigma[:, None, None, None]**2) + A = lambda u: physics.A_adjoint(physics.A(u)*norm) + u + b = physics.A_adjoint(norm*y) + x + return deepinv.optim.utils.conjugate_gradient(A, b, init=x, max_iter=3, tol=1e-3) + +from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep + +class myHQSIteration(OptimIterator): + r""" + Single iteration of half-quadratic splitting. + + Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`f(x) + \lambda \regname(x)`. + The iteration is given by + + + .. math:: + \begin{equation*} + \begin{aligned} + u_{k} &= \operatorname{prox}_{\gamma f}(x_k) \\ + x_{k+1} &= \operatorname{prox}_{\sigma \lambda \regname}(u_k). + \end{aligned} + \end{equation*} + + + where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to + a minimizer of :math:`f(x) + \lambda \regname(x)`, but instead to a minimizer of + :math:`\gamma\, ^1f+\sigma \lambda \regname`, where :math:`^1f` denotes + the Moreau envelope of :math:`f` + + """ + + def __init__(self, **kwargs): + super(myHQSIteration, self).__init__(**kwargs) + self.g_step = mygStepHQS(**kwargs) + self.f_step = myfStepHQS(**kwargs) + self.requires_prox_g = True + +class myfStepHQS(fStep): + r""" + HQS fStep module. + """ + + def __init__(self, **kwargs): + super(myfStepHQS, self).__init__(**kwargs) + + def forward(self, x, cur_data_fidelity, cur_params, y, physics): + r""" + Single proximal step on the data-fidelity term :math:`f`. + + :param torch.Tensor x: Current iterate :math:`x_k`. + :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity. + :param dict cur_params: Dictionary containing the current parameters of the algorithm. + :param torch.Tensor y: Input data. + :param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term. + """ + return cur_data_fidelity.prox(x, y, physics, gamma=cur_params["stepsize"]) + +class mygStepHQS(gStep): + r""" + HQS gStep module. + """ + + def __init__(self, **kwargs): + super(mygStepHQS, self).__init__(**kwargs) + + def forward(self, x, cur_prior, cur_params): + r""" + Single proximal step on the prior term :math:`\lambda \regname`. + + :param torch.Tensor x: Current iterate :math:`x_k`. + :param dict cur_prior: Class containing the current prior. + :param dict cur_params: Dictionary containing the current parameters of the algorithm. + """ + return cur_prior.prox( + x, + sigma_denoiser = cur_params["g_param"], + gain_denoiser = cur_params["gain_param"], + gamma=cur_params["lambda"] * cur_params["stepsize"], + ) + + +def get_unrolled_architecture(gain_param_init = 1e-3, weight_tied = True, model = None, device = 'cpu'): + + # Unrolled optimization algorithm parameters + max_iter = 8 # number of unfolded layers + + # Select the data fidelity term + + + # Set up the trainable denoising prior + # Here the prior model is common for all iterations + if model is not None : + denoiser = model.to(device) + else : + denoiser = dinv.models.DRUNet( + pretrained= '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth', + ).to(device) + + class myPnP(PnP): + r""" + Gradient-Step Denoiser prior. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def prox(self, x, sigma_denoiser, gain_denoiser, *args, **kwargs): + if not self.training: + pad = (-x.size(-2) % 8, -x.size(-1) % 8) + x = torch.nn.functional.pad(x, (0, pad[1], 0, pad[0]), mode="constant") + out = self.denoiser(x, sigma=sigma_denoiser, gamma=gain_denoiser) + if not self.training: + out = out[..., : -pad[0] or None, : -pad[1] or None] + return out + + data_fidelity = PoissonGaussianDataFidelity() + + if not weight_tied : + prior = [myPnP(denoiser=copy.deepcopy(denoiser)) for i in range(max_iter)] + else : + prior = [myPnP(denoiser=denoiser)] + + def get_DPIR_params(noise_level_img, max_iter=8): + r""" + Default parameters for the DPIR Plug-and-Play algorithm. + + :param float noise_level_img: Noise level of the input image. + """ + s1 = 49.0 / 255.0 + s2 = noise_level_img + sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( + np.float32 + ) + stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 + lamb = 1 / 0.23 + return list(sigma_denoiser), list(lamb * stepsize) + + sigma_denoiser, stepsize = get_DPIR_params(0.05) + stepsize = torch.tensor(stepsize) * (torch.tensor(sigma_denoiser)**2) + gain_denoiser = [gain_param_init]*len(sigma_denoiser) + params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "gain_param": gain_denoiser} + + trainable_params = [ + "g_param", + "gain_param" + "stepsize", + ] # define which parameters from 'params_algo' are trainable + + # Define the unfolded trainable model. + model = unfolded_builder( + iteration=myHQSIteration(), + params_algo=params_algo.copy(), + trainable_params=trainable_params, + data_fidelity=data_fidelity, + max_iter=max_iter, + prior=prior, + device=device, + ) + + return model.to(device) diff --git a/physics/__init__.py b/physics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/physics/blur_generator.py b/physics/blur_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..59943c851cae72bcf7c6dabf06aa15a69ff5779a --- /dev/null +++ b/physics/blur_generator.py @@ -0,0 +1,148 @@ +import torch +import torchvision + +from deepinv.physics.blur import rotate +from deepinv.physics.generator import PSFGenerator + + +def gaussian_blur_padded(sigma=(1, 1), angle=0, filt_size=None): + r""" + Padded gaussian blur filter. + + Defined as + + .. math:: + \begin{equation*} + G(x, y) = \frac{1}{2\pi\sigma_x\sigma_y} \exp{\left(-\frac{x'^2}{2\sigma_x^2} - \frac{y'^2}{2\sigma_y^2}\right)} + \end{equation*} + + where :math:`x'` and :math:`y'` are the rotated coordinates obtained by rotating $(x, y)$ around the origin + by an angle :math:`\theta`: + + .. math:: + + \begin{align*} + x' &= x \cos(\theta) - y \sin(\theta) \\ + y' &= x \sin(\theta) + y \cos(\theta) + \end{align*} + + with :math:`\sigma_x` and :math:`\sigma_y` the standard deviations along the :math:`x'` and :math:`y'` axes. + + + :param float, tuple[float] sigma: standard deviation of the gaussian filter. If sigma is a float the filter is isotropic, whereas + if sigma is a tuple of floats (sigma_x, sigma_y) the filter is anisotropic. + :param float angle: rotation angle of the filter in degrees (only useful for anisotropic filters) + """ + if isinstance(sigma, (int, float)): + sigma = (sigma, sigma) + device = "cpu" + elif isinstance(sigma, torch.Tensor): + device = sigma.device + + s = max(sigma) + c = int(s / 0.3 + 1) + k_size = 2 * c + 1 + + delta = torch.arange(k_size).to(device) + + x, y = torch.meshgrid(delta, delta, indexing="ij") + x = x - c + y = y - c + filt = (x / sigma[0]).pow(2) + filt += (y / sigma[1]).pow(2) + filt = torch.exp(-filt / 2.0) + + filt = ( + rotate( + filt.unsqueeze(0).unsqueeze(0), + angle, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + .squeeze(0) + .squeeze(0) + ) + + filt = filt / filt.flatten().sum() + + filt = filt.unsqueeze(0).unsqueeze(0) + + if filt_size is not None: + filt = torch.nn.functional.pad( + filt, + ( + (filt_size[0] - filt.shape[-2]) // 2, + (filt_size[0] - filt.shape[-2] + 1) // 2, + (filt_size[1] - filt.shape[-1]) // 2, + (filt_size[1] - filt.shape[-1] + 1) // 2, + ), + ) + + return filt + + +class GaussianBlurGenerator(PSFGenerator): + + def __init__( + self, + psf_size: tuple, + num_channels: int = 1, + device: str = "cpu", + dtype: type = torch.float32, + l: float = 0.3, + sigma: float = 0.25, + sigma_min: float = 0.01, + sigma_max: float = 4.0, + ) -> None: + kwargs = { + "l": l, + "sigma": sigma, + "sigma_min": sigma_min, + "sigma_max": sigma_max, + } + if len(psf_size) != 2: + raise ValueError( + "psf_size must 2D. Add channels via num_channels parameter" + ) + super().__init__( + psf_size=psf_size, + num_channels=num_channels, + device=device, + dtype=dtype, + **kwargs, + ) + + def step(self, batch_size: int = 1, sigma: float = None, **kwargs): + r""" + Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l` + + :param int batch_size: batch_size. + :param float sigma: the standard deviation of the Gaussian Process + :param float l: the length scale of the trajectory + + :return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])` + """ + + sigmas = [ + self.sigma_min + + torch.rand(2, **self.factory_kwargs) * (self.sigma_max - self.sigma_min) + for batch in range(batch_size) + ] + angles = [ + (torch.rand(1, **self.factory_kwargs) * 180.0).item() + for batch in range(batch_size) + ] + + kernels = [ + gaussian_blur_padded(sigma, angle, filt_size=self.psf_size) + for sigma, angle in zip(sigmas, angles) + ] + kernel = torch.cat(kernels, dim=0) + + return { + "filter": kernel.expand( + -1, + self.num_channels, + -1, + -1, + ) + } diff --git a/physics/inpainting_generator.py b/physics/inpainting_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3d025877881306dfa8741170c54df7e26df75d --- /dev/null +++ b/physics/inpainting_generator.py @@ -0,0 +1,107 @@ +import torch +from deepinv.physics.generator import PhysicsGenerator + + +class InpaintingMaskGenerator(PhysicsGenerator): + + def __init__( + self, + mask_shape: tuple, + num_channels: int = 1, + device: str = "cpu", + dtype: type = torch.float32, + block_size_ratio=0.1, + num_blocks=5, + ) -> None: + kwargs = { + "mask_shape": mask_shape, + "block_size_ratio": block_size_ratio, + "num_blocks": num_blocks, + } + if len(mask_shape) != 2: + raise ValueError( + "mask_shape must 2D. Add channels via num_channels parameter" + ) + super().__init__( + num_channels=num_channels, + device=device, + dtype=dtype, + **kwargs, + ) + + def generate_mask(self, image_shape, block_size_ratio, num_blocks): + # Create an all-ones tensor which will serve as the initial mask + mask = torch.ones(image_shape) + batch_size = mask.shape[0] + + # Calculate block size based on the image dimensions and block_size_ratio + block_width = int(image_shape[-2] * block_size_ratio) + block_height = int(image_shape[-1] * block_size_ratio) + + # Generate random coordinates for each block in each batch + x_coords = torch.randint( + 0, image_shape[-1] - block_width, (batch_size, num_blocks) + ) + y_coords = torch.randint( + 0, image_shape[-2] - block_height, (batch_size, num_blocks) + ) + + # Create grids of indices for the block dimensions + x_range = torch.arange(block_width).view(1, 1, -1) + y_range = torch.arange(block_height).view(1, 1, -1) + + # Expand ranges to match the batch and num_blocks dimensions + x_indices = x_coords.unsqueeze(-1) + x_range + y_indices = y_coords.unsqueeze(-1) + y_range + + # Expand and flatten the indices for advanced indexing + x_indices = x_indices.unsqueeze(2).expand(-1, -1, block_height, -1).reshape(-1) + y_indices = y_indices.unsqueeze(3).expand(-1, -1, -1, block_width).reshape(-1) + + # Create batch indices for advanced indexing + batch_indices = ( + torch.arange(batch_size) + .view(-1, 1, 1) + .expand(-1, num_blocks, block_width * block_height) + .reshape(-1) + ) + channel_indices = ( + torch.arange(3) + .view(1, 1, 1, -1) + .expand(batch_size, num_blocks, block_width * block_height, -1) + .reshape(-1) + ) + + # Apply the blocks using advanced indexing + mask[batch_indices, :, y_indices, x_indices] = 0 + + return mask + + def step( + self, batch_size: int = 1, block_size_ratio: float = None, num_blocks=None + ): + r""" + Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l` + + :param int batch_size: batch_size. + :param float sigma: the standard deviation of the Gaussian Process + :param float l: the length scale of the trajectory + + :return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])` + """ + + # TODO: add randomness + block_size_ratio = ( + self.block_size_ratio if block_size_ratio is None else block_size_ratio + ) + num_blocks = self.num_blocks if num_blocks is None else num_blocks + batch_shape = ( + batch_size, + self.num_channels, + self.mask_shape[-2], + self.mask_shape[-1], + ) + + mask = self.generate_mask(batch_shape, block_size_ratio, num_blocks) + + return {"mask": mask.to(self.factory_kwargs["device"])} diff --git a/physics/multiscale.py b/physics/multiscale.py new file mode 100644 index 0000000000000000000000000000000000000000..afba5853442106d07bab3b6890c8d69ef4b65daf --- /dev/null +++ b/physics/multiscale.py @@ -0,0 +1,84 @@ +import torch +from deepinv.physics import Physics, LinearPhysics, Downsampling + +class Upsampling(Downsampling): + def A(self, x, **kwargs): + return super().A_adjoint(x, **kwargs) + + def A_adjoint(self, y, **kwargs): + return super().A(y, **kwargs) + + def prox_l2(self, z, y, gamma, **kwargs): + return super().prox_l2(z, y, gamma, **kwargs) + + +class MultiScalePhysics(Physics): + def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs): + super().__init__(noise_model=physics.noise_model, **kwargs) + self.base = physics + self.scales = scales + self.img_shape = img_shape + self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in scales] + self.scale = 0 + + def set_scale(self, scale): + if scale is not None: + self.scale = scale + + def A(self, x, scale=None, **kwargs): + self.set_scale(scale) + if self.scale == 0: + return self.base.A(x, **kwargs) + else: + return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs) + + def downsample(self, x, scale=None): + self.set_scale(scale) + if self.scale == 0: + return x + else: + return self.Upsamplings[self.scale - 1].A_adjoint(x) + + def upsample(self, x, scale=None): + self.set_scale(scale) + if self.scale == 0: + return x + else: + return self.Upsamplings[self.scale - 1].A(x) + + def update_parameters(self, **kwargs): + self.base.update_parameters(**kwargs) + + +class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics): + def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs): + super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs) + + def A_adjoint(self, y, scale=None, **kwargs): + self.set_scale(scale) + y = self.base.A_adjoint(y, **kwargs) + if self.scale == 0: + return y + else: + return self.Upsamplings[self.scale - 1].A_adjoint(y) + + +class Pad(LinearPhysics): + def __init__(self, physics, pad): + super().__init__(noise_model=physics.noise_model) + self.base = physics + self.pad = pad + + def A(self, x): + return self.base.A(x[..., self.pad[0]:, self.pad[1]:]) + + def A_adjoint(self, y): + y = self.base.A_adjoint(y) + y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0)) + return y + + def remove_pad(self, x): + return x[..., self.pad[0]:, self.pad[1]:] + + def update_parameters(self, **kwargs): + self.base.update_parameters(**kwargs) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 016810d3b2673b72b0634dedbc151c1baec0fb4a..081c20836d5abbd18e49323d90d708f348d54688 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,2 @@ deepinv -bm3d timm -ptwt \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e2a61391ce3bf8b2e97eff760ff2ef8fa70322e --- /dev/null +++ b/utils.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +import deepinv as dinv + +from models.unext_wip import UNeXt +from models.unrolled_dpir import get_unrolled_architecture +from models.PDNet import get_PDNet_architecture +from physics.multiscale import Pad + + +class ArtifactRemoval(nn.Module): + def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False): + super().__init__() + self.pinv = pinv + self.backbone_net = backbone_net + self.fm_mode = fm_mode + + if ckpt_path is not None: + self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True) + self.backbone_net.eval() + + if type(self.backbone_net).__name__ == "UNetRes": + for _, v in self.backbone_net.named_parameters(): + v.requires_grad = False + self.backbone_net = self.backbone_net.to(device) + + def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs): + if physics is None: + physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device) + + if not self.training: + x_temp = physics.A_adjoint(y) + pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8) + physics = Pad(physics, pad) + + x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y) + + sigma = getattr(physics.noise_model, "sigma", 1e-3) + gamma = getattr(physics.noise_model, "gain", 1e-3) + + out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t) + + if not self.training: + out = physics.remove_pad(out) + + return out + + def forward(self, y=None, physics=None, x_in=None, **kwargs): + return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs) + + +def get_model( + model_name="unext_emb_physics_config_C", + device="cpu", + in_channels=[1, 2, 3], + conv_type="base", + pool_type="base", + layer_scale_init_value=1e-6, + init_type="ortho", + gain_init_conv=1.0, + gain_init_linear=1.0, + drop_prob=0.0, + replk=False, + mult_fact=4, + antialias="gaussian", + nc_base=64, + cond_type="base", + pretrained_pth=None, + weight_tied=True, + N=4, + c_mult=1, + depth_encoding=1, + relu_in_encoding=False, + skip_in_encoding=True, +): + model_name = model_name.lower() + nc = [nc_base * 2**i for i in range(4)] + + if model_name == "pdnet": + return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device) + + elif model_name == "unrolled_dpir": + model = UNeXt( + in_channels=in_channels, + out_channels=in_channels, + device=device, + conv_type=conv_type, + pool_type=pool_type, + layer_scale_init_value=layer_scale_init_value, + init_type=init_type, + gain_init_conv=gain_init_conv, + gain_init_linear=gain_init_linear, + drop_prob=drop_prob, + replk=replk, + mult_fact=mult_fact, + antialias=antialias, + nc=nc, + cond_type=cond_type, + emb_physics=False, + config=None, + pretrained_pth=pretrained_pth, + ).to(device) + model = get_unrolled_architecture(model=model, weight_tied=weight_tied, device=device) + return ArtifactRemoval(model, pinv=True, device=device) + + elif model_name == "unext_emb_physics_config_c": + model = UNeXt( + in_channels=in_channels, + out_channels=in_channels, + device=device, + conv_type=conv_type, + pool_type=pool_type, + layer_scale_init_value=layer_scale_init_value, + init_type=init_type, + gain_init_conv=gain_init_conv, + gain_init_linear=gain_init_linear, + drop_prob=drop_prob, + replk=replk, + mult_fact=mult_fact, + antialias=antialias, + nc=nc, + cond_type=cond_type, + emb_physics=True, + config="C", + pretrained_pth=pretrained_pth, + N=N, + c_mult=c_mult, + depth_encoding=depth_encoding, + relu_in_encoding=relu_in_encoding, + skip_in_encoding=skip_in_encoding, + ).to(device) + return ArtifactRemoval(model, pinv=False, device=device) + + else: + raise ValueError(f"Model {model_name} is not supported.") \ No newline at end of file