Generate_Images / app.py
ahmed792002's picture
Upload app.py
c6d85c0 verified
raw
history blame
3.54 kB
# Import needed libraries
import streamlit as st
import torchvision
from medigan import Generators
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
# Define the GAN models available in the app
model_ids = [
"00001_DCGAN_MMG_CALC_ROI",
"00002_DCGAN_MMG_MASS_ROI",
"00003_CYCLEGAN_MMG_DENSITY_FULL",
"00004_PIX2PIX_MMG_MASSES_W_MASKS",
"00019_PGGAN_CHEST_XRAY" # New model
]
def main():
st.title("MEDIGAN Medical Image Data Generator")
# Add dropdown widget for model selection to the sidebar
model_id = st.sidebar.selectbox("Select Model ID", model_ids)
# Add number image selector to the sidebar
num_images = st.sidebar.number_input(
"Number of Images", min_value=1, max_value=7, value=1, step=1
)
# Add generate button to the sidebar
if st.sidebar.button("Generate Images"):
generate_images(num_images, model_id)
# Task 5.4.9: Copy the torch_images function from this notebook to app.py.
def torch_images(num_images, model_id):
generators = Generators()
dataloader = generators.get_as_torch_dataloader(
model_id=model_id,
install_dependencies=True,
num_samples=num_images,
prefetch_factor=None,
)
images = []
for batch_idx, data_dict in enumerate(dataloader):
image_list = []
for i in data_dict:
if "sample" in i:
sample = data_dict.get("sample")
if sample.dim() == 4:
sample = sample.squeeze(0).permute(2, 0, 1)
sample = to_pil_image(sample).convert("RGB")
# Convert the image to a PyTorch tensor
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)
# Apply the transform to your PIL image
sample = transform(sample)
image_list.append(sample)
# Preprocess the mask
if "mask" in i:
mask = data_dict.get("mask")
if mask.dim() == 4:
mask = mask.squeeze(0).permute(2, 0, 1)
mask = to_pil_image(mask).convert("RGB")
mask = transform(mask)
image_list.append(mask)
# Organize the grid to have 'sample' images per row
Grid = make_grid(image_list, nrow=2)
# Change Grid tensor to be a consistent shape
# The Grid tensor has shape [1, 128, 128, 1] in some models
if Grid.dim() == 4:
# Remove the singleton batch dimension
Grid = Grid.squeeze(0)
if Grid.size(-1) == 1:
# Remove the singleton channel dimension (assuming grayscale)
Grid = Grid.squeeze(-1)
else:
raise ValueError("Expected a single channel (grayscale) image.")
# Convert the tensor grid to a PIL Image for display
img = torchvision.transforms.ToPILImage()(Grid)
images.append(img)
return images
def generate_images(num_images, model_id):
st.subheader("Generated Images:")
images = torch_images(num_images, model_id) # Pass the correct parameters
for i in range(len(images)):
st.image(
images[i],
caption=f"Generated Image {i+1} (Model ID: {model_id})",
use_container_width=True,
)
if __name__ == "__main__":
main() # Task 5.4.4: Call the main function