Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,40 @@
|
|
1 |
-
# Import
|
2 |
import streamlit as st
|
3 |
-
import
|
4 |
from medigan import Generators
|
5 |
from torchvision.transforms.functional import to_pil_image
|
6 |
-
from torchvision.utils import make_grid
|
7 |
|
8 |
-
#
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"00001_DCGAN_MMG_CALC_ROI",
|
11 |
"00002_DCGAN_MMG_MASS_ROI",
|
12 |
"00003_CYCLEGAN_MMG_DENSITY_FULL",
|
@@ -15,75 +43,48 @@ model_ids = [
|
|
15 |
]
|
16 |
|
17 |
def main():
|
18 |
-
# Setup page configuration
|
19 |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
|
20 |
-
|
21 |
-
# Main page title and description
|
22 |
-
st.title("🧠 MEDIGAN Medical Image Generator")
|
23 |
-
st.markdown("""
|
24 |
-
**Generate synthetic medical images using GAN models.**
|
25 |
-
🔍 Select model and parameters in the sidebar → Click **Generate Images**
|
26 |
-
""")
|
27 |
|
28 |
# Sidebar controls
|
29 |
with st.sidebar:
|
30 |
st.header("⚙️ Settings")
|
31 |
-
model_id = st.selectbox("Select
|
32 |
-
num_images = st.
|
33 |
generate_btn = st.button("✨ Generate Images")
|
34 |
|
35 |
# Main content area
|
36 |
if generate_btn:
|
37 |
-
with st.spinner(f"Generating {num_images}
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
dataloader = generators.get_as_torch_dataloader(
|
43 |
-
model_id=model_id,
|
44 |
-
install_dependencies=True,
|
45 |
-
num_samples=num_images,
|
46 |
-
num_workers=0,
|
47 |
-
prefetch_factor=None
|
48 |
-
)
|
49 |
-
|
50 |
-
images = []
|
51 |
-
for _, data_dict in enumerate(dataloader):
|
52 |
-
batch_images = []
|
53 |
-
for tensor in data_dict.values():
|
54 |
-
if tensor.dim() == 4:
|
55 |
-
tensor = tensor.squeeze(0).permute(2, 0, 1)
|
56 |
-
img = to_pil_image(tensor).convert("RGB")
|
57 |
-
batch_images.append(img)
|
58 |
-
|
59 |
-
# Create image grid
|
60 |
-
grid_tensor = make_grid(
|
61 |
-
[torchvision.transforms.ToTensor()(img) for img in batch_images],
|
62 |
-
nrow=2 if len(batch_images) > 1 else 1
|
63 |
-
)
|
64 |
-
grid_img = to_pil_image(grid_tensor)
|
65 |
-
images.append(grid_img)
|
66 |
-
|
67 |
-
return images
|
68 |
|
69 |
def generate_images(num_images, model_id):
|
70 |
-
#
|
71 |
-
|
|
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
# Create columns for responsive layout
|
77 |
-
cols = st.columns(len(images))
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
if __name__ == "__main__":
|
89 |
main()
|
|
|
1 |
+
# Import required libraries
|
2 |
import streamlit as st
|
3 |
+
import torch
|
4 |
from medigan import Generators
|
5 |
from torchvision.transforms.functional import to_pil_image
|
|
|
6 |
|
7 |
+
# Apply critical patch for PGGAN model compatibility
|
8 |
+
def apply_pggan_patch():
|
9 |
+
try:
|
10 |
+
from model19.base_GAN import BaseGAN
|
11 |
+
|
12 |
+
# Monkey-patch the BaseGAN class initialization
|
13 |
+
original_init = BaseGAN.__init__
|
14 |
+
|
15 |
+
def patched_init(self, dimLatentVector, **kwargs):
|
16 |
+
original_init(self, dimLatentVector, **kwargs)
|
17 |
+
|
18 |
+
# Force correct optimizer parameters
|
19 |
+
self.optimizerG = torch.optim.Adam(
|
20 |
+
self.netG.parameters(),
|
21 |
+
lr=0.0002,
|
22 |
+
betas=(0.5, 0.999) # Explicit tuple of floats
|
23 |
+
)
|
24 |
+
|
25 |
+
self.optimizerD = torch.optim.Adam(
|
26 |
+
self.netD.parameters(),
|
27 |
+
lr=0.0002,
|
28 |
+
betas=(0.5, 0.999) # Explicit tuple of floats
|
29 |
+
)
|
30 |
+
|
31 |
+
BaseGAN.__init__ = patched_init
|
32 |
+
st.success("Applied PGGAN compatibility patch!")
|
33 |
+
except Exception as e:
|
34 |
+
st.warning(f"Patching failed: {str(e)}")
|
35 |
+
|
36 |
+
# Model configuration
|
37 |
+
MODEL_IDS = [
|
38 |
"00001_DCGAN_MMG_CALC_ROI",
|
39 |
"00002_DCGAN_MMG_MASS_ROI",
|
40 |
"00003_CYCLEGAN_MMG_DENSITY_FULL",
|
|
|
43 |
]
|
44 |
|
45 |
def main():
|
|
|
46 |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
|
47 |
+
st.title("🧠 Medical Image Generator")
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
# Sidebar controls
|
50 |
with st.sidebar:
|
51 |
st.header("⚙️ Settings")
|
52 |
+
model_id = st.selectbox("Select Model", MODEL_IDS)
|
53 |
+
num_images = st.slider("Number of Images", 1, 8, 4)
|
54 |
generate_btn = st.button("✨ Generate Images")
|
55 |
|
56 |
# Main content area
|
57 |
if generate_btn:
|
58 |
+
with st.spinner(f"Generating {num_images} images..."):
|
59 |
+
try:
|
60 |
+
generate_images(num_images, model_id)
|
61 |
+
except Exception as e:
|
62 |
+
st.error(f"Generation failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def generate_images(num_images, model_id):
|
65 |
+
# Apply model-specific patches
|
66 |
+
if "00019" in model_id:
|
67 |
+
apply_pggan_patch()
|
68 |
|
69 |
+
# Initialize generator
|
70 |
+
generators = Generators()
|
|
|
|
|
|
|
71 |
|
72 |
+
# Generate and display images
|
73 |
+
cols = st.columns(4)
|
74 |
+
for i in range(num_images):
|
75 |
+
with cols[i % 4]:
|
76 |
+
try:
|
77 |
+
# Generate single sample
|
78 |
+
sample = generators.generate(
|
79 |
+
model_id=model_id,
|
80 |
+
num_samples=1,
|
81 |
+
install_dependencies=True
|
82 |
+
)
|
83 |
+
img = to_pil_image(sample[0]).convert("RGB")
|
84 |
+
st.image(img, caption=f"Image {i+1}", use_column_width=True)
|
85 |
+
st.markdown("---")
|
86 |
+
except Exception as e:
|
87 |
+
st.error(f"Error generating image {i+1}: {str(e)}")
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
main()
|