dropbop commited on
Commit
63711ea
·
verified ·
1 Parent(s): 00944ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -62
app.py CHANGED
@@ -1,74 +1,145 @@
1
- import pandas as pd
2
- from PIL import Image
3
- import earthview as ev
4
  import gradio as gr
5
- from datasets import load_dataset
6
-
7
- # Global variables to manage state
8
- current_image = None
9
- current_metadata_index = 0
10
- ratings = []
11
- image_id = 0
12
- bounds = []
13
- timestamps = []
14
-
15
- # Function to transform a metadata sample to bounds and timestamps
16
- def item_to_bounds_timestamps(sample):
17
- bounds_list = sample["metadata"]["bounds"]
18
- timestamp_list = sample["metadata"]["timestamp"]
19
- bounds = []
20
- timestamps = []
21
- for b, t in zip(bounds_list, timestamp_list):
22
- bounds.append(b)
23
- timestamps.append(t)
24
- return bounds, timestamps
25
-
26
- # Load the dataset directly from Hugging Face
27
- data = load_dataset("satellogic/EarthView", "satellogic", split="train", streaming=True)
28
-
29
- # Function to load and display the next image
30
- def load_next_image():
31
- global current_image, image_id, current_metadata_index, bounds, timestamps
32
 
33
- try:
34
- sample = next(iter(data))
35
- bounds_sample, timestamps_sample = item_to_bounds_timestamps(sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Use earthview library to convert arrays to PIL images
38
- sample = ev.item_to_images("satellogic", sample)
39
- current_image = sample["rgb"][0] # Get the first image
40
- image_id += 1
41
 
42
- return current_image, f"Image ID: {image_id}", bounds_sample[0]
 
 
 
 
 
 
 
43
 
 
44
  except StopIteration:
45
- return None, "No more images", None
46
 
47
- # Function to handle rating submission
48
- def submit_rating(rating, bounds_str):
49
- global image_id, current_metadata_index, ratings, bounds, timestamps
 
 
 
 
 
50
 
51
- ratings.append(rating)
52
- bounds.append(bounds_str)
53
- timestamps.append("timestamp") # Use a valid timestamp if available
54
 
55
- current_metadata_index += 1
56
 
57
- return load_next_image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Gradio Interface Layout
60
- with gr.Blocks() as demo:
61
  with gr.Row():
62
- with gr.Column():
63
- image_display = gr.Image(label="Satellite Image", type="pil")
64
- image_id_display = gr.Textbox(label="Image ID")
65
- bounds_display = gr.Textbox(label="Bounds")
66
- load_button = gr.Button("Load Next Image")
67
- load_button.click(fn=load_next_image, outputs=[image_display, image_id_display, bounds_display])
68
- with gr.Column():
69
- rating_radio = gr.Radio(["0", "1"], label="Rating (0 = No, 1 = Yes)")
70
- submit_button = gr.Button("Submit Rating")
71
- submit_button.click(fn=lambda rating, bounds_str: submit_rating(rating, bounds_str), inputs=[rating_radio, bounds_display], outputs=[image_display, image_id_display, bounds_display])
72
-
73
- # Launch the Gradio interface
74
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import earthview as ev
3
+ from PIL import Image
4
+ import numpy as np
5
+ import random
6
+ import os
7
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # --- Configuration ---
10
+ DATASET_SUBSET = "satellogic"
11
+ NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
12
+ LABELED_DATA_FILE = "labeled_data.json"
13
+ DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
14
+ SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
15
+ # --- Load Dataset ---
16
+ dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
17
+ data_iter = iter(dataset)
18
+
19
+ # --- Load Labeled Data (if it exists) ---
20
+ def load_labeled_data():
21
+ if os.path.exists(LABELED_DATA_FILE):
22
+ with open(LABELED_DATA_FILE, "r") as f:
23
+ return json.load(f)
24
+ else:
25
+ return []
26
 
27
+ labeled_data = load_labeled_data()
 
 
 
28
 
29
+ # --- Get Next Sample for Labeling ---
30
+ def get_next_sample():
31
+ global data_iter
32
+ try:
33
+ sample = next(data_iter)
34
+ sample = ev.item_to_images(DATASET_SUBSET, sample)
35
+ image = sample["rgb"][0]
36
+ metadata = sample["metadata"]
37
 
38
+ return image, metadata, len(labeled_data)
39
  except StopIteration:
40
+ return None, None, None
41
 
42
+ # --- Save Labeled Data ---
43
+ def save_labeled_data(image, metadata, label):
44
+ global labeled_data
45
+ labeled_data.append({
46
+ "image": image.convert("RGB").tobytes(), # Convert to RGB and then serialize
47
+ "metadata": metadata,
48
+ "label": label
49
+ })
50
 
51
+ with open(LABELED_DATA_FILE, "w") as f:
52
+ json.dump(labeled_data, f)
 
53
 
54
+ image, metadata, count = get_next_sample()
55
 
56
+ if image is None:
57
+ return "No more samples", gr.Image.update(value=None), "", f"Labeled {count} samples."
58
+
59
+ return "", image, str(metadata["bounds"]), f"Labeled {count} samples."
60
+
61
+ # --- Gradio Interface ---
62
+
63
+ # --- Labeling UI ---
64
+ def labeling_ui():
65
+ with gr.Row():
66
+ with gr.Column():
67
+ image_component = gr.Image(label="Satellite Image", type="pil")
68
+ metadata_text = gr.Textbox(label="Metadata (Bounds)")
69
+ label_count_text = gr.Textbox(label="Label Count")
70
+
71
+ with gr.Row():
72
+ cool_button = gr.Button("Cool")
73
+ not_cool_button = gr.Button("Not Cool")
74
+
75
+ # Handle button clicks
76
+ cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])
77
+ not_cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "not cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])
78
+
79
+ # Initialize with the first sample
80
+ image, metadata, count = get_next_sample()
81
+ if image is not None:
82
+ image_component.value = image
83
+ metadata_text.value = str(metadata["bounds"])
84
+ label_count_text.value = f"Labeled {count} samples."
85
+
86
+ # --- Display UI ---
87
+ def display_ui():
88
+
89
+ def get_random_cool_images(n):
90
+ cool_samples = [d for d in labeled_data if d["label"] == "cool"]
91
+ if len(cool_samples) < n:
92
+ return [Image.frombytes("RGB", (384,384), s["image"]) for s in cool_samples]
93
+
94
+ selected_cool = random.sample(cool_samples, n)
95
+ return [Image.frombytes("RGB", (384,384), s["image"]) for s in selected_cool]
96
+
97
+ def get_new_unlabeled_image():
98
+ global data_iter
99
+ try:
100
+ sample = next(data_iter)
101
+ sample = ev.item_to_images(DATASET_SUBSET, sample)
102
+ image = sample["rgb"][0]
103
+ metadata = sample["metadata"]
104
+ return image, str(metadata["bounds"])
105
+ except StopIteration:
106
+ return None, None
107
+
108
+ def refresh_display():
109
+ new_image, new_metadata = get_new_unlabeled_image()
110
+ if new_image is None:
111
+ return "No more samples", gr.Image.update(value=None), gr.Gallery.update(value=[])
112
+
113
+ cool_images = get_random_cool_images(DISPLAY_N_COOL)
114
+ return "", new_image, cool_images
115
+
116
+ with gr.Row():
117
+ new_image_component = gr.Image(label="New Image", type="pil")
118
+ metadata_display = gr.Textbox(label="Metadata (Bounds)")
119
 
 
 
120
  with gr.Row():
121
+ cool_images_gallery = gr.Gallery(label="Cool Examples", value=[])
122
+ cool_images_gallery.style(grid=DISPLAY_N_COOL)
123
+
124
+ with gr.Row():
125
+ refresh_button = gr.Button("Refresh")
126
+
127
+ refresh_button.click(fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery])
128
+
129
+ # Initialize
130
+ debug, image, gallery = refresh_display()
131
+ new_image_component.value = image
132
+ cool_images_gallery.value = gallery
133
+
134
+ # --- Main Interface ---
135
+ with gr.Blocks() as demo:
136
+ gr.Markdown("# TerraNomaly")
137
+
138
+ with gr.Tabs():
139
+ with gr.TabItem("Labeling"):
140
+ labeling_ui()
141
+
142
+ with gr.TabItem("Display"):
143
+ display_ui()
144
+
145
+ demo.launch(debug=True)