geoguessr-bot / geoguessr_bot /guessr /average_neighbor_embedder_guessr.py
Bastien Dechamps
[ADD] Average embedder
1791df2
raw
history blame
2.63 kB
from collections import Counter
from dataclasses import dataclass
import numpy as np
from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import haversine_distances
from PIL import Image
import pandas as pd
from geoguessr_bot.guessr import AbstractGuessr
from geoguessr_bot.interfaces import Coordinate
from geoguessr_bot.retriever import AbstractImageEmbedder
from geoguessr_bot.retriever import Retriever
@dataclass
class AverageNeighborsEmbedderGuessr(AbstractGuessr):
"""Guesses a coordinate using an Embedder and a retriever followed by NN.
"""
embedder: AbstractImageEmbedder
retriever: Retriever
metadata_path: str
n_neighbors: int = 1000
dbscan_eps: float = 0.05
def __post_init__(self):
"""Load metadata
"""
metadata = pd.read_csv(self.metadata_path)
self.image_to_coordinate = {
image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude)
for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
}
# DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine
self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distances)
def guess(self, image: Image) -> Coordinate:
"""Guess a coordinate from an image
"""
# Embed image
image = Image.fromarray(image)
image_embedding = self.embedder.embed(image)[None, :]
# Retrieve nearest neighbors
nearest_neighbors, distances = self.retriever.retrieve(image_embedding, self.n_neighbors)
nearest_neighbors = nearest_neighbors[0]
distances = distances[0]
# Get coordinates of neighbors
neighbors_coordinates = [self.image_to_coordinate[nn].to_radians() for nn in nearest_neighbors]
neighbors_coordinates = np.array([[nn.latitude, nn.longitude] for nn in neighbors_coordinates])
# Use DBSCAN to find the biggest cluster and potentially remove outliers
clustering = self.dbscan.fit(neighbors_coordinates)
labels = clustering.labels_
biggest_cluster = max(Counter(labels))
neighbors_coordinates = neighbors_coordinates[labels == biggest_cluster]
distances = distances[labels == biggest_cluster]
# Guess coordinate as the closest image among the cluster regarding retrieving distance
guess_coordinate = neighbors_coordinates[np.argmin(distances)]
guess_coordinate = Coordinate.from_radians(guess_coordinate[0], guess_coordinate[1])
return guess_coordinate