Spaces:
Runtime error
Runtime error
import gradio as gr | |
from fastmri.data.subsample import create_mask_for_mask_type | |
from fastmri.data.transforms import apply_mask, to_tensor, center_crop | |
from pytorch_msssim import ssim | |
# st.title('FastMRI Kspace Reconstruction Masks') | |
# st.write('This app allows you to visualize the masks and their effects on the kspace data.') | |
def main_func( | |
mask_name: str, | |
mask_center_fractions: int, | |
accelerations: int, | |
seed: int, | |
input_image: str, | |
): | |
file_dict = { | |
"knee 1": "knee_singlecoil_train/file1000002.h5", | |
"knee 2": "knee_singlecoil_train/file1000003.h5", | |
"brain 1": "brain_axial_train/file1000002.h5", | |
"prostate 1": "prostate_t1_tse_train/file1000002.h5", | |
"prostate 2": "prostate_t2_tse_train/file1000002.h5", | |
} | |
input_file = file_dict[input_image] | |
mask_func = create_mask_for_mask_type( | |
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations] | |
) | |
mask = | |
masked_kspace, mask = mask(input_image, return_mask=True) | |
return masked_kspace, mask | |
demo = gr.Interface( | |
fn=main_func, | |
inputs=[ | |
gr.inputs.Radio(['random', 'equispaced'], label="Mask Type"), | |
gr.inputs.Slider(minimum=0.04, maximum=0.4, default=0.08, label="Center Fraction"), | |
gr.inputs.Number(default=4, label="Acceleration"), | |
gr.inputs.Number(default=0, label="Seed"), | |
gr.inputs.Radio(["knee 1", "knee 2", "brain 1", "prostate 1", "prostate 2"], label="Input Image") | |
], | |
outputs=[ | |
gr.outputs.Image(type="mask", label="Mask"), | |
gr.outputs.Image(type="kspace", label="Masked Kspace"), | |
gr.outputs.Image(type="kspace", label="Reconstructed Image"), | |
gr.outputs.Image(type="kspace", label="Original Image"), | |
gr.outputs.Dataframe() | |
], | |
title="FastMRI Kspace Reconstruction Masks", | |
description="This app allows you to visualize the masks and their effects on the kspace data." | |
) | |
demo.launch() |