import gradio as gr from fastmri.data.subsample import create_mask_for_mask_type from fastmri.data.transforms import apply_mask, to_tensor, center_crop from pytorch_msssim import ssim # st.title('FastMRI Kspace Reconstruction Masks') # st.write('This app allows you to visualize the masks and their effects on the kspace data.') def main_func( mask_name: str, mask_center_fractions: int, accelerations: int, seed: int, input_image: str, ): file_dict = { "knee 1": "knee_singlecoil_train/file1000002.h5", "knee 2": "knee_singlecoil_train/file1000003.h5", "brain 1": "brain_axial_train/file1000002.h5", "prostate 1": "prostate_t1_tse_train/file1000002.h5", "prostate 2": "prostate_t2_tse_train/file1000002.h5", } input_file = file_dict[input_image] mask_func = create_mask_for_mask_type( mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations] ) mask = masked_kspace, mask = mask(input_image, return_mask=True) return masked_kspace, mask demo = gr.Interface( fn=main_func, inputs=[ gr.inputs.Radio(['random', 'equispaced'], label="Mask Type"), gr.inputs.Slider(minimum=0.04, maximum=0.4, default=0.08, label="Center Fraction"), gr.inputs.Number(default=4, label="Acceleration"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Radio(["knee 1", "knee 2", "brain 1", "prostate 1", "prostate 2"], label="Input Image") ], outputs=[ gr.outputs.Image(type="mask", label="Mask"), gr.outputs.Image(type="kspace", label="Masked Kspace"), gr.outputs.Image(type="kspace", label="Reconstructed Image"), gr.outputs.Image(type="kspace", label="Original Image"), gr.outputs.Dataframe() ], title="FastMRI Kspace Reconstruction Masks", description="This app allows you to visualize the masks and their effects on the kspace data." ) demo.launch()