# Import needed libraries import streamlit as st import torchvision from medigan import Generators from torchvision.transforms.functional import to_pil_image from torchvision.utils import make_grid # Define the GAN models available in the app model_ids = [ "00001_DCGAN_MMG_CALC_ROI", "00002_DCGAN_MMG_MASS_ROI", "00003_CYCLEGAN_MMG_DENSITY_FULL", "00004_PIX2PIX_MMG_MASSES_W_MASKS", "00019_PGGAN_CHEST_XRAY" # New model ] def main(): st.title("MEDIGAN Medical Image Data Generator") # Add dropdown widget for model selection to the sidebar model_id = st.sidebar.selectbox("Select Model ID", model_ids) # Add number image selector to the sidebar num_images = st.sidebar.number_input( "Number of Images", min_value=1, max_value=7, value=1, step=1 ) # Add generate button to the sidebar if st.sidebar.button("Generate Images"): generate_images(num_images, model_id) # Task 5.4.9: Copy the torch_images function from this notebook to app.py. def torch_images(num_images, model_id): generators = Generators() dataloader = generators.get_as_torch_dataloader( model_id=model_id, install_dependencies=True, num_samples=num_images, prefetch_factor=None, ) images = [] for batch_idx, data_dict in enumerate(dataloader): image_list = [] for i in data_dict: if "sample" in i: sample = data_dict.get("sample") if sample.dim() == 4: sample = sample.squeeze(0).permute(2, 0, 1) sample = to_pil_image(sample).convert("RGB") # Convert the image to a PyTorch tensor transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), ] ) # Apply the transform to your PIL image sample = transform(sample) image_list.append(sample) # Preprocess the mask if "mask" in i: mask = data_dict.get("mask") if mask.dim() == 4: mask = mask.squeeze(0).permute(2, 0, 1) mask = to_pil_image(mask).convert("RGB") mask = transform(mask) image_list.append(mask) # Organize the grid to have 'sample' images per row Grid = make_grid(image_list, nrow=2) # Change Grid tensor to be a consistent shape # The Grid tensor has shape [1, 128, 128, 1] in some models if Grid.dim() == 4: # Remove the singleton batch dimension Grid = Grid.squeeze(0) if Grid.size(-1) == 1: # Remove the singleton channel dimension (assuming grayscale) Grid = Grid.squeeze(-1) else: raise ValueError("Expected a single channel (grayscale) image.") # Convert the tensor grid to a PIL Image for display img = torchvision.transforms.ToPILImage()(Grid) images.append(img) return images def generate_images(num_images, model_id): st.subheader("Generated Images:") images = torch_images(num_images, model_id) # Pass the correct parameters for i in range(len(images)): st.image( images[i], caption=f"Generated Image {i+1} (Model ID: {model_id})", use_container_width=True, ) if __name__ == "__main__": main() # Task 5.4.4: Call the main function