File size: 3,235 Bytes
c6d85c0
d173e0c
 
 
c6d85c0
 
d173e0c
 
078cd45
d173e0c
 
 
 
 
 
 
ef30ad0
d173e0c
 
ef30ad0
078cd45
d173e0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d85c0
 
 
afbb2d7
078cd45
afbb2d7
d173e0c
 
 
 
 
 
 
 
 
 
 
 
 
afbb2d7
 
078cd45
 
ef30ad0
 
d173e0c
 
 
078cd45
c6d85c0
 
ef30ad0
 
 
 
 
 
 
 
d173e0c
ef30ad0
 
 
 
 
 
 
 
078cd45
ef30ad0
 
 
d173e0c
 
c6d85c0
 
afbb2d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import subprocess
import sys
import importlib
from medigan import Generators

def install_dependencies(model_id):
    """Install dependencies for a specific model"""
    try:
        result = subprocess.run(
            f"python -m src.medigan.install_model_dependencies --model_id {model_id}",
            shell=True,
            check=True,
            capture_output=True
        )
        st.success(f"Dependencies installed for {model_id}")
        return True
    except subprocess.CalledProcessError as e:
        st.error(f"Failed to install dependencies: {e.stderr.decode()}")
        return False

def check_and_install(package):
    """Check if a package is installed, install if missing"""
    try:
        importlib.import_module(package.split("[")[0])  # Handle extras like scikit-image[optional]
    except ImportError:
        st.warning(f"Installing missing dependency: {package}")
        subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
        st.success(f"Successfully installed {package}")

# Common dependencies for all models
COMMON_DEPENDENCIES = [
    "numpy",
    "pyyaml",
    "opencv-contrib-python-headless",
    "torch",
    "torchvision",
    "dominate",
    "visdom",
    "Pillow",
    "imageio",
    "scikit-image",
    "pathlib"  # For 'Path' dependency
]

def main():
    st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
    st.title("🧠 Medical Image Generator")
    
    # Pre-install common dependencies
    with st.spinner("Checking system dependencies..."):
        for dep in COMMON_DEPENDENCIES:
            check_and_install(dep)
    
    MODEL_IDS = [
        "00001_DCGAN_MMG_CALC_ROI",
        "00002_DCGAN_MMG_MASS_ROI",
        "00003_CYCLEGAN_MMG_DENSITY_FULL",
        "00004_PIX2PIX_MMG_MASSES_W_MASKS",
        "00019_PGGAN_CHEST_XRAY"
    ]
    
    with st.sidebar:
        st.header("⚙️ Settings")
        model_id = st.selectbox("Select Model", MODEL_IDS)
        num_images = st.slider("Number of Images", 1, 8, 4)
        
        if st.button("✨ Generate Images"):
            with st.spinner("Initializing model..."):
                if not install_dependencies(model_id):
                    return
                generate_images(num_images, model_id)

def generate_images(num_images, model_id):
    try:
        generators = Generators()
        images = []
        
        for i in range(num_images):
            sample = generators.generate(
                model_id=model_id,
                num_samples=1,
                install_dependencies=False  # Already handled by our install_dependencies()
            )
            img = to_pil_image(sample[0]).convert("RGB")
            images.append(img)
        
        cols = st.columns(4)
        for idx, img in enumerate(images):
            with cols[idx % 4]:
                st.image(img, caption=f"Image {idx+1}", use_column_width=True)
                st.markdown("---")
                
    except Exception as e:
        st.error(f"Generation failed: {str(e)}")
        st.info("If the error persists, try manual installation:")
        st.code(f"pip install {' '.join(COMMON_DEPENDENCIES)}")

if __name__ == "__main__":
    main()