Spaces:
Runtime error
Runtime error
File size: 4,308 Bytes
50e082d c989bc3 50e082d c989bc3 50e082d c989bc3 50e082d c989bc3 50e082d c989bc3 50e082d c989bc3 50e082d c989bc3 50e082d 45b3e15 c989bc3 50e082d c989bc3 50e082d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import apply_mask, to_tensor, center_crop
import skimage
import fastmri
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import uuid
# st.title('FastMRI Kspace Reconstruction Masks')
# st.write('This app allows you to visualize the masks and their effects on the kspace data.')
def main_func(
mask_name: str,
mask_center_fractions: int,
accelerations: int,
seed: int,
slice_index: int,
# input_image: str,
):
# file_dict = {
# "knee singlecoil": "data/knee1_kspace.npy",
# "knee multicoil": "data/knee2_kspace.npy",
# "brain multicoil 1": "data/brain1_kspace.npy",
# "brain multicoil 2": "data/brain2_kspace.npy",
# "prostate multicoil 1": "data/prostate1_kspace.npy",
# "prostate multicoil 2": "data/prostate2_kspace.npy",
# }
# input_file_path = file_dict[input_image]
# kspace = np.load(input_file_path)
kspace = np.load("data/prostate1_kspace.npy")
kspace = to_tensor(kspace)
mask_func = create_mask_for_mask_type(
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
)
subsampled_kspace, mask, num_low_frequencies = apply_mask(
kspace,
mask_func,
seed=seed,
)
print(mask.shape)
print(subsampled_kspace.shape)
print(kspace.shape)
mask = mask.squeeze() # 451
mask = mask.unsqueeze(0) # 1, 451
mask = mask.repeat(subsampled_kspace.shape[-3], 1).cpu().numpy()
print(mask.shape)
print()
subsampled_kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(subsampled_kspace)), dim=1)
kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace)), dim=1)
print(subsampled_kspace.shape)
print(kspace.shape)
subsampled_kspace = subsampled_kspace[slice_index]
kspace = kspace[slice_index]
print(subsampled_kspace.shape)
print(kspace.shape)
subsampled_kspace = center_crop(subsampled_kspace, (320, 320))
kspace = center_crop(kspace, (320, 320))
# now that we have the reconstructions, we can calculate the SSIM and psnr
kspace = kspace.cpu().numpy()
subsampled_kspace = subsampled_kspace.cpu().numpy()
ssim = skimage.metrics.structural_similarity(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
psnr = skimage.metrics.peak_signal_noise_ratio(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
df = pd.DataFrame({"SSIM": [ssim], "PSNR": [psnr], "Num Low Frequencies": [num_low_frequencies]})
print(df)
# create a plot
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(mask, cmap="gray")
ax[0].set_title("Mask")
ax[0].axis("off")
ax[1].imshow(subsampled_kspace, cmap="gray")
ax[1].set_title("Reconstructed Image")
ax[1].axis("off")
ax[2].imshow(kspace, cmap="gray")
ax[2].set_title("Original Image")
ax[2].axis("off")
plt.tight_layout()
plot_filename = f"data/{uuid.uuid4()}.png"
plt.savefig(plot_filename)
return df, plot_filename
demo = gr.Interface(
fn=main_func,
inputs=[
gr.Radio(['random', 'equispaced', "magic"], label="Mask Type", value="equispaced"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Center Fraction"),
gr.Number(value=4, label="Acceleration"),
gr.Number(value=42, label="Seed"),
gr.Number(value=15, label="Slice Index"),
# gr.Radio(["knee singlecoil", "knee multicoil", "brain multicoil 1", "brain multicoil 2", "prostate multicoil 1", "prostate multicoil 2"], label="Input Image")
],
outputs=[
gr.Dataframe(headers=["SSIM", "PSNR", "Num Low Frequencies"]),
gr.Image(type="filepath", label="Plot"),
# gr.Image(type="numpy", image_mode="L", label="Mask",),
# gr.Image(type="numpy", image_mode="L", label="Reconstructed Image", height=320, width=320),
# gr.Image(type="numpy", image_mode="L", label="Original Image", height=320, width=320),
],
title="FastMRI Kspace Reconstruction Masks",
description="This app allows you to visualize the masks and their effects on the kspace data."
)
demo.launch() |