sayedM commited on
Commit
e17f35c
Β·
verified Β·
1 Parent(s): ff51b2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -106
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  import torch.nn.functional as F
@@ -5,59 +6,45 @@ import gradio as gr
5
  import numpy as np
6
  from PIL import Image, ImageDraw
7
  import torchvision.transforms.functional as TF
8
- from matplotlib import colaps
9
  from transformers import AutoModel
10
 
11
  # ----------------------------
12
  # Configuration
13
  # ----------------------------
14
- # ⭐ Define available models, with the smaller one as default
15
- MODELS = {
16
- "DINOv3 ViT-S+ (Small, Default)": "facebook/dinov3-vits16plus-pretrain-lvd1689m",
17
- "DINOv3 ViT-H+ (Huge)": "facebook/dinov3-vith16plus-pretrain-lvd1689m",
18
- }
19
- DEFAULT_MODEL_NAME = "DINOv3 ViT-S+ (Small, Default)"
20
-
21
  PATCH_SIZE = 16
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # Normalization constants (standard for ImageNet)
25
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
26
  IMAGENET_STD = (0.229, 0.224, 0.225)
27
 
28
- # ⭐ Cache for loaded models to avoid re-downloading
29
- model_cache = {}
30
-
31
  # ----------------------------
32
  # Model Loading (Hugging Face Hub)
33
  # ----------------------------
34
- def load_model_from_hub(model_id: str):
35
- """Loads a DINOv3 model from the Hugging Face Hub."""
36
- print(f"Loading model '{model_id}' from Hugging Face Hub...")
37
  try:
38
  token = os.environ.get("HF_TOKEN")
39
- model = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
40
  model.to(DEVICE).eval()
41
  print(f"βœ… Model loaded successfully on device: {DEVICE}")
42
  return model
43
  except Exception as e:
44
  print(f"❌ Failed to load model: {e}")
45
  raise gr.Error(
46
- f"Could not load model '{model_id}'. "
47
  "This is a gated model. Please ensure you have accepted the terms on its Hugging Face page "
48
  "and set your HF_TOKEN as a secret in your Space settings. "
49
  f"Original error: {e}"
50
  )
51
 
52
- def get_model(model_name: str):
53
- """Gets a model from the cache or loads it if not present."""
54
- model_id = MODELS[model_name]
55
- if model_id not in model_cache:
56
- model_cache[model_id] = load_model_from_hub(model_id)
57
- return model_cache[model_id]
58
 
59
  # ----------------------------
60
- # Helper Functions (resize, viz) - No changes here
61
  # ----------------------------
62
  def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor:
63
  w, h = img.size
@@ -77,10 +64,7 @@ def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
77
  def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
78
  base = base.convert("RGBA")
79
  heat = heat.convert("RGBA")
80
- a = Image.new("L", heat.size, int(255 * alpha))
81
- heat.putalpha(a)
82
- out = Image.alpha_composite(base, heat)
83
- return out.convert("RGB")
84
 
85
  def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Image.Image:
86
  r = radius if radius is not None else max(2, PATCH_SIZE // 2)
@@ -112,31 +96,26 @@ def patch_neighborhood_box(r: int, c: int, Hp: int, Wp: int, rad: int, patch: in
112
  return (x0, y0, x1, y1)
113
 
114
  # ----------------------------
115
- # Feature Extraction (using transformers)
116
  # ----------------------------
117
  @torch.inference_mode()
118
- # ⭐ Pass the model object as an argument
119
- def extract_image_features(model, image_pil: Image.Image, target_long_side: int):
120
- """
121
- Extracts patch features from an image using the loaded Hugging Face model.
122
- """
123
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
124
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
125
  _, _, H, W = t_norm.shape
126
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
127
-
128
  outputs = model(t_norm)
129
-
130
  n_special_tokens = 5
131
  patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
132
-
133
  X = F.normalize(patch_embeddings, p=2, dim=-1)
134
-
135
  img_resized = TF.to_pil_image(t)
136
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
137
 
138
  # ----------------------------
139
- # Similarity inside the same image - No changes here
140
  # ----------------------------
141
  def click_to_similarity_in_same_image(
142
  state: dict,
@@ -149,21 +128,17 @@ def click_to_similarity_in_same_image(
149
  ):
150
  if not state:
151
  return None, None, None, None
152
-
153
  X = state["X"]
154
  Hp, Wp = state["Hp"], state["Wp"]
155
  base_img = state["img"]
156
  img_w, img_h = base_img.size
157
-
158
  x_pix, y_pix = click_xy
159
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
160
  row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
161
  idx = row * Wp + col
162
-
163
  q = X[idx]
164
  sims = torch.matmul(X, q)
165
  sim_map = sims.view(Hp, Wp)
166
-
167
  if exclude_radius_patches > 0:
168
  rr, cc = torch.meshgrid(
169
  torch.arange(Hp, device=sims.device),
@@ -172,17 +147,14 @@ def click_to_similarity_in_same_image(
172
  )
173
  mask = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
174
  sim_map = sim_map.masked_fill(mask, float("-inf"))
175
-
176
  sim_up = F.interpolate(
177
  sim_map.unsqueeze(0).unsqueeze(0),
178
  size=(img_h, img_w),
179
  mode="bicubic",
180
  align_corners=False,
181
  ).squeeze().detach().cpu().numpy()
182
-
183
  heatmap_pil = colorize(sim_up, cmap_name)
184
  overlay_pil = blend(base_img, heatmap_pil, alpha=alpha)
185
-
186
  overlay_boxes_pil = overlay_pil
187
  if topk and topk > 0:
188
  flat = sim_map.view(-1)
@@ -199,92 +171,82 @@ def click_to_similarity_in_same_image(
199
  for r, c in [divmod(j.item(), Wp) for j in top_idx]
200
  ]
201
  overlay_boxes_pil = draw_boxes(overlay_pil, boxes, outline="yellow", width=3, labels=True)
202
-
203
  marked_ref = draw_crosshair(base_img, x_pix, y_pix, radius=PATCH_SIZE // 2)
204
  return marked_ref, heatmap_pil, overlay_pil, overlay_boxes_pil
205
 
206
  # ----------------------------
207
  # Gradio UI
208
  # ----------------------------
209
- with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Single-Image Patch Similarity") as demo:
210
- gr.Markdown("# πŸ¦– DINOv3 Single-Image Patch Similarity")
211
- gr.Markdown("## Running on CPU-only Space, feature extraction can take a moment")
212
- gr.Markdown("1. **Choose a model**. 2. Upload an image. 3. Click **Process Image**. 4. **Click anywhere on the processed image** to find similar regions.")
213
-
 
 
214
  app_state = gr.State()
215
-
216
  with gr.Row():
217
- with gr.Column(scale=1):
218
- # ⭐ ADDED MODEL DROPDOWN
219
- model_name_dd = gr.Dropdown(
220
- label="1. Choose a Model",
221
- choices=list(MODELS.keys()),
222
- value=DEFAULT_MODEL_NAME,
223
- )
224
  input_image = gr.Image(
225
- label="2. Upload Image",
226
  type="pil",
227
  value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg"
228
  )
229
- target_long_side = gr.Slider(
230
- minimum=224, maximum=1024, value=768, step=16,
231
- label="Processing Resolution",
232
- info="Higher values = more detail but slower processing",
233
- )
234
- process_button = gr.Button("3. Process Image", variant="primary")
235
-
236
- with gr.Row():
237
- alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
238
  cmap = gr.Dropdown(
239
  ["viridis", "magma", "plasma", "inferno", "turbo", "cividis"],
240
- value="viridis", label="Colormap",
241
  )
242
- with gr.Column(scale=1):
243
- exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches)")
244
- topk = gr.Slider(0, 200, value=20, step=1, label="Top-K boxes")
245
- box_radius = gr.Slider(0, 10, value=1, step=1, label="Box radius (patches)")
246
-
247
- with gr.Row():
248
- marked_image = gr.Image(label="4. Click on this image", interactive=True)
249
- heatmap_output = gr.Image(label="Similarity heatmap", interactive=False)
250
- with gr.Row():
251
- overlay_output = gr.Image(label="Overlay (image βŠ• heatmap)", interactive=False)
252
- overlay_boxes_output = gr.Image(label="Overlay + top-K similar patch boxes", interactive=False)
253
-
254
- # ⭐ UPDATED to take model_name as input
255
- def _process_image(model_name: str, img: Image.Image, long_side: int, progress=gr.Progress(track_tqdm=True)):
 
 
256
  if img is None:
257
- gr.Warning("Please upload an image first!")
258
  return None, None
259
-
260
- progress(0, desc=f"Loading model '{model_name}'...")
261
- model = get_model(model_name)
262
-
263
- progress(0.5, desc="Extracting features...")
264
- st = extract_image_features(model, img, int(long_side))
265
-
266
- progress(1, desc="Done! You can now click on the image.")
267
- return st["img"], st
268
 
269
  def _on_click(st, a: float, m: str, excl: int, k: int, box_rad: int, evt: gr.SelectData):
270
  if not st or evt is None:
271
- gr.Warning("Please process an image before clicking on it.")
272
- return None, None, None, None
273
- return click_to_similarity_in_same_image(
 
274
  st, click_xy=evt.index, exclude_radius_patches=int(excl),
275
  topk=int(k), alpha=float(a), cmap_name=m,
276
  box_radius_patches=int(box_rad),
277
  )
 
278
 
279
- # ⭐ UPDATED EVENT WIRING to include the dropdown
280
- inputs_for_processing = [model_name_dd, input_image, target_long_side]
281
- outputs_for_processing = [marked_image, app_state]
282
 
283
- process_button.click(
284
- _process_image,
285
- inputs=inputs_for_processing,
286
- outputs=outputs_for_processing
287
- )
288
 
289
  marked_image.select(
290
  _on_click,
 
1
+ # app.py
2
  import os
3
  import torch
4
  import torch.nn.functional as F
 
6
  import numpy as np
7
  from PIL import Image, ImageDraw
8
  import torchvision.transforms.functional as TF
9
+ from matplotlib import colormaps
10
  from transformers import AutoModel
11
 
12
  # ----------------------------
13
  # Configuration
14
  # ----------------------------
15
+ MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
 
 
 
 
 
 
16
  PATCH_SIZE = 16
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
19
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
20
  IMAGENET_STD = (0.229, 0.224, 0.225)
21
 
 
 
 
22
  # ----------------------------
23
  # Model Loading (Hugging Face Hub)
24
  # ----------------------------
25
+ def load_model_from_hub():
26
+ """Loads the DINOv3 model from the Hugging Face Hub."""
27
+ print(f"Loading model '{MODEL_ID}' from Hugging Face Hub...")
28
  try:
29
  token = os.environ.get("HF_TOKEN")
30
+ model = AutoModel.from_pretrained(MODEL_ID, token=token, trust_remote_code=True)
31
  model.to(DEVICE).eval()
32
  print(f"βœ… Model loaded successfully on device: {DEVICE}")
33
  return model
34
  except Exception as e:
35
  print(f"❌ Failed to load model: {e}")
36
  raise gr.Error(
37
+ f"Could not load model '{MODEL_ID}'. "
38
  "This is a gated model. Please ensure you have accepted the terms on its Hugging Face page "
39
  "and set your HF_TOKEN as a secret in your Space settings. "
40
  f"Original error: {e}"
41
  )
42
 
43
+ # Load the model globally when the app starts
44
+ model = load_model_from_hub()
 
 
 
 
45
 
46
  # ----------------------------
47
+ # Helper Functions (resize, viz)
48
  # ----------------------------
49
  def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor:
50
  w, h = img.size
 
64
  def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
65
  base = base.convert("RGBA")
66
  heat = heat.convert("RGBA")
67
+ return Image.blend(base, heat, alpha=alpha)
 
 
 
68
 
69
  def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Image.Image:
70
  r = radius if radius is not None else max(2, PATCH_SIZE // 2)
 
96
  return (x0, y0, x1, y1)
97
 
98
  # ----------------------------
99
+ # Feature Extraction
100
  # ----------------------------
101
  @torch.inference_mode()
102
+ def extract_image_features(image_pil: Image.Image, target_long_side: int):
 
 
 
 
103
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
104
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
105
  _, _, H, W = t_norm.shape
106
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
107
+
108
  outputs = model(t_norm)
109
+
110
  n_special_tokens = 5
111
  patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
112
+
113
  X = F.normalize(patch_embeddings, p=2, dim=-1)
 
114
  img_resized = TF.to_pil_image(t)
115
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
116
 
117
  # ----------------------------
118
+ # Similarity Logic
119
  # ----------------------------
120
  def click_to_similarity_in_same_image(
121
  state: dict,
 
128
  ):
129
  if not state:
130
  return None, None, None, None
 
131
  X = state["X"]
132
  Hp, Wp = state["Hp"], state["Wp"]
133
  base_img = state["img"]
134
  img_w, img_h = base_img.size
 
135
  x_pix, y_pix = click_xy
136
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
137
  row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
138
  idx = row * Wp + col
 
139
  q = X[idx]
140
  sims = torch.matmul(X, q)
141
  sim_map = sims.view(Hp, Wp)
 
142
  if exclude_radius_patches > 0:
143
  rr, cc = torch.meshgrid(
144
  torch.arange(Hp, device=sims.device),
 
147
  )
148
  mask = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
149
  sim_map = sim_map.masked_fill(mask, float("-inf"))
 
150
  sim_up = F.interpolate(
151
  sim_map.unsqueeze(0).unsqueeze(0),
152
  size=(img_h, img_w),
153
  mode="bicubic",
154
  align_corners=False,
155
  ).squeeze().detach().cpu().numpy()
 
156
  heatmap_pil = colorize(sim_up, cmap_name)
157
  overlay_pil = blend(base_img, heatmap_pil, alpha=alpha)
 
158
  overlay_boxes_pil = overlay_pil
159
  if topk and topk > 0:
160
  flat = sim_map.view(-1)
 
171
  for r, c in [divmod(j.item(), Wp) for j in top_idx]
172
  ]
173
  overlay_boxes_pil = draw_boxes(overlay_pil, boxes, outline="yellow", width=3, labels=True)
 
174
  marked_ref = draw_crosshair(base_img, x_pix, y_pix, radius=PATCH_SIZE // 2)
175
  return marked_ref, heatmap_pil, overlay_pil, overlay_boxes_pil
176
 
177
  # ----------------------------
178
  # Gradio UI
179
  # ----------------------------
180
+ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Patch Similarity") as demo:
181
+ gr.Markdown("# πŸ¦– DINOv3: Visualizing Patch Similarity")
182
+ gr.Markdown(
183
+ "Upload an image, then **click anywhere** on it to find the most visually similar regions. "
184
+ "**Note:** If running on a CPU-only Space, feature extraction after uploading an image can take a moment."
185
+ )
186
+
187
  app_state = gr.State()
188
+
189
  with gr.Row():
190
+ with gr.Column(scale=2):
 
 
 
 
 
 
191
  input_image = gr.Image(
192
+ label="Image (click anywhere)",
193
  type="pil",
194
  value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg"
195
  )
196
+ with gr.Accordion("βš™οΈ Visualization Controls", open=True):
197
+ target_long_side = gr.Slider(
198
+ minimum=224, maximum=1024, value=768, step=16,
199
+ label="Processing Resolution",
200
+ info="Higher values = more detail but slower processing",
201
+ )
202
+ alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay Opacity")
 
 
203
  cmap = gr.Dropdown(
204
  ["viridis", "magma", "plasma", "inferno", "turbo", "cividis"],
205
+ value="viridis", label="Heatmap Colormap",
206
  )
207
+ with gr.Accordion("βš™οΈ Similarity Controls", open=True):
208
+ exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude Radius (patches)", info="Ignore patches around the click point.")
209
+ topk = gr.Slider(0, 50, value=10, step=1, label="Top-K Boxes", info="Number of similar regions to highlight.")
210
+ box_radius = gr.Slider(0, 10, value=1, step=1, label="Box Radius (patches)", info="Size of the highlight box.")
211
+
212
+ with gr.Column(scale=3):
213
+ marked_image = gr.Image(label="Your Click (on processed image)", interactive=False)
214
+ with gr.Tabs():
215
+ with gr.TabItem("πŸ“¦ Bounding Boxes"):
216
+ overlay_boxes_output = gr.Image(label="Overlay + Top-K Similar Patches", interactive=False)
217
+ with gr.TabItem("πŸ”₯ Heatmap"):
218
+ heatmap_output = gr.Image(label="Similarity Heatmap", interactive=False)
219
+ with gr.TabItem(" blended"):
220
+ overlay_output = gr.Image(label="Blended Overlay (Image + Heatmap)", interactive=False)
221
+
222
+ def _on_upload_or_slider_change(img: Image.Image, long_side: int, progress=gr.Progress(track_tqdm=True)):
223
  if img is None:
 
224
  return None, None
225
+ progress(0, desc="πŸ¦– Extracting DINOv3 features...")
226
+ st = extract_image_features(img, int(long_side))
227
+ progress(1, desc="βœ… Done!")
228
+ # Clear old results when a new image is uploaded
229
+ return st["img"], st, None, None, None, None
 
 
 
 
230
 
231
  def _on_click(st, a: float, m: str, excl: int, k: int, box_rad: int, evt: gr.SelectData):
232
  if not st or evt is None:
233
+ # Return current state if no click data
234
+ return st.get("img"), None, None, None
235
+
236
+ marked, heat, overlay, boxes = click_to_similarity_in_same_image(
237
  st, click_xy=evt.index, exclude_radius_patches=int(excl),
238
  topk=int(k), alpha=float(a), cmap_name=m,
239
  box_radius_patches=int(box_rad),
240
  )
241
+ return marked, heat, overlay, boxes
242
 
243
+ # Wire events
244
+ inputs_for_update = [input_image, target_long_side]
245
+ outputs_for_upload = [marked_image, app_state, heatmap_output, overlay_output, overlay_boxes_output, marked_image]
246
 
247
+ input_image.upload(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_upload)
248
+ target_long_side.change(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_upload)
249
+ demo.load(_on_upload_or_slider_change, inputs=inputs_for_update, outputs=outputs_for_upload)
 
 
250
 
251
  marked_image.select(
252
  _on_click,