clockclock commited on
Commit
7878ad2
·
verified ·
1 Parent(s): d04d87f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -129
app.py CHANGED
@@ -38,140 +38,52 @@ except Exception as e:
38
  raise
39
 
40
  # --- 2. Define the Explainability (Grad-CAM) Function ---
 
 
41
  def generate_heatmap(image_tensor, original_image, target_class_index):
 
 
 
 
42
  try:
43
- # Ensure tensor is on CPU and requires gradients
44
- image_tensor = image_tensor.to(device)
45
- image_tensor.requires_grad_(True)
46
-
47
- # Define wrapper function for model forward pass
48
- def model_forward_wrapper(input_tensor):
49
- outputs = model(pixel_values=input_tensor)
50
- return outputs.logits
51
 
52
- # Try different approaches for better heatmap generation
53
- try:
54
- # First try: Use GradCam directly (often more reliable than LayerGradCam)
55
- from captum.attr import GradCam
56
-
57
- # For SWIN transformer, target the last convolutional-like layer
58
- try:
59
- # Try to find a suitable layer in the SWIN model
60
- target_layer = model.swin.encoder.layers[-1].blocks[-1].norm1
61
- except:
62
- try:
63
- target_layer = model.swin.encoder.layers[-1].blocks[0].norm1
64
- except:
65
- target_layer = model.swin.layernorm
66
-
67
- gc = GradCam(model_forward_wrapper, target_layer)
68
-
69
- # Generate attributions
70
- attributions = gc.attribute(image_tensor, target=target_class_index)
71
-
72
- # Process attributions
73
- attr_np = attributions.squeeze().cpu().detach().numpy()
74
-
75
- print(f"Attribution stats: min={attr_np.min():.4f}, max={attr_np.max():.4f}, mean={attr_np.mean():.4f}")
76
-
77
- # Normalize to [0, 1] range
78
- if attr_np.max() > attr_np.min():
79
- attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min())
80
-
81
- # Resize to match original image size
82
- from PIL import Image as PILImage
83
- import cv2
84
-
85
- # Resize attribution map to original image size
86
- attr_resized = cv2.resize(attr_np, original_image.size, interpolation=cv2.INTER_LINEAR)
87
-
88
- # Create a more visible heatmap
89
- import matplotlib.pyplot as plt
90
- import matplotlib.cm as cm
91
-
92
- # Apply a strong colormap (jet gives good red visualization)
93
- colored_attr = cm.jet(attr_resized)[:, :, :3] # Remove alpha channel
94
-
95
- # Convert original image to numpy
96
- original_np = np.array(original_image) / 255.0
97
-
98
- # Create a stronger blend to make heatmap more visible
99
- alpha = 0.6 # Higher alpha for more heatmap visibility
100
- blended = (1 - alpha) * original_np + alpha * colored_attr
101
- blended = (blended * 255).astype(np.uint8)
102
-
103
- return blended
104
-
105
- except Exception as e1:
106
- print(f"GradCam failed: {e1}")
107
-
108
- # Fallback: Try LayerGradCam
109
- try:
110
- lgc = LayerGradCam(model_forward_wrapper, target_layer)
111
- attributions = lgc.attribute(
112
- image_tensor,
113
- target=target_class_index,
114
- relu_attributions=False
115
- )
116
-
117
- # Process the attributions
118
- attr_np = attributions.squeeze(0).cpu().detach().numpy()
119
-
120
- # Handle different attribution shapes
121
- if len(attr_np.shape) == 3:
122
- # Take mean across channels if multi-channel
123
- attr_np = np.mean(attr_np, axis=0)
124
-
125
- # Normalize
126
- if attr_np.max() > attr_np.min():
127
- attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min())
128
-
129
- # Create visualization using captum's viz
130
- if len(attr_np.shape) == 2:
131
- # Expand to 3 channels for visualization
132
- heatmap = np.expand_dims(attr_np, axis=-1)
133
- heatmap = np.repeat(heatmap, 3, axis=-1)
134
- else:
135
- heatmap = np.transpose(attr_np, (1, 2, 0))
136
-
137
- visualized_image, _ = viz.visualize_image_attr(
138
- heatmap,
139
- np.array(original_image),
140
- method="blended_heat_map",
141
- sign="all",
142
- show_colorbar=True,
143
- title="AI Detection Heatmap",
144
- alpha_overlay=0.4,
145
- cmap="jet", # Use jet colormap for strong red visualization
146
- outlier_perc=1
147
- )
148
-
149
- return visualized_image
150
-
151
- except Exception as e2:
152
- print(f"LayerGradCam also failed: {e2}")
153
-
154
- # Final fallback: Create a simple random heatmap for demonstration
155
- print("Creating demonstration heatmap...")
156
-
157
- # Create a simple demonstration heatmap
158
- h, w = original_image.size[1], original_image.size[0]
159
- demo_attr = np.random.rand(h, w) * 0.5 + 0.3 # Random values between 0.3 and 0.8
160
-
161
- # Apply jet colormap
162
- colored_attr = cm.jet(demo_attr)[:, :, :3]
163
-
164
- # Blend with original
165
- original_np = np.array(original_image) / 255.0
166
- blended = 0.7 * original_np + 0.3 * colored_attr
167
- blended = (blended * 255).astype(np.uint8)
168
-
169
- return blended
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  except Exception as e:
172
- print(f"Complete heatmap generation failed: {e}")
173
- # Return original image if everything fails
174
- return np.array(original_image)
175
 
176
  # --- 3. Main Prediction Function ---
177
  def predict(image_upload: Image.Image, image_url: str):
 
38
  raise
39
 
40
  # --- 2. Define the Explainability (Grad-CAM) Function ---
41
+ ### FIX ###: This function is now more robust. It returns `None` on failure
42
+ ### instead of returning the original image, allowing the main function to handle it.
43
  def generate_heatmap(image_tensor, original_image, target_class_index):
44
+ """
45
+ Generates a Grad-CAM heatmap.
46
+ Returns a numpy array of the blended image, or None if it fails.
47
+ """
48
  try:
49
+ # LayerGradCam is often a good choice for transformer-based models.
50
+ # The target layer is chosen as one of the last normalization layers in the SWIN transformer.
51
+ # This might need adjustment for different model architectures.
52
+ target_layer = model.swin.encoder.layers[-1].blocks[-1].norm1
53
+ lgc = LayerGradCam(model.forward, target_layer)
 
 
 
54
 
55
+ # Generate attributions
56
+ attributions = lgc.attribute(
57
+ image_tensor,
58
+ target=target_class_index,
59
+ relu_attributions=True # Use relu_attributions to focus on positive contributions
60
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Squeeze the attributions to a 2D map
63
+ attribution_map = attributions.squeeze(0).squeeze(0).cpu().detach().numpy()
64
+
65
+ ### FIX ###: Check if the attribution map is uniform (all zeros or same value).
66
+ # This happens when the model has no strong evidence for its decision,
67
+ # which is common in misclassifications.
68
+ if (attribution_map.max() - attribution_map.min()) < 1e-6:
69
+ print("Warning: Heatmap generation failed due to uniform gradients. The model likely has low confidence or is misclassifying.")
70
+ return None
71
+
72
+ # Use Captum's visualization tool to create a blended heatmap
73
+ blended_image, _ = viz.visualize_image_attr(
74
+ attribution_map,
75
+ np.array(original_image),
76
+ method="blended_heat_map",
77
+ sign="positive", # Focus on what positively contributed to the decision
78
+ alpha_overlay=0.5, # Make the overlay reasonably transparent
79
+ cmap="jet", # 'jet' colormap shows hot areas in red
80
+ show_colorbar=False
81
+ )
82
+ return blended_image
83
+
84
  except Exception as e:
85
+ print(f"Error during heatmap generation: {e}")
86
+ return None
 
87
 
88
  # --- 3. Main Prediction Function ---
89
  def predict(image_upload: Image.Image, image_url: str):