Spaces:
Running
Running
# Import needed libraries | |
import streamlit as st | |
import torchvision | |
from medigan import Generators | |
from torchvision.transforms.functional import to_pil_image | |
from torchvision.utils import make_grid | |
# Define the GAN models available in the app | |
model_ids = [ | |
"00001_DCGAN_MMG_CALC_ROI", | |
"00002_DCGAN_MMG_MASS_ROI", | |
"00003_CYCLEGAN_MMG_DENSITY_FULL", | |
"00004_PIX2PIX_MMG_MASSES_W_MASKS", | |
"00019_PGGAN_CHEST_XRAY" | |
] | |
def main(): | |
# Setup page configuration | |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide") | |
# Main page title and description | |
st.title("π§ MEDIGAN Medical Image Generator") | |
st.markdown(""" | |
**Generate synthetic medical images using GAN models.** | |
π Select model and parameters in the sidebar β Click **Generate Images** | |
""") | |
# Sidebar controls | |
with st.sidebar: | |
st.header("βοΈ Settings") | |
model_id = st.selectbox("Select GAN Model", model_ids) | |
num_images = st.number_input("Number of Images", 1, 7, 1) | |
generate_btn = st.button("β¨ Generate Images") | |
# Main content area | |
if generate_btn: | |
with st.spinner(f"Generating {num_images} image(s) using {model_id}..."): | |
generate_images(num_images, model_id) | |
def torch_images(num_images, model_id): | |
generators = Generators() | |
dataloader = generators.get_as_torch_dataloader( | |
model_id=model_id, | |
install_dependencies=True, | |
num_samples=num_images, | |
num_workers=0, | |
prefetch_factor=None | |
) | |
images = [] | |
for _, data_dict in enumerate(dataloader): | |
batch_images = [] | |
for tensor in data_dict.values(): | |
if tensor.dim() == 4: | |
tensor = tensor.squeeze(0).permute(2, 0, 1) | |
img = to_pil_image(tensor).convert("RGB") | |
batch_images.append(img) | |
# Create image grid | |
grid_tensor = make_grid( | |
[torchvision.transforms.ToTensor()(img) for img in batch_images], | |
nrow=2 if len(batch_images) > 1 else 1 | |
) | |
grid_img = to_pil_image(grid_tensor) | |
images.append(grid_img) | |
return images | |
def generate_images(num_images, model_id): | |
# Clear previous results | |
st.empty() | |
# Generate and display new images | |
images = torch_images(num_images, model_id) | |
# Create columns for responsive layout | |
cols = st.columns(len(images)) | |
for col, img in zip(cols, images): | |
with col: | |
st.image( | |
img, | |
caption=f"Generated by: {model_id}", | |
width=300 | |
) | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() |