Update app.py
Browse files
app.py
CHANGED
@@ -154,8 +154,6 @@ allowed_tags = list(tags.keys())
|
|
154 |
for idx, tag in enumerate(allowed_tags):
|
155 |
allowed_tags[idx] = tag.replace("_", " ")
|
156 |
|
157 |
-
|
158 |
-
|
159 |
@spaces.GPU(duration=5)
|
160 |
def run_classifier(image: Image.Image, threshold):
|
161 |
img = image.convert('RGBA')
|
@@ -186,9 +184,6 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
186 |
|
187 |
gradients = {}
|
188 |
activations = {}
|
189 |
-
cam = None
|
190 |
-
target_tag_index = None
|
191 |
-
|
192 |
|
193 |
def hook_forward(module, input, output):
|
194 |
activations['value'] = output
|
@@ -200,29 +195,24 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
200 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
201 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
202 |
|
203 |
-
probits = model(tensor)[0]
|
204 |
|
205 |
model.zero_grad()
|
206 |
-
|
207 |
-
target_score.backward(retain_graph=True)
|
208 |
-
|
209 |
-
grads = gradients.get('value')
|
210 |
-
acts = activations.get('value')
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
221 |
|
222 |
handle_forward.remove()
|
223 |
handle_backward.remove()
|
224 |
-
gradients = {}
|
225 |
-
activations = {}
|
226 |
|
227 |
return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
|
228 |
|
@@ -245,26 +235,30 @@ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
|
|
245 |
size = max(w, h)
|
246 |
|
247 |
# Normalize CAM to [0, 1]
|
248 |
-
|
|
|
249 |
|
250 |
# Create heatmap using matplotlib colormap
|
251 |
colormap = cm.get_cmap('inferno')
|
252 |
-
|
253 |
-
cam_alpha = (cam_norm >= vis_threshold).astype(np.float32) * alpha # Alpha mask
|
254 |
|
255 |
-
|
|
|
|
|
256 |
|
257 |
# Resize CAM to match image
|
258 |
-
|
|
|
259 |
|
260 |
-
|
|
|
|
|
261 |
|
262 |
# Composite over original
|
263 |
composite = Image.alpha_composite(image_pil, cam_image)
|
264 |
|
265 |
return composite
|
266 |
|
267 |
-
|
268 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
269 |
gr.Markdown("""
|
270 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
@@ -280,10 +274,20 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
280 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
281 |
cam_state = gr.State()
|
282 |
with gr.Row():
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
285 |
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
286 |
-
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.
|
287 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
288 |
with gr.Column():
|
289 |
tag_string = gr.Textbox(label="Tag String")
|
|
|
154 |
for idx, tag in enumerate(allowed_tags):
|
155 |
allowed_tags[idx] = tag.replace("_", " ")
|
156 |
|
|
|
|
|
157 |
@spaces.GPU(duration=5)
|
158 |
def run_classifier(image: Image.Image, threshold):
|
159 |
img = image.convert('RGBA')
|
|
|
184 |
|
185 |
gradients = {}
|
186 |
activations = {}
|
|
|
|
|
|
|
187 |
|
188 |
def hook_forward(module, input, output):
|
189 |
activations['value'] = output
|
|
|
195 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
196 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
197 |
|
198 |
+
probits = model(tensor)[0]
|
199 |
|
200 |
model.zero_grad()
|
201 |
+
probits[target_tag_index].backward(retain_graph=True)
|
|
|
|
|
|
|
|
|
202 |
|
203 |
+
with torch.no_grad():
|
204 |
+
patch_grads = gradients.get('value')
|
205 |
+
patch_acts = activations.get('value')
|
206 |
+
|
207 |
+
weights = torch.mean(patch_grads, dim=1).squeeze(0)
|
208 |
+
|
209 |
+
cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
|
210 |
+
cam_1d = torch.relu(cam_1d)
|
211 |
+
|
212 |
+
cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
|
213 |
|
214 |
handle_forward.remove()
|
215 |
handle_backward.remove()
|
|
|
|
|
216 |
|
217 |
return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
|
218 |
|
|
|
235 |
size = max(w, h)
|
236 |
|
237 |
# Normalize CAM to [0, 1]
|
238 |
+
cam -= cam.min()
|
239 |
+
cam /= cam.max()
|
240 |
|
241 |
# Create heatmap using matplotlib colormap
|
242 |
colormap = cm.get_cmap('inferno')
|
243 |
+
cam_rgb = colormap(cam)[:, :, :3] # RGB
|
|
|
244 |
|
245 |
+
# Create alpha channel
|
246 |
+
cam_alpha = (cam >= vis_threshold).astype(np.float32) * alpha # Alpha mask
|
247 |
+
cam_rgba = np.dstack((cam_rgb, cam_alpha)) # Shape: (H, W, 4)
|
248 |
|
249 |
# Resize CAM to match image
|
250 |
+
cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
|
251 |
+
cam_pil = cam_pil.resize((216,216), resample=Image.Resampling.NEAREST)
|
252 |
|
253 |
+
# Model uses padded image as input, this matches attention map to input image aspect ratio
|
254 |
+
cam_pil = cam_pil.resize((size, size), resample=Image.Resampling.BICUBIC)
|
255 |
+
cam_pil = transforms.CenterCrop((h, w))(cam_pil)
|
256 |
|
257 |
# Composite over original
|
258 |
composite = Image.alpha_composite(image_pil, cam_image)
|
259 |
|
260 |
return composite
|
261 |
|
|
|
262 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
263 |
gr.Markdown("""
|
264 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
|
|
274 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
275 |
cam_state = gr.State()
|
276 |
with gr.Row():
|
277 |
+
custom_css = """
|
278 |
+
.inferno-slider input[type=range] {
|
279 |
+
background: linear-gradient(to right,
|
280 |
+
#000004, #1b0c41, #4a0c6b, #781c6d,
|
281 |
+
#a52c60, #cf4446, #ed6925, #fb9b06,
|
282 |
+
#f7d13d, #fcffa4
|
283 |
+
) !important;
|
284 |
+
background-size: 100% 100% !important;
|
285 |
+
}
|
286 |
+
"""
|
287 |
+
with gr.Column(css=custom_css):
|
288 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
289 |
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
290 |
+
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
291 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
292 |
with gr.Column():
|
293 |
tag_string = gr.Textbox(label="Tag String")
|