Bastien Dechamps commited on
Commit
077dc3f
·
1 Parent(s): 944c93a
configs/embed_folder.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ command:
2
+ (): geoguessr_bot.commands.EmbedCommand
3
+ embedder:
4
+ (): geoguessr_bot.retriever.DinoV2Embedder
5
+ device: "cpu"
6
+ images_folder: !path "../data/samples"
7
+ output_path: !path "../data/samples_embedded"
configs/train_classifier.yml ADDED
File without changes
geoguessr_bot/__main__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ from configue import load
3
+
4
+
5
+ class CLI:
6
+ """Regroup all the commands of the CLI
7
+ """
8
+
9
+ @staticmethod
10
+ def run(config_path: str) -> None:
11
+ config = load(config_path)
12
+ command = config["command"]
13
+ command.run()
14
+
15
+
16
+ if __name__ == "__main__":
17
+ fire.Fire(CLI)
geoguessr_bot/classifier/__init__.py ADDED
File without changes
geoguessr_bot/commands/__init__.py CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .abstract_command import AbstractCommand
2
+ from .embed_command import EmbedCommand
geoguessr_bot/commands/abstract_command.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+
4
+ class AbstractCommand:
5
+ """Abstract class for all commands.
6
+ """
7
+ @abstractmethod
8
+ def run(self) -> None:
9
+ raise NotImplementedError
geoguessr_bot/commands/embed_command.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from geoguessr_bot.commands import AbstractCommand
4
+ from geoguessr_bot.retriever import AbstractImageEmbedder
5
+
6
+
7
+ @dataclass
8
+ class EmbedCommand(AbstractCommand):
9
+ """Embed all images in a folder and save them in a .npy file
10
+ """
11
+ embedder: AbstractImageEmbedder
12
+ images_folder: str
13
+ output_path: str
14
+
15
+ def run(self) -> None:
16
+ self.embedder.embed_folder(self.images_folder, self.output_path)
geoguessr_bot/retriever/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
- from .retriever import AbstractRetriever
2
  from .abstract_embedder import AbstractImageEmbedder
3
  from .dino_embedder import DinoEmbedder
 
 
1
+ from .retriever import Retriever
2
  from .abstract_embedder import AbstractImageEmbedder
3
  from .dino_embedder import DinoEmbedder
4
+ from .dino_v2_embedder import DinoV2Embedder
geoguessr_bot/retriever/abstract_embedder.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from PIL import Image
3
  import numpy as np
 
4
 
5
 
6
  class AbstractImageEmbedder:
@@ -12,13 +13,14 @@ class AbstractImageEmbedder:
12
  """
13
  raise NotImplementedError
14
 
15
- def embed_folder(self, folder_path: str):
16
  """Embed all images in a folder and save them in a .npy file
17
  """
 
18
  embeddings = {}
19
- for image in os.listdir(folder_path):
20
  image_path = os.path.join(folder_path, image)
21
  image = Image.open(image_path)
22
  embedding = self.embed(image)
23
  embeddings[image] = embedding
24
- np.save(folder_path + ".npy", embeddings)
 
1
  import os
2
  from PIL import Image
3
  import numpy as np
4
+ from tqdm import tqdm
5
 
6
 
7
  class AbstractImageEmbedder:
 
13
  """
14
  raise NotImplementedError
15
 
16
+ def embed_folder(self, folder_path: str, output_path: str) -> None:
17
  """Embed all images in a folder and save them in a .npy file
18
  """
19
+ assert output_path.endswith(".npy"), "`output_path` must end with .npy"
20
  embeddings = {}
21
+ for image in tqdm(os.listdir(folder_path)):
22
  image_path = os.path.join(folder_path, image)
23
  image = Image.open(image_path)
24
  embedding = self.embed(image)
25
  embeddings[image] = embedding
26
+ np.save(output_path, embeddings)
geoguessr_bot/retriever/dino_v2_embedder.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms as T
4
+ from PIL import Image
5
+
6
+ from .abstract_embedder import AbstractImageEmbedder
7
+
8
+
9
+ class DinoV2Embedder(AbstractImageEmbedder):
10
+ def __init__(self, device: str = "cpu"):
11
+ """Embedder using DINOv2 embeddings.
12
+ """
13
+ super().__init__(device)
14
+ self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device)
15
+ self.model.eval()
16
+ self.transforms = T.Compose([
17
+ T.Resize((256, 256), interpolation=T.InterpolationMode.BICUBIC),
18
+ T.CenterCrop(224),
19
+ T.ToTensor(),
20
+ T.Normalize(
21
+ mean=[0.485, 0.456, 0.406],
22
+ std=[0.229, 0.224, 0.225]
23
+ )
24
+ ])
25
+
26
+ def embed(self, image: Image) -> np.ndarray:
27
+ image = image.convert("RGB")
28
+ image = self.transforms(image).unsqueeze(0).to(self.device)
29
+ with torch.no_grad():
30
+ output = self.model(image)[0].cpu().numpy()
31
+ return output
geoguessr_bot/retriever/retriever.py CHANGED
@@ -20,11 +20,11 @@ class Retriever:
20
 
21
  @staticmethod
22
  def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]:
23
- """Load embeddings from a file
24
  """
25
- raise NotImplementedError
26
 
27
- def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> List[str]:
28
  """Retrieve nearest neighbors indexes from queries
29
  """
30
  dist, indexes = self.index.search(queries, n_neighbors)
 
20
 
21
  @staticmethod
22
  def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]:
23
+ """Load embeddings from a .npy file
24
  """
25
+ return np.load(embeddings_path, allow_pickle=True).item()
26
 
27
+ def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> List[List[str]]:
28
  """Retrieve nearest neighbors indexes from queries
29
  """
30
  dist, indexes = self.index.search(queries, n_neighbors)
requirements.txt CHANGED
@@ -7,3 +7,8 @@ Pillow
7
  transformers
8
  requests
9
  faiss-cpu
 
 
 
 
 
 
7
  transformers
8
  requests
9
  faiss-cpu
10
+ torch
11
+ torchvision
12
+ tqdm
13
+ configue
14
+ fire