File size: 3,138 Bytes
a2df67a
 
 
 
 
 
 
 
 
 
edfe705
 
c2f693f
a2df67a
f83d212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7aaffb
 
 
 
 
 
 
 
 
 
 
 
 
 
f83d212
c2f693f
d7aaffb
f83d212
d7aaffb
 
 
 
1b40516
f83d212
1b40516
 
a2df67a
 
 
f83d212
a2df67a
 
 
 
 
 
1b40516
c2f693f
f83d212
 
a2df67a
 
f83d212
 
a2df67a
 
f83d212
a2df67a
1b40516
a2df67a
1b40516
 
f83d212
 
 
a2df67a
edfe705
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import numpy as np
import nrrd
import matplotlib.pyplot as plt
import io
from PIL import Image

# Set Matplotlib to use the 'Agg' backend
plt.switch_backend('Agg')

def load_nrrd(file_obj):
    data, _ = nrrd.read(file_obj)
    return data

def get_num_slices(data, view):
    if view == 'Axial':
        return data.shape[2]
    elif view == 'Coronal':
        return data.shape[1]
    elif view == 'Sagittal':
        return data.shape[0]

def extract_slice(data, view, slice_index):
    if view == 'Axial':
        return data[:, :, slice_index]
    elif view == 'Coronal':
        return data[:, slice_index, :]
    elif view == 'Sagittal':
        return data[slice_index, :, :]

def resize_slice(slice_image, view, data_shape):
    if view == 'Axial':
        return slice_image  # No resizing needed for Axial view
    else:
        # For Coronal and Sagittal views, find the two largest dimensions
        if view == 'Coronal':
            resize_dims = (data_shape[0], data_shape[2])  # (x, z)
        elif view == 'Sagittal':
            resize_dims = (data_shape[1], data_shape[2])  # (y, z)
        
        # Resize the slice image while maintaining the aspect ratio
        resized_image = Image.fromarray(slice_image).resize(resize_dims, Image.ANTIALIAS)
        return np.array(resized_image)

def visualize_slice(file_obj, view, slice_index):
    data = load_nrrd(file_obj)
    data_shape = data.shape  # Get the original x, y, z dimensions
    slice_image = extract_slice(data, view, slice_index)
    
    # Resize the slice image if necessary
    slice_image = resize_slice(slice_image, view, data_shape)
    
    # Plot the slice
    fig, ax = plt.subplots()
    ax.imshow(slice_image, cmap='gray')
    plt.axis('off')
    
    # Convert matplotlib figure to PIL Image
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    pil_img = Image.open(buf)
    
    return pil_img

def update_slider(file_obj, view):
    data = load_nrrd(file_obj)
    num_slices = get_num_slices(data, view)
    return gr.update(maximum=num_slices-1, value=0)

with gr.Blocks() as app:
    gr.Markdown("## NRRD Slice Visualizer")
    gr.Markdown("Upload an NRRD file, select a view, and use the slider to select and visualize slices.")
    
    file_input = gr.File(label="Upload NRRD File")
    view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], label="View Selector", value="Axial")
    slider = gr.Slider(minimum=0, maximum=1, step=1, value=0, label="Slice Selector")
    image_output = gr.Image(type="pil", label="Selected Slice")
    
    file_input.change(fn=update_slider, inputs=[file_input, view_selector], outputs=slider)
    view_selector.change(fn=update_slider, inputs=[file_input, view_selector], outputs=slider)
    file_input.change(fn=visualize_slice, inputs=[file_input, view_selector, slider], outputs=image_output)
    slider.change(fn=visualize_slice, inputs=[file_input, view_selector, slider], outputs=image_output)
    view_selector.change(fn=visualize_slice, inputs=[file_input, view_selector, slider], outputs=image_output)

app.launch()