medfusion-app / scripts /evaluate_latent_embedder.py
mueller-franzes's picture
init
f85e212
raw
history blame
3.96 kB
from pathlib import Path
import logging
from datetime import datetime
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms.functional as tF
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data import TensorDataset, Subset
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.functional import multiscale_structural_similarity_index_measure as mmssim
from medical_diffusion.models.embedders.latent_embedders import VAE
# ----------------Settings --------------
batch_size = 100
max_samples = None # set to None for all
target_class = None # None for no specific class
# path_out = Path.cwd()/'results'/'MSIvsMSS_2'/'metrics'
# path_out = Path.cwd()/'results'/'AIROGS'/'metrics'
path_out = Path.cwd()/'results'/'CheXpert'/'metrics'
path_out.mkdir(parents=True, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ----------------- Logging -----------
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S")
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
logger.addHandler(logging.FileHandler(path_out/f'metrics_{current_time}.log', 'w'))
# -------------- Helpers ---------------------
pil2torch = lambda x: torch.as_tensor(np.array(x)).moveaxis(-1, 0) # In contrast to ToTensor(), this will not cast 0-255 to 0-1 and destroy uint8 (required later)
# ---------------- Dataset/Dataloader ----------------
ds_real = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', transform=pil2torch)
# ds_real = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/', transform=pil2torch)
# ds_real = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference_test/', transform=pil2torch)
# ---------- Limit Sample Size
ds_real.samples = ds_real.samples[slice(max_samples)]
# --------- Select specific class ------------
if target_class is not None:
ds_real = Subset(ds_real, [i for i in range(len(ds_real)) if ds_real.samples[i][1] == ds_real.class_to_idx[target_class]])
dm_real = DataLoader(ds_real, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False)
logger.info(f"Samples Real: {len(ds_real)}")
# --------------- Load Model ------------------
model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt')
model.to(device)
# from diffusers import StableDiffusionPipeline
# with open('auth_token.txt', 'r') as file:
# auth_token = file.read()
# pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, use_auth_token=auth_token)
# model = pipe.vae
# model.to(device)
# ------------- Init Metrics ----------------------
calc_lpips = LPIPS().to(device)
# --------------- Start Calculation -----------------
mmssim_list, mse_list = [], []
for real_batch in tqdm(dm_real):
imgs_real_batch = real_batch[0].to(device)
imgs_real_batch = tF.normalize(imgs_real_batch/255, 0.5, 0.5) # [0, 255] -> [-1, 1]
with torch.no_grad():
imgs_fake_batch = model(imgs_real_batch)[0].clamp(-1, 1)
# -------------- LPIP -------------------
calc_lpips.update(imgs_real_batch, imgs_fake_batch) # expect input to be [-1, 1]
# -------------- MS-SSIM + MSE -------------------
for img_real, img_fake in zip(imgs_real_batch, imgs_fake_batch):
img_real, img_fake = (img_real+1)/2, (img_fake+1)/2 # [-1, 1] -> [0, 1]
mmssim_list.append(mmssim(img_real[None], img_fake[None], normalize='relu'))
mse_list.append(torch.mean(torch.square(img_real-img_fake)))
# -------------- Summary -------------------
mmssim_list = torch.stack(mmssim_list)
mse_list = torch.stack(mse_list)
lpips = 1-calc_lpips.compute()
logger.info(f"LPIPS Score: {lpips}")
logger.info(f"MS-SSIM: {torch.mean(mmssim_list)} ± {torch.std(mmssim_list)}")
logger.info(f"MSE: {torch.mean(mse_list)} ± {torch.std(mse_list)}")