attempt to make tag vis work
Browse files
app.py
CHANGED
@@ -12,8 +12,6 @@ from torchvision.transforms import InterpolationMode
|
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
-
torch.set_grad_enabled(False)
|
16 |
-
|
17 |
class Fit(torch.nn.Module):
|
18 |
def __init__(
|
19 |
self,
|
@@ -155,11 +153,14 @@ for idx, tag in enumerate(allowed_tags):
|
|
155 |
allowed_tags[idx] = tag.replace("_", " ")
|
156 |
|
157 |
sorted_tag_score = {}
|
|
|
|
|
158 |
|
159 |
@spaces.GPU(duration=5)
|
160 |
def run_classifier(image, threshold):
|
161 |
-
global sorted_tag_score
|
162 |
-
|
|
|
163 |
tensor = transform(img).unsqueeze(0)
|
164 |
|
165 |
with torch.no_grad():
|
@@ -180,10 +181,124 @@ def create_tags(threshold):
|
|
180 |
return text_no_impl, filtered_tag_score
|
181 |
|
182 |
def clear_image():
|
183 |
-
global sorted_tag_score
|
|
|
184 |
sorted_tag_score = {}
|
185 |
return "", {}
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
188 |
gr.Markdown("""
|
189 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
@@ -219,5 +334,11 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
219 |
outputs=[tag_string, label_box]
|
220 |
)
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
if __name__ == "__main__":
|
223 |
demo.launch()
|
|
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
|
|
|
|
|
15 |
class Fit(torch.nn.Module):
|
16 |
def __init__(
|
17 |
self,
|
|
|
153 |
allowed_tags[idx] = tag.replace("_", " ")
|
154 |
|
155 |
sorted_tag_score = {}
|
156 |
+
input_image = None
|
157 |
+
|
158 |
|
159 |
@spaces.GPU(duration=5)
|
160 |
def run_classifier(image, threshold):
|
161 |
+
global sorted_tag_score, input_image
|
162 |
+
input_image = image.convert('RGBA')
|
163 |
+
img = input_image
|
164 |
tensor = transform(img).unsqueeze(0)
|
165 |
|
166 |
with torch.no_grad():
|
|
|
181 |
return text_no_impl, filtered_tag_score
|
182 |
|
183 |
def clear_image():
|
184 |
+
global sorted_tag_score, input_image
|
185 |
+
input_image = None
|
186 |
sorted_tag_score = {}
|
187 |
return "", {}
|
188 |
|
189 |
+
target_tag_index = None
|
190 |
+
|
191 |
+
# Store hooks and intermediate values
|
192 |
+
gradients = {}
|
193 |
+
activations = {}
|
194 |
+
|
195 |
+
def hook_forward(module, input, output):
|
196 |
+
activations['value'] = output
|
197 |
+
|
198 |
+
def hook_backward(module, grad_in, grad_out):
|
199 |
+
gradients['value'] = grad_out[0]
|
200 |
+
|
201 |
+
def cam_inference(target_tag, threshold):
|
202 |
+
global input_image, sorted_tag_score, target_tag_index, gradients, activations
|
203 |
+
img = input_image
|
204 |
+
tensor = transform(img).unsqueeze(0)
|
205 |
+
|
206 |
+
gradients = {}
|
207 |
+
activations = {}
|
208 |
+
cam = None
|
209 |
+
target_tag_index = None
|
210 |
+
|
211 |
+
if target_tag:
|
212 |
+
if target_tag not in allowed_tags:
|
213 |
+
print(f"Warning: Target tag '{target_tag}' not found in allowed tags.")
|
214 |
+
target_tag = None
|
215 |
+
else:
|
216 |
+
target_tag_index = allowed_tags.index(target_tag)
|
217 |
+
handle_forward = model.norm.register_forward_hook(hook_forward)
|
218 |
+
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
219 |
+
|
220 |
+
probits = model(tensor)[0].cpu()
|
221 |
+
|
222 |
+
|
223 |
+
if target_tag is not None and target_tag_index is not None:
|
224 |
+
model.zero_grad()
|
225 |
+
target_score = probits[target_tag_index]
|
226 |
+
target_score.backward(retain_graph=True)
|
227 |
+
|
228 |
+
grads = gradients.get('value')
|
229 |
+
acts = activations.get('value')
|
230 |
+
|
231 |
+
if grads is not None and acts is not None:
|
232 |
+
patch_grads = grads
|
233 |
+
patch_acts = acts
|
234 |
+
|
235 |
+
weights = torch.mean(patch_grads, dim=1).squeeze(0)
|
236 |
+
|
237 |
+
cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
|
238 |
+
cam_1d = torch.relu(cam_1d)
|
239 |
+
|
240 |
+
cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
|
241 |
+
|
242 |
+
|
243 |
+
handle_forward.remove()
|
244 |
+
handle_backward.remove()
|
245 |
+
gradients = {}
|
246 |
+
activations = {}
|
247 |
+
|
248 |
+
return create_cam_visualization_pil(cam, vis_threshold=threshold)
|
249 |
+
|
250 |
+
def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
|
251 |
+
"""
|
252 |
+
Overlays CAM on image and returns a PIL image.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
image_pil: PIL Image (RGB)
|
256 |
+
cam: 2D numpy array (activation map)
|
257 |
+
alpha: float, blending factor
|
258 |
+
vis_threshold: float, minimum normalized CAM value to show color
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
PIL.Image.Image with overlay
|
262 |
+
"""
|
263 |
+
if cam is None:
|
264 |
+
print("CAM is None, skipping visualization.")
|
265 |
+
return image_pil
|
266 |
+
global input_image
|
267 |
+
# Convert to RGB (in case RGBA or others)
|
268 |
+
image_pil = input_image
|
269 |
+
w, h = image_pil.size
|
270 |
+
|
271 |
+
# Resize CAM to match image
|
272 |
+
cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.BILINEAR))
|
273 |
+
|
274 |
+
# Normalize CAM to [0, 1]
|
275 |
+
cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
|
276 |
+
|
277 |
+
# Apply threshold mask
|
278 |
+
mask = cam_norm >= vis_threshold
|
279 |
+
|
280 |
+
# Create heatmap using matplotlib colormap
|
281 |
+
colormap = cm.get_cmap('jet')
|
282 |
+
heatmap_rgba = colormap(cam_norm) # shape: (H, W, 4), values in [0, 1]
|
283 |
+
heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
|
284 |
+
|
285 |
+
# Convert heatmap to PIL image
|
286 |
+
heatmap_pil = Image.fromarray(heatmap_rgb).convert("RGB")
|
287 |
+
|
288 |
+
# Convert images to NumPy for blending
|
289 |
+
base_np = np.array(image_pil).astype(np.float32)
|
290 |
+
heat_np = np.array(heatmap_pil).astype(np.float32)
|
291 |
+
|
292 |
+
# Blend only where mask is True
|
293 |
+
blended_np = base_np.copy()
|
294 |
+
blended_np[mask] = base_np[mask] * (1 - alpha) + heat_np[mask] * alpha
|
295 |
+
blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
|
296 |
+
|
297 |
+
# Convert back to PIL image
|
298 |
+
blended_img = Image.fromarray(blended_np)
|
299 |
+
return blended_img
|
300 |
+
|
301 |
+
|
302 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
303 |
gr.Markdown("""
|
304 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
|
|
334 |
outputs=[tag_string, label_box]
|
335 |
)
|
336 |
|
337 |
+
label_box.select(
|
338 |
+
fn=cam_inference,
|
339 |
+
inputs=[threshold_slider],
|
340 |
+
outputs=[image_input]
|
341 |
+
)
|
342 |
+
|
343 |
if __name__ == "__main__":
|
344 |
demo.launch()
|