medassist-liver-cancer / utils /visualization.py
lingchmao's picture
Upload 12 files
6ffe23f verified
raw
history blame
4.44 kB
from matplotlib import pyplot as plt
import math
import numpy as np
def visualize_results(img, mask, pred, n_slices: int=3, slices: list=None, title: str=""):
"""
img: tensor [C, H, W, Z]
mask: tensor [C, H, W, Z]
pred: tensor [C, H, W, Z]
n_slices: number of slices to visualize
slices: list of slices to visualize
title; title of the plot
"""
if slices is not None:
n_slices = len(slices)
fig, ax = plt.subplots(n_slices, 3, figsize=(14, 5*n_slices))
inc = img.shape[-1] // n_slices
mask_masked = np.ma.masked_where(mask == 0, mask)
pred_masked = np.ma.masked_where(pred == 0, pred)
for i in range(n_slices):
slice_num = i*inc if slices is None else slices[i]
# image
for c in range(3):
ax[i,c].imshow(img[0,:,:,slice_num], cmap="gray")
ax[i,c].axis("off")
ax[i,c].set_title(f'image')
# ground truth
ax[i,1].imshow(mask_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
ax[i,1].imshow(mask_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
ax[i,1].set_title(f'ground truth')
# predicted
ax[i,2].imshow(pred_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
ax[i,2].imshow(pred_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
ax[i,2].set_title(f'predicted')
plt.suptitle(title, size=14)
plt.tight_layout()
plt.show()
def visualize_patient(img, mask=None, n_slices: int=3, slices: list=None, z_dim_last=True, mask_channel=0, title: str=""):
"""
img: tensor [C, H, W, Z]
mask: tensor [C, H, W, Z]
n: number of slices to visualize
"""
if slices is not None:
n_slices = len(slices)
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
if z_dim_last: inc = img.shape[-1] // n_slices
else: inc = img.shape[0] // n_slices
masked = np.ma.masked_where(mask == 0, mask)
for i in range(n_slices):
r, c = divmod(i, 3)
slice_num = i*inc if slices is None else slices[i]
if n_slices <= 3:
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[c].axis("off")
ax[c].set_title(f'slice {slice_num}')
if mask is not None:
if z_dim_last: mask_overlay = ax[c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
else: mask_overlay = ax[c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
else:
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[r][c].axis("off")
ax[r][c].set_title(f'slice {slice_num}')
if mask is not None:
if z_dim_last: mask_overlay = ax[r][c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
else: mask_overlay = ax[r][c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
plt.suptitle(title, size=14)
#if mask is not None:
# cbar = fig.colorbar(mask_overlay, extend='both')
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
if z_dim_last: inc = img.shape[-1] // n_slices
else: inc = img.shape[0] // n_slices
for i in range(n_slices):
r, c = divmod(i, 3)
slice_num = i*inc if slices is None else slices[i]
if n_slices <= 3:
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[c].axis("off")
ax[c].set_title(f'slice {slice_num}')
else:
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[r][c].axis("off")
ax[r][c].set_title(f'slice {slice_num}')
plt.suptitle(title, size=14)
plt.tight_layout()
plt.show()