diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..efc3862f00eef3c340f5b6e184e08b5a4f99bfbb --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +/**/* +!/**/ +/venv/ +!*.ipynb +!*.gitignore +!*.md +!*.bat +!*.py +!*.yml +!*.ui +!*.yaml + +!requirements.txt +!version.txt + +/docs/build +!/docs/Makefile + +/build/ + + +/results +/scripts/local_trash + +!/media/** \ No newline at end of file diff --git a/README.md b/README.md index cf2fcaa29b79130719fd8b7b20af1d327222ee69..bd1f0f028993606f23f768e2ccff89077b5d4e63 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ --- title: Medfusion App -emoji: 🏢 +emoji: 🔬 colorFrom: pink colorTo: gray sdk: streamlit @@ -10,4 +10,64 @@ pinned: false license: mit --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Medfusion - Medical Denoising Diffusion Probabilistic Model +============= + +Paper +======= +Please see: [**Diffusion Probabilistic Models beat GANs on Medical 2D Images**]() + +![](media/Medfusion.png) +*Figure: Medfusion* + +![](media/animation_eye.gif) ![](media/animation_histo.gif) ![](media/animation_chest.gif)\ +*Figure: Eye fundus, chest X-ray and colon histology images generated with Medfusion (Warning color quality limited by .gif)* + +Demo +============= +[Link]() to streamlit app. + +Install +============= + +Create virtual environment and install packages: \ +`python -m venv venv` \ +`source venv/bin/activate`\ +`pip install -e .` + + +Get Started +============= + +1 Prepare Data +------------- + +* Go to [medical_diffusion/data/datasets/dataset_simple_2d.py](medical_diffusion/data/datasets/dataset_simple_2d.py) and create a new `SimpleDataset2D` or write your own Dataset. + + +2 Train Autoencoder +---------------- +* Go to [scripts/train_latent_embedder_2d.py](scripts/train_latent_embedder_2d.py) and import your Dataset. +* Load your dataset with eg. `SimpleDataModule` +* Customize `VAE` to your needs +* (Optional): Train a `VAEGAN` instead or load a pre-trained `VAE` and set `start_gan_train_step=-1` to start training of GAN immediately. + +2.1 Evaluate Autoencoder +---------------- +* Use [scripts/evaluate_latent_embedder.py](scripts/evaluate_latent_embedder.py) to evaluate the performance of the Autoencoder. + +3 Train Diffusion +---------------- +* Go to [scripts/train_diffusion.py](scripts/train_diffusion.py) and import/load your Dataset as before. +* Load your pre-trained VAE or VAEGAN with `latent_embedder_checkpoint=...` +* Use `cond_embedder = LabelEmbedder` for conditional training, otherwise `cond_embedder = None` + +3.1 Evaluate Diffusion +---------------- +* Go to [scripts/sample.py](scripts/sample.py) to sample a test image. +* Go to [scripts/helpers/sample_dataset.py](scripts/helpers/sample_dataset.py) to sample a more reprensative sample size. +* Use [scripts/evaluate_images.py](scripts/evaluate_images.py) to evaluate performance of sample (FID, Precision, Recall) + +Acknowledgment +============= +* Code builds upon https://github.com/lucidrains/denoising-diffusion-pytorch diff --git a/medical_diffusion/data/augmentation/__init__.py b/medical_diffusion/data/augmentation/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/medical_diffusion/data/augmentation/augmentations_2d.py b/medical_diffusion/data/augmentation/augmentations_2d.py new file mode 100755 index 0000000000000000000000000000000000000000..2154a4898ea659f42aa010ad9c316d7cf6c0453b --- /dev/null +++ b/medical_diffusion/data/augmentation/augmentations_2d.py @@ -0,0 +1,27 @@ + +import torch +import numpy as np + +class ToTensor16bit(object): + """PyTorch can not handle uint16 only int16. First transform to int32. Note, this function also adds a channel-dim""" + def __call__(self, image): + # return torch.as_tensor(np.array(image, dtype=np.int32)[None]) + # return torch.from_numpy(np.array(image, np.int32, copy=True)[None]) + image = np.array(image, np.int32, copy=True) # [H,W,C] or [H,W] + image = np.expand_dims(image, axis=-1) if image.ndim ==2 else image + return torch.from_numpy(np.moveaxis(image, -1, 0)) #[C, H, W] + +class Normalize(object): + """Rescale the image to [0,1] and ensure float32 dtype """ + + def __call__(self, image): + image = image.type(torch.FloatTensor) + return (image-image.min())/(image.max()-image.min()) + + +class RandomBackground(object): + """Fill Background (intensity ==0) with random values""" + + def __call__(self, image): + image[image==0] = torch.rand(*image[image==0].shape) #(image.max()-image.min()) + return image \ No newline at end of file diff --git a/medical_diffusion/data/augmentation/augmentations_3d.py b/medical_diffusion/data/augmentation/augmentations_3d.py new file mode 100755 index 0000000000000000000000000000000000000000..d6b6012d5f50a7d26017daf641eb5eed1c2be639 --- /dev/null +++ b/medical_diffusion/data/augmentation/augmentations_3d.py @@ -0,0 +1,38 @@ +import torchio as tio +from typing import Union, Optional, Sequence +from torchio.typing import TypeTripletInt +from torchio import Subject, Image +from torchio.utils import to_tuple + +class CropOrPad_None(tio.CropOrPad): + def __init__( + self, + target_shape: Union[int, TypeTripletInt, None] = None, + padding_mode: Union[str, float] = 0, + mask_name: Optional[str] = None, + labels: Optional[Sequence[int]] = None, + **kwargs + ): + + # WARNING: Ugly workaround to allow None values + if target_shape is not None: + self.original_target_shape = to_tuple(target_shape, length=3) + target_shape = [1 if t_s is None else t_s for t_s in target_shape] + super().__init__(target_shape, padding_mode, mask_name, labels, **kwargs) + + def apply_transform(self, subject: Subject): + # WARNING: This makes the transformation subject dependent - reverse transformation must be adapted + if self.target_shape is not None: + self.target_shape = [s_s if t_s is None else t_s for t_s, s_s in zip(self.original_target_shape, subject.spatial_shape)] + return super().apply_transform(subject=subject) + + +class SubjectToTensor(object): + """Transforms TorchIO Subjects into a Python dict and changes axes order from TorchIO to Torch""" + def __call__(self, subject: Subject): + return {key: val.data.swapaxes(1,-1) if isinstance(val, Image) else val for key,val in subject.items()} + +class ImageToTensor(object): + """Transforms TorchIO Image into a Numpy/Torch Tensor and changes axes order from TorchIO [B, C, W, H, D] to Torch [B, C, D, H, W]""" + def __call__(self, image: Image): + return image.data.swapaxes(1,-1) \ No newline at end of file diff --git a/medical_diffusion/data/datamodules/__init__.py b/medical_diffusion/data/datamodules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..673fee178d4234ce9219704234743d3c934dbed2 --- /dev/null +++ b/medical_diffusion/data/datamodules/__init__.py @@ -0,0 +1 @@ +from .datamodule_simple import SimpleDataModule \ No newline at end of file diff --git a/medical_diffusion/data/datamodules/datamodule_simple.py b/medical_diffusion/data/datamodules/datamodule_simple.py new file mode 100755 index 0000000000000000000000000000000000000000..7d8eda7f5a08905d72ce678b3de9c8d58fa7ef76 --- /dev/null +++ b/medical_diffusion/data/datamodules/datamodule_simple.py @@ -0,0 +1,79 @@ + +import pytorch_lightning as pl +import torch +from torch.utils.data.dataloader import DataLoader +import torch.multiprocessing as mp +from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler + + + +class SimpleDataModule(pl.LightningDataModule): + + def __init__(self, + ds_train: object, + ds_val:object =None, + ds_test:object =None, + batch_size: int = 1, + num_workers: int = mp.cpu_count(), + seed: int = 0, + pin_memory: bool = False, + weights: list = None + ): + super().__init__() + self.hyperparameters = {**locals()} + self.hyperparameters.pop('__class__') + self.hyperparameters.pop('self') + + self.ds_train = ds_train + self.ds_val = ds_val + self.ds_test = ds_test + + self.batch_size = batch_size + self.num_workers = num_workers + self.seed = seed + self.pin_memory = pin_memory + self.weights = weights + + + + def train_dataloader(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + + if self.weights is not None: + sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator) + else: + sampler = RandomSampler(self.ds_train, replacement=False, generator=generator) + return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers, + sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory) + + + def val_dataloader(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + if self.ds_val is not None: + return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, + generator=generator, drop_last=False, pin_memory=self.pin_memory) + else: + raise AssertionError("A validation set was not initialized.") + + + def test_dataloader(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + if self.ds_test is not None: + return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, + generator = generator, drop_last=False, pin_memory=self.pin_memory) + else: + raise AssertionError("A test test set was not initialized.") + + + + + + + + + + + \ No newline at end of file diff --git a/medical_diffusion/data/datasets/__init__.py b/medical_diffusion/data/datasets/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..cca9ef494a2d44f0e027cf63edf8c5c6f0357394 --- /dev/null +++ b/medical_diffusion/data/datasets/__init__.py @@ -0,0 +1,2 @@ +from .dataset_simple_2d import * +from .dataset_simple_3d import * \ No newline at end of file diff --git a/medical_diffusion/data/datasets/dataset_simple_2d.py b/medical_diffusion/data/datasets/dataset_simple_2d.py new file mode 100755 index 0000000000000000000000000000000000000000..d8d953caf3d8a165aaf8300fda33f33f50132128 --- /dev/null +++ b/medical_diffusion/data/datasets/dataset_simple_2d.py @@ -0,0 +1,198 @@ + +import torch.utils.data as data +import torch +from torch import nn +from pathlib import Path +from torchvision import transforms as T +import pandas as pd + +from PIL import Image + +from medical_diffusion.data.augmentation.augmentations_2d import Normalize, ToTensor16bit + +class SimpleDataset2D(data.Dataset): + def __init__( + self, + path_root, + item_pointers =[], + crawler_ext = 'tif', # other options are ['jpg', 'jpeg', 'png', 'tiff'], + transform = None, + image_resize = None, + augment_horizontal_flip = False, + augment_vertical_flip = False, + image_crop = None, + ): + super().__init__() + self.path_root = Path(path_root) + self.crawler_ext = crawler_ext + if len(item_pointers): + self.item_pointers = item_pointers + else: + self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext) + + if transform is None: + self.transform = T.Compose([ + T.Resize(image_resize) if image_resize is not None else nn.Identity(), + T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), + T.RandomVerticalFlip() if augment_vertical_flip else nn.Identity(), + T.CenterCrop(image_crop) if image_crop is not None else nn.Identity(), + T.ToTensor(), + # T.Lambda(lambda x: torch.cat([x]*3) if x.shape[0]==1 else x), + # ToTensor16bit(), + # Normalize(), # [0, 1.0] + # T.ConvertImageDtype(torch.float), + T.Normalize(mean=0.5, std=0.5) # WARNING: mean and std are not the target values but rather the values to subtract and divide by: [0, 1] -> [0-0.5, 1-0.5]/0.5 -> [-1, 1] + ]) + else: + self.transform = transform + + def __len__(self): + return len(self.item_pointers) + + def __getitem__(self, index): + rel_path_item = self.item_pointers[index] + path_item = self.path_root/rel_path_item + # img = Image.open(path_item) + img = self.load_item(path_item) + return {'uid':rel_path_item.stem, 'source': self.transform(img)} + + def load_item(self, path_item): + return Image.open(path_item).convert('RGB') + # return cv2.imread(str(path_item), cv2.IMREAD_UNCHANGED) # NOTE: Only CV2 supports 16bit RGB images + + @classmethod + def run_item_crawler(cls, path_root, extension, **kwargs): + return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')] + + def get_weights(self): + """Return list of class-weights for WeightedSampling""" + return None + + +class AIROGSDataset(SimpleDataset2D): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.labels = pd.read_csv(self.path_root.parent/'train_labels.csv', index_col='challenge_id') + + def __len__(self): + return len(self.labels) + + def __getitem__(self, index): + uid = self.labels.index[index] + path_item = self.path_root/f'{uid}.jpg' + img = self.load_item(path_item) + str_2_int = {'NRG':0, 'RG':1} # RG = 3270, NRG = 98172 + target = str_2_int[self.labels.loc[uid, 'class']] + # return {'uid':uid, 'source': self.transform(img), 'target':target} + return {'source': self.transform(img), 'target':target} + + def get_weights(self): + n_samples = len(self) + weight_per_class = 1/self.labels['class'].value_counts(normalize=True) # {'NRG': 1.03, 'RG': 31.02} + weights = [0] * n_samples + for index in range(n_samples): + target = self.labels.iloc[index]['class'] + weights[index] = weight_per_class[target] + return weights + + @classmethod + def run_item_crawler(cls, path_root, extension, **kwargs): + """Overwrite to speed up as paths are determined by .csv file anyway""" + return [] + +class MSIvsMSS_Dataset(SimpleDataset2D): + # https://doi.org/10.5281/zenodo.2530835 + def __getitem__(self, index): + rel_path_item = self.item_pointers[index] + path_item = self.path_root/rel_path_item + img = self.load_item(path_item) + uid = rel_path_item.stem + str_2_int = {'MSIMUT':0, 'MSS':1} + target = str_2_int[path_item.parent.name] # + return {'uid':uid, 'source': self.transform(img), 'target':target} + + +class MSIvsMSS_2_Dataset(SimpleDataset2D): + # https://doi.org/10.5281/zenodo.3832231 + def __getitem__(self, index): + rel_path_item = self.item_pointers[index] + path_item = self.path_root/rel_path_item + img = self.load_item(path_item) + uid = rel_path_item.stem + str_2_int = {'MSIH':0, 'nonMSIH':1} # patients with MSI-H = MSIH; patients with MSI-L and MSS = NonMSIH) + target = str_2_int[path_item.parent.name] + # return {'uid':uid, 'source': self.transform(img), 'target':target} + return {'source': self.transform(img), 'target':target} + + +class CheXpert_Dataset(SimpleDataset2D): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + mode = self.path_root.name + labels = pd.read_csv(self.path_root.parent/f'{mode}.csv', index_col='Path') + self.labels = labels.loc[labels['Frontal/Lateral'] == 'Frontal'].copy() + self.labels.index = self.labels.index.str[20:] + self.labels.loc[self.labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Affects 1 case, must be "female" to match stats in publication + self.labels.fillna(2, inplace=True) # TODO: Find better solution, + str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1}} + self.labels.replace(str_2_int, inplace=True) + + def __len__(self): + return len(self.labels) + + def __getitem__(self, index): + rel_path_item = self.labels.index[index] + path_item = self.path_root/rel_path_item + img = self.load_item(path_item) + uid = str(rel_path_item) + target = torch.tensor(self.labels.loc[uid, 'Cardiomegaly']+1, dtype=torch.long) # Note Labels are -1=uncertain, 0=negative, 1=positive, NA=not reported -> Map to [0, 2], NA=3 + return {'uid':uid, 'source': self.transform(img), 'target':target} + + + @classmethod + def run_item_crawler(cls, path_root, extension, **kwargs): + """Overwrite to speed up as paths are determined by .csv file anyway""" + return [] + +class CheXpert_2_Dataset(SimpleDataset2D): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + labels = pd.read_csv(self.path_root/'labels/cheXPert_label.csv', index_col=['Path', 'Image Index']) # Note: 1 and -1 (uncertain) cases count as positives (1), 0 and NA count as negatives (0) + labels = labels.loc[labels['fold']=='train'].copy() + labels = labels.drop(labels='fold', axis=1) + + labels2 = pd.read_csv(self.path_root/'labels/train.csv', index_col='Path') + labels2 = labels2.loc[labels2['Frontal/Lateral'] == 'Frontal'].copy() + labels2 = labels2[['Cardiomegaly',]].copy() + labels2[ (labels2 <0) | labels2.isna()] = 2 # 0 = Negative, 1 = Positive, 2 = Uncertain + labels = labels.join(labels2['Cardiomegaly'], on=["Path",], rsuffix='_true') + # labels = labels[labels['Cardiomegaly_true']!=2] + + self.labels = labels + + def __len__(self): + return len(self.labels) + + def __getitem__(self, index): + path_index, image_index = self.labels.index[index] + path_item = self.path_root/'data'/f'{image_index:06}.png' + img = self.load_item(path_item) + uid = image_index + target = int(self.labels.loc[(path_index, image_index), 'Cardiomegaly']) + # return {'uid':uid, 'source': self.transform(img), 'target':target} + return {'source': self.transform(img), 'target':target} + + @classmethod + def run_item_crawler(cls, path_root, extension, **kwargs): + """Overwrite to speed up as paths are determined by .csv file anyway""" + return [] + + def get_weights(self): + n_samples = len(self) + weight_per_class = 1/self.labels['Cardiomegaly'].value_counts(normalize=True) + # weight_per_class = {2.0: 1.2, 1.0: 8.2, 0.0: 24.3} + weights = [0] * n_samples + for index in range(n_samples): + target = self.labels.loc[self.labels.index[index], 'Cardiomegaly'] + weights[index] = weight_per_class[target] + return weights \ No newline at end of file diff --git a/medical_diffusion/data/datasets/dataset_simple_3d.py b/medical_diffusion/data/datasets/dataset_simple_3d.py new file mode 100755 index 0000000000000000000000000000000000000000..2fda25a47cf8d3e85fe13c90e9afc206b3ed7a3a --- /dev/null +++ b/medical_diffusion/data/datasets/dataset_simple_3d.py @@ -0,0 +1,58 @@ + +import torch.utils.data as data +from pathlib import Path +from torchvision import transforms as T + + +import torchio as tio + +from medical_diffusion.data.augmentation.augmentations_3d import ImageToTensor + + +class SimpleDataset3D(data.Dataset): + def __init__( + self, + path_root, + item_pointers =[], + crawler_ext = ['nii'], # other options are ['nii.gz'], + transform = None, + image_resize = None, + flip = False, + image_crop = None, + use_znorm=True, # Use z-Norm for MRI as scale is arbitrary, otherwise scale intensity to [-1, 1] + ): + super().__init__() + self.path_root = path_root + self.crawler_ext = crawler_ext + + if transform is None: + self.transform = T.Compose([ + tio.Resize(image_resize) if image_resize is not None else tio.Lambda(lambda x: x), + tio.RandomFlip((0,1,2)) if flip else tio.Lambda(lambda x: x), + tio.CropOrPad(image_crop) if image_crop is not None else tio.Lambda(lambda x: x), + tio.ZNormalization() if use_znorm else tio.RescaleIntensity((-1,1)), + ImageToTensor() # [C, W, H, D] -> [C, D, H, W] + ]) + else: + self.transform = transform + + if len(item_pointers): + self.item_pointers = item_pointers + else: + self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext) + + def __len__(self): + return len(self.item_pointers) + + def __getitem__(self, index): + rel_path_item = self.item_pointers[index] + path_item = self.path_root/rel_path_item + img = self.load_item(path_item) + return {'uid':rel_path_item.stem, 'source': self.transform(img)} + + def load_item(self, path_item): + return tio.ScalarImage(path_item) # Consider to use this or tio.ScalarLabel over SimpleITK (sitk.ReadImage(str(path_item))) + + @classmethod + def run_item_crawler(cls, path_root, extension, **kwargs): + return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')] \ No newline at end of file diff --git a/medical_diffusion/external/diffusers/attention.py b/medical_diffusion/external/diffusers/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..25e1ea28dcf0226defc89fc6c92b5fc3faeac462 --- /dev/null +++ b/medical_diffusion/external/diffusers/attention.py @@ -0,0 +1,347 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Parameters: + in_channels (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + """ + + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: int = 1, + dropout: float = 0.0, + num_groups: int = 32, + context_dim: Optional[int] = None, + ): + super().__init__() + self.n_heads = n_heads + self.d_head = d_head + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth) + ] + ) + + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def _set_attention_slice(self, slice_size): + for block in self.transformer_blocks: + block._set_attention_slice(slice_size) + + def forward(self, hidden_states, context=None): + # note: if no context is given, cross-attention defaults to self-attention + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=context) + hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + return hidden_states + residual + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout=0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def _set_attention_slice(self, slice_size): + self.attn1._slice_size = slice_size + self.attn2._slice_size = slice_size + + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (:obj:`int`): The number of channels in the query. + context_dim (:obj:`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim if context_dim is not None else query_dim + + self.scale = dim_head**-0.5 + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self._slice_size = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + dim = query.shape[-1] + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + # TODO(PVP) - mask is currently never used. Remember to re-implement when used + + # attention, what we cannot get enough of + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + + return self.to_out(hidden_states) + + def _attention(self, query, key, value): + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_probs = attention_scores.softmax(dim=-1) + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + attn_slice = attn_slice.softmax(dim=-1) + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + project_in = GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + return self.net(hidden_states) + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) diff --git a/medical_diffusion/external/diffusers/embeddings.py b/medical_diffusion/external/diffusers/embeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..d721bc45a87d7bea0892a7767fafd896f220fb3c --- /dev/null +++ b/medical_diffusion/external/diffusers/embeddings.py @@ -0,0 +1,89 @@ +import math +from pydoc import describe + +import numpy as np +import torch +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + +class TimeEmbbeding(nn.Module): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + super().__init__() + + self.temb = Timesteps(channel, flip_sin_to_cos=True, downscale_freq_shift=0) + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.temb(sample) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + + diff --git a/medical_diffusion/external/diffusers/resnet.py b/medical_diffusion/external/diffusers/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..97f3c02a8ccf434e9f7788ba503d64e0395146b0 --- /dev/null +++ b/medical_diffusion/external/diffusers/resnet.py @@ -0,0 +1,479 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Upsample2D(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + x = self.conv(x) + else: + x = self.Conv2d_0(x) + + return x + + +class Downsample2D(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + + assert x.shape[1] == self.channels + x = self.conv(x) + + return x + + +class FirUpsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + + p = (kernel.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) + output_padding = ( + output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + inC = weight.shape[1] + num_groups = x.shape[1] // inC + + # Transpose weights. + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) + + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + else: + p = kernel.shape[0] - factor + x = upfirdn2d_native( + x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + ) + + return x + + def forward(self, x): + if self.use_conv: + height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2) + + return height + + +class FirDownsample2D(nn.Module): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + + Args: + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, + filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // + numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: + Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, convH, convW = weight.shape + p = (kernel.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2)) + x = F.conv2d(x, weight, stride=s, padding=0) + else: + p = kernel.shape[0] - factor + x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + return x + + def forward(self, x): + if self.use_conv: + x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) + + return x + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + hidden_states = x + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + x = self.upsample(x) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + x = self.downsample(x) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + # make sure hidden states is in float32 + # when running in half-precision + hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + x = self.conv_shortcut(x) + + out = (x + hidden_states) / self.output_scale_factor + + return out + + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +def upsample_2d(x, kernel=None, factor=2, gain=1): + r"""Upsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: + multiple of the upsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + p = kernel.shape[0] - factor + return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + + +def downsample_2d(x, kernel=None, factor=2, gain=1): + r"""Downsample2D a batch of 2D images with the given filter. + + Args: + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + p = kernel.shape[0] - factor + return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + +def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + + # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 + if input.device.type == "mps": + out = out.to("cpu") + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(input.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/medical_diffusion/external/diffusers/taming_discriminator.py b/medical_diffusion/external/diffusers/taming_discriminator.py new file mode 100755 index 0000000000000000000000000000000000000000..c0e2ba389a330c3e067d758288c4d78a151fdae3 --- /dev/null +++ b/medical_diffusion/external/diffusers/taming_discriminator.py @@ -0,0 +1,57 @@ +import functools +import torch.nn as nn + + + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + raise NotImplementedError + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) \ No newline at end of file diff --git a/medical_diffusion/external/diffusers/unet.py b/medical_diffusion/external/diffusers/unet.py new file mode 100755 index 0000000000000000000000000000000000000000..122b50ac976ef25b6e45735ee966aa4c3cea26a9 --- /dev/null +++ b/medical_diffusion/external/diffusers/unet.py @@ -0,0 +1,257 @@ + + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + + +from .embeddings import TimeEmbbeding + +from .unet_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) + +class TimestepEmbedding(nn.Module): + def __init__(self, channel, time_embed_dim, act_fn="silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class UNet2DConditionModel(nn.Module): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 768, + attention_head_dim: int = 8, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + self.emb = nn.Embedding(2, cross_attention_dim) + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_embedding = TimeEmbbeding(block_out_channels[0], time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + + + def forward( + self, + sample: torch.FloatTensor, + t: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + self_cond: torch.Tensor = None + ): + encoder_hidden_states = self.emb(encoder_hidden_states) + # encoder_hidden_states = None # ------------------------ WARNING Disabled --------------------- + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. center input if necessary + # if self.config.center_input_sample: + # sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.time_embedding(t) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=t_emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=t_emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample = upsample_block(hidden_states=sample, temb=t_emb, res_hidden_states_tuple=res_samples) + + # 6. post-process + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + + return sample, [] diff --git a/medical_diffusion/external/diffusers/unet_blocks.py b/medical_diffusion/external/diffusers/unet_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..a895d520afad8e325a5387368f4c21d2c29cf4e5 --- /dev/null +++ b/medical_diffusion/external/diffusers/unet_blocks.py @@ -0,0 +1,1557 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import numpy as np + +# limitations under the License. +import torch +from torch import nn + +from .attention import AttentionBlock, SpatialTransformer +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + ) + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.attention_type == "default": + hidden_states = attn(hidden_states) + else: + hidden_states = attn(hidden_states, encoder_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + **kwargs, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + SpatialTransformer( + in_channels, + attn_num_head_channels, + in_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn), hidden_states, encoder_hidden_states + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels if len(resnets)>0 else in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_type="default", + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + SpatialTransformer( + out_channels, + attn_num_head_channels, + out_channels // attn_num_head_channels, + depth=1, + context_dim=cross_attention_dim, + num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + raise ValueError( + f"Make sure slice_size {slice_size} is a divisor of " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + if slice_size is not None and slice_size > self.attn_num_head_channels: + raise ValueError( + f"Chunk_size {slice_size} has to be smaller or equal to " + f"the number of heads used in cross_attention {self.attn_num_head_channels}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn), hidden_states, encoder_hidden_states + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=np.sqrt(2.0), + upsample_padding=1, + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + self.attention_type = attention_type + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample diff --git a/medical_diffusion/external/diffusers/vae.py b/medical_diffusion/external/diffusers/vae.py new file mode 100755 index 0000000000000000000000000000000000000000..f83b71b82e40451571f5fbdbb3ca66a3cb26c65b --- /dev/null +++ b/medical_diffusion/external/diffusers/vae.py @@ -0,0 +1,857 @@ + + +from typing import Optional, Tuple, Union +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from itertools import chain + +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .taming_discriminator import NLayerDiscriminator +from medical_diffusion.models import BasicModel +from torchvision.utils import save_image + +from torch.distributions.normal import Normal +from torch.distributions import kl_divergence + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i+1] + is_final_block = False #i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i+1] + + is_final_block = False # i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward(self, z): + sample = z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=False): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.batch_size = parameters.shape[0] + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + # self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + device = self.parameters.device + sample_device = "cpu" if device.type == "mps" else device + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + x = self.mean + self.std * sample + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar)/self.batch_size + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + )/self.batch_size + + # q_z_x = Normal(self.mean, self.logvar.mul(.5).exp()) + # p_z = Normal(torch.zeros_like(self.mean), torch.ones_like(self.logvar)) + # kl_div = kl_divergence(q_z_x, p_z).sum(1).mean() + # return kl_div + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class VQModel(nn.Module): + r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray + Kavukcuoglu. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + """ + + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), + up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + block_out_channels: Tuple[int] = (32, 64, 128, 256), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + ) + + self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.quantize = VectorQuantizer( + num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False + ) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + ) + + # def encode(self, x: torch.FloatTensor): + # z = self.encoder(x) + # z = self.quant_conv(z) + # return z + + def encode(self, x, return_loss=True, force_quantize= True): + z = self.encoder(x) + z = self.quant_conv(z) + + if force_quantize: + z_q, emb_loss, _ = self.quantize(z) + else: + z_q, emb_loss = z, None + + if return_loss: + return z_q, emb_loss + else: + return z_q + + def decode(self, z_q) -> torch.FloatTensor: + z_q = self.post_quant_conv(z_q) + x = self.decoder(z_q) + return x + + # def decode(self, z: torch.FloatTensor, return_loss=True, force_quantize: bool = True) -> torch.FloatTensor: + # if force_quantize: + # z_q, emb_loss, _ = self.quantize(z) + # else: + # z_q, emb_loss = z, None + + # z_q = self.post_quant_conv(z_q) + # x = self.decoder(z_q) + + # if return_loss: + # return x, emb_loss + # else: + # return x + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + """ + # h = self.encode(sample) + h, emb_loss = self.encode(sample) + dec = self.decode(h) + # dec, emb_loss = self.decode(h) + + return dec, emb_loss + + +class AutoencoderKL(nn.Module): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + """ + + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D","DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (32, 64, 128, 128), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + norm_num_groups: int = 32, + sample_size: int = 32, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + + def encode(self, x: torch.FloatTensor): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z: torch.FloatTensor) -> torch.FloatTensor: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = True, + generator: Optional[torch.Generator] = None, + ) -> torch.FloatTensor: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + """ + x = sample + posterior = self.encode(x) + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + kl_loss = posterior.kl() + dec = self.decode(z) + return dec, kl_loss + + + +class VQVAEWrapper(BasicModel): + def __init__( + self, + in_ch: int = 3, + out_ch: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (32, 64, 128, 256, ), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 64, + norm_num_groups: int = 32, + + optimizer=torch.optim.AdamW, + optimizer_kwargs={}, + lr_scheduler=None, + lr_scheduler_kwargs={}, + loss=torch.nn.MSELoss, + loss_kwargs={} + ): + super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs, loss, loss_kwargs) + self.model = VQModel(in_ch, out_ch, down_block_types, up_block_types, block_out_channels, + layers_per_block, act_fn, latent_channels, sample_size, num_vq_embeddings, norm_num_groups) + + def forward(self, sample): + return self.model(sample) + + def encode(self, x): + z = self.model.encode(x, return_loss=False) + return z + + def decode(self, z): + x = self.model.decode(z) + return x + + def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): + # ------------------------- Get Source/Target --------------------------- + x = batch['source'] + target = x + + # ------------------------- Run Model --------------------------- + pred, vq_loss = self(x) + + # ------------------------- Compute Loss --------------------------- + loss = self.loss_fct(pred, target) + loss += vq_loss + + # --------------------- Compute Metrics ------------------------------- + results = {'loss':loss} + with torch.no_grad(): + results['L2'] = torch.nn.functional.mse_loss(pred, target) + results['L1'] = torch.nn.functional.l1_loss(pred, target) + + # ----------------- Log Scalars ---------------------- + for metric_name, metric_val in results.items(): + self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True) + + # ----------------- Save Image ------------------------------ + if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0: + def norm(x): + return (x-x.min())/(x.max()-x.min()) + + images = [x, pred] + log_step = self.global_step // self.trainer.log_every_n_steps + path_out = Path(self.logger.log_dir)/'images' + path_out.mkdir(parents=True, exist_ok=True) + images = torch.cat([norm(img) for img in images]) + save_image(images, path_out/f'sample_{log_step}.png') + + return loss + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(F.softplus(-logits_real)) + + torch.mean(F.softplus(logits_fake))) + return d_loss + +class VQGAN(BasicModel): + def __init__( + self, + in_ch: int = 3, + out_ch: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (32, 64, 128, 256, ), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 64, + norm_num_groups: int = 32, + + start_gan_train_step = 50000, # NOTE step increase with each optimizer + gan_loss_weight: float = 1.0, # alias discriminator + perceptual_loss_weight: float = 1.0, + embedding_loss_weight: float = 1.0, + + optimizer=torch.optim.AdamW, + optimizer_kwargs={}, + lr_scheduler=None, + lr_scheduler_kwargs={}, + loss=torch.nn.MSELoss, + loss_kwargs={} + ): + super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs, loss, loss_kwargs) + self.vqvae = VQModel(in_ch, out_ch, down_block_types, up_block_types, block_out_channels, layers_per_block, act_fn, + latent_channels, sample_size, num_vq_embeddings, norm_num_groups) + self.discriminator = NLayerDiscriminator(in_ch) + # self.perceiver = ... # Currently not supported, would require another trained NN + + self.start_gan_train_step = start_gan_train_step + self.perceptual_loss_weight = perceptual_loss_weight + self.gan_loss_weight = gan_loss_weight + self.embedding_loss_weight = embedding_loss_weight + + def forward(self, x, condition=None): + return self.vqvae(x) + + def encode(self, x): + z = self.vqvae.encode(x, return_loss=False) + return z + + def decode(self, z): + x = self.vqvae.decode(z) + return x + + + def compute_lambda(self, rec_loss, gan_loss, eps=1e-4): + """Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841""" + last_layer = self.vqvae.decoder.conv_out.weight + rec_grads = torch.autograd.grad(rec_loss, last_layer, retain_graph=True)[0] + gan_grads = torch.autograd.grad(gan_loss, last_layer, retain_graph=True)[0] + d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps) + d_weight = torch.clamp(d_weight, 0.0, 1e4) + return d_weight.detach() + + + + def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): + x = batch['source'] + # condition = batch.get('target', None) + + pred, vq_emb_loss = self.vqvae(x) + + if optimizer_idx == 0: + # ------ VAE ------- + vq_img_loss = F.mse_loss(pred, x) + vq_per_loss = 0.0 #self.perceiver(pred, x) + rec_loss = vq_img_loss+self.perceptual_loss_weight*vq_per_loss + + # ------- GAN ----- + if step > self.start_gan_train_step: + gan_loss = -torch.mean(self.discriminator(pred)) + lambda_weight = self.compute_lambda(rec_loss, gan_loss) + gan_loss = gan_loss*lambda_weight + else: + gan_loss = torch.tensor([0.0], requires_grad=True, device=x.device) + + loss = self.gan_loss_weight*gan_loss+rec_loss+self.embedding_loss_weight*vq_emb_loss + + elif optimizer_idx == 1: + if step > self.start_gan_train_step//2: + logits_real = self.discriminator(x.detach()) + logits_fake = self.discriminator(pred.detach()) + loss = hinge_d_loss(logits_real, logits_fake) + else: + loss = torch.tensor([0.0], requires_grad=True, device=x.device) + + # --------------------- Compute Metrics ------------------------------- + results = {'loss':loss.detach(), f'loss_{optimizer_idx}':loss.detach()} + with torch.no_grad(): + results[f'L2'] = torch.nn.functional.mse_loss(pred, x) + results[f'L1'] = torch.nn.functional.l1_loss(pred, x) + + # ----------------- Log Scalars ---------------------- + for metric_name, metric_val in results.items(): + self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True) + + # ----------------- Save Image ------------------------------ + if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0: # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ... + def norm(x): + return (x-x.min())/(x.max()-x.min()) + + images = torch.cat([x, pred]) + log_step = self.global_step // self.trainer.log_every_n_steps + path_out = Path(self.logger.log_dir)/'images' + path_out.mkdir(parents=True, exist_ok=True) + images = torch.stack([norm(img) for img in images]) + save_image(images, path_out/f'sample_{log_step}.png') + + return loss + + def configure_optimizers(self): + opt_vae = self.optimizer(self.vqvae.parameters(), **self.optimizer_kwargs) + opt_disc = self.optimizer(self.discriminator.parameters(), **self.optimizer_kwargs) + if self.lr_scheduler is not None: + scheduler = [ + { + 'scheduler': self.lr_scheduler(opt_vae, **self.lr_scheduler_kwargs), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': self.lr_scheduler(opt_disc, **self.lr_scheduler_kwargs), + 'interval': 'step', + 'frequency': 1 + }, + ] + else: + scheduler = [] + + return [opt_vae, opt_disc], scheduler + +class VAEWrapper(BasicModel): + def __init__( + self, + in_ch: int = 3, + out_ch: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), # "DownEncoderBlock2D", "DownEncoderBlock2D", + up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D","UpDecoderBlock2D" ), # "UpDecoderBlock2D", "UpDecoderBlock2D", + block_out_channels: Tuple[int] = (32, 64, 128, 256), # 128, 256 + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + norm_num_groups: int = 32, + sample_size: int = 32, + + optimizer=torch.optim.AdamW, + optimizer_kwargs={'lr':1e-4, 'weight_decay':1e-3, 'amsgrad':True}, + lr_scheduler=None, + lr_scheduler_kwargs={}, + # loss=torch.nn.MSELoss, # WARNING: No Effect + # loss_kwargs={'reduction': 'mean'} + ): + super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs ) # loss, loss_kwargs + self.model = AutoencoderKL(in_ch, out_ch, down_block_types, up_block_types, block_out_channels, + layers_per_block, act_fn, latent_channels, norm_num_groups, sample_size) + + self.logvar = nn.Parameter(torch.zeros(size=())) # Better weighting between KL and MSE, see (https://arxiv.org/abs/1903.05789), also used by Taming-Transfomer/Stable Diffusion + + def forward(self, sample): + return self.model(sample) + + def encode(self, x): + z = self.model.encode(x) # Latent space but not yet mapped to discrete embedding vectors + return z.sample(generator=None) + + def decode(self, z): + x = self.model.decode(z) + return x + + def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): + # ------------------------- Get Source/Target --------------------------- + x = batch['source'] + target = x + HALF_LOG_TWO_PI = 0.91893 # log(2pi)/2 + + # ------------------------- Run Model --------------------------- + pred, kl_loss = self(x) + + # ------------------------- Compute Loss --------------------------- + loss = torch.sum( torch.square(pred-target))/x.shape[0] #torch.sum( torch.square((pred-target)/torch.exp(self.logvar))/2 + self.logvar + HALF_LOG_TWO_PI )/x.shape[0] + loss += kl_loss + + # --------------------- Compute Metrics ------------------------------- + results = {'loss':loss.detach()} + with torch.no_grad(): + results['L2'] = torch.nn.functional.mse_loss(pred, target) + results['L1'] = torch.nn.functional.l1_loss(pred, target) + + # ----------------- Log Scalars ---------------------- + for metric_name, metric_val in results.items(): + self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True) + + # ----------------- Save Image ------------------------------ + if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0: + def norm(x): + return (x-x.min())/(x.max()-x.min()) + + images = torch.cat([x, pred]) + log_step = self.global_step // self.trainer.log_every_n_steps + path_out = Path(self.logger.log_dir)/'images' + path_out.mkdir(parents=True, exist_ok=True) + images = torch.stack([norm(img) for img in images]) + save_image(images, path_out/f'sample_{log_step}.png') + + return loss \ No newline at end of file diff --git a/medical_diffusion/external/stable_diffusion/attention.py b/medical_diffusion/external/stable_diffusion/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..844d73c23e40b8bb9c2392fd270c8da46f9eb1aa --- /dev/null +++ b/medical_diffusion/external/stable_diffusion/attention.py @@ -0,0 +1,261 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from .util_attention import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/medical_diffusion/external/stable_diffusion/lr_schedulers.py b/medical_diffusion/external/stable_diffusion/lr_schedulers.py new file mode 100755 index 0000000000000000000000000000000000000000..32ef2e41ce5b2462e2d022795257ebdb3c95e5bb --- /dev/null +++ b/medical_diffusion/external/stable_diffusion/lr_schedulers.py @@ -0,0 +1,33 @@ +import torch + +class LambdaLinearScheduler: + def __init__(self, warm_up_steps=[10000,], f_min=[1.0,], f_max=[1.0,], f_start=[1.e-6], cycle_lengths=[10000000000000], verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = torch.cumsum(torch.tensor([0] + list(self.cycle_lengths)), 0) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f \ No newline at end of file diff --git a/medical_diffusion/external/stable_diffusion/unet_openai.py b/medical_diffusion/external/stable_diffusion/unet_openai.py new file mode 100755 index 0000000000000000000000000000000000000000..2cd4ee7c72d805c02add83b1632f9f3b09d44108 --- /dev/null +++ b/medical_diffusion/external/stable_diffusion/unet_openai.py @@ -0,0 +1,962 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from .attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size=32, + in_channels=4, + model_channels=256, + out_channels=4, + num_res_blocks=2, + attention_resolutions=[4,2,1], + dropout=0, + channel_mult=(1, 2, 4), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=8, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + **kwargs + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + # from omegaconf.listconfig import ListConfig + # if type(context_dim) == ListConfig: + # context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, t=None, condition=None, context=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + condition = None # --------------------- WANRING ONLY for Testing --------------------- + assert (condition is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(t, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert condition.shape == (x.shape[0],) + emb = emb + self.label_emb(condition) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h), [] + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) \ No newline at end of file diff --git a/medical_diffusion/external/stable_diffusion/util.py b/medical_diffusion/external/stable_diffusion/util.py new file mode 100755 index 0000000000000000000000000000000000000000..bf545e77c4d27f01ca0816b32c27c13a2d632205 --- /dev/null +++ b/medical_diffusion/external/stable_diffusion/util.py @@ -0,0 +1,284 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +#--------------- Added ---------------- +import importlib +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + +#-------------------------------- + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/medical_diffusion/external/stable_diffusion/util_attention.py b/medical_diffusion/external/stable_diffusion/util_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..dada3a3c45bdd82db4f4b84772a2e4e4abe0ca40 --- /dev/null +++ b/medical_diffusion/external/stable_diffusion/util_attention.py @@ -0,0 +1,56 @@ + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + diff --git a/medical_diffusion/external/unet_lucidrains.py b/medical_diffusion/external/unet_lucidrains.py new file mode 100755 index 0000000000000000000000000000000000000000..7b80507d12cf87e24bbced03c06caa963a45eb43 --- /dev/null +++ b/medical_diffusion/external/unet_lucidrains.py @@ -0,0 +1,332 @@ +from torch import nn, einsum +from einops import rearrange, reduce +import torch +import torch.nn.functional as F +from functools import partial +import math + +# -------------------------------- Embeddings ------------------------------------------------------ +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + +# ------------------------------------------------------------------- + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def l2norm(t): + return F.normalize(t, dim = -1) + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + ) + +def Downsample(dim, dim_out = None): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) + +class WeightStandardizedConv2d(nn.Conv2d): + """ + https://arxiv.org/abs/1903.10520 + weight standardization purportedly works synergistically with group normalization + """ + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + + weight = self.weight + mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') + var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) + normalized_weight = (weight - mean) * (var + eps).rsqrt() + + return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * (var + eps).rsqrt() * self.g + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups = 8): + super().__init__() + self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + +class LinearAttention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + + self.to_out = nn.Sequential( + nn.Conv2d(hidden_dim, dim, 1), + LayerNorm(dim) + ) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q.softmax(dim = -2) + k = k.softmax(dim = -1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + return self.to_out(out) + +class Attention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32, scale = 10): + super().__init__() + self.scale = scale + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q, k = map(l2norm, (q, k)) + + sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale + attn = sim.softmax(dim = -1) + out = einsum('b h i j, b h d j -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + return self.to_out(out) + + + +class UNet(nn.Module): + def __init__( + self, + dim=32, + init_dim = None, + out_dim = None, + dim_mults=(1, 2, 4, 8), + channels = 3, + self_condition = False, + resnet_block_groups = 8, + learned_variance = False, + learned_sinusoidal_cond = False, + learned_sinusoidal_dim = 16, + **kwargs + ): + super().__init__() + + # determine dimensions + + self.channels = channels + self.self_condition = self_condition + input_channels = channels * (2 if self_condition else 1) + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # time embeddings + + time_dim = dim * 4 + + self.learned_sinusoidal_cond = learned_sinusoidal_cond + + if learned_sinusoidal_cond: + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) + fourier_dim = learned_sinusoidal_dim + 1 + else: + sinu_pos_emb = SinusoidalPosEmb(dim) + fourier_dim = dim + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + default_out_dim = channels * (1 if not learned_variance else 2) + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv = nn.Conv2d(dim, self.out_dim, 1) + + def forward(self, x, time, condition=None, self_cond=None): + if self.self_condition: + x_self_cond = default(self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x_self_cond, x), dim = 1) + + x = self.init_conv(x) + r = x.clone() + + t = self.time_mlp(time) + + h = [] + + for block1, block2, attn, downsample in self.downs: + x = block1(x, t) + h.append(x) + + x = block2(x, t) + x = attn(x) + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x, t) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x, t) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim = 1) + + x = self.final_res_block(x, t) + return self.final_conv(x), [] \ No newline at end of file diff --git a/medical_diffusion/loss/gan_losses.py b/medical_diffusion/loss/gan_losses.py new file mode 100755 index 0000000000000000000000000000000000000000..3b7ecb187745408292200b59f5d72ec7f4c95bb2 --- /dev/null +++ b/medical_diffusion/loss/gan_losses.py @@ -0,0 +1,22 @@ + + +import torch +import torch.nn.functional as F + +def exp_d_loss(logits_real, logits_fake): + loss_real = torch.mean(torch.exp(-logits_real)) + loss_fake = torch.mean(torch.exp(logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(F.softplus(-logits_real)) + + torch.mean(F.softplus(logits_fake))) + return d_loss \ No newline at end of file diff --git a/medical_diffusion/loss/perceivers.py b/medical_diffusion/loss/perceivers.py new file mode 100755 index 0000000000000000000000000000000000000000..0b789b40bf36e8876ccd053d98247da5ffdc4b90 --- /dev/null +++ b/medical_diffusion/loss/perceivers.py @@ -0,0 +1,27 @@ + + +import lpips +import torch + +class LPIPS(torch.nn.Module): + """Learned Perceptual Image Patch Similarity (LPIPS)""" + def __init__(self, linear_calibration=False, normalize=False): + super().__init__() + self.loss_fn = lpips.LPIPS(net='vgg', lpips=linear_calibration) # Note: only 'vgg' valid as loss + self.normalize = normalize # If true, normalize [0, 1] to [-1, 1] + + + def forward(self, pred, target): + # No need to do that because ScalingLayer was introduced in version 0.1 which does this indirectly + # if pred.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB + # pred = torch.concat([pred, pred, pred], dim=1) + # if target.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB + # target = torch.concat([target, target, target], dim=1) + + if pred.ndim == 5: # 3D Image: Just use 2D model and compute average over slices + depth = pred.shape[2] + losses = torch.stack([self.loss_fn(pred[:,:,d], target[:,:,d], normalize=self.normalize) for d in range(depth)], dim=2) + return torch.mean(losses, dim=2, keepdim=True) + else: + return self.loss_fn(pred, target, normalize=self.normalize) + \ No newline at end of file diff --git a/medical_diffusion/metrics/__init__.py b/medical_diffusion/metrics/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/medical_diffusion/metrics/torchmetrics_pr_recall.py b/medical_diffusion/metrics/torchmetrics_pr_recall.py new file mode 100755 index 0000000000000000000000000000000000000000..1b47664d191097e9c599904cd9f05ff6835121c8 --- /dev/null +++ b/medical_diffusion/metrics/torchmetrics_pr_recall.py @@ -0,0 +1,170 @@ +from typing import Optional, List + +import torch +from torch import Tensor +from torchmetrics import Metric +import torchvision.models as models +from torchvision import transforms + + + +from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE + +if _TORCH_FIDELITY_AVAILABLE: + from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 +else: + class FeatureExtractorInceptionV3(Module): # type: ignore + pass + __doctest_skip__ = ["ImprovedPrecessionRecall", "IPR"] + +class NoTrainInceptionV3(FeatureExtractorInceptionV3): + def __init__( + self, + name: str, + features_list: List[str], + feature_extractor_weights_path: Optional[str] = None, + ) -> None: + super().__init__(name, features_list, feature_extractor_weights_path) + # put into evaluation mode + self.eval() + + def train(self, mode: bool) -> "NoTrainInceptionV3": + """the inception network should not be able to be switched away from evaluation mode.""" + return super().train(False) + + def forward(self, x: Tensor) -> Tensor: + out = super().forward(x) + return out[0].reshape(x.shape[0], -1) + + +# -------------------------- VGG Trans --------------------------- +# class Normalize(object): +# """Rescale the image from 0-255 (uint8) to [0,1] (float32). +# Note, this doesn't ensure that min=0 and max=1 as a min-max scale would do!""" + +# def __call__(self, image): +# return image/255 + +# # see https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html +# VGG_Trans = transforms.Compose([ +# transforms.Resize([224, 224], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), +# # transforms.Resize([256, 256], interpolation=InterpolationMode.BILINEAR), +# # transforms.CenterCrop(224), +# Normalize(), # scale to [0, 1] +# transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) +# ]) + + + +class ImprovedPrecessionRecall(Metric): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + + def __init__(self, feature=2048, knn=3, splits_real=1, splits_fake=5): + super().__init__() + + + # ------------------------- Init Feature Extractor (VGG or Inception) ------------------------------ + # Original VGG: https://github.com/kynkaat/improved-precision-and-recall-metric/blob/b0247eafdead494a5d243bd2efb1b0b124379ae9/utils.py#L40 + # Compare Inception: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L574 + # TODO: Add option to switch between Inception and VGG feature extractor + # self.vgg_model = models.vgg16(weights='IMAGENET1K_V1').eval() + # self.feature_extractor = transforms.Compose([ + # VGG_Trans, + # self.vgg_model.features, + # transforms.Lambda(lambda x: torch.flatten(x, 1)), + # self.vgg_model.classifier[:4] # [:4] corresponds to 4096 features + # ]) + + if isinstance(feature, int): + if not _TORCH_FIDELITY_AVAILABLE: + raise ModuleNotFoundError( + "FrechetInceptionDistance metric requires that `Torch-fidelity` is installed." + " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." + ) + valid_int_input = [64, 192, 768, 2048] + if feature not in valid_int_input: + raise ValueError( + f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}." + ) + + self.feature_extractor = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)]) + elif isinstance(feature, torch.nn.Module): + self.feature_extractor = feature + else: + raise TypeError("Got unknown input to argument `feature`") + + # --------------------------- End Feature Extractor --------------------------------------------------------------- + + self.knn = knn + self.splits_real = splits_real + self.splits_fake = splits_fake + self.add_state("real_features", [], dist_reduce_fx=None) + self.add_state("fake_features", [], dist_reduce_fx=None) + + + + def update(self, imgs: Tensor, real: bool) -> None: # type: ignore + """Update the state with extracted features. + + Args: + imgs: tensor with images feed to the feature extractor + real: bool indicating if ``imgs`` belong to the real or the fake distribution + """ + assert torch.is_tensor(imgs) and imgs.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8' + + features = self.feature_extractor(imgs).view(imgs.shape[0], -1) + + if real: + self.real_features.append(features) + else: + self.fake_features.append(features) + + def compute(self): + real_features = torch.concat(self.real_features) + fake_features = torch.concat(self.fake_features) + + real_distances = _compute_pairwise_distances(real_features, self.splits_real) + real_radii = _distances2radii(real_distances, self.knn) + + fake_distances = _compute_pairwise_distances(fake_features, self.splits_fake) + fake_radii = _distances2radii(fake_distances, self.knn) + + precision = _compute_metric(real_features, real_radii, self.splits_real, fake_features, self.splits_fake) + recall = _compute_metric(fake_features, fake_radii, self.splits_fake, real_features, self.splits_real) + + return precision, recall + +def _compute_metric(ref_features, ref_radii, ref_splits, pred_features, pred_splits): + dist = _compute_pairwise_distances(ref_features, ref_splits, pred_features, pred_splits) + num_feat = pred_features.shape[0] + count = 0 + for i in range(num_feat): + count += (dist[:, i] < ref_radii).any() + return count / num_feat + +def _distances2radii(distances, knn): + return torch.topk(distances, knn+1, dim=1, largest=False)[0].max(dim=1)[0] + +def _compute_pairwise_distances(X, splits_x, Y=None, splits_y=None): + # X = [B, features] + # Y = [B', features] + Y = X if Y is None else Y + # X = X.double() + # Y = Y.double() + splits_y = splits_x if splits_y is None else splits_y + dist = torch.concat([ + torch.concat([ + (torch.sum(X_batch**2, dim=1, keepdim=True) + + torch.sum(Y_batch**2, dim=1, keepdim=True).t() - + 2 * torch.einsum("bd,dn->bn", X_batch, Y_batch.t())) + for Y_batch in Y.chunk(splits_y, dim=0)], dim=1) + for X_batch in X.chunk(splits_x, dim=0)]) + + # dist = torch.maximum(dist, torch.zeros_like(dist)) + dist[dist<0] = 0 + return torch.sqrt(dist) + + \ No newline at end of file diff --git a/medical_diffusion/models/__init__.py b/medical_diffusion/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..ae49315a0fea3449a5fcf5d194426778c95bc364 --- /dev/null +++ b/medical_diffusion/models/__init__.py @@ -0,0 +1 @@ +from .model_base import BasicModel diff --git a/medical_diffusion/models/embedders/__init__.py b/medical_diffusion/models/embedders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..32a87795340663cb12b33b59ad64f621e109ed29 --- /dev/null +++ b/medical_diffusion/models/embedders/__init__.py @@ -0,0 +1,2 @@ +from .time_embedder import TimeEmbbeding, LearnedSinusoidalPosEmb, SinusoidalPosEmb +from .cond_embedders import LabelEmbedder \ No newline at end of file diff --git a/medical_diffusion/models/embedders/cond_embedders.py b/medical_diffusion/models/embedders/cond_embedders.py new file mode 100755 index 0000000000000000000000000000000000000000..10a8a44211f75b96690c312c58a2420c7591e3c1 --- /dev/null +++ b/medical_diffusion/models/embedders/cond_embedders.py @@ -0,0 +1,27 @@ + +import torch.nn as nn +import torch +from monai.networks.layers.utils import get_act_layer + +class LabelEmbedder(nn.Module): + def __init__(self, emb_dim=32, num_classes=2, act_name=("SWISH", {})): + super().__init__() + self.emb_dim = emb_dim + self.embedding = nn.Embedding(num_classes, emb_dim) + + # self.embedding = nn.Embedding(num_classes, emb_dim//4) + # self.emb_net = nn.Sequential( + # nn.Linear(1, emb_dim), + # get_act_layer(act_name), + # nn.Linear(emb_dim, emb_dim) + # ) + + def forward(self, condition): + c = self.embedding(condition) #[B,] -> [B, C] + # c = self.emb_net(c) + # c = self.emb_net(condition[:,None].float()) + # c = (2*condition-1)[:, None].expand(-1, self.emb_dim).type(torch.float32) + return c + + + diff --git a/medical_diffusion/models/embedders/latent_embedders.py b/medical_diffusion/models/embedders/latent_embedders.py new file mode 100755 index 0000000000000000000000000000000000000000..0adff67df0ec22895f9cd59d05fdb93bea1c1bf9 --- /dev/null +++ b/medical_diffusion/models/embedders/latent_embedders.py @@ -0,0 +1,1065 @@ + +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.utils import save_image +from monai.networks.blocks import UnetOutBlock + + +from medical_diffusion.models.utils.conv_blocks import DownBlock, UpBlock, BasicBlock, BasicResBlock, UnetResBlock, UnetBasicBlock +from medical_diffusion.loss.gan_losses import hinge_d_loss +from medical_diffusion.loss.perceivers import LPIPS +from medical_diffusion.models.model_base import BasicModel, VeryBasicModel + + +from pytorch_msssim import SSIM, ssim + + +class DiagonalGaussianDistribution(nn.Module): + + def forward(self, x): + mean, logvar = torch.chunk(x, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + sample = torch.randn(mean.shape, generator=None, device=x.device) + z = mean + std * sample + + batch_size = x.shape[0] + var = torch.exp(logvar) + kl = 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar)/batch_size + + return z, kl + + + + + + +class VectorQuantizer(nn.Module): + def __init__(self, num_embeddings, emb_channels, beta=0.25): + super().__init__() + self.num_embeddings = num_embeddings + self.emb_channels = emb_channels + self.beta = beta + + self.embedder = nn.Embedding(num_embeddings, emb_channels) + self.embedder.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings) + + def forward(self, z): + assert z.shape[1] == self.emb_channels, "Channels of z and codebook don't match" + z_ch = torch.moveaxis(z, 1, -1) # [B, C, *] -> [B, *, C] + z_flattened = z_ch.reshape(-1, self.emb_channels) # [B, *, C] -> [Bx*, C], Note: or use contiguous() and view() + + # distances from z to embeddings e: (z - e)^2 = z^2 + e^2 - 2 e * z + dist = ( torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedder.weight**2, dim=1) + -2* torch.einsum("bd,dn->bn", z_flattened, self.embedder.weight.t()) + ) # [Bx*, num_embeddings] + + min_encoding_indices = torch.argmin(dist, dim=1) # [Bx*] + z_q = self.embedder(min_encoding_indices) # [Bx*, C] + z_q = z_q.view(z_ch.shape) # [Bx*, C] -> [B, *, C] + z_q = torch.moveaxis(z_q, -1, 1) # [B, *, C] -> [B, C, *] + + # Compute Embedding Loss + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + return z_q, loss + + + +class Discriminator(nn.Module): + def __init__(self, + in_channels=1, + spatial_dims = 3, + hid_chs = [32, 64, 128, 256, 512], + kernel_sizes=[(1,3,3), (1,3,3), (1,3,3), 3, 3], + strides = [ 1, (1,2,2), (1,2,2), 2, 2], + act_name=("Swish", {}), + norm_name = ("GROUP", {'num_groups':32, "affine": True}), + dropout=None + ): + super().__init__() + + self.inc = BasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=hid_chs[0], + kernel_size=kernel_sizes[0], # 2*pad = kernel-stride -> kernel = 2*pad + stride => 1 = 2*0+1, 3, =2*1+1, 2 = 2*0+2, 4 = 2*1+2 + stride=strides[0], + norm_name=norm_name, + act_name=act_name, + dropout=dropout, + ) + + self.encoder = nn.Sequential(*[ + BasicBlock( + spatial_dims=spatial_dims, + in_channels=hid_chs[i-1], + out_channels=hid_chs[i], + kernel_size=kernel_sizes[i], + stride=strides[i], + act_name=act_name, + norm_name=norm_name, + dropout=dropout) + for i in range(1, len(hid_chs)) + ]) + + + self.outc = BasicBlock( + spatial_dims=spatial_dims, + in_channels=hid_chs[-1], + out_channels=1, + kernel_size=3, + stride=1, + act_name=None, + norm_name=None, + dropout=None, + zero_conv=True + ) + + + + def forward(self, x): + x = self.inc(x) + x = self.encoder(x) + return self.outc(x) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, + in_channels=1, + spatial_dims = 3, + hid_chs = [64, 128, 256, 512, 512], + kernel_sizes=[4, 4, 4, 4, 4], + strides = [2, 2, 2, 1, 1], + act_name=("LeakyReLU", {'negative_slope': 0.2}), + norm_name = ("BATCH", {}), + dropout=None + ): + super().__init__() + + self.inc = BasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=hid_chs[0], + kernel_size=kernel_sizes[0], + stride=strides[0], + norm_name=None, + act_name=act_name, + dropout=dropout, + ) + + self.encoder = nn.Sequential(*[ + BasicBlock( + spatial_dims=spatial_dims, + in_channels=hid_chs[i-1], + out_channels=hid_chs[i], + kernel_size=kernel_sizes[i], + stride=strides[i], + act_name=act_name, + norm_name=norm_name, + dropout=dropout) + for i in range(1, len(strides)) + ]) + + + self.outc = BasicBlock( + spatial_dims=spatial_dims, + in_channels=hid_chs[-1], + out_channels=1, + kernel_size=4, + stride=1, + norm_name=None, + act_name=None, + dropout=False, + ) + + def forward(self, x): + x = self.inc(x) + x = self.encoder(x) + return self.outc(x) + + + + +class VQVAE(BasicModel): + def __init__( + self, + in_channels=3, + out_channels=3, + spatial_dims = 2, + emb_channels = 4, + num_embeddings = 8192, + hid_chs = [32, 64, 128, 256], + kernel_sizes=[ 3, 3, 3, 3], + strides = [ 1, 2, 2, 2], + norm_name = ("GROUP", {'num_groups':32, "affine": True}), + act_name=("Swish", {}), + dropout=0.0, + use_res_block=True, + deep_supervision=False, + learnable_interpolation=True, + use_attention='none', + beta = 0.25, + embedding_loss_weight=1.0, + perceiver = LPIPS, + perceiver_kwargs = {}, + perceptual_loss_weight = 1.0, + + + optimizer=torch.optim.Adam, + optimizer_kwargs={'lr':1e-4}, + lr_scheduler= None, + lr_scheduler_kwargs={}, + loss = torch.nn.L1Loss, + loss_kwargs={'reduction': 'none'}, + + sample_every_n_steps = 1000 + + ): + super().__init__( + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs + ) + self.sample_every_n_steps=sample_every_n_steps + self.loss_fct = loss(**loss_kwargs) + self.embedding_loss_weight = embedding_loss_weight + self.perceiver = perceiver(**perceiver_kwargs).eval() if perceiver is not None else None + self.perceptual_loss_weight = perceptual_loss_weight + use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides) + self.depth = len(strides) + self.deep_supervision = deep_supervision + + # ----------- In-Convolution ------------ + ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock + self.inc = ConvBlock(spatial_dims, in_channels, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock( + spatial_dims, + hid_chs[i-1], + hid_chs[i], + kernel_sizes[i], + strides[i], + kernel_sizes[i], + norm_name, + act_name, + dropout, + use_res_block, + learnable_interpolation, + use_attention[i]) + for i in range(1, self.depth) + ]) + + # ----------- Out-Encoder ------------ + self.out_enc = BasicBlock(spatial_dims, hid_chs[-1], emb_channels, 1) + + + # ----------- Quantizer -------------- + self.quantizer = VectorQuantizer( + num_embeddings=num_embeddings, + emb_channels=emb_channels, + beta=beta + ) + + # ----------- In-Decoder ------------ + self.inc_dec = ConvBlock(spatial_dims, emb_channels, hid_chs[-1], 3, act_name=act_name, norm_name=norm_name) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock( + spatial_dims, + hid_chs[i+1], + hid_chs[i], + kernel_size=kernel_sizes[i+1], + stride=strides[i+1], + upsample_kernel_size=strides[i+1], + norm_name=norm_name, + act_name=act_name, + dropout=dropout, + use_res_block=use_res_block, + learnable_interpolation=learnable_interpolation, + use_attention=use_attention[i], + skip_channels=0) + for i in range(self.depth-1) + ]) + + # --------------- Out-Convolution ---------------- + self.outc = BasicBlock(spatial_dims, hid_chs[0], out_channels, 1, zero_conv=True) + if isinstance(deep_supervision, bool): + deep_supervision = self.depth-1 if deep_supervision else 0 + self.outc_ver = nn.ModuleList([ + BasicBlock(spatial_dims, hid_chs[i], out_channels, 1, zero_conv=True) + for i in range(1, deep_supervision+1) + ]) + + + def encode(self, x): + h = self.inc(x) + for i in range(len(self.encoders)): + h = self.encoders[i](h) + z = self.out_enc(h) + return z + + def decode(self, z): + z, _ = self.quantizer(z) + h = self.inc_dec(z) + for i in range(len(self.decoders), 0, -1): + h = self.decoders[i-1](h) + x = self.outc(h) + return x + + def forward(self, x_in): + # --------- Encoder -------------- + h = self.inc(x_in) + for i in range(len(self.encoders)): + h = self.encoders[i](h) + z = self.out_enc(h) + + # --------- Quantizer -------------- + z_q, emb_loss = self.quantizer(z) + + # -------- Decoder ----------- + out_hor = [] + h = self.inc_dec(z_q) + for i in range(len(self.decoders)-1, -1, -1): + out_hor.append(self.outc_ver[i](h)) if i < len(self.outc_ver) else None + h = self.decoders[i](h) + out = self.outc(h) + + return out, out_hor[::-1], emb_loss + + def perception_loss(self, pred, target, depth=0): + if (self.perceiver is not None) and (depth<2): + self.perceiver.eval() + return self.perceiver(pred, target)*self.perceptual_loss_weight + else: + return 0 + + def ssim_loss(self, pred, target): + return 1-ssim(((pred+1)/2).clamp(0,1), (target.type(pred.dtype)+1)/2, data_range=1, size_average=False, + nonnegative_ssim=True).reshape(-1, *[1]*(pred.ndim-1)) + + + def rec_loss(self, pred, pred_vertical, target): + interpolation_mode = 'nearest-exact' + weights = [1/2**i for i in range(1+len(pred_vertical))] # horizontal (equal) + vertical (reducing with every step down) + tot_weight = sum(weights) + weights = [w/tot_weight for w in weights] + + # Loss + loss = 0 + loss += torch.mean(self.loss_fct(pred, target)+self.perception_loss(pred, target)+self.ssim_loss(pred, target))*weights[0] + + for i, pred_i in enumerate(pred_vertical): + target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) + loss += torch.mean(self.loss_fct(pred_i, target_i)+self.perception_loss(pred_i, target_i)+self.ssim_loss(pred_i, target_i))*weights[i+1] + + return loss + + def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): + # ------------------------- Get Source/Target --------------------------- + x = batch['source'] + target = x + + # ------------------------- Run Model --------------------------- + pred, pred_vertical, emb_loss = self(x) + + # ------------------------- Compute Loss --------------------------- + loss = self.rec_loss(pred, pred_vertical, target) + loss += emb_loss*self.embedding_loss_weight + + # --------------------- Compute Metrics ------------------------------- + with torch.no_grad(): + logging_dict = {'loss':loss, 'emb_loss': emb_loss} + logging_dict['L2'] = torch.nn.functional.mse_loss(pred, target) + logging_dict['L1'] = torch.nn.functional.l1_loss(pred, target) + logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1) + + # ----------------- Log Scalars ---------------------- + for metric_name, metric_val in logging_dict.items(): + self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True) + + # ----------------- Save Image ------------------------------ + if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: + log_step = self.global_step // self.sample_every_n_steps + path_out = Path(self.logger.log_dir)/'images' + path_out.mkdir(parents=True, exist_ok=True) + # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images + def depth2batch(image): + return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1)) + images = torch.cat([depth2batch(img)[:16] for img in (x, pred)]) + save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True) + + return loss + + + +class VQGAN(VeryBasicModel): + def __init__( + self, + in_channels=3, + out_channels=3, + spatial_dims = 2, + emb_channels = 4, + num_embeddings = 8192, + hid_chs = [ 64, 128, 256, 512], + kernel_sizes=[ 3, 3, 3, 3], + strides = [ 1, 2, 2, 2], + norm_name = ("GROUP", {'num_groups':32, "affine": True}), + act_name=("Swish", {}), + dropout=0.0, + use_res_block=True, + deep_supervision=False, + learnable_interpolation=True, + use_attention='none', + beta = 0.25, + embedding_loss_weight=1.0, + perceiver = LPIPS, + perceiver_kwargs = {}, + perceptual_loss_weight: float = 1.0, + + + start_gan_train_step = 50000, # NOTE step increase with each optimizer + gan_loss_weight: float = 1.0, # = discriminator + + optimizer_vqvae=torch.optim.Adam, + optimizer_gan=torch.optim.Adam, + optimizer_vqvae_kwargs={'lr':1e-6}, + optimizer_gan_kwargs={'lr':1e-6}, + lr_scheduler_vqvae= None, + lr_scheduler_vqvae_kwargs={}, + lr_scheduler_gan= None, + lr_scheduler_gan_kwargs={}, + + pixel_loss = torch.nn.L1Loss, + pixel_loss_kwargs={'reduction':'none'}, + gan_loss_fct = hinge_d_loss, + + sample_every_n_steps = 1000 + + ): + super().__init__() + self.sample_every_n_steps=sample_every_n_steps + self.start_gan_train_step = start_gan_train_step + self.gan_loss_weight = gan_loss_weight + self.embedding_loss_weight = embedding_loss_weight + + self.optimizer_vqvae = optimizer_vqvae + self.optimizer_gan = optimizer_gan + self.optimizer_vqvae_kwargs = optimizer_vqvae_kwargs + self.optimizer_gan_kwargs = optimizer_gan_kwargs + self.lr_scheduler_vqvae = lr_scheduler_vqvae + self.lr_scheduler_vqvae_kwargs = lr_scheduler_vqvae_kwargs + self.lr_scheduler_gan = lr_scheduler_gan + self.lr_scheduler_gan_kwargs = lr_scheduler_gan_kwargs + + self.pixel_loss_fct = pixel_loss(**pixel_loss_kwargs) + self.gan_loss_fct = gan_loss_fct + + self.vqvae = VQVAE(in_channels, out_channels, spatial_dims, emb_channels, num_embeddings, hid_chs, kernel_sizes, + strides, norm_name, act_name, dropout, use_res_block, deep_supervision, learnable_interpolation, use_attention, + beta, embedding_loss_weight, perceiver, perceiver_kwargs, perceptual_loss_weight) + + self.discriminator = nn.ModuleList([Discriminator(in_channels, spatial_dims, hid_chs, kernel_sizes, strides, + act_name, norm_name, dropout) for i in range(len(self.vqvae.outc_ver)+1)]) + + + # self.discriminator = nn.ModuleList([NLayerDiscriminator(in_channels, spatial_dims) + # for _ in range(len(self.vqvae.decoder.outc_ver)+1)]) + + + + def encode(self, x): + return self.vqvae.encode(x) + + def decode(self, z): + return self.vqvae.decode(z) + + def forward(self, x): + return self.vqvae.forward(x) + + + def vae_img_loss(self, pred, target, dec_out_layer, step, discriminator, depth=0): + # ------ VQVAE ------- + rec_loss = self.vqvae.rec_loss(pred, [], target) + + # ------- GAN ----- + if step > self.start_gan_train_step: + gan_loss = -torch.mean(discriminator[depth](pred)) + lambda_weight = self.compute_lambda(rec_loss, gan_loss, dec_out_layer) + gan_loss = gan_loss*lambda_weight + + with torch.no_grad(): + self.log(f"train/gan_loss_{depth}", gan_loss, on_step=True, on_epoch=True) + self.log(f"train/lambda_{depth}", lambda_weight, on_step=True, on_epoch=True) + else: + gan_loss = 0 #torch.tensor([0.0], requires_grad=True, device=target.device) + + return self.gan_loss_weight*gan_loss+rec_loss + + + def gan_img_loss(self, pred, target, step, discriminators, depth): + if (step > self.start_gan_train_step) and (depth self.start_gan_train_step) and (depth<2): + gan_loss = -torch.sum(discriminator[depth](pred)) # clamp(..., None, 0) => only punish areas that were rated as fake (<0) by discriminator => ensures loss >0 and +- don't cannel out in sum + lambda_weight = self.compute_lambda(rec_loss, gan_loss, dec_out_layer) + gan_loss = gan_loss*lambda_weight + + with torch.no_grad(): + self.log(f"train/gan_loss_{depth}", gan_loss, on_step=True, on_epoch=True) + self.log(f"train/lambda_{depth}", lambda_weight, on_step=True, on_epoch=True) + else: + gan_loss = 0 #torch.tensor([0.0], requires_grad=True, device=target.device) + + + + return self.gan_loss_weight*gan_loss+rec_loss + + def gan_img_loss(self, pred, target, step, discriminators, depth): + if (step > self.start_gan_train_step) and (depth1) and k==0: + seq_list.append( + BasicUp( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=strides[i], + stride=strides[i], + learnable_interpolation=learnable_interpolation + ) + ) + + out_blocks.append(SequentialEmb(*seq_list)) + self.out_blocks = nn.ModuleList(out_blocks) + + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = zero_module(UnetOutBlock(spatial_dims, hid_chs[0], out_ch_hor, dropout=None)) + if isinstance(deep_supervision, bool): + deep_supervision = self.depth-2 if deep_supervision else 0 + self.outc_ver = nn.ModuleList([ + zero_module(UnetOutBlock(spatial_dims, hid_chs[i]+hid_chs[i-1], out_ch, dropout=None) ) + for i in range(2, deep_supervision+2) + ]) + + + def forward(self, x_t, t=None, condition=None, self_cond=None): + # x_t [B, C, *] + # t [B,] + # condition [B,] + # self_cond [B, C, *] + + + # -------- Time Embedding (Gloabl) ----------- + if t is None: + time_emb = None + else: + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + emb = save_add(time_emb, cond_emb) + + # ---------- Self-conditioning----------- + if self.use_self_conditioning: + self_cond = torch.zeros_like(x_t) if self_cond is None else x_t + x_t = torch.cat([x_t, self_cond], dim=1) + + # --------- Encoder -------------- + x = [self.in_conv(x_t)] + for i in range(len(self.in_blocks)): + x.append(self.in_blocks[i](x[i], emb)) + + # ---------- Middle -------------- + h = self.middle_block(x[-1], emb) + + # -------- Decoder ----------- + y_ver = [] + for i in range(len(self.out_blocks), 0, -1): + h = torch.cat([h, x.pop()], dim=1) + + depth, j = i//(self.num_res_blocks+1), i%(self.num_res_blocks+1)-1 + y_ver.append(self.outc_ver[depth-1](h)) if (len(self.outc_ver)>=depth>0) and (j==0) else None + + h = self.out_blocks[i-1](h, emb) + + # ---------Out-Convolution ------------ + y = self.outc(h) + + return y, y_ver[::-1] + + + + +if __name__=='__main__': + model = UNet(in_ch=3, use_res_block=False, learnable_interpolation=False) + input = torch.randn((1,3,16,32,32)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) \ No newline at end of file diff --git a/medical_diffusion/models/model_base.py b/medical_diffusion/models/model_base.py new file mode 100755 index 0000000000000000000000000000000000000000..1c3dd87b6d1aeef49afc73354a6ee5f2309429d4 --- /dev/null +++ b/medical_diffusion/models/model_base.py @@ -0,0 +1,114 @@ + +from pathlib import Path +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.migration import pl_legacy_patch + +class VeryBasicModel(pl.LightningModule): + def __init__(self): + super().__init__() + self.save_hyperparameters() + self._step_train = 0 + self._step_val = 0 + self._step_test = 0 + + + def forward(self, x_in): + raise NotImplementedError + + def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): + raise NotImplementedError + + def training_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0 ): + self._step_train += 1 # =self.global_step + return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx) + + def validation_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0): + self._step_val += 1 + return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx ) + + def test_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0): + self._step_test += 1 + return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx) + + def _epoch_end(self, outputs: list, state: str): + return + + def training_epoch_end(self, outputs): + self._epoch_end(outputs, "train") + + def validation_epoch_end(self, outputs): + self._epoch_end(outputs, "val") + + def test_epoch_end(self, outputs): + self._epoch_end(outputs, "test") + + @classmethod + def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path): + with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f: + json.dump({'best_model_epoch': Path(best_model_path).name}, f) + + @classmethod + def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs): + path_version = 'lightning_logs/version_'+str(version) + with open(Path(path_checkpoint_dir) / path_version/ 'best_checkpoint.json', 'r') as f: + path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch']) + return Path(path_checkpoint_dir)/path_rel_best_checkpoint + + @classmethod + def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs): + path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version) + return cls.load_from_checkpoint(path_best_checkpoint, **kwargs) + + def load_pretrained(self, checkpoint_path, map_location=None, **kwargs): + if checkpoint_path.is_dir(): + checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs) + + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + return self.load_weights(checkpoint["state_dict"], **kwargs) + + def load_weights(self, pretrained_weights, strict=True, **kwargs): + filter = kwargs.get('filter', lambda key:key in pretrained_weights) + init_weights = self.state_dict() + pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter(key)} + init_weights.update(pretrained_weights) + self.load_state_dict(init_weights, strict=strict) + return self + + + + +class BasicModel(VeryBasicModel): + def __init__(self, + optimizer=torch.optim.AdamW, + optimizer_kwargs={'lr':1e-3, 'weight_decay':1e-2}, + lr_scheduler= None, + lr_scheduler_kwargs={}, + ): + super().__init__() + self.save_hyperparameters() + self.optimizer = optimizer + self.optimizer_kwargs = optimizer_kwargs + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs) + if self.lr_scheduler is not None: + lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) + return [optimizer], [lr_scheduler] + else: + return [optimizer] + + + + \ No newline at end of file diff --git a/medical_diffusion/models/noise_schedulers/__init__.py b/medical_diffusion/models/noise_schedulers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..8642182680db629b726918a8fb4703b68c7a7f96 --- /dev/null +++ b/medical_diffusion/models/noise_schedulers/__init__.py @@ -0,0 +1,2 @@ +from .scheduler_base import BasicNoiseScheduler +from .gaussian_scheduler import GaussianNoiseScheduler \ No newline at end of file diff --git a/medical_diffusion/models/noise_schedulers/gaussian_scheduler.py b/medical_diffusion/models/noise_schedulers/gaussian_scheduler.py new file mode 100755 index 0000000000000000000000000000000000000000..fa8b316fb0a0a6e654eddf5b5dff1914febcfa68 --- /dev/null +++ b/medical_diffusion/models/noise_schedulers/gaussian_scheduler.py @@ -0,0 +1,154 @@ + +import torch +import torch.nn.functional as F + + +from medical_diffusion.models.noise_schedulers import BasicNoiseScheduler + +class GaussianNoiseScheduler(BasicNoiseScheduler): + def __init__( + self, + timesteps=1000, + T = None, + schedule_strategy='cosine', + beta_start = 0.0001, # default 1e-4, stable-diffusion ~ 1e-3 + beta_end = 0.02, + betas = None, + ): + super().__init__(timesteps, T) + + self.schedule_strategy = schedule_strategy + + if betas is not None: + betas = torch.as_tensor(betas, dtype = torch.float64) + elif schedule_strategy == "linear": + betas = torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) + elif schedule_strategy == "scaled_linear": # proposed as "quadratic" in https://arxiv.org/abs/2006.11239, used in stable-diffusion + betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64)**2 + elif schedule_strategy == "cosine": + s = 0.008 + x = torch.linspace(0, timesteps, timesteps + 1, dtype = torch.float64) # [0, T] + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = torch.clip(betas, 0, 0.999) + else: + raise NotImplementedError(f"{schedule_strategy} does is not implemented for {self.__class__}") + + + alphas = 1-betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) + + + register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) + + register_buffer('betas', betas) # (0 , 1) + + register_buffer('alphas', alphas) # (1 , 0) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) + + register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + register_buffer('posterior_variance', betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)) + + + def estimate_x_t(self, x_0, t, x_T=None): + # NOTE: t == 0 means diffused for 1 step (https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils.py#L108) + # NOTE: t == 0 means not diffused for cold-diffusion (in contradiction to the above comment) https://github.com/arpitbansal297/Cold-Diffusion-Models/blob/c828140b7047ca22f995b99fbcda360bc30fc25d/denoising-diffusion-pytorch/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L361 + x_T = self.x_final(x_0) if x_T is None else x_T + # ndim = x_0.ndim + # x_t = (self.extract(self.sqrt_alphas_cumprod, t, ndim)*x_0 + + # self.extract(self.sqrt_one_minus_alphas_cumprod, t, ndim)*x_T) + def clipper(b): + tb = t[b] + if tb<0: + return x_0[b] + elif tb>=self.T: + return x_T[b] + else: + return self.sqrt_alphas_cumprod[tb]*x_0[b]+self.sqrt_one_minus_alphas_cumprod[tb]*x_T[b] + x_t = torch.stack([clipper(b) for b in range(t.shape[0])]) + return x_t + + + def estimate_x_t_prior_from_x_T(self, x_t, t, x_T, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False): + x_0 = self.estimate_x_0(x_t, x_T, t, clip_x0) + return self.estimate_x_t_prior_from_x_0(x_t, t, x_0, use_log, clip_x0, var_scale, cold_diffusion) + + + def estimate_x_t_prior_from_x_0(self, x_t, t, x_0, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False): + x_0 = self._clip_x_0(x_0) if clip_x0 else x_0 + + if cold_diffusion: # see https://arxiv.org/abs/2208.09392 + x_T_est = self.estimate_x_T(x_t, x_0, t) # or use x_T estimated by UNet if available? + x_t_est = self.estimate_x_t(x_0, t, x_T=x_T_est) + x_t_prior = self.estimate_x_t(x_0, t-1, x_T=x_T_est) + noise_t = x_t_est-x_t_prior + x_t_prior = x_t-noise_t + else: + mean = self.estimate_mean_t(x_t, x_0, t) + variance = self.estimate_variance_t(t, x_t.ndim, use_log, var_scale) + std = torch.exp(0.5*variance) if use_log else torch.sqrt(variance) + std[t==0] = 0.0 + x_T = self.x_final(x_t) + x_t_prior = mean+std*x_T + return x_t_prior, x_0 + + + def estimate_mean_t(self, x_t, x_0, t): + ndim = x_t.ndim + return (self.extract(self.posterior_mean_coef1, t, ndim)*x_0+ + self.extract(self.posterior_mean_coef2, t, ndim)*x_t) + + + def estimate_variance_t(self, t, ndim, log=True, var_scale=0, eps=1e-20): + min_variance = self.extract(self.posterior_variance, t, ndim) + max_variance = self.extract(self.betas, t, ndim) + if log: + min_variance = torch.log(min_variance.clamp(min=eps)) + max_variance = torch.log(max_variance.clamp(min=eps)) + return var_scale * max_variance + (1 - var_scale) * min_variance + + + def estimate_x_0(self, x_t, x_T, t, clip_x0=True): + ndim = x_t.ndim + x_0 = (self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - + self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)*x_T) + x_0 = self._clip_x_0(x_0) if clip_x0 else x_0 + return x_0 + + + def estimate_x_T(self, x_t, x_0, t, clip_x0=True): + ndim = x_t.ndim + x_0 = self._clip_x_0(x_0) if clip_x0 else x_0 + return ((self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - x_0)/ + self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)) + + + @classmethod + def x_final(cls, x): + return torch.randn_like(x) + + @classmethod + def _clip_x_0(cls, x_0): + # See "static/dynamic thresholding" in Imagen https://arxiv.org/abs/2205.11487 + + # "static thresholding" + m = 1 # Set this to about 4*sigma = 4 if latent diffusion is used + x_0 = x_0.clamp(-m, m) + + # "dynamic thresholding" + # r = torch.stack([torch.quantile(torch.abs(x_0_b), 0.997) for x_0_b in x_0]) + # r = torch.maximum(r, torch.full_like(r,m)) + # x_0 = torch.stack([x_0_b.clamp(-r_b, r_b)/r_b*m for x_0_b, r_b in zip(x_0, r) ] ) + + return x_0 + + + diff --git a/medical_diffusion/models/noise_schedulers/scheduler_base.py b/medical_diffusion/models/noise_schedulers/scheduler_base.py new file mode 100755 index 0000000000000000000000000000000000000000..1fbd790bf8e5af079991e91b1f128f902a15802f --- /dev/null +++ b/medical_diffusion/models/noise_schedulers/scheduler_base.py @@ -0,0 +1,49 @@ + + +import torch +import torch.nn as nn + + +class BasicNoiseScheduler(nn.Module): + def __init__( + self, + timesteps=1000, + T=None, + ): + super().__init__() + self.timesteps = timesteps + self.T = timesteps if T is None else T + + self.register_buffer('timesteps_array', torch.linspace(0, self.T-1, self.timesteps, dtype=torch.long)) # NOTE: End is inclusive therefore use -1 to get [0, T-1] + + def __len__(self): + return len(self.timesteps) + + def sample(self, x_0): + """Randomly sample t from [0,T] and return x_t and x_T based on x_0""" + t = torch.randint(0, self.T, (x_0.shape[0],), dtype=torch.long, device=x_0.device) # NOTE: High is exclusive, therefore [0, T-1] + x_T = self.x_final(x_0) + return self.estimate_x_t(x_0, t, x_T), x_T, t + + def estimate_x_t_prior_from_x_T(self, x_T, t, **kwargs): + raise NotImplemented + + def estimate_x_t_prior_from_x_0(self, x_0, t, **kwargs): + raise NotImplemented + + def estimate_x_t(self, x_0, t, x_T=None, **kwargs): + """Get x_t at time t""" + raise NotImplemented + + @classmethod + def x_final(cls, x): + """Get noise that should be obtained for t->T """ + raise NotImplemented + + @staticmethod + def extract(x, t, ndim): + """Extract values from x at t and reshape them to n-dim tensor""" + return x.gather(0, t).reshape(-1, *((1,)*(ndim-1))) + + + diff --git a/medical_diffusion/models/pipelines/__init__.py b/medical_diffusion/models/pipelines/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9dcccb7eb9077e438099e5c9196109a6950f5fba --- /dev/null +++ b/medical_diffusion/models/pipelines/__init__.py @@ -0,0 +1 @@ +from .diffusion_pipeline import DiffusionPipeline diff --git a/medical_diffusion/models/pipelines/diffusion_pipeline.py b/medical_diffusion/models/pipelines/diffusion_pipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..0b27d6143faaaf4e71f583cec3b33312adeb170a --- /dev/null +++ b/medical_diffusion/models/pipelines/diffusion_pipeline.py @@ -0,0 +1,348 @@ + + +from pathlib import Path +from tqdm import tqdm + +import torch +import torch.nn.functional as F +from torchvision.utils import save_image +import streamlit as st + +from medical_diffusion.models import BasicModel +from medical_diffusion.utils.train_utils import EMAModel +from medical_diffusion.utils.math_utils import kl_gaussians + + + + + + +class DiffusionPipeline(BasicModel): + def __init__(self, + noise_scheduler, + noise_estimator, + latent_embedder=None, + noise_scheduler_kwargs={}, + noise_estimator_kwargs={}, + latent_embedder_checkpoint='', + estimator_objective = 'x_T', # 'x_T' or 'x_0' + estimate_variance=False, + use_self_conditioning=False, + classifier_free_guidance_dropout=0.5, # Probability to drop condition during training, has only an effect for label-conditioned training + num_samples = 4, + do_input_centering = True, # Only for training + clip_x0=True, # Has only an effect during traing if use_self_conditioning=True, import for inference/sampling + use_ema = False, + ema_kwargs = {}, + optimizer=torch.optim.AdamW, + optimizer_kwargs={'lr':1e-4}, # stable-diffusion ~ 1e-4 + lr_scheduler= None, # stable-diffusion - LambdaLR + lr_scheduler_kwargs={}, + loss=torch.nn.L1Loss, + loss_kwargs={}, + sample_every_n_steps = 1000 + ): + # self.save_hyperparameters(ignore=['noise_estimator', 'noise_scheduler']) + super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs) + self.loss_fct = loss(**loss_kwargs) + self.sample_every_n_steps=sample_every_n_steps + + noise_estimator_kwargs['estimate_variance'] = estimate_variance + noise_estimator_kwargs['use_self_conditioning'] = use_self_conditioning + + self.noise_scheduler = noise_scheduler(**noise_scheduler_kwargs) + self.noise_estimator = noise_estimator(**noise_estimator_kwargs) + + with torch.no_grad(): + if latent_embedder is not None: + self.latent_embedder = latent_embedder.load_from_checkpoint(latent_embedder_checkpoint) + for param in self.latent_embedder.parameters(): + param.requires_grad = False + else: + self.latent_embedder = None + + self.estimator_objective = estimator_objective + self.use_self_conditioning = use_self_conditioning + self.num_samples = num_samples + self.classifier_free_guidance_dropout = classifier_free_guidance_dropout + self.do_input_centering = do_input_centering + self.estimate_variance = estimate_variance + self.clip_x0 = clip_x0 + + self.use_ema = use_ema + if use_ema: + self.ema_model = EMAModel(self.noise_estimator, **ema_kwargs) + + + + def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): + results = {} + x_0 = batch['source'] + condition = batch.get('target', None) + + # Embed into latent space or normalize + if self.latent_embedder is not None: + self.latent_embedder.eval() + with torch.no_grad(): + x_0 = self.latent_embedder.encode(x_0) + + if self.do_input_centering: + x_0 = 2*x_0-1 # [0, 1] -> [-1, 1] + + # if self.clip_x0: + # x_0 = torch.clamp(x_0, -1, 1) + + + # Sample Noise + with torch.no_grad(): + # Randomly selecting t [0,T-1] and compute x_t (noisy version of x_0 at t) + x_t, x_T, t = self.noise_scheduler.sample(x_0) + + # Use EMA Model + if self.use_ema and (state != 'train'): + noise_estimator = self.ema_model.averaged_model + else: + noise_estimator = self.noise_estimator + + # Re-estimate x_T or x_0, self-conditioned on previous estimate + self_cond = None + if self.use_self_conditioning: + with torch.no_grad(): + pred, pred_vertical = noise_estimator(x_t, t, condition, None) + if self.estimate_variance: + pred, _ = pred.chunk(2, dim = 1) # Seperate actual prediction and variance estimation + if self.estimator_objective == "x_T": # self condition on x_0 + self_cond = self.noise_scheduler.estimate_x_0(x_t, pred, t=t, clip_x0=self.clip_x0) + elif self.estimator_objective == "x_0": # self condition on x_T + self_cond = self.noise_scheduler.estimate_x_T(x_t, pred, t=t, clip_x0=self.clip_x0) + else: + raise NotImplementedError(f"Option estimator_target={self.estimator_objective} not supported.") + + # Classifier free guidance + if torch.rand(1) [0, 1] + pred_logvar = self.noise_scheduler.estimate_variance_t(t, x_t.ndim, log=True, var_scale=var_scale) + # pred_logvar = pred_var # If variance is estimated directly + + if self.estimator_objective == 'x_T': + pred_x_0 = self.noise_scheduler.estimate_x_0(x_t, x_T, t, clip_x0=self.clip_x0) + elif self.estimator_objective == "x_0": + pred_x_0 = pred + else: + raise NotImplementedError() + + with torch.no_grad(): + pred_mean = self.noise_scheduler.estimate_mean_t(x_t, pred_x_0, t) + true_mean = self.noise_scheduler.estimate_mean_t(x_t, x_0, t) + true_logvar = self.noise_scheduler.estimate_variance_t(t, x_t.ndim, log=True, var_scale=0) + + kl_loss = torch.mean(kl_gaussians(true_mean, true_logvar, pred_mean, pred_logvar), dim=list(range(1, x_0.ndim))) + nnl_loss = torch.mean(F.gaussian_nll_loss(pred_x_0, x_0, torch.exp(pred_logvar), reduction='none'), dim=list(range(1, x_0.ndim))) + var_loss = torch.mean(torch.where(t == 0, nnl_loss, kl_loss)) + loss += var_loss + + results['variance_scale'] = torch.mean(var_scale) + results['variance_loss'] = var_loss + + + # ----------------------------- Deep Supervision ------------------------- + for i, pred_i in enumerate(pred_vertical): + target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) + loss += self.loss_fct(pred_i, target_i)*weights[i+1] + results['loss'] = loss + + + + # --------------------- Compute Metrics ------------------------------- + with torch.no_grad(): + results['L2'] = F.mse_loss(pred, target) + results['L1'] = F.l1_loss(pred, target) + # results['SSIM'] = SSIMMetric(data_range=pred.max()-pred.min(), spatial_dims=source.ndim-2)(pred, target) + + # for i, pred_i in enumerate(pred_vertical): + # target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) + # results[f'L1_{i}'] = F.l1_loss(pred_i, target_i).detach() + + + + # ----------------- Log Scalars ---------------------- + for metric_name, metric_val in results.items(): + self.log(f"{state}/{metric_name}", metric_val, batch_size=x_0.shape[0], on_step=True, on_epoch=True) + + + #------------------ Log Image ----------------------- + if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: + dataformats = 'NHWC' if x_0.ndim == 5 else 'HWC' + def norm(x): + return (x-x.min())/(x.max()-x.min()) + + sample_cond = condition[0:self.num_samples] if condition is not None else None + sample_img = self.sample(num_samples=self.num_samples, img_size=x_0.shape[1:], condition=sample_cond).detach() + + log_step = self.global_step // self.sample_every_n_steps + # self.logger.experiment.add_images("predict_img", norm(torch.moveaxis(pred[0,-1:], 0,-1)), global_step=self.current_epoch, dataformats=dataformats) + # self.logger.experiment.add_images("target_img", norm(torch.moveaxis(target[0,-1:], 0,-1)), global_step=self.current_epoch, dataformats=dataformats) + + # self.logger.experiment.add_images("source_img", norm(torch.moveaxis(x_0[0,-1:], 0,-1)), global_step=log_step, dataformats=dataformats) + # self.logger.experiment.add_images("sample_img", norm(torch.moveaxis(sample_img[0,-1:], 0,-1)), global_step=log_step, dataformats=dataformats) + + path_out = Path(self.logger.log_dir)/'images' + path_out.mkdir(parents=True, exist_ok=True) + # for 3D images use depth as batch :[D, C, H, W], never show more than 32 images + def depth2batch(image): + return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1)) + images = depth2batch(sample_img)[:32] + save_image(images, path_out/f'sample_{log_step}.png', normalize=True) + + + return loss + + + def forward(self, x_t, t, condition=None, self_cond=None, guidance_scale=1.0, cold_diffusion=False, un_cond=None): + # Note: x_t expected to be in range ~ [-1, 1] + if self.use_ema: + noise_estimator = self.ema_model.averaged_model + else: + noise_estimator = self.noise_estimator + + # Concatenate inputs for guided and unguided diffusion as proposed by classifier-free-guidance + if (condition is not None) and (guidance_scale != 1.0): + # Model prediction + pred_uncond, _ = noise_estimator(x_t, t, condition=un_cond, self_cond=self_cond) + pred_cond, _ = noise_estimator(x_t, t, condition=condition, self_cond=self_cond) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + if self.estimate_variance: + pred_uncond, pred_var_uncond = pred_uncond.chunk(2, dim = 1) + pred_cond, pred_var_cond = pred_cond.chunk(2, dim = 1) + pred_var = pred_var_uncond + guidance_scale * (pred_var_cond - pred_var_uncond) + else: + pred, _ = noise_estimator(x_t, t, condition=condition, self_cond=self_cond) + if self.estimate_variance: + pred, pred_var = pred.chunk(2, dim = 1) + + if self.estimate_variance: + pred_var_scale = pred_var/2+0.5 # [-1, 1] -> [0, 1] + pred_var_value = pred_var + else: + pred_var_scale = 0 + pred_var_value = None + + # pred_var_scale = pred_var_scale.clamp(0, 1) + + if self.estimator_objective == 'x_0': + x_t_prior, x_0 = self.noise_scheduler.estimate_x_t_prior_from_x_0(x_t, t, pred, clip_x0=self.clip_x0, var_scale=pred_var_scale, cold_diffusion=cold_diffusion) + x_T = self.noise_scheduler.estimate_x_T(x_t, x_0=pred, t=t, clip_x0=self.clip_x0) + self_cond = x_T + elif self.estimator_objective == 'x_T': + x_t_prior, x_0 = self.noise_scheduler.estimate_x_t_prior_from_x_T(x_t, t, pred, clip_x0=self.clip_x0, var_scale=pred_var_scale, cold_diffusion=cold_diffusion) + x_T = pred + self_cond = x_0 + else: + raise ValueError("Unknown Objective") + + return x_t_prior, x_0, x_T, self_cond + + + @torch.no_grad() + def denoise(self, x_t, steps=None, condition=None, use_ddim=True, **kwargs): + self_cond = None + + # ---------- run denoise loop --------------- + if use_ddim: + steps = self.noise_scheduler.timesteps if steps is None else steps + timesteps_array = torch.linspace(0, self.noise_scheduler.T-1, steps, dtype=torch.long, device=x_t.device) # [0, 1, 2, ..., T-1] if steps = T + else: + timesteps_array = self.noise_scheduler.timesteps_array[slice(0, steps)] # [0, ...,T-1] (target time not time of x_t) + + st_prog_bar = st.progress(0) + for i, t in tqdm(enumerate(reversed(timesteps_array))): + st_prog_bar.progress((i+1)/len(timesteps_array)) + + # UNet prediction + x_t, x_0, x_T, self_cond = self(x_t, t.expand(x_t.shape[0]), condition, self_cond=self_cond, **kwargs) + self_cond = self_cond if self.use_self_conditioning else None + + if use_ddim and (steps-i-1>0): + t_next = timesteps_array[steps-i-2] + alpha = self.noise_scheduler.alphas_cumprod[t] + alpha_next = self.noise_scheduler.alphas_cumprod[t_next] + sigma = kwargs.get('eta', 1) * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c = (1 - alpha_next - sigma ** 2).sqrt() + noise = torch.randn_like(x_t) + x_t = x_0 * alpha_next.sqrt() + c * x_T + sigma * noise + + # ------ Eventually decode from latent space into image space-------- + if self.latent_embedder is not None: + x_t = self.latent_embedder.decode(x_t) + + return x_t # Should be x_0 in final step (t=0) + + @torch.no_grad() + def sample(self, num_samples, img_size, condition=None, **kwargs): + template = torch.zeros((num_samples, *img_size), device=self.device) + x_T = self.noise_scheduler.x_final(template) + x_0 = self.denoise(x_T, condition=condition, **kwargs) + return x_0 + + + @torch.no_grad() + def interpolate(self, img1, img2, i = None, condition=None, lam = 0.5, **kwargs): + assert img1.shape == img2.shape, "Image 1 and 2 must have equal shape" + + t = self.noise_scheduler.T-1 if i is None else i + t = torch.full(img1.shape[:1], i, device=img1.device) + + img1_t = self.noise_scheduler.estimate_x_t(img1, t=t, clip_x0=self.clip_x0) + img2_t = self.noise_scheduler.estimate_x_t(img2, t=t, clip_x0=self.clip_x0) + + img = (1 - lam) * img1_t + lam * img2_t + img = self.denoise(img, i, condition, **kwargs) + return img + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.ema_model.step(self.noise_estimator) + + def configure_optimizers(self): + optimizer = self.optimizer(self.noise_estimator.parameters(), **self.optimizer_kwargs) + if self.lr_scheduler is not None: + lr_scheduler = { + 'scheduler': self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs), + 'interval': 'step', + 'frequency': 1 + } + return [optimizer], [lr_scheduler] + else: + return [optimizer] \ No newline at end of file diff --git a/medical_diffusion/models/utils/__init__.py b/medical_diffusion/models/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..10ea9d8e4f14b8599dd72228142b06e54740d706 --- /dev/null +++ b/medical_diffusion/models/utils/__init__.py @@ -0,0 +1,2 @@ +from .attention_blocks import * +from .conv_blocks import * \ No newline at end of file diff --git a/medical_diffusion/models/utils/attention_blocks.py b/medical_diffusion/models/utils/attention_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..b609017118cf875bf31cc4c5302ecd4343e47e41 --- /dev/null +++ b/medical_diffusion/models/utils/attention_blocks.py @@ -0,0 +1,335 @@ +import torch.nn.functional as F +import torch.nn as nn +import torch + +from monai.networks.blocks import TransformerBlock +from monai.networks.layers.utils import get_norm_layer, get_dropout_layer +from monai.networks.layers.factories import Conv +from einops import rearrange + + +class GEGLU(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.norm = nn.LayerNorm(in_channels) + self.proj = nn.Linear(in_channels, out_channels*2, bias=True) + + def forward(self, x): + # x expected to be [B, C, *] + # Workaround as layer norm can't currently be applied on arbitrary dimension: https://github.com/pytorch/pytorch/issues/71465 + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).transpose(1, 2) # -> [B, C, N] -> [B, N, C] + x = self.norm(x) + x, gate = self.proj(x).chunk(2, dim=-1) + x = x * F.gelu(gate) + return x.transpose(1, 2).reshape(b, -1, *spatial) # -> [B, C, N] -> [B, C, *] + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +def compute_attention(q,k,v , num_heads, scale): + q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> (b h) d n', h=num_heads), (q, k, v)) # [(BxHeads), Dim_per_head, N] + + attn = (torch.einsum('b d i, b d j -> b i j', q*scale, k*scale)).softmax(dim=-1) # Matrix product = [(BxHeads), Dim_per_head, N] * [(BxHeads), Dim_per_head, N'] =[(BxHeads), N, N'] + + out = torch.einsum('b i j, b d j-> b d i', attn, v) # Matrix product: [(BxHeads), N, N'] * [(BxHeads), Dim_per_head, N'] = [(BxHeads), Dim_per_head, N] + out = rearrange(out, '(b h) d n-> b (h d) n', h=num_heads) # -> [B, (Heads x Dim_per_head), N] + + return out + + +class LinearTransformerNd(nn.Module): + """ Combines multi-head self-attention and multi-head cross-attention. + + Multi-Head Self-Attention: + Similar to multi-head self-attention (https://arxiv.org/abs/1706.03762) without Norm+MLP (compare Monai TransformerBlock) + Proposed here: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/diffusionmodules/openaimodel.py#L278 + Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L80 + Similar to: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/dfbafee555bdae80b55d63a989073836bbfc257e/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L209 + Similar to: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py#L150 + + CrossAttention: + Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L152 + + """ + def __init__( + self, + spatial_dims, + in_channels, + out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled + num_heads=8, + ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) + norm_name=("GROUP", {'num_groups':32, "affine": True}), # Or use LayerNorm but be aware of https://github.com/pytorch/pytorch/issues/71465 (=> GroupNorm with num_groups=1) + dropout=None, + emb_dim=None, + ): + super().__init__() + hid_channels = num_heads*ch_per_head + self.num_heads = num_heads + self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale + + self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) + emb_dim = in_channels if emb_dim is None else emb_dim + + Convolution = Conv["conv", spatial_dims] + self.to_q = Convolution(in_channels, hid_channels, 1) + self.to_k = Convolution(emb_dim, hid_channels, 1) + self.to_v = Convolution(emb_dim, hid_channels, 1) + + self.to_out = nn.Sequential( + zero_module(Convolution(hid_channels, out_channels, 1)), + nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims) + ) + + def forward(self, x, embedding=None): + # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] + # if no embedding is given, cross-attention defaults to self-attention + + # Normalize + b, c, *spatial = x.shape + x_n = self.norm_x(x) + + # Attention: embedding (cross-attention) or x (self-attention) + if embedding is None: + embedding = x_n # WARNING: This assumes that emb_dim==in_channels + else: + if embedding.ndim == 2: + embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *] + # Why no normalization for embedding here? + + # Convolution + q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), *] + k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), *] + v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), *] + + # Flatten + q = q.reshape(b, c, -1) # -> [B, (Heads x Dim_per_head), N] + k = k.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N'] + v = v.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N'] + + # Apply attention + out = compute_attention(q, k, v, self.num_heads, self.scale) + + out = out.reshape(*out.shape[:2], *spatial) # -> [B, (Heads x Dim_per_head), *] + out = self.to_out(out) # -> [B, C', *] + + + if x.shape == out.shape: + out = x + out + return out # [B, C', *] + + +class LinearTransformer(nn.Module): + """ See LinearTransformer, however this implementation is fixed to Conv1d/Linear""" + def __init__( + self, + spatial_dims, + in_channels, + out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled + num_heads, + ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) + norm_name=("GROUP", {'num_groups':32, "affine": True}), + dropout=None, + emb_dim=None + ): + super().__init__() + hid_channels = num_heads*ch_per_head + self.num_heads = num_heads + self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale + + self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) + emb_dim = in_channels if emb_dim is None else emb_dim + + # Note: Conv1d and Linear are interchangeable but order of input changes [B, C, N] <-> [B, N, C] + self.to_q = nn.Conv1d(in_channels, hid_channels, 1) + self.to_k = nn.Conv1d(emb_dim, hid_channels, 1) + self.to_v = nn.Conv1d(emb_dim, hid_channels, 1) + # self.to_qkv = nn.Conv1d(emb_dim, hid_channels*3, 1) + + self.to_out = nn.Sequential( + zero_module(nn.Conv1d(hid_channels, out_channels, 1)), + nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims) + ) + + def forward(self, x, embedding=None): + # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] + # if no embedding is given, cross-attention defaults to self-attention + + # Normalize + b, c, *spatial = x.shape + x_n = self.norm_x(x) + + # Attention: embedding (cross-attention) or x (self-attention) + if embedding is None: + embedding = x_n # WARNING: This assumes that emb_dim==in_channels + else: + if embedding.ndim == 2: + embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *] + # Why no normalization for embedding here? + + # Flatten + x_n = x_n.reshape(b, c, -1) # [B, C, *] -> [B, C, N] + embedding = embedding.reshape(*embedding.shape[:2], -1) # [B, C*, *] -> [B, C*, N'] + + # Convolution + q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), N] + k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), N'] + v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), N'] + # qkv = self.to_qkv(x_n) + # q,k,v = qkv.split(qkv.shape[1]//3, dim=1) + + # Apply attention + out = compute_attention(q, k, v, self.num_heads, self.scale) + + out = self.to_out(out) # -> [B, C', N] + out = out.reshape(*out.shape[:2], *spatial) # -> [B, C', *] + + if x.shape == out.shape: + out = x + out + return out # [B, C', *] + + + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_channels, + out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled + num_heads, + ch_per_head=32, + norm_name=("GROUP", {'num_groups':32, "affine": True}), + dropout=None, + emb_dim=None + ): + super().__init__() + self.self_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, None) + if emb_dim is not None: + self.cros_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, emb_dim) + self.proj_out = nn.Sequential( + GEGLU(in_channels, in_channels*4), + nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims), + Conv["conv", spatial_dims](in_channels*4, out_channels, 1, bias=True) + ) + + + def forward(self, x, embedding=None): + # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] + x = self.self_atn(x) + if embedding is not None: + x = self.cros_atn(x, embedding=embedding) + out = self.proj_out(x) + if out.shape[1] == x.shape[1]: + return out + x + return x + +class SpatialTransformer(nn.Module): + """ Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L218 + Unrelated to: https://arxiv.org/abs/1506.02025 + """ + def __init__( + self, + spatial_dims, + in_channels, + out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled + num_heads, + ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) + norm_name = ("GROUP", {'num_groups':32, "affine": True}), + dropout=None, + emb_dim=None, + depth=1 + ): + super().__init__() + self.in_channels = in_channels + self.norm = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) + conv_class = Conv["conv", spatial_dims] + hid_channels = num_heads*ch_per_head + + self.proj_in = conv_class( + in_channels, + hid_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock(spatial_dims, hid_channels, hid_channels, num_heads, ch_per_head, norm_name, dropout=dropout, emb_dim=emb_dim) + for _ in range(depth)] + ) + + self.proj_out = conv_class( # Note: zero_module is used in original code + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, x, embedding=None): + # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] + # Note: if no embedding is given, cross-attention is disabled + h = self.norm(x) + h = self.proj_in(h) + + for block in self.transformer_blocks: + h = block(h, embedding=embedding) + + h = self.proj_out(h) # -> [B, C'', *] + if h.shape == x.shape: + return h + x + return h + + +class Attention(nn.Module): + def __init__( + self, + spatial_dims, + in_channels, + out_channels, + num_heads=8, + ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) + norm_name = ("GROUP", {'num_groups':32, "affine": True}), + dropout=0, + emb_dim=None, + depth=1, + attention_type='linear' + ) -> None: + super().__init__() + if attention_type == 'spatial': + self.attention = SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + num_heads=num_heads, + ch_per_head=ch_per_head, + depth=depth, + norm_name=norm_name, + dropout=dropout, + emb_dim=emb_dim + ) + elif attention_type == 'linear': + self.attention = LinearTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + num_heads=num_heads, + ch_per_head=ch_per_head, + norm_name=norm_name, + dropout=dropout, + emb_dim=emb_dim + ) + + + def forward(self, x, emb=None): + if hasattr(self, 'attention'): + return self.attention(x, emb) + else: + return x \ No newline at end of file diff --git a/medical_diffusion/models/utils/conv_blocks.py b/medical_diffusion/models/utils/conv_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..ad87d4937f85c1f9638548ada984634ff5ea75fa --- /dev/null +++ b/medical_diffusion/models/utils/conv_blocks.py @@ -0,0 +1,528 @@ +from typing import Optional, Sequence, Tuple, Union, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +from monai.networks.blocks.dynunet_block import get_padding, get_output_padding +from monai.networks.layers import Pool, Conv +from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_dropout_layer +from monai.utils.misc import ensure_tuple_rep + +from medical_diffusion.models.utils.attention_blocks import Attention, zero_module + +def save_add(*args): + args = [arg for arg in args if arg is not None] + return sum(args) if len(args)>0 else None + + +class SequentialEmb(nn.Sequential): + def forward(self, input, emb): + for module in self: + input = module(input, emb) + return input + + +class BasicDown(nn.Module): + def __init__( + self, + spatial_dims, + in_channels, + out_channels, + kernel_size=3, + stride=2, + learnable_interpolation=True, + use_res=False + ) -> None: + super().__init__() + + if learnable_interpolation: + Convolution = Conv[Conv.CONV, spatial_dims] + self.down_op = Convolution( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=get_padding(kernel_size, stride), + dilation=1, + groups=1, + bias=True, + ) + + if use_res: + self.down_skip = nn.PixelUnshuffle(2) # WARNING: Only supports 2D, , out_channels == 4*in_channels + + else: + Pooling = Pool['avg', spatial_dims] + self.down_op = Pooling( + kernel_size=kernel_size, + stride=stride, + padding=get_padding(kernel_size, stride) + ) + + + def forward(self, x, emb=None): + y = self.down_op(x) + if hasattr(self, 'down_skip'): + y = y+self.down_skip(x) + return y + +class BasicUp(nn.Module): + def __init__( + self, + spatial_dims, + in_channels, + out_channels, + kernel_size=2, + stride=2, + learnable_interpolation=True, + use_res=False, + ) -> None: + super().__init__() + self.learnable_interpolation = learnable_interpolation + if learnable_interpolation: + # TransConvolution = Conv[Conv.CONVTRANS, spatial_dims] + # padding = get_padding(kernel_size, stride) + # output_padding = get_output_padding(kernel_size, stride, padding) + # self.up_op = TransConvolution( + # in_channels, + # out_channels, + # kernel_size=kernel_size, + # stride=stride, + # padding=padding, + # output_padding=output_padding, + # groups=1, + # bias=True, + # dilation=1 + # ) + + self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size) + -2*np.atleast_1d(get_padding(kernel_size, stride))) + Convolution = Conv[Conv.CONV, spatial_dims] + self.up_op = Convolution( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=True, + ) + + if use_res: + self.up_skip = nn.PixelShuffle(2) # WARNING: Only supports 2D, out_channels == in_channels/4 + else: + self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size) + -2*np.atleast_1d(get_padding(kernel_size, stride))) + + def forward(self, x, emb=None): + if self.learnable_interpolation: + new_size = self.calc_shape(x.shape[2:]) + x_res = F.interpolate(x, size=new_size, mode='nearest-exact') + y = self.up_op(x_res) + if hasattr(self, 'up_skip'): + y = y+self.up_skip(x) + return y + else: + new_size = self.calc_shape(x.shape[2:]) + return F.interpolate(x, size=new_size, mode='nearest-exact') + + +class BasicBlock(nn.Module): + """ + A block that consists of Conv-Norm-Drop-Act, similar to blocks.Convolution. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + zero_conv: zero out the parameters of the convolution. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int]=1, + norm_name: Union[Tuple, str, None]=None, + act_name: Union[Tuple, str, None] = None, + dropout: Optional[Union[Tuple, str, float]] = None, + zero_conv: bool = False, + ): + super().__init__() + Convolution = Conv[Conv.CONV, spatial_dims] + conv = Convolution( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=get_padding(kernel_size, stride), + dilation=1, + groups=1, + bias=True, + ) + self.conv = zero_module(conv) if zero_conv else conv + + if norm_name is not None: + self.norm = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + if dropout is not None: + self.drop = get_dropout_layer(name=dropout, dropout_dim=spatial_dims) + if act_name is not None: + self.act = get_act_layer(name=act_name) + + + def forward(self, inp): + out = self.conv(inp) + if hasattr(self, "norm"): + out = self.norm(out) + if hasattr(self, 'drop'): + out = self.drop(out) + if hasattr(self, "act"): + out = self.act(out) + return out + +class BasicResBlock(nn.Module): + """ + A block that consists of Conv-Act-Norm + skip. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + zero_conv: zero out the parameters of the convolution. + """ + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int]=1, + norm_name: Union[Tuple, str, None]=None, + act_name: Union[Tuple, str, None] = None, + dropout: Optional[Union[Tuple, str, float]] = None, + zero_conv: bool = False + ): + super().__init__() + self.basic_block = BasicBlock(spatial_dims, in_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, zero_conv) + Convolution = Conv[Conv.CONV, spatial_dims] + self.conv_res = Convolution( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + padding=get_padding(1, stride), + dilation=1, + groups=1, + bias=True, + ) if in_channels != out_channels else nn.Identity() + + + def forward(self, inp): + out = self.basic_block(inp) + residual = self.conv_res(inp) + out = out+residual + return out + + + +class UnetBasicBlock(nn.Module): + """ + A modified version of monai.networks.blocks.UnetBasicBlock with additional embedding + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + emb_channels: Number of embedding channels + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int]=1, + norm_name: Union[Tuple, str]=None, + act_name: Union[Tuple, str]=None, + dropout: Optional[Union[Tuple, str, float]] = None, + emb_channels: int = None, + blocks = 2 + ): + super().__init__() + self.block_seq = nn.ModuleList([ + BasicBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1) + for i in range(blocks) + ]) + + if emb_channels is not None: + self.local_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(emb_channels, out_channels), + ) + + def forward(self, x, emb=None): + # ------------ Embedding ---------- + if emb is not None: + emb = self.local_embedder(emb) + b,c, *_ = emb.shape + sp_dim = x.ndim-2 + emb = emb.reshape(b, c, *((1,)*sp_dim) ) + # scale, shift = emb.chunk(2, dim = 1) + # x = x * (scale + 1) + shift + # x = x+emb + + # ----------- Convolution --------- + n_blocks = len(self.block_seq) + for i, block in enumerate(self.block_seq): + x = block(x) + if (emb is not None) and i [-1, 1] + with torch.no_grad(): + imgs_fake_batch = model(imgs_real_batch)[0].clamp(-1, 1) + + # -------------- LPIP ------------------- + calc_lpips.update(imgs_real_batch, imgs_fake_batch) # expect input to be [-1, 1] + + # -------------- MS-SSIM + MSE ------------------- + for img_real, img_fake in zip(imgs_real_batch, imgs_fake_batch): + img_real, img_fake = (img_real+1)/2, (img_fake+1)/2 # [-1, 1] -> [0, 1] + mmssim_list.append(mmssim(img_real[None], img_fake[None], normalize='relu')) + mse_list.append(torch.mean(torch.square(img_real-img_fake))) + + +# -------------- Summary ------------------- +mmssim_list = torch.stack(mmssim_list) +mse_list = torch.stack(mse_list) + +lpips = 1-calc_lpips.compute() +logger.info(f"LPIPS Score: {lpips}") +logger.info(f"MS-SSIM: {torch.mean(mmssim_list)} ± {torch.std(mmssim_list)}") +logger.info(f"MSE: {torch.mean(mse_list)} ± {torch.std(mse_list)}") \ No newline at end of file diff --git a/scripts/helpers/dump_discrimnator.py b/scripts/helpers/dump_discrimnator.py new file mode 100755 index 0000000000000000000000000000000000000000..8ed1e69a510f0a8d971b40f68b0c07b26507ca60 --- /dev/null +++ b/scripts/helpers/dump_discrimnator.py @@ -0,0 +1,26 @@ +from pathlib import Path +import torch +from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint + +path_root = Path('runs/2022_12_01_210017_patho_vaegan') + +# Load model +model = VAEGAN.load_from_checkpoint(path_root/'last.ckpt') +# model = torch.load(path_root/'last.ckpt') + + + +# Save model-part +# torch.save(model.vqvae, path_root/'last_vae.ckpt') # Not working +# ------ Ugly workaround ---------- +checkpointing = ModelCheckpoint() +trainer = Trainer(callbacks=[checkpointing]) +trainer.strategy._lightning_module = model.vqvae +trainer.model = model.vqvae +trainer.save_checkpoint(path_root/'last_vae.ckpt') +# ----------------- + +model = VAE.load_from_checkpoint(path_root/'last_vae.ckpt') +# model = torch.load(path_root/'last_vae.ckpt') # load_state_dict \ No newline at end of file diff --git a/scripts/helpers/export_example_gifs.py b/scripts/helpers/export_example_gifs.py new file mode 100755 index 0000000000000000000000000000000000000000..0c752556ede085c0c017c3a18e628114cc4e9932 --- /dev/null +++ b/scripts/helpers/export_example_gifs.py @@ -0,0 +1,34 @@ + +from pathlib import Path +from PIL import Image +import numpy as np + + + +if __name__ == "__main__": + path_out = Path.cwd()/'media/' + path_out.mkdir(parents=True, exist_ok=True) + + # imgs = [] + # for img_i in range(50): + # for label_a, label_b, label_c in [('NRG', 'No_Cardiomegaly', 'nonMSIH'), ('RG', 'Cardiomegaly', 'MSIH')]: + # img_a = Image.open(f'/mnt/hdd/datasets/eye/AIROGS/data_generated_diffusion/{label_a}/fake_{img_i}.png').quantize(200, 0).convert('RGB') + # img_b = Image.open(f'/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_diffusion2_150/{label_b}/fake_{img_i}.png').quantize(50, 0).convert('RGB') + # img_c = Image.open(f'/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_150/{label_c}/fake_{img_i}.png').resize((256, 256)).quantize(10, 0).convert('RGB') + + # img = Image.fromarray(np.concatenate([np.array(img_a), np.array(img_b), np.array(img_c)], axis=1), 'RGB').quantize(256, 1) + # imgs.append(img) + + # imgs[0].save(fp=path_out/f'animation.gif', format='GIF', append_images=imgs[1:], optimize=False, save_all=True, duration=500, loop=0) + + imgs = [] + path_root = Path('/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_150') + for img_i in range(50): + for path_label in path_root.iterdir(): + img = Image.open(path_label/f'fake_{img_i}.png').resize((256, 256)) + imgs.append(img) + + imgs[0].save(fp=path_out/f'animation_histo.gif', format='GIF', append_images=imgs[1:], optimize=False, save_all=True, duration=500, loop=0) + + + \ No newline at end of file diff --git a/scripts/helpers/export_random_images.py b/scripts/helpers/export_random_images.py new file mode 100755 index 0000000000000000000000000000000000000000..de957da97b66d09f679402589512429238ebb725 --- /dev/null +++ b/scripts/helpers/export_random_images.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import torch +import numpy as np +from PIL import Image +from torchvision.utils import save_image + + + + + +# class_2 = 'RG' +# class_1 = 'NRG' +# path_out = Path().cwd()/'results'/'AIROGS'/'generated_images' +# path_root = Path('/mnt/hdd/datasets/eye/AIROGS/data_generated_diffusion/') +# path_root = Path('/mnt/hdd/datasets/eye/AIROGS/data_generated_stylegan3') +# path_root = Path('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/') + +class_2 = 'Cardiomegaly' +class_1 = 'No_Cardiomegaly' +path_out = Path().cwd()/'results'/'CheXpert'/'generated_images' +path_root = Path('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_diffusion3_150/') +# path_root = Path('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_progan/') +# path_root = Path('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference/') + +# class_2 = 'MSIH' +# class_1 = 'nonMSIH' +# path_out = Path().cwd()/'results'/'MSIvsMSS_2'/'generated_images' +# path_root = Path('/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_150/') +# path_root = Path('/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/SYNTH-CRC-10K/') +# path_root = Path('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train') + +num = 2 +np.random.seed(2) +a = np.random.randint(0, 1000) +b = np.random.randint(0, 1000) +print(a, b) + +path_out.mkdir(parents=True, exist_ok=True) +paths_class_1 = [path_img for n, path_img in enumerate((path_root/class_1).iterdir()) if a<=n [0, 1] + # diff = torch.abs(images[1]-images[0]) + utils.save_image(diff, path_out/'diff.png', nrow=int(math.sqrt(results.shape[0])), normalize=True, scale_each=True) # For 2D images: [B, C, H, W] + + + \ No newline at end of file diff --git a/scripts/train_diffusion.py b/scripts/train_diffusion.py new file mode 100755 index 0000000000000000000000000000000000000000..6416169bf9c1efe9883ea585738b5b7206452116 --- /dev/null +++ b/scripts/train_diffusion.py @@ -0,0 +1,183 @@ + +from email.mime import audio +from pathlib import Path +from datetime import datetime + +import torch +import torch.nn as nn +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +import numpy as np +import torchio as tio + +from medical_diffusion.data.datamodules import SimpleDataModule +from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset +from medical_diffusion.models.pipelines import DiffusionPipeline +from medical_diffusion.models.estimators import UNet +from medical_diffusion.external.stable_diffusion.unet_openai import UNetModel +from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler +from medical_diffusion.models.embedders import LabelEmbedder, TimeEmbbeding +from medical_diffusion.models.embedders.latent_embedders import VAE, VAEGAN, VQVAE, VQGAN + +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + + + +if __name__ == "__main__": + # ------------ Load Data ---------------- + # ds = AIROGSDataset( + # crawler_ext='jpg', + # augment_horizontal_flip = False, + # augment_vertical_flip = False, + # # path_root='/home/gustav/Documents/datasets/AIROGS/data_256x256/', + # path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', + # ) + + # ds = MSIvsMSS_2_Dataset( + # crawler_ext='jpg', + # image_resize=None, + # image_crop=None, + # augment_horizontal_flip=False, + # augment_vertical_flip=False, + # # path_root='/home/gustav/Documents/datasets/Kather_2/train', + # path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', + # ) + + ds = CheXpert_2_Dataset( # 256x256 + augment_horizontal_flip=False, + augment_vertical_flip=False, + path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' + ) + + dm = SimpleDataModule( + ds_train = ds, + batch_size=32, + # num_workers=0, + pin_memory=True, + # weights=ds.get_weights() + ) + + current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") + path_run_dir = Path.cwd() / 'runs' / str(current_time) + path_run_dir.mkdir(parents=True, exist_ok=True) + accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' + + + + # ------------ Initialize Model ------------ + # cond_embedder = None + cond_embedder = LabelEmbedder + cond_embedder_kwargs = { + 'emb_dim': 1024, + 'num_classes': 2 + } + + + time_embedder = TimeEmbbeding + time_embedder_kwargs ={ + 'emb_dim': 1024 # stable diffusion uses 4*model_channels (model_channels is about 256) + } + + + noise_estimator = UNet + noise_estimator_kwargs = { + 'in_ch':8, + 'out_ch':8, + 'spatial_dims':2, + 'hid_chs': [ 256, 256, 512, 1024], + 'kernel_sizes':[3, 3, 3, 3], + 'strides': [1, 2, 2, 2], + 'time_embedder':time_embedder, + 'time_embedder_kwargs': time_embedder_kwargs, + 'cond_embedder':cond_embedder, + 'cond_embedder_kwargs': cond_embedder_kwargs, + 'deep_supervision': False, + 'use_res_block':True, + 'use_attention':'none', + } + + + # ------------ Initialize Noise ------------ + noise_scheduler = GaussianNoiseScheduler + noise_scheduler_kwargs = { + 'timesteps': 1000, + 'beta_start': 0.002, # 0.0001, 0.0015 + 'beta_end': 0.02, # 0.01, 0.0195 + 'schedule_strategy': 'scaled_linear' + } + + # ------------ Initialize Latent Space ------------ + # latent_embedder = None + # latent_embedder = VQVAE + latent_embedder = VAE + latent_embedder_checkpoint = 'runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt' + + # ------------ Initialize Pipeline ------------ + pipeline = DiffusionPipeline( + noise_estimator=noise_estimator, + noise_estimator_kwargs=noise_estimator_kwargs, + noise_scheduler=noise_scheduler, + noise_scheduler_kwargs = noise_scheduler_kwargs, + latent_embedder=latent_embedder, + latent_embedder_checkpoint = latent_embedder_checkpoint, + estimator_objective='x_T', + estimate_variance=False, + use_self_conditioning=False, + use_ema=False, + classifier_free_guidance_dropout=0.5, # Disable during training by setting to 0 + do_input_centering=False, + clip_x0=False, + sample_every_n_steps=1000 + ) + + # pipeline_old = pipeline.load_from_checkpoint('runs/2022_11_27_085654_chest_diffusion/last.ckpt') + # pipeline.noise_estimator.load_state_dict(pipeline_old.noise_estimator.state_dict(), strict=True) + + # -------------- Training Initialization --------------- + to_monitor = "train/loss" # "pl/val_loss" + min_max = "min" + save_and_sample_every = 100 + + early_stopping = EarlyStopping( + monitor=to_monitor, + min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement + patience=30, # number of checks with no improvement + mode=min_max + ) + checkpointing = ModelCheckpoint( + dirpath=str(path_run_dir), # dirpath + monitor=to_monitor, + every_n_train_steps=save_and_sample_every, + save_last=True, + save_top_k=2, + mode=min_max, + ) + trainer = Trainer( + accelerator=accelerator, + # devices=[0], + # precision=16, + # amp_backend='apex', + # amp_level='O2', + # gradient_clip_val=0.5, + default_root_dir=str(path_run_dir), + callbacks=[checkpointing], + # callbacks=[checkpointing, early_stopping], + enable_checkpointing=True, + check_val_every_n_epoch=1, + log_every_n_steps=save_and_sample_every, + auto_lr_find=False, + # limit_train_batches=1000, + limit_val_batches=0, # 0 = disable validation - Note: Early Stopping no longer available + min_epochs=100, + max_epochs=1001, + num_sanity_val_steps=2, + ) + + # ---------------- Execute Training ---------------- + trainer.fit(pipeline, datamodule=dm) + + # ------------- Save path to best model ------------- + pipeline.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path) + + diff --git a/scripts/train_latent_embedder_2d.py b/scripts/train_latent_embedder_2d.py new file mode 100755 index 0000000000000000000000000000000000000000..596ab7a1cdd85ebeba0d651d899f2d573c395ad9 --- /dev/null +++ b/scripts/train_latent_embedder_2d.py @@ -0,0 +1,180 @@ + + + + + +from pathlib import Path +from datetime import datetime + +import torch +from torch.utils.data import ConcatDataset +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint + + +from medical_diffusion.data.datamodules import SimpleDataModule +from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset +from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN + +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +if __name__ == "__main__": + + # --------------- Settings -------------------- + current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") + path_run_dir = Path.cwd() / 'runs' / str(current_time) + path_run_dir.mkdir(parents=True, exist_ok=True) + gpus = [0] if torch.cuda.is_available() else None + + + # ------------ Load Data ---------------- + # ds_1 = AIROGSDataset( # 256x256 + # crawler_ext='jpg', + # augment_horizontal_flip=True, + # augment_vertical_flip=True, + # # path_root='/home/gustav/Documents/datasets/AIROGS/dataset', + # path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', + # ) + + # ds_2 = MSIvsMSS_2_Dataset( # 512x512 + # # image_resize=256, + # crawler_ext='jpg', + # augment_horizontal_flip=True, + # augment_vertical_flip=True, + # # path_root='/home/gustav/Documents/datasets/Kather_2/train' + # path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/' + # ) + + ds_3 = CheXpert_2_Dataset( # 256x256 + # image_resize=128, + augment_horizontal_flip=False, + augment_vertical_flip=False, + # path_root = '/home/gustav/Documents/datasets/CheXpert/preprocessed_tianyu' + path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' + ) + + # ds = ConcatDataset([ds_1, ds_2, ds_3]) + + dm = SimpleDataModule( + ds_train = ds_3, + batch_size=8, + # num_workers=0, + pin_memory=True + ) + + + # ------------ Initialize Model ------------ + model = VAE( + in_channels=3, + out_channels=3, + emb_channels=8, + spatial_dims=2, + hid_chs = [ 64, 128, 256, 512], + kernel_sizes=[ 3, 3, 3, 3], + strides = [ 1, 2, 2, 2], + deep_supervision=1, + use_attention= 'none', + loss = torch.nn.MSELoss, + # optimizer_kwargs={'lr':1e-6}, + embedding_loss_weight=1e-6 + ) + + # model.load_pretrained(Path.cwd()/'runs/2022_12_01_183752_patho_vae/last.ckpt', strict=True) + + # model = VAEGAN( + # in_channels=3, + # out_channels=3, + # emb_channels=8, + # spatial_dims=2, + # hid_chs = [ 64, 128, 256, 512], + # deep_supervision=1, + # use_attention= 'none', + # start_gan_train_step=-1, + # embedding_loss_weight=1e-6 + # ) + + # model.vqvae.load_pretrained(Path.cwd()/'runs/2022_11_25_082209_chest_vae/last.ckpt') + # model.load_pretrained(Path.cwd()/'runs/2022_11_25_232957_patho_vaegan/last.ckpt') + + + # model = VQVAE( + # in_channels=3, + # out_channels=3, + # emb_channels=4, + # num_embeddings = 8192, + # spatial_dims=2, + # hid_chs = [64, 128, 256, 512], + # embedding_loss_weight=1, + # beta=1, + # loss = torch.nn.L1Loss, + # deep_supervision=1, + # use_attention = 'none', + # ) + + + # model = VQGAN( + # in_channels=3, + # out_channels=3, + # emb_channels=4, + # num_embeddings = 8192, + # spatial_dims=2, + # hid_chs = [64, 128, 256, 512], + # embedding_loss_weight=1, + # beta=1, + # start_gan_train_step=-1, + # pixel_loss = torch.nn.L1Loss, + # deep_supervision=1, + # use_attention='none', + # ) + + # model.vqvae.load_pretrained(Path.cwd()/'runs/2022_12_13_093727_patho_vqvae/last.ckpt') + + + # -------------- Training Initialization --------------- + to_monitor = "train/L1" # "val/loss" + min_max = "min" + save_and_sample_every = 50 + + early_stopping = EarlyStopping( + monitor=to_monitor, + min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement + patience=30, # number of checks with no improvement + mode=min_max + ) + checkpointing = ModelCheckpoint( + dirpath=str(path_run_dir), # dirpath + monitor=to_monitor, + every_n_train_steps=save_and_sample_every, + save_last=True, + save_top_k=5, + mode=min_max, + ) + trainer = Trainer( + accelerator='gpu', + devices=[0], + # precision=16, + # amp_backend='apex', + # amp_level='O2', + # gradient_clip_val=0.5, + default_root_dir=str(path_run_dir), + callbacks=[checkpointing], + # callbacks=[checkpointing, early_stopping], + enable_checkpointing=True, + check_val_every_n_epoch=1, + log_every_n_steps=save_and_sample_every, + auto_lr_find=False, + # limit_train_batches=1000, + limit_val_batches=0, # 0 = disable validation - Note: Early Stopping no longer available + min_epochs=100, + max_epochs=1001, + num_sanity_val_steps=2, + ) + + # ---------------- Execute Training ---------------- + trainer.fit(model, datamodule=dm) + + # ------------- Save path to best model ------------- + model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path) + + diff --git a/setup.py b/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..2d659285650d9875515d363064a06904a72e2590 --- /dev/null +++ b/setup.py @@ -0,0 +1,20 @@ +from setuptools import setup, find_packages + +with open('README.md', encoding='utf-8') as f: + long_description = f.read() + +with open('requirements.txt', encoding='utf-8') as f: + install_requires = f.read() + + + +setup( + name='Medical Diffusion', + author="", + version="1.0", + description="Diffusion model for medical images", + long_description=long_description, + long_description_content_type="text/markdown", + packages=find_packages(exclude=['contrib', 'docs', 'tests']), + install_requires=install_requires, +) \ No newline at end of file diff --git a/streamlit/pages/chest.py b/streamlit/pages/chest.py new file mode 100755 index 0000000000000000000000000000000000000000..524b38d8ad3cb4ab4bf185c3a24dcc2564c43d70 --- /dev/null +++ b/streamlit/pages/chest.py @@ -0,0 +1,41 @@ +import streamlit as st +import torch +import numpy as np + +from medical_diffusion.models.pipelines import DiffusionPipeline + +st.title("Chest X-ray images", anchor=None) +st.sidebar.markdown("Medfusion for chest X-ray image generation") +st.header('Information') +st.markdown('Medfusion was trained on the [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) dataset') + +st.header('Settings') +n_samples = st.number_input("Samples", min_value=1, max_value=25, value=4) +steps = st.number_input("Sampling steps", min_value=1, max_value=999, value=50) +guidance_scale = st.number_input("Guidance scale", min_value=1, max_value=10, value=1) +seed = st.number_input("Seed", min_value=0, max_value=None, value=1) +cond_str = st.radio("Cardiomegaly", ('Yes', 'No'), index=1, help="Conditioned on 'cardiomegaly' or 'no cardiomegaly'", horizontal=True) +torch.manual_seed(seed) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +@st.cache(allow_output_mutation = True) +def init_pipeline(): + pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_11_27_085654_chest_diffusion/last.ckpt') + return pipeline + +if st.button('Sample'): + cond = {'Yes':1, 'No':0}[cond_str] + condition = torch.tensor([cond]*n_samples, device=device) + un_cond = torch.tensor([1-cond]*n_samples, device=device) + + pipeline = init_pipeline() + pipeline.to(device) + images = pipeline.sample(n_samples, (8, 32, 32), guidance_scale=guidance_scale, condition=condition, un_cond=un_cond, steps=steps, use_ddim=True ) + + images = images.clamp(-1, 1) + images = images.cpu().numpy() # [B, C, H, W] + images = (images+1)/2 # Transform from [-1, 1] to [0, 1] + + images = [np.moveaxis(img, 0, -1) for img in images] + st.image(images, channels="RGB", output_format='png') # expects (w,h,3) \ No newline at end of file diff --git a/streamlit/pages/colon.py b/streamlit/pages/colon.py new file mode 100755 index 0000000000000000000000000000000000000000..b0e0953b0871f0b71acafb6ecf698f40936cc1ae --- /dev/null +++ b/streamlit/pages/colon.py @@ -0,0 +1,43 @@ +import streamlit as st +import torch +import numpy as np + +from medical_diffusion.models.pipelines import DiffusionPipeline + +st.title("Colon histology images", anchor=None) +st.sidebar.markdown("Medfusion for colon histology image generation") +st.header('Information') +st.markdown('Medfusion was trained on the [CRC-DX](https://zenodo.org/record/3832231#.Y29uInbMKbg) dataset') + + + +st.header('Settings') +n_samples = st.number_input("Samples", min_value=1, max_value=25, value=4) +steps = st.number_input("Sampling steps", min_value=1, max_value=999, value=50) +guidance_scale = st.number_input("Guidance scale", min_value=1, max_value=10, value=1) +seed = st.number_input("Seed", min_value=0, max_value=None, value=1) +cond_str = st.radio("Microsatellite stable", ('Yes', 'No'), index=1, help="Conditioned on 'microsatellite stable (MSS)' or 'microsatellite instable (MSI)'", horizontal=True) +torch.manual_seed(seed) +device_str = 'cuda' if torch.cuda.is_available() else 'cpu' +device = torch.device(device_str) + +@st.cache(allow_output_mutation = True) +def init_pipeline(): + pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_12_02_174623_patho_diffusion/last.ckpt') + return pipeline + +if st.button(f'Sample (using {device_str})'): + cond = {'Yes':1, 'No':0}[cond_str] + condition = torch.tensor([cond]*n_samples, device=device) + un_cond = torch.tensor([1-cond]*n_samples, device=device) + + pipeline = init_pipeline() + pipeline.to(device) + images = pipeline.sample(n_samples, (4, 64, 64), guidance_scale=guidance_scale, condition=condition, un_cond=un_cond, steps=steps, use_ddim=True ) + + images = images.clamp(-1, 1) + images = images.cpu().numpy() # [B, C, H, W] + images = (images+1)/2 # Transform from [-1, 1] to [0, 1] + + images = [np.moveaxis(img, 0, -1) for img in images] + st.image(images, channels="RGB", output_format='png') # expects (w,h,3) \ No newline at end of file diff --git a/streamlit/pages/eye.py b/streamlit/pages/eye.py new file mode 100755 index 0000000000000000000000000000000000000000..ccb7c08506877f4fa16a1d52df818e51fba11f05 --- /dev/null +++ b/streamlit/pages/eye.py @@ -0,0 +1,41 @@ +import streamlit as st +import torch +import numpy as np + +from medical_diffusion.models.pipelines import DiffusionPipeline + +st.title("Eye fundus images", anchor=None) +st.sidebar.markdown("Medfusion for eye fundus image generation") +st.header('Information') +st.markdown('Medfusion was trained on the [AIROGS](https://airogs.grand-challenge.org/data-and-challenge/) dataset') + + +st.header('Settings') +n_samples = st.number_input("Samples", min_value=1, max_value=25, value=4) +steps = st.number_input("Sampling steps", min_value=1, max_value=999, value=50) +guidance_scale = st.number_input("Guidance scale", min_value=1, max_value=10, value=1) +seed = st.number_input("Seed", min_value=0, max_value=None, value=1) +cond_str = st.radio("Glaucoma", ('Yes', 'No'), index=1, help="Conditioned on 'referable glaucoma' or 'no referable glaucoma'", horizontal=True) +torch.manual_seed(seed) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +@st.cache(allow_output_mutation = True) +def init_pipeline(): + pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_11_11_175610_eye_diffusion/last.ckpt') + return pipeline + +if st.button('Sample'): + cond = {'Yes':1, 'No':0}[cond_str] + condition = torch.tensor([cond]*n_samples, device=device) + un_cond = torch.tensor([1-cond]*n_samples, device=device) + + pipeline = init_pipeline() + pipeline.to(device) + images = pipeline.sample(n_samples, (4, 32, 32), guidance_scale=guidance_scale, condition=condition, un_cond=un_cond, steps=steps, use_ddim=True ) + + images = images.clamp(-1, 1) + images = images.cpu().numpy() # [B, C, H, W] + images = (images+1)/2 # Transform from [-1, 1] to [0, 1] + + images = [np.moveaxis(img, 0, -1) for img in images] + st.image(images, channels="RGB", output_format='png') # expects (w,h,3) \ No newline at end of file diff --git a/streamlit/welcome.py b/streamlit/welcome.py new file mode 100755 index 0000000000000000000000000000000000000000..1cc615d506474e26b68bda82bc6bd6f0b2e2dd73 --- /dev/null +++ b/streamlit/welcome.py @@ -0,0 +1,5 @@ +import streamlit as st + +st.title('Welcome to Medfusion') +st.text("A latent Denoising Diffusion Probabilistic Model (DDPM) for Medial Image Synthesis") +st.image('media/Medfusion.png', channels="RGB", output_format='png') \ No newline at end of file diff --git a/tests/dataset/test_dataset.py b/tests/dataset/test_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..a1d677b53675a0827a801e85c3f905d5ead2a1b5 --- /dev/null +++ b/tests/dataset/test_dataset.py @@ -0,0 +1,27 @@ +from medical_diffusion.data.datasets import SimpleDataset2D + +import matplotlib.pyplot as plt +from pathlib import Path +from torchvision.utils import save_image + +path_out = Path().cwd()/'results'/'test' +path_out.mkdir(parents=True, exist_ok=True) + +# ds = SimpleDataset2D( +# crawler_ext='jpg', +# image_resize=(352, 528), +# image_crop=(192, 288), +# path_root='/home/gustav/Documents/datasets/AIROGS/dataset', +# ) + +ds = SimpleDataset2D( + crawler_ext='tif', + image_resize=None, + image_crop=None, + path_root='/home/gustav/Documents/datasets/BREAST-DIAGNOSIS/dataset_lr2d/' +) + +images = [ds[n]['source'] for n in range(4)] + + +save_image(images, path_out/'test.png') \ No newline at end of file diff --git a/tests/dataset/test_dataset_3d.py b/tests/dataset/test_dataset_3d.py new file mode 100755 index 0000000000000000000000000000000000000000..54166501fca10d85c99d09fe94c36532e93f5713 --- /dev/null +++ b/tests/dataset/test_dataset_3d.py @@ -0,0 +1,25 @@ +from medical_diffusion.data.datasets import SimpleDataset3D + +import matplotlib.pyplot as plt +from pathlib import Path +from torchvision.utils import save_image +import torch + +path_out = Path().cwd()/'results'/'test' +path_out.mkdir(parents=True, exist_ok=True) + + +ds = SimpleDataset3D( + crawler_ext='nii.gz', + image_resize=None, + image_crop=None, + path_root='/mnt/hdd/datasets/breast/DUKE/dataset_lr_256_256_32', + use_znorm=False +) + +image = ds[0]['source'] # [C, D, H, W] + +image = image.swapaxes(0, 1) # [D, C, H, W] -> treat D as Batch Dimension +image = image/2+0.5 + +save_image(image, path_out/'test.png') \ No newline at end of file diff --git a/tests/dataset/test_dataset_airogs.py b/tests/dataset/test_dataset_airogs.py new file mode 100755 index 0000000000000000000000000000000000000000..0f0a1c49e0e79a0141598b55e5c4790c44dbe2a8 --- /dev/null +++ b/tests/dataset/test_dataset_airogs.py @@ -0,0 +1,27 @@ +from medical_diffusion.data.datasets import SimpleDataset2D, AIROGSDataset + +import torch.nn.functional as F + +import matplotlib.pyplot as plt +from pathlib import Path +from torchvision.utils import save_image + +path_out = Path().cwd()/'results'/'test' +path_out.mkdir(parents=True, exist_ok=True) + +ds = AIROGSDataset( + crawler_ext='jpg', + image_resize=(256, 256), + image_crop=(256, 256), + path_root='/mnt/hdd/datasets/eye/AIROGS/data/', # '/home/gustav/Documents/datasets/AIROGS/dataset', '/mnt/hdd/datasets/eye/AIROGS/data/' +) + +weights = ds.get_weights() +images = [ds[n]['source'] for n in range(4)] + +interpolation_mode = 'bilinear' +images = [F.interpolate(img[None], size=[128, 128], mode=interpolation_mode, align_corners=None)[0] for img in images] + +images = [img/2+0.5 for img in images] + +save_image(images, path_out/'test.png') \ No newline at end of file diff --git a/tests/dataset/test_dataset_airogs_prep.py b/tests/dataset/test_dataset_airogs_prep.py new file mode 100755 index 0000000000000000000000000000000000000000..b0ef97e45d5de6a5748e91307c3a0245c4f0a0d5 --- /dev/null +++ b/tests/dataset/test_dataset_airogs_prep.py @@ -0,0 +1,23 @@ +from medical_diffusion.data.datasets import SimpleDataset2D, AIROGSDataset + +import torch.nn.functional as F + +import matplotlib.pyplot as plt +from pathlib import Path +from torchvision.utils import save_image + + +path_out = Path('/mnt/hdd/datasets/eye/AIROGS/data_256x256/') +path_out.mkdir(parents=True, exist_ok=True) + +ds = AIROGSDataset( + crawler_ext='jpg', + image_resize=256, + image_crop=(256, 256), + path_root='/mnt/hdd/datasets/eye/AIROGS/data/', # '/home/gustav/Documents/datasets/AIROGS/dataset', '/mnt/hdd/datasets/eye/AIROGS/data/' +) + +weights = ds.get_weights() + +for img in ds: + img['source'].save(path_out/f"{img['uid']}.jpg") \ No newline at end of file diff --git a/tests/dataset/test_dataset_chexpert.py b/tests/dataset/test_dataset_chexpert.py new file mode 100755 index 0000000000000000000000000000000000000000..73019ae362094a2bf80b9c09cf41478585266f16 --- /dev/null +++ b/tests/dataset/test_dataset_chexpert.py @@ -0,0 +1,51 @@ + + +from pathlib import Path +from torchvision.utils import save_image +import pandas as pd +import torch +import torch.nn.functional as F +from medical_diffusion.data.datasets import CheXpert_Dataset +import math + +path_out = Path().cwd()/'results'/'test'/'CheXpert' +path_out.mkdir(parents=True, exist_ok=True) + +# path_root = Path('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/train') +path_root = Path('/media/NAS/Chexpert_dataset/CheXpert-v1.0/train') +mode = path_root.name +labels = pd.read_csv(path_root.parent/f'{mode}.csv', index_col='Path') +labels = labels[labels['Frontal/Lateral'] == 'Frontal'] +labels.loc[labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Must be "female" to match paper data +labels.fillna(3, inplace=True) +str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1, 'LL':2, 'RL':3}} +labels.replace(str_2_int, inplace=True) + +# Get patients +labels['patient'] = labels.index.str.split('/').str[2] +labels.set_index('patient',drop=True, append=True, inplace=True) + +for c in labels.columns: + print(labels[c].value_counts(dropna=False)) + +ds = CheXpert_Dataset( + crawler_ext='jpg', + image_resize=(256, 256), + # image_crop=(256, 256), + path_root=path_root, +) + + + + +x = torch.stack([ds[n]['source'] for n in range(4)]) +b = x.shape[0] +save_image(x, path_out/'samples_down_0.png', nrwos=int(math.sqrt(b)), normalize=True, scale_each=True ) + +size_0 = torch.tensor(x.shape[2:]) + +for i in range(3): + new_size = torch.div(size_0, 2**(i+1), rounding_mode='floor' ) + x_i = F.interpolate(x, size=tuple(new_size), mode='nearest', align_corners=None) + print(x_i.shape) + save_image(x_i, path_out/f'samples_down_{i+1}.png', nrwos=int(math.sqrt(b)), normalize=True, scale_each=True) \ No newline at end of file diff --git a/tests/dataset/test_dataset_chexpert_2.py b/tests/dataset/test_dataset_chexpert_2.py new file mode 100755 index 0000000000000000000000000000000000000000..889348b34670cb88a31e75a2b0426e9cc3c06e63 --- /dev/null +++ b/tests/dataset/test_dataset_chexpert_2.py @@ -0,0 +1,42 @@ + + +from pathlib import Path +from torchvision.utils import save_image +import pandas as pd +import torch +import torch.nn.functional as F +from medical_diffusion.data.datasets import CheXpert_Dataset, CheXpert_2_Dataset +import math + +path_out = Path().cwd()/'results'/'test'/'CheXpert_2' +path_out.mkdir(parents=True, exist_ok=True) + +path_root = Path('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu') +labels = pd.read_csv(path_root/'labels/cheXPert_label.csv', index_col='Path') + + +# Get patients +# labels['patient'] = labels.index.str.split('/').str[2] +# labels.set_index('patient',drop=True, append=True, inplace=True) + +# for c in labels.columns: +# print(labels[c].value_counts(dropna=False)) + +ds = CheXpert_2_Dataset( + path_root=path_root, +) + + +weights = ds.get_weights() + +x = torch.stack([ds[n]['source'] for n in range(4)]) +b = x.shape[0] +save_image(x, path_out/'samples_down_0.png', nrwos=int(math.sqrt(b)), normalize=True, scale_each=True ) + +size_0 = torch.tensor(x.shape[2:]) + +for i in range(3): + new_size = torch.div(size_0, 2**(i+1), rounding_mode='floor' ) + x_i = F.interpolate(x, size=tuple(new_size), mode='nearest', align_corners=None) + print(x_i.shape) + save_image(x_i, path_out/f'samples_down_{i+1}.png', nrwos=int(math.sqrt(b)), normalize=True, scale_each=True) \ No newline at end of file diff --git a/tests/dataset/test_dataset_duke.py b/tests/dataset/test_dataset_duke.py new file mode 100755 index 0000000000000000000000000000000000000000..188dd70dc7a97337829cde7ecab749836ff3cbb7 --- /dev/null +++ b/tests/dataset/test_dataset_duke.py @@ -0,0 +1,27 @@ +from medical_diffusion.data.datasets import DUKEDataset + +import matplotlib.pyplot as plt +from pathlib import Path +from torchvision.utils import save_image +from pathlib import Path + +path_out = Path().cwd()/'results'/'test' +path_out.mkdir(parents=True, exist_ok=True) + +ids = [int(path_file.stem.split('_')[-1]) for path_file in Path('/mnt/hdd/datasets/breast/Diffusion2D/images').glob('*.png')] +print(min(ids), max(ids)) # [0, 53] + +ds = DUKEDataset( + crawler_ext='png', + image_resize=None, + image_crop=None, + path_root='/mnt/hdd/datasets/breast/Diffusion2D/images', +) + +print(ds[0]) +images = [ds[n]['source'] for n in range(4)] + + + + +save_image(images, path_out/'test.png') \ No newline at end of file diff --git a/tests/dataset/test_dataset_pathology.py b/tests/dataset/test_dataset_pathology.py new file mode 100755 index 0000000000000000000000000000000000000000..3d5360e64b1f517deb5a89aa50ed71ff9f8e9165 --- /dev/null +++ b/tests/dataset/test_dataset_pathology.py @@ -0,0 +1,25 @@ +from medical_diffusion.data.datasets import MSIvsMSS_Dataset + +import matplotlib.pyplot as plt +from pathlib import Path +from torchvision.utils import save_image +from pathlib import Path + +path_out = Path().cwd()/'results'/'test' +path_out.mkdir(parents=True, exist_ok=True) + + +ds = MSIvsMSS_Dataset( + crawler_ext='png', + image_resize=None, + image_crop=None, + path_root='/home/gustav/Documents/datasets/Kather/data/CRC/train', +) + +print(ds[0]) +images = [ds[n]['source']/2+0.5 for n in range(4)] + + + + +save_image(images, path_out/'test.png') \ No newline at end of file diff --git a/tests/dataset/test_dataset_pathology_2.py b/tests/dataset/test_dataset_pathology_2.py new file mode 100755 index 0000000000000000000000000000000000000000..e96ee34fb04fad04b5f880603203db213911f90d --- /dev/null +++ b/tests/dataset/test_dataset_pathology_2.py @@ -0,0 +1,26 @@ +from medical_diffusion.data.datasets import MSIvsMSS_2_Dataset + +import matplotlib.pyplot as plt +from pathlib import Path +from torchvision.utils import save_image +from pathlib import Path + +path_out = Path().cwd()/'results'/'test'/'patho2' +path_out.mkdir(parents=True, exist_ok=True) + + +ds = MSIvsMSS_2_Dataset( + crawler_ext='jpg', + image_resize=None, + image_crop=None, + # path_root='/home/gustav/Documents/datasets/Kather_2/train', + path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/' +) + +print(ds[0]) +images = [ds[n]['source']/2+0.5 for n in range(4)] + + + + +save_image(images, path_out/'test.png') \ No newline at end of file diff --git a/tests/losses/test_ffl.py b/tests/losses/test_ffl.py new file mode 100755 index 0000000000000000000000000000000000000000..0dcdc88bb1ee25ac8f0533096014b1d66541479f --- /dev/null +++ b/tests/losses/test_ffl.py @@ -0,0 +1,9 @@ +from medical_diffusion.loss.ffl_loss import FocalFrequencyLoss as FFL +ffl = FFL(loss_weight=1.0, alpha=1.0) # initialize nn.Module class + +import torch +fake = torch.randn(4, 3, 64, 64) # replace it with the predicted tensor of shape (N, C, H, W) +real = torch.randn(4, 3, 64, 64) # replace it with the target tensor of shape (N, C, H, W) + +loss = ffl(fake, real) # calculate focal frequency loss +print(loss) diff --git a/tests/losses/test_lpips.py b/tests/losses/test_lpips.py new file mode 100755 index 0000000000000000000000000000000000000000..606fbac8362a0a491fc35d3e12a0394a7509829d --- /dev/null +++ b/tests/losses/test_lpips.py @@ -0,0 +1,37 @@ + + +import torch +from medical_diffusion.loss.perceivers import LPIPS +from medical_diffusion.data.datasets import AIROGSDataset, SimpleDataset3D + +loss = LPIPS(normalize=False) +torch.manual_seed(0) + +# input = torch.randn((1, 3, 16, 128, 128)) # 3D - 1 channel +# input = torch.randn((1, 1, 128, 128)) # 2D - 1 channel +# input = torch.randn((1, 3, 128, 128)) # 2D - 3 channel + +# target = input/2 + +# print(loss(input, target)) + + +# ds = AIROGSDataset( +# crawler_ext='jpg', +# image_resize=(256, 256), +# image_crop=(256, 256), +# path_root='/mnt/hdd/datasets/eye/AIROGS/data/', # '/home/gustav/Documents/datasets/AIROGS/dataset', '/mnt/hdd/datasets/eye/AIROGS/data/' +# ) +ds = SimpleDataset3D( + crawler_ext='nii.gz', + image_resize=None, + image_crop=None, + flip=True, + path_root='/mnt/hdd/datasets/breast/DUKE/dataset_lr_256_256_32', + use_znorm=True + ) + +input = ds[0]['source'][None] + +target = torch.randn_like(input) +print(loss(input, target)) \ No newline at end of file diff --git a/tests/models/latent_embedders/test_vae.py b/tests/models/latent_embedders/test_vae.py new file mode 100755 index 0000000000000000000000000000000000000000..6e4f0a3cb0bb8c3a640849aab8aca98eb86df706 --- /dev/null +++ b/tests/models/latent_embedders/test_vae.py @@ -0,0 +1,45 @@ +from pathlib import Path +import math + +import torch +from torchvision.utils import save_image + +from medical_diffusion.data.datamodules import SimpleDataModule +from medical_diffusion.data.datasets import AIROGSDataset, SimpleDataset2D +from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN + + +path_out = Path.cwd()/'results/test' +path_out.mkdir(parents=True, exist_ok=True) +device = torch.device('cuda') +torch.manual_seed(0) + +ds = AIROGSDataset( + crawler_ext='jpg', + image_resize=(256, 256), + image_crop=(256, 256), + path_root='/home/gustav/Documents/datasets/AIROGS/dataset', # '/home/gustav/Documents/datasets/AIROGS/dataset', '/mnt/hdd/datasets/eye/AIROGS/data/' +) + +x = ds[0]['source'][None].to(device) # [B, C, H, W] + +# v_min = x.min() +# v_max = x.max() +# x = (x-v_min)/(v_max-v_min) +# x = x*2-1 + +# x = (x+1)/2 +# x = x*(v_max-v_min)+v_min + +embedder = VQVAE.load_from_checkpoint('runs/2022_10_06_233542_vqvae_eye/last.ckpt') +embedder.to(device) + + +with torch.no_grad(): + z = embedder.encode(x) + +x_pred = embedder.decode(z) + + +images = torch.cat([x, x_pred]) +save_image(images, path_out/'test_latent_embedder.png', nrwos=int(math.sqrt(images.shape[0])), normalize=True, scale_each=True) \ No newline at end of file diff --git a/tests/models/latent_embedders/test_vae_simple.py b/tests/models/latent_embedders/test_vae_simple.py new file mode 100755 index 0000000000000000000000000000000000000000..e2c8d6f3277c194a41bf8ea55e725f35de8d47eb --- /dev/null +++ b/tests/models/latent_embedders/test_vae_simple.py @@ -0,0 +1,12 @@ +import torch +from medical_diffusion.models.embedders.latent_embedders import VAE + + +input = torch.randn((1, 3, 128, 128)) # [B, C, H, W] + + +model = VAE(in_channels=3, out_channels=3, spatial_dims = 2, deep_supervision=True) +output = model(input) +print(output) + + diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py new file mode 100755 index 0000000000000000000000000000000000000000..915b5f30774a9018cee6617c6a49a7b8f7666416 --- /dev/null +++ b/tests/models/test_unet.py @@ -0,0 +1,38 @@ + +from medical_diffusion.models.estimators import UNet +from medical_diffusion.models.embedders import LabelEmbedder + +import torch + +cond_embedder = LabelEmbedder +cond_embedder_kwargs = { + 'emb_dim': 64, + 'num_classes':2 +} + +noise_estimator = UNet +noise_estimator_kwargs = { + 'in_ch':3, + 'out_ch':3, + 'spatial_dims':2, + 'hid_chs': [32, 64, 128, 256], + 'kernel_sizes': [ 1, 3, 3, 3], + 'strides': [ 1, 2, 2, 2], + # 'kernel_sizes':[(1,3,3), (1,3,3), (1,3,3), 3, 3], + # 'strides':[ 1, (1,2,2), (1,2,2), 2, 2], + # 'kernel_sizes':[3, 3, 3, 3, 3], + # 'strides': [1, 2, 2, 2, 2], + 'cond_embedder':cond_embedder, + 'cond_embedder_kwargs': cond_embedder_kwargs, + 'use_attention': 'linear', #['none', 'spatial', 'spatial', 'spatial', 'linear'], + } + + +model = UNet(**noise_estimator_kwargs) +# print(model) + +input = torch.randn((1,3,256,256)) +time = torch.randn([1,]) +cond = torch.tensor([0,]) +out_hor, out_ver = model(input, time, cond) +# print(out_hor) \ No newline at end of file diff --git a/tests/models/test_unet_openai.py b/tests/models/test_unet_openai.py new file mode 100755 index 0000000000000000000000000000000000000000..c2a1a27cd2ec78c7056d6314894c5a75c61f45b7 --- /dev/null +++ b/tests/models/test_unet_openai.py @@ -0,0 +1,19 @@ + +from medical_diffusion.external.stable_diffusion.unet_openai import UNetModel +from medical_diffusion.models.embedders import LabelEmbedder + +import torch + + +noise_estimator = UNetModel +noise_estimator_kwargs = {} + + +model = noise_estimator(**noise_estimator_kwargs) +print(model) + +input = torch.randn((1,4,32,32)) +time = torch.randn([1,]) +cond = None #torch.tensor([0,]) +out_hor, out_ver = model(input, time, cond) +print(out_hor) \ No newline at end of file diff --git a/tests/models/test_vae3d.py b/tests/models/test_vae3d.py new file mode 100755 index 0000000000000000000000000000000000000000..c2105ddac294e8597aa535d842c84e547dd971f5 --- /dev/null +++ b/tests/models/test_vae3d.py @@ -0,0 +1,19 @@ +import torch +from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN + + +input = torch.randn((1, 3, 16, 128, 128)) # [B, C, H, W] + + +model = VQVAE(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True) +# output = model(input) +# print(output) +loss = model._step({'source':input}, 1, 'train', 1, 1) +print(loss) + + +# model = VQGAN(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True) +# # output = model(input) +# # print(output) +# loss = model._step({'source':input}, 1, 'train', 1, 1) +# print(loss) diff --git a/tests/models/test_vae_diffusers.py b/tests/models/test_vae_diffusers.py new file mode 100755 index 0000000000000000000000000000000000000000..5597f1f47f625603657676176c4ba934b8933156 --- /dev/null +++ b/tests/models/test_vae_diffusers.py @@ -0,0 +1,23 @@ + +import torch +from medical_diffusion.external.diffusers.vae import VQModel, VQVAEWrapper, VAEWrapper + + +# model = AutoencoderKL(in_channels=3, out_channels=3) + +input = torch.randn((1, 3, 128, 128)) # [B, C, H, W] + +# model = VQModel(in_channels=3, out_channels=3) +# output = model(input, sample_posterior=True) +# print(output) + +model = VQVAEWrapper(in_ch=3, out_ch=3) +output = model(input) +print(output) + + + + +# model = VAEWrapper(in_ch=3, out_ch=3) +# output = model(input) +# print(output) \ No newline at end of file diff --git a/tests/models/time_embedders/test.py b/tests/models/time_embedders/test.py new file mode 100755 index 0000000000000000000000000000000000000000..c4eb5211f200aba0bacf1d2ac15803e796e7f783 --- /dev/null +++ b/tests/models/time_embedders/test.py @@ -0,0 +1,18 @@ +import torch +from medical_diffusion.models.embedders import TimeEmbbeding, SinusoidalPosEmb, LabelEmbedder + +cond_emb = LabelEmbedder(10, num_classes=2) +c = torch.tensor([[0,], [1,]]) +v = cond_emb(c) +print(v) + + +tim_emb = SinusoidalPosEmb(20, max_period=10) +t = torch.tensor([1,2,3, 1000]) +v = tim_emb(t) +print(v) + +tim_emb = TimeEmbbeding(4*4, SinusoidalPosEmb, {'max_period':10}) +t = torch.tensor([1,2,3, 1000]) +v = tim_emb(t) +print(v) \ No newline at end of file diff --git a/tests/noise_schedulers/test.py b/tests/noise_schedulers/test.py new file mode 100755 index 0000000000000000000000000000000000000000..73731301d20dbe582980ee207630dec0684540f1 --- /dev/null +++ b/tests/noise_schedulers/test.py @@ -0,0 +1,43 @@ + + +from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler +import torch +from pathlib import Path + +from torchvision.utils import save_image + + +device = torch.device('cuda') + +scheduler = GaussianNoiseScheduler() +# scheduler.to(device) +path_out = Path.cwd()/'results/test' + + +# print(scheduler.posterior_mean_coef1) +torch.manual_seed(0) +# x_0 = torch.ones((2, 3, 64, 64)) +x_0 = torch.rand((2, 3, 64, 64)) +noise = torch.randn_like(x_0) +t = torch.tensor([0, 999]) + +x_t = scheduler.estimate_x_t(x_0=x_0, t=t, x_T=noise) + +# x_0_pred = scheduler.estimate_x_t(x_0=x_0, t=torch.full_like(t, 0) , noise=noise) +# assert (x_0_pred == x_0).all(), "For t=0, function should return x_0" +# x_t, noise, t = scheduler.sample(x_0) +# print(x_t) + + +# x_0 = scheduler.estimate_x_0(x_t, noise, t) +# print(x_0) +# print(x_0.shape) + + + +pred = torch.randn_like(x_t) +x_t_prior, _ = scheduler.estimate_x_t_prior_from_x_T(x_t, t, pred, clip_x0=False) +print(x_t_prior) + +# save_image(x_t_prior, path_out/'test2.png') + diff --git a/tests/noise_schedulers/test_data.py b/tests/noise_schedulers/test_data.py new file mode 100755 index 0000000000000000000000000000000000000000..5ed5457c159eadea29ab85284d68a7de37648981 --- /dev/null +++ b/tests/noise_schedulers/test_data.py @@ -0,0 +1,120 @@ + + +from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler +from medical_diffusion.data.datasets import SimpleDataset2D, AIROGSDataset, CheXpert_Dataset, MSIvsMSS_2_Dataset +from medical_diffusion.models.embedders.latent_embedders import VAE, VAEGAN +import torch +from pathlib import Path +import matplotlib.pyplot as plt +import seaborn as sns +from math import ceil + + +from torchvision.utils import save_image + +# ds = SimpleDataset2D( +# crawler_ext='jpg', +# image_resize=(352, 528), +# image_crop=(192, 288), +# path_root='/home/gustav/Documents/datasets/AIROGS/dataset', +# ) + +# ds = AIROGSDataset( +# crawler_ext='jpg', +# image_resize=(256, 256), +# image_crop=(256, 256), +# path_root='/home/gustav/Documents/datasets/AIROGS/dataset', # '/home/gustav/Documents/datasets/AIROGS/dataset', /mnt/hdd/datasets/eye/AIROGS/data/ +# ) +# ds = CheXpert_Dataset( +# crawler_ext='jpg', +# augment_horizontal_flip=False, +# augment_vertical_flip=False, +# path_root='/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed/valid', +# ) + +ds = MSIvsMSS_2_Dataset( + crawler_ext='jpg', + image_resize=None, + image_crop=None, + augment_horizontal_flip=False, + augment_vertical_flip=False, + # path_root='/home/gustav/Documents/datasets/Kather_2/train', + path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', + ) + +device = torch.device('cuda') + +scheduler = GaussianNoiseScheduler(timesteps=1000, beta_start=1e-4, schedule_strategy='scaled_linear') +# scheduler.to(device) +path_out = Path.cwd()/'results/test/scheduler' +path_out.mkdir(parents=True, exist_ok=True) + + +# print(scheduler.posterior_mean_coef1) +torch.manual_seed(0) +x_0 = ds[0]['source'][None] # [B, C, H, W] + + + +embedder = VAE.load_from_checkpoint('runs/2022_11_25_232957_patho_vaegan/last_vae.ckpt') +with torch.no_grad(): + x_0 = embedder.encode(x_0) + +# x_0 = (x_0-x_0.min())/(x_0.max()-x_0.min()) +# x_0 = x_0*2-1 +# x*2-1 = (x-0.5)*2 + +noise = torch.randn_like(x_0) + +x_ts = [] +step=100 + + +for t in range(0, scheduler.T+step, step): + t = torch.tensor([t]) + x_t = scheduler.estimate_x_t(x_0=x_0, t=t, x_T=noise) # [B, C, H, W] + print(t, x_t.mean(), x_t.std()) + x_ts.append(x_t) + +x_ts = torch.cat(x_ts) +# save_image(x_ts, path_out/'scheduler_nosing.png', normalize=True, scale_each=True) + + + + +binrange=(-2.5,2.5) +bins = 50 + +ncols=8 +nelem = (scheduler.T+step)//step+2 +nrows = ceil(nelem/8) +fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3)) +ax_iter = iter(ax.flatten()) + + + +for axis in ax_iter: + axis.spines['top'].set_visible(False) + axis.spines['right'].set_visible(False) + axis.spines['left'].set_visible(False) + axis.axes.get_yaxis().set_visible(False) +ax_iter = iter(ax.flatten()) + +axis = next(ax_iter) +sns.histplot(x=x_0.flatten(), bins=bins, binrange=binrange, ax=axis) + +for t in range(0, scheduler.T+step, step): + print(t) + t = torch.tensor([t]) + x_t = scheduler.estimate_x_t(x_0=x_0, t=t, x_T=noise) # [B, C, H, W] + axis = next(ax_iter) + sns.histplot(x=x_t.flatten(), bins=bins, binrange=binrange, ax=axis) + +axis = next(ax_iter) +sns.histplot(x=noise.flatten(), bins=bins, binrange=binrange, ax=axis) + +fig.tight_layout() +fig.savefig(path_out/'scheduler_nosing_histo.png') + + + diff --git a/tests/noise_schedulers/test_data_qq.py b/tests/noise_schedulers/test_data_qq.py new file mode 100755 index 0000000000000000000000000000000000000000..b049b61668cb01cd98139613874f13914dbee2b2 --- /dev/null +++ b/tests/noise_schedulers/test_data_qq.py @@ -0,0 +1,76 @@ + + +from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler +from medical_diffusion.data.datasets import SimpleDataset2D, AIROGSDataset +from medical_diffusion.models.embedders.latent_embedders import VQVAE +import torch +from pathlib import Path +import matplotlib.pyplot as plt +import seaborn as sns +from math import ceil + + +import statsmodels.api as sm + + + + + +device = torch.device('cuda') +path_out = Path.cwd()/'results/test' +torch.manual_seed(0) + +ds = AIROGSDataset( + crawler_ext='jpg', + image_resize=(256, 256), + image_crop=(256, 256), + path_root='/home/gustav/Documents/datasets/AIROGS/dataset', # '/home/gustav/Documents/datasets/AIROGS/dataset', /mnt/hdd/datasets/eye/AIROGS/data/ + ) +x_0 = ds[0]['source'][None] # [B, C, H, W] + + +scheduler = GaussianNoiseScheduler(timesteps=500, schedule_strategy='scaled_linear') + + +# embedder = VQVAE.load_from_checkpoint('runs/2022_10_06_233542_vqvae_eye/last.ckpt') +# with torch.no_grad(): +# x_0 = embedder.encode(x_0) + +noise = torch.randn_like(x_0) + +step=100 +binrange=(-2.5,2.5) +bins = 50 + +ncols=8 +nelem = (scheduler.T+step)//step+2 +nrows = ceil(nelem/8) +fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3)) +ax_iter = iter(ax.flatten()) +for axis in ax_iter: + axis.spines['top'].set_visible(False) + axis.spines['right'].set_visible(False) + axis.spines['left'].set_visible(False) + axis.axes.get_yaxis().set_visible(False) +ax_iter = iter(ax.flatten()) + + + +axis = next(ax_iter) +sm.qqplot(x_0.flatten(), line='45', ax=axis) + +for t in range(0, scheduler.T+step, step): + print(t) + t = torch.tensor([t]) + x_t = scheduler.estimate_x_t(x_0=x_0, t=t, x_T=noise) # [B, C, H, W] + axis = next(ax_iter) + sm.qqplot(x_t.flatten(), line='45', ax=axis) + +axis = next(ax_iter) +sm.qqplot(noise.flatten(), line='45', ax=axis) + +fig.tight_layout() +fig.savefig(path_out/'scheduler_nosing_qq.png') + + + diff --git a/tests/noise_schedulers/test_data_reverse.py b/tests/noise_schedulers/test_data_reverse.py new file mode 100755 index 0000000000000000000000000000000000000000..3a1c42cef06bb2cdce1ffca21ece3b581f548b9e --- /dev/null +++ b/tests/noise_schedulers/test_data_reverse.py @@ -0,0 +1,59 @@ + + +from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler +from medical_diffusion.data.datasets import SimpleDataset2D +from medical_diffusion.models.pipelines import DiffusionPipeline +import torch +from pathlib import Path + +from torchvision.utils import save_image + +ds = SimpleDataset2D( + crawler_ext='jpg', + image_resize=(352, 528), + image_crop=(192, 288), + path_root='/home/gustav/Documents/datasets/AIROGS/dataset', +) + +device = torch.device('cuda') + +pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_09_22_153738/last.ckpt') +pipeline.to(device) + +scheduler = GaussianNoiseScheduler() +scheduler.to(device) + + +path_out = Path.cwd()/'results/test' +torch.manual_seed(0) + + +x_0 = ds[0]['source'][None] # [B, C, H, W] +x_0 = x_0.to(device) +x_0 = x_0*2-1 +noise = torch.rand_like(x_0) + +x_ts = [] +x_0_preds = [] +for t in range(0, 1000, 100): + time = torch.tensor([t], device=device) + x_t = scheduler.estimate_x_t(x_0=x_0, t=time, noise=noise) # [B, C, H, W] + x_0_pred = pipeline.denoise(x_t, i=t) + x_t = x_t/2+0.5 + x_0_pred = x_0_pred/2+0.5 + x_ts.append(x_t) + x_0_preds.append(x_0_pred) +# print(x_t) +x_ts = torch.cat(x_ts) +save_image(x_ts, path_out/'test2.png') + +x_0_preds = torch.cat(x_0_preds) +save_image(x_0_preds, path_out/'test3.png') + +# x_0 = scheduler.estimate_x_0(x_t, noise, t) +# # print(x_0) + +# x_t_prior = scheduler.estimate_x_t_prior_from_noise(x_t, t, noise, noise=noise) + + + diff --git a/tests/utils/test_attention.py b/tests/utils/test_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..f13374e21a5775bd58158ba0fe1a69586c82fc83 --- /dev/null +++ b/tests/utils/test_attention.py @@ -0,0 +1,21 @@ + +import torch + +from medical_diffusion.models.utils.attention_blocks import LinearTransformer, SpatialTransformer + + +input = torch.randn((1, 32, 16, 64, 64)) # 3D +input = torch.randn((1, 32, 64, 64)) # 2D + +b, ch, *_ = input.shape +dim = input.ndim +# attention = SpatialTransformer(dim-2, in_channels=ch, out_channels=ch, num_heads=8) +# attention(input) + +embedding = input +embedding = None +emb_dim = embedding.shape[1] if embedding is not None else None +attention = LinearTransformer(input.ndim-2, in_channels=ch, out_channels=ch, num_heads=3, emb_dim=emb_dim) +attention = SpatialTransformer(input.ndim-2, in_channels=ch, out_channels=ch, num_heads=3, emb_dim=emb_dim) + +print(attention(input, embedding)) \ No newline at end of file diff --git a/tests/utils/test_attention_vs_sd.py b/tests/utils/test_attention_vs_sd.py new file mode 100755 index 0000000000000000000000000000000000000000..762af104ac9d3bee0d3b082e8e1070c8d7640f70 --- /dev/null +++ b/tests/utils/test_attention_vs_sd.py @@ -0,0 +1,35 @@ + +import torch + +from medical_diffusion.models.utils.attention_blocks import LinearTransformer,LinearTransformerNd, SpatialTransformer + +from medical_diffusion.external.stable_diffusion.unet_openai import AttentionBlock +from medical_diffusion.external.stable_diffusion.attention import SpatialSelfAttention # similar/equal to Attention used SD-UNet implementation + + + +torch.manual_seed(0) +input = torch.randn((1, 32, 64, 64)) # 2D + +b, ch, *_ = input.shape +dim = input.ndim +# attention = SpatialTransformer(dim-2, in_channels=ch, out_channels=ch, num_heads=8) +# attention(input) + +embedding = input + +torch.manual_seed(0) +attention_a = LinearTransformer(input.ndim-2, in_channels=ch, out_channels=ch, num_heads=1, ch_per_head=ch, emb_dim=None) +torch.manual_seed(0) +attention_a2 = LinearTransformerNd(input.ndim-2, in_channels=ch, out_channels=ch, num_heads=1, ch_per_head=ch, emb_dim=None) +torch.manual_seed(0) +attention_b = SpatialSelfAttention(in_channels=ch) +torch.manual_seed(0) +attention_c = AttentionBlock(ch, num_heads=1, num_head_channels=ch) + +a = attention_a(input) +a2 = attention_a2(input) +b = attention_b(input) +c = attention_c(input) + +print(torch.abs(a-b).max(), torch.abs(a-a2).max(), torch.abs(a-c).max()) \ No newline at end of file