Spaces:
Sleeping
Sleeping
from matplotlib import pyplot as plt | |
import math | |
import numpy as np | |
def visualize_results(img, mask, pred, n_slices: int=3, slices: list=None, title: str=""): | |
""" | |
img: tensor [C, H, W, Z] | |
mask: tensor [C, H, W, Z] | |
pred: tensor [C, H, W, Z] | |
n_slices: number of slices to visualize | |
slices: list of slices to visualize | |
title; title of the plot | |
""" | |
if slices is not None: | |
n_slices = len(slices) | |
fig, ax = plt.subplots(n_slices, 3, figsize=(14, 5*n_slices)) | |
inc = img.shape[-1] // n_slices | |
mask_masked = np.ma.masked_where(mask == 0, mask) | |
pred_masked = np.ma.masked_where(pred == 0, pred) | |
for i in range(n_slices): | |
slice_num = i*inc if slices is None else slices[i] | |
# image | |
for c in range(3): | |
ax[i,c].imshow(img[0,:,:,slice_num], cmap="gray") | |
ax[i,c].axis("off") | |
ax[i,c].set_title(f'image') | |
# ground truth | |
ax[i,1].imshow(mask_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5) | |
ax[i,1].imshow(mask_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8) | |
ax[i,1].set_title(f'ground truth') | |
# predicted | |
ax[i,2].imshow(pred_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5) | |
ax[i,2].imshow(pred_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8) | |
ax[i,2].set_title(f'predicted') | |
plt.suptitle(title, size=14) | |
plt.tight_layout() | |
plt.show() | |
def visualize_patient(img, mask=None, n_slices: int=3, slices: list=None, z_dim_last=True, mask_channel=0, title: str=""): | |
""" | |
img: tensor [C, H, W, Z] | |
mask: tensor [C, H, W, Z] | |
n: number of slices to visualize | |
""" | |
if slices is not None: | |
n_slices = len(slices) | |
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3))) | |
if z_dim_last: inc = img.shape[-1] // n_slices | |
else: inc = img.shape[0] // n_slices | |
masked = np.ma.masked_where(mask == 0, mask) | |
for i in range(n_slices): | |
r, c = divmod(i, 3) | |
slice_num = i*inc if slices is None else slices[i] | |
if n_slices <= 3: | |
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray") | |
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray") | |
ax[c].axis("off") | |
ax[c].set_title(f'slice {slice_num}') | |
if mask is not None: | |
if z_dim_last: mask_overlay = ax[c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) | |
else: mask_overlay = ax[c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) | |
else: | |
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray") | |
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray") | |
ax[r][c].axis("off") | |
ax[r][c].set_title(f'slice {slice_num}') | |
if mask is not None: | |
if z_dim_last: mask_overlay = ax[r][c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) | |
else: mask_overlay = ax[r][c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) | |
plt.suptitle(title, size=14) | |
#if mask is not None: | |
# cbar = fig.colorbar(mask_overlay, extend='both') | |
plt.tight_layout() | |
plt.show() | |
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3))) | |
if z_dim_last: inc = img.shape[-1] // n_slices | |
else: inc = img.shape[0] // n_slices | |
for i in range(n_slices): | |
r, c = divmod(i, 3) | |
slice_num = i*inc if slices is None else slices[i] | |
if n_slices <= 3: | |
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray") | |
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray") | |
ax[c].axis("off") | |
ax[c].set_title(f'slice {slice_num}') | |
else: | |
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray") | |
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray") | |
ax[r][c].axis("off") | |
ax[r][c].set_title(f'slice {slice_num}') | |
plt.suptitle(title, size=14) | |
plt.tight_layout() | |
plt.show() |