ahmed792002 commited on
Commit
c6d85c0
·
verified ·
1 Parent(s): ba6e10f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import needed libraries
2
+ import streamlit as st
3
+ import torchvision
4
+ from medigan import Generators
5
+ from torchvision.transforms.functional import to_pil_image
6
+ from torchvision.utils import make_grid
7
+
8
+ # Define the GAN models available in the app
9
+ model_ids = [
10
+ "00001_DCGAN_MMG_CALC_ROI",
11
+ "00002_DCGAN_MMG_MASS_ROI",
12
+ "00003_CYCLEGAN_MMG_DENSITY_FULL",
13
+ "00004_PIX2PIX_MMG_MASSES_W_MASKS",
14
+ "00019_PGGAN_CHEST_XRAY" # New model
15
+ ]
16
+
17
+ def main():
18
+ st.title("MEDIGAN Medical Image Data Generator")
19
+
20
+ # Add dropdown widget for model selection to the sidebar
21
+ model_id = st.sidebar.selectbox("Select Model ID", model_ids)
22
+
23
+ # Add number image selector to the sidebar
24
+ num_images = st.sidebar.number_input(
25
+ "Number of Images", min_value=1, max_value=7, value=1, step=1
26
+ )
27
+
28
+ # Add generate button to the sidebar
29
+ if st.sidebar.button("Generate Images"):
30
+ generate_images(num_images, model_id)
31
+
32
+ # Task 5.4.9: Copy the torch_images function from this notebook to app.py.
33
+ def torch_images(num_images, model_id):
34
+ generators = Generators()
35
+ dataloader = generators.get_as_torch_dataloader(
36
+ model_id=model_id,
37
+ install_dependencies=True,
38
+ num_samples=num_images,
39
+ prefetch_factor=None,
40
+ )
41
+
42
+ images = []
43
+ for batch_idx, data_dict in enumerate(dataloader):
44
+ image_list = []
45
+ for i in data_dict:
46
+ if "sample" in i:
47
+ sample = data_dict.get("sample")
48
+ if sample.dim() == 4:
49
+ sample = sample.squeeze(0).permute(2, 0, 1)
50
+
51
+ sample = to_pil_image(sample).convert("RGB")
52
+ # Convert the image to a PyTorch tensor
53
+ transform = torchvision.transforms.Compose(
54
+ [
55
+ torchvision.transforms.ToTensor(),
56
+ ]
57
+ )
58
+
59
+ # Apply the transform to your PIL image
60
+ sample = transform(sample)
61
+ image_list.append(sample)
62
+
63
+ # Preprocess the mask
64
+ if "mask" in i:
65
+ mask = data_dict.get("mask")
66
+ if mask.dim() == 4:
67
+ mask = mask.squeeze(0).permute(2, 0, 1)
68
+ mask = to_pil_image(mask).convert("RGB")
69
+ mask = transform(mask)
70
+ image_list.append(mask)
71
+
72
+ # Organize the grid to have 'sample' images per row
73
+ Grid = make_grid(image_list, nrow=2)
74
+
75
+ # Change Grid tensor to be a consistent shape
76
+ # The Grid tensor has shape [1, 128, 128, 1] in some models
77
+ if Grid.dim() == 4:
78
+ # Remove the singleton batch dimension
79
+ Grid = Grid.squeeze(0)
80
+ if Grid.size(-1) == 1:
81
+ # Remove the singleton channel dimension (assuming grayscale)
82
+ Grid = Grid.squeeze(-1)
83
+ else:
84
+ raise ValueError("Expected a single channel (grayscale) image.")
85
+
86
+ # Convert the tensor grid to a PIL Image for display
87
+ img = torchvision.transforms.ToPILImage()(Grid)
88
+ images.append(img)
89
+ return images
90
+
91
+ def generate_images(num_images, model_id):
92
+ st.subheader("Generated Images:")
93
+ images = torch_images(num_images, model_id) # Pass the correct parameters
94
+ for i in range(len(images)):
95
+ st.image(
96
+ images[i],
97
+ caption=f"Generated Image {i+1} (Model ID: {model_id})",
98
+ use_container_width=True,
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ main() # Task 5.4.4: Call the main function