latterworks commited on
Commit
3f2c8e3
·
verified ·
1 Parent(s): 49e3fab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +755 -173
app.py CHANGED
@@ -1,191 +1,773 @@
1
  import gradio as gr
2
- from pathlib import Path
3
- from PIL import Image, ExifTags
4
- import json
5
- import os
6
- import logging
7
- import traceback
8
  import folium
 
 
 
 
9
  import io
10
- from typing import Dict, List, Any, Optional, Tuple, Union
 
 
 
 
 
11
 
12
  # Configure logging
13
- logging.basicConfig(level=logging.INFO)
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
- # Supported extensions
17
- SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.heic', '.tiff', '.bmp', '.webp'}
18
-
19
- def convert_to_degrees(value: tuple) -> Optional[float]:
20
- """Convert GPS coordinates from DMS to decimal degrees."""
21
- try:
22
- d, m, s = value
23
- return float(d) + (float(m) / 60.0) + (float(s) / 3600.0)
24
- except Exception as e:
25
- logger.error(f"GPS conversion error: {e}")
26
- return None
27
-
28
- def extract_gps_info(gps_info: Dict[int, Any]) -> Optional[Dict[str, Any]]:
29
- """Extract GPS data from EXIF."""
30
- if not isinstance(gps_info, dict):
31
- return None
32
-
33
- gps_data = {}
34
- try:
35
- for key, val in gps_info.items():
36
- tag_name = ExifTags.GPSTAGS.get(key, f"unknown_gps_tag_{key}")
37
- gps_data[tag_name] = val
38
-
39
- if 'GPSLatitude' in gps_data and 'GPSLongitude' in gps_data:
40
- lat = convert_to_degrees(gps_data['GPSLatitude'])
41
- lon = convert_to_degrees(gps_data['GPSLongitude'])
42
-
43
- if lat is None or lon is None:
44
- return None
45
-
46
- if gps_data.get('GPSLatitudeRef', 'N') == 'S':
47
- lat = -lat
48
- if gps_data.get('GPSLongitudeRef', 'E') == 'W':
49
- lon = -lon
50
-
51
- gps_data['Latitude'] = lat
52
- gps_data['Longitude'] = lon
53
-
54
- return gps_data
55
- except Exception as e:
56
- logger.error(f"GPS extraction error: {str(e)}")
57
- return None
58
-
59
- def get_image_metadata(image_path: Path) -> Dict[str, Any]:
60
- """Extract metadata from a single image."""
61
- metadata = {"file_name": str(image_path.absolute())}
62
- try:
63
- with Image.open(image_path) as image:
64
- metadata.update({
65
- "format": image.format or "unknown",
66
- "size": list(image.size)
67
- })
68
-
69
- exif_data = None
70
- try:
71
- exif_data = image._getexif()
72
- except Exception:
73
- metadata["exif_error"] = "No EXIF data available"
74
-
75
- if exif_data and isinstance(exif_data, dict):
76
- for tag_id, value in exif_data.items():
77
- try:
78
- tag_name = ExifTags.TAGS.get(tag_id, f"tag_{tag_id}").lower()
79
- if tag_name == "gpsinfo":
80
- gps_info = extract_gps_info(value)
81
- if gps_info:
82
- metadata["gps_info"] = gps_info
83
- except Exception:
84
- pass
85
-
86
- return metadata
87
- except Exception as e:
88
- logger.error(f"Error processing {image_path}: {str(e)}")
89
- return {"file_name": str(image_path.absolute()), "error": str(e)}
90
-
91
- def process_images(files) -> Tuple[str, List[Dict[str, Any]]]:
92
- """Process uploaded image files."""
93
- metadata_list = []
94
-
95
- try:
96
- # Create temp directory for uploads
97
- temp_dir = Path("./temp_uploads")
98
- temp_dir.mkdir(exist_ok=True)
99
-
100
- # Save and process uploaded files
101
- for file in files:
102
- # Handle byte content from Gradio uploads
103
- if hasattr(file, "name") and hasattr(file, "read"):
104
- file_path = temp_dir / file.name
105
- with open(file_path, "wb") as f:
106
- f.write(file.read())
107
-
108
- if file_path.suffix.lower() in SUPPORTED_EXTENSIONS:
109
- metadata = get_image_metadata(file_path)
110
- if metadata:
111
- metadata_list.append(metadata)
112
-
113
- if not metadata_list:
114
- return "No valid images found", []
115
-
116
- # Count geotagged images
117
- geotagged = sum(1 for m in metadata_list if "gps_info" in m)
118
-
119
- return f"Processed {len(metadata_list)} images ({geotagged} geotagged)", metadata_list
120
-
121
- except Exception as e:
122
- logger.error(f"Error processing uploads: {traceback.format_exc()}")
123
- return f"Error: {str(e)}", []
124
-
125
- def create_map(metadata_list: List[Dict[str, Any]]) -> str:
126
- """Create a folium map with markers for geotagged images."""
127
- try:
128
- # Extract coordinates
129
- coords = []
130
- for entry in metadata_list:
131
- gps_info = entry.get("gps_info", {})
132
- if isinstance(gps_info, dict) and "Latitude" in gps_info and "Longitude" in gps_info:
133
- coords.append((
134
- gps_info["Latitude"],
135
- gps_info["Longitude"],
136
- os.path.basename(entry.get("file_name", "Unknown"))
137
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- if not coords:
140
- return "No geotagged images found"
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- # Calculate center
143
- center_lat = sum(c[0] for c in coords) / len(coords)
144
- center_lon = sum(c[1] for c in coords) / len(coords)
145
 
146
- # Create map
147
- m = folium.Map(location=[center_lat, center_lon], zoom_start=10)
 
 
 
 
 
148
 
149
- # Add markers
150
- for lat, lon, name in coords:
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  folium.Marker(
152
- location=[lat, lon],
153
- popup=f"<b>{name}</b><br>Location: {lat:.6f}, {lon:.6f}"
 
154
  ).add_to(m)
 
 
 
 
 
 
 
 
155
 
156
- # Convert to HTML
157
- map_html = m._repr_html_()
158
- return map_html
159
-
160
- except Exception as e:
161
- logger.error(f"Error creating map: {str(e)}")
162
- return f"Error creating map: {str(e)}"
163
-
164
- def ui_process_files(files):
165
- """Handle file processing for the UI."""
166
- if not files:
167
- return "No files uploaded", None, "No data to display"
168
-
169
- status, metadata = process_images(files)
170
- map_html = create_map(metadata) if metadata else "No map data available"
171
-
172
- return status, metadata, map_html
173
-
174
- # Create the Gradio interface
175
- demo = gr.Interface(
176
- fn=ui_process_files,
177
- inputs=gr.Files(label="Upload Images"),
178
- outputs=[
179
- gr.Textbox(label="Status"),
180
- gr.JSON(label="Metadata"),
181
- gr.HTML(label="Map")
182
- ],
183
- title="Simple Geotagged Image Analyzer",
184
- description="Upload images to extract EXIF data and view locations on a map.",
185
- examples=[],
186
- cache_examples=False,
187
- allow_flagging=False,
188
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  if __name__ == "__main__":
 
191
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from functools import lru_cache
 
 
 
5
  import folium
6
+ from folium.plugins import HeatMap, MarkerCluster
7
+ from typing import List, Dict, Tuple, Optional, Union, Any
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
  import io
11
+ import base64
12
+ from dataclasses import dataclass
13
+ import logging
14
+ import warnings
15
+ from transformers import CLIPTokenizer
16
+ from geoclip import GeoCLIP
17
 
18
  # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format="%(asctime)s [%(levelname)s] %(message)s"
22
+ )
23
  logger = logging.getLogger(__name__)
24
 
25
+ # Suppress transformer warnings
26
+ warnings.filterwarnings("ignore", message="weights_only=False")
27
+
28
+ @dataclass
29
+ class LocationPrediction:
30
+ """Structured container for geographic predictions with confidence metrics."""
31
+ coordinates: Tuple[float, float]
32
+ confidence: float
33
+
34
+
35
+ class GeoCLIPAnalyzer:
36
+ """High-performance GeoCLIP analyzer with cached operations and optimized tensor handling."""
37
+
38
+ def __init__(self, cache_enabled: bool = True, cache_size: int = 128):
39
+ """
40
+ Initialize the analyzer with configurable caching.
41
+
42
+ Args:
43
+ cache_enabled: Toggle for LRU caching mechanism
44
+ cache_size: Maximum cache entries per method
45
+ """
46
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ logger.info(f"Initializing GeoCLIP on {self.device}")
48
+
49
+ self.model = GeoCLIP().to(self.device)
50
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
51
+
52
+ # Apply LRU caching to high-compute methods
53
+ if cache_enabled:
54
+ self.predict_location = lru_cache(maxsize=cache_size)(self.predict_location)
55
+ self.analyze_temporal_variations = lru_cache(maxsize=cache_size)(self.analyze_temporal_variations)
56
+ self.find_related_locations = lru_cache(maxsize=cache_size)(self.find_related_locations)
57
+ logger.info(f"Method caching enabled with size {cache_size}")
58
+
59
+ def predict_location(self, text: str, top_k: int = 5) -> List[LocationPrediction]:
60
+ """
61
+ Generate location predictions with confidence metrics for a text query.
62
+
63
+ Args:
64
+ text: Descriptive location text query
65
+ top_k: Number of predictions to return
66
+
67
+ Returns:
68
+ List of LocationPrediction objects sorted by confidence
69
+ """
70
+ with torch.no_grad():
71
+ # Generate text embeddings
72
+ text_inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.device)
73
+ text_features = self.model.image_encoder.mlp(
74
+ self.model.image_encoder.CLIP.get_text_features(**text_inputs)
75
+ )
76
+ text_features = F.normalize(text_features, dim=1)
77
+
78
+ # Retrieve and normalize location features
79
+ gps_gallery = self.model.gps_gallery.to(self.device)
80
+ location_features = self.model.location_encoder(gps_gallery)
81
+ location_features = F.normalize(location_features, dim=1)
82
+
83
+ # Compute similarity and extract top predictions
84
+ similarity = self.model.logit_scale.exp() * (text_features @ location_features.T)
85
+ probs = similarity.softmax(dim=-1)
86
+ top_pred = torch.topk(probs[0], top_k)
87
+
88
+ # Convert to Python native types for consistent serialization
89
+ predictions = []
90
+ for coord, conf in zip(
91
+ gps_gallery[top_pred.indices].cpu().numpy(),
92
+ top_pred.values.cpu().numpy()
93
+ ):
94
+ predictions.append(LocationPrediction(
95
+ coordinates=(float(coord[0]), float(coord[1])),
96
+ confidence=float(conf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  ))
98
+
99
+ return predictions
100
+
101
+ def create_location_map(
102
+ self,
103
+ predictions: List[LocationPrediction],
104
+ title: str,
105
+ zoom: int = 5,
106
+ heatmap: bool = True,
107
+ cluster: bool = False
108
+ ) -> str:
109
+ """
110
+ Generate an interactive folium map visualization from location predictions.
111
 
112
+ Args:
113
+ predictions: List of location predictions
114
+ title: Map title/description
115
+ zoom: Initial zoom level (higher = more zoomed in)
116
+ heatmap: Whether to add heatmap layer
117
+ cluster: Whether to cluster nearby markers
118
+
119
+ Returns:
120
+ HTML string of rendered interactive map
121
+ """
122
+ # Calculate center from prediction distribution
123
+ center_lat = sum(p.coordinates[0] for p in predictions) / len(predictions)
124
+ center_lon = sum(p.coordinates[1] for p in predictions) / len(predictions)
125
 
126
+ # Initialize map
127
+ m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom)
 
128
 
129
+ # Add heatmap if requested
130
+ if heatmap and len(predictions) > 1:
131
+ heat_data = [
132
+ [pred.coordinates[0], pred.coordinates[1], pred.confidence]
133
+ for pred in predictions
134
+ ]
135
+ HeatMap(heat_data, radius=25, blur=15, min_opacity=0.4).add_to(m)
136
 
137
+ # Add markers with clustering option
138
+ if cluster:
139
+ marker_cluster = MarkerCluster().add_to(m)
140
+
141
+ for i, pred in enumerate(predictions):
142
+ popup_text = f"{title} #{i+1}<br>Coordinates: {pred.coordinates[0]:.6f}, {pred.coordinates[1]:.6f}<br>Confidence: {pred.confidence:.4f}"
143
+
144
+ folium.Marker(
145
+ location=pred.coordinates,
146
+ popup=popup_text,
147
+ icon=folium.Icon(color='red' if i == 0 else 'blue')
148
+ ).add_to(marker_cluster)
149
+ else:
150
+ # Add top prediction with distinctive marker
151
+ top_pred = predictions[0]
152
  folium.Marker(
153
+ location=top_pred.coordinates,
154
+ popup=f"{title}<br>Confidence: {top_pred.confidence:.4f}",
155
+ icon=folium.Icon(color='red', icon='info-sign')
156
  ).add_to(m)
157
+
158
+ # Add remaining predictions with standard markers
159
+ for i, pred in enumerate(predictions[1:], 1):
160
+ folium.Marker(
161
+ location=pred.coordinates,
162
+ popup=f"{title} #{i+1}<br>Confidence: {pred.confidence:.4f}",
163
+ icon=folium.Icon(color='blue')
164
+ ).add_to(m)
165
 
166
+ # Return HTML
167
+ return m._repr_html_()
168
+
169
+ def analyze_at_scales(self, text: str) -> Dict[str, str]:
170
+ """
171
+ Perform multi-scale geospatial analysis with optimized visualization parameters.
172
+
173
+ Args:
174
+ text: Location description text
175
+
176
+ Returns:
177
+ Dictionary mapping scale names to HTML map visualizations
178
+ """
179
+ # Generate predictions
180
+ predictions = self.predict_location(text, top_k=10)
181
+
182
+ # Define scale parameters (zoom level, display radius)
183
+ scales = {
184
+ 'Street': {'zoom': 17, 'radius': 10},
185
+ 'Neighborhood': {'zoom': 14, 'radius': 15},
186
+ 'City': {'zoom': 11, 'radius': 20},
187
+ 'Region': {'zoom': 8, 'radius': 25},
188
+ 'Country': {'zoom': 5, 'radius': 30}
189
+ }
190
+
191
+ # Generate visualizations at each scale
192
+ visualizations = {}
193
+ for scale_name, params in scales.items():
194
+ map_html = self.create_location_map(
195
+ predictions=predictions,
196
+ title=f"{scale_name} view of {text}",
197
+ zoom=params['zoom'],
198
+ heatmap=True if scale_name in ['Region', 'Country'] else False
199
+ )
200
+ visualizations[scale_name] = map_html
201
+
202
+ return visualizations
203
+
204
+ def analyze_temporal_variations(
205
+ self,
206
+ base_location: str,
207
+ time_periods: List[str]
208
+ ) -> Tuple[List[Tuple[str, Tuple[float, float], float]], str]:
209
+ """
210
+ Analyze location shifts across time periods with trajectory visualization.
211
+
212
+ Args:
213
+ base_location: Base location descriptor
214
+ time_periods: List of time period identifiers
215
+
216
+ Returns:
217
+ Tuple of (analysis_results, map_html)
218
+ """
219
+ results = []
220
+ m = folium.Map(zoom_start=4)
221
+ colors = ['red', 'blue', 'green', 'purple', 'orange', 'darkred', 'darkblue', 'cadetblue', 'darkgreen']
222
+
223
+ for period, color in zip(time_periods, colors * (1 + len(time_periods) // len(colors))):
224
+ query = f"{base_location} in {period}"
225
+ predictions = self.predict_location(query, top_k=1)
226
+
227
+ if predictions:
228
+ pred = predictions[0]
229
+ coords = pred.coordinates
230
+ conf = pred.confidence
231
+
232
+ # Add marker with period information
233
+ folium.Marker(
234
+ location=coords,
235
+ popup=f"{period}<br>Confidence: {conf:.4f}",
236
+ icon=folium.Icon(color=color, icon='info-sign')
237
+ ).add_to(m)
238
+
239
+ results.append((period, coords, conf))
240
+
241
+ # Connect points chronologically with polyline
242
+ if len(results) > 1:
243
+ points = [coords for _, coords, _ in results]
244
+ folium.PolyLine(points, weight=2, color='gray', opacity=0.8,
245
+ dash_array='5, 5').add_to(m)
246
+
247
+ # Center map on middle point for optimal visualization
248
+ if results:
249
+ center_point = results[len(results)//2][1]
250
+ m.location = center_point
251
+
252
+ return results, m._repr_html_()
253
+
254
+ def find_related_locations(
255
+ self,
256
+ reference: str,
257
+ candidates: List[str]
258
+ ) -> List[Tuple[str, float]]:
259
+ """
260
+ Identify semantically related locations using embedding cosine similarity.
261
+
262
+ Args:
263
+ reference: Reference location or concept
264
+ candidates: List of candidate locations to compare
265
+
266
+ Returns:
267
+ List of (location, similarity_score) tuples sorted by relevance
268
+ """
269
+ with torch.no_grad():
270
+ # Generate reference embedding
271
+ text_inputs = self.tokenizer(reference, return_tensors="pt", padding=True).to(self.device)
272
+ ref_features = self.model.image_encoder.mlp(
273
+ self.model.image_encoder.CLIP.get_text_features(**text_inputs)
274
+ )
275
+ ref_features = F.normalize(ref_features, dim=1)
276
+
277
+ results = []
278
+ for candidate in candidates:
279
+ # Generate candidate embedding
280
+ text_inputs = self.tokenizer(candidate, return_tensors="pt", padding=True).to(self.device)
281
+ cand_features = self.model.image_encoder.mlp(
282
+ self.model.image_encoder.CLIP.get_text_features(**text_inputs)
283
+ )
284
+ cand_features = F.normalize(cand_features, dim=1)
285
+
286
+ # Compute similarity
287
+ similarity = F.cosine_similarity(
288
+ ref_features, cand_features
289
+ ).item()
290
+
291
+ results.append((candidate, similarity))
292
+
293
+ # Sort by similarity (descending)
294
+ return sorted(results, key=lambda x: x[1], reverse=True)
295
+
296
+ def visualize_related_locations(
297
+ self,
298
+ reference: str,
299
+ candidates: List[str]
300
+ ) -> Tuple[List[Tuple[str, float]], str]:
301
+ """
302
+ Visualize semantically related locations with map integration.
303
+
304
+ Args:
305
+ reference: Reference location or concept
306
+ candidates: List of candidate locations
307
+
308
+ Returns:
309
+ Tuple of (similarity_results, map_html)
310
+ """
311
+ # Compute similarities
312
+ related_results = self.find_related_locations(reference, candidates)
313
+
314
+ # Predict coordinates for all locations
315
+ marker_data = []
316
+ ref_predictions = self.predict_location(reference, top_k=1)
317
+
318
+ if ref_predictions:
319
+ ref_coords = ref_predictions[0].coordinates
320
+ ref_conf = ref_predictions[0].confidence
321
+ marker_data.append((reference, ref_coords, ref_conf, 'red', 1.0))
322
+
323
+ # Get coordinates for each candidate
324
+ for candidate, similarity in related_results:
325
+ predictions = self.predict_location(candidate, top_k=1)
326
+ if predictions:
327
+ coords = predictions[0].coordinates
328
+ conf = predictions[0].confidence
329
+ marker_data.append((candidate, coords, conf, 'blue', similarity))
330
+
331
+ # Create map
332
+ m = folium.Map()
333
+
334
+ for name, coords, conf, color, sim in marker_data:
335
+ # Scale marker size by similarity
336
+ radius = 8 + (sim * 10) if name != reference else 15
337
+
338
+ # Add circle marker
339
+ folium.CircleMarker(
340
+ location=coords,
341
+ radius=radius,
342
+ popup=f"{name}<br>Similarity: {sim:.4f}<br>Confidence: {conf:.4f}",
343
+ color=color,
344
+ fill=True,
345
+ fill_color=color
346
+ ).add_to(m)
347
+
348
+ # Connect to reference with line
349
+ if name != reference:
350
+ folium.PolyLine(
351
+ [ref_coords, coords],
352
+ color=color,
353
+ weight=sim * 5, # Scale line weight by similarity
354
+ opacity=0.7
355
+ ).add_to(m)
356
+
357
+ # Fit bounds to include all markers
358
+ if marker_data:
359
+ all_lats = [coords[0] for _, coords, _, _, _ in marker_data]
360
+ all_lons = [coords[1] for _, coords, _, _, _ in marker_data]
361
+ sw = [min(all_lats), min(all_lons)]
362
+ ne = [max(all_lats), max(all_lons)]
363
+ m.fit_bounds([sw, ne])
364
+
365
+ return related_results, m._repr_html_()
366
+
367
+ return related_results, ""
368
+
369
+ def comprehensive_analysis(self, location: str) -> Dict[str, Any]:
370
+ """
371
+ Execute comprehensive multi-faceted location analysis pipeline.
372
+
373
+ Args:
374
+ location: Target location description
375
+
376
+ Returns:
377
+ Dictionary containing all analysis results
378
+ """
379
+ results = {
380
+ "query": location,
381
+ "timestamp": None, # Can be filled with current timestamp
382
+ }
383
+
384
+ # Basic prediction
385
+ predictions = self.predict_location(location, top_k=5)
386
+ results["predictions"] = predictions
387
+
388
+ # Create basic map
389
+ results["basic_map"] = self.create_location_map(
390
+ predictions,
391
+ f"'{location}' Predictions"
392
+ )
393
+
394
+ # Multi-scale analysis
395
+ results["scale_maps"] = self.analyze_at_scales(location)
396
+
397
+ # Temporal analysis
398
+ time_periods = ["ancient times", "middle ages", "19th century", "modern day"]
399
+ temporal_results, temporal_map = self.analyze_temporal_variations(location, time_periods)
400
+ results["temporal_analysis"] = temporal_results
401
+ results["temporal_map"] = temporal_map
402
+
403
+ # Related locations analysis
404
+ candidates = [
405
+ f"{location} business district",
406
+ f"{location} historic center",
407
+ f"{location} tourist area",
408
+ f"{location} downtown",
409
+ f"{location} suburbs"
410
+ ]
411
+ similarity_results, similarity_map = self.visualize_related_locations(
412
+ location, candidates
413
+ )
414
+ results["similarity_analysis"] = similarity_results
415
+ results["similarity_map"] = similarity_map
416
+
417
+ return results
418
+
419
+
420
+ def create_temporal_analysis_ui(analyzer):
421
+ """Create the temporal analysis interface component."""
422
+ with gr.Column():
423
+ gr.Markdown("## Temporal Analysis")
424
+ gr.Markdown("Analyze how a location changes across different time periods.")
425
+
426
+ with gr.Row():
427
+ with gr.Column():
428
+ base_location = gr.Textbox(label="Base Location", placeholder="e.g., Constantinople")
429
+ with gr.Row():
430
+ time_periods = gr.Textbox(
431
+ label="Time Periods (comma-separated)",
432
+ placeholder="ancient times, middle ages, 19th century, modern day",
433
+ value="ancient times, middle ages, 19th century, modern day"
434
+ )
435
+ temporal_btn = gr.Button("Analyze Temporal Variations", variant="primary")
436
+
437
+ with gr.Column():
438
+ temporal_results = gr.Dataframe(
439
+ headers=["Time Period", "Latitude", "Longitude", "Confidence"],
440
+ label="Temporal Analysis Results"
441
+ )
442
+ temporal_map = gr.HTML(label="Temporal Map")
443
+
444
+ def run_temporal_analysis(location, periods_text):
445
+ if not location:
446
+ return None, "Please enter a base location"
447
+
448
+ periods = [p.strip() for p in periods_text.split(",") if p.strip()]
449
+ if not periods:
450
+ return None, "Please enter at least one time period"
451
+
452
+ try:
453
+ # Run analysis
454
+ results, map_html = analyzer.analyze_temporal_variations(location, periods)
455
+
456
+ # Format results for dataframe
457
+ df_data = [
458
+ [period, coords[0], coords[1], conf]
459
+ for period, coords, conf in results
460
+ ]
461
+
462
+ return df_data, map_html
463
+ except Exception as e:
464
+ logger.error(f"Error in temporal analysis: {str(e)}")
465
+ return None, f"Error: {str(e)}"
466
+
467
+ temporal_btn.click(
468
+ fn=run_temporal_analysis,
469
+ inputs=[base_location, time_periods],
470
+ outputs=[temporal_results, temporal_map]
471
+ )
472
+
473
+ return base_location, time_periods, temporal_btn, temporal_results, temporal_map
474
+
475
+
476
+ def create_related_locations_ui(analyzer):
477
+ """Create the related locations interface component."""
478
+ with gr.Column():
479
+ gr.Markdown("## Related Locations Analysis")
480
+ gr.Markdown("Find semantically related locations based on GeoCLIP embeddings.")
481
+
482
+ with gr.Row():
483
+ with gr.Column():
484
+ reference_location = gr.Textbox(
485
+ label="Reference Location/Concept",
486
+ placeholder="e.g., technology hub"
487
+ )
488
+ candidate_locations = gr.Textbox(
489
+ label="Candidate Locations (comma-separated)",
490
+ placeholder="Silicon Valley, Shenzhen China, Bangalore India",
491
+ value="Silicon Valley, Shenzhen China, Bangalore India, Tel Aviv Israel, London financial district"
492
+ )
493
+ related_btn = gr.Button("Find Related Locations", variant="primary")
494
+
495
+ with gr.Column():
496
+ similarity_results = gr.Dataframe(
497
+ headers=["Location", "Similarity Score"],
498
+ label="Similarity Results"
499
+ )
500
+ similarity_map = gr.HTML(label="Similarity Map")
501
+
502
+ def run_similarity_analysis(reference, candidates_text):
503
+ if not reference:
504
+ return None, "Please enter a reference location or concept"
505
+
506
+ candidates = [c.strip() for c in candidates_text.split(",") if c.strip()]
507
+ if not candidates:
508
+ return None, "Please enter at least one candidate location"
509
+
510
+ try:
511
+ # Run analysis
512
+ results, map_html = analyzer.visualize_related_locations(reference, candidates)
513
+
514
+ # Format results for dataframe
515
+ df_data = [
516
+ [location, similarity]
517
+ for location, similarity in results
518
+ ]
519
+
520
+ return df_data, map_html
521
+ except Exception as e:
522
+ logger.error(f"Error in similarity analysis: {str(e)}")
523
+ return None, f"Error: {str(e)}"
524
+
525
+ related_btn.click(
526
+ fn=run_similarity_analysis,
527
+ inputs=[reference_location, candidate_locations],
528
+ outputs=[similarity_results, similarity_map]
529
+ )
530
+
531
+ return reference_location, candidate_locations, related_btn, similarity_results, similarity_map
532
+
533
+
534
+ def create_comprehensive_analysis_ui(analyzer):
535
+ """Create the comprehensive analysis interface component."""
536
+ with gr.Column():
537
+ gr.Markdown("## Comprehensive Analysis")
538
+ gr.Markdown("Perform a full multi-faceted analysis of a location.")
539
+
540
+ with gr.Row():
541
+ with gr.Column(scale=1):
542
+ comp_location = gr.Textbox(
543
+ label="Location",
544
+ placeholder="e.g., Tokyo Japan"
545
+ )
546
+ comp_btn = gr.Button("Run Comprehensive Analysis", variant="primary")
547
+
548
+ with gr.Column(scale=3):
549
+ with gr.Tabs():
550
+ with gr.TabItem("Basic Prediction"):
551
+ basic_results = gr.Dataframe(
552
+ headers=["Rank", "Latitude", "Longitude", "Confidence"],
553
+ label="Top Predictions"
554
+ )
555
+ basic_map = gr.HTML(label="Map")
556
+
557
+ with gr.TabItem("Multi-scale Analysis"):
558
+ with gr.Tabs() as scale_tabs:
559
+ scale_maps = {
560
+ scale: gr.HTML(label=f"{scale} Scale")
561
+ for scale in ["Street", "Neighborhood", "City", "Region", "Country"]
562
+ }
563
+
564
+ with gr.TabItem("Temporal Analysis"):
565
+ comp_temporal_results = gr.Dataframe(
566
+ headers=["Time Period", "Latitude", "Longitude", "Confidence"],
567
+ label="Temporal Analysis"
568
+ )
569
+ comp_temporal_map = gr.HTML(label="Temporal Map")
570
+
571
+ with gr.TabItem("Related Contexts"):
572
+ comp_similarity_results = gr.Dataframe(
573
+ headers=["Context", "Similarity Score"],
574
+ label="Related Contexts"
575
+ )
576
+ comp_similarity_map = gr.HTML(label="Similarity Map")
577
+
578
+ def run_comprehensive_analysis(location):
579
+ if not location:
580
+ return (
581
+ None, "",
582
+ {"Street": "", "Neighborhood": "", "City": "", "Region": "", "Country": ""},
583
+ None, "", None, ""
584
+ )
585
+
586
+ try:
587
+ # Run analysis
588
+ results = analyzer.comprehensive_analysis(location)
589
+
590
+ # Format basic results
591
+ basic_df = [
592
+ [i+1, pred.coordinates[0], pred.coordinates[1], pred.confidence]
593
+ for i, pred in enumerate(results["predictions"])
594
+ ]
595
+
596
+ # Format temporal results
597
+ temporal_df = [
598
+ [period, coords[0], coords[1], conf]
599
+ for period, coords, conf in results["temporal_analysis"]
600
+ ] if "temporal_analysis" in results else None
601
+
602
+ # Format similarity results
603
+ similarity_df = [
604
+ [location, similarity]
605
+ for location, similarity in results["similarity_analysis"]
606
+ ] if "similarity_analysis" in results else None
607
+
608
+ return (
609
+ basic_df,
610
+ results["basic_map"],
611
+ results["scale_maps"],
612
+ temporal_df,
613
+ results["temporal_map"],
614
+ similarity_df,
615
+ results["similarity_map"]
616
+ )
617
+ except Exception as e:
618
+ logger.error(f"Error in comprehensive analysis: {str(e)}")
619
+ return (
620
+ None, f"Error: {str(e)}",
621
+ {"Street": "", "Neighborhood": "", "City": "", "Region": "", "Country": ""},
622
+ None, "", None, ""
623
+ )
624
+
625
+ comp_btn.click(
626
+ fn=run_comprehensive_analysis,
627
+ inputs=[comp_location],
628
+ outputs=[
629
+ basic_results, basic_map,
630
+ gr.Dict(scale_maps),
631
+ comp_temporal_results, comp_temporal_map,
632
+ comp_similarity_results, comp_similarity_map
633
+ ]
634
+ )
635
+
636
+ return comp_location, comp_btn, basic_results, basic_map, scale_maps, comp_temporal_results, comp_temporal_map, comp_similarity_results, comp_similarity_map
637
+
638
+
639
+ def create_interface():
640
+ """Create the Gradio interface for the GeoCLIP Text-to-Location Analyzer."""
641
+ # Initialize the analyzer with caching
642
+ analyzer = GeoCLIPAnalyzer(cache_enabled=True)
643
+
644
+ with gr.Blocks(title="GeoCLIP Text-to-Location Analyzer") as demo:
645
+ gr.Markdown("# 🌍 GeoCLIP Text-to-Location Analyzer")
646
+ gr.Markdown("""
647
+ This interface allows you to analyze geographic locations using GeoCLIP's text-to-location capabilities.
648
+ You can perform basic location predictions, temporal analysis, find related locations, and run comprehensive analyses.
649
+ """)
650
+
651
+ # Basic prediction section
652
+ with gr.Column():
653
+ gr.Markdown("## Basic Location Prediction")
654
+ gr.Markdown("Enter a textual description of a location to get coordinate predictions.")
655
+
656
+ with gr.Row():
657
+ with gr.Column():
658
+ location_input = gr.Textbox(
659
+ label="Location Description",
660
+ placeholder="e.g., Eiffel Tower Paris"
661
+ )
662
+ top_k = gr.Slider(
663
+ minimum=1, maximum=10, value=5, step=1,
664
+ label="Number of Predictions"
665
+ )
666
+ predict_btn = gr.Button("Predict Location", variant="primary")
667
+
668
+ with gr.Column():
669
+ prediction_results = gr.Dataframe(
670
+ headers=["Rank", "Latitude", "Longitude", "Confidence"],
671
+ label="Prediction Results"
672
+ )
673
+ map_output = gr.HTML(label="Map Visualization")
674
+
675
+ # Add tab-based sections for different analyses
676
+ with gr.Tabs():
677
+ with gr.TabItem("Multi-scale Analysis"):
678
+ with gr.Row():
679
+ with gr.Column():
680
+ scale_location = gr.Textbox(
681
+ label="Location Description",
682
+ placeholder="e.g., Central Park New York"
683
+ )
684
+ scale_btn = gr.Button("Analyze at Different Scales", variant="primary")
685
+
686
+ with gr.Column():
687
+ with gr.Tabs() as scale_tabs:
688
+ street_map = gr.HTML(label="Street Level")
689
+ neighborhood_map = gr.HTML(label="Neighborhood Level")
690
+ city_map = gr.HTML(label="City Level")
691
+ region_map = gr.HTML(label="Regional Level")
692
+ country_map = gr.HTML(label="Country Level")
693
+
694
+ with gr.TabItem("Temporal Analysis"):
695
+ base_location, time_periods, temporal_btn, temporal_results, temporal_map = create_temporal_analysis_ui(analyzer)
696
+
697
+ with gr.TabItem("Related Locations"):
698
+ reference_location, candidate_locations, related_btn, similarity_results, similarity_map = create_related_locations_ui(analyzer)
699
+
700
+ with gr.TabItem("Comprehensive Analysis"):
701
+ comp_location, comp_btn, basic_results, basic_map, scale_maps, comp_temporal_results, comp_temporal_map, comp_similarity_results, comp_similarity_map = create_comprehensive_analysis_ui(analyzer)
702
+
703
+ # Basic prediction handler
704
+ def handle_prediction(text, k):
705
+ if not text:
706
+ return None, "Please enter a location description"
707
+
708
+ try:
709
+ predictions = analyzer.predict_location(text, top_k=int(k))
710
+
711
+ # Format for dataframe
712
+ df_data = [
713
+ [i+1, pred.coordinates[0], pred.coordinates[1], pred.confidence]
714
+ for i, pred in enumerate(predictions)
715
+ ]
716
+
717
+ # Create map
718
+ map_html = analyzer.create_location_map(predictions, f"'{text}' Predictions")
719
+
720
+ return df_data, map_html
721
+ except Exception as e:
722
+ logger.error(f"Error in prediction: {str(e)}")
723
+ return None, f"Error: {str(e)}"
724
+
725
+ # Multi-scale analysis handler
726
+ def handle_scale_analysis(text):
727
+ if not text:
728
+ return "", "", "", "", ""
729
+
730
+ try:
731
+ scale_maps = analyzer.analyze_at_scales(text)
732
+ return (
733
+ scale_maps.get("Street", ""),
734
+ scale_maps.get("Neighborhood", ""),
735
+ scale_maps.get("City", ""),
736
+ scale_maps.get("Region", ""),
737
+ scale_maps.get("Country", "")
738
+ )
739
+ except Exception as e:
740
+ logger.error(f"Error in scale analysis: {str(e)}")
741
+ error_msg = f"<div class='error'>Error: {str(e)}</div>"
742
+ return error_msg, error_msg, error_msg, error_msg, error_msg
743
+
744
+ # Set up event handlers
745
+ predict_btn.click(
746
+ fn=handle_prediction,
747
+ inputs=[location_input, top_k],
748
+ outputs=[prediction_results, map_output]
749
+ )
750
+
751
+ scale_btn.click(
752
+ fn=handle_scale_analysis,
753
+ inputs=[scale_location],
754
+ outputs=[street_map, neighborhood_map, city_map, region_map, country_map]
755
+ )
756
+
757
+ gr.Markdown("""
758
+ ## About GeoCLIP
759
+
760
+ GeoCLIP is a CLIP-inspired model that aligns locations with images for effective worldwide geo-localization.
761
+ This interface uses GeoCLIP's text encoder to map textual descriptions to geographic coordinates.
762
+
763
+ All operations use efficient LRU caching for improved performance on repeated queries.
764
+
765
+ **Reference:** [GeoCLIP: Clip-Inspired Alignment between Locations and Images for Effective Worldwide Geo-localization](https://arxiv.org/abs/2309.16020)
766
+ """)
767
+
768
+ return demo
769
+
770
 
771
  if __name__ == "__main__":
772
+ demo = create_interface()
773
  demo.launch()