ahmed792002 commited on
Commit
afbb2d7
·
verified ·
1 Parent(s): e182e13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -71
app.py CHANGED
@@ -11,92 +11,79 @@ model_ids = [
11
  "00002_DCGAN_MMG_MASS_ROI",
12
  "00003_CYCLEGAN_MMG_DENSITY_FULL",
13
  "00004_PIX2PIX_MMG_MASSES_W_MASKS",
14
- "00019_PGGAN_CHEST_XRAY" # New model
15
  ]
16
 
17
  def main():
18
- st.title("MEDIGAN Medical Image Data Generator")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Add dropdown widget for model selection to the sidebar
21
- model_id = st.sidebar.selectbox("Select Model ID", model_ids)
22
-
23
- # Add number image selector to the sidebar
24
- num_images = st.sidebar.number_input(
25
- "Number of Images", min_value=1, max_value=7, value=1, step=1
26
- )
27
-
28
- # Add generate button to the sidebar
29
- if st.sidebar.button("Generate Images"):
30
- generate_images(num_images, model_id)
31
-
32
- # Task 5.4.9: Copy the torch_images function from this notebook to app.py.
33
  def torch_images(num_images, model_id):
34
  generators = Generators()
35
  dataloader = generators.get_as_torch_dataloader(
36
  model_id=model_id,
37
  install_dependencies=True,
38
  num_samples=num_images,
39
- prefetch_factor=None,
 
40
  )
41
 
42
  images = []
43
- for batch_idx, data_dict in enumerate(dataloader):
44
- image_list = []
45
- for i in data_dict:
46
- if "sample" in i:
47
- sample = data_dict.get("sample")
48
- if sample.dim() == 4:
49
- sample = sample.squeeze(0).permute(2, 0, 1)
50
-
51
- sample = to_pil_image(sample).convert("RGB")
52
- # Convert the image to a PyTorch tensor
53
- transform = torchvision.transforms.Compose(
54
- [
55
- torchvision.transforms.ToTensor(),
56
- ]
57
- )
58
-
59
- # Apply the transform to your PIL image
60
- sample = transform(sample)
61
- image_list.append(sample)
62
-
63
- # Preprocess the mask
64
- if "mask" in i:
65
- mask = data_dict.get("mask")
66
- if mask.dim() == 4:
67
- mask = mask.squeeze(0).permute(2, 0, 1)
68
- mask = to_pil_image(mask).convert("RGB")
69
- mask = transform(mask)
70
- image_list.append(mask)
71
-
72
- # Organize the grid to have 'sample' images per row
73
- Grid = make_grid(image_list, nrow=2)
74
-
75
- # Change Grid tensor to be a consistent shape
76
- # The Grid tensor has shape [1, 128, 128, 1] in some models
77
- if Grid.dim() == 4:
78
- # Remove the singleton batch dimension
79
- Grid = Grid.squeeze(0)
80
- if Grid.size(-1) == 1:
81
- # Remove the singleton channel dimension (assuming grayscale)
82
- Grid = Grid.squeeze(-1)
83
- else:
84
- raise ValueError("Expected a single channel (grayscale) image.")
85
-
86
- # Convert the tensor grid to a PIL Image for display
87
- img = torchvision.transforms.ToPILImage()(Grid)
88
- images.append(img)
89
  return images
90
 
91
  def generate_images(num_images, model_id):
92
- st.subheader("Generated Images:")
93
- images = torch_images(num_images, model_id) # Pass the correct parameters
94
- for i in range(len(images)):
95
- st.image(
96
- images[i],
97
- caption=f"Generated Image {i+1} (Model ID: {model_id})",
98
- use_container_width=True,
99
- )
 
 
 
 
 
 
 
 
 
100
 
101
  if __name__ == "__main__":
102
- main() # Task 5.4.4: Call the main function
 
11
  "00002_DCGAN_MMG_MASS_ROI",
12
  "00003_CYCLEGAN_MMG_DENSITY_FULL",
13
  "00004_PIX2PIX_MMG_MASSES_W_MASKS",
14
+ "00019_PGGAN_CHEST_XRAY"
15
  ]
16
 
17
  def main():
18
+ # Setup page configuration
19
+ st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
20
+
21
+ # Main page title and description
22
+ st.title("🧠 MEDIGAN Medical Image Generator")
23
+ st.markdown("""
24
+ **Generate synthetic medical images using GAN models.**
25
+ 🔍 Select model and parameters in the sidebar → Click **Generate Images**
26
+ """)
27
+
28
+ # Sidebar controls
29
+ with st.sidebar:
30
+ st.header("⚙️ Settings")
31
+ model_id = st.selectbox("Select GAN Model", model_ids)
32
+ num_images = st.number_input("Number of Images", 1, 7, 1)
33
+ generate_btn = st.button("✨ Generate Images")
34
+
35
+ # Main content area
36
+ if generate_btn:
37
+ with st.spinner(f"Generating {num_images} image(s) using {model_id}..."):
38
+ generate_images(num_images, model_id)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def torch_images(num_images, model_id):
41
  generators = Generators()
42
  dataloader = generators.get_as_torch_dataloader(
43
  model_id=model_id,
44
  install_dependencies=True,
45
  num_samples=num_images,
46
+ num_workers=0,
47
+ prefetch_factor=None
48
  )
49
 
50
  images = []
51
+ for _, data_dict in enumerate(dataloader):
52
+ batch_images = []
53
+ for tensor in data_dict.values():
54
+ if tensor.dim() == 4:
55
+ tensor = tensor.squeeze(0).permute(2, 0, 1)
56
+ img = to_pil_image(tensor).convert("RGB")
57
+ batch_images.append(img)
58
+
59
+ # Create image grid
60
+ grid_tensor = make_grid(
61
+ [torchvision.transforms.ToTensor()(img) for img in batch_images],
62
+ nrow=2 if len(batch_images) > 1 else 1
63
+ )
64
+ grid_img = to_pil_image(grid_tensor)
65
+ images.append(grid_img)
66
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return images
68
 
69
  def generate_images(num_images, model_id):
70
+ # Clear previous results
71
+ st.empty()
72
+
73
+ # Generate and display new images
74
+ images = torch_images(num_images, model_id)
75
+
76
+ # Create columns for responsive layout
77
+ cols = st.columns(len(images))
78
+
79
+ for col, img in zip(cols, images):
80
+ with col:
81
+ st.image(
82
+ img,
83
+ caption=f"Generated by: {model_id}",
84
+ width=300
85
+ )
86
+ st.markdown("---")
87
 
88
  if __name__ == "__main__":
89
+ main()