latterworks commited on
Commit
f5dce4b
·
verified ·
1 Parent(s): 4658311

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -117
app.py CHANGED
@@ -1,74 +1,77 @@
 
1
  import torch
2
- import numpy as np
3
  import folium
4
  from folium.plugins import HeatMap
5
  import gradio as gr
6
- import os
7
- import PIL.Image
8
- import json
9
- import time
10
- from typing import Dict, Any, Optional, Union
11
- from pathlib import Path
12
- from datasets import Dataset, load_dataset, concatenate_datasets
13
- from huggingface_hub import HfApi
14
 
15
  # GeoCLIP dependencies
16
  from geoclip import GeoCLIP
17
- from transformers import CLIPTokenizer, CLIPProcessor
18
 
19
- # Initialize GeoCLIP core with vectorized execution path
20
- class GeoCLIPCore:
21
- def __init__(self, device=None, dataset_id="latterworks/geo-metadata", token=None):
 
 
 
 
 
 
 
 
 
 
 
22
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
23
- self.dataset_id = dataset_id
24
- self.token = token
25
  self._model = GeoCLIP().to(self.device)
26
  self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
27
- self._processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
28
- self._location_encoder = self._model.location_encoder
29
- self._gps_gallery = None # Lazy-loaded for memory optimization
30
 
31
- # Initialize dataset connection with error handling
32
- try:
33
- self.dataset = load_dataset(self.dataset_id, split="train", token=self.token)
34
- except Exception:
35
- self.dataset = Dataset.from_dict({"filename": [], "classes": [], "metadata": []})
36
-
37
- # Core tensor operations for embedding generation
38
- def text_to_location(self, text, top_k=5):
39
  with torch.no_grad():
 
40
  tokens = self._tokenizer(text, return_tensors="pt", padding=True).to(self.device)
41
- embedding = self._model.image_encoder.mlp(
42
  self._model.image_encoder.CLIP.get_text_features(**tokens)
43
  )
44
- embedding = torch.nn.functional.normalize(embedding, dim=1)
45
 
46
- # Ensure gallery is loaded with memory pooling
47
  if self._gps_gallery is None:
48
  self._gps_gallery = self._model.gps_gallery.to(self.device)
49
 
50
- # Execute vectorized similarity computation
51
- location_embeddings = self._location_encoder(self._gps_gallery)
52
- location_embeddings = torch.nn.functional.normalize(location_embeddings, dim=1)
53
- similarity = self._model.logit_scale.exp() * (embedding @ location_embeddings.T)
 
 
54
  probs = similarity.softmax(dim=-1)
55
 
56
- # Extract predictions with single tensor operation
57
  top_values, top_indices = torch.topk(probs[0], min(top_k, len(self._gps_gallery)))
58
 
59
  return [
60
  {"coordinates": tuple(self._gps_gallery[idx].cpu().numpy()),
61
- "confidence": float(conf)}
62
  for idx, conf in zip(top_indices.cpu().numpy(), top_values.cpu().numpy())
63
  ]
64
-
65
- # Generate map visualization with optimized rendering
66
- def create_map_visualization(self, predictions, title=""):
67
- m = folium.Map(location=predictions[0]["coordinates"], zoom_start=5)
68
- if title:
69
- m.get_root().html.add_child(folium.Element(f'<h3 style="text-align:center">{title}</h3>'))
70
 
71
- # Add markers with confidence metadata
72
  for i, pred in enumerate(predictions):
73
  color = 'red' if i == 0 else 'blue' if i < 3 else 'green'
74
  folium.Marker(
@@ -76,8 +79,7 @@ class GeoCLIPCore:
76
  popup=f"Prediction #{i+1}<br>Confidence: {pred['confidence']:.6f}",
77
  icon=folium.Icon(color=color)
78
  ).add_to(m)
79
-
80
- # Add heatmap for density visualization
81
  if len(predictions) >= 3:
82
  heat_data = [[p["coordinates"][0], p["coordinates"][1], p["confidence"]]
83
  for p in predictions]
@@ -85,82 +87,59 @@ class GeoCLIPCore:
85
 
86
  return m
87
 
88
- # Initialize GeoCLIP and codebase exemplars
89
- def initialize_gradio_interface(hf_token=None):
90
- python_code = """def fib(n):
91
- if n <= 0:
92
- return 0
93
- elif n == 1:
94
- return 1
95
- else:
96
- return fib(n-1) + fib(n-2)
97
- """
98
- js_code = """function fib(n) {
99
- if (n <= 0) return 0;
100
- if (n === 1) return 1;
101
- return fib(n - 1) + fib(n - 2);
102
- }
103
- """
104
- # Initialize GeoCLIP with optimized resource allocation
105
- geo_core = GeoCLIPCore(token=hf_token)
106
-
107
- # Message handler with multimodal dispatch logic
108
- def chat(message, history):
109
- if "python" in message.lower():
110
- return "Type Python or JavaScript to see the code.", gr.Code(language="python", value=python_code)
111
- elif "javascript" in message.lower():
112
- return "Type Python or JavaScript to see the code.", gr.Code(language="javascript", value=js_code)
113
- elif any(kw in message.lower() for kw in ["location", "where", "place", "predict"]):
114
- # Extract location query with pattern matching
115
- for term in ["location", "where", "place", "find", "predict"]:
116
- if term in message.lower():
117
- query = message.lower().split(term, 1)[1].strip()
118
- if not query:
119
- return "Please provide a location description.", None
120
-
121
- # Execute prediction with tensor acceleration
122
- predictions = geo_core.text_to_location(query, top_k=5)
123
- m = geo_core.create_map_visualization(predictions, f"Predictions for: {query}")
124
-
125
- # Format response with structured data
126
- result = f"Top predictions for: '{query}'\n\n"
127
- for i, pred in enumerate(predictions, 1):
128
- coords = pred["coordinates"]
129
- result += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - conf: {pred['confidence']:.6f}\n"
130
-
131
- return result, gr.HTML(value=m._repr_html_())
132
 
133
- return "Couldn't process your location query. Please try again.", None
134
- else:
135
- return "I can show code examples or predict locations. Try 'Where is the Eiffel Tower?'", None
 
 
 
 
 
 
 
 
136
 
137
- # Build gradio blocks with structured layout
138
- with gr.Blocks() as demo:
139
- code = gr.Code(render=False)
140
- map_output = gr.HTML(render=False)
141
-
142
- with gr.Row():
143
- with gr.Column(scale=1):
144
- gr.Markdown("<center><h1>GeoCLIP + Code Examples</h1></center>")
145
- chatbot = gr.ChatInterface(
146
- chat,
147
- examples=["Python", "JavaScript", "Where is the Eiffel Tower?"],
148
- additional_outputs=[code, map_output]
149
- )
150
 
151
- with gr.Column(scale=1):
152
- gr.Markdown("<center><h1>Output Artifacts</h1></center>")
153
- with gr.Tab("Code"):
154
- code.render()
155
- with gr.Tab("Location Map"):
156
- map_output.render()
157
 
158
- gr.Markdown(f"<center>Connected to dataset: {geo_core.dataset_id}</center>")
159
-
160
- return demo
161
 
162
- # Entry point with environmental token acquisition
163
  if __name__ == "__main__":
164
- hf_token = os.environ.get("HF_TOKEN")
165
- demo = initialize_gradio_interface(hf_token)
166
- demo.launch()
 
 
 
 
1
+ import os
2
  import torch
 
3
  import folium
4
  from folium.plugins import HeatMap
5
  import gradio as gr
6
+ from typing import Dict, List, Any
7
+ from functools import lru_cache
 
 
 
 
 
 
8
 
9
  # GeoCLIP dependencies
10
  from geoclip import GeoCLIP
11
+ from transformers import CLIPTokenizer
12
 
13
+ # Singleton pattern for GeoCLIP engine
14
+ class GeoCLIPEngine:
15
+ _instance = None
16
+
17
+ def __new__(cls, *args, **kwargs):
18
+ if cls._instance is None:
19
+ cls._instance = super(GeoCLIPEngine, cls).__new__(cls)
20
+ cls._instance._initialized = False
21
+ return cls._instance
22
+
23
+ def __init__(self, device=None):
24
+ if self._initialized:
25
+ return
26
+
27
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
28
+ print(f"Initializing GeoCLIP on {self.device}")
29
+
30
  self._model = GeoCLIP().to(self.device)
31
  self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
32
+ self._gps_gallery = None # Lazy-loaded on first prediction
 
 
33
 
34
+ self._initialized = True
35
+
36
+ @lru_cache(maxsize=32)
37
+ def predict_location(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
38
+ """Vectorized text-to-location prediction with tensor optimization."""
 
 
 
39
  with torch.no_grad():
40
+ # Generate text embedding with optimal tensor allocation
41
  tokens = self._tokenizer(text, return_tensors="pt", padding=True).to(self.device)
42
+ text_features = self._model.image_encoder.mlp(
43
  self._model.image_encoder.CLIP.get_text_features(**tokens)
44
  )
45
+ text_features = torch.nn.functional.normalize(text_features, dim=1)
46
 
47
+ # Ensure GPS gallery is loaded with resource pooling
48
  if self._gps_gallery is None:
49
  self._gps_gallery = self._model.gps_gallery.to(self.device)
50
 
51
+ # Generate location embeddings with memory-efficient tensor operations
52
+ location_features = self._model.location_encoder(self._gps_gallery)
53
+ location_features = torch.nn.functional.normalize(location_features, dim=1)
54
+
55
+ # Calculate similarity with vectorized matrix multiplication
56
+ similarity = self._model.logit_scale.exp() * (text_features @ location_features.T)
57
  probs = similarity.softmax(dim=-1)
58
 
59
+ # Extract top predictions
60
  top_values, top_indices = torch.topk(probs[0], min(top_k, len(self._gps_gallery)))
61
 
62
  return [
63
  {"coordinates": tuple(self._gps_gallery[idx].cpu().numpy()),
64
+ "confidence": float(conf)}
65
  for idx, conf in zip(top_indices.cpu().numpy(), top_values.cpu().numpy())
66
  ]
67
+
68
+ def create_map_visualization(self, predictions: List[Dict[str, Any]], title: str = "") -> folium.Map:
69
+ """Generate geospatial visualization."""
70
+ # Initialize map centered on highest confidence prediction
71
+ center_coords = predictions[0]["coordinates"]
72
+ m = folium.Map(location=center_coords, zoom_start=5, tiles="OpenStreetMap")
73
 
74
+ # Add markers and heatmap
75
  for i, pred in enumerate(predictions):
76
  color = 'red' if i == 0 else 'blue' if i < 3 else 'green'
77
  folium.Marker(
 
79
  popup=f"Prediction #{i+1}<br>Confidence: {pred['confidence']:.6f}",
80
  icon=folium.Icon(color=color)
81
  ).add_to(m)
82
+
 
83
  if len(predictions) >= 3:
84
  heat_data = [[p["coordinates"][0], p["coordinates"][1], p["confidence"]]
85
  for p in predictions]
 
87
 
88
  return m
89
 
90
+ # Initialize global singleton
91
+ engine = GeoCLIPEngine()
92
+
93
+ # Fixed chat function with proper output handling
94
+ def loc_chat(message, history):
95
+ """Chat function that avoids returning Code objects."""
96
+ # Process location queries
97
+ if any(term in message.lower() for term in ["location", "where", "place", "find"]):
98
+ try:
99
+ # Execute prediction with tensor acceleration
100
+ predictions = engine.predict_location(message, top_k=5)
101
+
102
+ # Generate map visualization
103
+ m = engine.create_map_visualization(predictions, f"Predictions for: {message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Format response with structured information
106
+ result_text = f"Top predictions for: '{message}'\n\n"
107
+ for i, pred in enumerate(predictions, 1):
108
+ coords = pred["coordinates"]
109
+ conf = pred["confidence"]
110
+ result_text += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - confidence: {conf:.6f}\n"
111
+
112
+ # Return only string and HTML types to avoid validation errors
113
+ return result_text, gr.HTML(value=m._repr_html_())
114
+ except Exception as e:
115
+ return f"Error: {str(e)}", None
116
 
117
+ return "Ask about a location like 'Where is the Eiffel Tower?'", None
118
+
119
+ # Interface with minimal dependencies
120
+ with gr.Blocks() as demo:
121
+ map_output = gr.HTML(render=False)
122
+
123
+ with gr.Row():
124
+ with gr.Column():
125
+ gr.Markdown("<h1>GeoCLIP Location Intelligence</h1>")
 
 
 
 
126
 
127
+ chatbot = gr.ChatInterface(
128
+ loc_chat,
129
+ examples=["Where is the Eiffel Tower?", "Find ancient pyramids in desert"],
130
+ additional_outputs=[map_output],
131
+ type="messages" # Critical: use messages type to avoid deprecation
132
+ )
133
 
134
+ with gr.Column():
135
+ gr.Markdown("<h1>Map Visualization</h1>")
136
+ map_output.render()
137
 
138
+ # Main entrypoint with error mitigation configuration
139
  if __name__ == "__main__":
140
+ demo.launch(
141
+ share=True,
142
+ server_name="0.0.0.0",
143
+ cache_examples=False, # Critical: Disable example caching
144
+ show_error=True
145
+ )