Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,21 @@
|
|
1 |
-
|
2 |
import torch
|
|
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
import torchvision.transforms.functional as TF
|
7 |
from matplotlib import colormaps
|
8 |
from transformers import AutoModel
|
9 |
-
import os
|
10 |
|
11 |
# ----------------------------
|
12 |
# Configuration
|
13 |
# ----------------------------
|
14 |
-
#
|
15 |
-
|
|
|
|
|
|
|
16 |
PATCH_SIZE = 16
|
17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
|
@@ -21,32 +24,49 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
21 |
IMAGENET_STD = (0.229, 0.224, 0.225)
|
22 |
|
23 |
# ----------------------------
|
24 |
-
# Model Loading (
|
25 |
# ----------------------------
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
try:
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
print(f"β
Model loaded successfully on device: {DEVICE}")
|
36 |
-
return model
|
37 |
except Exception as e:
|
38 |
-
print(f"β Failed to load model: {e}")
|
39 |
-
# This will display a clear error message in the Gradio interface
|
40 |
raise gr.Error(
|
41 |
-
f"Could not load model '{
|
42 |
-
"
|
43 |
-
"and set
|
44 |
f"Original error: {e}"
|
45 |
)
|
46 |
|
47 |
-
|
48 |
-
model
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# ----------------------------
|
51 |
# Helper Functions
|
52 |
# ----------------------------
|
@@ -85,9 +105,11 @@ def generate_pca_visuals(
|
|
85 |
resolution: int,
|
86 |
cmap_name: str,
|
87 |
overlay_alpha: float,
|
|
|
88 |
progress=gr.Progress(track_tqdm=True)
|
89 |
):
|
90 |
"""Main function to generate PCA visuals."""
|
|
|
91 |
if model is None:
|
92 |
raise gr.Error("DINOv3 model is not available. Check the startup logs.")
|
93 |
if image_pil is None:
|
@@ -105,9 +127,8 @@ def generate_pca_visuals(
|
|
105 |
progress(0.5, desc="π¦ Extracting features with DINOv3...")
|
106 |
outputs = model(t_norm)
|
107 |
|
108 |
-
#
|
109 |
-
|
110 |
-
n_special_tokens = 5 # 1 [CLS] token + 4 register tokens for ViT-H/16+
|
111 |
patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
|
112 |
|
113 |
# 3. PCA Calculation
|
@@ -115,8 +136,7 @@ def generate_pca_visuals(
|
|
115 |
X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
|
116 |
U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
|
117 |
|
118 |
-
#
|
119 |
-
# This prevents the colors from randomly inverting on different runs.
|
120 |
for i in range(V.shape[1]):
|
121 |
max_abs_idx = torch.argmax(torch.abs(V[:, i]))
|
122 |
if V[max_abs_idx, i] < 0:
|
@@ -134,7 +154,6 @@ def generate_pca_visuals(
|
|
134 |
)
|
135 |
|
136 |
# 5. Create Visualizations
|
137 |
-
# This part should now work correctly as `scores` has the right shape (Hp*Wp, 3)
|
138 |
pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
|
139 |
pc1_image_raw = colorize(pc1_map, cmap_name)
|
140 |
|
@@ -155,10 +174,10 @@ def generate_pca_visuals(
|
|
155 |
# ----------------------------
|
156 |
# Gradio Interface
|
157 |
# ----------------------------
|
158 |
-
with gr.Blocks(theme=gr.themes.Soft(), title="
|
159 |
gr.Markdown(
|
160 |
"""
|
161 |
-
#
|
162 |
Upload an image to visualize the principal components of its patch features.
|
163 |
This reveals the main axes of semantic variation within the image as understood by the model.
|
164 |
"""
|
@@ -166,7 +185,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait π¦
|
|
166 |
|
167 |
with gr.Row():
|
168 |
with gr.Column(scale=2):
|
169 |
-
# Added a default image URL for convenience
|
170 |
input_image = gr.Image(type="pil", label="Upload Image", value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
|
171 |
|
172 |
with gr.Accordion("βοΈ Visualization Controls", open=True):
|
@@ -175,6 +193,12 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait π¦
|
|
175 |
label="Processing Resolution",
|
176 |
info="Higher values capture more detail but are slower."
|
177 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
cmap_dropdown = gr.Dropdown(
|
179 |
['viridis', 'magma', 'inferno', 'plasma', 'cividis', 'jet'],
|
180 |
value='viridis',
|
@@ -201,7 +225,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait π¦
|
|
201 |
|
202 |
run_button.click(
|
203 |
fn=generate_pca_visuals,
|
204 |
-
inputs=[input_image, resolution_slider, cmap_dropdown, alpha_slider],
|
205 |
outputs=[output_pc1, output_rgb, output_variance, output_blended, output_processed]
|
206 |
)
|
207 |
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
import torchvision.transforms.functional as TF
|
8 |
from matplotlib import colormaps
|
9 |
from transformers import AutoModel
|
|
|
10 |
|
11 |
# ----------------------------
|
12 |
# Configuration
|
13 |
# ----------------------------
|
14 |
+
# Define available models
|
15 |
+
DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
|
16 |
+
ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
|
17 |
+
AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
|
18 |
+
|
19 |
PATCH_SIZE = 16
|
20 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
|
|
|
24 |
IMAGENET_STD = (0.229, 0.224, 0.225)
|
25 |
|
26 |
# ----------------------------
|
27 |
+
# Model Loading (with caching)
|
28 |
# ----------------------------
|
29 |
+
_model_cache = {}
|
30 |
+
_current_model_id = None
|
31 |
+
model = None # global reference
|
32 |
+
|
33 |
+
def load_model_from_hub(model_id: str):
|
34 |
+
"""Loads a DINOv3 model from the Hugging Face Hub."""
|
35 |
+
print(f"Loading model '{model_id}' from Hugging Face Hub...")
|
36 |
try:
|
37 |
+
token = os.environ.get("HF_TOKEN") # optional, for gated models
|
38 |
+
mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
|
39 |
+
mdl.to(DEVICE).eval()
|
40 |
+
print(f"β
Model '{model_id}' loaded successfully on device: {DEVICE}")
|
41 |
+
return mdl
|
|
|
|
|
42 |
except Exception as e:
|
43 |
+
print(f"β Failed to load model '{model_id}': {e}")
|
|
|
44 |
raise gr.Error(
|
45 |
+
f"Could not load model '{model_id}'. "
|
46 |
+
"If the model is gated, please accept the terms on its Hugging Face page "
|
47 |
+
"and set HF_TOKEN in your environment. "
|
48 |
f"Original error: {e}"
|
49 |
)
|
50 |
|
51 |
+
def get_model(model_id: str):
|
52 |
+
"""Return a cached model if available, otherwise load and cache it."""
|
53 |
+
if model_id in _model_cache:
|
54 |
+
return _model_cache[model_id]
|
55 |
+
mdl = load_model_from_hub(model_id)
|
56 |
+
_model_cache[model_id] = mdl
|
57 |
+
return mdl
|
58 |
+
|
59 |
+
# Load the default model at startup
|
60 |
+
model = get_model(DEFAULT_MODEL_ID)
|
61 |
+
_current_model_id = DEFAULT_MODEL_ID
|
62 |
+
|
63 |
+
def _ensure_model(model_id: str):
|
64 |
+
"""Ensure the global 'model' matches the dropdown selection."""
|
65 |
+
global model, _current_model_id
|
66 |
+
if model_id != _current_model_id:
|
67 |
+
model = get_model(model_id)
|
68 |
+
_current_model_id = model_id
|
69 |
+
|
70 |
# ----------------------------
|
71 |
# Helper Functions
|
72 |
# ----------------------------
|
|
|
105 |
resolution: int,
|
106 |
cmap_name: str,
|
107 |
overlay_alpha: float,
|
108 |
+
model_id: str,
|
109 |
progress=gr.Progress(track_tqdm=True)
|
110 |
):
|
111 |
"""Main function to generate PCA visuals."""
|
112 |
+
_ensure_model(model_id)
|
113 |
if model is None:
|
114 |
raise gr.Error("DINOv3 model is not available. Check the startup logs.")
|
115 |
if image_pil is None:
|
|
|
127 |
progress(0.5, desc="π¦ Extracting features with DINOv3...")
|
128 |
outputs = model(t_norm)
|
129 |
|
130 |
+
# The model output includes a [CLS] token AND 4 register tokens.
|
131 |
+
n_special_tokens = 5
|
|
|
132 |
patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
|
133 |
|
134 |
# 3. PCA Calculation
|
|
|
136 |
X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
|
137 |
U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
|
138 |
|
139 |
+
# Stabilize the signs of the eigenvectors for deterministic output.
|
|
|
140 |
for i in range(V.shape[1]):
|
141 |
max_abs_idx = torch.argmax(torch.abs(V[:, i]))
|
142 |
if V[max_abs_idx, i] < 0:
|
|
|
154 |
)
|
155 |
|
156 |
# 5. Create Visualizations
|
|
|
157 |
pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
|
158 |
pc1_image_raw = colorize(pc1_map, cmap_name)
|
159 |
|
|
|
174 |
# ----------------------------
|
175 |
# Gradio Interface
|
176 |
# ----------------------------
|
177 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="π¦ DINOv3 PCA Explorer") as demo:
|
178 |
gr.Markdown(
|
179 |
"""
|
180 |
+
# π¦ DINOv3 PCA Explorer
|
181 |
Upload an image to visualize the principal components of its patch features.
|
182 |
This reveals the main axes of semantic variation within the image as understood by the model.
|
183 |
"""
|
|
|
185 |
|
186 |
with gr.Row():
|
187 |
with gr.Column(scale=2):
|
|
|
188 |
input_image = gr.Image(type="pil", label="Upload Image", value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
|
189 |
|
190 |
with gr.Accordion("βοΈ Visualization Controls", open=True):
|
|
|
193 |
label="Processing Resolution",
|
194 |
info="Higher values capture more detail but are slower."
|
195 |
)
|
196 |
+
model_choice = gr.Dropdown(
|
197 |
+
choices=AVAILABLE_MODELS,
|
198 |
+
value=DEFAULT_MODEL_ID,
|
199 |
+
label="Backbone (DINOv3)",
|
200 |
+
info="ViT-S/16+ is smaller & faster; ViT-H/16+ is larger.",
|
201 |
+
)
|
202 |
cmap_dropdown = gr.Dropdown(
|
203 |
['viridis', 'magma', 'inferno', 'plasma', 'cividis', 'jet'],
|
204 |
value='viridis',
|
|
|
225 |
|
226 |
run_button.click(
|
227 |
fn=generate_pca_visuals,
|
228 |
+
inputs=[input_image, resolution_slider, cmap_dropdown, alpha_slider, model_choice],
|
229 |
outputs=[output_pc1, output_rgb, output_variance, output_blended, output_processed]
|
230 |
)
|
231 |
|