Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,62 +1,74 @@
|
|
1 |
import streamlit as st
|
2 |
-
import
|
3 |
-
import
|
|
|
4 |
from medigan import Generators
|
5 |
-
from torchvision.transforms.functional import to_pil_image
|
6 |
|
7 |
-
def
|
|
|
8 |
try:
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
# Patch base_GAN.py
|
18 |
-
base_gan_path = os.path.join(model_path, "model19", "base_GAN.py")
|
19 |
-
if os.path.exists(base_gan_path):
|
20 |
-
with open(base_gan_path, "r") as f:
|
21 |
-
content = f.read()
|
22 |
-
|
23 |
-
# Replace problematic optimizer lines
|
24 |
-
content = content.replace(
|
25 |
-
"betas=(self.beta1, self.beta2)",
|
26 |
-
"betas=(0.5, 0.999)" # Force correct beta values
|
27 |
-
)
|
28 |
-
|
29 |
-
with open(base_gan_path, "w") as f:
|
30 |
-
f.write(content)
|
31 |
-
|
32 |
-
st.success("Model successfully patched!")
|
33 |
return True
|
34 |
-
except
|
35 |
-
st.error(f"
|
36 |
return False
|
37 |
|
38 |
-
|
39 |
-
"
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
]
|
45 |
|
46 |
def main():
|
47 |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
|
48 |
st.title("🧠 Medical Image Generator")
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
with st.sidebar:
|
51 |
st.header("⚙️ Settings")
|
52 |
model_id = st.selectbox("Select Model", MODEL_IDS)
|
53 |
num_images = st.slider("Number of Images", 1, 8, 4)
|
54 |
|
55 |
if st.button("✨ Generate Images"):
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
with st.spinner("Generating images..."):
|
60 |
generate_images(num_images, model_id)
|
61 |
|
62 |
def generate_images(num_images, model_id):
|
@@ -68,7 +80,7 @@ def generate_images(num_images, model_id):
|
|
68 |
sample = generators.generate(
|
69 |
model_id=model_id,
|
70 |
num_samples=1,
|
71 |
-
install_dependencies=False
|
72 |
)
|
73 |
img = to_pil_image(sample[0]).convert("RGB")
|
74 |
images.append(img)
|
@@ -81,6 +93,8 @@ def generate_images(num_images, model_id):
|
|
81 |
|
82 |
except Exception as e:
|
83 |
st.error(f"Generation failed: {str(e)}")
|
|
|
|
|
84 |
|
85 |
if __name__ == "__main__":
|
86 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
import importlib
|
5 |
from medigan import Generators
|
|
|
6 |
|
7 |
+
def install_dependencies(model_id):
|
8 |
+
"""Install dependencies for a specific model"""
|
9 |
try:
|
10 |
+
result = subprocess.run(
|
11 |
+
f"python -m src.medigan.install_model_dependencies --model_id {model_id}",
|
12 |
+
shell=True,
|
13 |
+
check=True,
|
14 |
+
capture_output=True
|
15 |
+
)
|
16 |
+
st.success(f"Dependencies installed for {model_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
return True
|
18 |
+
except subprocess.CalledProcessError as e:
|
19 |
+
st.error(f"Failed to install dependencies: {e.stderr.decode()}")
|
20 |
return False
|
21 |
|
22 |
+
def check_and_install(package):
|
23 |
+
"""Check if a package is installed, install if missing"""
|
24 |
+
try:
|
25 |
+
importlib.import_module(package.split("[")[0]) # Handle extras like scikit-image[optional]
|
26 |
+
except ImportError:
|
27 |
+
st.warning(f"Installing missing dependency: {package}")
|
28 |
+
subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
|
29 |
+
st.success(f"Successfully installed {package}")
|
30 |
+
|
31 |
+
# Common dependencies for all models
|
32 |
+
COMMON_DEPENDENCIES = [
|
33 |
+
"numpy",
|
34 |
+
"pyyaml",
|
35 |
+
"opencv-contrib-python-headless",
|
36 |
+
"torch",
|
37 |
+
"torchvision",
|
38 |
+
"dominate",
|
39 |
+
"visdom",
|
40 |
+
"Pillow",
|
41 |
+
"imageio",
|
42 |
+
"scikit-image",
|
43 |
+
"pathlib" # For 'Path' dependency
|
44 |
]
|
45 |
|
46 |
def main():
|
47 |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
|
48 |
st.title("🧠 Medical Image Generator")
|
49 |
|
50 |
+
# Pre-install common dependencies
|
51 |
+
with st.spinner("Checking system dependencies..."):
|
52 |
+
for dep in COMMON_DEPENDENCIES:
|
53 |
+
check_and_install(dep)
|
54 |
+
|
55 |
+
MODEL_IDS = [
|
56 |
+
"00001_DCGAN_MMG_CALC_ROI",
|
57 |
+
"00002_DCGAN_MMG_MASS_ROI",
|
58 |
+
"00003_CYCLEGAN_MMG_DENSITY_FULL",
|
59 |
+
"00004_PIX2PIX_MMG_MASSES_W_MASKS",
|
60 |
+
"00019_PGGAN_CHEST_XRAY"
|
61 |
+
]
|
62 |
+
|
63 |
with st.sidebar:
|
64 |
st.header("⚙️ Settings")
|
65 |
model_id = st.selectbox("Select Model", MODEL_IDS)
|
66 |
num_images = st.slider("Number of Images", 1, 8, 4)
|
67 |
|
68 |
if st.button("✨ Generate Images"):
|
69 |
+
with st.spinner("Initializing model..."):
|
70 |
+
if not install_dependencies(model_id):
|
71 |
+
return
|
|
|
72 |
generate_images(num_images, model_id)
|
73 |
|
74 |
def generate_images(num_images, model_id):
|
|
|
80 |
sample = generators.generate(
|
81 |
model_id=model_id,
|
82 |
num_samples=1,
|
83 |
+
install_dependencies=False # Already handled by our install_dependencies()
|
84 |
)
|
85 |
img = to_pil_image(sample[0]).convert("RGB")
|
86 |
images.append(img)
|
|
|
93 |
|
94 |
except Exception as e:
|
95 |
st.error(f"Generation failed: {str(e)}")
|
96 |
+
st.info("If the error persists, try manual installation:")
|
97 |
+
st.code(f"pip install {' '.join(COMMON_DEPENDENCIES)}")
|
98 |
|
99 |
if __name__ == "__main__":
|
100 |
main()
|