Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,634 +1,166 @@
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
import folium
|
4 |
-
from folium.plugins import HeatMap
|
5 |
import gradio as gr
|
6 |
import os
|
7 |
import PIL.Image
|
8 |
-
from io import BytesIO
|
9 |
-
import base64
|
10 |
import json
|
11 |
import time
|
12 |
-
from typing import
|
13 |
from pathlib import Path
|
14 |
from datasets import Dataset, load_dataset, concatenate_datasets
|
|
|
15 |
|
16 |
# GeoCLIP dependencies
|
17 |
from geoclip import GeoCLIP
|
18 |
from transformers import CLIPTokenizer, CLIPProcessor
|
19 |
-
from huggingface_hub import HfApi
|
20 |
-
|
21 |
|
|
|
22 |
class GeoCLIPCore:
|
23 |
-
""
|
24 |
-
Vectorized GeoCLIP implementation with HuggingFace Hub integration.
|
25 |
-
|
26 |
-
Implements tensor-optimized inference with persistent dataset storage:
|
27 |
-
1. Text-to-location prediction with confidence scoring
|
28 |
-
2. Image-to-location prediction with metadata extraction
|
29 |
-
3. Coordinate embedding generation for vector analysis
|
30 |
-
4. Cross-modal similarity computation
|
31 |
-
5. Dataset persistence to HuggingFace Hub
|
32 |
-
"""
|
33 |
-
|
34 |
-
def __init__(self,
|
35 |
-
device: Optional[str] = None,
|
36 |
-
dataset_id: str = "latterworks/geo-metadata",
|
37 |
-
token: Optional[str] = None) -> None:
|
38 |
-
"""
|
39 |
-
Initialize model with optimal compute allocation and dataset connection.
|
40 |
-
|
41 |
-
Args:
|
42 |
-
device: Target compute device (None for auto-detection)
|
43 |
-
dataset_id: HuggingFace dataset identifier
|
44 |
-
token: HuggingFace API token
|
45 |
-
"""
|
46 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
self.dataset_id = dataset_id
|
48 |
self.token = token
|
49 |
-
|
50 |
-
# Initialize HuggingFace API for dataset operations
|
51 |
-
self.api = HfApi(token=token)
|
52 |
-
|
53 |
-
# Load and configure core model components with vectorized execution path
|
54 |
self._model = GeoCLIP().to(self.device)
|
55 |
self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
56 |
self._processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
57 |
-
|
58 |
-
# Cache frequently accessed components for reduced latency
|
59 |
self._location_encoder = self._model.location_encoder
|
60 |
-
self.
|
61 |
-
self._gps_gallery = None # Lazy-loaded on first prediction
|
62 |
|
63 |
-
# Initialize
|
64 |
-
self._initialize_dataset()
|
65 |
-
|
66 |
-
print(f"GeoCLIP initialized on {self.device} with Hub dataset: {dataset_id}")
|
67 |
-
|
68 |
-
def _initialize_dataset(self) -> None:
|
69 |
-
"""Initialize connection to HuggingFace dataset with atomic transaction handling."""
|
70 |
try:
|
71 |
-
# Attempt to load existing dataset
|
72 |
self.dataset = load_dataset(self.dataset_id, split="train", token=self.token)
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
"filename": [],
|
79 |
-
"classes": [],
|
80 |
-
"metadata": []
|
81 |
-
})
|
82 |
-
|
83 |
-
def embed_text(self, text: str) -> torch.Tensor:
|
84 |
-
"""
|
85 |
-
Generate normalized embedding for text input using vectorized operations.
|
86 |
-
|
87 |
-
Args:
|
88 |
-
text: Text description to encode
|
89 |
-
|
90 |
-
Returns:
|
91 |
-
L2-normalized embedding tensor (shape: [1, 512])
|
92 |
-
"""
|
93 |
with torch.no_grad():
|
94 |
tokens = self._tokenizer(text, return_tensors="pt", padding=True).to(self.device)
|
95 |
embedding = self._model.image_encoder.mlp(
|
96 |
self._model.image_encoder.CLIP.get_text_features(**tokens)
|
97 |
)
|
98 |
-
|
99 |
-
|
100 |
-
def embed_image(self, image: Union[str, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
|
101 |
-
"""
|
102 |
-
Generate normalized embedding for image input using vectorized operations.
|
103 |
-
|
104 |
-
Args:
|
105 |
-
image: Input image (PIL Image, file path, or numpy array)
|
106 |
-
|
107 |
-
Returns:
|
108 |
-
L2-normalized embedding tensor (shape: [1, 512])
|
109 |
-
"""
|
110 |
-
with torch.no_grad():
|
111 |
-
# Process different image input types with type-specific optimizations
|
112 |
-
if isinstance(image, str):
|
113 |
-
# Path to image file
|
114 |
-
image = PIL.Image.open(image).convert("RGB")
|
115 |
-
elif isinstance(image, np.ndarray):
|
116 |
-
# Convert numpy array to PIL Image with optimal memory layout
|
117 |
-
image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
|
118 |
-
|
119 |
-
# Process image using CLIP processor with tensor allocation
|
120 |
-
inputs = self._processor(images=image, return_tensors="pt").to(self.device)
|
121 |
-
embedding = self._model.image_encoder(inputs.pixel_values)
|
122 |
-
return torch.nn.functional.normalize(embedding, dim=1)
|
123 |
-
|
124 |
-
def embed_coordinates(self, coords: Tuple[float, float]) -> torch.Tensor:
|
125 |
-
"""
|
126 |
-
Generate normalized embedding for geographical coordinates.
|
127 |
-
|
128 |
-
Args:
|
129 |
-
coords: Coordinate pair (latitude, longitude)
|
130 |
-
|
131 |
-
Returns:
|
132 |
-
L2-normalized embedding tensor (shape: [1, 512])
|
133 |
-
"""
|
134 |
-
with torch.no_grad():
|
135 |
-
coords_tensor = torch.tensor([coords], dtype=torch.float32).to(self.device)
|
136 |
-
embedding = self._location_encoder(coords_tensor)
|
137 |
-
return torch.nn.functional.normalize(embedding, dim=1)
|
138 |
-
|
139 |
-
def _ensure_gps_gallery(self) -> None:
|
140 |
-
"""Ensure GPS gallery is loaded and cached for efficient reuse."""
|
141 |
-
if self._gps_gallery is None:
|
142 |
-
self._gps_gallery = self._model.gps_gallery.to(self.device)
|
143 |
-
|
144 |
-
def predict_location(self,
|
145 |
-
query_embedding: torch.Tensor,
|
146 |
-
top_k: int = 5) -> List[Dict[str, Any]]:
|
147 |
-
"""
|
148 |
-
Execute cosine similarity-based location retrieval against GPS gallery.
|
149 |
-
|
150 |
-
Args:
|
151 |
-
query_embedding: L2-normalized query embedding
|
152 |
-
top_k: Number of top predictions to return
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
with torch.no_grad():
|
158 |
-
# Ensure GPS gallery is loaded with resource pooling
|
159 |
-
self._ensure_gps_gallery()
|
160 |
|
161 |
-
#
|
162 |
location_embeddings = self._location_encoder(self._gps_gallery)
|
163 |
location_embeddings = torch.nn.functional.normalize(location_embeddings, dim=1)
|
164 |
-
|
165 |
-
# Calculate similarity with vectorized matrix multiplication
|
166 |
-
similarity = self._model.logit_scale.exp() * (query_embedding @ location_embeddings.T)
|
167 |
probs = similarity.softmax(dim=-1)
|
168 |
|
169 |
-
# Extract
|
170 |
top_values, top_indices = torch.topk(probs[0], min(top_k, len(self._gps_gallery)))
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
def text_to_location(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
183 |
-
"""
|
184 |
-
Primary entry point for text-to-location prediction pipeline.
|
185 |
-
|
186 |
-
Args:
|
187 |
-
text: Text description to predict location for
|
188 |
-
top_k: Number of top predictions to return
|
189 |
-
|
190 |
-
Returns:
|
191 |
-
List of prediction dictionaries with coordinates and confidence scores
|
192 |
-
"""
|
193 |
-
embedding = self.embed_text(text)
|
194 |
-
return self.predict_location(embedding, top_k)
|
195 |
-
|
196 |
-
def image_to_location(self,
|
197 |
-
image: Union[str, PIL.Image.Image, np.ndarray],
|
198 |
-
top_k: int = 5) -> List[Dict[str, Any]]:
|
199 |
-
"""
|
200 |
-
Primary entry point for image-to-location prediction pipeline.
|
201 |
-
|
202 |
-
Args:
|
203 |
-
image: Input image (PIL Image, file path, or numpy array)
|
204 |
-
top_k: Number of top predictions to return
|
205 |
-
|
206 |
-
Returns:
|
207 |
-
List of prediction dictionaries with coordinates and confidence scores
|
208 |
-
"""
|
209 |
-
embedding = self.embed_image(image)
|
210 |
-
return self.predict_location(embedding, top_k)
|
211 |
-
|
212 |
-
def extract_image_metadata(self, image_path: str) -> Dict[str, Any]:
|
213 |
-
"""
|
214 |
-
Extract comprehensive metadata from image file with GPS coordinates.
|
215 |
-
|
216 |
-
Args:
|
217 |
-
image_path: Path to image file
|
218 |
-
|
219 |
-
Returns:
|
220 |
-
Dictionary containing extracted metadata
|
221 |
-
"""
|
222 |
-
try:
|
223 |
-
from PIL import Image, ExifTags
|
224 |
-
import piexif
|
225 |
-
|
226 |
-
# Open image and extract EXIF data with efficient memory mapping
|
227 |
-
img = Image.open(image_path)
|
228 |
-
metadata = {"file_name": image_path, "file_size": os.path.getsize(image_path)}
|
229 |
-
|
230 |
-
# Extract basic image properties
|
231 |
-
metadata["format"] = img.format
|
232 |
-
metadata["mode"] = img.mode
|
233 |
-
metadata["size"] = list(img.size)
|
234 |
-
|
235 |
-
if hasattr(img, "_getexif") and img._getexif():
|
236 |
-
exif_dict = {}
|
237 |
-
for tag_id, value in img._getexif().items():
|
238 |
-
tag = ExifTags.TAGS.get(tag_id, tag_id)
|
239 |
-
exif_dict[tag.lower()] = value
|
240 |
-
|
241 |
-
# Copy relevant EXIF data to metadata
|
242 |
-
for key, value in exif_dict.items():
|
243 |
-
if isinstance(value, bytes):
|
244 |
-
continue
|
245 |
-
metadata[key] = value
|
246 |
-
|
247 |
-
# Extract GPS data with specialized parsing
|
248 |
-
gps_info = {}
|
249 |
-
if "gpsinfo" in exif_dict:
|
250 |
-
gps_data = exif_dict["gpsinfo"]
|
251 |
-
for key, value in gps_data.items():
|
252 |
-
tag = ExifTags.GPSTAGS.get(key, key)
|
253 |
-
gps_info[tag] = value
|
254 |
-
|
255 |
-
# Parse GPS coordinates to decimal format
|
256 |
-
if "GPSLatitude" in gps_info and "GPSLongitude" in gps_info:
|
257 |
-
lat = self._convert_to_decimal(
|
258 |
-
gps_info["GPSLatitude"],
|
259 |
-
gps_info.get("GPSLatitudeRef", "N")
|
260 |
-
)
|
261 |
-
lon = self._convert_to_decimal(
|
262 |
-
gps_info["GPSLongitude"],
|
263 |
-
gps_info.get("GPSLongitudeRef", "E")
|
264 |
-
)
|
265 |
-
gps_info["Latitude"] = lat
|
266 |
-
gps_info["Longitude"] = lon
|
267 |
-
|
268 |
-
metadata["gps_info"] = gps_info
|
269 |
-
|
270 |
-
# Add file metadata
|
271 |
-
metadata["file_extension"] = os.path.splitext(image_path)[1]
|
272 |
-
metadata["extraction_timestamp"] = int(time.time())
|
273 |
-
|
274 |
-
return metadata
|
275 |
-
except Exception as e:
|
276 |
-
print(f"Error extracting metadata: {e}")
|
277 |
-
return {"error": str(e), "file_name": image_path}
|
278 |
-
|
279 |
-
def _convert_to_decimal(self, dms_coords, ref) -> float:
|
280 |
-
"""
|
281 |
-
Convert GPS DMS (Degree, Minute, Second) to decimal format.
|
282 |
-
|
283 |
-
Args:
|
284 |
-
dms_coords: Tuple of degrees, minutes, seconds
|
285 |
-
ref: Direction reference (N/S/E/W)
|
286 |
-
|
287 |
-
Returns:
|
288 |
-
Decimal coordinate value
|
289 |
-
"""
|
290 |
-
degrees = dms_coords[0]
|
291 |
-
minutes = dms_coords[1] / 60.0
|
292 |
-
seconds = dms_coords[2] / 3600.0
|
293 |
-
|
294 |
-
decimal = degrees + minutes + seconds
|
295 |
-
|
296 |
-
# Apply negative value for south or west coordinates
|
297 |
-
if ref in ['S', 'W']:
|
298 |
-
decimal = -decimal
|
299 |
-
|
300 |
-
return decimal
|
301 |
-
|
302 |
-
def add_to_dataset(self,
|
303 |
-
image_path: str,
|
304 |
-
classes: Optional[List[str]] = None,
|
305 |
-
push_to_hub: bool = True) -> Dict[str, Any]:
|
306 |
-
"""
|
307 |
-
Process image and add entry to dataset with optional HuggingFace Hub synchronization.
|
308 |
-
|
309 |
-
Args:
|
310 |
-
image_path: Path to image file
|
311 |
-
classes: Optional list of class labels
|
312 |
-
push_to_hub: Whether to push changes to Hub
|
313 |
-
|
314 |
-
Returns:
|
315 |
-
Dictionary containing the added entry
|
316 |
-
"""
|
317 |
-
# Extract filename from path
|
318 |
-
filename = os.path.basename(image_path)
|
319 |
-
|
320 |
-
# Extract comprehensive metadata with optimized parser
|
321 |
-
metadata = self.extract_image_metadata(image_path)
|
322 |
-
|
323 |
-
# Prepare new entry
|
324 |
-
new_entry = {
|
325 |
-
"filename": filename,
|
326 |
-
"classes": classes or [],
|
327 |
-
"metadata": metadata
|
328 |
-
}
|
329 |
-
|
330 |
-
# Add to local dataset with optimized append operation
|
331 |
-
self.dataset = concatenate_datasets([
|
332 |
-
self.dataset,
|
333 |
-
Dataset.from_dict({
|
334 |
-
"filename": [new_entry["filename"]],
|
335 |
-
"classes": [new_entry["classes"]],
|
336 |
-
"metadata": [new_entry["metadata"]]
|
337 |
-
})
|
338 |
-
])
|
339 |
-
|
340 |
-
# Push updates to HuggingFace Hub
|
341 |
-
if push_to_hub:
|
342 |
-
self.push_dataset_to_hub()
|
343 |
-
|
344 |
-
return new_entry
|
345 |
-
|
346 |
-
def push_dataset_to_hub(self) -> None:
|
347 |
-
"""Push dataset updates to HuggingFace Hub with atomic transaction."""
|
348 |
-
if self.token:
|
349 |
-
try:
|
350 |
-
self.dataset.push_to_hub(self.dataset_id, token=self.token)
|
351 |
-
print(f"Successfully pushed dataset with {len(self.dataset)} entries to {self.dataset_id}")
|
352 |
-
except Exception as e:
|
353 |
-
print(f"Error pushing to Hub: {e}")
|
354 |
-
else:
|
355 |
-
print("HuggingFace token not provided. Dataset not pushed to Hub.")
|
356 |
-
|
357 |
-
def compute_similarity(self, embed1: torch.Tensor, embed2: torch.Tensor) -> float:
|
358 |
-
"""
|
359 |
-
Compute cosine similarity between two embeddings.
|
360 |
-
|
361 |
-
Args:
|
362 |
-
embed1: First embedding tensor
|
363 |
-
embed2: Second embedding tensor
|
364 |
-
|
365 |
-
Returns:
|
366 |
-
Similarity score between 0 and 1
|
367 |
-
"""
|
368 |
-
return float(torch.nn.functional.cosine_similarity(embed1, embed2).item())
|
369 |
-
|
370 |
-
def create_map_visualization(self,
|
371 |
-
predictions: List[Dict[str, Any]],
|
372 |
-
title: str = "",
|
373 |
-
cluster: bool = False) -> folium.Map:
|
374 |
-
"""
|
375 |
-
Generate geospatial visualization of prediction results.
|
376 |
-
|
377 |
-
Args:
|
378 |
-
predictions: List of prediction dictionaries
|
379 |
-
title: Optional map title
|
380 |
-
cluster: Whether to cluster nearby markers
|
381 |
-
|
382 |
-
Returns:
|
383 |
-
Folium map object with marker and heatmap layers
|
384 |
-
"""
|
385 |
-
# Initialize map centered on highest confidence prediction
|
386 |
-
center_coords = predictions[0]["coordinates"]
|
387 |
-
m = folium.Map(location=center_coords, zoom_start=5, tiles="OpenStreetMap")
|
388 |
-
|
389 |
-
# Add title if provided
|
390 |
if title:
|
391 |
-
|
392 |
-
m.get_root().html.add_child(folium.Element(title_html))
|
393 |
-
|
394 |
-
# Create marker cluster if requested
|
395 |
-
marker_group = MarkerCluster() if cluster else m
|
396 |
|
397 |
# Add markers with confidence metadata
|
398 |
for i, pred in enumerate(predictions):
|
399 |
color = 'red' if i == 0 else 'blue' if i < 3 else 'green'
|
400 |
-
|
401 |
folium.Marker(
|
402 |
location=pred["coordinates"],
|
403 |
popup=f"Prediction #{i+1}<br>Confidence: {pred['confidence']:.6f}",
|
404 |
icon=folium.Icon(color=color)
|
405 |
-
).add_to(
|
406 |
|
407 |
-
# Add
|
408 |
-
if cluster:
|
409 |
-
m.add_child(marker_group)
|
410 |
-
|
411 |
-
# Add heatmap layer for visual density representation
|
412 |
if len(predictions) >= 3:
|
413 |
heat_data = [[p["coordinates"][0], p["coordinates"][1], p["confidence"]]
|
414 |
-
|
415 |
HeatMap(heat_data, radius=15, blur=10).add_to(m)
|
416 |
|
417 |
return m
|
418 |
|
419 |
-
|
420 |
-
def
|
421 |
-
"""
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
geo_core = GeoCLIPCore(token=hf_token)
|
429 |
|
430 |
-
|
431 |
-
|
432 |
-
if
|
433 |
-
return
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
# Create HTML representation
|
445 |
-
map_html = m._repr_html_()
|
446 |
-
|
447 |
-
# Format textual results
|
448 |
-
result_text = f"Top predictions for: '{text_query}'\n\n"
|
449 |
-
for i, pred in enumerate(predictions, 1):
|
450 |
-
coords = pred["coordinates"]
|
451 |
-
conf = pred["confidence"]
|
452 |
-
result_text += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - confidence: {conf:.6f}\n"
|
453 |
-
|
454 |
-
return map_html, result_text
|
455 |
-
|
456 |
-
def process_image(image, image_path, save_to_hub, top_k):
|
457 |
-
"""
|
458 |
-
Process image for prediction and metadata extraction with Hub integration.
|
459 |
-
|
460 |
-
Returns map visualization, prediction results, and metadata.
|
461 |
-
"""
|
462 |
-
if image is None:
|
463 |
-
return None, "Please upload an image.", "{}"
|
464 |
-
|
465 |
-
# Execute prediction pipeline with tensor acceleration
|
466 |
-
predictions = geo_core.image_to_location(image, top_k=int(top_k))
|
467 |
-
|
468 |
-
# Generate map visualization
|
469 |
-
m = geo_core.create_map_visualization(
|
470 |
-
predictions,
|
471 |
-
title="Predictions from Image"
|
472 |
-
)
|
473 |
-
|
474 |
-
# Create HTML representation
|
475 |
-
map_html = m._repr_html_()
|
476 |
-
|
477 |
-
# Format textual results
|
478 |
-
result_text = "Top predictions from image:\n\n"
|
479 |
-
for i, pred in enumerate(predictions, 1):
|
480 |
-
coords = pred["coordinates"]
|
481 |
-
conf = pred["confidence"]
|
482 |
-
result_text += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - confidence: {conf:.6f}\n"
|
483 |
-
|
484 |
-
# Extract metadata if image was uploaded and path is available
|
485 |
-
metadata = {}
|
486 |
-
if image_path:
|
487 |
-
# Add to dataset if requested
|
488 |
-
if save_to_hub:
|
489 |
-
entry = geo_core.add_to_dataset(
|
490 |
-
image_path,
|
491 |
-
classes=["location"],
|
492 |
-
push_to_hub=True
|
493 |
-
)
|
494 |
-
metadata = entry["metadata"]
|
495 |
-
else:
|
496 |
-
# Just extract metadata without saving
|
497 |
-
metadata = geo_core.extract_image_metadata(image_path)
|
498 |
-
|
499 |
-
# Format metadata as JSON
|
500 |
-
metadata_json = json.dumps(metadata, indent=2)
|
501 |
-
|
502 |
-
return map_html, result_text, metadata_json
|
503 |
-
|
504 |
-
def compute_text_similarity(text1, text2):
|
505 |
-
"""Compute semantic similarity between two text descriptions."""
|
506 |
-
if not text1.strip() or not text2.strip():
|
507 |
-
return "Please enter both text descriptions."
|
508 |
-
|
509 |
-
embed1 = geo_core.embed_text(text1)
|
510 |
-
embed2 = geo_core.embed_text(text2)
|
511 |
-
|
512 |
-
similarity = geo_core.compute_similarity(embed1, embed2)
|
513 |
-
return f"Similarity between the texts: {similarity:.4f} (range: 0-1)"
|
514 |
-
|
515 |
-
# Create Gradio interface with tabs for different functions
|
516 |
-
with gr.Blocks(title="GeoCLIP Location Intelligence") as demo:
|
517 |
-
gr.Markdown("# GeoCLIP Location Intelligence with Hub Integration")
|
518 |
-
gr.Markdown("Predict locations from text descriptions or images with dataset persistence.")
|
519 |
-
|
520 |
-
with gr.Tabs():
|
521 |
-
with gr.TabItem("Text → Location"):
|
522 |
-
with gr.Row():
|
523 |
-
with gr.Column():
|
524 |
-
text_input = gr.Textbox(
|
525 |
-
lines=3,
|
526 |
-
placeholder="Enter location description...",
|
527 |
-
label="Location Description"
|
528 |
-
)
|
529 |
-
text_top_k = gr.Slider(
|
530 |
-
minimum=1,
|
531 |
-
maximum=20,
|
532 |
-
value=10,
|
533 |
-
step=1,
|
534 |
-
label="Number of Predictions"
|
535 |
-
)
|
536 |
-
text_submit = gr.Button("Predict Location")
|
537 |
|
538 |
-
with
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
inputs=text_input
|
550 |
-
)
|
551 |
-
|
552 |
-
text_map_output = gr.HTML(label="Map Visualization")
|
553 |
-
text_result_output = gr.Textbox(label="Prediction Results")
|
554 |
-
|
555 |
-
text_submit.click(
|
556 |
-
predict_from_text,
|
557 |
-
inputs=[text_input, text_top_k],
|
558 |
-
outputs=[text_map_output, text_result_output]
|
559 |
-
)
|
560 |
-
|
561 |
-
with gr.TabItem("Image → Location with Hub Integration"):
|
562 |
-
with gr.Row():
|
563 |
-
with gr.Column():
|
564 |
-
image_input = gr.Image(type="pil", label="Upload Image")
|
565 |
-
save_to_hub = gr.Checkbox(
|
566 |
-
label="Save to HuggingFace Dataset",
|
567 |
-
value=True
|
568 |
-
)
|
569 |
-
image_top_k = gr.Slider(
|
570 |
-
minimum=1,
|
571 |
-
maximum=20,
|
572 |
-
value=10,
|
573 |
-
step=1,
|
574 |
-
label="Number of Predictions"
|
575 |
-
)
|
576 |
-
image_submit = gr.Button("Process Image")
|
577 |
-
|
578 |
-
image_map_output = gr.HTML(label="Map Visualization")
|
579 |
-
image_result_output = gr.Textbox(label="Prediction Results")
|
580 |
-
metadata_output = gr.JSON(label="Image Metadata")
|
581 |
-
|
582 |
-
image_submit.click(
|
583 |
-
process_image,
|
584 |
-
inputs=[image_input, image_input.upload_path, save_to_hub, image_top_k],
|
585 |
-
outputs=[image_map_output, image_result_output, metadata_output]
|
586 |
-
)
|
587 |
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
inputs=[text1_input, text2_input],
|
605 |
-
outputs=similarity_output
|
606 |
)
|
607 |
|
608 |
-
with gr.
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
f"Number of entries: {len(geo_core.dataset)}"
|
617 |
-
)
|
618 |
-
|
619 |
-
update_status.click(
|
620 |
-
update_dataset_status,
|
621 |
-
inputs=[],
|
622 |
-
outputs=[dataset_info, dataset_count]
|
623 |
-
)
|
624 |
|
625 |
-
|
626 |
-
demo.launch(share=True)
|
627 |
-
|
628 |
|
|
|
629 |
if __name__ == "__main__":
|
630 |
-
# Read API token from environment variable
|
631 |
hf_token = os.environ.get("HF_TOKEN")
|
632 |
-
|
633 |
-
|
634 |
-
launch_gradio_interface(hf_token=hf_token)
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
import folium
|
4 |
+
from folium.plugins import HeatMap
|
5 |
import gradio as gr
|
6 |
import os
|
7 |
import PIL.Image
|
|
|
|
|
8 |
import json
|
9 |
import time
|
10 |
+
from typing import Dict, Any, Optional, Union
|
11 |
from pathlib import Path
|
12 |
from datasets import Dataset, load_dataset, concatenate_datasets
|
13 |
+
from huggingface_hub import HfApi
|
14 |
|
15 |
# GeoCLIP dependencies
|
16 |
from geoclip import GeoCLIP
|
17 |
from transformers import CLIPTokenizer, CLIPProcessor
|
|
|
|
|
18 |
|
19 |
+
# Initialize GeoCLIP core with vectorized execution path
|
20 |
class GeoCLIPCore:
|
21 |
+
def __init__(self, device=None, dataset_id="latterworks/geo-metadata", token=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
self.dataset_id = dataset_id
|
24 |
self.token = token
|
|
|
|
|
|
|
|
|
|
|
25 |
self._model = GeoCLIP().to(self.device)
|
26 |
self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
27 |
self._processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
28 |
self._location_encoder = self._model.location_encoder
|
29 |
+
self._gps_gallery = None # Lazy-loaded for memory optimization
|
|
|
30 |
|
31 |
+
# Initialize dataset connection with error handling
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
try:
|
|
|
33 |
self.dataset = load_dataset(self.dataset_id, split="train", token=self.token)
|
34 |
+
except Exception:
|
35 |
+
self.dataset = Dataset.from_dict({"filename": [], "classes": [], "metadata": []})
|
36 |
+
|
37 |
+
# Core tensor operations for embedding generation
|
38 |
+
def text_to_location(self, text, top_k=5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
with torch.no_grad():
|
40 |
tokens = self._tokenizer(text, return_tensors="pt", padding=True).to(self.device)
|
41 |
embedding = self._model.image_encoder.mlp(
|
42 |
self._model.image_encoder.CLIP.get_text_features(**tokens)
|
43 |
)
|
44 |
+
embedding = torch.nn.functional.normalize(embedding, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
# Ensure gallery is loaded with memory pooling
|
47 |
+
if self._gps_gallery is None:
|
48 |
+
self._gps_gallery = self._model.gps_gallery.to(self.device)
|
|
|
|
|
|
|
49 |
|
50 |
+
# Execute vectorized similarity computation
|
51 |
location_embeddings = self._location_encoder(self._gps_gallery)
|
52 |
location_embeddings = torch.nn.functional.normalize(location_embeddings, dim=1)
|
53 |
+
similarity = self._model.logit_scale.exp() * (embedding @ location_embeddings.T)
|
|
|
|
|
54 |
probs = similarity.softmax(dim=-1)
|
55 |
|
56 |
+
# Extract predictions with single tensor operation
|
57 |
top_values, top_indices = torch.topk(probs[0], min(top_k, len(self._gps_gallery)))
|
58 |
|
59 |
+
return [
|
60 |
+
{"coordinates": tuple(self._gps_gallery[idx].cpu().numpy()),
|
61 |
+
"confidence": float(conf)}
|
62 |
+
for idx, conf in zip(top_indices.cpu().numpy(), top_values.cpu().numpy())
|
63 |
+
]
|
64 |
+
|
65 |
+
# Generate map visualization with optimized rendering
|
66 |
+
def create_map_visualization(self, predictions, title=""):
|
67 |
+
m = folium.Map(location=predictions[0]["coordinates"], zoom_start=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if title:
|
69 |
+
m.get_root().html.add_child(folium.Element(f'<h3 style="text-align:center">{title}</h3>'))
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Add markers with confidence metadata
|
72 |
for i, pred in enumerate(predictions):
|
73 |
color = 'red' if i == 0 else 'blue' if i < 3 else 'green'
|
|
|
74 |
folium.Marker(
|
75 |
location=pred["coordinates"],
|
76 |
popup=f"Prediction #{i+1}<br>Confidence: {pred['confidence']:.6f}",
|
77 |
icon=folium.Icon(color=color)
|
78 |
+
).add_to(m)
|
79 |
|
80 |
+
# Add heatmap for density visualization
|
|
|
|
|
|
|
|
|
81 |
if len(predictions) >= 3:
|
82 |
heat_data = [[p["coordinates"][0], p["coordinates"][1], p["confidence"]]
|
83 |
+
for p in predictions]
|
84 |
HeatMap(heat_data, radius=15, blur=10).add_to(m)
|
85 |
|
86 |
return m
|
87 |
|
88 |
+
# Initialize GeoCLIP and codebase exemplars
|
89 |
+
def initialize_gradio_interface(hf_token=None):
|
90 |
+
python_code = """def fib(n):
|
91 |
+
if n <= 0:
|
92 |
+
return 0
|
93 |
+
elif n == 1:
|
94 |
+
return 1
|
95 |
+
else:
|
96 |
+
return fib(n-1) + fib(n-2)
|
97 |
+
"""
|
98 |
+
js_code = """function fib(n) {
|
99 |
+
if (n <= 0) return 0;
|
100 |
+
if (n === 1) return 1;
|
101 |
+
return fib(n - 1) + fib(n - 2);
|
102 |
+
}
|
103 |
+
"""
|
104 |
+
# Initialize GeoCLIP with optimized resource allocation
|
105 |
geo_core = GeoCLIPCore(token=hf_token)
|
106 |
|
107 |
+
# Message handler with multimodal dispatch logic
|
108 |
+
def chat(message, history):
|
109 |
+
if "python" in message.lower():
|
110 |
+
return "Type Python or JavaScript to see the code.", gr.Code(language="python", value=python_code)
|
111 |
+
elif "javascript" in message.lower():
|
112 |
+
return "Type Python or JavaScript to see the code.", gr.Code(language="javascript", value=js_code)
|
113 |
+
elif any(kw in message.lower() for kw in ["location", "where", "place", "predict"]):
|
114 |
+
# Extract location query with pattern matching
|
115 |
+
for term in ["location", "where", "place", "find", "predict"]:
|
116 |
+
if term in message.lower():
|
117 |
+
query = message.lower().split(term, 1)[1].strip()
|
118 |
+
if not query:
|
119 |
+
return "Please provide a location description.", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
# Execute prediction with tensor acceleration
|
122 |
+
predictions = geo_core.text_to_location(query, top_k=5)
|
123 |
+
m = geo_core.create_map_visualization(predictions, f"Predictions for: {query}")
|
124 |
+
|
125 |
+
# Format response with structured data
|
126 |
+
result = f"Top predictions for: '{query}'\n\n"
|
127 |
+
for i, pred in enumerate(predictions, 1):
|
128 |
+
coords = pred["coordinates"]
|
129 |
+
result += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - conf: {pred['confidence']:.6f}\n"
|
130 |
+
|
131 |
+
return result, gr.HTML(value=m._repr_html_())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
+
return "Couldn't process your location query. Please try again.", None
|
134 |
+
else:
|
135 |
+
return "I can show code examples or predict locations. Try 'Where is the Eiffel Tower?'", None
|
136 |
+
|
137 |
+
# Build gradio blocks with structured layout
|
138 |
+
with gr.Blocks() as demo:
|
139 |
+
code = gr.Code(render=False)
|
140 |
+
map_output = gr.HTML(render=False)
|
141 |
+
|
142 |
+
with gr.Row():
|
143 |
+
with gr.Column(scale=1):
|
144 |
+
gr.Markdown("<center><h1>GeoCLIP + Code Examples</h1></center>")
|
145 |
+
chatbot = gr.ChatInterface(
|
146 |
+
chat,
|
147 |
+
examples=["Python", "JavaScript", "Where is the Eiffel Tower?"],
|
148 |
+
additional_outputs=[code, map_output]
|
|
|
|
|
149 |
)
|
150 |
|
151 |
+
with gr.Column(scale=1):
|
152 |
+
gr.Markdown("<center><h1>Output Artifacts</h1></center>")
|
153 |
+
with gr.Tab("Code"):
|
154 |
+
code.render()
|
155 |
+
with gr.Tab("Location Map"):
|
156 |
+
map_output.render()
|
157 |
+
|
158 |
+
gr.Markdown(f"<center>Connected to dataset: {geo_core.dataset_id}</center>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
return demo
|
|
|
|
|
161 |
|
162 |
+
# Entry point with environmental token acquisition
|
163 |
if __name__ == "__main__":
|
|
|
164 |
hf_token = os.environ.get("HF_TOKEN")
|
165 |
+
demo = initialize_gradio_interface(hf_token)
|
166 |
+
demo.launch()
|
|