latterworks commited on
Commit
4d7f662
·
verified ·
1 Parent(s): 3f2c8e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -761
app.py CHANGED
@@ -1,773 +1,243 @@
 
 
 
 
 
1
  import gradio as gr
2
- import torch
3
- import torch.nn.functional as F
4
- from functools import lru_cache
5
- import folium
6
- from folium.plugins import HeatMap, MarkerCluster
7
- from typing import List, Dict, Tuple, Optional, Union, Any
8
- import numpy as np
9
- import matplotlib.pyplot as plt
10
- import io
11
- import base64
12
- from dataclasses import dataclass
13
  import logging
14
- import warnings
15
- from transformers import CLIPTokenizer
16
- from geoclip import GeoCLIP
17
 
18
- # Configure logging
19
  logging.basicConfig(
20
  level=logging.INFO,
21
- format="%(asctime)s [%(levelname)s] %(message)s"
 
22
  )
23
  logger = logging.getLogger(__name__)
24
 
25
- # Suppress transformer warnings
26
- warnings.filterwarnings("ignore", message="weights_only=False")
27
-
28
- @dataclass
29
- class LocationPrediction:
30
- """Structured container for geographic predictions with confidence metrics."""
31
- coordinates: Tuple[float, float]
32
- confidence: float
33
-
34
-
35
- class GeoCLIPAnalyzer:
36
- """High-performance GeoCLIP analyzer with cached operations and optimized tensor handling."""
37
-
38
- def __init__(self, cache_enabled: bool = True, cache_size: int = 128):
39
- """
40
- Initialize the analyzer with configurable caching.
41
-
42
- Args:
43
- cache_enabled: Toggle for LRU caching mechanism
44
- cache_size: Maximum cache entries per method
45
- """
46
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
- logger.info(f"Initializing GeoCLIP on {self.device}")
48
-
49
- self.model = GeoCLIP().to(self.device)
50
- self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
51
-
52
- # Apply LRU caching to high-compute methods
53
- if cache_enabled:
54
- self.predict_location = lru_cache(maxsize=cache_size)(self.predict_location)
55
- self.analyze_temporal_variations = lru_cache(maxsize=cache_size)(self.analyze_temporal_variations)
56
- self.find_related_locations = lru_cache(maxsize=cache_size)(self.find_related_locations)
57
- logger.info(f"Method caching enabled with size {cache_size}")
58
-
59
- def predict_location(self, text: str, top_k: int = 5) -> List[LocationPrediction]:
60
- """
61
- Generate location predictions with confidence metrics for a text query.
62
-
63
- Args:
64
- text: Descriptive location text query
65
- top_k: Number of predictions to return
66
-
67
- Returns:
68
- List of LocationPrediction objects sorted by confidence
69
- """
70
- with torch.no_grad():
71
- # Generate text embeddings
72
- text_inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.device)
73
- text_features = self.model.image_encoder.mlp(
74
- self.model.image_encoder.CLIP.get_text_features(**text_inputs)
75
- )
76
- text_features = F.normalize(text_features, dim=1)
77
-
78
- # Retrieve and normalize location features
79
- gps_gallery = self.model.gps_gallery.to(self.device)
80
- location_features = self.model.location_encoder(gps_gallery)
81
- location_features = F.normalize(location_features, dim=1)
82
-
83
- # Compute similarity and extract top predictions
84
- similarity = self.model.logit_scale.exp() * (text_features @ location_features.T)
85
- probs = similarity.softmax(dim=-1)
86
- top_pred = torch.topk(probs[0], top_k)
87
-
88
- # Convert to Python native types for consistent serialization
89
- predictions = []
90
- for coord, conf in zip(
91
- gps_gallery[top_pred.indices].cpu().numpy(),
92
- top_pred.values.cpu().numpy()
93
- ):
94
- predictions.append(LocationPrediction(
95
- coordinates=(float(coord[0]), float(coord[1])),
96
- confidence=float(conf)
97
- ))
98
-
99
- return predictions
100
-
101
- def create_location_map(
102
- self,
103
- predictions: List[LocationPrediction],
104
- title: str,
105
- zoom: int = 5,
106
- heatmap: bool = True,
107
- cluster: bool = False
108
- ) -> str:
109
- """
110
- Generate an interactive folium map visualization from location predictions.
111
-
112
- Args:
113
- predictions: List of location predictions
114
- title: Map title/description
115
- zoom: Initial zoom level (higher = more zoomed in)
116
- heatmap: Whether to add heatmap layer
117
- cluster: Whether to cluster nearby markers
118
-
119
- Returns:
120
- HTML string of rendered interactive map
121
- """
122
- # Calculate center from prediction distribution
123
- center_lat = sum(p.coordinates[0] for p in predictions) / len(predictions)
124
- center_lon = sum(p.coordinates[1] for p in predictions) / len(predictions)
125
-
126
- # Initialize map
127
- m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom)
128
-
129
- # Add heatmap if requested
130
- if heatmap and len(predictions) > 1:
131
- heat_data = [
132
- [pred.coordinates[0], pred.coordinates[1], pred.confidence]
133
- for pred in predictions
134
- ]
135
- HeatMap(heat_data, radius=25, blur=15, min_opacity=0.4).add_to(m)
136
-
137
- # Add markers with clustering option
138
- if cluster:
139
- marker_cluster = MarkerCluster().add_to(m)
140
-
141
- for i, pred in enumerate(predictions):
142
- popup_text = f"{title} #{i+1}<br>Coordinates: {pred.coordinates[0]:.6f}, {pred.coordinates[1]:.6f}<br>Confidence: {pred.confidence:.4f}"
143
-
144
- folium.Marker(
145
- location=pred.coordinates,
146
- popup=popup_text,
147
- icon=folium.Icon(color='red' if i == 0 else 'blue')
148
- ).add_to(marker_cluster)
149
- else:
150
- # Add top prediction with distinctive marker
151
- top_pred = predictions[0]
152
- folium.Marker(
153
- location=top_pred.coordinates,
154
- popup=f"{title}<br>Confidence: {top_pred.confidence:.4f}",
155
- icon=folium.Icon(color='red', icon='info-sign')
156
- ).add_to(m)
157
-
158
- # Add remaining predictions with standard markers
159
- for i, pred in enumerate(predictions[1:], 1):
160
- folium.Marker(
161
- location=pred.coordinates,
162
- popup=f"{title} #{i+1}<br>Confidence: {pred.confidence:.4f}",
163
- icon=folium.Icon(color='blue')
164
- ).add_to(m)
165
-
166
- # Return HTML
167
- return m._repr_html_()
168
-
169
- def analyze_at_scales(self, text: str) -> Dict[str, str]:
170
- """
171
- Perform multi-scale geospatial analysis with optimized visualization parameters.
172
-
173
- Args:
174
- text: Location description text
175
-
176
- Returns:
177
- Dictionary mapping scale names to HTML map visualizations
178
- """
179
- # Generate predictions
180
- predictions = self.predict_location(text, top_k=10)
181
-
182
- # Define scale parameters (zoom level, display radius)
183
- scales = {
184
- 'Street': {'zoom': 17, 'radius': 10},
185
- 'Neighborhood': {'zoom': 14, 'radius': 15},
186
- 'City': {'zoom': 11, 'radius': 20},
187
- 'Region': {'zoom': 8, 'radius': 25},
188
- 'Country': {'zoom': 5, 'radius': 30}
189
- }
190
-
191
- # Generate visualizations at each scale
192
- visualizations = {}
193
- for scale_name, params in scales.items():
194
- map_html = self.create_location_map(
195
- predictions=predictions,
196
- title=f"{scale_name} view of {text}",
197
- zoom=params['zoom'],
198
- heatmap=True if scale_name in ['Region', 'Country'] else False
199
- )
200
- visualizations[scale_name] = map_html
201
-
202
- return visualizations
203
-
204
- def analyze_temporal_variations(
205
- self,
206
- base_location: str,
207
- time_periods: List[str]
208
- ) -> Tuple[List[Tuple[str, Tuple[float, float], float]], str]:
209
- """
210
- Analyze location shifts across time periods with trajectory visualization.
211
-
212
- Args:
213
- base_location: Base location descriptor
214
- time_periods: List of time period identifiers
215
-
216
- Returns:
217
- Tuple of (analysis_results, map_html)
218
- """
219
- results = []
220
- m = folium.Map(zoom_start=4)
221
- colors = ['red', 'blue', 'green', 'purple', 'orange', 'darkred', 'darkblue', 'cadetblue', 'darkgreen']
222
-
223
- for period, color in zip(time_periods, colors * (1 + len(time_periods) // len(colors))):
224
- query = f"{base_location} in {period}"
225
- predictions = self.predict_location(query, top_k=1)
226
-
227
- if predictions:
228
- pred = predictions[0]
229
- coords = pred.coordinates
230
- conf = pred.confidence
231
-
232
- # Add marker with period information
233
- folium.Marker(
234
- location=coords,
235
- popup=f"{period}<br>Confidence: {conf:.4f}",
236
- icon=folium.Icon(color=color, icon='info-sign')
237
- ).add_to(m)
238
-
239
- results.append((period, coords, conf))
240
-
241
- # Connect points chronologically with polyline
242
- if len(results) > 1:
243
- points = [coords for _, coords, _ in results]
244
- folium.PolyLine(points, weight=2, color='gray', opacity=0.8,
245
- dash_array='5, 5').add_to(m)
246
-
247
- # Center map on middle point for optimal visualization
248
- if results:
249
- center_point = results[len(results)//2][1]
250
- m.location = center_point
251
-
252
- return results, m._repr_html_()
253
-
254
- def find_related_locations(
255
- self,
256
- reference: str,
257
- candidates: List[str]
258
- ) -> List[Tuple[str, float]]:
259
- """
260
- Identify semantically related locations using embedding cosine similarity.
261
-
262
- Args:
263
- reference: Reference location or concept
264
- candidates: List of candidate locations to compare
265
-
266
- Returns:
267
- List of (location, similarity_score) tuples sorted by relevance
268
- """
269
- with torch.no_grad():
270
- # Generate reference embedding
271
- text_inputs = self.tokenizer(reference, return_tensors="pt", padding=True).to(self.device)
272
- ref_features = self.model.image_encoder.mlp(
273
- self.model.image_encoder.CLIP.get_text_features(**text_inputs)
274
- )
275
- ref_features = F.normalize(ref_features, dim=1)
276
-
277
- results = []
278
- for candidate in candidates:
279
- # Generate candidate embedding
280
- text_inputs = self.tokenizer(candidate, return_tensors="pt", padding=True).to(self.device)
281
- cand_features = self.model.image_encoder.mlp(
282
- self.model.image_encoder.CLIP.get_text_features(**text_inputs)
283
- )
284
- cand_features = F.normalize(cand_features, dim=1)
285
-
286
- # Compute similarity
287
- similarity = F.cosine_similarity(
288
- ref_features, cand_features
289
- ).item()
290
-
291
- results.append((candidate, similarity))
292
-
293
- # Sort by similarity (descending)
294
- return sorted(results, key=lambda x: x[1], reverse=True)
295
-
296
- def visualize_related_locations(
297
- self,
298
- reference: str,
299
- candidates: List[str]
300
- ) -> Tuple[List[Tuple[str, float]], str]:
301
- """
302
- Visualize semantically related locations with map integration.
303
-
304
- Args:
305
- reference: Reference location or concept
306
- candidates: List of candidate locations
307
-
308
- Returns:
309
- Tuple of (similarity_results, map_html)
310
- """
311
- # Compute similarities
312
- related_results = self.find_related_locations(reference, candidates)
313
-
314
- # Predict coordinates for all locations
315
- marker_data = []
316
- ref_predictions = self.predict_location(reference, top_k=1)
317
-
318
- if ref_predictions:
319
- ref_coords = ref_predictions[0].coordinates
320
- ref_conf = ref_predictions[0].confidence
321
- marker_data.append((reference, ref_coords, ref_conf, 'red', 1.0))
322
-
323
- # Get coordinates for each candidate
324
- for candidate, similarity in related_results:
325
- predictions = self.predict_location(candidate, top_k=1)
326
- if predictions:
327
- coords = predictions[0].coordinates
328
- conf = predictions[0].confidence
329
- marker_data.append((candidate, coords, conf, 'blue', similarity))
330
-
331
- # Create map
332
- m = folium.Map()
333
-
334
- for name, coords, conf, color, sim in marker_data:
335
- # Scale marker size by similarity
336
- radius = 8 + (sim * 10) if name != reference else 15
337
-
338
- # Add circle marker
339
- folium.CircleMarker(
340
- location=coords,
341
- radius=radius,
342
- popup=f"{name}<br>Similarity: {sim:.4f}<br>Confidence: {conf:.4f}",
343
- color=color,
344
- fill=True,
345
- fill_color=color
346
- ).add_to(m)
347
-
348
- # Connect to reference with line
349
- if name != reference:
350
- folium.PolyLine(
351
- [ref_coords, coords],
352
- color=color,
353
- weight=sim * 5, # Scale line weight by similarity
354
- opacity=0.7
355
- ).add_to(m)
356
-
357
- # Fit bounds to include all markers
358
- if marker_data:
359
- all_lats = [coords[0] for _, coords, _, _, _ in marker_data]
360
- all_lons = [coords[1] for _, coords, _, _, _ in marker_data]
361
- sw = [min(all_lats), min(all_lons)]
362
- ne = [max(all_lats), max(all_lons)]
363
- m.fit_bounds([sw, ne])
364
-
365
- return related_results, m._repr_html_()
366
-
367
- return related_results, ""
368
-
369
- def comprehensive_analysis(self, location: str) -> Dict[str, Any]:
370
- """
371
- Execute comprehensive multi-faceted location analysis pipeline.
372
-
373
- Args:
374
- location: Target location description
375
-
376
- Returns:
377
- Dictionary containing all analysis results
378
- """
379
- results = {
380
- "query": location,
381
- "timestamp": None, # Can be filled with current timestamp
382
- }
383
-
384
- # Basic prediction
385
- predictions = self.predict_location(location, top_k=5)
386
- results["predictions"] = predictions
387
-
388
- # Create basic map
389
- results["basic_map"] = self.create_location_map(
390
- predictions,
391
- f"'{location}' Predictions"
392
- )
393
-
394
- # Multi-scale analysis
395
- results["scale_maps"] = self.analyze_at_scales(location)
396
-
397
- # Temporal analysis
398
- time_periods = ["ancient times", "middle ages", "19th century", "modern day"]
399
- temporal_results, temporal_map = self.analyze_temporal_variations(location, time_periods)
400
- results["temporal_analysis"] = temporal_results
401
- results["temporal_map"] = temporal_map
402
-
403
- # Related locations analysis
404
- candidates = [
405
- f"{location} business district",
406
- f"{location} historic center",
407
- f"{location} tourist area",
408
- f"{location} downtown",
409
- f"{location} suburbs"
410
- ]
411
- similarity_results, similarity_map = self.visualize_related_locations(
412
- location, candidates
413
- )
414
- results["similarity_analysis"] = similarity_results
415
- results["similarity_map"] = similarity_map
416
-
417
- return results
418
-
419
-
420
- def create_temporal_analysis_ui(analyzer):
421
- """Create the temporal analysis interface component."""
422
- with gr.Column():
423
- gr.Markdown("## Temporal Analysis")
424
- gr.Markdown("Analyze how a location changes across different time periods.")
425
-
426
- with gr.Row():
427
- with gr.Column():
428
- base_location = gr.Textbox(label="Base Location", placeholder="e.g., Constantinople")
429
- with gr.Row():
430
- time_periods = gr.Textbox(
431
- label="Time Periods (comma-separated)",
432
- placeholder="ancient times, middle ages, 19th century, modern day",
433
- value="ancient times, middle ages, 19th century, modern day"
434
- )
435
- temporal_btn = gr.Button("Analyze Temporal Variations", variant="primary")
436
-
437
- with gr.Column():
438
- temporal_results = gr.Dataframe(
439
- headers=["Time Period", "Latitude", "Longitude", "Confidence"],
440
- label="Temporal Analysis Results"
441
- )
442
- temporal_map = gr.HTML(label="Temporal Map")
443
-
444
- def run_temporal_analysis(location, periods_text):
445
- if not location:
446
- return None, "Please enter a base location"
447
-
448
- periods = [p.strip() for p in periods_text.split(",") if p.strip()]
449
- if not periods:
450
- return None, "Please enter at least one time period"
451
-
452
- try:
453
- # Run analysis
454
- results, map_html = analyzer.analyze_temporal_variations(location, periods)
455
-
456
- # Format results for dataframe
457
- df_data = [
458
- [period, coords[0], coords[1], conf]
459
- for period, coords, conf in results
460
- ]
461
-
462
- return df_data, map_html
463
- except Exception as e:
464
- logger.error(f"Error in temporal analysis: {str(e)}")
465
- return None, f"Error: {str(e)}"
466
-
467
- temporal_btn.click(
468
- fn=run_temporal_analysis,
469
- inputs=[base_location, time_periods],
470
- outputs=[temporal_results, temporal_map]
471
- )
472
-
473
- return base_location, time_periods, temporal_btn, temporal_results, temporal_map
474
-
475
-
476
- def create_related_locations_ui(analyzer):
477
- """Create the related locations interface component."""
478
- with gr.Column():
479
- gr.Markdown("## Related Locations Analysis")
480
- gr.Markdown("Find semantically related locations based on GeoCLIP embeddings.")
481
-
482
- with gr.Row():
483
- with gr.Column():
484
- reference_location = gr.Textbox(
485
- label="Reference Location/Concept",
486
- placeholder="e.g., technology hub"
487
- )
488
- candidate_locations = gr.Textbox(
489
- label="Candidate Locations (comma-separated)",
490
- placeholder="Silicon Valley, Shenzhen China, Bangalore India",
491
- value="Silicon Valley, Shenzhen China, Bangalore India, Tel Aviv Israel, London financial district"
492
- )
493
- related_btn = gr.Button("Find Related Locations", variant="primary")
494
-
495
- with gr.Column():
496
- similarity_results = gr.Dataframe(
497
- headers=["Location", "Similarity Score"],
498
- label="Similarity Results"
499
- )
500
- similarity_map = gr.HTML(label="Similarity Map")
501
-
502
- def run_similarity_analysis(reference, candidates_text):
503
- if not reference:
504
- return None, "Please enter a reference location or concept"
505
-
506
- candidates = [c.strip() for c in candidates_text.split(",") if c.strip()]
507
- if not candidates:
508
- return None, "Please enter at least one candidate location"
509
-
510
- try:
511
- # Run analysis
512
- results, map_html = analyzer.visualize_related_locations(reference, candidates)
513
-
514
- # Format results for dataframe
515
- df_data = [
516
- [location, similarity]
517
- for location, similarity in results
518
- ]
519
-
520
- return df_data, map_html
521
- except Exception as e:
522
- logger.error(f"Error in similarity analysis: {str(e)}")
523
- return None, f"Error: {str(e)}"
524
-
525
- related_btn.click(
526
- fn=run_similarity_analysis,
527
- inputs=[reference_location, candidate_locations],
528
- outputs=[similarity_results, similarity_map]
529
- )
530
-
531
- return reference_location, candidate_locations, related_btn, similarity_results, similarity_map
532
-
533
-
534
- def create_comprehensive_analysis_ui(analyzer):
535
- """Create the comprehensive analysis interface component."""
536
- with gr.Column():
537
- gr.Markdown("## Comprehensive Analysis")
538
- gr.Markdown("Perform a full multi-faceted analysis of a location.")
539
-
540
- with gr.Row():
541
- with gr.Column(scale=1):
542
- comp_location = gr.Textbox(
543
- label="Location",
544
- placeholder="e.g., Tokyo Japan"
545
- )
546
- comp_btn = gr.Button("Run Comprehensive Analysis", variant="primary")
547
-
548
- with gr.Column(scale=3):
549
- with gr.Tabs():
550
- with gr.TabItem("Basic Prediction"):
551
- basic_results = gr.Dataframe(
552
- headers=["Rank", "Latitude", "Longitude", "Confidence"],
553
- label="Top Predictions"
554
- )
555
- basic_map = gr.HTML(label="Map")
556
-
557
- with gr.TabItem("Multi-scale Analysis"):
558
- with gr.Tabs() as scale_tabs:
559
- scale_maps = {
560
- scale: gr.HTML(label=f"{scale} Scale")
561
- for scale in ["Street", "Neighborhood", "City", "Region", "Country"]
562
- }
563
-
564
- with gr.TabItem("Temporal Analysis"):
565
- comp_temporal_results = gr.Dataframe(
566
- headers=["Time Period", "Latitude", "Longitude", "Confidence"],
567
- label="Temporal Analysis"
568
- )
569
- comp_temporal_map = gr.HTML(label="Temporal Map")
570
-
571
- with gr.TabItem("Related Contexts"):
572
- comp_similarity_results = gr.Dataframe(
573
- headers=["Context", "Similarity Score"],
574
- label="Related Contexts"
575
- )
576
- comp_similarity_map = gr.HTML(label="Similarity Map")
577
-
578
- def run_comprehensive_analysis(location):
579
- if not location:
580
- return (
581
- None, "",
582
- {"Street": "", "Neighborhood": "", "City": "", "Region": "", "Country": ""},
583
- None, "", None, ""
584
- )
585
-
586
- try:
587
- # Run analysis
588
- results = analyzer.comprehensive_analysis(location)
589
-
590
- # Format basic results
591
- basic_df = [
592
- [i+1, pred.coordinates[0], pred.coordinates[1], pred.confidence]
593
- for i, pred in enumerate(results["predictions"])
594
- ]
595
-
596
- # Format temporal results
597
- temporal_df = [
598
- [period, coords[0], coords[1], conf]
599
- for period, coords, conf in results["temporal_analysis"]
600
- ] if "temporal_analysis" in results else None
601
-
602
- # Format similarity results
603
- similarity_df = [
604
- [location, similarity]
605
- for location, similarity in results["similarity_analysis"]
606
- ] if "similarity_analysis" in results else None
607
-
608
- return (
609
- basic_df,
610
- results["basic_map"],
611
- results["scale_maps"],
612
- temporal_df,
613
- results["temporal_map"],
614
- similarity_df,
615
- results["similarity_map"]
616
- )
617
- except Exception as e:
618
- logger.error(f"Error in comprehensive analysis: {str(e)}")
619
- return (
620
- None, f"Error: {str(e)}",
621
- {"Street": "", "Neighborhood": "", "City": "", "Region": "", "Country": ""},
622
- None, "", None, ""
623
- )
624
-
625
- comp_btn.click(
626
- fn=run_comprehensive_analysis,
627
- inputs=[comp_location],
628
- outputs=[
629
- basic_results, basic_map,
630
- gr.Dict(scale_maps),
631
- comp_temporal_results, comp_temporal_map,
632
- comp_similarity_results, comp_similarity_map
633
- ]
634
- )
635
-
636
- return comp_location, comp_btn, basic_results, basic_map, scale_maps, comp_temporal_results, comp_temporal_map, comp_similarity_results, comp_similarity_map
637
-
638
-
639
- def create_interface():
640
- """Create the Gradio interface for the GeoCLIP Text-to-Location Analyzer."""
641
- # Initialize the analyzer with caching
642
- analyzer = GeoCLIPAnalyzer(cache_enabled=True)
643
-
644
- with gr.Blocks(title="GeoCLIP Text-to-Location Analyzer") as demo:
645
- gr.Markdown("# 🌍 GeoCLIP Text-to-Location Analyzer")
646
- gr.Markdown("""
647
- This interface allows you to analyze geographic locations using GeoCLIP's text-to-location capabilities.
648
- You can perform basic location predictions, temporal analysis, find related locations, and run comprehensive analyses.
649
- """)
650
-
651
- # Basic prediction section
652
- with gr.Column():
653
- gr.Markdown("## Basic Location Prediction")
654
- gr.Markdown("Enter a textual description of a location to get coordinate predictions.")
655
-
656
- with gr.Row():
657
- with gr.Column():
658
- location_input = gr.Textbox(
659
- label="Location Description",
660
- placeholder="e.g., Eiffel Tower Paris"
661
- )
662
- top_k = gr.Slider(
663
- minimum=1, maximum=10, value=5, step=1,
664
- label="Number of Predictions"
665
- )
666
- predict_btn = gr.Button("Predict Location", variant="primary")
667
-
668
- with gr.Column():
669
- prediction_results = gr.Dataframe(
670
- headers=["Rank", "Latitude", "Longitude", "Confidence"],
671
- label="Prediction Results"
672
- )
673
- map_output = gr.HTML(label="Map Visualization")
674
-
675
- # Add tab-based sections for different analyses
676
- with gr.Tabs():
677
- with gr.TabItem("Multi-scale Analysis"):
678
- with gr.Row():
679
- with gr.Column():
680
- scale_location = gr.Textbox(
681
- label="Location Description",
682
- placeholder="e.g., Central Park New York"
683
- )
684
- scale_btn = gr.Button("Analyze at Different Scales", variant="primary")
685
-
686
- with gr.Column():
687
- with gr.Tabs() as scale_tabs:
688
- street_map = gr.HTML(label="Street Level")
689
- neighborhood_map = gr.HTML(label="Neighborhood Level")
690
- city_map = gr.HTML(label="City Level")
691
- region_map = gr.HTML(label="Regional Level")
692
- country_map = gr.HTML(label="Country Level")
693
-
694
- with gr.TabItem("Temporal Analysis"):
695
- base_location, time_periods, temporal_btn, temporal_results, temporal_map = create_temporal_analysis_ui(analyzer)
696
-
697
- with gr.TabItem("Related Locations"):
698
- reference_location, candidate_locations, related_btn, similarity_results, similarity_map = create_related_locations_ui(analyzer)
699
-
700
- with gr.TabItem("Comprehensive Analysis"):
701
- comp_location, comp_btn, basic_results, basic_map, scale_maps, comp_temporal_results, comp_temporal_map, comp_similarity_results, comp_similarity_map = create_comprehensive_analysis_ui(analyzer)
702
-
703
- # Basic prediction handler
704
- def handle_prediction(text, k):
705
- if not text:
706
- return None, "Please enter a location description"
707
-
708
- try:
709
- predictions = analyzer.predict_location(text, top_k=int(k))
710
-
711
- # Format for dataframe
712
- df_data = [
713
- [i+1, pred.coordinates[0], pred.coordinates[1], pred.confidence]
714
- for i, pred in enumerate(predictions)
715
- ]
716
-
717
- # Create map
718
- map_html = analyzer.create_location_map(predictions, f"'{text}' Predictions")
719
-
720
- return df_data, map_html
721
- except Exception as e:
722
- logger.error(f"Error in prediction: {str(e)}")
723
- return None, f"Error: {str(e)}"
724
-
725
- # Multi-scale analysis handler
726
- def handle_scale_analysis(text):
727
- if not text:
728
- return "", "", "", "", ""
729
-
730
  try:
731
- scale_maps = analyzer.analyze_at_scales(text)
732
- return (
733
- scale_maps.get("Street", ""),
734
- scale_maps.get("Neighborhood", ""),
735
- scale_maps.get("City", ""),
736
- scale_maps.get("Region", ""),
737
- scale_maps.get("Country", "")
738
- )
739
  except Exception as e:
740
- logger.error(f"Error in scale analysis: {str(e)}")
741
- error_msg = f"<div class='error'>Error: {str(e)}</div>"
742
- return error_msg, error_msg, error_msg, error_msg, error_msg
743
-
744
- # Set up event handlers
745
- predict_btn.click(
746
- fn=handle_prediction,
747
- inputs=[location_input, top_k],
748
- outputs=[prediction_results, map_output]
749
- )
750
-
751
- scale_btn.click(
752
- fn=handle_scale_analysis,
753
- inputs=[scale_location],
754
- outputs=[street_map, neighborhood_map, city_map, region_map, country_map]
755
- )
756
-
757
- gr.Markdown("""
758
- ## About GeoCLIP
759
-
760
- GeoCLIP is a CLIP-inspired model that aligns locations with images for effective worldwide geo-localization.
761
- This interface uses GeoCLIP's text encoder to map textual descriptions to geographic coordinates.
762
-
763
- All operations use efficient LRU caching for improved performance on repeated queries.
764
-
765
- **Reference:** [GeoCLIP: Clip-Inspired Alignment between Locations and Images for Effective Worldwide Geo-localization](https://arxiv.org/abs/2309.16020)
766
- """)
767
-
768
- return demo
769
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
 
771
  if __name__ == "__main__":
772
- demo = create_interface()
773
- demo.launch()
 
1
+ from pathlib import Path
2
+ from PIL import Image, ExifTags
3
+ import json
4
+ import sys
5
+ import os
6
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
7
  import logging
8
+ from datasets import Dataset
9
+ from typing import Dict, List, Any, Optional
10
+ import traceback
11
 
12
+ # Logging setup
13
  logging.basicConfig(
14
  level=logging.INFO,
15
+ format="%(asctime)s [%(levelname)s] %(message)s",
16
+ handlers=[logging.StreamHandler(sys.stdout)]
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Config with defaults (editable via UI or env vars)
21
+ DEFAULT_IMAGE_DIR = Path(os.environ.get("IMAGE_DIR", "./images"))
22
+ DEFAULT_OUTPUT_FILE = Path(os.environ.get("OUTPUT_METADATA_FILE", "./metadata.jsonl"))
23
+ HF_USERNAME = os.environ.get("HF_USERNAME", "latterworks")
24
+ DATASET_NAME = os.environ.get("DATASET_NAME", "geo-metadata")
25
+
26
+ SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.heic', '.tiff', '.bmp', '.webp'}
27
+
28
+ # Convert GPS coordinates to decimal degrees
29
+ def convert_to_degrees(value: tuple) -> Optional[float]:
30
+ try:
31
+ if not isinstance(value, (tuple, list)) or len(value) != 3:
32
+ raise ValueError("GPS value must be a tuple of 3 elements")
33
+ d, m, s = value
34
+ degrees = float(d) + (float(m) / 60.0) + (float(s) / 3600.0)
35
+ if not -180 <= degrees <= 180:
36
+ raise ValueError("GPS degrees out of valid range")
37
+ return degrees
38
+ except (TypeError, ValueError) as e:
39
+ logger.error(f"Failed to convert GPS coordinates: {e}")
40
+ return None
41
+
42
+ # Extract and format GPS metadata
43
+ def extract_gps_info(gps_info: Dict[int, Any]) -> Optional[Dict[str, Any]]:
44
+ if not isinstance(gps_info, dict):
45
+ logger.warning("GPSInfo ain’t a dict, skipping")
46
+ return None
47
+
48
+ gps_data = {}
49
+ try:
50
+ for key, val in gps_info.items():
51
+ tag_name = ExifTags.GPSTAGS.get(key, f"unknown_gps_tag_{key}")
52
+ gps_data[tag_name] = val
53
+
54
+ if 'GPSLatitude' in gps_data and 'GPSLongitude' in gps_data:
55
+ lat = convert_to_degrees(gps_data['GPSLatitude'])
56
+ lon = convert_to_degrees(gps_data['GPSLongitude'])
57
+ if lat is None or lon is None:
58
+ logger.error("Failed to convert lat/lon, skipping GPS")
59
+ return None
60
+
61
+ lat_ref = gps_data.get('GPSLatitudeRef', 'N')
62
+ lon_ref = gps_data.get('GPSLongitudeRef', 'E')
63
+ if lat_ref not in {'N', 'S'} or lon_ref not in {'E', 'W'}:
64
+ logger.warning(f"Bad GPS ref: {lat_ref}, {lon_ref}")
65
+ else:
66
+ if lat_ref == 'S':
67
+ lat = -lat
68
+ if lon_ref == 'W':
69
+ lon = -lon
70
+
71
+ gps_data['Latitude'] = lat
72
+ gps_data['Longitude'] = lon
73
+
74
+ return gps_data
75
+ except Exception as e:
76
+ logger.error(f"GPS extraction crashed: {traceback.format_exc()}")
77
+ return None
78
+
79
+ # Make stuff JSON-serializable
80
+ def make_serializable(value: Any) -> Any:
81
+ try:
82
+ if hasattr(value, 'numerator') and hasattr(value, 'denominator'):
83
+ return float(value.numerator) / float(value.denominator)
84
+ elif isinstance(value, (tuple, list)):
85
+ return [make_serializable(item) for item in value]
86
+ elif isinstance(value, dict):
87
+ return {str(k): make_serializable(v) for k, v in value.items()}
88
+ elif isinstance(value, bytes):
89
+ return value.decode('utf-8', errors='replace')
90
+ json.dumps(value)
91
+ return value
92
+ except Exception as e:
93
+ logger.warning(f"Serialization failed, stringin’ it: {e}")
94
+ return str(value)
95
+
96
+ # Extract metadata from one image
97
+ def get_image_metadata(image_path: Path) -> Dict[str, Any]:
98
+ metadata = {"file_name": str(image_path.absolute())}
99
+ try:
100
+ with Image.open(image_path) as image:
101
+ metadata.update({
102
+ "format": image.format or "unknown",
103
+ "size": list(image.size),
104
+ "mode": image.mode or "unknown"
105
+ })
106
+
107
+ exif_data = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  try:
109
+ exif_data = image._getexif()
110
+ except AttributeError:
111
+ metadata["exif_error"] = "No EXIF data"
 
 
 
 
 
112
  except Exception as e:
113
+ metadata["exif_error"] = f"EXIF crashed: {str(e)}"
114
+
115
+ if exif_data and isinstance(exif_data, dict):
116
+ for tag_id, value in exif_data.items():
117
+ tag_name = ExifTags.TAGS.get(tag_id, f"tag_{tag_id}").lower()
118
+ if tag_name == "gpsinfo":
119
+ gps_info = extract_gps_info(value)
120
+ if gps_info:
121
+ metadata["gps_info"] = make_serializable(gps_info)
122
+ else:
123
+ metadata[tag_name] = make_serializable(value)
124
+
125
+ metadata["file_size"] = image_path.stat().st_size
126
+ metadata["file_extension"] = image_path.suffix.lower()
127
+ return metadata
128
+ except Exception as e:
129
+ logger.error(f"Image {image_path} crashed: {traceback.format_exc()}")
130
+ return {"file_name": str(image_path.absolute()), "error": str(e)}
131
+
132
+ # Process images (single file or directory)
133
+ def process_images(input_data: str | Path) -> List[Dict[str, Any]]:
134
+ metadata_list = []
135
+ input_path = Path(input_data)
136
+
137
+ if input_path.is_file() and input_path.suffix.lower() in SUPPORTED_EXTENSIONS:
138
+ logger.info(f"Processing single image: {input_path}")
139
+ metadata = get_image_metadata(input_path)
140
+ if metadata:
141
+ metadata_list.append(metadata)
142
+ elif input_path.is_dir():
143
+ logger.info(f"Processing directory: {input_path}")
144
+ for image_path in input_path.rglob("*"):
145
+ if image_path.is_file() and image_path.suffix.lower() in SUPPORTED_EXTENSIONS:
146
+ logger.info(f"Processing: {image_path}")
147
+ metadata = get_image_metadata(image_path)
148
+ if metadata:
149
+ metadata_list.append(metadata)
150
+ else:
151
+ logger.error(f"Invalid input: {input_data}")
152
+ return [{"error": f"Invalid input: {input_data}"}]
153
+
154
+ return metadata_list
155
+
156
+ # Save to JSONL
157
+ def save_metadata_to_jsonl(metadata_list: List[Dict[str, Any]], output_file: Path) -> bool:
158
+ try:
159
+ output_file.parent.mkdir(parents=True, exist_ok=True)
160
+ with output_file.open('w', encoding='utf-8') as f:
161
+ for entry in metadata_list:
162
+ f.write(json.dumps(entry, ensure_ascii=False) + '\n')
163
+ logger.info(f"Saved {len(metadata_list)} entries to {output_file}")
164
+ return True
165
+ except Exception as e:
166
+ logger.error(f"Save crashed: {traceback.format_exc()}")
167
+ return False
168
+
169
+ # Upload to Hugging Face
170
+ def upload_to_huggingface(metadata_file: Path, username: str, dataset_name: str) -> str:
171
+ try:
172
+ metadata_list = []
173
+ with metadata_file.open('r', encoding='utf-8') as f:
174
+ for line in f:
175
+ metadata_list.append(json.loads(line))
176
+
177
+ if not metadata_list:
178
+ return "No metadata to upload, fam!"
179
+
180
+ dataset = Dataset.from_dict({
181
+ "images": [entry.get("file_name") for entry in metadata_list],
182
+ "metadata": metadata_list
183
+ })
184
+ dataset.push_to_hub(f"{username}/{dataset_name}", private=False)
185
+ return f"Uploaded to {username}/{dataset_name} with {len(metadata_list)} entries!"
186
+ except Exception as e:
187
+ logger.error(f"Upload crashed: {traceback.format_exc()}")
188
+ return f"Upload failed: {str(e)}"
189
+
190
+ # Gradio processing function
191
+ def gradio_process(image_file, dir_path: str, username: str, dataset_name: str) -> str:
192
+ output = []
193
+ metadata_list = []
194
+
195
+ # Process single image if uploaded
196
+ if image_file:
197
+ image_path = Path(image_file.name) # Gradio gives temp file path
198
+ metadata_list = process_images(image_path)
199
+ output.append("Single Image Metadata:")
200
+ for entry in metadata_list:
201
+ output.append(json.dumps(entry, indent=2))
202
+
203
+ # Process directory if provided
204
+ if dir_path:
205
+ dir_path = Path(dir_path)
206
+ if dir_path.is_dir():
207
+ metadata_list.extend(process_images(dir_path))
208
+ output.append("Directory Metadata:")
209
+ for entry in metadata_list[-len(process_images(dir_path)):]:
210
+ output.append(json.dumps(entry, indent=2))
211
+ else:
212
+ output.append(f"Error: {dir_path} ain’t a directory, fam!")
213
+
214
+ # Save and upload if we got metadata
215
+ if metadata_list:
216
+ temp_output_file = Path("temp_metadata.jsonl")
217
+ if save_metadata_to_jsonl(metadata_list, temp_output_file):
218
+ output.append(f"Saved metadata to {temp_output_file}")
219
+ upload_result = upload_to_huggingface(temp_output_file, username, dataset_name)
220
+ output.append(upload_result)
221
+ else:
222
+ output.append("Save failed, dawg!")
223
+
224
+ return "\n\n".join(output) if output else "Drop an image or dir, fam!"
225
+
226
+ # Gradio interface
227
+ demo = gr.Interface(
228
+ fn=gradio_process,
229
+ inputs=[
230
+ gr.File(label="Upload Image", file_types=list(SUPPORTED_EXTENSIONS)),
231
+ gr.Textbox(label="Image Directory", placeholder=str(DEFAULT_IMAGE_DIR), value=str(DEFAULT_IMAGE_DIR)),
232
+ gr.Textbox(label="Hugging Face Username", value=HF_USERNAME),
233
+ gr.Textbox(label="Dataset Name", value=DATASET_NAME)
234
+ ],
235
+ outputs=gr.Textbox(label="Metadata Output"),
236
+ title="Geo-Metadata Extractor",
237
+ description="Upload an image or point to a directory to extract metadata and push to Hugging Face, Bay Area style!",
238
+ allow_flagging="never"
239
+ )
240
 
241
  if __name__ == "__main__":
242
+ logger.info("Firin’ up the Gradio geo-metadata extractor...")
243
+ demo.launch(server_name="0.0.0.0", server_port=7860)