File size: 6,002 Bytes
465c443
63711ea
 
 
 
 
 
c501c19
 
465c443
63711ea
 
 
 
7a6ee1c
 
c501c19
7c425bb
63711ea
 
 
 
 
 
 
 
 
 
 
465c443
63711ea
465c443
63711ea
 
 
 
 
 
c501c19
d3015aa
7c425bb
c501c19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465c443
63711ea
c501c19
63711ea
c501c19
 
 
 
 
 
 
63711ea
7a6ee1c
c501c19
63711ea
 
c501c19
63711ea
 
c501c19
 
 
 
 
 
 
 
 
 
 
 
 
63711ea
 
 
 
c501c19
 
 
7a6ee1c
 
c501c19
 
7a6ee1c
 
 
c501c19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a6ee1c
c501c19
 
 
7a6ee1c
 
c501c19
 
 
7a6ee1c
 
63711ea
 
c501c19
63711ea
 
7c425bb
63711ea
 
 
 
 
 
7c425bb
63711ea
7c425bb
63711ea
 
 
 
 
7c425bb
 
63711ea
 
 
 
 
465c443
4f33c08
7c425bb
63711ea
7c425bb
 
 
 
 
 
63711ea
4561290
7c425bb
 
4561290
 
 
 
63711ea
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import earthview as ev
from PIL import Image
import numpy as np
import random
import os
import json
import utils
from pandas import DataFrame

# --- Configuration ---
DATASET_SUBSET = "satellogic"
NUM_SAMPLES_TO_LABEL = 100  # You can adjust this
LABELED_DATA_FILE = "labeled_data.json"
DISPLAY_N_COOL = 5  # How many "cool" examples to display alongside each new image.
SAMPLE_SEED = 10  # The seed to use when sampling the dataset for the demo.
BATCH_SIZE = 10

# --- Load Dataset ---
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
data_iter = iter(dataset)

# --- Load Labeled Data (if it exists) ---
def load_labeled_data():
    if os.path.exists(LABELED_DATA_FILE):
        with open(LABELED_DATA_FILE, "r") as f:
            return json.load(f)
    else:
        return []

labeled_data = load_labeled_data()

# --- Get Next Sample for Labeling ---
def get_next_sample():
    global data_iter
    try:
        sample = next(data_iter)
        sample = ev.item_to_images(DATASET_SUBSET, sample)
        return sample
    except StopIteration:
        print("No more samples in the dataset.")
        return None

def get_images(batch_size, state):
    subset = state["subset"]
    images = []
    metadatas = []
    for _ in range(batch_size):
      sample = get_next_sample()
      if sample is None:
        break

      image = sample["rgb"][0]
      metadata = sample["metadata"]
      metadata["map"] = f'<a href="{utils.get_google_map_link(sample, subset)}" target="_blank">🧭</a>'

      images.append(image)
      metadatas.append(metadata)

    return images, DataFrame(metadatas)

# --- Save Labeled Data ---
def save_labeled_data(label, state):
    global labeled_data
    sample = state["sample"]
    if sample is None:
        return "No image to label", None, DataFrame()

    image = sample["rgb"][0]  # Get the PIL Image object

    image_bytes = image.convert("RGB").tobytes()
    labeled_data.append({
        "image": image_bytes,
        "metadata": sample["metadata"],
        "label": label
    })

    with open(LABELED_DATA_FILE, "w") as f:
        json.dump(labeled_data, f)

    new_sample = get_next_sample()
    if new_sample is None:
      state["sample"] = None
      return "Dataset exhausted.", None, DataFrame()

    state["sample"] = new_sample

    new_image = new_sample["rgb"][0]
    new_metadata = new_sample["metadata"]
    new_metadata["map"] = f'<a href="{utils.get_google_map_link(new_sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
    
    return "", new_image, DataFrame([new_metadata])

# --- Gradio Interface ---
# --- Labeling UI ---
def labeling_ui():
    
    state = gr.State({"sample": None, "subset": DATASET_SUBSET})

    with gr.Row():
        with gr.Column():
            gallery = gr.Gallery(label="Satellite Image", interactive=False, columns=1, object_fit="scale-down")

            with gr.Row():
                cool_button = gr.Button("Cool")
                not_cool_button = gr.Button("Not Cool")

            table = gr.DataFrame(datatype="html")

            def initialize_labeling_ui():
              sample = get_next_sample()
              
              image, metadata = get_images(1, {"sample": None, "subset": DATASET_SUBSET})

              return sample, image, metadata

            initial_sample, initial_image, initial_metadata = initialize_labeling_ui()
            gallery.value = initial_image
            table.value = initial_metadata
            state.value["sample"] = initial_sample

            # Handle button clicks
            cool_button.click(
                fn=lambda label, state: save_labeled_data(label, state),
                inputs=[gr.Textbox(visible=False, value="cool"), state],
                outputs=[gr.Textbox(label="Debug"), gallery, table]
            )
            not_cool_button.click(
                fn=lambda label, state: save_labeled_data(label, state),
                inputs=[gr.Textbox(visible=False, value="not cool"), state],
                outputs=[gr.Textbox(label="Debug"), gallery, table]
            )

# --- Display UI ---
def display_ui():

    def get_random_cool_images(n):
        cool_samples = [d for d in labeled_data if d["label"] == "cool"]
        return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else []

    def get_new_unlabeled_image():
        global data_iter
        try:
            sample = next(data_iter)
            sample = ev.item_to_images(DATASET_SUBSET, sample)
            return sample["rgb"][0], json.dumps(sample["metadata"]["bounds"])
        except StopIteration:
            print("No more samples in the dataset.")
            return None, None

    def refresh_display():
        new_image, new_metadata = get_new_unlabeled_image()
        cool_images = get_random_cool_images(DISPLAY_N_COOL)
        if new_image is None:
            return "No more samples", None, []
        return "", new_image, cool_images

    with gr.Row():
        new_image_component = gr.Image(label="New Image", type="pil")
        metadata_display = gr.Textbox(label="Metadata (Bounds)")

    with gr.Row():
        cool_images_gallery = gr.Gallery(label="Cool Examples", value=[], columns=DISPLAY_N_COOL)

    refresh_button = gr.Button("Refresh")
    refresh_button.click(
        fn=refresh_display,
        inputs=[],
        outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery]
    )

    def initialize_display_ui():
        debug, image, gallery = refresh_display()
        return debug, image, gallery

    debug, initial_image, initial_gallery = initialize_display_ui()
    new_image_component.value = initial_image
    cool_images_gallery.value = initial_gallery

# --- Main Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# TerraNomaly")
    with gr.Tabs():
        with gr.TabItem("Labeling"):
            labeling_ui()
        with gr.TabItem("Display"):
            display_ui()

demo.launch(debug=True)