# Import needed libraries import streamlit as st import torchvision from medigan import Generators from torchvision.transforms.functional import to_pil_image from torchvision.utils import make_grid # Define the GAN models available in the app model_ids = [ "00001_DCGAN_MMG_CALC_ROI", "00002_DCGAN_MMG_MASS_ROI", "00003_CYCLEGAN_MMG_DENSITY_FULL", "00004_PIX2PIX_MMG_MASSES_W_MASKS", "00019_PGGAN_CHEST_XRAY" ] def main(): # Setup page configuration st.set_page_config(page_title="MEDIGAN Generator", layout="wide") # Main page title and description st.title("🧠 MEDIGAN Medical Image Generator") st.markdown(""" **Generate synthetic medical images using GAN models.** 🔍 Select model and parameters in the sidebar → Click **Generate Images** """) # Sidebar controls with st.sidebar: st.header("⚙️ Settings") model_id = st.selectbox("Select GAN Model", model_ids) num_images = st.number_input("Number of Images", 1, 7, 1) generate_btn = st.button("✨ Generate Images") # Main content area if generate_btn: with st.spinner(f"Generating {num_images} image(s) using {model_id}..."): generate_images(num_images, model_id) def torch_images(num_images, model_id): generators = Generators() dataloader = generators.get_as_torch_dataloader( model_id=model_id, install_dependencies=True, num_samples=num_images, num_workers=0, prefetch_factor=None ) images = [] for _, data_dict in enumerate(dataloader): batch_images = [] for tensor in data_dict.values(): if tensor.dim() == 4: tensor = tensor.squeeze(0).permute(2, 0, 1) img = to_pil_image(tensor).convert("RGB") batch_images.append(img) # Create image grid grid_tensor = make_grid( [torchvision.transforms.ToTensor()(img) for img in batch_images], nrow=2 if len(batch_images) > 1 else 1 ) grid_img = to_pil_image(grid_tensor) images.append(grid_img) return images def generate_images(num_images, model_id): # Clear previous results st.empty() # Generate and display new images images = torch_images(num_images, model_id) # Create columns for responsive layout cols = st.columns(len(images)) for col, img in zip(cols, images): with col: st.image( img, caption=f"Generated by: {model_id}", width=300 ) st.markdown("---") if __name__ == "__main__": main()