sayedM commited on
Commit
a6e4f18
Β·
verified Β·
1 Parent(s): 5e20a05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -33
app.py CHANGED
@@ -1,18 +1,21 @@
1
- # app.py
2
  import torch
 
3
  import gradio as gr
4
  import numpy as np
5
  from PIL import Image
6
  import torchvision.transforms.functional as TF
7
  from matplotlib import colormaps
8
  from transformers import AutoModel
9
- import os
10
 
11
  # ----------------------------
12
  # Configuration
13
  # ----------------------------
14
- # πŸ’‘ FIX: Use the full, correct model ID from the Hugging Face Hub.
15
- MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
 
 
 
16
  PATCH_SIZE = 16
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
@@ -21,32 +24,49 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
21
  IMAGENET_STD = (0.229, 0.224, 0.225)
22
 
23
  # ----------------------------
24
- # Model Loading (runs once at startup)
25
  # ----------------------------
26
- def load_model_from_hub():
27
- """Loads the DINOv3 model from the Hugging Face Hub."""
28
- print(f"Loading model '{MODEL_ID}' from Hugging Face Hub...")
 
 
 
 
29
  try:
30
- # This will use the HF_TOKEN secret if you set it in your Space settings.
31
- token = os.environ.get("HF_TOKEN")
32
- # trust_remote_code is necessary for DINOv3
33
- model = AutoModel.from_pretrained(MODEL_ID, token=token, trust_remote_code=True)
34
- model.to(DEVICE).eval()
35
- print(f"βœ… Model loaded successfully on device: {DEVICE}")
36
- return model
37
  except Exception as e:
38
- print(f"❌ Failed to load model: {e}")
39
- # This will display a clear error message in the Gradio interface
40
  raise gr.Error(
41
- f"Could not load model '{MODEL_ID}'. "
42
- "This is a gated model. Please ensure you have accepted the terms on its Hugging Face page "
43
- "and set your HF_TOKEN as a secret in your Space settings. "
44
  f"Original error: {e}"
45
  )
46
 
47
- # Load the model globally when the app starts
48
- model = load_model_from_hub()
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # ----------------------------
51
  # Helper Functions
52
  # ----------------------------
@@ -85,9 +105,11 @@ def generate_pca_visuals(
85
  resolution: int,
86
  cmap_name: str,
87
  overlay_alpha: float,
 
88
  progress=gr.Progress(track_tqdm=True)
89
  ):
90
  """Main function to generate PCA visuals."""
 
91
  if model is None:
92
  raise gr.Error("DINOv3 model is not available. Check the startup logs.")
93
  if image_pil is None:
@@ -105,9 +127,8 @@ def generate_pca_visuals(
105
  progress(0.5, desc="πŸ¦– Extracting features with DINOv3...")
106
  outputs = model(t_norm)
107
 
108
- # πŸ’‘ FIX: The model output includes a [CLS] token AND 4 register tokens.
109
- # We must skip all of them (total 5) to get only the patch embeddings.
110
- n_special_tokens = 5 # 1 [CLS] token + 4 register tokens for ViT-H/16+
111
  patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
112
 
113
  # 3. PCA Calculation
@@ -115,8 +136,7 @@ def generate_pca_visuals(
115
  X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
116
  U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
117
 
118
- # πŸ’‘ IMPROVEMENT: Stabilize the signs of the eigenvectors for deterministic output.
119
- # This prevents the colors from randomly inverting on different runs.
120
  for i in range(V.shape[1]):
121
  max_abs_idx = torch.argmax(torch.abs(V[:, i]))
122
  if V[max_abs_idx, i] < 0:
@@ -134,7 +154,6 @@ def generate_pca_visuals(
134
  )
135
 
136
  # 5. Create Visualizations
137
- # This part should now work correctly as `scores` has the right shape (Hp*Wp, 3)
138
  pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
139
  pc1_image_raw = colorize(pc1_map, cmap_name)
140
 
@@ -155,10 +174,10 @@ def generate_pca_visuals(
155
  # ----------------------------
156
  # Gradio Interface
157
  # ----------------------------
158
- with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait πŸ¦– DINOv3 PCA Explorer") as demo:
159
  gr.Markdown(
160
  """
161
- # Running on CPU so please wait πŸ¦– DINOv3 PCA Explorer
162
  Upload an image to visualize the principal components of its patch features.
163
  This reveals the main axes of semantic variation within the image as understood by the model.
164
  """
@@ -166,7 +185,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait πŸ¦–
166
 
167
  with gr.Row():
168
  with gr.Column(scale=2):
169
- # Added a default image URL for convenience
170
  input_image = gr.Image(type="pil", label="Upload Image", value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
171
 
172
  with gr.Accordion("βš™οΈ Visualization Controls", open=True):
@@ -175,6 +193,12 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait πŸ¦–
175
  label="Processing Resolution",
176
  info="Higher values capture more detail but are slower."
177
  )
 
 
 
 
 
 
178
  cmap_dropdown = gr.Dropdown(
179
  ['viridis', 'magma', 'inferno', 'plasma', 'cividis', 'jet'],
180
  value='viridis',
@@ -201,7 +225,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Running on CPU so please wait πŸ¦–
201
 
202
  run_button.click(
203
  fn=generate_pca_visuals,
204
- inputs=[input_image, resolution_slider, cmap_dropdown, alpha_slider],
205
  outputs=[output_pc1, output_rgb, output_variance, output_blended, output_processed]
206
  )
207
 
 
1
+ import os
2
  import torch
3
+ import torch.nn.functional as F
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
7
  import torchvision.transforms.functional as TF
8
  from matplotlib import colormaps
9
  from transformers import AutoModel
 
10
 
11
  # ----------------------------
12
  # Configuration
13
  # ----------------------------
14
+ # Define available models
15
+ DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
16
+ ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
17
+ AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
18
+
19
  PATCH_SIZE = 16
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
 
24
  IMAGENET_STD = (0.229, 0.224, 0.225)
25
 
26
  # ----------------------------
27
+ # Model Loading (with caching)
28
  # ----------------------------
29
+ _model_cache = {}
30
+ _current_model_id = None
31
+ model = None # global reference
32
+
33
+ def load_model_from_hub(model_id: str):
34
+ """Loads a DINOv3 model from the Hugging Face Hub."""
35
+ print(f"Loading model '{model_id}' from Hugging Face Hub...")
36
  try:
37
+ token = os.environ.get("HF_TOKEN") # optional, for gated models
38
+ mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
39
+ mdl.to(DEVICE).eval()
40
+ print(f"βœ… Model '{model_id}' loaded successfully on device: {DEVICE}")
41
+ return mdl
 
 
42
  except Exception as e:
43
+ print(f"❌ Failed to load model '{model_id}': {e}")
 
44
  raise gr.Error(
45
+ f"Could not load model '{model_id}'. "
46
+ "If the model is gated, please accept the terms on its Hugging Face page "
47
+ "and set HF_TOKEN in your environment. "
48
  f"Original error: {e}"
49
  )
50
 
51
+ def get_model(model_id: str):
52
+ """Return a cached model if available, otherwise load and cache it."""
53
+ if model_id in _model_cache:
54
+ return _model_cache[model_id]
55
+ mdl = load_model_from_hub(model_id)
56
+ _model_cache[model_id] = mdl
57
+ return mdl
58
+
59
+ # Load the default model at startup
60
+ model = get_model(DEFAULT_MODEL_ID)
61
+ _current_model_id = DEFAULT_MODEL_ID
62
+
63
+ def _ensure_model(model_id: str):
64
+ """Ensure the global 'model' matches the dropdown selection."""
65
+ global model, _current_model_id
66
+ if model_id != _current_model_id:
67
+ model = get_model(model_id)
68
+ _current_model_id = model_id
69
+
70
  # ----------------------------
71
  # Helper Functions
72
  # ----------------------------
 
105
  resolution: int,
106
  cmap_name: str,
107
  overlay_alpha: float,
108
+ model_id: str,
109
  progress=gr.Progress(track_tqdm=True)
110
  ):
111
  """Main function to generate PCA visuals."""
112
+ _ensure_model(model_id)
113
  if model is None:
114
  raise gr.Error("DINOv3 model is not available. Check the startup logs.")
115
  if image_pil is None:
 
127
  progress(0.5, desc="πŸ¦– Extracting features with DINOv3...")
128
  outputs = model(t_norm)
129
 
130
+ # The model output includes a [CLS] token AND 4 register tokens.
131
+ n_special_tokens = 5
 
132
  patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
133
 
134
  # 3. PCA Calculation
 
136
  X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
137
  U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
138
 
139
+ # Stabilize the signs of the eigenvectors for deterministic output.
 
140
  for i in range(V.shape[1]):
141
  max_abs_idx = torch.argmax(torch.abs(V[:, i]))
142
  if V[max_abs_idx, i] < 0:
 
154
  )
155
 
156
  # 5. Create Visualizations
 
157
  pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
158
  pc1_image_raw = colorize(pc1_map, cmap_name)
159
 
 
174
  # ----------------------------
175
  # Gradio Interface
176
  # ----------------------------
177
+ with gr.Blocks(theme=gr.themes.Soft(), title="πŸ¦– DINOv3 PCA Explorer") as demo:
178
  gr.Markdown(
179
  """
180
+ # πŸ¦– DINOv3 PCA Explorer
181
  Upload an image to visualize the principal components of its patch features.
182
  This reveals the main axes of semantic variation within the image as understood by the model.
183
  """
 
185
 
186
  with gr.Row():
187
  with gr.Column(scale=2):
 
188
  input_image = gr.Image(type="pil", label="Upload Image", value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
189
 
190
  with gr.Accordion("βš™οΈ Visualization Controls", open=True):
 
193
  label="Processing Resolution",
194
  info="Higher values capture more detail but are slower."
195
  )
196
+ model_choice = gr.Dropdown(
197
+ choices=AVAILABLE_MODELS,
198
+ value=DEFAULT_MODEL_ID,
199
+ label="Backbone (DINOv3)",
200
+ info="ViT-S/16+ is smaller & faster; ViT-H/16+ is larger.",
201
+ )
202
  cmap_dropdown = gr.Dropdown(
203
  ['viridis', 'magma', 'inferno', 'plasma', 'cividis', 'jet'],
204
  value='viridis',
 
225
 
226
  run_button.click(
227
  fn=generate_pca_visuals,
228
+ inputs=[input_image, resolution_slider, cmap_dropdown, alpha_slider, model_choice],
229
  outputs=[output_pc1, output_rgb, output_variance, output_blended, output_processed]
230
  )
231