sayedM commited on
Commit
d73e700
·
verified ·
1 Parent(s): e17f35c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -74
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # app.py
 
2
  import os
3
  import torch
4
  import torch.nn.functional as F
@@ -6,47 +7,79 @@ import gradio as gr
6
  import numpy as np
7
  from PIL import Image, ImageDraw
8
  import torchvision.transforms.functional as TF
9
- from matplotlib import colormaps
10
- from transformers import AutoModel
 
 
 
 
 
 
 
 
 
 
11
 
12
  # ----------------------------
13
  # Configuration
14
  # ----------------------------
15
- MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
 
 
 
 
16
  PATCH_SIZE = 16
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
19
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
20
  IMAGENET_STD = (0.229, 0.224, 0.225)
21
 
22
  # ----------------------------
23
- # Model Loading (Hugging Face Hub)
24
  # ----------------------------
25
- def load_model_from_hub():
26
- """Loads the DINOv3 model from the Hugging Face Hub."""
27
- print(f"Loading model '{MODEL_ID}' from Hugging Face Hub...")
 
 
 
 
28
  try:
29
- token = os.environ.get("HF_TOKEN")
30
- model = AutoModel.from_pretrained(MODEL_ID, token=token, trust_remote_code=True)
31
- model.to(DEVICE).eval()
32
- print(f"✅ Model loaded successfully on device: {DEVICE}")
33
- return model
34
  except Exception as e:
35
- print(f"❌ Failed to load model: {e}")
36
  raise gr.Error(
37
- f"Could not load model '{MODEL_ID}'. "
38
- "This is a gated model. Please ensure you have accepted the terms on its Hugging Face page "
39
- "and set your HF_TOKEN as a secret in your Space settings. "
40
  f"Original error: {e}"
41
  )
42
 
43
- # Load the model globally when the app starts
44
- model = load_model_from_hub()
 
 
 
 
 
 
 
 
 
45
 
46
  # ----------------------------
47
  # Helper Functions (resize, viz)
48
  # ----------------------------
49
  def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor:
 
 
 
 
50
  w, h = img.size
51
  scale = long_side / max(h, w)
52
  new_h = max(patch, int(round(h * scale)))
@@ -58,13 +91,17 @@ def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor
58
  def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
59
  x = sim_map_up.astype(np.float32)
60
  x = (x - x.min()) / (x.max() - x.min() + 1e-6)
61
- rgb = (colormaps[cmap_name](x)[..., :3] * 255).astype(np.uint8)
62
  return Image.fromarray(rgb)
63
 
64
  def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
 
65
  base = base.convert("RGBA")
66
  heat = heat.convert("RGBA")
67
- return Image.blend(base, heat, alpha=alpha)
 
 
 
68
 
69
  def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Image.Image:
70
  r = radius if radius is not None else max(2, PATCH_SIZE // 2)
@@ -96,26 +133,33 @@ def patch_neighborhood_box(r: int, c: int, Hp: int, Wp: int, rad: int, patch: in
96
  return (x0, y0, x1, y1)
97
 
98
  # ----------------------------
99
- # Feature Extraction
100
  # ----------------------------
101
  @torch.inference_mode()
102
  def extract_image_features(image_pil: Image.Image, target_long_side: int):
 
 
 
103
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
104
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
105
  _, _, H, W = t_norm.shape
106
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
107
-
 
108
  outputs = model(t_norm)
109
-
 
110
  n_special_tokens = 5
111
  patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
112
-
 
113
  X = F.normalize(patch_embeddings, p=2, dim=-1)
 
114
  img_resized = TF.to_pil_image(t)
115
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
116
 
117
  # ----------------------------
118
- # Similarity Logic
119
  # ----------------------------
120
  def click_to_similarity_in_same_image(
121
  state: dict,
@@ -128,17 +172,21 @@ def click_to_similarity_in_same_image(
128
  ):
129
  if not state:
130
  return None, None, None, None
 
131
  X = state["X"]
132
  Hp, Wp = state["Hp"], state["Wp"]
133
  base_img = state["img"]
134
  img_w, img_h = base_img.size
 
135
  x_pix, y_pix = click_xy
136
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
137
  row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
138
  idx = row * Wp + col
 
139
  q = X[idx]
140
  sims = torch.matmul(X, q)
141
  sim_map = sims.view(Hp, Wp)
 
142
  if exclude_radius_patches > 0:
143
  rr, cc = torch.meshgrid(
144
  torch.arange(Hp, device=sims.device),
@@ -147,14 +195,17 @@ def click_to_similarity_in_same_image(
147
  )
148
  mask = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
149
  sim_map = sim_map.masked_fill(mask, float("-inf"))
 
150
  sim_up = F.interpolate(
151
  sim_map.unsqueeze(0).unsqueeze(0),
152
  size=(img_h, img_w),
153
  mode="bicubic",
154
  align_corners=False,
155
  ).squeeze().detach().cpu().numpy()
 
156
  heatmap_pil = colorize(sim_up, cmap_name)
157
  overlay_pil = blend(base_img, heatmap_pil, alpha=alpha)
 
158
  overlay_boxes_pil = overlay_pil
159
  if topk and topk > 0:
160
  flat = sim_map.view(-1)
@@ -171,83 +222,100 @@ def click_to_similarity_in_same_image(
171
  for r, c in [divmod(j.item(), Wp) for j in top_idx]
172
  ]
173
  overlay_boxes_pil = draw_boxes(overlay_pil, boxes, outline="yellow", width=3, labels=True)
 
174
  marked_ref = draw_crosshair(base_img, x_pix, y_pix, radius=PATCH_SIZE // 2)
175
  return marked_ref, heatmap_pil, overlay_pil, overlay_boxes_pil
176
 
177
  # ----------------------------
178
- # Gradio UI
179
  # ----------------------------
180
- with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Patch Similarity") as demo:
181
- gr.Markdown("# 🦖 DINOv3: Visualizing Patch Similarity")
182
- gr.Markdown(
183
- "Upload an image, then **click anywhere** on it to find the most visually similar regions. "
184
- "**Note:** If running on a CPU-only Space, feature extraction after uploading an image can take a moment."
185
- )
186
-
187
  app_state = gr.State()
188
-
189
  with gr.Row():
190
- with gr.Column(scale=2):
191
  input_image = gr.Image(
192
  label="Image (click anywhere)",
193
  type="pil",
194
  value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg"
195
  )
196
- with gr.Accordion("⚙️ Visualization Controls", open=True):
197
- target_long_side = gr.Slider(
198
- minimum=224, maximum=1024, value=768, step=16,
199
- label="Processing Resolution",
200
- info="Higher values = more detail but slower processing",
201
- )
202
- alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay Opacity")
203
  cmap = gr.Dropdown(
204
  ["viridis", "magma", "plasma", "inferno", "turbo", "cividis"],
205
- value="viridis", label="Heatmap Colormap",
206
  )
207
- with gr.Accordion("⚙️ Similarity Controls", open=True):
208
- exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude Radius (patches)", info="Ignore patches around the click point.")
209
- topk = gr.Slider(0, 50, value=10, step=1, label="Top-K Boxes", info="Number of similar regions to highlight.")
210
- box_radius = gr.Slider(0, 10, value=1, step=1, label="Box Radius (patches)", info="Size of the highlight box.")
211
-
212
- with gr.Column(scale=3):
213
- marked_image = gr.Image(label="Your Click (on processed image)", interactive=False)
214
- with gr.Tabs():
215
- with gr.TabItem("📦 Bounding Boxes"):
216
- overlay_boxes_output = gr.Image(label="Overlay + Top-K Similar Patches", interactive=False)
217
- with gr.TabItem("🔥 Heatmap"):
218
- heatmap_output = gr.Image(label="Similarity Heatmap", interactive=False)
219
- with gr.TabItem(" blended"):
220
- overlay_output = gr.Image(label="Blended Overlay (Image + Heatmap)", interactive=False)
221
-
222
- def _on_upload_or_slider_change(img: Image.Image, long_side: int, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  if img is None:
224
  return None, None
225
- progress(0, desc="🦖 Extracting DINOv3 features...")
 
226
  st = extract_image_features(img, int(long_side))
227
- progress(1, desc="Done!")
228
- # Clear old results when a new image is uploaded
229
- return st["img"], st, None, None, None, None
230
 
231
  def _on_click(st, a: float, m: str, excl: int, k: int, box_rad: int, evt: gr.SelectData):
232
  if not st or evt is None:
233
- # Return current state if no click data
234
- return st.get("img"), None, None, None
235
-
236
- marked, heat, overlay, boxes = click_to_similarity_in_same_image(
237
  st, click_xy=evt.index, exclude_radius_patches=int(excl),
238
  topk=int(k), alpha=float(a), cmap_name=m,
239
  box_radius_patches=int(box_rad),
240
  )
241
- return marked, heat, overlay, boxes
242
 
243
  # Wire events
244
- inputs_for_update = [input_image, target_long_side]
245
- outputs_for_upload = [marked_image, app_state, heatmap_output, overlay_output, overlay_boxes_output, marked_image]
 
 
 
 
 
 
246
 
247
- input_image.upload(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_upload)
248
- target_long_side.change(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_upload)
249
- demo.load(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_upload)
250
 
 
251
  marked_image.select(
252
  _on_click,
253
  inputs=[app_state, alpha, cmap, exclude_r, topk, box_radius],
@@ -255,4 +323,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Patch Similarity") as demo:
255
  )
256
 
257
  if __name__ == "__main__":
258
- demo.launch()
 
1
  # app.py
2
+
3
  import os
4
  import torch
5
  import torch.nn.functional as F
 
7
  import numpy as np
8
  from PIL import Image, ImageDraw
9
  import torchvision.transforms.functional as TF
10
+
11
+ # --- Robust colormap import (Matplotlib ≥3.5 and older versions) ---
12
+ try:
13
+ from matplotlib import colormaps as _mpl_colormaps
14
+ def _get_cmap(name: str):
15
+ return _mpl_colormaps[name]
16
+ except Exception:
17
+ import matplotlib.cm as _cm
18
+ def _get_cmap(name: str):
19
+ return _cm.get_cmap(name)
20
+
21
+ from transformers import AutoModel # uses trust_remote_code for DINOv3
22
 
23
  # ----------------------------
24
  # Configuration
25
  # ----------------------------
26
+ # Default to smaller/faster ViT-S/16+; offer ViT-H/16+ as alternative.
27
+ DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
28
+ ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
29
+ AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
30
+
31
  PATCH_SIZE = 16
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
+ # Normalization constants (standard for ImageNet)
35
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
36
  IMAGENET_STD = (0.229, 0.224, 0.225)
37
 
38
  # ----------------------------
39
+ # Model Loading (Hugging Face Hub) with caching
40
  # ----------------------------
41
+ _model_cache = {}
42
+ _current_model_id = None
43
+ model = None # global reference used by extract_image_features()
44
+
45
+ def load_model_from_hub(model_id: str):
46
+ """Loads a DINOv3 model from the Hugging Face Hub."""
47
+ print(f"Loading model '{model_id}' from Hugging Face Hub...")
48
  try:
49
+ token = os.environ.get("HF_TOKEN") # optional, for gated models
50
+ mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
51
+ mdl.to(DEVICE).eval()
52
+ print(f"✅ Model '{model_id}' loaded successfully on device: {DEVICE}")
53
+ return mdl
54
  except Exception as e:
55
+ print(f"❌ Failed to load model '{model_id}': {e}")
56
  raise gr.Error(
57
+ f"Could not load model '{model_id}'. "
58
+ "If the model is gated, please accept the terms on its Hugging Face page "
59
+ "and set HF_TOKEN in your environment. "
60
  f"Original error: {e}"
61
  )
62
 
63
+ def get_model(model_id: str):
64
+ """Return a cached model if available, otherwise load and cache it."""
65
+ if model_id in _model_cache:
66
+ return _model_cache[model_id]
67
+ mdl = load_model_from_hub(model_id)
68
+ _model_cache[model_id] = mdl
69
+ return mdl
70
+
71
+ # Load default model at startup
72
+ model = get_model(DEFAULT_MODEL_ID)
73
+ _current_model_id = DEFAULT_MODEL_ID
74
 
75
  # ----------------------------
76
  # Helper Functions (resize, viz)
77
  # ----------------------------
78
  def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor:
79
+ """
80
+ Resizes so max(h,w)=long_side (keeping aspect), then rounds each side UP to a multiple of 'patch'.
81
+ Returns CHW float tensor in [0,1].
82
+ """
83
  w, h = img.size
84
  scale = long_side / max(h, w)
85
  new_h = max(patch, int(round(h * scale)))
 
91
  def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
92
  x = sim_map_up.astype(np.float32)
93
  x = (x - x.min()) / (x.max() - x.min() + 1e-6)
94
+ rgb = (_get_cmap(cmap_name)(x)[..., :3] * 255).astype(np.uint8)
95
  return Image.fromarray(rgb)
96
 
97
  def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
98
+ # Put alpha on heatmap and composite for a crisp overlay
99
  base = base.convert("RGBA")
100
  heat = heat.convert("RGBA")
101
+ a = Image.new("L", heat.size, int(255 * alpha))
102
+ heat.putalpha(a)
103
+ out = Image.alpha_composite(base, heat)
104
+ return out.convert("RGB")
105
 
106
  def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Image.Image:
107
  r = radius if radius is not None else max(2, PATCH_SIZE // 2)
 
133
  return (x0, y0, x1, y1)
134
 
135
  # ----------------------------
136
+ # Feature Extraction (using transformers)
137
  # ----------------------------
138
  @torch.inference_mode()
139
  def extract_image_features(image_pil: Image.Image, target_long_side: int):
140
+ """
141
+ Extracts patch features from an image using the loaded Hugging Face model.
142
+ """
143
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
144
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
145
  _, _, H, W = t_norm.shape
146
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
147
+
148
+ # Models output: [CLS] + 4 register tokens + patches
149
  outputs = model(t_norm)
150
+
151
+ # Skip the 5 special tokens to get only patch embeddings
152
  n_special_tokens = 5
153
  patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
154
+
155
+ # L2-normalize features for cosine similarity
156
  X = F.normalize(patch_embeddings, p=2, dim=-1)
157
+
158
  img_resized = TF.to_pil_image(t)
159
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
160
 
161
  # ----------------------------
162
+ # Similarity inside the same image
163
  # ----------------------------
164
  def click_to_similarity_in_same_image(
165
  state: dict,
 
172
  ):
173
  if not state:
174
  return None, None, None, None
175
+
176
  X = state["X"]
177
  Hp, Wp = state["Hp"], state["Wp"]
178
  base_img = state["img"]
179
  img_w, img_h = base_img.size
180
+
181
  x_pix, y_pix = click_xy
182
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
183
  row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
184
  idx = row * Wp + col
185
+
186
  q = X[idx]
187
  sims = torch.matmul(X, q)
188
  sim_map = sims.view(Hp, Wp)
189
+
190
  if exclude_radius_patches > 0:
191
  rr, cc = torch.meshgrid(
192
  torch.arange(Hp, device=sims.device),
 
195
  )
196
  mask = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
197
  sim_map = sim_map.masked_fill(mask, float("-inf"))
198
+
199
  sim_up = F.interpolate(
200
  sim_map.unsqueeze(0).unsqueeze(0),
201
  size=(img_h, img_w),
202
  mode="bicubic",
203
  align_corners=False,
204
  ).squeeze().detach().cpu().numpy()
205
+
206
  heatmap_pil = colorize(sim_up, cmap_name)
207
  overlay_pil = blend(base_img, heatmap_pil, alpha=alpha)
208
+
209
  overlay_boxes_pil = overlay_pil
210
  if topk and topk > 0:
211
  flat = sim_map.view(-1)
 
222
  for r, c in [divmod(j.item(), Wp) for j in top_idx]
223
  ]
224
  overlay_boxes_pil = draw_boxes(overlay_pil, boxes, outline="yellow", width=3, labels=True)
225
+
226
  marked_ref = draw_crosshair(base_img, x_pix, y_pix, radius=PATCH_SIZE // 2)
227
  return marked_ref, heatmap_pil, overlay_pil, overlay_boxes_pil
228
 
229
  # ----------------------------
230
+ # Gradio UI (+ Start button, + Model dropdown)
231
  # ----------------------------
232
+ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Single-Image Patch Similarity") as demo:
233
+ gr.Markdown("# 🦖 DINOv3 Single-Image Patch Similarity")
234
+ gr.Markdown("## Running on CPU-only Space, feature extraction after uploading an image can take a moment")
235
+ gr.Markdown("Upload one image, then **click anywhere** to highlight the most similar regions in the *same* image.")
236
+
 
 
237
  app_state = gr.State()
238
+
239
  with gr.Row():
240
+ with gr.Column(scale=1):
241
  input_image = gr.Image(
242
  label="Image (click anywhere)",
243
  type="pil",
244
  value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg"
245
  )
246
+ target_long_side = gr.Slider(
247
+ minimum=224, maximum=1024, value=768, step=16,
248
+ label="Processing Resolution",
249
+ info="Higher values = more detail but slower processing",
250
+ )
251
+ with gr.Row():
252
+ alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
253
  cmap = gr.Dropdown(
254
  ["viridis", "magma", "plasma", "inferno", "turbo", "cividis"],
255
+ value="viridis", label="Colormap",
256
  )
257
+ # NEW: Backbone selector (default = smaller/faster ViT-S/16+)
258
+ model_choice = gr.Dropdown(
259
+ choices=AVAILABLE_MODELS,
260
+ value=DEFAULT_MODEL_ID,
261
+ label="Backbone (DINOv3)",
262
+ info="ViT-S/16+ is smaller & faster; ViT-H/16+ is larger.",
263
+ )
264
+ # Start processing button
265
+ with gr.Row():
266
+ start_btn = gr.Button("▶️ Start processing", variant="primary")
267
+
268
+ with gr.Column(scale=1):
269
+ exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches)")
270
+ topk = gr.Slider(0, 200, value=20, step=1, label="Top-K boxes")
271
+ box_radius = gr.Slider(0, 10, value=1, step=1, label="Box radius (patches)")
272
+
273
+ with gr.Row():
274
+ marked_image = gr.Image(label="Click marker", interactive=False)
275
+ heatmap_output = gr.Image(label="Similarity heatmap", interactive=False)
276
+ with gr.Row():
277
+ overlay_output = gr.Image(label="Overlay (image ⊕ heatmap)", interactive=False)
278
+ overlay_boxes_output = gr.Image(label="Overlay + top-K similar patch boxes", interactive=False)
279
+
280
+ def _ensure_model(model_id: str):
281
+ """Ensure the global 'model' matches the dropdown selection."""
282
+ global model, _current_model_id
283
+ if model_id != _current_model_id:
284
+ model = get_model(model_id)
285
+ _current_model_id = model_id
286
+
287
+ def _on_upload_or_slider_change(img: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=True)):
288
  if img is None:
289
  return None, None
290
+ _ensure_model(model_id)
291
+ progress(0, desc="Extracting features...")
292
  st = extract_image_features(img, int(long_side))
293
+ progress(1, desc="Done!")
294
+ return st["img"], st
 
295
 
296
  def _on_click(st, a: float, m: str, excl: int, k: int, box_rad: int, evt: gr.SelectData):
297
  if not st or evt is None:
298
+ return None, None, None, None
299
+ return click_to_similarity_in_same_image(
 
 
300
  st, click_xy=evt.index, exclude_radius_patches=int(excl),
301
  topk=int(k), alpha=float(a), cmap_name=m,
302
  box_radius_patches=int(box_rad),
303
  )
 
304
 
305
  # Wire events
306
+ inputs_for_update = [input_image, target_long_side, model_choice]
307
+ outputs_for_update = [marked_image, app_state]
308
+
309
+ # Auto triggers (kept)
310
+ input_image.upload(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_update)
311
+ target_long_side.change(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_update)
312
+ model_choice.change(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_update)
313
+ demo.load(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_update) # Process default image on load
314
 
315
+ # Manual trigger via button (kept)
316
+ start_btn.click(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_update)
 
317
 
318
+ # Click to compute similarities
319
  marked_image.select(
320
  _on_click,
321
  inputs=[app_state, alpha, cmap, exclude_r, topk, box_radius],
 
323
  )
324
 
325
  if __name__ == "__main__":
326
+ demo.launch()