Tomatillo commited on
Commit
b4f72c6
·
verified ·
1 Parent(s): b87411c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +435 -30
src/streamlit_app.py CHANGED
@@ -2,39 +2,444 @@ import altair as alt
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
  """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
5
+ import os
6
+ import io
7
+ import json
8
+ import numpy as np
9
+ from PIL import Image
10
+ import requests
11
+ from segments import SegmentsClient
12
 
13
  """
14
+ # Copy images from one frame to other frames in the same sample
15
+
16
+ This HF-application first updates selected frames for a given sample UUID by
17
+ replacing each annotation’s id with its track_id and updating the segmentation bitmap
18
+ accordingly (to avoid potential conflicts). Only the frames specified by the source or target
19
+ frame numbers are processed.
20
+
21
+ Then it copies annotations from one source frame to one or more target frames. When copying:
22
+ - The annotations from the source frame are merged with the target's annotations,
23
+ adding only those from the source that are not already present (based on id).
24
+ - The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding
25
+ r-value from the source is used.
26
+ - After merging, the bitmap is scanned for unique r-values and any annotation in the target
27
+ frame whose id is not present in these unique values is deleted.
28
+ Finally, the updated datalabel is uploaded.
29
 
30
+ **Important**
31
+ If a loop is detected in any frame’s id mappings, or if any track_id in the
32
+ selected frames is greater than 255, no changes are made and nothing is uploaded to Segments.ai.
33
+ For each track_id higher than 255, the warning will include:
34
+ - Which track_id is too high and the first frame (1-indexed) where it appears.
35
+ - The lowest available id values (searched across all frames) – one for each offending track_id.
36
+ The user must resolve these issues before updating.
37
 
 
38
  """
39
 
40
+ # ---------------- Utility Functions ----------------
41
+
42
+ def download_image(url: str) -> Image.Image:
43
+ """Download an image from the given URL and return a PIL Image in RGB mode."""
44
+ resp = requests.get(url)
45
+ resp.raise_for_status()
46
+ return Image.open(io.BytesIO(resp.content)).convert("RGB")
47
+
48
+ def topological_sort(mapping: dict) -> list:
49
+ """
50
+ Given a mapping (original id -> new id) for non-trivial changes,
51
+ compute a processing order so that if a new id exists among the original ids,
52
+ it is processed first.
53
+ Returns the order as a list, or None if a cycle is detected.
54
+ """
55
+ nodes = set(mapping.keys())
56
+ graph = {}
57
+ for x in mapping:
58
+ new_id = mapping[x]
59
+ if new_id in nodes:
60
+ graph.setdefault(new_id, set()).add(x)
61
+ visited = {}
62
+ result = []
63
+ cycle_found = False
64
+
65
+ def dfs(node):
66
+ nonlocal cycle_found
67
+ if cycle_found:
68
+ return
69
+ if node in visited:
70
+ if visited[node] == "visiting":
71
+ cycle_found = True
72
+ return
73
+ visited[node] = "visiting"
74
+ for neighbor in graph.get(node, set()):
75
+ dfs(neighbor)
76
+ visited[node] = "visited"
77
+ result.append(node)
78
+
79
+ for node in nodes:
80
+ if node not in visited:
81
+ dfs(node)
82
+ if cycle_found:
83
+ return None
84
+ return result[::-1]
85
+
86
+ def detect_cycle(mapping: dict):
87
+ """
88
+ For each chain in the mapping (original id -> new id), follow it.
89
+ If a cycle is detected, return the first conflicting mapping as a tuple (original, new).
90
+ Otherwise, return None.
91
+ """
92
+ for orig in mapping:
93
+ visited = set()
94
+ current = orig
95
+ while current in mapping:
96
+ if current in visited:
97
+ return (orig, mapping[orig])
98
+ visited.add(current)
99
+ current = mapping[current]
100
+ return None
101
+
102
+ def parse_frame_numbers(frame_str: str):
103
+ """
104
+ Parse a string representing frame numbers.
105
+ Accepts comma-separated values and ranges (e.g., "2,4-6,8").
106
+ Returns a list of integers (1-indexed).
107
+ """
108
+ result = []
109
+ for part in frame_str.split(","):
110
+ part = part.strip()
111
+ if not part:
112
+ continue
113
+ if "-" in part:
114
+ tokens = part.split("-")
115
+ start = int(tokens[0].strip())
116
+ end = int(tokens[1].strip())
117
+ result.extend(range(start, end+1))
118
+ else:
119
+ result.append(int(part))
120
+ return result
121
+
122
+ # ---------------- Update Ids and Bitmaps ----------------
123
+
124
+ def update_frame_annotations_and_bitmap(client, frame: dict) -> (int, bool, tuple):
125
+ """
126
+ For a given frame, update annotations (set id = track_id) and update the segmentation bitmap.
127
+ Only non-trivial mappings (where original id != track_id) are processed.
128
+ The bitmap's R channel is updated accordingly.
129
+
130
+ IMPORTANT: If a loop is detected in the mapping, the function returns immediately without
131
+ modifying any values in the frame.
132
+
133
+ Returns a tuple:
134
+ (collision_count, cycle_detected, conflict_pair)
135
+ """
136
+ annotations = frame.get("annotations", [])
137
+ mapping = {}
138
+ original_ids = set()
139
+ for ann in annotations:
140
+ try:
141
+ orig_id = int(ann.get("id"))
142
+ new_id = int(ann.get("track_id"))
143
+ if orig_id != new_id:
144
+ mapping[orig_id] = new_id
145
+ original_ids.add(orig_id)
146
+ except (ValueError, TypeError):
147
+ continue
148
+
149
+ # Check for a cycle first. If a cycle is detected, do not modify the frame.
150
+ conflict = detect_cycle(mapping)
151
+ if conflict is not None:
152
+ return 0, True, conflict
153
+
154
+ collision_count = sum(1 for orig, new in mapping.items() if new in original_ids)
155
+ if mapping:
156
+ order = topological_sort(mapping)
157
+ if order is not None:
158
+ seg_info = frame.get("segmentation_bitmap", {})
159
+ seg_url = seg_info.get("url")
160
+ if seg_url:
161
+ try:
162
+ image = download_image(seg_url)
163
+ arr = np.array(image)
164
+ r_channel = arr[:, :, 0]
165
+ for orig in order:
166
+ new = mapping[orig]
167
+ r_channel[r_channel == orig] = new
168
+ arr[:, :, 0] = r_channel
169
+ updated_image = Image.fromarray(arr.astype(np.uint8))
170
+ buf = io.BytesIO()
171
+ updated_image.save(buf, format="PNG")
172
+ buf.seek(0)
173
+ base = os.path.basename(seg_url)
174
+ name, _ = os.path.splitext(base)
175
+ new_filename = f"{name}_updated.png"
176
+ asset = client.upload_asset(buf, filename=new_filename)
177
+ new_url = asset.url
178
+ frame.setdefault("segmentation_bitmap", {})["url"] = new_url
179
+ except Exception:
180
+ pass
181
+ # Update all annotations: set id = track_id.
182
+ for ann in annotations:
183
+ ann["id"] = ann.get("track_id")
184
+ return collision_count, False, None
185
+
186
+ def update_datalabel(sample_uuid: str, api_key: str, frames_to_update: set, labelset: str = "ground-truth") -> str:
187
+ """
188
+ Retrieves the label for the given sample UUID, updates only the specified frames by modifying annotations
189
+ and updating the segmentation bitmap (via update_frame_annotations_and_bitmap), then uploads
190
+ the updated datalabel.
191
+
192
+ IMPORTANT: If a loop is detected in any processed frame, no changes are applied and nothing is uploaded.
193
+ The user must resolve the loop before updating Segments.ai.
194
+
195
+ frames_to_update: a set of 0-indexed frame indices that should be processed.
196
+
197
+ Returns a single summary line describing the operation with extra newlines for readability.
198
+ """
199
+ client = SegmentsClient(api_key)
200
+ try:
201
+ label = client.get_label(sample_uuid)
202
+ except Exception as e:
203
+ return f"Error retrieving label for sample {sample_uuid}: {e}"
204
+
205
+ attributes = label.attributes.model_dump()
206
+ frames = attributes.get("frames", [])
207
+
208
+ total_collisions = 0
209
+ conflict_found = None
210
+ for i, frame in enumerate(frames):
211
+ if i in frames_to_update:
212
+ collisions, cycle, conflict_pair = update_frame_annotations_and_bitmap(client, frame)
213
+ if cycle:
214
+ conflict_found = conflict_pair
215
+ break
216
+ total_collisions += collisions
217
+
218
+ if conflict_found is not None:
219
+ return (f"Error: Cycle detected in annotation id mappings: original id {conflict_found[0]} -> new id {conflict_found[1]}.\n"
220
+ "Please resolve the loop before updating. No changes have been uploaded to Segments.ai.")
221
+
222
+ try:
223
+ client.update_label(sample_uuid, labelset=labelset, attributes=attributes)
224
+ except Exception as e:
225
+ return f"Error updating label on Segments.ai: {e}"
226
+
227
+ if total_collisions > 0:
228
+ summary = ("Updated annotation ids with track_ids and updated bitmap for the specified frames.\n"
229
+ "Collisions in ids were detected and resolved using logical processing.")
230
+ else:
231
+ summary = "Updated annotation ids with track_ids and updated bitmap for the specified frames."
232
+
233
+ return summary
234
+
235
+ # ---------------- Copy Annotations ----------------
236
+
237
+ def copy_annotations_to_frames(client, attributes: dict, sample_uuid: str, source_index: int, target_indexes: list) -> str:
238
+ """
239
+ Copies annotations from the source frame to each target frame and merges the segmentation bitmap.
240
+ For each target frame:
241
+ - The annotations from the source frame are merged with the target's annotations.
242
+ Only those annotations from the source that are not already present (based on id) are added.
243
+ - The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding
244
+ r-value from the source is used.
245
+ - After merging, only annotations whose id appears in the set of unique r-values in the merged bitmap are kept.
246
+ After processing all target frames, the updated label is uploaded.
247
+ Returns a single summary line with extra newlines.
248
+ """
249
+ frames = attributes.get("frames", [])
250
+ if source_index < 0 or source_index >= len(frames):
251
+ return f"Source frame index {source_index+1} is out of range."
252
+ source_frame = frames[source_index]
253
+
254
+ for tgt in target_indexes:
255
+ if tgt < 0 or tgt >= len(frames):
256
+ return f"Target frame index {tgt+1} is out of range."
257
+ target_frame = frames[tgt]
258
+ # Merge annotations: keep existing target annotations and add source annotations not already present.
259
+ target_annotations = target_frame.get("annotations", [])
260
+ source_annotations = source_frame.get("annotations", [])
261
+ existing_ids = {ann.get("id") for ann in target_annotations}
262
+ for ann in source_annotations:
263
+ if ann.get("id") not in existing_ids:
264
+ target_annotations.append(ann)
265
+ target_frame["annotations"] = target_annotations
266
+
267
+ # Merge segmentation bitmaps if both exist.
268
+ source_seg_url = source_frame.get("segmentation_bitmap", {}).get("url")
269
+ target_seg_url = target_frame.get("segmentation_bitmap", {}).get("url")
270
+ if source_seg_url and target_seg_url:
271
+ try:
272
+ source_img = download_image(source_seg_url)
273
+ target_img = download_image(target_seg_url)
274
+ arr_source = np.array(source_img)
275
+ arr_target = np.array(target_img)
276
+ if arr_source.shape != arr_target.shape:
277
+ # If shapes differ, simply use source bitmap.
278
+ target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
279
+ else:
280
+ # For pixels where the target's r-value is 0, use the source's pixel.
281
+ r_target = arr_target[:, :, 0]
282
+ mask = (r_target == 0)
283
+ merged_arr = arr_target.copy()
284
+ merged_arr[mask] = arr_source[mask]
285
+ # Upload the merged image.
286
+ merged_img = Image.fromarray(merged_arr.astype(np.uint8))
287
+ buf = io.BytesIO()
288
+ merged_img.save(buf, format="PNG")
289
+ buf.seek(0)
290
+ base = os.path.basename(target_seg_url)
291
+ name, _ = os.path.splitext(base)
292
+ new_filename = f"{name}_merged.png"
293
+ asset = client.upload_asset(buf, filename=new_filename)
294
+ merged_url = asset.url
295
+ target_frame.setdefault("segmentation_bitmap", {})["url"] = merged_url
296
+ # Determine unique r-values from the merged bitmap.
297
+ unique_vals = set(np.unique(merged_arr[:, :, 0])) - {0}
298
+ # Filter annotations: keep only those whose id (as int) is in unique_vals.
299
+ filtered_annotations = []
300
+ for ann in target_frame.get("annotations", []):
301
+ try:
302
+ ann_id = int(ann.get("id"))
303
+ if ann_id in unique_vals:
304
+ filtered_annotations.append(ann)
305
+ except (ValueError, TypeError):
306
+ pass
307
+ target_frame["annotations"] = filtered_annotations
308
+ except Exception:
309
+ # If any error occurs during merge, fall back to using the source segmentation bitmap.
310
+ target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
311
+ else:
312
+ # If one of the bitmaps is missing, use the source's.
313
+ target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
314
+
315
+ try:
316
+ client.update_label(sample_uuid, labelset="ground-truth", attributes=attributes)
317
+ except Exception as e:
318
+ return f"Error updating label on Segments.ai during annotation copy: {e}"
319
+
320
+ target_frames_str = ", ".join(str(tgt+1) for tgt in target_indexes)
321
+ return (f"Annotations merged from frame {source_index+1} into frames {target_frames_str}.\n"
322
+ f"Bitmap updated with merged r-values and annotations filtered accordingly.")
323
+
324
+ # ---------------------- Main UI ----------------------
325
+
326
+ st.title("Copy/Merge Annotations to Target Frames")
327
+
328
+ # Prompt user for API key as the first input (use type="password" for security if desired)
329
+ api_key = st.text_input("API Key", type="password")
330
+ sample_uuid = st.text_input("Sample UUID", value="")
331
+ source_frame_num = st.number_input("Source Frame Number (1-indexed)", min_value=1, step=1)
332
+ target_frames_str = st.text_input("Target Frame Numbers (comma-separated or range, e.g., '2,4-6')", value="2")
333
+
334
+ if "result" not in st.session_state:
335
+ st.session_state["result"] = ""
336
+ if "original_label" not in st.session_state:
337
+ st.session_state["original_label"] = ""
338
+ if "new_label" not in st.session_state:
339
+ st.session_state["new_label"] = ""
340
+
341
+ if st.button("Update and Copy Annotations"):
342
+ if not api_key:
343
+ st.error("Please enter your API Key.")
344
+ else:
345
+ client = SegmentsClient(api_key)
346
+ try:
347
+ orig_label_obj = client.get_label(sample_uuid)
348
+ except Exception as e:
349
+ st.error("Error retrieving original label: " + str(e))
350
+ orig_label_obj = None
351
+ if orig_label_obj is not None:
352
+ original_label_json = json.dumps(orig_label_obj.attributes.model_dump(), indent=4)
353
+ try:
354
+ target_frames_nums = parse_frame_numbers(target_frames_str)
355
+ except Exception as e:
356
+ st.error(f"Error parsing target frame numbers: {e}")
357
+ target_frames_nums = []
358
+ source_index = int(source_frame_num) - 1
359
+ target_indexes = [n - 1 for n in target_frames_nums]
360
+ # Only update the frames that are in the source or target list.
361
+ frames_to_update = set([source_index] + target_indexes)
362
+
363
+ # --- Warning Check: Verify no track_id > 255 in selected frames ---
364
+ attributes = orig_label_obj.attributes.model_dump()
365
+ frames = attributes.get("frames", [])
366
+ # Dictionary to record offending track_id -> first (1-indexed) frame number where it appears
367
+ track_id_warnings = {}
368
+ # Also accumulate id and track_id values in the selected frames (for reference)
369
+ selected_existing_values = set()
370
+ for i in sorted(frames_to_update):
371
+ if i < 0 or i >= len(frames):
372
+ continue
373
+ frame = frames[i]
374
+ for ann in frame.get("annotations", []):
375
+ try:
376
+ t_id = int(ann.get("track_id"))
377
+ selected_existing_values.add(t_id)
378
+ selected_existing_values.add(int(ann.get("id")))
379
+ if t_id > 255 and t_id not in track_id_warnings:
380
+ track_id_warnings[t_id] = i + 1 # record first appearance (1-indexed)
381
+ except Exception:
382
+ continue
383
+
384
+ # Compute available lowest values across ALL frames.
385
+ all_existing_values = set()
386
+ for frame in frames:
387
+ for ann in frame.get("annotations", []):
388
+ try:
389
+ t_id = int(ann.get("track_id"))
390
+ all_existing_values.add(t_id)
391
+ all_existing_values.add(int(ann.get("id")))
392
+ except Exception:
393
+ continue
394
+ num_offending = len(track_id_warnings)
395
+ lowest_available_list = []
396
+ candidate = 1
397
+ while len(lowest_available_list) < num_offending:
398
+ if candidate not in all_existing_values:
399
+ lowest_available_list.append(candidate)
400
+ candidate += 1
401
+
402
+ if track_id_warnings:
403
+ warning_message = "Warning: The following track_id values exceed 255:\n"
404
+ for t_id, frame_no in sorted(track_id_warnings.items()):
405
+ warning_message += f" - Track_id {t_id} appears first in frame {frame_no}.\n"
406
+ warning_message += "\nPlease change these values on Segments.ai before proceeding.\n"
407
+ warning_message += f"The lowest available id values (across all frames) are: {', '.join(map(str, lowest_available_list))}."
408
+ st.error(warning_message)
409
+ else:
410
+ update_summary = update_datalabel(sample_uuid, api_key, frames_to_update)
411
+ if update_summary.startswith("Error"):
412
+ st.error(update_summary)
413
+ else:
414
+ try:
415
+ # Retrieve the label after the update.
416
+ label = client.get_label(sample_uuid)
417
+ except Exception as e:
418
+ st.error("Error retrieving updated label: " + str(e))
419
+ label = None
420
+ if label is not None:
421
+ attributes = label.attributes.model_dump()
422
+ copy_summary = copy_annotations_to_frames(client, attributes, sample_uuid, source_index, target_indexes)
423
+ final_summary = update_summary + "\n\n" + copy_summary
424
+ st.session_state["result"] = final_summary
425
+ # Retrieve the final updated label.
426
+ try:
427
+ label_after = client.get_label(sample_uuid)
428
+ except Exception as e:
429
+ st.error("Error retrieving final updated label: " + str(e))
430
+ label_after = None
431
+ if label_after is not None:
432
+ new_label_json = json.dumps(label_after.attributes.model_dump(), indent=4)
433
+ st.session_state["original_label"] = original_label_json
434
+ st.session_state["new_label"] = new_label_json
435
+
436
+ if st.session_state["result"]:
437
+ st.text_area("Output", value=st.session_state["result"], height=150)
438
+ st.download_button("Download Original Label",
439
+ data=st.session_state["original_label"],
440
+ file_name=f"{sample_uuid}_original.json",
441
+ mime="application/json")
442
+ st.download_button("Download Updated Label",
443
+ data=st.session_state["new_label"],
444
+ file_name=f"{sample_uuid}_updated.json",
445
+ mime="application/json")