File size: 3,173 Bytes
d9b35be
dfbe385
 
 
 
1791df2
 
ed8157d
dfbe385
 
4388025
1791df2
 
dfbe385
 
 
bd2d69d
1791df2
bd2d69d
 
 
 
d9b35be
4cf404a
 
 
 
1791df2
 
 
 
 
 
 
 
 
 
2492245
ed8157d
1791df2
bd2d69d
dfbe385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06ee965
dfbe385
 
 
2492245
9fcc076
dfbe385
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os.path
import numpy as np
import gradio as gr
import plotly.graph_objects as go

from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
    AverageNeighborsEmbedderGuessr
from geoguessr_bot.retriever import DinoV2Embedder, Retriever, RandomEmbedder

ALL_GUESSR_CLASS = {
    "random": RandomGuessr,
    "nearestNeighborEmbedder": NearestNeighborEmbedderGuessr,
    "averageNeighborsEmbedder": AverageNeighborsEmbedderGuessr,
}

ALL_GUESSR_ARGS = {
    "random": {},
    "nearestNeighborEmbedder": {
        "embedder": DinoV2Embedder(
            device="cpu"
        ),
        "retriever": Retriever(
            embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                         "resources/embeddings.npy"),
        ),
        "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                      "resources/metadatav3.csv"),
    },
    "averageNeighborsEmbedder": {
        "embedder": DinoV2Embedder(
            device="cpu"
        ),
        "retriever": Retriever(
            embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                         "resources/embeddings.npy"),
        ),
        "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                      "resources/metadatav3.csv"),
        "n_neighbors": 100,
        "dbscan_eps": 0.5
    }
}

# For instantiating guessrs only when needed
ALL_GUESSR = {}


def create_map(guessr: str) -> go.Figure:
    """Create an interactive map
    """
    # Instantiate guessr if not already done
    if guessr not in ALL_GUESSR:
        ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr])
    return AbstractGuessr.create_map()


def guess(guessr: str, uploaded_image) -> go.Figure:
    """Guess a coordinate from an image uploaded in the Gradio interface
    """
    # Instantiate guessr if not already done
    if guessr not in ALL_GUESSR:
        ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr])
    # Convert image to numpy array
    uploaded_image = np.array(uploaded_image)
    # Guess coordinate
    guess_coordinate = ALL_GUESSR[guessr].guess(uploaded_image)
    # Create map
    fig = ALL_GUESSR[guessr].create_map(guess_coordinate)
    return fig


if __name__ == "__main__":
    # Create & launch Gradio interface
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                guessr_dropdown = gr.Dropdown(
                    list(ALL_GUESSR_CLASS.keys()),
                    value="nearestNeighborEmbedder",
                    label="Guessr type",
                    info="More Guessr types will be added soon!"
                )
                image = gr.Image()  # Removed shape argument
                button = gr.Button("Guess")  # Changed 'text' to 'label'
            interactive_map = gr.Plot()
            demo.load(create_map, [guessr_dropdown], interactive_map)
            button.click(guess, [guessr_dropdown, image], interactive_map)
    # Launch demo 🚀
    demo.launch()