File size: 2,699 Bytes
c6d85c0
 
 
 
 
 
 
 
 
 
 
 
 
afbb2d7
c6d85c0
 
 
afbb2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d85c0
 
 
 
 
 
 
afbb2d7
 
c6d85c0
 
 
afbb2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d85c0
 
 
afbb2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d85c0
 
afbb2d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Import needed libraries
import streamlit as st
import torchvision
from medigan import Generators
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid

# Define the GAN models available in the app
model_ids = [
    "00001_DCGAN_MMG_CALC_ROI",
    "00002_DCGAN_MMG_MASS_ROI",
    "00003_CYCLEGAN_MMG_DENSITY_FULL",
    "00004_PIX2PIX_MMG_MASSES_W_MASKS",
    "00019_PGGAN_CHEST_XRAY"
]

def main():
    # Setup page configuration
    st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
    
    # Main page title and description
    st.title("🧠 MEDIGAN Medical Image Generator")
    st.markdown("""
    **Generate synthetic medical images using GAN models.**  
    🔍 Select model and parameters in the sidebar → Click **Generate Images**
    """)
    
    # Sidebar controls
    with st.sidebar:
        st.header("⚙️ Settings")
        model_id = st.selectbox("Select GAN Model", model_ids)
        num_images = st.number_input("Number of Images", 1, 7, 1)
        generate_btn = st.button("✨ Generate Images")
    
    # Main content area
    if generate_btn:
        with st.spinner(f"Generating {num_images} image(s) using {model_id}..."):
            generate_images(num_images, model_id)

def torch_images(num_images, model_id):
    generators = Generators()
    dataloader = generators.get_as_torch_dataloader(
        model_id=model_id,
        install_dependencies=True,
        num_samples=num_images,
        num_workers=0,
        prefetch_factor=None
    )

    images = []
    for _, data_dict in enumerate(dataloader):
        batch_images = []
        for tensor in data_dict.values():
            if tensor.dim() == 4:
                tensor = tensor.squeeze(0).permute(2, 0, 1)
            img = to_pil_image(tensor).convert("RGB")
            batch_images.append(img)
        
        # Create image grid
        grid_tensor = make_grid(
            [torchvision.transforms.ToTensor()(img) for img in batch_images],
            nrow=2 if len(batch_images) > 1 else 1
        )
        grid_img = to_pil_image(grid_tensor)
        images.append(grid_img)
    
    return images

def generate_images(num_images, model_id):
    # Clear previous results
    st.empty()
    
    # Generate and display new images
    images = torch_images(num_images, model_id)
    
    # Create columns for responsive layout
    cols = st.columns(len(images))
    
    for col, img in zip(cols, images):
        with col:
            st.image(
                img,
                caption=f"Generated by: {model_id}",
                width=300
            )
            st.markdown("---")

if __name__ == "__main__":
    main()