Spaces:
Running
Running
File size: 2,699 Bytes
c6d85c0 afbb2d7 c6d85c0 afbb2d7 c6d85c0 afbb2d7 c6d85c0 afbb2d7 c6d85c0 afbb2d7 c6d85c0 afbb2d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# Import needed libraries
import streamlit as st
import torchvision
from medigan import Generators
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
# Define the GAN models available in the app
model_ids = [
"00001_DCGAN_MMG_CALC_ROI",
"00002_DCGAN_MMG_MASS_ROI",
"00003_CYCLEGAN_MMG_DENSITY_FULL",
"00004_PIX2PIX_MMG_MASSES_W_MASKS",
"00019_PGGAN_CHEST_XRAY"
]
def main():
# Setup page configuration
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
# Main page title and description
st.title("🧠 MEDIGAN Medical Image Generator")
st.markdown("""
**Generate synthetic medical images using GAN models.**
🔍 Select model and parameters in the sidebar → Click **Generate Images**
""")
# Sidebar controls
with st.sidebar:
st.header("⚙️ Settings")
model_id = st.selectbox("Select GAN Model", model_ids)
num_images = st.number_input("Number of Images", 1, 7, 1)
generate_btn = st.button("✨ Generate Images")
# Main content area
if generate_btn:
with st.spinner(f"Generating {num_images} image(s) using {model_id}..."):
generate_images(num_images, model_id)
def torch_images(num_images, model_id):
generators = Generators()
dataloader = generators.get_as_torch_dataloader(
model_id=model_id,
install_dependencies=True,
num_samples=num_images,
num_workers=0,
prefetch_factor=None
)
images = []
for _, data_dict in enumerate(dataloader):
batch_images = []
for tensor in data_dict.values():
if tensor.dim() == 4:
tensor = tensor.squeeze(0).permute(2, 0, 1)
img = to_pil_image(tensor).convert("RGB")
batch_images.append(img)
# Create image grid
grid_tensor = make_grid(
[torchvision.transforms.ToTensor()(img) for img in batch_images],
nrow=2 if len(batch_images) > 1 else 1
)
grid_img = to_pil_image(grid_tensor)
images.append(grid_img)
return images
def generate_images(num_images, model_id):
# Clear previous results
st.empty()
# Generate and display new images
images = torch_images(num_images, model_id)
# Create columns for responsive layout
cols = st.columns(len(images))
for col, img in zip(cols, images):
with col:
st.image(
img,
caption=f"Generated by: {model_id}",
width=300
)
st.markdown("---")
if __name__ == "__main__":
main() |