latterworks commited on
Commit
52e39e5
·
verified ·
1 Parent(s): 50428b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +399 -25
app.py CHANGED
@@ -1,29 +1,403 @@
1
- import gradio as gr
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
3
  from geoclip import GeoCLIP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Load the GeoCLIP model
6
- model = GeoCLIP()
7
-
8
- # Define the function for geolocation prediction
9
- def predict_location(image_path):
10
- top_pred_gps, top_pred_prob = model.predict(image_path, top_k=5)
11
- results = []
12
- for i in range(5):
13
- lat, lon = top_pred_gps[i]
14
- prob = top_pred_prob[i]
15
- results.append(f"Prediction {i+1}: ({lat:.6f}, {lon:.6f}) | Probability: {prob:.6f}")
16
- return "\n".join(results)
17
-
18
- # Define Gradio interface
19
- interface = gr.Interface(
20
- fn=predict_location,
21
- inputs=gr.Image(type="filepath", label="Upload Image"),
22
- outputs=gr.Textbox(label="Predicted Locations"),
23
- title="GeoCLIP Geolocation",
24
- description="Upload an image, and GeoCLIP will predict the top 5 GPS locations."
25
- )
26
-
27
- # Launch the Gradio app
28
  if __name__ == "__main__":
29
- interface.launch()
 
 
1
+
2
  import torch
3
+ import numpy as np
4
+ import folium
5
+ from folium.plugins import HeatMap, MarkerCluster
6
+ import gradio as gr
7
+ import os
8
+ import PIL.Image
9
+ from io import BytesIO
10
+ import base64
11
+ from typing import Tuple, List, Dict, Any, Optional, Union
12
+ from pathlib import Path
13
+
14
+ # GeoCLIP dependencies
15
  from geoclip import GeoCLIP
16
+ from transformers import CLIPTokenizer, CLIPProcessor
17
+
18
+
19
+ class GeoCLIPCore:
20
+ """
21
+ Vectorized GeoCLIP implementation with minimal compute overhead.
22
+
23
+ Implements tensor-optimized inference for:
24
+ 1. Text-to-location prediction
25
+ 2. Image-to-location prediction
26
+ 3. Coordinate embedding generation
27
+ 4. Cross-modal similarity analysis
28
+ """
29
+
30
+ def __init__(self, device: Optional[str] = None) -> None:
31
+ """
32
+ Initialize model with optimal compute resource allocation.
33
+
34
+ Args:
35
+ device: Target compute device (None for auto-detection)
36
+ """
37
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # Load and configure core model components
40
+ self._model = GeoCLIP().to(self.device)
41
+ self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
42
+ self._processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
43
+
44
+ # Cache frequently used components for performance
45
+ self._location_encoder = self._model.location_encoder
46
+ self._image_encoder = self._model.image_encoder
47
+ self._gps_gallery = None # Lazy-loaded on first prediction
48
+
49
+ print(f"GeoCLIP initialized on {self.device}")
50
+
51
+ def embed_text(self, text: str) -> torch.Tensor:
52
+ """
53
+ Generate normalized embedding for text input using vectorized operations.
54
+
55
+ Args:
56
+ text: Text description to encode
57
+
58
+ Returns:
59
+ L2-normalized embedding tensor (shape: [1, 512])
60
+ """
61
+ with torch.no_grad():
62
+ tokens = self._tokenizer(text, return_tensors="pt", padding=True).to(self.device)
63
+ embedding = self._model.image_encoder.mlp(
64
+ self._model.image_encoder.CLIP.get_text_features(**tokens)
65
+ )
66
+ return torch.nn.functional.normalize(embedding, dim=1)
67
+
68
+ def embed_image(self, image: Union[str, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
69
+ """
70
+ Generate normalized embedding for image input using vectorized operations.
71
+
72
+ Args:
73
+ image: Input image (PIL Image, file path, or numpy array)
74
+
75
+ Returns:
76
+ L2-normalized embedding tensor (shape: [1, 512])
77
+ """
78
+ with torch.no_grad():
79
+ # Process different image input types
80
+ if isinstance(image, str):
81
+ # Path to image file
82
+ image = PIL.Image.open(image).convert("RGB")
83
+ elif isinstance(image, np.ndarray):
84
+ # Convert numpy array to PIL Image
85
+ image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
86
+
87
+ # Process image using CLIP processor
88
+ inputs = self._processor(images=image, return_tensors="pt").to(self.device)
89
+ embedding = self._model.image_encoder(inputs.pixel_values)
90
+ return torch.nn.functional.normalize(embedding, dim=1)
91
+
92
+ def embed_coordinates(self, coords: Tuple[float, float]) -> torch.Tensor:
93
+ """
94
+ Generate normalized embedding for geographical coordinates.
95
+
96
+ Args:
97
+ coords: Coordinate pair (latitude, longitude)
98
+
99
+ Returns:
100
+ L2-normalized embedding tensor (shape: [1, 512])
101
+ """
102
+ with torch.no_grad():
103
+ coords_tensor = torch.tensor([coords], dtype=torch.float32).to(self.device)
104
+ embedding = self._location_encoder(coords_tensor)
105
+ return torch.nn.functional.normalize(embedding, dim=1)
106
+
107
+ def _ensure_gps_gallery(self):
108
+ """Ensure GPS gallery is loaded and cached for efficient reuse."""
109
+ if self._gps_gallery is None:
110
+ self._gps_gallery = self._model.gps_gallery.to(self.device)
111
+
112
+ def predict_location(self,
113
+ query_embedding: torch.Tensor,
114
+ top_k: int = 5) -> List[Dict[str, Any]]:
115
+ """
116
+ Execute cosine similarity-based location retrieval against GPS gallery.
117
+
118
+ Args:
119
+ query_embedding: L2-normalized query embedding
120
+ top_k: Number of top predictions to return
121
+
122
+ Returns:
123
+ List of prediction dictionaries with coordinates and confidence scores
124
+ """
125
+ with torch.no_grad():
126
+ # Ensure GPS gallery is loaded
127
+ self._ensure_gps_gallery()
128
+
129
+ # Generate location embeddings
130
+ location_embeddings = self._location_encoder(self._gps_gallery)
131
+ location_embeddings = torch.nn.functional.normalize(location_embeddings, dim=1)
132
+
133
+ # Calculate similarity and softmax probabilities
134
+ similarity = self._model.logit_scale.exp() * (query_embedding @ location_embeddings.T)
135
+ probs = similarity.softmax(dim=-1)
136
+
137
+ # Extract top predictions
138
+ top_values, top_indices = torch.topk(probs[0], min(top_k, len(self._gps_gallery)))
139
+
140
+ # Format results
141
+ predictions = []
142
+ for idx, confidence in zip(top_indices.cpu().numpy(), top_values.cpu().numpy()):
143
+ predictions.append({
144
+ "coordinates": tuple(self._gps_gallery[idx].cpu().numpy()),
145
+ "confidence": float(confidence)
146
+ })
147
+
148
+ return predictions
149
+
150
+ def text_to_location(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
151
+ """
152
+ Primary entry point for text-to-location prediction pipeline.
153
+
154
+ Args:
155
+ text: Text description to predict location for
156
+ top_k: Number of top predictions to return
157
+
158
+ Returns:
159
+ List of prediction dictionaries with coordinates and confidence scores
160
+ """
161
+ embedding = self.embed_text(text)
162
+ return self.predict_location(embedding, top_k)
163
+
164
+ def image_to_location(self, image: Union[str, PIL.Image.Image, np.ndarray], top_k: int = 5) -> List[Dict[str, Any]]:
165
+ """
166
+ Primary entry point for image-to-location prediction pipeline.
167
+
168
+ Args:
169
+ image: Input image (PIL Image, file path, or numpy array)
170
+ top_k: Number of top predictions to return
171
+
172
+ Returns:
173
+ List of prediction dictionaries with coordinates and confidence scores
174
+ """
175
+ embedding = self.embed_image(image)
176
+ return self.predict_location(embedding, top_k)
177
+
178
+ def compute_similarity(self, embed1: torch.Tensor, embed2: torch.Tensor) -> float:
179
+ """
180
+ Compute cosine similarity between two embeddings.
181
+
182
+ Args:
183
+ embed1: First embedding tensor
184
+ embed2: Second embedding tensor
185
+
186
+ Returns:
187
+ Similarity score between 0 and 1
188
+ """
189
+ return float(torch.nn.functional.cosine_similarity(embed1, embed2).item())
190
+
191
+ def create_map_visualization(self,
192
+ predictions: List[Dict[str, Any]],
193
+ title: str = "",
194
+ cluster: bool = False) -> folium.Map:
195
+ """
196
+ Generate geospatial visualization of prediction results.
197
+
198
+ Args:
199
+ predictions: List of prediction dictionaries
200
+ title: Optional map title
201
+ cluster: Whether to cluster nearby markers
202
+
203
+ Returns:
204
+ Folium map object with marker and heatmap layers
205
+ """
206
+ # Initialize map centered on highest confidence prediction
207
+ center_coords = predictions[0]["coordinates"]
208
+ m = folium.Map(location=center_coords, zoom_start=5, tiles="OpenStreetMap")
209
+
210
+ # Add title if provided
211
+ if title:
212
+ title_html = f'<h3 style="text-align:center">{title}</h3>'
213
+ m.get_root().html.add_child(folium.Element(title_html))
214
+
215
+ # Create marker cluster if requested
216
+ marker_group = MarkerCluster() if cluster else m
217
+
218
+ # Add markers with confidence metadata
219
+ for i, pred in enumerate(predictions):
220
+ color = 'red' if i == 0 else 'blue' if i < 3 else 'green'
221
+
222
+ folium.Marker(
223
+ location=pred["coordinates"],
224
+ popup=f"Prediction #{i+1}<br>Confidence: {pred['confidence']:.6f}",
225
+ icon=folium.Icon(color=color)
226
+ ).add_to(marker_group if cluster else m)
227
+
228
+ # Add marker cluster to map if used
229
+ if cluster:
230
+ m.add_child(marker_group)
231
+
232
+ # Add heatmap layer for visual density representation
233
+ if len(predictions) >= 3:
234
+ heat_data = [[p["coordinates"][0], p["coordinates"][1], p["confidence"]]
235
+ for p in predictions]
236
+ HeatMap(heat_data, radius=15, blur=10).add_to(m)
237
+
238
+ return m
239
+
240
+
241
+ def launch_gradio_interface():
242
+ """Deploy GeoCLIP with Gradio interface for both text and image inputs."""
243
+ # Initialize model with optimal compute configuration
244
+ geo_core = GeoCLIPCore()
245
+
246
+ def predict_from_text(text_query, top_k):
247
+ """Process text query and generate visualization."""
248
+ if not text_query.strip():
249
+ return None, "Please enter a location description."
250
+
251
+ # Execute prediction pipeline
252
+ predictions = geo_core.text_to_location(text_query, top_k=int(top_k))
253
+
254
+ # Generate map visualization
255
+ m = geo_core.create_map_visualization(
256
+ predictions,
257
+ title=f"Predictions for: {text_query}"
258
+ )
259
+
260
+ # Create temporary HTML file for map
261
+ map_html = m._repr_html_()
262
+
263
+ # Format textual results
264
+ result_text = f"Top predictions for: '{text_query}'\n\n"
265
+ for i, pred in enumerate(predictions, 1):
266
+ coords = pred["coordinates"]
267
+ conf = pred["confidence"]
268
+ result_text += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - confidence: {conf:.6f}\n"
269
+
270
+ return map_html, result_text
271
+
272
+ def predict_from_image(image, top_k):
273
+ """Process image input and generate visualization."""
274
+ if image is None:
275
+ return None, "Please upload an image."
276
+
277
+ # Execute prediction pipeline
278
+ predictions = geo_core.image_to_location(image, top_k=int(top_k))
279
+
280
+ # Generate map visualization
281
+ m = geo_core.create_map_visualization(
282
+ predictions,
283
+ title="Predictions from Image"
284
+ )
285
+
286
+ # Create temporary HTML file for map
287
+ map_html = m._repr_html_()
288
+
289
+ # Format textual results
290
+ result_text = "Top predictions from image:\n\n"
291
+ for i, pred in enumerate(predictions, 1):
292
+ coords = pred["coordinates"]
293
+ conf = pred["confidence"]
294
+ result_text += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - confidence: {conf:.6f}\n"
295
+
296
+ return map_html, result_text
297
+
298
+ def compute_text_similarity(text1, text2):
299
+ """Compute semantic similarity between two text descriptions."""
300
+ if not text1.strip() or not text2.strip():
301
+ return "Please enter both text descriptions."
302
+
303
+ embed1 = geo_core.embed_text(text1)
304
+ embed2 = geo_core.embed_text(text2)
305
+
306
+ similarity = geo_core.compute_similarity(embed1, embed2)
307
+ return f"Similarity between the texts: {similarity:.4f} (range: 0-1)"
308
+
309
+ # Create Gradio interface with tabs for different functions
310
+ with gr.Blocks(title="GeoCLIP Location Intelligence") as demo:
311
+ gr.Markdown("# GeoCLIP Location Intelligence")
312
+ gr.Markdown("Predict locations from text descriptions or images.")
313
+
314
+ with gr.Tabs():
315
+ with gr.TabItem("Text → Location"):
316
+ with gr.Row():
317
+ with gr.Column():
318
+ text_input = gr.Textbox(
319
+ lines=3,
320
+ placeholder="Enter location description...",
321
+ label="Location Description"
322
+ )
323
+ text_top_k = gr.Slider(
324
+ minimum=1,
325
+ maximum=20,
326
+ value=10,
327
+ step=1,
328
+ label="Number of Predictions"
329
+ )
330
+ text_submit = gr.Button("Predict Location")
331
+
332
+ with gr.Column():
333
+ text_examples = gr.Examples(
334
+ examples=[
335
+ "ancient pyramids in desert",
336
+ "Eiffel Tower in Paris",
337
+ "beach resort with palm trees",
338
+ "technology hub with startups",
339
+ "busy downtown with skyscrapers",
340
+ "mountain with snow and ski slopes",
341
+ "tropical island with clear water"
342
+ ],
343
+ inputs=text_input
344
+ )
345
+
346
+ text_map_output = gr.HTML(label="Map Visualization")
347
+ text_result_output = gr.Textbox(label="Prediction Results")
348
+
349
+ text_submit.click(
350
+ predict_from_text,
351
+ inputs=[text_input, text_top_k],
352
+ outputs=[text_map_output, text_result_output]
353
+ )
354
+
355
+ with gr.TabItem("Image → Location"):
356
+ with gr.Row():
357
+ with gr.Column():
358
+ image_input = gr.Image(type="pil", label="Upload Image")
359
+ image_top_k = gr.Slider(
360
+ minimum=1,
361
+ maximum=20,
362
+ value=10,
363
+ step=1,
364
+ label="Number of Predictions"
365
+ )
366
+ image_submit = gr.Button("Predict Location")
367
+
368
+ image_map_output = gr.HTML(label="Map Visualization")
369
+ image_result_output = gr.Textbox(label="Prediction Results")
370
+
371
+ image_submit.click(
372
+ predict_from_image,
373
+ inputs=[image_input, image_top_k],
374
+ outputs=[image_map_output, image_result_output]
375
+ )
376
+
377
+ with gr.TabItem("Semantic Similarity"):
378
+ text1_input = gr.Textbox(
379
+ lines=2,
380
+ placeholder="Enter first description...",
381
+ label="Text Description 1"
382
+ )
383
+ text2_input = gr.Textbox(
384
+ lines=2,
385
+ placeholder="Enter second description...",
386
+ label="Text Description 2"
387
+ )
388
+ similarity_submit = gr.Button("Compute Similarity")
389
+ similarity_output = gr.Textbox(label="Similarity Result")
390
+
391
+ similarity_submit.click(
392
+ compute_text_similarity,
393
+ inputs=[text1_input, text2_input],
394
+ outputs=similarity_output
395
+ )
396
+
397
+ # Launch Gradio interface with optimized server settings
398
+ demo.launch(share=True, server_name="0.0.0.0")
399
+
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  if __name__ == "__main__":
402
+ # Execute vectorized deployment pipeline
403
+ launch_gradio_interface()