File size: 4,963 Bytes
465c443
63711ea
 
 
 
 
 
465c443
63711ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465c443
63711ea
465c443
63711ea
 
 
 
 
 
 
 
465c443
63711ea
d3015aa
63711ea
465c443
63711ea
 
 
 
 
 
 
 
465c443
63711ea
 
d3015aa
63711ea
465c443
63711ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465c443
4f33c08
63711ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import earthview as ev
from PIL import Image
import numpy as np
import random
import os
import json

# --- Configuration ---
DATASET_SUBSET = "satellogic"
NUM_SAMPLES_TO_LABEL = 100  # You can adjust this
LABELED_DATA_FILE = "labeled_data.json"
DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
# --- Load Dataset ---
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
data_iter = iter(dataset)

# --- Load Labeled Data (if it exists) ---
def load_labeled_data():
    if os.path.exists(LABELED_DATA_FILE):
        with open(LABELED_DATA_FILE, "r") as f:
            return json.load(f)
    else:
        return []

labeled_data = load_labeled_data()

# --- Get Next Sample for Labeling ---
def get_next_sample():
    global data_iter
    try:
        sample = next(data_iter)
        sample = ev.item_to_images(DATASET_SUBSET, sample)
        image = sample["rgb"][0]
        metadata = sample["metadata"]

        return image, metadata, len(labeled_data)
    except StopIteration:
        return None, None, None

# --- Save Labeled Data ---
def save_labeled_data(image, metadata, label):
    global labeled_data
    labeled_data.append({
        "image": image.convert("RGB").tobytes(),  # Convert to RGB and then serialize
        "metadata": metadata,
        "label": label
    })

    with open(LABELED_DATA_FILE, "w") as f:
        json.dump(labeled_data, f)

    image, metadata, count = get_next_sample()

    if image is None:
      return "No more samples", gr.Image.update(value=None), "", f"Labeled {count} samples."

    return "", image, str(metadata["bounds"]), f"Labeled {count} samples."

# --- Gradio Interface ---

# --- Labeling UI ---
def labeling_ui():
  with gr.Row():
      with gr.Column():
          image_component = gr.Image(label="Satellite Image", type="pil")
          metadata_text = gr.Textbox(label="Metadata (Bounds)")
          label_count_text = gr.Textbox(label="Label Count")

          with gr.Row():
              cool_button = gr.Button("Cool")
              not_cool_button = gr.Button("Not Cool")
              
          # Handle button clicks
          cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])
          not_cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "not cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])

  # Initialize with the first sample
  image, metadata, count = get_next_sample()
  if image is not None:
      image_component.value = image
      metadata_text.value = str(metadata["bounds"])
      label_count_text.value = f"Labeled {count} samples."

# --- Display UI ---
def display_ui():
    
    def get_random_cool_images(n):
        cool_samples = [d for d in labeled_data if d["label"] == "cool"]
        if len(cool_samples) < n:
          return [Image.frombytes("RGB", (384,384), s["image"]) for s in cool_samples]
        
        selected_cool = random.sample(cool_samples, n)
        return [Image.frombytes("RGB", (384,384), s["image"]) for s in selected_cool]

    def get_new_unlabeled_image():
        global data_iter
        try:
            sample = next(data_iter)
            sample = ev.item_to_images(DATASET_SUBSET, sample)
            image = sample["rgb"][0]
            metadata = sample["metadata"]
            return image, str(metadata["bounds"])
        except StopIteration:
            return None, None

    def refresh_display():
        new_image, new_metadata = get_new_unlabeled_image()
        if new_image is None:
            return "No more samples", gr.Image.update(value=None), gr.Gallery.update(value=[])
        
        cool_images = get_random_cool_images(DISPLAY_N_COOL)
        return "", new_image, cool_images

    with gr.Row():
        new_image_component = gr.Image(label="New Image", type="pil")
        metadata_display = gr.Textbox(label="Metadata (Bounds)")

    with gr.Row():
        cool_images_gallery = gr.Gallery(label="Cool Examples", value=[])
        cool_images_gallery.style(grid=DISPLAY_N_COOL)

    with gr.Row():
        refresh_button = gr.Button("Refresh")

    refresh_button.click(fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery])

    # Initialize
    debug, image, gallery = refresh_display()
    new_image_component.value = image
    cool_images_gallery.value = gallery

# --- Main Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# TerraNomaly")

    with gr.Tabs():
        with gr.TabItem("Labeling"):
            labeling_ui()

        with gr.TabItem("Display"):
            display_ui()

demo.launch(debug=True)