Spaces:
Running
Running
# Import required libraries | |
import streamlit as st | |
import torch | |
from medigan import Generators | |
from torchvision.transforms.functional import to_pil_image | |
# Apply critical patch for PGGAN model compatibility | |
def apply_pggan_patch(): | |
try: | |
from model19.base_GAN import BaseGAN | |
# Monkey-patch the BaseGAN class initialization | |
original_init = BaseGAN.__init__ | |
def patched_init(self, dimLatentVector, **kwargs): | |
original_init(self, dimLatentVector, **kwargs) | |
# Force correct optimizer parameters | |
self.optimizerG = torch.optim.Adam( | |
self.netG.parameters(), | |
lr=0.0002, | |
betas=(0.5, 0.999) # Explicit tuple of floats | |
) | |
self.optimizerD = torch.optim.Adam( | |
self.netD.parameters(), | |
lr=0.0002, | |
betas=(0.5, 0.999) # Explicit tuple of floats | |
) | |
BaseGAN.__init__ = patched_init | |
st.success("Applied PGGAN compatibility patch!") | |
except Exception as e: | |
st.warning(f"Patching failed: {str(e)}") | |
# Model configuration | |
MODEL_IDS = [ | |
"00001_DCGAN_MMG_CALC_ROI", | |
"00002_DCGAN_MMG_MASS_ROI", | |
"00003_CYCLEGAN_MMG_DENSITY_FULL", | |
"00004_PIX2PIX_MMG_MASSES_W_MASKS", | |
"00019_PGGAN_CHEST_XRAY" | |
] | |
def main(): | |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide") | |
st.title("🧠 Medical Image Generator") | |
# Sidebar controls | |
with st.sidebar: | |
st.header("⚙️ Settings") | |
model_id = st.selectbox("Select Model", MODEL_IDS) | |
num_images = st.slider("Number of Images", 1, 8, 4) | |
generate_btn = st.button("✨ Generate Images") | |
# Main content area | |
if generate_btn: | |
with st.spinner(f"Generating {num_images} images..."): | |
try: | |
generate_images(num_images, model_id) | |
except Exception as e: | |
st.error(f"Generation failed: {str(e)}") | |
def generate_images(num_images, model_id): | |
# Apply model-specific patches | |
if "00019" in model_id: | |
apply_pggan_patch() | |
# Initialize generator | |
generators = Generators() | |
# Generate and display images | |
cols = st.columns(4) | |
for i in range(num_images): | |
with cols[i % 4]: | |
try: | |
# Generate single sample | |
sample = generators.generate( | |
model_id=model_id, | |
num_samples=1, | |
install_dependencies=True | |
) | |
img = to_pil_image(sample[0]).convert("RGB") | |
st.image(img, caption=f"Image {i+1}", use_column_width=True) | |
st.markdown("---") | |
except Exception as e: | |
st.error(f"Error generating image {i+1}: {str(e)}") | |
if __name__ == "__main__": | |
main() |