osbm's picture
add script
50e082d
raw
history blame
1.99 kB
import gradio as gr
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import apply_mask, to_tensor, center_crop
from pytorch_msssim import ssim
# st.title('FastMRI Kspace Reconstruction Masks')
# st.write('This app allows you to visualize the masks and their effects on the kspace data.')
def main_func(
mask_name: str,
mask_center_fractions: int,
accelerations: int,
seed: int,
input_image: str,
):
file_dict = {
"knee 1": "knee_singlecoil_train/file1000002.h5",
"knee 2": "knee_singlecoil_train/file1000003.h5",
"brain 1": "brain_axial_train/file1000002.h5",
"prostate 1": "prostate_t1_tse_train/file1000002.h5",
"prostate 2": "prostate_t2_tse_train/file1000002.h5",
}
input_file = file_dict[input_image]
mask_func = create_mask_for_mask_type(
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
)
mask =
masked_kspace, mask = mask(input_image, return_mask=True)
return masked_kspace, mask
demo = gr.Interface(
fn=main_func,
inputs=[
gr.inputs.Radio(['random', 'equispaced'], label="Mask Type"),
gr.inputs.Slider(minimum=0.04, maximum=0.4, default=0.08, label="Center Fraction"),
gr.inputs.Number(default=4, label="Acceleration"),
gr.inputs.Number(default=0, label="Seed"),
gr.inputs.Radio(["knee 1", "knee 2", "brain 1", "prostate 1", "prostate 2"], label="Input Image")
],
outputs=[
gr.outputs.Image(type="mask", label="Mask"),
gr.outputs.Image(type="kspace", label="Masked Kspace"),
gr.outputs.Image(type="kspace", label="Reconstructed Image"),
gr.outputs.Image(type="kspace", label="Original Image"),
gr.outputs.Dataframe()
],
title="FastMRI Kspace Reconstruction Masks",
description="This app allows you to visualize the masks and their effects on the kspace data."
)
demo.launch()