ahmed792002 commited on
Commit
078cd45
·
verified ·
1 Parent(s): 8421ad0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -61
app.py CHANGED
@@ -1,12 +1,40 @@
1
- # Import needed libraries
2
  import streamlit as st
3
- import torchvision
4
  from medigan import Generators
5
  from torchvision.transforms.functional import to_pil_image
6
- from torchvision.utils import make_grid
7
 
8
- # Define the GAN models available in the app
9
- model_ids = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  "00001_DCGAN_MMG_CALC_ROI",
11
  "00002_DCGAN_MMG_MASS_ROI",
12
  "00003_CYCLEGAN_MMG_DENSITY_FULL",
@@ -15,75 +43,48 @@ model_ids = [
15
  ]
16
 
17
  def main():
18
- # Setup page configuration
19
  st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
20
-
21
- # Main page title and description
22
- st.title("🧠 MEDIGAN Medical Image Generator")
23
- st.markdown("""
24
- **Generate synthetic medical images using GAN models.**
25
- 🔍 Select model and parameters in the sidebar → Click **Generate Images**
26
- """)
27
 
28
  # Sidebar controls
29
  with st.sidebar:
30
  st.header("⚙️ Settings")
31
- model_id = st.selectbox("Select GAN Model", model_ids)
32
- num_images = st.number_input("Number of Images", 1, 7, 1)
33
  generate_btn = st.button("✨ Generate Images")
34
 
35
  # Main content area
36
  if generate_btn:
37
- with st.spinner(f"Generating {num_images} image(s) using {model_id}..."):
38
- generate_images(num_images, model_id)
39
-
40
- def torch_images(num_images, model_id):
41
- generators = Generators()
42
- dataloader = generators.get_as_torch_dataloader(
43
- model_id=model_id,
44
- install_dependencies=True,
45
- num_samples=num_images,
46
- num_workers=0,
47
- prefetch_factor=None
48
- )
49
-
50
- images = []
51
- for _, data_dict in enumerate(dataloader):
52
- batch_images = []
53
- for tensor in data_dict.values():
54
- if tensor.dim() == 4:
55
- tensor = tensor.squeeze(0).permute(2, 0, 1)
56
- img = to_pil_image(tensor).convert("RGB")
57
- batch_images.append(img)
58
-
59
- # Create image grid
60
- grid_tensor = make_grid(
61
- [torchvision.transforms.ToTensor()(img) for img in batch_images],
62
- nrow=2 if len(batch_images) > 1 else 1
63
- )
64
- grid_img = to_pil_image(grid_tensor)
65
- images.append(grid_img)
66
-
67
- return images
68
 
69
  def generate_images(num_images, model_id):
70
- # Clear previous results
71
- st.empty()
 
72
 
73
- # Generate and display new images
74
- images = torch_images(num_images, model_id)
75
-
76
- # Create columns for responsive layout
77
- cols = st.columns(len(images))
78
 
79
- for col, img in zip(cols, images):
80
- with col:
81
- st.image(
82
- img,
83
- caption=f"Generated by: {model_id}",
84
- width=300
85
- )
86
- st.markdown("---")
 
 
 
 
 
 
 
 
87
 
88
  if __name__ == "__main__":
89
  main()
 
1
+ # Import required libraries
2
  import streamlit as st
3
+ import torch
4
  from medigan import Generators
5
  from torchvision.transforms.functional import to_pil_image
 
6
 
7
+ # Apply critical patch for PGGAN model compatibility
8
+ def apply_pggan_patch():
9
+ try:
10
+ from model19.base_GAN import BaseGAN
11
+
12
+ # Monkey-patch the BaseGAN class initialization
13
+ original_init = BaseGAN.__init__
14
+
15
+ def patched_init(self, dimLatentVector, **kwargs):
16
+ original_init(self, dimLatentVector, **kwargs)
17
+
18
+ # Force correct optimizer parameters
19
+ self.optimizerG = torch.optim.Adam(
20
+ self.netG.parameters(),
21
+ lr=0.0002,
22
+ betas=(0.5, 0.999) # Explicit tuple of floats
23
+ )
24
+
25
+ self.optimizerD = torch.optim.Adam(
26
+ self.netD.parameters(),
27
+ lr=0.0002,
28
+ betas=(0.5, 0.999) # Explicit tuple of floats
29
+ )
30
+
31
+ BaseGAN.__init__ = patched_init
32
+ st.success("Applied PGGAN compatibility patch!")
33
+ except Exception as e:
34
+ st.warning(f"Patching failed: {str(e)}")
35
+
36
+ # Model configuration
37
+ MODEL_IDS = [
38
  "00001_DCGAN_MMG_CALC_ROI",
39
  "00002_DCGAN_MMG_MASS_ROI",
40
  "00003_CYCLEGAN_MMG_DENSITY_FULL",
 
43
  ]
44
 
45
  def main():
 
46
  st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
47
+ st.title("🧠 Medical Image Generator")
 
 
 
 
 
 
48
 
49
  # Sidebar controls
50
  with st.sidebar:
51
  st.header("⚙️ Settings")
52
+ model_id = st.selectbox("Select Model", MODEL_IDS)
53
+ num_images = st.slider("Number of Images", 1, 8, 4)
54
  generate_btn = st.button("✨ Generate Images")
55
 
56
  # Main content area
57
  if generate_btn:
58
+ with st.spinner(f"Generating {num_images} images..."):
59
+ try:
60
+ generate_images(num_images, model_id)
61
+ except Exception as e:
62
+ st.error(f"Generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def generate_images(num_images, model_id):
65
+ # Apply model-specific patches
66
+ if "00019" in model_id:
67
+ apply_pggan_patch()
68
 
69
+ # Initialize generator
70
+ generators = Generators()
 
 
 
71
 
72
+ # Generate and display images
73
+ cols = st.columns(4)
74
+ for i in range(num_images):
75
+ with cols[i % 4]:
76
+ try:
77
+ # Generate single sample
78
+ sample = generators.generate(
79
+ model_id=model_id,
80
+ num_samples=1,
81
+ install_dependencies=True
82
+ )
83
+ img = to_pil_image(sample[0]).convert("RGB")
84
+ st.image(img, caption=f"Image {i+1}", use_column_width=True)
85
+ st.markdown("---")
86
+ except Exception as e:
87
+ st.error(f"Error generating image {i+1}: {str(e)}")
88
 
89
  if __name__ == "__main__":
90
  main()