NRRD_3Views / app.py
vincentgao95's picture
Update app.py
6c5575c verified
import gradio as gr
import numpy as np
import nrrd
import matplotlib.pyplot as plt
import io
from PIL import Image
# Set Matplotlib to use the 'Agg' backend
plt.switch_backend('Agg')
def load_nrrd(file_obj):
data, _ = nrrd.read(file_obj)
return data
def get_slice(data, slice_index, view):
if view == "Axial":
return data[:, :, slice_index]
elif view == "Coronal":
return data[:, slice_index, :]
elif view == "Sagittal":
return data[slice_index, :, :]
def visualize_slice(file_obj, slice_index, view):
data = load_nrrd(file_obj)
# Get the appropriate slice based on the selected view
slice_image = get_slice(data, slice_index, view)
# Rotate the image 90 degrees clockwise
slice_image = np.rot90(slice_image, k=-1)
# Flip the image from left to right along the vertical axis
slice_image = np.fliplr(slice_image)
# Flip the image from top to bottom along the horizontal axis for coronal and sagittal views
if view in ["Coronal", "Sagittal"]:
slice_image = np.flipud(slice_image)
# Plot the slice
fig, ax = plt.subplots()
ax.imshow(slice_image, cmap='gray')
plt.axis('off')
# Convert matplotlib figure to PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
def update_slider(file_obj, view):
data = load_nrrd(file_obj)
if view == "Axial":
num_slices = data.shape[2]
elif view == "Coronal":
num_slices = data.shape[1]
elif view == "Sagittal":
num_slices = data.shape[0]
middle_slice = num_slices // 2 # Calculate the middle slice
return gr.update(maximum=num_slices-1, value=middle_slice) # Set the slider to start in the middle
with gr.Blocks() as app:
gr.Markdown("## NRRD Slice Visualizer")
gr.Markdown("Upload an NRRD file and use the slider to select and visualize slices.")
file_input = gr.File(label="Upload NRRD File")
view_selector = gr.Dropdown(choices=["Axial", "Coronal", "Sagittal"], value="Axial", label="View Selector")
slider = gr.Slider(minimum=0, maximum=1, step=1, value=0, label="Slice Selector")
image_output = gr.Image(type="pil", label="Selected Slice")
file_input.change(fn=update_slider, inputs=[file_input, view_selector], outputs=slider)
view_selector.change(fn=update_slider, inputs=[file_input, view_selector], outputs=slider)
file_input.change(fn=visualize_slice, inputs=[file_input, slider, view_selector], outputs=image_output)
slider.change(fn=visualize_slice, inputs=[file_input, slider, view_selector], outputs=image_output)
view_selector.change(fn=visualize_slice, inputs=[file_input, slider, view_selector], outputs=image_output)
app.launch()