clockclock commited on
Commit
2cee47c
·
verified ·
1 Parent(s): c58bef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -21
app.py CHANGED
@@ -45,48 +45,108 @@ def generate_heatmap(image_tensor, original_image, target_class_index):
45
 
46
  # Define wrapper function for model forward pass
47
  def model_forward_wrapper(input_tensor):
48
- with torch.no_grad(): # Save memory during attribution
49
- outputs = model(pixel_values=input_tensor)
50
- return outputs.logits
51
 
52
  # Get the target layer for Grad-CAM
53
- # For SWIN transformer, use the layer normalization layer
54
- target_layer = model.swin.layernorm
 
 
 
 
 
 
 
 
 
55
 
56
  # Initialize LayerGradCam with the wrapper function
57
  lgc = LayerGradCam(model_forward_wrapper, target_layer)
58
 
59
- # Generate attributions
60
- with torch.no_grad():
61
- attributions = lgc.attribute(
62
- image_tensor,
63
- target=target_class_index,
64
- relu_attributions=True
65
- )
66
 
67
  # Convert attributions to numpy for visualization
68
- heatmap = np.transpose(
69
- attributions.squeeze(0).cpu().detach().numpy(),
70
- (1, 2, 0)
71
- )
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Create visualization
74
  visualized_image, _ = viz.visualize_image_attr(
75
  heatmap,
76
  np.array(original_image),
77
  method="blended_heat_map",
78
- sign="all",
79
  show_colorbar=True,
80
  title="AI Detection Heatmap",
81
- alpha_overlay=0.6
 
 
82
  )
83
 
84
  return visualized_image
85
 
86
  except Exception as e:
87
  print(f"Error generating heatmap: {e}")
88
- # Return original image if heatmap generation fails
89
- return np.array(original_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # --- 3. Main Prediction Function ---
92
  def predict(image_upload: Image.Image, image_url: str):
 
45
 
46
  # Define wrapper function for model forward pass
47
  def model_forward_wrapper(input_tensor):
48
+ outputs = model(pixel_values=input_tensor)
49
+ return outputs.logits
 
50
 
51
  # Get the target layer for Grad-CAM
52
+ # For SWIN transformer, try different layers for better visualization
53
+ try:
54
+ # Try the encoder's last layer first
55
+ target_layer = model.swin.encoder.layers[-1].blocks[-1].layernorm_after
56
+ except:
57
+ try:
58
+ # Fallback to the main layernorm
59
+ target_layer = model.swin.layernorm
60
+ except:
61
+ # Final fallback to pooler if available
62
+ target_layer = model.swin.pooler.layernorm if hasattr(model.swin, 'pooler') else model.swin.layernorm
63
 
64
  # Initialize LayerGradCam with the wrapper function
65
  lgc = LayerGradCam(model_forward_wrapper, target_layer)
66
 
67
+ # Generate attributions - remove torch.no_grad() to allow gradients
68
+ attributions = lgc.attribute(
69
+ image_tensor,
70
+ target=target_class_index,
71
+ relu_attributions=False # Changed to False to see both positive and negative attributions
72
+ )
 
73
 
74
  # Convert attributions to numpy for visualization
75
+ attr_np = attributions.squeeze(0).cpu().detach().numpy()
76
+
77
+ # Normalize attributions to [0, 1] range for better visualization
78
+ attr_min = attr_np.min()
79
+ attr_max = attr_np.max()
80
+ if attr_max > attr_min:
81
+ attr_np = (attr_np - attr_min) / (attr_max - attr_min)
82
+
83
+ # Transpose for visualization (channels last)
84
+ if len(attr_np.shape) == 3:
85
+ heatmap = np.transpose(attr_np, (1, 2, 0))
86
+ else:
87
+ # If single channel, expand to 3 channels
88
+ heatmap = np.expand_dims(attr_np, axis=-1)
89
+ heatmap = np.repeat(heatmap, 3, axis=-1)
90
 
91
+ # Create visualization with enhanced parameters
92
  visualized_image, _ = viz.visualize_image_attr(
93
  heatmap,
94
  np.array(original_image),
95
  method="blended_heat_map",
96
+ sign="all", # Show both positive and negative attributions
97
  show_colorbar=True,
98
  title="AI Detection Heatmap",
99
+ alpha_overlay=0.5, # Reduced alpha for better visibility
100
+ cmap="RdYlBu_r", # Red-Yellow-Blue colormap (reversed)
101
+ outlier_perc=2 # Remove outliers for better contrast
102
  )
103
 
104
  return visualized_image
105
 
106
  except Exception as e:
107
  print(f"Error generating heatmap: {e}")
108
+ print(f"Attribution shape: {attributions.shape if 'attributions' in locals() else 'Not generated'}")
109
+
110
+ # Create a simple fallback heatmap using GradCAM on a different layer
111
+ try:
112
+ from captum.attr import GradCam
113
+
114
+ # Use GradCAM instead of LayerGradCAM as fallback
115
+ gc = GradCam(model_forward_wrapper, target_layer)
116
+ attributions = gc.attribute(image_tensor, target=target_class_index)
117
+
118
+ # Process the attributions
119
+ attr_np = attributions.squeeze().cpu().detach().numpy()
120
+
121
+ # Normalize
122
+ attr_min = attr_np.min()
123
+ attr_max = attr_np.max()
124
+ if attr_max > attr_min:
125
+ attr_np = (attr_np - attr_min) / (attr_max - attr_min)
126
+
127
+ # Create a simple overlay
128
+ import matplotlib.pyplot as plt
129
+ import matplotlib.cm as cm
130
+
131
+ # Resize attribution to match image size
132
+ from PIL import Image as PILImage
133
+ attr_resized = PILImage.fromarray((attr_np * 255).astype(np.uint8)).resize(original_image.size)
134
+ attr_resized = np.array(attr_resized) / 255.0
135
+
136
+ # Apply colormap
137
+ colored_attr = cm.jet(attr_resized)[:, :, :3] # Remove alpha channel
138
+
139
+ # Blend with original image
140
+ original_np = np.array(original_image) / 255.0
141
+ blended = 0.6 * original_np + 0.4 * colored_attr
142
+ blended = (blended * 255).astype(np.uint8)
143
+
144
+ return blended
145
+
146
+ except Exception as e2:
147
+ print(f"Fallback heatmap also failed: {e2}")
148
+ # Return original image if all heatmap generation fails
149
+ return np.array(original_image)
150
 
151
  # --- 3. Main Prediction Function ---
152
  def predict(image_upload: Image.Image, image_url: str):