dropbop commited on
Commit
7d5835f
·
verified ·
1 Parent(s): bdd3e91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -116
app.py CHANGED
@@ -1,123 +1,94 @@
1
  import gradio as gr
2
  import earthview as ev
3
- from PIL import Image
4
- import numpy as np
5
  import random
 
6
  import os
7
- import json
8
- import utils
9
- from pandas import DataFrame
10
-
11
- # --- Configuration ---
12
- DATASET_SUBSET = "satellogic"
13
- LABELED_DATA_FILE = "labeled_data.json"
14
- SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
15
 
16
- # --- Load Dataset ---
17
- dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
18
  data_iter = iter(dataset)
19
 
20
- # --- Load Labeled Data (if it exists) ---
21
- def load_labeled_data():
22
- if os.path.exists(LABELED_DATA_FILE):
23
- with open(LABELED_DATA_FILE, "r") as f:
24
- return json.load(f)
25
- else:
26
- return []
27
-
28
- labeled_data = load_labeled_data()
29
-
30
- # --- Get Next Sample for Labeling ---
31
- def get_next_sample():
32
- global data_iter
33
- try:
34
- sample = next(data_iter)
35
- sample = ev.item_to_images(DATASET_SUBSET, sample)
36
- return sample
37
- except StopIteration:
38
- print("No more samples in the dataset.")
39
- return None
40
-
41
- # --- Save Labeled Data ---
42
- def save_labeled_data(image, label, state):
43
- global labeled_data
44
-
45
- # Convert PIL Image to bytes before saving
46
- if image is not None:
47
- image_bytes = image.convert("RGB").tobytes()
48
- else:
49
- image_bytes = None
50
-
51
- labeled_data.append({
52
- "image": image_bytes,
53
- "metadata": state["metadata"],
54
- "label": label
55
- })
56
-
57
- with open(LABELED_DATA_FILE, "w") as f:
58
- json.dump(labeled_data, f)
59
-
60
- new_sample = get_next_sample()
61
- if new_sample is None:
62
- print("Dataset exhausted.")
63
- return "Dataset exhausted.", None, DataFrame()
64
-
65
- new_image = new_sample["rgb"][0]
66
- new_metadata = new_sample["metadata"]
67
- new_metadata["map"] = f'<a href="{utils.get_google_map_link(new_sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
68
- state["metadata"] = new_metadata # Update metadata in state
69
-
70
- # Convert new PIL Image to bytes for sending to client-side
71
- if new_image is not None:
72
- new_image_bytes = new_image.convert("RGB").tobytes()
73
- else:
74
- new_image_bytes = None
75
-
76
- return "", new_image_bytes, DataFrame([new_metadata])
77
-
78
- # --- Gradio Interface ---
79
- # --- Labeling UI ---
80
- def labeling_ui():
81
- state = gr.State({"sample": None, "metadata": None, "subset": DATASET_SUBSET})
82
-
83
- with gr.Row():
84
- with gr.Column():
85
- image_component = gr.Image(label="Satellite Image", type="pil", interactive=False)
86
- with gr.Row():
87
- cool_button = gr.Button("Cool")
88
- not_cool_button = gr.Button("Not Cool")
89
- table = gr.DataFrame(datatype="html")
90
-
91
- def initialize_labeling_ui():
92
- sample = get_next_sample()
93
- if sample is None:
94
- return {"sample": None, "metadata": None, "subset": DATASET_SUBSET}, None, DataFrame()
95
- image = sample["rgb"][0]
96
- metadata = sample["metadata"]
97
- metadata["map"] = f'<a href="{utils.get_google_map_link(sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
98
- return {"sample": sample, "metadata": metadata, "subset": DATASET_SUBSET}, image, DataFrame([metadata])
99
-
100
- initial_state, initial_image, initial_metadata = initialize_labeling_ui()
101
- image_component.value = initial_image
102
- table.value = initial_metadata
103
- state.value = initial_state
104
-
105
- cool_button.click(
106
- fn=lambda image, label, state: save_labeled_data(image, label, state),
107
- inputs=[image_component, gr.Textbox(visible=False, value="cool"), state],
108
- outputs=[gr.Textbox(label="Debug"), image_component, table]
109
- )
110
- not_cool_button.click(
111
- fn=lambda image, label, state: save_labeled_data(image, label, state),
112
- inputs=[image_component, gr.Textbox(visible=False, value="not cool"), state],
113
- outputs=[gr.Textbox(label="Debug"), image_component, table]
114
- )
115
-
116
- # --- Main Interface ---
117
- with gr.Blocks() as demo:
118
- gr.Markdown("# TerraNomaly")
119
- with gr.Tabs():
120
- with gr.TabItem("Labeling"):
121
- labeling_ui()
122
-
123
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import earthview as ev
3
+ import utils
 
4
  import random
5
+ import pandas as pd
6
  import os
 
 
 
 
 
 
 
 
7
 
8
+ # Load the Satellogic dataset
9
+ dataset = ev.load_dataset("satellogic", streaming=True).shuffle(seed=42)
10
  data_iter = iter(dataset)
11
 
12
+ # File to store labels (will create if it doesn't exist)
13
+ label_file = "labels.csv"
14
+
15
+ # Initialize a DataFrame to hold labels (or load existing)
16
+ if os.path.exists(label_file):
17
+ labels_df = pd.read_csv(label_file)
18
+ else:
19
+ labels_df = pd.DataFrame(columns=["image_id", "bounds", "rating", "google_maps_link"])
20
+
21
+ def get_next_image():
22
+ global data_iter, labels_df
23
+
24
+ while True: # Keep iterating until we find an unlabeled image
25
+ try:
26
+ sample = next(data_iter)
27
+ except StopIteration:
28
+ #refresh the dataset if we reach the end
29
+ dataset = ev.load_dataset("satellogic", streaming=True).shuffle(seed=random.randint(0, 1000000))
30
+ data_iter = iter(dataset)
31
+ continue
32
+
33
+ sample = ev.item_to_images("satellogic", sample)
34
+ image = sample["rgb"][0] # Get the first RGB image
35
+ metadata = sample["metadata"]
36
+
37
+ bounds = metadata["bounds"]
38
+ google_maps_link = utils.get_google_map_link(sample, "satellogic")
39
+ #generate a unique image ID:
40
+ image_id = (str(bounds))
41
+
42
+ # Check if image is already labeled
43
+ if image_id not in labels_df["image_id"].values:
44
+ return image, image_id, bounds, google_maps_link
45
+
46
+ def rate_image(image_id, bounds, rating, google_maps_link):
47
+ global labels_df
48
+
49
+ # Add the rating to the DataFrame
50
+ new_row = pd.DataFrame({"image_id": [image_id], "bounds": [bounds], "rating": [rating], "google_maps_link": [google_maps_link]})
51
+ labels_df = pd.concat([labels_df, new_row], ignore_index=True)
52
+
53
+ # Save the DataFrame to CSV
54
+ labels_df.to_csv(label_file, index=False)
55
+
56
+ # Get the next image and its details
57
+ next_image, next_image_id, next_bounds, next_google_maps_link = get_next_image()
58
+
59
+ return next_image, next_image_id, next_bounds, next_google_maps_link
60
+
61
+ # Define the Gradio interface
62
+ iface = gr.Interface(
63
+ fn=rate_image,
64
+ inputs=[
65
+ gr.Textbox(label="Image ID", visible=False),
66
+ gr.Textbox(label="Bounds", visible=False),
67
+ gr.Radio(["Cool", "Not Cool"], label="Rating"),
68
+ gr.Textbox(label="Google Maps Link"),
69
+ ],
70
+ outputs=[
71
+ gr.Image(label="Satellite Image"),
72
+ gr.Textbox(label="Image ID", visible=False),
73
+ gr.Textbox(label="Bounds", visible=False),
74
+ gr.Textbox(label="Google Maps Link"),
75
+ ],
76
+ title="TerraNomaly - Satellite Image Labeling",
77
+ description="Rate satellite images as 'Cool' or 'Not Cool'.",
78
+ live=False,
79
+ )
80
+
81
+ # Get the first image and its details
82
+ initial_image, initial_image_id, initial_bounds, initial_google_maps_link = get_next_image()
83
+
84
+ # Set the initial values for the output components
85
+ iface.launch(
86
+ share=True,
87
+ initial_outputs=[
88
+ initial_image,
89
+ initial_image_id,
90
+ initial_bounds,
91
+ initial_google_maps_link,
92
+ ],
93
+
94
+ )