File size: 2,787 Bytes
c6d85c0
078cd45
ef30ad0
c6d85c0
 
 
ef30ad0
078cd45
ef30ad0
 
 
 
078cd45
ef30ad0
 
078cd45
ef30ad0
 
 
 
 
078cd45
ef30ad0
 
 
 
078cd45
 
ef30ad0
 
 
 
 
078cd45
ef30ad0
 
078cd45
 
c6d85c0
 
 
 
afbb2d7
c6d85c0
 
 
afbb2d7
078cd45
afbb2d7
 
 
078cd45
 
ef30ad0
 
 
 
 
 
078cd45
c6d85c0
 
ef30ad0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
078cd45
ef30ad0
 
 
c6d85c0
 
afbb2d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
import torch
import os
from medigan import Generators
from torchvision.transforms.functional import to_pil_image

def install_and_patch_model(model_id):
    try:
        # Install model if not exists
        generators = Generators()
        if model_id not in generators.get_model_ids():
            generators.download_and_install_model(model_id)
        
        # Locate model files
        model_path = os.path.join(os.path.expanduser("~"), ".medigan", "models", model_id)
        
        # Patch base_GAN.py
        base_gan_path = os.path.join(model_path, "model19", "base_GAN.py")
        if os.path.exists(base_gan_path):
            with open(base_gan_path, "r") as f:
                content = f.read()
            
            # Replace problematic optimizer lines
            content = content.replace(
                "betas=(self.beta1, self.beta2)",
                "betas=(0.5, 0.999)"  # Force correct beta values
            )
            
            with open(base_gan_path, "w") as f:
                f.write(content)
            
            st.success("Model successfully patched!")
        return True
    except Exception as e:
        st.error(f"Patching failed: {str(e)}")
        return False

MODEL_IDS = [
    "00001_DCGAN_MMG_CALC_ROI",
    "00002_DCGAN_MMG_MASS_ROI",
    "00003_CYCLEGAN_MMG_DENSITY_FULL",
    "00004_PIX2PIX_MMG_MASSES_W_MASKS",
    "00019_PGGAN_CHEST_XRAY"
]

def main():
    st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
    st.title("🧠 Medical Image Generator")
    
    with st.sidebar:
        st.header("⚙️ Settings")
        model_id = st.selectbox("Select Model", MODEL_IDS)
        num_images = st.slider("Number of Images", 1, 8, 4)
        
        if st.button("✨ Generate Images"):
            if "00019" in model_id and not install_and_patch_model(model_id):
                return
                
            with st.spinner("Generating images..."):
                generate_images(num_images, model_id)

def generate_images(num_images, model_id):
    try:
        generators = Generators()
        images = []
        
        for i in range(num_images):
            sample = generators.generate(
                model_id=model_id,
                num_samples=1,
                install_dependencies=False
            )
            img = to_pil_image(sample[0]).convert("RGB")
            images.append(img)
        
        cols = st.columns(4)
        for idx, img in enumerate(images):
            with cols[idx % 4]:
                st.image(img, caption=f"Image {idx+1}", use_column_width=True)
                st.markdown("---")
                
    except Exception as e:
        st.error(f"Generation failed: {str(e)}")

if __name__ == "__main__":
    main()