ahmed792002 commited on
Commit
ef30ad0
·
verified ·
1 Parent(s): 529661d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -54
app.py CHANGED
@@ -1,39 +1,40 @@
1
- # Import required libraries
2
  import streamlit as st
3
  import torch
 
4
  from medigan import Generators
5
  from torchvision.transforms.functional import to_pil_image
6
 
7
- # Apply critical patch for PGGAN model compatibility
8
- def apply_pggan_patch():
9
  try:
10
- from model19.base_GAN import BaseGAN
 
 
 
11
 
12
- # Monkey-patch the BaseGAN class initialization
13
- original_init = BaseGAN.__init__
14
 
15
- def patched_init(self, dimLatentVector, **kwargs):
16
- original_init(self, dimLatentVector, **kwargs)
 
 
 
17
 
18
- # Force correct optimizer parameters
19
- self.optimizerG = torch.optim.Adam(
20
- self.netG.parameters(),
21
- lr=0.0002,
22
- betas=(0.5, 0.999) # Explicit tuple of floats
23
  )
24
 
25
- self.optimizerD = torch.optim.Adam(
26
- self.netD.parameters(),
27
- lr=0.0002,
28
- betas=(0.5, 0.999) # Explicit tuple of floats
29
- )
30
-
31
- BaseGAN.__init__ = patched_init
32
- st.success("Applied PGGAN compatibility patch!")
33
  except Exception as e:
34
- st.warning(f"Patching failed: {str(e)}")
 
35
 
36
- # Model configuration
37
  MODEL_IDS = [
38
  "00001_DCGAN_MMG_CALC_ROI",
39
  "00002_DCGAN_MMG_MASS_ROI",
@@ -46,45 +47,40 @@ def main():
46
  st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
47
  st.title("🧠 Medical Image Generator")
48
 
49
- # Sidebar controls
50
  with st.sidebar:
51
  st.header("⚙️ Settings")
52
  model_id = st.selectbox("Select Model", MODEL_IDS)
53
  num_images = st.slider("Number of Images", 1, 8, 4)
54
- generate_btn = st.button("✨ Generate Images")
55
-
56
- # Main content area
57
- if generate_btn:
58
- with st.spinner(f"Generating {num_images} images..."):
59
- try:
60
  generate_images(num_images, model_id)
61
- except Exception as e:
62
- st.error(f"Generation failed: {str(e)}")
63
 
64
  def generate_images(num_images, model_id):
65
- # Apply model-specific patches
66
- if "00019" in model_id:
67
- apply_pggan_patch()
68
-
69
- # Initialize generator
70
- generators = Generators()
71
-
72
- # Generate and display images
73
- cols = st.columns(4)
74
- for i in range(num_images):
75
- with cols[i % 4]:
76
- try:
77
- # Generate single sample
78
- sample = generators.generate(
79
- model_id=model_id,
80
- num_samples=1,
81
- install_dependencies=True
82
- )
83
- img = to_pil_image(sample[0]).convert("RGB")
84
- st.image(img, caption=f"Image {i+1}", use_column_width=True)
85
  st.markdown("---")
86
- except Exception as e:
87
- st.error(f"Error generating image {i+1}: {str(e)}")
 
88
 
89
  if __name__ == "__main__":
90
  main()
 
 
1
  import streamlit as st
2
  import torch
3
+ import os
4
  from medigan import Generators
5
  from torchvision.transforms.functional import to_pil_image
6
 
7
+ def install_and_patch_model(model_id):
 
8
  try:
9
+ # Install model if not exists
10
+ generators = Generators()
11
+ if model_id not in generators.get_model_ids():
12
+ generators.download_and_install_model(model_id)
13
 
14
+ # Locate model files
15
+ model_path = os.path.join(os.path.expanduser("~"), ".medigan", "models", model_id)
16
 
17
+ # Patch base_GAN.py
18
+ base_gan_path = os.path.join(model_path, "model19", "base_GAN.py")
19
+ if os.path.exists(base_gan_path):
20
+ with open(base_gan_path, "r") as f:
21
+ content = f.read()
22
 
23
+ # Replace problematic optimizer lines
24
+ content = content.replace(
25
+ "betas=(self.beta1, self.beta2)",
26
+ "betas=(0.5, 0.999)" # Force correct beta values
 
27
  )
28
 
29
+ with open(base_gan_path, "w") as f:
30
+ f.write(content)
31
+
32
+ st.success("Model successfully patched!")
33
+ return True
 
 
 
34
  except Exception as e:
35
+ st.error(f"Patching failed: {str(e)}")
36
+ return False
37
 
 
38
  MODEL_IDS = [
39
  "00001_DCGAN_MMG_CALC_ROI",
40
  "00002_DCGAN_MMG_MASS_ROI",
 
47
  st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
48
  st.title("🧠 Medical Image Generator")
49
 
 
50
  with st.sidebar:
51
  st.header("⚙️ Settings")
52
  model_id = st.selectbox("Select Model", MODEL_IDS)
53
  num_images = st.slider("Number of Images", 1, 8, 4)
54
+
55
+ if st.button("✨ Generate Images"):
56
+ if "00019" in model_id and not install_and_patch_model(model_id):
57
+ return
58
+
59
+ with st.spinner("Generating images..."):
60
  generate_images(num_images, model_id)
 
 
61
 
62
  def generate_images(num_images, model_id):
63
+ try:
64
+ generators = Generators()
65
+ images = []
66
+
67
+ for i in range(num_images):
68
+ sample = generators.generate(
69
+ model_id=model_id,
70
+ num_samples=1,
71
+ install_dependencies=False
72
+ )
73
+ img = to_pil_image(sample[0]).convert("RGB")
74
+ images.append(img)
75
+
76
+ cols = st.columns(4)
77
+ for idx, img in enumerate(images):
78
+ with cols[idx % 4]:
79
+ st.image(img, caption=f"Image {idx+1}", use_column_width=True)
 
 
 
80
  st.markdown("---")
81
+
82
+ except Exception as e:
83
+ st.error(f"Generation failed: {str(e)}")
84
 
85
  if __name__ == "__main__":
86
  main()