from dataclasses import dataclass from PIL import Image import pandas as pd from geoguessr_bot.guessr import AbstractGuessr from geoguessr_bot.interfaces import Coordinate from geoguessr_bot.retriever import AbstractImageEmbedder from geoguessr_bot.retriever import Retriever @dataclass class GlobalEmbedderGuessr(AbstractGuessr): """Guesses a coordinate using an Embedder and a retriever """ embedder: AbstractImageEmbedder retriever: Retriever metadata_path: str def __post_init__(self): """Load metadata """ metadata = pd.read_csv(self.metadata_path) self.image_to_coordinate = { image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude) for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"]) } def guess(self, image: Image) -> Coordinate: """Guess a coordinate from an image """ # Embed image image = Image.fromarray(image) image_embedding = self.embedder.embed(image)[None, :] # Retrieve nearest neighbors nearest_neighbors = self.retriever.retrieve(image_embedding) nearest_neighbor = nearest_neighbors[0][0] # Guess coordinate guess_coordinate = self.image_to_coordinate[nearest_neighbor] return guess_coordinate