import streamlit as st import subprocess import sys import importlib from medigan import Generators def install_dependencies(model_id): """Install dependencies for a specific model""" try: result = subprocess.run( f"python -m src.medigan.install_model_dependencies --model_id {model_id}", shell=True, check=True, capture_output=True ) st.success(f"Dependencies installed for {model_id}") return True except subprocess.CalledProcessError as e: st.error(f"Failed to install dependencies: {e.stderr.decode()}") return False def check_and_install(package): """Check if a package is installed, install if missing""" try: importlib.import_module(package.split("[")[0]) # Handle extras like scikit-image[optional] except ImportError: st.warning(f"Installing missing dependency: {package}") subprocess.run([sys.executable, "-m", "pip", "install", package], check=True) st.success(f"Successfully installed {package}") # Common dependencies for all models COMMON_DEPENDENCIES = [ "numpy", "pyyaml", "opencv-contrib-python-headless", "torch", "torchvision", "dominate", "visdom", "Pillow", "imageio", "scikit-image", "pathlib" # For 'Path' dependency ] def main(): st.set_page_config(page_title="MEDIGAN Generator", layout="wide") st.title("🧠 Medical Image Generator") # Pre-install common dependencies with st.spinner("Checking system dependencies..."): for dep in COMMON_DEPENDENCIES: check_and_install(dep) MODEL_IDS = [ "00001_DCGAN_MMG_CALC_ROI", "00002_DCGAN_MMG_MASS_ROI", "00003_CYCLEGAN_MMG_DENSITY_FULL", "00004_PIX2PIX_MMG_MASSES_W_MASKS", "00019_PGGAN_CHEST_XRAY" ] with st.sidebar: st.header("⚙️ Settings") model_id = st.selectbox("Select Model", MODEL_IDS) num_images = st.slider("Number of Images", 1, 8, 4) if st.button("✨ Generate Images"): with st.spinner("Initializing model..."): if not install_dependencies(model_id): return generate_images(num_images, model_id) def generate_images(num_images, model_id): try: generators = Generators() images = [] for i in range(num_images): sample = generators.generate( model_id=model_id, num_samples=1, install_dependencies=False # Already handled by our install_dependencies() ) img = to_pil_image(sample[0]).convert("RGB") images.append(img) cols = st.columns(4) for idx, img in enumerate(images): with cols[idx % 4]: st.image(img, caption=f"Image {idx+1}", use_column_width=True) st.markdown("---") except Exception as e: st.error(f"Generation failed: {str(e)}") st.info("If the error persists, try manual installation:") st.code(f"pip install {' '.join(COMMON_DEPENDENCIES)}") if __name__ == "__main__": main()