Spaces:
Running
Running
import streamlit as st | |
import subprocess | |
import sys | |
import importlib | |
from medigan import Generators | |
def install_dependencies(model_id): | |
"""Install dependencies for a specific model""" | |
try: | |
result = subprocess.run( | |
f"python -m src.medigan.install_model_dependencies --model_id {model_id}", | |
shell=True, | |
check=True, | |
capture_output=True | |
) | |
st.success(f"Dependencies installed for {model_id}") | |
return True | |
except subprocess.CalledProcessError as e: | |
st.error(f"Failed to install dependencies: {e.stderr.decode()}") | |
return False | |
def check_and_install(package): | |
"""Check if a package is installed, install if missing""" | |
try: | |
importlib.import_module(package.split("[")[0]) # Handle extras like scikit-image[optional] | |
except ImportError: | |
st.warning(f"Installing missing dependency: {package}") | |
subprocess.run([sys.executable, "-m", "pip", "install", package], check=True) | |
st.success(f"Successfully installed {package}") | |
# Common dependencies for all models | |
COMMON_DEPENDENCIES = [ | |
"numpy", | |
"pyyaml", | |
"opencv-contrib-python-headless", | |
"torch", | |
"torchvision", | |
"dominate", | |
"visdom", | |
"Pillow", | |
"imageio", | |
"scikit-image", | |
"pathlib" # For 'Path' dependency | |
] | |
def main(): | |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide") | |
st.title("🧠 Medical Image Generator") | |
# Pre-install common dependencies | |
with st.spinner("Checking system dependencies..."): | |
for dep in COMMON_DEPENDENCIES: | |
check_and_install(dep) | |
MODEL_IDS = [ | |
"00001_DCGAN_MMG_CALC_ROI", | |
"00002_DCGAN_MMG_MASS_ROI", | |
"00003_CYCLEGAN_MMG_DENSITY_FULL", | |
"00004_PIX2PIX_MMG_MASSES_W_MASKS", | |
"00019_PGGAN_CHEST_XRAY" | |
] | |
with st.sidebar: | |
st.header("⚙️ Settings") | |
model_id = st.selectbox("Select Model", MODEL_IDS) | |
num_images = st.slider("Number of Images", 1, 8, 4) | |
if st.button("✨ Generate Images"): | |
with st.spinner("Initializing model..."): | |
if not install_dependencies(model_id): | |
return | |
generate_images(num_images, model_id) | |
def generate_images(num_images, model_id): | |
try: | |
generators = Generators() | |
images = [] | |
for i in range(num_images): | |
sample = generators.generate( | |
model_id=model_id, | |
num_samples=1, | |
install_dependencies=False # Already handled by our install_dependencies() | |
) | |
img = to_pil_image(sample[0]).convert("RGB") | |
images.append(img) | |
cols = st.columns(4) | |
for idx, img in enumerate(images): | |
with cols[idx % 4]: | |
st.image(img, caption=f"Image {idx+1}", use_column_width=True) | |
st.markdown("---") | |
except Exception as e: | |
st.error(f"Generation failed: {str(e)}") | |
st.info("If the error persists, try manual installation:") | |
st.code(f"pip install {' '.join(COMMON_DEPENDENCIES)}") | |
if __name__ == "__main__": | |
main() |