|
from matplotlib import pyplot as plt |
|
import math |
|
import numpy as np |
|
|
|
|
|
|
|
def visualize_results(img, mask, pred, n_slices: int=3, slices: list=None, title: str=""): |
|
""" |
|
img: tensor [C, H, W, Z] |
|
mask: tensor [C, H, W, Z] |
|
pred: tensor [C, H, W, Z] |
|
n_slices: number of slices to visualize |
|
slices: list of slices to visualize |
|
title; title of the plot |
|
""" |
|
if slices is not None: |
|
n_slices = len(slices) |
|
|
|
fig, ax = plt.subplots(n_slices, 3, figsize=(14, 5*n_slices)) |
|
inc = img.shape[-1] // n_slices |
|
mask_masked = np.ma.masked_where(mask == 0, mask) |
|
pred_masked = np.ma.masked_where(pred == 0, pred) |
|
|
|
for i in range(n_slices): |
|
slice_num = i*inc if slices is None else slices[i] |
|
|
|
|
|
for c in range(3): |
|
ax[i,c].imshow(img[0,:,:,slice_num], cmap="gray") |
|
ax[i,c].axis("off") |
|
ax[i,c].set_title(f'image') |
|
|
|
|
|
ax[i,1].imshow(mask_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5) |
|
ax[i,1].imshow(mask_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8) |
|
ax[i,1].set_title(f'ground truth') |
|
|
|
|
|
ax[i,2].imshow(pred_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5) |
|
ax[i,2].imshow(pred_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8) |
|
ax[i,2].set_title(f'predicted') |
|
|
|
plt.suptitle(title, size=14) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
def visualize_patient(img, mask=None, n_slices: int=3, slices: list=None, z_dim_last=True, mask_channel=0, title: str=""): |
|
""" |
|
img: tensor [C, H, W, Z] |
|
mask: tensor [C, H, W, Z] |
|
n: number of slices to visualize |
|
""" |
|
if slices is not None: |
|
n_slices = len(slices) |
|
|
|
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3))) |
|
if z_dim_last: inc = img.shape[-1] // n_slices |
|
else: inc = img.shape[0] // n_slices |
|
masked = np.ma.masked_where(mask == 0, mask) |
|
|
|
for i in range(n_slices): |
|
r, c = divmod(i, 3) |
|
slice_num = i*inc if slices is None else slices[i] |
|
if n_slices <= 3: |
|
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray") |
|
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray") |
|
ax[c].axis("off") |
|
ax[c].set_title(f'slice {slice_num}') |
|
if mask is not None: |
|
if z_dim_last: mask_overlay = ax[c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) |
|
else: mask_overlay = ax[c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) |
|
else: |
|
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray") |
|
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray") |
|
ax[r][c].axis("off") |
|
ax[r][c].set_title(f'slice {slice_num}') |
|
if mask is not None: |
|
if z_dim_last: mask_overlay = ax[r][c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) |
|
else: mask_overlay = ax[r][c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4) |
|
|
|
plt.suptitle(title, size=14) |
|
|
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3))) |
|
if z_dim_last: inc = img.shape[-1] // n_slices |
|
else: inc = img.shape[0] // n_slices |
|
|
|
for i in range(n_slices): |
|
r, c = divmod(i, 3) |
|
slice_num = i*inc if slices is None else slices[i] |
|
if n_slices <= 3: |
|
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray") |
|
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray") |
|
ax[c].axis("off") |
|
ax[c].set_title(f'slice {slice_num}') |
|
else: |
|
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray") |
|
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray") |
|
ax[r][c].axis("off") |
|
ax[r][c].set_title(f'slice {slice_num}') |
|
|
|
plt.suptitle(title, size=14) |
|
|
|
plt.tight_layout() |
|
plt.show() |