Spaces:
Runtime error
Runtime error
from collections import Counter | |
from dataclasses import dataclass | |
import numpy as np | |
from sklearn.cluster import DBSCAN | |
from sklearn.metrics.pairwise import haversine_distances | |
from PIL import Image | |
import pandas as pd | |
from geoguessr_bot.guessr import AbstractGuessr | |
from geoguessr_bot.interfaces import Coordinate | |
from geoguessr_bot.retriever import AbstractImageEmbedder | |
from geoguessr_bot.retriever import Retriever | |
class AverageNeighborsEmbedderGuessr(AbstractGuessr): | |
"""Guesses a coordinate using an Embedder and a retriever followed by NN. | |
""" | |
embedder: AbstractImageEmbedder | |
retriever: Retriever | |
metadata_path: str | |
n_neighbors: int = 1000 | |
dbscan_eps: float = 0.05 | |
def __post_init__(self): | |
"""Load metadata | |
""" | |
metadata = pd.read_csv(self.metadata_path) | |
self.image_to_coordinate = { | |
image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude) | |
for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"]) | |
} | |
# DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine | |
self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distances) | |
def guess(self, image: Image) -> Coordinate: | |
"""Guess a coordinate from an image | |
""" | |
# Embed image | |
image = Image.fromarray(image) | |
image_embedding = self.embedder.embed(image)[None, :] | |
# Retrieve nearest neighbors | |
nearest_neighbors, distances = self.retriever.retrieve(image_embedding, self.n_neighbors) | |
nearest_neighbors = nearest_neighbors[0] | |
distances = distances[0] | |
# Get coordinates of neighbors | |
neighbors_coordinates = [self.image_to_coordinate[nn].to_radians() for nn in nearest_neighbors] | |
neighbors_coordinates = np.array([[nn.latitude, nn.longitude] for nn in neighbors_coordinates]) | |
# Use DBSCAN to find the biggest cluster and potentially remove outliers | |
clustering = self.dbscan.fit(neighbors_coordinates) | |
labels = clustering.labels_ | |
biggest_cluster = max(Counter(labels)) | |
neighbors_coordinates = neighbors_coordinates[labels == biggest_cluster] | |
distances = distances[labels == biggest_cluster] | |
# Guess coordinate as the closest image among the cluster regarding retrieving distance | |
guess_coordinate = neighbors_coordinates[np.argmin(distances)] | |
guess_coordinate = Coordinate.from_radians(guess_coordinate[0], guess_coordinate[1]) | |
return guess_coordinate | |