Generate_Images / app.py
ahmed792002's picture
Update app.py
d173e0c verified
import streamlit as st
import subprocess
import sys
import importlib
from medigan import Generators
def install_dependencies(model_id):
"""Install dependencies for a specific model"""
try:
result = subprocess.run(
f"python -m src.medigan.install_model_dependencies --model_id {model_id}",
shell=True,
check=True,
capture_output=True
)
st.success(f"Dependencies installed for {model_id}")
return True
except subprocess.CalledProcessError as e:
st.error(f"Failed to install dependencies: {e.stderr.decode()}")
return False
def check_and_install(package):
"""Check if a package is installed, install if missing"""
try:
importlib.import_module(package.split("[")[0]) # Handle extras like scikit-image[optional]
except ImportError:
st.warning(f"Installing missing dependency: {package}")
subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
st.success(f"Successfully installed {package}")
# Common dependencies for all models
COMMON_DEPENDENCIES = [
"numpy",
"pyyaml",
"opencv-contrib-python-headless",
"torch",
"torchvision",
"dominate",
"visdom",
"Pillow",
"imageio",
"scikit-image",
"pathlib" # For 'Path' dependency
]
def main():
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
st.title("🧠 Medical Image Generator")
# Pre-install common dependencies
with st.spinner("Checking system dependencies..."):
for dep in COMMON_DEPENDENCIES:
check_and_install(dep)
MODEL_IDS = [
"00001_DCGAN_MMG_CALC_ROI",
"00002_DCGAN_MMG_MASS_ROI",
"00003_CYCLEGAN_MMG_DENSITY_FULL",
"00004_PIX2PIX_MMG_MASSES_W_MASKS",
"00019_PGGAN_CHEST_XRAY"
]
with st.sidebar:
st.header("⚙️ Settings")
model_id = st.selectbox("Select Model", MODEL_IDS)
num_images = st.slider("Number of Images", 1, 8, 4)
if st.button("✨ Generate Images"):
with st.spinner("Initializing model..."):
if not install_dependencies(model_id):
return
generate_images(num_images, model_id)
def generate_images(num_images, model_id):
try:
generators = Generators()
images = []
for i in range(num_images):
sample = generators.generate(
model_id=model_id,
num_samples=1,
install_dependencies=False # Already handled by our install_dependencies()
)
img = to_pil_image(sample[0]).convert("RGB")
images.append(img)
cols = st.columns(4)
for idx, img in enumerate(images):
with cols[idx % 4]:
st.image(img, caption=f"Image {idx+1}", use_column_width=True)
st.markdown("---")
except Exception as e:
st.error(f"Generation failed: {str(e)}")
st.info("If the error persists, try manual installation:")
st.code(f"pip install {' '.join(COMMON_DEPENDENCIES)}")
if __name__ == "__main__":
main()