drhead commited on
Commit
fbd5ebe
·
1 Parent(s): aa163e9

Add attention visualization

Browse files
Files changed (1) hide show
  1. app.py +179 -41
app.py CHANGED
@@ -1,17 +1,17 @@
1
- import json
2
-
3
- import gradio as gr
4
  from PIL import Image
5
- import safetensors.torch
6
- import spaces
7
- import timm
8
- from timm.models import VisionTransformer
9
  import torch
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
-
14
- torch.set_grad_enabled(False)
 
 
 
 
15
 
16
  class Fit(torch.nn.Module):
17
  def __init__(
@@ -118,68 +118,206 @@ model = timm.create_model(
118
  num_classes=9083,
119
  ) # type: VisionTransformer
120
 
121
- safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
 
 
 
 
 
 
122
  model.eval()
123
 
124
- with open("tagger_tags.json", "r") as file:
125
- tags = json.load(file) # type: dict
126
- allowed_tags = list(tags.keys())
127
 
128
- for idx, tag in enumerate(allowed_tags):
129
- allowed_tags[idx] = tag.replace("_", " ")
130
 
131
- sorted_tag_score = {}
132
 
133
  @spaces.GPU(duration=5)
134
- def run_classifier(image, threshold):
135
- global sorted_tag_score
136
  img = image.convert('RGBA')
137
  tensor = transform(img).unsqueeze(0)
138
 
139
  with torch.no_grad():
140
- logits = model(tensor)
141
- probits = torch.nn.functional.sigmoid(logits[0])
142
- values, indices = probits.topk(250)
 
143
 
144
- tag_score = dict()
145
- for i in range(indices.size(0)):
146
- tag_score[allowed_tags[indices[i]]] = values[i].item()
147
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
148
 
149
- return create_tags(threshold)
150
 
151
- def create_tags(threshold):
152
- global sorted_tag_score
153
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
154
  text_no_impl = ", ".join(filtered_tag_score.keys())
155
  return text_no_impl, filtered_tag_score
156
-
157
 
158
- with gr.Blocks(css=".output-class { display: none; }") as demo:
159
- gr.Markdown("""
160
- ## Joint Tagger Project: PILOT Demo
161
- This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
162
 
163
- This tagger is the result of joint efforts between members of the RedRocket team. Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
164
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  with gr.Row():
166
  with gr.Column():
167
- image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
168
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
 
169
  with gr.Column():
 
170
  tag_string = gr.Textbox(label="Tag String")
171
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
172
 
173
- image_input.upload(
 
 
 
 
 
 
 
174
  fn=run_classifier,
175
- inputs=[image_input, threshold_slider],
176
- outputs=[tag_string, label_box]
 
 
 
 
 
 
 
177
  )
178
 
179
  threshold_slider.input(
180
  fn=create_tags,
181
- inputs=[threshold_slider],
182
- outputs=[tag_string, label_box]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  )
184
 
185
  if __name__ == "__main__":
 
 
 
 
1
  from PIL import Image
2
+ import numpy as np
3
+ import matplotlib.cm as cm
4
+ import msgspec
 
5
  import torch
6
  from torchvision.transforms import transforms
7
  from torchvision.transforms import InterpolationMode
8
  import torchvision.transforms.functional as TF
9
+ import timm
10
+ from timm.models import VisionTransformer
11
+ import safetensors.torch
12
+ import gradio as gr
13
+ import spaces
14
+ from huggingface_hub import hf_hub_download
15
 
16
  class Fit(torch.nn.Module):
17
  def __init__(
 
118
  num_classes=9083,
119
  ) # type: VisionTransformer
120
 
121
+ cached_model = hf_hub_download(
122
+ repo_id="RedRocket/JointTaggerProject",
123
+ subfolder="JTP_PILOT",
124
+ filename="JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors"
125
+ )
126
+
127
+ safetensors.torch.load_model(model, cached_model)
128
  model.eval()
129
 
130
+ with open("tagger_tags.json", "rb") as file:
131
+ tags = msgspec.json.decode(file.read(), type=dict[str, int])
 
132
 
133
+ for tag in list(tags.keys()):
134
+ tags[tag.replace("_", " ")] = tags.pop(tag)
135
 
136
+ allowed_tags = list(tags.keys())
137
 
138
  @spaces.GPU(duration=5)
139
+ def run_classifier(image: Image.Image, threshold):
 
140
  img = image.convert('RGBA')
141
  tensor = transform(img).unsqueeze(0)
142
 
143
  with torch.no_grad():
144
+ probits = model(tensor)[0] # type: torch.Tensor
145
+ values, indices = probits.cpu().topk(250)
146
+
147
+ tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
148
 
 
 
 
149
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
150
 
151
+ return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
152
 
153
+ def create_tags(threshold, sorted_tag_score: dict):
 
154
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
155
  text_no_impl = ", ".join(filtered_tag_score.keys())
156
  return text_no_impl, filtered_tag_score
 
157
 
158
+ def clear_image():
159
+ return "", {}, None, {}, None
 
 
160
 
161
+ @spaces.GPU(duration=5)
162
+ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
163
+ target_tag_index = tags[evt.value]
164
+ tensor = transform(img).unsqueeze(0)
165
+
166
+ gradients = {}
167
+ activations = {}
168
+
169
+ def hook_forward(module, input, output):
170
+ activations['value'] = output
171
+
172
+ def hook_backward(module, grad_in, grad_out):
173
+ gradients['value'] = grad_out[0]
174
+
175
+ handle_forward = model.norm.register_forward_hook(hook_forward)
176
+ handle_backward = model.norm.register_full_backward_hook(hook_backward)
177
+
178
+ probits = model(tensor)[0]
179
+
180
+ model.zero_grad()
181
+ probits[target_tag_index].backward(retain_graph=True)
182
+
183
+ with torch.no_grad():
184
+ patch_grads = gradients.get('value')
185
+ patch_acts = activations.get('value')
186
+
187
+ weights = torch.mean(patch_grads, dim=1).squeeze(0)
188
+
189
+ cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
190
+ cam_1d = torch.relu(cam_1d)
191
+
192
+ cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
193
+
194
+ handle_forward.remove()
195
+ handle_backward.remove()
196
+
197
+ return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
198
+
199
+ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
200
+ """
201
+ Overlays CAM on image and returns a PIL image.
202
+ Args:
203
+ image_pil: PIL Image (RGB)
204
+ cam: 2D numpy array (activation map)
205
+ alpha: float, blending factor
206
+ vis_threshold: float, minimum normalized CAM value to show color
207
+ Returns:
208
+ PIL.Image.Image with overlay
209
+ """
210
+ if cam is None:
211
+ return image_pil
212
+ w, h = image_pil.size
213
+ size = max(w, h)
214
+
215
+ # Normalize CAM to [0, 1]
216
+ cam -= cam.min()
217
+ cam /= cam.max()
218
+
219
+ # Create heatmap using matplotlib colormap
220
+ colormap = cm.get_cmap('inferno')
221
+ cam_rgb = colormap(cam)[:, :, :3] # RGB
222
+
223
+ # Create alpha channel
224
+ cam_alpha = (cam >= vis_threshold).astype(np.float32) * alpha # Alpha mask
225
+ cam_rgba = np.dstack((cam_rgb, cam_alpha)) # Shape: (H, W, 4)
226
+
227
+ # Coarse upscale for CAM output -- keeps "blocky" effect that is truer to what is measured
228
+ cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
229
+ cam_pil = cam_pil.resize((216,216), resample=Image.Resampling.NEAREST)
230
+
231
+ # Model uses padded image as input, this matches attention map to input image aspect ratio
232
+ cam_pil = cam_pil.resize((size, size), resample=Image.Resampling.BICUBIC)
233
+ cam_pil = transforms.CenterCrop((h, w))(cam_pil)
234
+
235
+ # Composite over original
236
+ composite = Image.alpha_composite(image_pil, cam_pil)
237
+
238
+ return composite
239
+
240
+ custom_css = """
241
+ .output-class { display: none; }
242
+ .inferno-slider input[type=range] {
243
+ background: linear-gradient(to right,
244
+ #000004, #1b0c41, #4a0c6b, #781c6d,
245
+ #a52c60, #cf4446, #ed6925, #fb9b06,
246
+ #f7d13d, #fcffa4
247
+ ) !important;
248
+ background-size: 100% 100% !important;
249
+ }
250
+ #image_container-image {
251
+ width: 100%;
252
+ aspect-ratio: 1 / 1;
253
+ max-height: 100%;
254
+ }
255
+ #image_container img {
256
+ object-fit: contain !important;
257
+ }
258
+ """
259
+
260
+ with gr.Blocks(css=custom_css) as demo:
261
+ gr.Markdown("## Joint Tagger Project: JTP-PILOT² Demo **BETA**")
262
+ original_image_state = gr.State() # stash a copy of the input image
263
+ sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
264
+ cam_state = gr.State()
265
  with gr.Row():
266
  with gr.Column():
267
+ image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
268
+ cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
269
+ alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
270
  with gr.Column():
271
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
272
  tag_string = gr.Textbox(label="Tag String")
273
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
274
 
275
+ gr.Markdown("""
276
+ This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
277
+ This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
278
+ Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
279
+ Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
280
+ """)
281
+
282
+ image.upload(
283
  fn=run_classifier,
284
+ inputs=[image, threshold_slider],
285
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
286
+ show_progress='minimal'
287
+ )
288
+
289
+ image.clear(
290
+ fn=clear_image,
291
+ inputs=[],
292
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
293
  )
294
 
295
  threshold_slider.input(
296
  fn=create_tags,
297
+ inputs=[threshold_slider, sorted_tag_score_state],
298
+ outputs=[tag_string, label_box],
299
+ show_progress='hidden'
300
+ )
301
+
302
+ label_box.select(
303
+ fn=cam_inference,
304
+ inputs=[original_image_state, cam_slider, alpha_slider],
305
+ outputs=[image, cam_state],
306
+ show_progress='minimal'
307
+ )
308
+
309
+ cam_slider.input(
310
+ fn=create_cam_visualization_pil,
311
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
312
+ outputs=[image],
313
+ show_progress='hidden'
314
+ )
315
+
316
+ alpha_slider.input(
317
+ fn=create_cam_visualization_pil,
318
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
319
+ outputs=[image],
320
+ show_progress='hidden'
321
  )
322
 
323
  if __name__ == "__main__":