hiyata commited on
Commit
2e254a9
·
verified ·
1 Parent(s): d153967

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -17
app.py CHANGED
@@ -207,9 +207,6 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
207
  - Negative = blue
208
  - 0 = white
209
  - Positive = red
210
- We'll force the range to be symmetrical around 0 by using:
211
- vmin=-extent, vmax=+extent
212
- so 0 is in the middle.
213
  """
214
  if start is not None and end is not None:
215
  local_shap = shap_means[start:end]
@@ -217,11 +214,10 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
217
  else:
218
  local_shap = shap_means
219
  subtitle = ""
220
-
221
  if len(local_shap) == 0:
222
- # Edge case: no data to plot
223
  local_shap = np.array([0.0])
224
-
225
  # Build 2D array for imshow
226
  heatmap_data = local_shap.reshape(1, -1)
227
 
@@ -233,25 +229,46 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
233
  # Create custom colormap
234
  custom_cmap = get_zero_centered_cmap()
235
 
236
- fig, ax = plt.subplots(figsize=(12, 2))
 
 
 
237
  cax = ax.imshow(
238
- heatmap_data,
239
- aspect='auto',
240
  cmap=custom_cmap,
241
  vmin=-extent,
242
  vmax=+extent
243
  )
244
 
245
- # Place colorbar below with plenty of margin
246
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
247
- cbar.set_label('SHAP Contribution (negative=blue, zero=white, positive=red)')
248
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  ax.set_yticks([])
250
- ax.set_xlabel('Position in Sequence')
251
- ax.set_title(f"{title}{subtitle}")
252
 
253
- # Extra bottom margin so colorbar won't overlap x-axis labels
254
- plt.subplots_adjust(bottom=0.4)
 
 
 
 
255
 
256
  return fig
257
 
 
207
  - Negative = blue
208
  - 0 = white
209
  - Positive = red
 
 
 
210
  """
211
  if start is not None and end is not None:
212
  local_shap = shap_means[start:end]
 
214
  else:
215
  local_shap = shap_means
216
  subtitle = ""
217
+
218
  if len(local_shap) == 0:
 
219
  local_shap = np.array([0.0])
220
+
221
  # Build 2D array for imshow
222
  heatmap_data = local_shap.reshape(1, -1)
223
 
 
229
  # Create custom colormap
230
  custom_cmap = get_zero_centered_cmap()
231
 
232
+ # Create figure with adjusted height ratio
233
+ fig, ax = plt.subplots(figsize=(12, 1.8)) # Reduced height
234
+
235
+ # Plot heatmap
236
  cax = ax.imshow(
237
+ heatmap_data,
238
+ aspect='auto',
239
  cmap=custom_cmap,
240
  vmin=-extent,
241
  vmax=+extent
242
  )
243
 
244
+ # Configure colorbar with more subtle positioning
245
+ cbar = plt.colorbar(
246
+ cax,
247
+ orientation='horizontal',
248
+ pad=0.25, # Reduced padding
249
+ aspect=40, # Make colorbar thinner
250
+ shrink=0.8 # Make colorbar shorter than plot width
251
+ )
252
+
253
+ # Style the colorbar
254
+ cbar.ax.tick_params(labelsize=8) # Smaller tick labels
255
+ cbar.set_label(
256
+ 'SHAP Contribution',
257
+ fontsize=9,
258
+ labelpad=5
259
+ )
260
+
261
+ # Configure main plot
262
  ax.set_yticks([])
263
+ ax.set_xlabel('Position in Sequence', fontsize=10)
264
+ ax.set_title(f"{title}{subtitle}", pad=10)
265
 
266
+ # Fine-tune layout
267
+ plt.subplots_adjust(
268
+ bottom=0.25, # Reduced bottom margin
269
+ left=0.05, # Tighter left margin
270
+ right=0.95 # Tighter right margin
271
+ )
272
 
273
  return fig
274