Bastien Dechamps commited on
Commit
ed8157d
·
1 Parent(s): 06ee965

[ADD] Random Embedder

Browse files
app.py CHANGED
@@ -6,8 +6,7 @@ import plotly.graph_objects as go
6
 
7
  from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
8
  AverageNeighborsEmbedderGuessr
9
- from geoguessr_bot.retriever import DinoV2Embedder, Retriever
10
-
11
 
12
  ALL_GUESSR_CLASS = {
13
  "random": RandomGuessr,
@@ -21,6 +20,7 @@ ALL_GUESSR_ARGS = {
21
  "embedder": DinoV2Embedder(
22
  device="cpu"
23
  ),
 
24
  "retriever": Retriever(
25
  embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
26
  "resources/embeddings.npy"),
@@ -32,13 +32,14 @@ ALL_GUESSR_ARGS = {
32
  "embedder": DinoV2Embedder(
33
  device="cpu"
34
  ),
 
35
  "retriever": Retriever(
36
  embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
37
  "resources/embeddings.npy"),
38
  ),
39
  "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
40
  "resources/metadatav3.csv"),
41
- "n_neighbors": 2000,
42
  "dbscan_eps": 0.5
43
  }
44
  }
 
6
 
7
  from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
8
  AverageNeighborsEmbedderGuessr
9
+ from geoguessr_bot.retriever import DinoV2Embedder, Retriever, RandomEmbedder
 
10
 
11
  ALL_GUESSR_CLASS = {
12
  "random": RandomGuessr,
 
20
  "embedder": DinoV2Embedder(
21
  device="cpu"
22
  ),
23
+ # "embedder": RandomEmbedder(n_dim=384),
24
  "retriever": Retriever(
25
  embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
26
  "resources/embeddings.npy"),
 
32
  "embedder": DinoV2Embedder(
33
  device="cpu"
34
  ),
35
+ # "embedder": RandomEmbedder(n_dim=384),
36
  "retriever": Retriever(
37
  embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
38
  "resources/embeddings.npy"),
39
  ),
40
  "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
41
  "resources/metadatav3.csv"),
42
+ "n_neighbors": 100,
43
  "dbscan_eps": 0.5
44
  }
45
  }
geoguessr_bot/guessr/average_neighbor_embedder_guessr.py CHANGED
@@ -13,6 +13,11 @@ from geoguessr_bot.retriever import AbstractImageEmbedder
13
  from geoguessr_bot.retriever import Retriever
14
 
15
 
 
 
 
 
 
16
  @dataclass
17
  class AverageNeighborsEmbedderGuessr(AbstractGuessr):
18
  """Guesses a coordinate using an Embedder and a retriever followed by NN.
@@ -20,7 +25,7 @@ class AverageNeighborsEmbedderGuessr(AbstractGuessr):
20
  embedder: AbstractImageEmbedder
21
  retriever: Retriever
22
  metadata_path: str
23
- n_neighbors: int = 1000
24
  dbscan_eps: float = 0.05
25
 
26
  def __post_init__(self):
@@ -32,7 +37,7 @@ class AverageNeighborsEmbedderGuessr(AbstractGuessr):
32
  for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
33
  }
34
  # DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine
35
- self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distances)
36
 
37
  def guess(self, image: Image) -> Coordinate:
38
  """Guess a coordinate from an image
 
13
  from geoguessr_bot.retriever import Retriever
14
 
15
 
16
+ def haversine_distance(x, y) -> float:
17
+ """Compute the haversine distance between two coordinates
18
+ """
19
+ return haversine_distances(np.array(x).reshape(1, -1), np.array(y).reshape(1, -1))[0][0]
20
+
21
  @dataclass
22
  class AverageNeighborsEmbedderGuessr(AbstractGuessr):
23
  """Guesses a coordinate using an Embedder and a retriever followed by NN.
 
25
  embedder: AbstractImageEmbedder
26
  retriever: Retriever
27
  metadata_path: str
28
+ n_neighbors: int = 50
29
  dbscan_eps: float = 0.05
30
 
31
  def __post_init__(self):
 
37
  for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
38
  }
39
  # DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine
40
+ self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distance)
41
 
42
  def guess(self, image: Image) -> Coordinate:
43
  """Guess a coordinate from an image
geoguessr_bot/retriever/__init__.py CHANGED
@@ -2,3 +2,4 @@ from .retriever import Retriever
2
  from .abstract_embedder import AbstractImageEmbedder
3
  from .dino_embedder import DinoEmbedder
4
  from .dino_v2_embedder import DinoV2Embedder
 
 
2
  from .abstract_embedder import AbstractImageEmbedder
3
  from .dino_embedder import DinoEmbedder
4
  from .dino_v2_embedder import DinoV2Embedder
5
+ from .random_embedder import RandomEmbedder
geoguessr_bot/retriever/abstract_embedder.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  from PIL import Image
3
  import numpy as np
4
  from tqdm import tqdm
@@ -8,6 +10,7 @@ class AbstractImageEmbedder:
8
  def __init__(self, device: str = "cpu"):
9
  self.device = device
10
 
 
11
  def embed(self, image: Image) -> np.ndarray:
12
  """Embed an image
13
  """
 
1
  import os
2
+ from abc import abstractmethod
3
+
4
  from PIL import Image
5
  import numpy as np
6
  from tqdm import tqdm
 
10
  def __init__(self, device: str = "cpu"):
11
  self.device = device
12
 
13
+ @abstractmethod
14
  def embed(self, image: Image) -> np.ndarray:
15
  """Embed an image
16
  """
geoguessr_bot/retriever/random_embedder.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ from geoguessr_bot.retriever import AbstractImageEmbedder
7
+
8
+
9
+ @dataclass
10
+ class RandomEmbedder(AbstractImageEmbedder):
11
+ n_dim: int = 8
12
+
13
+ def embed(self, image: Image) -> np.ndarray:
14
+ """Embed an image
15
+ """
16
+ return np.random.rand(self.n_dim)