Generate_Images / app.py
ahmed792002's picture
Update app.py
afbb2d7 verified
raw
history blame
2.7 kB
# Import needed libraries
import streamlit as st
import torchvision
from medigan import Generators
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
# Define the GAN models available in the app
model_ids = [
"00001_DCGAN_MMG_CALC_ROI",
"00002_DCGAN_MMG_MASS_ROI",
"00003_CYCLEGAN_MMG_DENSITY_FULL",
"00004_PIX2PIX_MMG_MASSES_W_MASKS",
"00019_PGGAN_CHEST_XRAY"
]
def main():
# Setup page configuration
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
# Main page title and description
st.title("🧠 MEDIGAN Medical Image Generator")
st.markdown("""
**Generate synthetic medical images using GAN models.**
πŸ” Select model and parameters in the sidebar β†’ Click **Generate Images**
""")
# Sidebar controls
with st.sidebar:
st.header("βš™οΈ Settings")
model_id = st.selectbox("Select GAN Model", model_ids)
num_images = st.number_input("Number of Images", 1, 7, 1)
generate_btn = st.button("✨ Generate Images")
# Main content area
if generate_btn:
with st.spinner(f"Generating {num_images} image(s) using {model_id}..."):
generate_images(num_images, model_id)
def torch_images(num_images, model_id):
generators = Generators()
dataloader = generators.get_as_torch_dataloader(
model_id=model_id,
install_dependencies=True,
num_samples=num_images,
num_workers=0,
prefetch_factor=None
)
images = []
for _, data_dict in enumerate(dataloader):
batch_images = []
for tensor in data_dict.values():
if tensor.dim() == 4:
tensor = tensor.squeeze(0).permute(2, 0, 1)
img = to_pil_image(tensor).convert("RGB")
batch_images.append(img)
# Create image grid
grid_tensor = make_grid(
[torchvision.transforms.ToTensor()(img) for img in batch_images],
nrow=2 if len(batch_images) > 1 else 1
)
grid_img = to_pil_image(grid_tensor)
images.append(grid_img)
return images
def generate_images(num_images, model_id):
# Clear previous results
st.empty()
# Generate and display new images
images = torch_images(num_images, model_id)
# Create columns for responsive layout
cols = st.columns(len(images))
for col, img in zip(cols, images):
with col:
st.image(
img,
caption=f"Generated by: {model_id}",
width=300
)
st.markdown("---")
if __name__ == "__main__":
main()