sayedM commited on
Commit
364c029
Β·
verified Β·
1 Parent(s): e55e82c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms.functional as TF
6
+ from matplotlib import colormaps
7
+ from transformers import AutoModel
8
+
9
+ # ----------------------------
10
+ # Configuration
11
+ # ----------------------------
12
+ # The model will be downloaded from the Hugging Face Hub
13
+ MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
14
+ PATCH_SIZE = 16
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # Normalization constants
18
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
+ IMAGENET_STD = (0.229, 0.224, 0.225)
20
+
21
+ # ----------------------------
22
+ # Model Loading (runs once at startup)
23
+ # ----------------------------
24
+ def load_model_from_hub():
25
+ """Loads the DINOv3 model from the Hugging Face Hub."""
26
+ print(f"Loading model '{MODEL_ID}' from Hugging Face Hub...")
27
+ try:
28
+ model = AutoModel.from_pretrained(MODEL_ID)
29
+ model.to(DEVICE).eval()
30
+ print(f"βœ… Model loaded successfully on device: {DEVICE}")
31
+ return model
32
+ except Exception as e:
33
+ print(f"❌ Failed to load model: {e}")
34
+ gr.Error(f"Could not load model from Hub: {e}")
35
+ return None
36
+
37
+ # Load the model globally when the app starts
38
+ model = load_model_from_hub()
39
+
40
+ # ----------------------------
41
+ # Helper Functions
42
+ # ----------------------------
43
+ def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor:
44
+ """Resizes an image to dimensions that are multiples of the patch size."""
45
+ w, h = img.size
46
+ scale = long_side / max(h, w)
47
+ new_h = max(patch, int(round(h * scale)))
48
+ new_w = max(patch, int(round(w * scale)))
49
+
50
+ new_h = ((new_h + patch - 1) // patch) * patch
51
+ new_w = ((new_w + patch - 1) // patch) * patch
52
+
53
+ return TF.to_tensor(TF.resize(img.convert("RGB"), (new_h, new_w)))
54
+
55
+ def colorize(data: np.ndarray, cmap_name: str = 'viridis') -> Image.Image:
56
+ """Converts a 2D numpy array to a colored PIL image."""
57
+ x = data.astype(np.float32)
58
+ x = (x - x.min()) / (x.max() - x.min() + 1e-8)
59
+ cmap = colormaps.get_cmap(cmap_name)
60
+ rgb = (cmap(x)[..., :3] * 255).astype(np.uint8)
61
+ return Image.fromarray(rgb)
62
+
63
+ def blend(base: Image.Image, heat: Image.Image, alpha: float) -> Image.Image:
64
+ """Blends a heatmap onto a base image."""
65
+ base = base.convert("RGBA")
66
+ heat = heat.convert("RGBA")
67
+ return Image.blend(base, heat, alpha=alpha)
68
+
69
+ # ----------------------------
70
+ # Core Gradio Function
71
+ # ----------------------------
72
+ @torch.inference_mode()
73
+ def generate_pca_visuals(
74
+ image_pil: Image.Image,
75
+ resolution: int,
76
+ cmap_name: str,
77
+ overlay_alpha: float,
78
+ progress=gr.Progress(track_tqdm=True)
79
+ ):
80
+ """Main function to generate PCA visuals."""
81
+ if model is None:
82
+ raise gr.Error("DINOv3 model could not be loaded. Check the logs.")
83
+ if image_pil is None:
84
+ return None, None, "Please upload an image and click Generate.", None, None
85
+
86
+ # 1. Image Preprocessing
87
+ progress(0.2, desc="Resizing and preprocessing image...")
88
+ image_tensor = resize_to_grid(image_pil, resolution, PATCH_SIZE)
89
+ t_norm = TF.normalize(image_tensor, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
90
+ original_processed_image = TF.to_pil_image(image_tensor)
91
+ _, _, H, W = t_norm.shape
92
+ Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
93
+
94
+ # 2. Feature Extraction
95
+ progress(0.5, desc="πŸ¦– Extracting features with DINOv3...")
96
+ outputs = model(t_norm)
97
+ # The patch embeddings are in last_hidden_state, we skip the first token (CLS)
98
+ patch_embeddings = outputs.last_hidden_state.squeeze(0)[1:, :]
99
+
100
+ # 3. PCA Calculation
101
+ progress(0.8, desc="πŸ”¬ Performing PCA...")
102
+ X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
103
+ U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
104
+
105
+ # Stabilize the signs of the eigenvectors for deterministic output
106
+ for i in range(V.shape[1]):
107
+ max_abs_idx = torch.argmax(torch.abs(V[:, i]))
108
+ if V[max_abs_idx, i] < 0:
109
+ V[:, i] *= -1
110
+
111
+ scores = X_centered @ V[:, :3]
112
+
113
+ # 4. Explained Variance
114
+ total_variance = (X_centered ** 2).sum()
115
+ explained_variance = [float((s**2) / total_variance) for s in S]
116
+ variance_text = (
117
+ f"**πŸ“Š Explained Variance Ratios:**\n\n"
118
+ f"- **PC1:** {explained_variance[0]:.2%}\n"
119
+ f"- **PC2:** {explained_variance[1]:.2%}\n"
120
+ f"- **PC3:** {explained_variance[2]:.2%}"
121
+ )
122
+
123
+ # 5. Create Visualizations
124
+ pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
125
+ pc1_image_raw = colorize(pc1_map, cmap_name)
126
+ pc_rgb_map = scores.reshape(Hp, Wp, 3).cpu().numpy()
127
+ min_vals = pc_rgb_map.reshape(-1, 3).min(axis=0)
128
+ max_vals = pc_rgb_map.reshape(-1, 3).max(axis=0)
129
+ pc_rgb_map = (pc_rgb_map - min_vals) / (max_vals - min_vals + 1e-8)
130
+ pc_rgb_image_raw = Image.fromarray((pc_rgb_map * 255).astype(np.uint8))
131
+
132
+ target_size = original_processed_image.size
133
+ pc1_image_smooth = pc1_image_raw.resize(target_size, Image.Resampling.BICUBIC)
134
+ pc_rgb_image_smooth = pc_rgb_image_raw.resize(target_size, Image.Resampling.BICUBIC)
135
+ blended_image = blend(original_processed_image, pc1_image_smooth, overlay_alpha)
136
+
137
+ progress(1.0, desc="βœ… Done!")
138
+ return pc1_image_smooth, pc_rgb_image_smooth, variance_text, blended_image, original_processed_image
139
+
140
+
141
+ # ----------------------------
142
+ # Gradio Interface
143
+ # ----------------------------
144
+ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 PCA Explorer") as demo:
145
+ gr.Markdown(
146
+ """
147
+ # πŸ¦– DINOv3 PCA Explorer
148
+ Upload an image to visualize the principal components of its patch features.
149
+ This reveals the main axes of semantic variation within the image as understood by the model.
150
+ """
151
+ )
152
+
153
+ with gr.Row():
154
+ with gr.Column(scale=2):
155
+ input_image = gr.Image(type="pil", label="Upload Image", value="https://picsum.photos/id/1011/800/600")
156
+
157
+ with gr.Accordion("βš™οΈ Visualization Controls", open=True):
158
+ resolution_slider = gr.Slider(
159
+ minimum=224, maximum=1024, value=512, step=16,
160
+ label="Processing Resolution",
161
+ info="Higher values capture more detail but are slower."
162
+ )
163
+ cmap_dropdown = gr.Dropdown(
164
+ ['viridis', 'magma', 'inferno', 'plasma', 'cividis', 'jet'],
165
+ value='viridis',
166
+ label="Heatmap Colormap"
167
+ )
168
+ alpha_slider = gr.Slider(
169
+ minimum=0, maximum=1, value=0.5,
170
+ label="Overlay Opacity"
171
+ )
172
+
173
+ run_button = gr.Button("πŸš€ Generate PCA Visuals", variant="primary")
174
+
175
+ with gr.Column(scale=3):
176
+ with gr.Tabs():
177
+ with gr.TabItem("πŸ–ΌοΈ Overlay"):
178
+ gr.Markdown("Visualize the main heatmap blended with the original image.")
179
+ output_blended = gr.Image(label="PC1 Heatmap Overlay")
180
+ output_processed = gr.Image(label="Original Processed Image (at selected resolution)")
181
+ with gr.TabItem("πŸ“Š PCA Outputs"):
182
+ gr.Markdown("View the raw outputs of the Principal Component Analysis.")
183
+ output_pc1 = gr.Image(label="PC1 Heatmap (Smoothed)")
184
+ output_rgb = gr.Image(label="Top 3 PCs as RGB (Smoothed)")
185
+ output_variance = gr.Markdown(label="Explained Variance")
186
+
187
+ run_button.click(
188
+ fn=generate_pca_visuals,
189
+ inputs=[input_image, resolution_slider, cmap_dropdown, alpha_slider],
190
+ outputs=[output_pc1, output_rgb, output_variance, output_blended, output_processed]
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ demo.launch()