Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import logging | |
from datetime import datetime | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import torchvision.transforms.functional as tF | |
from torch.utils.data.dataloader import DataLoader | |
from torchvision.datasets import ImageFolder | |
from torch.utils.data import TensorDataset, Subset | |
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS | |
from torchmetrics.functional import multiscale_structural_similarity_index_measure as mmssim | |
from medical_diffusion.models.embedders.latent_embedders import VAE | |
# ----------------Settings -------------- | |
batch_size = 100 | |
max_samples = None # set to None for all | |
target_class = None # None for no specific class | |
# path_out = Path.cwd()/'results'/'MSIvsMSS_2'/'metrics' | |
# path_out = Path.cwd()/'results'/'AIROGS'/'metrics' | |
path_out = Path.cwd()/'results'/'CheXpert'/'metrics' | |
path_out.mkdir(parents=True, exist_ok=True) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# ----------------- Logging ----------- | |
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") | |
logger = logging.getLogger() | |
logging.basicConfig(level=logging.INFO) | |
logger.addHandler(logging.FileHandler(path_out/f'metrics_{current_time}.log', 'w')) | |
# -------------- Helpers --------------------- | |
pil2torch = lambda x: torch.as_tensor(np.array(x)).moveaxis(-1, 0) # In contrast to ToTensor(), this will not cast 0-255 to 0-1 and destroy uint8 (required later) | |
# ---------------- Dataset/Dataloader ---------------- | |
ds_real = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', transform=pil2torch) | |
# ds_real = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/', transform=pil2torch) | |
# ds_real = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference_test/', transform=pil2torch) | |
# ---------- Limit Sample Size | |
ds_real.samples = ds_real.samples[slice(max_samples)] | |
# --------- Select specific class ------------ | |
if target_class is not None: | |
ds_real = Subset(ds_real, [i for i in range(len(ds_real)) if ds_real.samples[i][1] == ds_real.class_to_idx[target_class]]) | |
dm_real = DataLoader(ds_real, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False) | |
logger.info(f"Samples Real: {len(ds_real)}") | |
# --------------- Load Model ------------------ | |
model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt') | |
model.to(device) | |
# from diffusers import StableDiffusionPipeline | |
# with open('auth_token.txt', 'r') as file: | |
# auth_token = file.read() | |
# pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, use_auth_token=auth_token) | |
# model = pipe.vae | |
# model.to(device) | |
# ------------- Init Metrics ---------------------- | |
calc_lpips = LPIPS().to(device) | |
# --------------- Start Calculation ----------------- | |
mmssim_list, mse_list = [], [] | |
for real_batch in tqdm(dm_real): | |
imgs_real_batch = real_batch[0].to(device) | |
imgs_real_batch = tF.normalize(imgs_real_batch/255, 0.5, 0.5) # [0, 255] -> [-1, 1] | |
with torch.no_grad(): | |
imgs_fake_batch = model(imgs_real_batch)[0].clamp(-1, 1) | |
# -------------- LPIP ------------------- | |
calc_lpips.update(imgs_real_batch, imgs_fake_batch) # expect input to be [-1, 1] | |
# -------------- MS-SSIM + MSE ------------------- | |
for img_real, img_fake in zip(imgs_real_batch, imgs_fake_batch): | |
img_real, img_fake = (img_real+1)/2, (img_fake+1)/2 # [-1, 1] -> [0, 1] | |
mmssim_list.append(mmssim(img_real[None], img_fake[None], normalize='relu')) | |
mse_list.append(torch.mean(torch.square(img_real-img_fake))) | |
# -------------- Summary ------------------- | |
mmssim_list = torch.stack(mmssim_list) | |
mse_list = torch.stack(mse_list) | |
lpips = 1-calc_lpips.compute() | |
logger.info(f"LPIPS Score: {lpips}") | |
logger.info(f"MS-SSIM: {torch.mean(mmssim_list)} ± {torch.std(mmssim_list)}") | |
logger.info(f"MSE: {torch.mean(mse_list)} ± {torch.std(mse_list)}") |