drhead commited on
Commit
0d0b3f9
·
verified ·
1 Parent(s): 2d4900f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -244,13 +244,8 @@ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
244
  w, h = image_pil.size
245
  size = max(w, h)
246
 
247
- # Resize CAM to match image
248
- cam_resized = np.array(Image.fromarray(cam).resize((size, size), resample=Image.Resampling.BILINEAR))
249
-
250
  # Normalize CAM to [0, 1]
251
- cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
252
-
253
- cam_norm = transforms.CenterCrop((h, w))(cam_norm)
254
 
255
  # Create heatmap using matplotlib colormap
256
  colormap = cm.get_cmap('jet')
@@ -258,7 +253,13 @@ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
258
  cam_alpha = (cam_norm >= vis_threshold).astype(np.float32) * alpha # Alpha mask
259
 
260
  cam_rgba = np.dstack((cam_colored, cam_alpha)) # Shape: (H, W, 4)
261
- cam_image = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
 
 
 
 
 
 
262
 
263
  # Composite over original
264
  composite = Image.alpha_composite(image_pil, cam_image)
 
244
  w, h = image_pil.size
245
  size = max(w, h)
246
 
 
 
 
247
  # Normalize CAM to [0, 1]
248
+ cam_norm = (cam - cam.min()) / (np.ptp(cam) + 1e-8)
 
 
249
 
250
  # Create heatmap using matplotlib colormap
251
  colormap = cm.get_cmap('jet')
 
253
  cam_alpha = (cam_norm >= vis_threshold).astype(np.float32) * alpha # Alpha mask
254
 
255
  cam_rgba = np.dstack((cam_colored, cam_alpha)) # Shape: (H, W, 4)
256
+
257
+ # Resize CAM to match image
258
+ cam_resized = np.array(Image.fromarray(cam_rgba).resize((size, size), resample=Image.Resampling.BILINEAR))
259
+
260
+ cam_resized = transforms.CenterCrop((h, w))(torch.tensor(cam_resized)).numpy()
261
+
262
+ cam_image = Image.fromarray((cam_resized * 255).astype(np.uint8), mode="RGBA")
263
 
264
  # Composite over original
265
  composite = Image.alpha_composite(image_pil, cam_image)