ahmed792002 commited on
Commit
d173e0c
·
verified ·
1 Parent(s): ef30ad0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -41
app.py CHANGED
@@ -1,62 +1,74 @@
1
  import streamlit as st
2
- import torch
3
- import os
 
4
  from medigan import Generators
5
- from torchvision.transforms.functional import to_pil_image
6
 
7
- def install_and_patch_model(model_id):
 
8
  try:
9
- # Install model if not exists
10
- generators = Generators()
11
- if model_id not in generators.get_model_ids():
12
- generators.download_and_install_model(model_id)
13
-
14
- # Locate model files
15
- model_path = os.path.join(os.path.expanduser("~"), ".medigan", "models", model_id)
16
-
17
- # Patch base_GAN.py
18
- base_gan_path = os.path.join(model_path, "model19", "base_GAN.py")
19
- if os.path.exists(base_gan_path):
20
- with open(base_gan_path, "r") as f:
21
- content = f.read()
22
-
23
- # Replace problematic optimizer lines
24
- content = content.replace(
25
- "betas=(self.beta1, self.beta2)",
26
- "betas=(0.5, 0.999)" # Force correct beta values
27
- )
28
-
29
- with open(base_gan_path, "w") as f:
30
- f.write(content)
31
-
32
- st.success("Model successfully patched!")
33
  return True
34
- except Exception as e:
35
- st.error(f"Patching failed: {str(e)}")
36
  return False
37
 
38
- MODEL_IDS = [
39
- "00001_DCGAN_MMG_CALC_ROI",
40
- "00002_DCGAN_MMG_MASS_ROI",
41
- "00003_CYCLEGAN_MMG_DENSITY_FULL",
42
- "00004_PIX2PIX_MMG_MASSES_W_MASKS",
43
- "00019_PGGAN_CHEST_XRAY"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ]
45
 
46
  def main():
47
  st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
48
  st.title("🧠 Medical Image Generator")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with st.sidebar:
51
  st.header("⚙️ Settings")
52
  model_id = st.selectbox("Select Model", MODEL_IDS)
53
  num_images = st.slider("Number of Images", 1, 8, 4)
54
 
55
  if st.button("✨ Generate Images"):
56
- if "00019" in model_id and not install_and_patch_model(model_id):
57
- return
58
-
59
- with st.spinner("Generating images..."):
60
  generate_images(num_images, model_id)
61
 
62
  def generate_images(num_images, model_id):
@@ -68,7 +80,7 @@ def generate_images(num_images, model_id):
68
  sample = generators.generate(
69
  model_id=model_id,
70
  num_samples=1,
71
- install_dependencies=False
72
  )
73
  img = to_pil_image(sample[0]).convert("RGB")
74
  images.append(img)
@@ -81,6 +93,8 @@ def generate_images(num_images, model_id):
81
 
82
  except Exception as e:
83
  st.error(f"Generation failed: {str(e)}")
 
 
84
 
85
  if __name__ == "__main__":
86
  main()
 
1
  import streamlit as st
2
+ import subprocess
3
+ import sys
4
+ import importlib
5
  from medigan import Generators
 
6
 
7
+ def install_dependencies(model_id):
8
+ """Install dependencies for a specific model"""
9
  try:
10
+ result = subprocess.run(
11
+ f"python -m src.medigan.install_model_dependencies --model_id {model_id}",
12
+ shell=True,
13
+ check=True,
14
+ capture_output=True
15
+ )
16
+ st.success(f"Dependencies installed for {model_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  return True
18
+ except subprocess.CalledProcessError as e:
19
+ st.error(f"Failed to install dependencies: {e.stderr.decode()}")
20
  return False
21
 
22
+ def check_and_install(package):
23
+ """Check if a package is installed, install if missing"""
24
+ try:
25
+ importlib.import_module(package.split("[")[0]) # Handle extras like scikit-image[optional]
26
+ except ImportError:
27
+ st.warning(f"Installing missing dependency: {package}")
28
+ subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
29
+ st.success(f"Successfully installed {package}")
30
+
31
+ # Common dependencies for all models
32
+ COMMON_DEPENDENCIES = [
33
+ "numpy",
34
+ "pyyaml",
35
+ "opencv-contrib-python-headless",
36
+ "torch",
37
+ "torchvision",
38
+ "dominate",
39
+ "visdom",
40
+ "Pillow",
41
+ "imageio",
42
+ "scikit-image",
43
+ "pathlib" # For 'Path' dependency
44
  ]
45
 
46
  def main():
47
  st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
48
  st.title("🧠 Medical Image Generator")
49
 
50
+ # Pre-install common dependencies
51
+ with st.spinner("Checking system dependencies..."):
52
+ for dep in COMMON_DEPENDENCIES:
53
+ check_and_install(dep)
54
+
55
+ MODEL_IDS = [
56
+ "00001_DCGAN_MMG_CALC_ROI",
57
+ "00002_DCGAN_MMG_MASS_ROI",
58
+ "00003_CYCLEGAN_MMG_DENSITY_FULL",
59
+ "00004_PIX2PIX_MMG_MASSES_W_MASKS",
60
+ "00019_PGGAN_CHEST_XRAY"
61
+ ]
62
+
63
  with st.sidebar:
64
  st.header("⚙️ Settings")
65
  model_id = st.selectbox("Select Model", MODEL_IDS)
66
  num_images = st.slider("Number of Images", 1, 8, 4)
67
 
68
  if st.button("✨ Generate Images"):
69
+ with st.spinner("Initializing model..."):
70
+ if not install_dependencies(model_id):
71
+ return
 
72
  generate_images(num_images, model_id)
73
 
74
  def generate_images(num_images, model_id):
 
80
  sample = generators.generate(
81
  model_id=model_id,
82
  num_samples=1,
83
+ install_dependencies=False # Already handled by our install_dependencies()
84
  )
85
  img = to_pil_image(sample[0]).convert("RGB")
86
  images.append(img)
 
93
 
94
  except Exception as e:
95
  st.error(f"Generation failed: {str(e)}")
96
+ st.info("If the error persists, try manual installation:")
97
+ st.code(f"pip install {' '.join(COMMON_DEPENDENCIES)}")
98
 
99
  if __name__ == "__main__":
100
  main()