Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
from pathlib import Path | |
from time import time | |
from typing import Union | |
import matplotlib.pyplot as plt | |
import dosma | |
import numpy as np | |
import wget | |
import cv2 | |
import scipy.misc | |
from PIL import Image | |
import dicom2nifti | |
import math | |
import pydicom | |
import operator | |
import moviepy.video.io.ImageSequenceClip | |
from tkinter import Tcl | |
import pandas as pd | |
import warnings | |
import numpy as np | |
from skimage.morphology import skeletonize_3d | |
from scipy.spatial.distance import pdist, squareform | |
from scipy.interpolate import splprep, splev | |
import nibabel as nib | |
from nibabel.processing import resample_to_output | |
import matplotlib.pyplot as plt | |
from scipy.interpolate import interp1d | |
from totalsegmentator.libs import ( | |
download_pretrained_weights, | |
nostdout, | |
setup_nnunet, | |
) | |
from comp2comp.inference_class_base import InferenceClass | |
from comp2comp.models.models import Models | |
from comp2comp.spine import spine_utils | |
import nibabel as nib | |
class AortaSegmentation(InferenceClass): | |
"""Spine segmentation.""" | |
def __init__(self, save=True): | |
super().__init__() | |
self.model_name = "totalsegmentator" | |
self.save_segmentations = save | |
def __call__(self, inference_pipeline): | |
# inference_pipeline.dicom_series_path = self.input_path | |
self.output_dir = inference_pipeline.output_dir | |
self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") | |
if not os.path.exists(self.output_dir_segmentations): | |
os.makedirs(self.output_dir_segmentations) | |
self.model_dir = inference_pipeline.model_dir | |
seg, mv = self.spine_seg( | |
os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), | |
self.output_dir_segmentations + "spine.nii.gz", | |
inference_pipeline.model_dir, | |
) | |
seg = seg.get_fdata() | |
medical_volume = mv.get_fdata() | |
axial_masks = [] | |
ct_image = [] | |
for i in range(seg.shape[2]): | |
axial_masks.append(seg[:, :, i]) | |
for i in range(medical_volume.shape[2]): | |
ct_image.append(medical_volume[:, :, i]) | |
# Save input axial slices to pipeline | |
inference_pipeline.ct_image = ct_image | |
# Save aorta masks to pipeline | |
inference_pipeline.axial_masks = axial_masks | |
return {} | |
def setup_nnunet_c2c(self, model_dir: Union[str, Path]): | |
"""Adapted from TotalSegmentator.""" | |
model_dir = Path(model_dir) | |
config_dir = model_dir / Path("." + self.model_name) | |
(config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir(exist_ok=True, parents=True) | |
(config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) | |
weights_dir = config_dir / "nnunet/results" | |
self.weights_dir = weights_dir | |
os.environ["nnUNet_raw_data_base"] = str( | |
weights_dir | |
) # not needed, just needs to be an existing directory | |
os.environ["nnUNet_preprocessed"] = str( | |
weights_dir | |
) # not needed, just needs to be an existing directory | |
os.environ["RESULTS_FOLDER"] = str(weights_dir) | |
def download_spine_model(self, model_dir: Union[str, Path]): | |
download_dir = Path( | |
os.path.join( | |
self.weights_dir, | |
"nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", | |
) | |
) | |
print(download_dir) | |
fold_0_path = download_dir / "fold_0" | |
if not os.path.exists(fold_0_path): | |
download_dir.mkdir(parents=True, exist_ok=True) | |
wget.download( | |
"https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip", | |
out=os.path.join(download_dir, "fold_0.zip"), | |
) | |
with zipfile.ZipFile(os.path.join(download_dir, "fold_0.zip"), "r") as zip_ref: | |
zip_ref.extractall(download_dir) | |
os.remove(os.path.join(download_dir, "fold_0.zip")) | |
wget.download( | |
"https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl", | |
out=os.path.join(download_dir, "plans.pkl"), | |
) | |
print("Spine model downloaded.") | |
else: | |
print("Spine model already downloaded.") | |
def spine_seg(self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir): | |
"""Run spine segmentation. | |
Args: | |
input_path (Union[str, Path]): Input path. | |
output_path (Union[str, Path]): Output path. | |
""" | |
print("Segmenting spine...") | |
st = time() | |
os.environ["SCRATCH"] = self.model_dir | |
print(self.model_dir) | |
# Setup nnunet | |
model = "3d_fullres" | |
folds = [0] | |
trainer = "nnUNetTrainerV2_ep4000_nomirror" | |
crop_path = None | |
task_id = [253] | |
self.setup_nnunet_c2c(model_dir) | |
self.download_spine_model(model_dir) | |
from totalsegmentator.nnunet import nnUNet_predict_image | |
with nostdout(): | |
img, seg = nnUNet_predict_image( | |
input_path, | |
output_path, | |
task_id, | |
model=model, | |
folds=folds, | |
trainer=trainer, | |
tta=False, | |
multilabel_image=True, | |
resample=1.5, | |
crop=None, | |
crop_path=crop_path, | |
task_name="total", | |
nora_tag="None", | |
preview=False, | |
nr_threads_resampling=1, | |
nr_threads_saving=6, | |
quiet=False, | |
verbose=False, | |
test=0, | |
) | |
end = time() | |
# Log total time for spine segmentation | |
print(f"Total time for spine segmentation: {end-st:.2f}s.") | |
seg_data = seg.get_fdata() | |
seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) | |
return seg, img | |
class AortaDiameter(InferenceClass): | |
def __init__(self): | |
super().__init__() | |
def normalize_img(self, img: np.ndarray) -> np.ndarray: | |
"""Normalize the image. | |
Args: | |
img (np.ndarray): Input image. | |
Returns: | |
np.ndarray: Normalized image. | |
""" | |
return (img - img.min()) / (img.max() - img.min()) | |
def __call__(self, inference_pipeline): | |
axial_masks = inference_pipeline.axial_masks # list of 2D numpy arrays of shape (512, 512) | |
ct_img = inference_pipeline.ct_image # 3D numpy array of shape (512, 512, num_axial_slices) | |
# image output directory | |
output_dir = inference_pipeline.output_dir | |
output_dir_slices = os.path.join(output_dir, "images/slices/") | |
if not os.path.exists(output_dir_slices): | |
os.makedirs(output_dir_slices) | |
output_dir = inference_pipeline.output_dir | |
output_dir_summary = os.path.join(output_dir, "images/summary/") | |
if not os.path.exists(output_dir_summary): | |
os.makedirs(output_dir_summary) | |
DICOM_PATH = inference_pipeline.dicom_series_path | |
dicom = pydicom.dcmread(DICOM_PATH+"/"+os.listdir(DICOM_PATH)[0]) | |
dicom.PhotometricInterpretation = 'YBR_FULL' | |
pixel_conversion = dicom.PixelSpacing | |
print("Pixel conversion: "+str(pixel_conversion)) | |
RATIO_PIXEL_TO_MM = pixel_conversion[0] | |
SLICE_COUNT = dicom["InstanceNumber"].value | |
print(SLICE_COUNT) | |
SLICE_COUNT = len(ct_img) | |
diameterDict = {} | |
for i in range(len(ct_img)): | |
mask = axial_masks[i].astype('uint8') | |
img = ct_img[i] | |
img = np.clip(img, -300, 1800) | |
img = self.normalize_img(img) * 255.0 | |
img = img.reshape((img.shape[0], img.shape[1], 1)) | |
img = np.tile(img, (1, 1, 3)) | |
contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) | |
if len(contours) != 0: | |
areas = [cv2.contourArea(c) for c in contours] | |
sorted_areas = np.sort(areas) | |
contours = contours[areas.index(sorted_areas[-1])] | |
overlay = img.copy() | |
back = img.copy() | |
cv2.drawContours(back, [contours], 0, (0,255,0), -1) | |
alpha = 0.25 | |
img = cv2.addWeighted(img, 1-alpha, back, alpha, 0) | |
ellipse = cv2.fitEllipse(contours) | |
(xc,yc),(d1,d2),angle = ellipse | |
cv2.ellipse(img, ellipse, (0, 255, 0), 1) | |
xc, yc = ellipse[0] | |
cv2.circle(img, (int(xc),int(yc)), 5, (0, 0, 255), -1) | |
rmajor = max(d1,d2)/2 | |
rminor = min(d1,d2)/2 | |
### Draw major axes | |
if angle > 90: | |
angle = angle - 90 | |
else: | |
angle = angle + 90 | |
print(angle) | |
xtop = xc + math.cos(math.radians(angle))*rmajor | |
ytop = yc + math.sin(math.radians(angle))*rmajor | |
xbot = xc + math.cos(math.radians(angle+180))*rmajor | |
ybot = yc + math.sin(math.radians(angle+180))*rmajor | |
cv2.line(img, (int(xtop),int(ytop)), (int(xbot),int(ybot)), (0, 0, 255), 3) | |
### Draw minor axes | |
if angle > 90: | |
angle = angle - 90 | |
else: | |
angle = angle + 90 | |
print(angle) | |
x1 = xc + math.cos(math.radians(angle))*rminor | |
y1 = yc + math.sin(math.radians(angle))*rminor | |
x2 = xc + math.cos(math.radians(angle+180))*rminor | |
y2 = yc + math.sin(math.radians(angle+180))*rminor | |
cv2.line(img, (int(x1),int(y1)), (int(x2),int(y2)), (255, 0, 0), 3) | |
# pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 ) | |
pixel_length = rminor*2 | |
print("Pixel_length_minor: "+str(pixel_length)) | |
area_px = cv2.contourArea(contours) | |
area_mm = round(area_px*RATIO_PIXEL_TO_MM) | |
area_cm = area_mm/10 | |
diameter_mm = round((pixel_length)*RATIO_PIXEL_TO_MM) | |
diameter_cm = diameter_mm/10 | |
diameterDict[(SLICE_COUNT-(i))] = diameter_cm | |
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
h,w,c = img.shape | |
lbls = ["Area (mm): "+str(area_mm)+"mm", "Area (cm): "+str(area_cm)+"cm", "Diameter (mm): "+str(diameter_mm)+"mm", "Diameter (cm): "+str(diameter_cm)+"cm", "Slice: "+str(SLICE_COUNT-(i))] | |
offset = 0 | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
scale = 0.03 | |
fontScale = min(w,h)/(25/scale) | |
cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2) | |
cv2.imwrite(output_dir_slices+"slice"+str(SLICE_COUNT-(i))+".png", img) | |
plt.bar(list(diameterDict.keys()), diameterDict.values(), color='b') | |
plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") | |
plt.xlabel('Slice Number') | |
plt.ylabel('Diameter Measurement (cm)') | |
plt.savefig(output_dir_summary+"diameter_graph.png", dpi=500) | |
print(diameterDict) | |
print(max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]) | |
inference_pipeline.max_diameter = diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]] | |
img = ct_img[SLICE_COUNT-(max(diameterDict.items(), key=operator.itemgetter(1))[0])] | |
img = np.clip(img, -300, 1800) | |
img = self.normalize_img(img) * 255.0 | |
img = img.reshape((img.shape[0], img.shape[1], 1)) | |
img2 = np.tile(img, (1, 1, 3)) | |
img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
img1 = cv2.imread(output_dir_slices+'slice'+str(max(diameterDict.items(), key=operator.itemgetter(1))[0])+'.png') | |
border_size = 3 | |
img1 = cv2.copyMakeBorder( | |
img1, | |
top=border_size, | |
bottom=border_size, | |
left=border_size, | |
right=border_size, | |
borderType=cv2.BORDER_CONSTANT, | |
value=[0, 244, 0] | |
) | |
img2 = cv2.copyMakeBorder( | |
img2, | |
top=border_size, | |
bottom=border_size, | |
left=border_size, | |
right=border_size, | |
borderType=cv2.BORDER_CONSTANT, | |
value=[244, 0, 0] | |
) | |
vis = np.concatenate((img2, img1), axis=1) | |
cv2.imwrite(output_dir_summary+'out.png', vis) | |
image_folder=output_dir_slices | |
fps=20 | |
image_files = [os.path.join(image_folder,img) | |
for img in Tcl().call('lsort', '-dict', os.listdir(image_folder)) | |
if img.endswith(".png")] | |
clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps) | |
clip.write_videofile(output_dir_summary+'aaa.mp4') | |
def compute_centerline_3d(aorta_segmentation): | |
skeleton = skeletonize_3d(aorta_segmentation) | |
z, y, x = np.where(skeleton) | |
centerline_points = np.vstack((x, y, z)).T | |
centerline_points = centerline_points[centerline_points[:, 0].argsort()] | |
return centerline_points | |
def fit_bspline(centerline_points, smoothness=1e8): | |
x, y, z = centerline_points.T | |
tck, _ = splprep([x, y, z], s=smoothness) | |
return tck | |
def evaluate_bspline(tck, num_points=1000): | |
u = np.linspace(0, 1, num_points) | |
x, y, z = splev(u, tck) | |
return np.vstack((x, y, z)).T | |
def interpolate_points(data, num_points=32): | |
x = data[:, 0] | |
y = data[:, 1:] | |
f_y = interp1d(x, y, kind="nearest", fill_value="extrapolate", axis=0) | |
new_x = np.arange(0, num_points) | |
new_y = f_y(new_x) | |
new_data = np.round(np.hstack((new_x.reshape(-1, 1), new_y))) | |
return new_data | |
def compute_orthogonal_planes(tck, num_points=100): | |
u = np.linspace(0, 1, num_points) | |
points = np.vstack(splev(u, tck)).T | |
tangents = np.vstack(splev(u, tck, der=1)).T | |
normals = tangents / np.linalg.norm(tangents, axis=1)[:, np.newaxis] | |
planes = [] | |
for point, normal in zip(points, normals): | |
d = -np.dot(point, normal) | |
planes.append((normal, d)) | |
return planes | |
def compute_maximum_diameter(aorta_segmentation, planes): | |
z, y, x = np.where(aorta_segmentation) | |
aorta_points = np.vstack((x, y, z)).T | |
max_diameters = [] | |
intersecting_points_list = [] | |
for normal, d in planes: | |
distances = np.dot(aorta_points, normal) + d | |
intersecting_points = aorta_points[np.abs(distances) < 0.5] | |
if len(intersecting_points) < 2: | |
continue | |
dist_matrix = squareform(pdist(intersecting_points)) | |
intersecting_points_list.append(intersecting_points) | |
max_diameter = np.max(dist_matrix) | |
max_diameters.append(max_diameter) | |
max_diameter_index = np.argmax(max_diameters) | |
max_diameter_in_pixels = max_diameters[max_diameter_index] | |
print(f'Maximum Diameter in Pixels: {max_diameter_in_pixels}') | |
diameter_mm = round((max_diameter_in_pixels)*RATIO_PIXEL_TO_MM) | |
print(f'Maximum Diameter in mm: {diameter_mm}') | |
max_diameters = np.array(max_diameters) * 0.15 | |
max_diameter_index = np.argmax(max_diameters) | |
max_diameter_normal, max_diameter_point = planes[max_diameter_index] | |
max_intersecting_points = intersecting_points_list[max_diameter_index] | |
print("max_diameter_normal type:", type(max_diameter_normal)) | |
print("max_diameter_normal shape:", np.shape(max_diameter_normal)) | |
print("max_diameter_point type:", type(max_diameter_point)) | |
print("max_diameter_point shape:", np.shape(max_diameter_point)) | |
print("max intersecting points type:", type(max_intersecting_points)) | |
print("max intersecting points shape:", np.shape(max_intersecting_points)) | |
print("max intersecting points:", max_intersecting_points) | |
return ( | |
max_diameters, | |
max_diameter_point, | |
max_diameter_normal, | |
max_intersecting_points, | |
) | |
def plot_2d_planar_reconstruction( | |
image, | |
segmentation, | |
interpolated_points, | |
max_diameter_point, | |
max_diameter_normal, | |
max_intersecting_points, | |
): | |
fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(15, 10)) | |
sagittal_index = interpolated_points[:, 2].astype(int) | |
image_2d = image[sagittal_index, :, range(image.shape[2])] | |
seg_2d = segmentation[sagittal_index, :, range(image.shape[2])] | |
# axs[0].imshow(image_2d, cmap="gray") | |
# axs[0].imshow(seg_2d, cmap="jet", alpha=0.3) | |
axs[0].scatter( | |
interpolated_points[:, 1].astype(int), | |
interpolated_points[:, 0].astype(int), | |
color="red", | |
s=1, | |
) | |
axs[0].plot( | |
max_intersecting_points[:, 1].astype(int), | |
max_intersecting_points[:, 0].astype(int), | |
color="blue", | |
) | |
coronal_index = interpolated_points[:, 1].astype(int) | |
image_2d = image[:, coronal_index, range(image.shape[2])].T | |
seg_2d = segmentation[:, coronal_index, range(image.shape[2])].T | |
# axs[1].imshow(image_2d, cmap="gray") | |
# axs[1].imshow(seg_2d, cmap="jet", alpha=0.3) | |
axs[1].scatter( | |
interpolated_points[:, 2].astype(int), | |
interpolated_points[:, 0].astype(int), | |
color="red", | |
s=1, | |
) | |
axs[1].plot( | |
max_intersecting_points[:, 2].astype(int), | |
max_intersecting_points[:, 0].astype(int), | |
color="blue", | |
) | |
plt.savefig(output_dir_summary+"planar_reconstruction.png") | |
output_dir = inference_pipeline.output_dir_segmentations | |
segmentation = nib.load( | |
os.path.join(output_dir, "converted_dcm.nii.gz") | |
) | |
image = nib.load( | |
os.path.join(output_dir, "spine.nii.gz") | |
) | |
image = resample_to_output(image, (1.5, 1.5, 1.5)) | |
segmentation = resample_to_output(segmentation, (1.5, 1.5, 1.5), order=0) | |
image = image.get_fdata() | |
segmentation = segmentation.get_fdata() | |
segmentation[segmentation == 42] = 1 | |
print(segmentation.shape) | |
print(np.unique(segmentation)) | |
centerline_points = compute_centerline_3d(segmentation) | |
print(centerline_points) | |
tck = fit_bspline(centerline_points) | |
evaluated_points = evaluate_bspline(tck) | |
print(evaluated_points) | |
interpolated_points = interpolate_points(evaluated_points, image.shape[2]) | |
print(interpolated_points) | |
planes = compute_orthogonal_planes(tck) | |
( | |
cmax_diameters, | |
max_diameter_point, | |
max_diameter_normal, | |
max_intersecting_points, | |
) = compute_maximum_diameter(segmentation, planes) | |
plot_2d_planar_reconstruction( | |
image, | |
segmentation, | |
interpolated_points, | |
max_diameter_point, | |
max_diameter_normal, | |
max_intersecting_points, | |
) | |
return {} | |
class AortaMetricsSaver(InferenceClass): | |
"""Save metrics to a CSV file.""" | |
def __init__(self): | |
super().__init__() | |
def __call__(self, inference_pipeline): | |
"""Save metrics to a CSV file.""" | |
self.max_diameter = inference_pipeline.max_diameter | |
self.dicom_series_path = inference_pipeline.dicom_series_path | |
self.output_dir = inference_pipeline.output_dir | |
self.csv_output_dir = os.path.join(self.output_dir, "metrics") | |
if not os.path.exists(self.csv_output_dir): | |
os.makedirs(self.csv_output_dir, exist_ok=True) | |
self.save_results() | |
return {} | |
def save_results(self): | |
"""Save results to a CSV file.""" | |
_, filename = os.path.split(self.dicom_series_path) | |
data = [[filename, str(self.max_diameter)]] | |
df = pd.DataFrame(data, columns=['Filename', 'Max Diameter']) | |
df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False) |