latterworks commited on
Commit
87b4d97
·
verified ·
1 Parent(s): f5dce4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +643 -123
app.py CHANGED
@@ -1,145 +1,665 @@
 
 
 
 
 
1
  import os
2
- import torch
 
3
  import folium
4
- from folium.plugins import HeatMap
5
- import gradio as gr
6
- from typing import Dict, List, Any
7
- from functools import lru_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # GeoCLIP dependencies
10
- from geoclip import GeoCLIP
11
- from transformers import CLIPTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Singleton pattern for GeoCLIP engine
14
- class GeoCLIPEngine:
15
- _instance = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def __new__(cls, *args, **kwargs):
18
- if cls._instance is None:
19
- cls._instance = super(GeoCLIPEngine, cls).__new__(cls)
20
- cls._instance._initialized = False
21
- return cls._instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def __init__(self, device=None):
24
- if self._initialized:
25
- return
26
-
27
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
28
- print(f"Initializing GeoCLIP on {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- self._model = GeoCLIP().to(self.device)
31
- self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
32
- self._gps_gallery = None # Lazy-loaded on first prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- self._initialized = True
35
 
36
- @lru_cache(maxsize=32)
37
- def predict_location(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
38
- """Vectorized text-to-location prediction with tensor optimization."""
39
- with torch.no_grad():
40
- # Generate text embedding with optimal tensor allocation
41
- tokens = self._tokenizer(text, return_tensors="pt", padding=True).to(self.device)
42
- text_features = self._model.image_encoder.mlp(
43
- self._model.image_encoder.CLIP.get_text_features(**tokens)
44
- )
45
- text_features = torch.nn.functional.normalize(text_features, dim=1)
46
-
47
- # Ensure GPS gallery is loaded with resource pooling
48
- if self._gps_gallery is None:
49
- self._gps_gallery = self._model.gps_gallery.to(self.device)
50
-
51
- # Generate location embeddings with memory-efficient tensor operations
52
- location_features = self._model.location_encoder(self._gps_gallery)
53
- location_features = torch.nn.functional.normalize(location_features, dim=1)
 
54
 
55
- # Calculate similarity with vectorized matrix multiplication
56
- similarity = self._model.logit_scale.exp() * (text_features @ location_features.T)
57
- probs = similarity.softmax(dim=-1)
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Extract top predictions
60
- top_values, top_indices = torch.topk(probs[0], min(top_k, len(self._gps_gallery)))
 
 
 
 
 
 
 
61
 
62
- return [
63
- {"coordinates": tuple(self._gps_gallery[idx].cpu().numpy()),
64
- "confidence": float(conf)}
65
- for idx, conf in zip(top_indices.cpu().numpy(), top_values.cpu().numpy())
66
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- def create_map_visualization(self, predictions: List[Dict[str, Any]], title: str = "") -> folium.Map:
69
- """Generate geospatial visualization."""
70
- # Initialize map centered on highest confidence prediction
71
- center_coords = predictions[0]["coordinates"]
72
- m = folium.Map(location=center_coords, zoom_start=5, tiles="OpenStreetMap")
73
-
74
- # Add markers and heatmap
75
- for i, pred in enumerate(predictions):
76
- color = 'red' if i == 0 else 'blue' if i < 3 else 'green'
77
- folium.Marker(
78
- location=pred["coordinates"],
79
- popup=f"Prediction #{i+1}<br>Confidence: {pred['confidence']:.6f}",
80
- icon=folium.Icon(color=color)
81
- ).add_to(m)
82
-
83
- if len(predictions) >= 3:
84
- heat_data = [[p["coordinates"][0], p["coordinates"][1], p["confidence"]]
85
- for p in predictions]
86
- HeatMap(heat_data, radius=15, blur=10).add_to(m)
87
 
88
- return m
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Initialize global singleton
91
- engine = GeoCLIPEngine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- # Fixed chat function with proper output handling
94
- def loc_chat(message, history):
95
- """Chat function that avoids returning Code objects."""
96
- # Process location queries
97
- if any(term in message.lower() for term in ["location", "where", "place", "find"]):
98
- try:
99
- # Execute prediction with tensor acceleration
100
- predictions = engine.predict_location(message, top_k=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # Generate map visualization
103
- m = engine.create_map_visualization(predictions, f"Predictions for: {message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Format response with structured information
106
- result_text = f"Top predictions for: '{message}'\n\n"
107
- for i, pred in enumerate(predictions, 1):
108
- coords = pred["coordinates"]
109
- conf = pred["confidence"]
110
- result_text += f"{i}. ({coords[0]:.6f}, {coords[1]:.6f}) - confidence: {conf:.6f}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- # Return only string and HTML types to avoid validation errors
113
- return result_text, gr.HTML(value=m._repr_html_())
114
- except Exception as e:
115
- return f"Error: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- return "Ask about a location like 'Where is the Eiffel Tower?'", None
118
 
119
- # Interface with minimal dependencies
120
- with gr.Blocks() as demo:
121
- map_output = gr.HTML(render=False)
122
-
123
- with gr.Row():
124
- with gr.Column():
125
- gr.Markdown("<h1>GeoCLIP Location Intelligence</h1>")
126
-
127
- chatbot = gr.ChatInterface(
128
- loc_chat,
129
- examples=["Where is the Eiffel Tower?", "Find ancient pyramids in desert"],
130
- additional_outputs=[map_output],
131
- type="messages" # Critical: use messages type to avoid deprecation
132
- )
133
-
134
- with gr.Column():
135
- gr.Markdown("<h1>Map Visualization</h1>")
136
- map_output.render()
137
-
138
- # Main entrypoint with error mitigation configuration
139
  if __name__ == "__main__":
140
- demo.launch(
141
- share=True,
142
- server_name="0.0.0.0",
143
- cache_examples=False, # Critical: Disable example caching
144
- show_error=True
145
- )
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from PIL import Image, ExifTags
4
+ import json
5
+ import sys
6
  import os
7
+ import logging
8
+ import traceback
9
  import folium
10
+ from folium.plugins import MarkerCluster
11
+ import pandas as pd
12
+ import io
13
+ import base64
14
+ from typing import Dict, List, Any, Optional, Tuple, Union
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ from datasets import Dataset
18
+ from geoclip import LocationEncoder
19
+ import torch
20
+
21
+ # Set up logging to capture all events and errors
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format="%(asctime)s [%(levelname)s] %(message)s",
25
+ handlers=[logging.StreamHandler(sys.stdout)]
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Configuration with environment variable fallback
30
+ DEFAULT_IMAGE_DIR = os.environ.get("IMAGE_DIR", "./images")
31
+ OUTPUT_METADATA_FILE = Path(os.environ.get("OUTPUT_METADATA_FILE", "./metadata.jsonl"))
32
+ HF_USERNAME = os.environ.get("HF_USERNAME", "latterworks")
33
+ DATASET_NAME = os.environ.get("DATASET_NAME", "geo-metadata")
34
+
35
+ # Supported image extensions
36
+ SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.heic', '.tiff', '.bmp', '.webp'}
37
+
38
+ # Convert GPS coordinates to decimal degrees
39
+ def convert_to_degrees(value: tuple) -> Optional[float]:
40
+ try:
41
+ if not isinstance(value, (tuple, list)) or len(value) != 3:
42
+ raise ValueError("GPS value must be a tuple of 3 elements")
43
+ d, m, s = value
44
+ degrees = float(d) + (float(m) / 60.0) + (float(s) / 3600.0)
45
+ if not -180 <= degrees <= 180: # Basic sanity check
46
+ raise ValueError("GPS degrees out of valid range")
47
+ return degrees
48
+ except (TypeError, ValueError) as e:
49
+ logger.error(f"Failed to convert GPS coordinates: {e}")
50
+ return None
51
+
52
+ # Extract and format GPS metadata
53
+ def extract_gps_info(gps_info: Dict[int, Any]) -> Optional[Dict[str, Any]]:
54
+ if not isinstance(gps_info, dict):
55
+ logger.warning("GPSInfo is not a dictionary, skipping")
56
+ return None
57
+
58
+ gps_data = {}
59
+ try:
60
+ for key, val in gps_info.items():
61
+ tag_name = ExifTags.GPSTAGS.get(key, f"unknown_gps_tag_{key}")
62
+ gps_data[tag_name] = val
63
+
64
+ if 'GPSLatitude' in gps_data and 'GPSLongitude' in gps_data:
65
+ lat = convert_to_degrees(gps_data['GPSLatitude'])
66
+ lon = convert_to_degrees(gps_data['GPSLongitude'])
67
+ if lat is None or lon is None:
68
+ logger.error("Failed to convert latitude or longitude, skipping GPS data")
69
+ return None
70
+
71
+ lat_ref = gps_data.get('GPSLatitudeRef', 'N')
72
+ lon_ref = gps_data.get('GPSLongitudeRef', 'E')
73
+ if lat_ref not in {'N', 'S'} or lon_ref not in {'E', 'W'}:
74
+ logger.warning(f"Invalid GPS reference: {lat_ref}, {lon_ref}")
75
+ else:
76
+ if lat_ref == 'S':
77
+ lat = -lat
78
+ if lon_ref == 'W':
79
+ lon = -lon
80
+
81
+ gps_data['Latitude'] = lat
82
+ gps_data['Longitude'] = lon
83
+
84
+ return gps_data
85
+ except Exception as e:
86
+ logger.error(f"Error extracting GPS info: {traceback.format_exc()}")
87
+ return None
88
 
89
+ # Convert non-serializable objects to JSON-serializable types
90
+ def make_serializable(value: Any) -> Any:
91
+ try:
92
+ if hasattr(value, 'numerator') and hasattr(value, 'denominator'): # PIL IFDRational
93
+ return float(value.numerator) / float(value.denominator)
94
+ elif isinstance(value, (tuple, list)):
95
+ return [make_serializable(item) for item in value]
96
+ elif isinstance(value, dict):
97
+ return {str(k): make_serializable(v) for k, v in value.items()}
98
+ elif isinstance(value, bytes):
99
+ return value.decode('utf-8', errors='replace')
100
+ json.dumps(value) # Test serialization
101
+ return value
102
+ except Exception as e:
103
+ logger.warning(f"Converting to string due to serialization failure: {e}")
104
+ return str(value)
105
 
106
+ # Extract metadata from an image
107
+ def get_image_metadata(image_path: Path) -> Dict[str, Any]:
108
+ metadata = {"file_name": str(image_path.absolute())}
109
+ try:
110
+ with Image.open(image_path) as image:
111
+ metadata.update({
112
+ "format": image.format or "unknown",
113
+ "size": list(image.size),
114
+ "mode": image.mode or "unknown"
115
+ })
116
+
117
+ exif_data = None
118
+ try:
119
+ exif_data = image._getexif()
120
+ except AttributeError:
121
+ metadata["exif_error"] = "No EXIF data available"
122
+ except Exception as e:
123
+ metadata["exif_error"] = f"EXIF extraction failed: {str(e)}"
124
+
125
+ if exif_data and isinstance(exif_data, dict):
126
+ for tag_id, value in exif_data.items():
127
+ try:
128
+ tag_name = ExifTags.TAGS.get(tag_id, f"tag_{tag_id}").lower()
129
+ if tag_name == "gpsinfo":
130
+ gps_info = extract_gps_info(value)
131
+ if gps_info:
132
+ metadata["gps_info"] = make_serializable(gps_info)
133
+ else:
134
+ metadata[tag_name] = make_serializable(value)
135
+ except Exception as e:
136
+ metadata[f"error_tag_{tag_id}"] = str(e)
137
+
138
+ metadata["file_size"] = image_path.stat().st_size
139
+ metadata["file_extension"] = image_path.suffix.lower()
140
+
141
+ try:
142
+ json.dumps(metadata)
143
+ except Exception as e:
144
+ logger.error(f"Serialization failed for {image_path}: {e}")
145
+ clean_metadata = {k: v for k, v in metadata.items() if k in {"file_name", "format", "size", "mode", "file_size", "file_extension"}}
146
+ clean_metadata["serialization_error"] = str(e)
147
+ return clean_metadata
148
+
149
+ return metadata
150
+
151
+ except Exception as e:
152
+ logger.error(f"Error processing {image_path}: {traceback.format_exc()}")
153
+ return {"file_name": str(image_path.absolute()), "error": str(e)}
154
+
155
+ # Process all images in the directory
156
+ def process_images(image_dir: Union[str, Path]) -> List[Dict[str, Any]]:
157
+ if isinstance(image_dir, str):
158
+ image_dir = Path(image_dir)
159
 
160
+ if not image_dir.is_dir():
161
+ logger.error(f"Invalid or non-existent directory: {image_dir}")
162
+ return []
163
+
164
+ metadata_list = []
165
+ for image_path in image_dir.rglob("*"): # Recursive search
166
+ if image_path.is_file() and image_path.suffix.lower() in SUPPORTED_EXTENSIONS:
167
+ logger.info(f"Processing: {image_path}")
168
+ try:
169
+ metadata = get_image_metadata(image_path)
170
+ if metadata:
171
+ metadata_list.append(metadata)
172
+ except Exception as e:
173
+ logger.error(f"Unexpected error processing {image_path}: {traceback.format_exc()}")
174
+ metadata_list.append({"file_name": str(image_path.absolute()), "error": str(e)})
175
+
176
+ return metadata_list
177
+
178
+ # Save metadata to JSONL file
179
+ def save_metadata_to_jsonl(metadata_list: List[Dict[str, Any]], output_file: Path) -> bool:
180
+ try:
181
+ with output_file.open('w', encoding='utf-8') as f:
182
+ for entry in metadata_list:
183
+ try:
184
+ f.write(json.dumps(entry, ensure_ascii=False) + '\n')
185
+ except Exception as e:
186
+ logger.error(f"Failed to write entry for {entry.get('file_name', 'unknown')}: {e}")
187
+ f.write(json.dumps({"file_name": entry.get("file_name", "unknown"), "error": str(e)}) + '\n')
188
+ logger.info(f"Metadata saved to {output_file} with {len(metadata_list)} entries")
189
+ return True
190
+ except Exception as e:
191
+ logger.error(f"Failed to save metadata to {output_file}: {traceback.format_exc()}")
192
+ return False
193
+
194
+ # Upload dataset to Hugging Face Hub
195
+ def upload_to_huggingface(metadata_file: Path, username: str, dataset_name: str) -> bool:
196
+ try:
197
+ metadata_list = []
198
+ with metadata_file.open('r', encoding='utf-8') as f:
199
+ for line in f:
200
+ try:
201
+ metadata_list.append(json.loads(line))
202
+ except json.JSONDecodeError as e:
203
+ logger.error(f"Failed to parse line in {metadata_file}: {e}")
204
+
205
+ if not metadata_list:
206
+ logger.error("No valid metadata entries to upload")
207
+ return False
208
+
209
+ image_paths = [entry.get("file_name") for entry in metadata_list if entry.get("file_name")]
210
+ dataset = Dataset.from_dict({
211
+ "images": image_paths,
212
+ "metadata": metadata_list
213
+ })
214
+
215
+ logger.info("Attempting to upload dataset to Hugging Face Hub")
216
+ dataset.push_to_hub(f"{username}/{dataset_name}", private=False)
217
+ logger.info(f"Dataset successfully uploaded to {username}/{dataset_name}")
218
+ return True
219
+
220
+ except Exception as e:
221
+ logger.error(f"Failed to upload to Hugging Face: {traceback.format_exc()}")
222
+ return False
223
+
224
+ # Create a folium map with markers for geotagged images
225
+ def create_geo_map(metadata_list: List[Dict[str, Any]]) -> str:
226
+ try:
227
+ # Filter entries that have GPS coordinates
228
+ geo_entries = []
229
+ for entry in metadata_list:
230
+ gps_info = entry.get("gps_info", {})
231
+ if isinstance(gps_info, dict) and "Latitude" in gps_info and "Longitude" in gps_info:
232
+ geo_entries.append({
233
+ "file_name": entry.get("file_name", "Unknown"),
234
+ "latitude": gps_info["Latitude"],
235
+ "longitude": gps_info["Longitude"],
236
+ "date_time": entry.get("datetime", "Unknown")
237
+ })
238
+
239
+ if not geo_entries:
240
+ return "No geotagged images found"
241
+
242
+ # Create a DataFrame for easier handling
243
+ df = pd.DataFrame(geo_entries)
244
+
245
+ # Calculate map center based on average coordinates
246
+ center_lat = df["latitude"].mean()
247
+ center_lon = df["longitude"].mean()
248
+
249
+ # Create map
250
+ m = folium.Map(location=[center_lat, center_lon], zoom_start=10)
251
+
252
+ # Add marker cluster
253
+ marker_cluster = MarkerCluster().add_to(m)
254
+
255
+ # Add markers for each image
256
+ for _, row in df.iterrows():
257
+ popup_text = f"""
258
+ <strong>File:</strong> {os.path.basename(row['file_name'])}<br>
259
+ <strong>Date:</strong> {row['date_time']}<br>
260
+ <strong>Location:</strong> {row['latitude']:.6f}, {row['longitude']:.6f}
261
+ """
262
+ folium.Marker(
263
+ location=[row['latitude'], row['longitude']],
264
+ popup=folium.Popup(popup_text, max_width=300)
265
+ ).add_to(marker_cluster)
266
+
267
+ # Save map to HTML string
268
+ map_html = m._repr_html_()
269
+ return map_html
270
 
271
+ except Exception as e:
272
+ logger.error(f"Error creating map: {traceback.format_exc()}")
273
+ return f"Error creating map: {str(e)}"
274
+
275
+ # Generate embedding visualization using GeoCLIP's LocationEncoder
276
+ def generate_embedding_visualization(metadata_list: List[Dict[str, Any]]) -> Tuple[str, str]:
277
+ try:
278
+ # Filter entries that have GPS coordinates
279
+ geo_entries = []
280
+ for entry in metadata_list:
281
+ gps_info = entry.get("gps_info", {})
282
+ if isinstance(gps_info, dict) and "Latitude" in gps_info and "Longitude" in gps_info:
283
+ geo_entries.append({
284
+ "file_name": os.path.basename(entry.get("file_name", "Unknown")),
285
+ "latitude": gps_info["Latitude"],
286
+ "longitude": gps_info["Longitude"]
287
+ })
288
+
289
+ if len(geo_entries) < 2:
290
+ return "Not enough geotagged images for embedding visualization", None
291
+
292
+ # Create a DataFrame
293
+ df = pd.DataFrame(geo_entries)
294
+
295
+ # Initialize LocationEncoder
296
+ device = "cuda" if torch.cuda.is_available() else "cpu"
297
+ location_encoder = LocationEncoder().to(device)
298
+
299
+ # Generate embeddings
300
+ coords = torch.tensor(df[["latitude", "longitude"]].values, dtype=torch.float32).to(device)
301
+ embeddings = location_encoder(coords).detach().cpu().numpy()
302
+
303
+ # PCA visualization of embeddings
304
+ from sklearn.decomposition import PCA
305
+ pca = PCA(n_components=3)
306
+ pca_result = pca.fit_transform(embeddings)
307
+
308
+ # Create 3D scatter plot
309
+ fig = plt.figure(figsize=(10, 8))
310
+ ax = fig.add_subplot(111, projection='3d')
311
+
312
+ scatter = ax.scatter(
313
+ pca_result[:, 0],
314
+ pca_result[:, 1],
315
+ pca_result[:, 2],
316
+ c=np.arange(len(pca_result)),
317
+ cmap='viridis',
318
+ s=100,
319
+ alpha=0.8
320
+ )
321
+
322
+ # Add labels for each point
323
+ for i, filename in enumerate(df["file_name"]):
324
+ ax.text(pca_result[i, 0], pca_result[i, 1], pca_result[i, 2], filename, size=8)
325
+
326
+ ax.set_title('GeoCLIP Embedding Visualization (PCA)')
327
+ ax.set_xlabel('PCA Component 1')
328
+ ax.set_ylabel('PCA Component 2')
329
+ ax.set_zlabel('PCA Component 3')
330
+
331
+ # Convert plot to image
332
+ buffer = io.BytesIO()
333
+ plt.savefig(buffer, format='png', dpi=100, bbox_inches='tight')
334
+ buffer.seek(0)
335
+
336
+ # Convert to base64 for embedding in HTML
337
+ img_str = base64.b64encode(buffer.read()).decode('utf-8')
338
 
339
+ # Generate code for embedding space exploration
340
+ code_sample = """
341
+ # GeoCLIP Location Encoder Exploration Code
342
+ from geoclip import LocationEncoder
343
+ import torch
344
+ import matplotlib.pyplot as plt
345
+ from sklearn.decomposition import PCA
346
+ import numpy as np
347
+
348
+ # Initialize LocationEncoder
349
+ device = "cuda" if torch.cuda.is_available() else "cpu"
350
+ location_encoder = LocationEncoder().to(device)
351
+
352
+ # Generate embeddings for your coordinates
353
+ coords = torch.tensor([
354
+ [40.7128, -74.0060], # New York
355
+ [34.0522, -118.2437], # Los Angeles
356
+ [51.5074, -0.1278], # London
357
+ [35.6762, 139.6503], # Tokyo
358
+ [28.6139, 77.2090], # Delhi
359
+ ], dtype=torch.float32).to(device)
360
+
361
+ embeddings = location_encoder(coords).detach().cpu().numpy()
362
+
363
+ # Visualize with PCA
364
+ pca = PCA(n_components=2)
365
+ pca_result = pca.fit_transform(embeddings)
366
+
367
+ plt.figure(figsize=(10, 8))
368
+ plt.scatter(pca_result[:, 0], pca_result[:, 1], s=100)
369
+
370
+ locations = ["New York", "Los Angeles", "London", "Tokyo", "Delhi"]
371
+ for i, location in enumerate(locations):
372
+ plt.annotate(location, (pca_result[i, 0], pca_result[i, 1]), fontsize=12)
373
+
374
+ plt.title('GeoCLIP Location Embeddings (PCA)')
375
+ plt.xlabel('PCA Component 1')
376
+ plt.ylabel('PCA Component 2')
377
+ plt.grid(True, alpha=0.3)
378
+ plt.show()
379
+ """
380
 
381
+ return f'<img src="data:image/png;base64,{img_str}" alt="Embedding Visualization">', code_sample
382
 
383
+ except Exception as e:
384
+ logger.error(f"Error generating embedding visualization: {traceback.format_exc()}")
385
+ return f"Error generating embedding visualization: {str(e)}", None
386
+
387
+ # Function to analyze metadata and extract insights
388
+ def analyze_metadata(metadata_list: List[Dict[str, Any]]) -> str:
389
+ try:
390
+ total_images = len(metadata_list)
391
+ if total_images == 0:
392
+ return "No images found in metadata"
393
+
394
+ geotagged_count = sum(1 for entry in metadata_list if "gps_info" in entry and entry["gps_info"].get("Latitude") is not None)
395
+ camera_models = {}
396
+ capture_dates = []
397
+
398
+ for entry in metadata_list:
399
+ # Extract camera model
400
+ model = entry.get("model", "Unknown")
401
+ camera_models[model] = camera_models.get(model, 0) + 1
402
 
403
+ # Extract capture dates
404
+ date_str = entry.get("datetime", "")
405
+ if date_str and isinstance(date_str, str):
406
+ try:
407
+ # Simple extraction of date part (assuming format like "YYYY:MM:DD HH:MM:SS")
408
+ date_part = date_str.split()[0] if " " in date_str else date_str
409
+ capture_dates.append(date_part)
410
+ except:
411
+ pass
412
+
413
+ # Generate HTML report
414
+ html_report = f"""
415
+ <div style="font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; border-radius: 10px;">
416
+ <h2 style="color: #333;">Metadata Analysis Report</h2>
417
 
418
+ <div style="margin: 20px 0; padding: 15px; background-color: #fff; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
419
+ <h3 style="color: #0066cc;">Summary</h3>
420
+ <ul>
421
+ <li><strong>Total Images:</strong> {total_images}</li>
422
+ <li><strong>Geotagged Images:</strong> {geotagged_count} ({geotagged_count/total_images*100:.1f}%)</li>
423
+ <li><strong>Unique Camera Models:</strong> {len(camera_models)}</li>
424
+ <li><strong>Date Range:</strong> {min(capture_dates) if capture_dates else 'Unknown'} to {max(capture_dates) if capture_dates else 'Unknown'}</li>
425
+ </ul>
426
+ </div>
427
 
428
+ <div style="display: flex; flex-wrap: wrap; gap: 20px;">
429
+ <div style="flex: 1; min-width: 300px; margin: 10px 0; padding: 15px; background-color: #fff; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
430
+ <h3 style="color: #0066cc;">Camera Models</h3>
431
+ <ul>
432
+ """
433
+
434
+ # Add top 5 camera models
435
+ for model, count in sorted(camera_models.items(), key=lambda x: x[1], reverse=True)[:5]:
436
+ html_report += f'<li><strong>{model}:</strong> {count} images ({count/total_images*100:.1f}%)</li>'
437
+
438
+ html_report += """
439
+ </ul>
440
+ </div>
441
+ </div>
442
+ </div>
443
+ """
444
+
445
+ return html_report
446
 
447
+ except Exception as e:
448
+ logger.error(f"Error analyzing metadata: {traceback.format_exc()}")
449
+ return f"Error analyzing metadata: {str(e)}"
450
+
451
+ # Function to process a batch of uploaded files
452
+ def process_uploaded_files(files) -> Tuple[str, List[Dict[str, Any]], str, str, str]:
453
+ try:
454
+ if not files:
455
+ return "No files uploaded", [], "", "", ""
456
+
457
+ # Create temporary directory for uploaded files
458
+ temp_dir = Path("./temp_uploads")
459
+ temp_dir.mkdir(exist_ok=True)
 
 
 
 
 
 
460
 
461
+ # Save uploaded files to temp directory
462
+ for file in files:
463
+ file_path = temp_dir / file.name
464
+ with open(file_path, "wb") as f:
465
+ f.write(file.read())
466
+
467
+ # Process the images
468
+ metadata_list = process_images(temp_dir)
469
+
470
+ if not metadata_list:
471
+ return "No valid images found in uploads", [], "", "", ""
472
+
473
+ # Generate analysis and visualizations
474
+ analysis_html = analyze_metadata(metadata_list)
475
+ map_html = create_geo_map(metadata_list)
476
+ embedding_viz, code_sample = generate_embedding_visualization(metadata_list)
477
+
478
+ # Save metadata to file
479
+ output_file = Path("./uploaded_metadata.jsonl")
480
+ save_metadata_to_jsonl(metadata_list, output_file)
481
+
482
+ return f"Processed {len(metadata_list)} images successfully", metadata_list, analysis_html, map_html, embedding_viz, code_sample
483
+
484
+ except Exception as e:
485
+ logger.error(f"Error processing uploaded files: {traceback.format_exc()}")
486
+ return f"Error: {str(e)}", [], "", "", "", ""
487
 
488
+ # Function to process an existing directory
489
+ def process_directory(directory_path: str) -> Tuple[str, List[Dict[str, Any]], str, str, str]:
490
+ try:
491
+ if not directory_path or not os.path.isdir(directory_path):
492
+ return "Invalid directory path", [], "", "", "", ""
493
+
494
+ # Process the images in the directory
495
+ metadata_list = process_images(directory_path)
496
+
497
+ if not metadata_list:
498
+ return "No valid images found in directory", [], "", "", "", ""
499
+
500
+ # Generate analysis and visualizations
501
+ analysis_html = analyze_metadata(metadata_list)
502
+ map_html = create_geo_map(metadata_list)
503
+ embedding_viz, code_sample = generate_embedding_visualization(metadata_list)
504
+
505
+ # Save metadata to file
506
+ output_file = Path("./directory_metadata.jsonl")
507
+ save_metadata_to_jsonl(metadata_list, output_file)
508
+
509
+ return f"Processed {len(metadata_list)} images successfully", metadata_list, analysis_html, map_html, embedding_viz, code_sample
510
+
511
+ except Exception as e:
512
+ logger.error(f"Error processing directory: {traceback.format_exc()}")
513
+ return f"Error: {str(e)}", [], "", "", "", ""
514
 
515
+ # Upload metadata to Hugging Face
516
+ def upload_metadata(metadata_list: List[Dict[str, Any]], username: str, dataset_name: str) -> str:
517
+ try:
518
+ if not metadata_list:
519
+ return "No metadata to upload"
520
+
521
+ # Save metadata to temporary file
522
+ output_file = Path(f"./{dataset_name}_metadata.jsonl")
523
+ save_metadata_to_jsonl(metadata_list, output_file)
524
+
525
+ # Upload to Hugging Face
526
+ success = upload_to_huggingface(output_file, username, dataset_name)
527
+
528
+ if success:
529
+ return f"Successfully uploaded dataset to {username}/{dataset_name}"
530
+ else:
531
+ return "Failed to upload dataset to Hugging Face"
532
+
533
+ except Exception as e:
534
+ logger.error(f"Error uploading metadata: {traceback.format_exc()}")
535
+ return f"Error: {str(e)}"
536
+
537
+ # Create the Gradio interface
538
+ def create_interface():
539
+ with gr.Blocks(title="GeoCLIP Image Metadata Analyzer") as demo:
540
+ gr.Markdown("# 🌍 GeoCLIP Image Metadata Analyzer")
541
+ gr.Markdown("This tool extracts and analyzes EXIF metadata from images, with a focus on geolocation data. It leverages GeoCLIP embeddings to visualize geographic relationships.")
542
+
543
+ with gr.Tabs():
544
+ with gr.TabItem("Upload Files"):
545
+ with gr.Row():
546
+ with gr.Column():
547
+ upload_files = gr.Files(label="Upload Images", file_count="multiple")
548
+ upload_button = gr.Button("Process Uploaded Files")
549
+
550
+ with gr.Column():
551
+ status_output = gr.Textbox(label="Status")
552
+
553
+ with gr.Accordion("Raw Metadata", open=False):
554
+ metadata_json = gr.JSON(label="Extracted Metadata")
555
+
556
+ with gr.Row():
557
+ with gr.Column():
558
+ analysis_html = gr.HTML(label="Analysis Report")
559
+ with gr.Column():
560
+ map_html = gr.HTML(label="Geographic Map")
561
+
562
+ with gr.Row():
563
+ with gr.Column():
564
+ embedding_viz = gr.HTML(label="GeoCLIP Embedding Visualization")
565
+ with gr.Column():
566
+ embedding_code = gr.Code(language="python", label="GeoCLIP Exploration Code", lines=20)
567
 
568
+ with gr.TabItem("Process Directory"):
569
+ with gr.Row():
570
+ with gr.Column():
571
+ dir_path = gr.Textbox(label="Directory Path", placeholder=DEFAULT_IMAGE_DIR)
572
+ dir_button = gr.Button("Process Directory")
573
+
574
+ with gr.Column():
575
+ dir_status = gr.Textbox(label="Status")
576
+
577
+ with gr.Accordion("Raw Metadata", open=False):
578
+ dir_metadata_json = gr.JSON(label="Extracted Metadata")
579
+
580
+ with gr.Row():
581
+ with gr.Column():
582
+ dir_analysis_html = gr.HTML(label="Analysis Report")
583
+ with gr.Column():
584
+ dir_map_html = gr.HTML(label="Geographic Map")
585
+
586
+ with gr.Row():
587
+ with gr.Column():
588
+ dir_embedding_viz = gr.HTML(label="GeoCLIP Embedding Visualization")
589
+ with gr.Column():
590
+ dir_embedding_code = gr.Code(language="python", label="GeoCLIP Exploration Code", lines=20)
591
 
592
+ with gr.TabItem("Upload to HuggingFace"):
593
+ with gr.Row():
594
+ with gr.Column():
595
+ hf_username = gr.Textbox(label="HuggingFace Username", value=HF_USERNAME)
596
+ hf_dataset = gr.Textbox(label="Dataset Name", value=DATASET_NAME)
597
+ hf_source = gr.Radio(["From Uploaded Files", "From Directory"], label="Source", value="From Uploaded Files")
598
+ hf_upload_button = gr.Button("Upload to HuggingFace")
599
+
600
+ with gr.Column():
601
+ hf_status = gr.Textbox(label="Upload Status")
602
+
603
+ # Define event handlers
604
+ upload_button.click(
605
+ fn=process_uploaded_files,
606
+ inputs=[upload_files],
607
+ outputs=[status_output, metadata_json, analysis_html, map_html, embedding_viz, embedding_code]
608
+ )
609
+
610
+ dir_button.click(
611
+ fn=process_directory,
612
+ inputs=[dir_path],
613
+ outputs=[dir_status, dir_metadata_json, dir_analysis_html, dir_map_html, dir_embedding_viz, dir_embedding_code]
614
+ )
615
+
616
+ def handle_hf_upload(username, dataset_name, source):
617
+ if source == "From Uploaded Files":
618
+ metadata_file = Path("./uploaded_metadata.jsonl")
619
+ else:
620
+ metadata_file = Path("./directory_metadata.jsonl")
621
 
622
+ if not metadata_file.exists():
623
+ return "No metadata file found. Please process images first."
624
+
625
+ try:
626
+ metadata_list = []
627
+ with metadata_file.open('r', encoding='utf-8') as f:
628
+ for line in f:
629
+ try:
630
+ metadata_list.append(json.loads(line))
631
+ except json.JSONDecodeError:
632
+ pass
633
+
634
+ return upload_metadata(metadata_list, username, dataset_name)
635
+ except Exception as e:
636
+ return f"Error: {str(e)}"
637
+
638
+ hf_upload_button.click(
639
+ fn=handle_hf_upload,
640
+ inputs=[hf_username, hf_dataset, hf_source],
641
+ outputs=[hf_status]
642
+ )
643
+
644
+ gr.Markdown("""
645
+ ## About this Tool
646
+
647
+ This application integrates **GeoCLIP** location embeddings to analyze and visualize geographic relationships between images.
648
+
649
+ GeoCLIP is a CLIP-inspired model that aligns locations with images for effective worldwide geo-localization.
650
+
651
+ **Features:**
652
+ - Extract EXIF metadata from images, including geolocation data
653
+ - Visualize image locations on an interactive map
654
+ - Generate GeoCLIP embeddings for geographic coordinates
655
+ - Upload processed metadata to Hugging Face datasets
656
+
657
+ **Reference:** [GeoCLIP: Clip-Inspired Alignment between Locations and Images for Effective Worldwide Geo-localization](https://arxiv.org/abs/2309.16020)
658
+ """)
659
 
660
+ return demo
661
 
662
+ # Main entry point
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  if __name__ == "__main__":
664
+ demo = create_interface()
665
+ demo.launch()