File size: 2,912 Bytes
078cd45
c6d85c0
078cd45
c6d85c0
 
 
078cd45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d85c0
 
 
 
afbb2d7
c6d85c0
 
 
afbb2d7
078cd45
afbb2d7
 
 
 
078cd45
 
afbb2d7
 
 
 
078cd45
 
 
 
 
c6d85c0
 
078cd45
 
 
afbb2d7
078cd45
 
afbb2d7
078cd45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d85c0
 
afbb2d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Import required libraries
import streamlit as st
import torch
from medigan import Generators
from torchvision.transforms.functional import to_pil_image

# Apply critical patch for PGGAN model compatibility
def apply_pggan_patch():
    try:
        from model19.base_GAN import BaseGAN
        
        # Monkey-patch the BaseGAN class initialization
        original_init = BaseGAN.__init__
        
        def patched_init(self, dimLatentVector, **kwargs):
            original_init(self, dimLatentVector, **kwargs)
            
            # Force correct optimizer parameters
            self.optimizerG = torch.optim.Adam(
                self.netG.parameters(),
                lr=0.0002,
                betas=(0.5, 0.999)  # Explicit tuple of floats
            )
            
            self.optimizerD = torch.optim.Adam(
                self.netD.parameters(),
                lr=0.0002,
                betas=(0.5, 0.999)  # Explicit tuple of floats
            )
        
        BaseGAN.__init__ = patched_init
        st.success("Applied PGGAN compatibility patch!")
    except Exception as e:
        st.warning(f"Patching failed: {str(e)}")

# Model configuration
MODEL_IDS = [
    "00001_DCGAN_MMG_CALC_ROI",
    "00002_DCGAN_MMG_MASS_ROI",
    "00003_CYCLEGAN_MMG_DENSITY_FULL",
    "00004_PIX2PIX_MMG_MASSES_W_MASKS",
    "00019_PGGAN_CHEST_XRAY"
]

def main():
    st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
    st.title("🧠 Medical Image Generator")
    
    # Sidebar controls
    with st.sidebar:
        st.header("⚙️ Settings")
        model_id = st.selectbox("Select Model", MODEL_IDS)
        num_images = st.slider("Number of Images", 1, 8, 4)
        generate_btn = st.button("✨ Generate Images")
    
    # Main content area
    if generate_btn:
        with st.spinner(f"Generating {num_images} images..."):
            try:
                generate_images(num_images, model_id)
            except Exception as e:
                st.error(f"Generation failed: {str(e)}")

def generate_images(num_images, model_id):
    # Apply model-specific patches
    if "00019" in model_id:
        apply_pggan_patch()
    
    # Initialize generator
    generators = Generators()
    
    # Generate and display images
    cols = st.columns(4)
    for i in range(num_images):
        with cols[i % 4]:
            try:
                # Generate single sample
                sample = generators.generate(
                    model_id=model_id,
                    num_samples=1,
                    install_dependencies=True
                )
                img = to_pil_image(sample[0]).convert("RGB")
                st.image(img, caption=f"Image {i+1}", use_column_width=True)
                st.markdown("---")
            except Exception as e:
                st.error(f"Error generating image {i+1}: {str(e)}")

if __name__ == "__main__":
    main()