File size: 20,460 Bytes
9322901
 
 
 
b4f72c6
 
 
 
 
 
 
9322901
 
b4f72c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9322901
b4f72c6
 
 
 
 
 
 
9322901
 
 
b4f72c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import os
import io
import json
import numpy as np
from PIL import Image
import requests
from segments import SegmentsClient

"""
# Copy images from one frame to other frames in the same sample

This HF-application first updates selected frames for a given sample UUID by
replacing each annotation’s id with its track_id and updating the segmentation bitmap
accordingly (to avoid potential conflicts). Only the frames specified by the source or target
frame numbers are processed.
    
Then it copies annotations from one source frame to one or more target frames. When copying:
   - The annotations from the source frame are merged with the target's annotations,
     adding only those from the source that are not already present (based on id).
   - The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding
     r-value from the source is used.
   - After merging, the bitmap is scanned for unique r-values and any annotation in the target
     frame whose id is not present in these unique values is deleted.
Finally, the updated datalabel is uploaded. 

**Important** 
If a loop is detected in any frame’s id mappings, or if any track_id in the
selected frames is greater than 255, no changes are made and nothing is uploaded to Segments.ai.
For each track_id higher than 255, the warning will include:
   - Which track_id is too high and the first frame (1-indexed) where it appears.
   - The lowest available id values (searched across all frames) – one for each offending track_id.
The user must resolve these issues before updating.

"""

# ---------------- Utility Functions ----------------

def download_image(url: str) -> Image.Image:
    """Download an image from the given URL and return a PIL Image in RGB mode."""
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content)).convert("RGB")

def topological_sort(mapping: dict) -> list:
    """
    Given a mapping (original id -> new id) for non-trivial changes,
    compute a processing order so that if a new id exists among the original ids,
    it is processed first.
    Returns the order as a list, or None if a cycle is detected.
    """
    nodes = set(mapping.keys())
    graph = {}
    for x in mapping:
        new_id = mapping[x]
        if new_id in nodes:
            graph.setdefault(new_id, set()).add(x)
    visited = {}
    result = []
    cycle_found = False

    def dfs(node):
        nonlocal cycle_found
        if cycle_found:
            return
        if node in visited:
            if visited[node] == "visiting":
                cycle_found = True
            return
        visited[node] = "visiting"
        for neighbor in graph.get(node, set()):
            dfs(neighbor)
        visited[node] = "visited"
        result.append(node)

    for node in nodes:
        if node not in visited:
            dfs(node)
    if cycle_found:
        return None
    return result[::-1]

def detect_cycle(mapping: dict):
    """
    For each chain in the mapping (original id -> new id), follow it.
    If a cycle is detected, return the first conflicting mapping as a tuple (original, new).
    Otherwise, return None.
    """
    for orig in mapping:
        visited = set()
        current = orig
        while current in mapping:
            if current in visited:
                return (orig, mapping[orig])
            visited.add(current)
            current = mapping[current]
    return None

def parse_frame_numbers(frame_str: str):
    """
    Parse a string representing frame numbers.
    Accepts comma-separated values and ranges (e.g., "2,4-6,8").
    Returns a list of integers (1-indexed).
    """
    result = []
    for part in frame_str.split(","):
        part = part.strip()
        if not part:
            continue
        if "-" in part:
            tokens = part.split("-")
            start = int(tokens[0].strip())
            end = int(tokens[1].strip())
            result.extend(range(start, end+1))
        else:
            result.append(int(part))
    return result

# ---------------- Update Ids and Bitmaps ----------------

def update_frame_annotations_and_bitmap(client, frame: dict) -> (int, bool, tuple):
    """
    For a given frame, update annotations (set id = track_id) and update the segmentation bitmap.
    Only non-trivial mappings (where original id != track_id) are processed.
    The bitmap's R channel is updated accordingly.
    
    IMPORTANT: If a loop is detected in the mapping, the function returns immediately without
    modifying any values in the frame.
    
    Returns a tuple:
       (collision_count, cycle_detected, conflict_pair)
    """
    annotations = frame.get("annotations", [])
    mapping = {}
    original_ids = set()
    for ann in annotations:
        try:
            orig_id = int(ann.get("id"))
            new_id = int(ann.get("track_id"))
            if orig_id != new_id:
                mapping[orig_id] = new_id
                original_ids.add(orig_id)
        except (ValueError, TypeError):
            continue

    # Check for a cycle first. If a cycle is detected, do not modify the frame.
    conflict = detect_cycle(mapping)
    if conflict is not None:
        return 0, True, conflict

    collision_count = sum(1 for orig, new in mapping.items() if new in original_ids)
    if mapping:
        order = topological_sort(mapping)
        if order is not None:
            seg_info = frame.get("segmentation_bitmap", {})
            seg_url = seg_info.get("url")
            if seg_url:
                try:
                    image = download_image(seg_url)
                    arr = np.array(image)
                    r_channel = arr[:, :, 0]
                    for orig in order:
                        new = mapping[orig]
                        r_channel[r_channel == orig] = new
                    arr[:, :, 0] = r_channel
                    updated_image = Image.fromarray(arr.astype(np.uint8))
                    buf = io.BytesIO()
                    updated_image.save(buf, format="PNG")
                    buf.seek(0)
                    base = os.path.basename(seg_url)
                    name, _ = os.path.splitext(base)
                    new_filename = f"{name}_updated.png"
                    asset = client.upload_asset(buf, filename=new_filename)
                    new_url = asset.url
                    frame.setdefault("segmentation_bitmap", {})["url"] = new_url
                except Exception:
                    pass
    # Update all annotations: set id = track_id.
    for ann in annotations:
        ann["id"] = ann.get("track_id")
    return collision_count, False, None

def update_datalabel(sample_uuid: str, api_key: str, frames_to_update: set, labelset: str = "ground-truth") -> str:
    """
    Retrieves the label for the given sample UUID, updates only the specified frames by modifying annotations
    and updating the segmentation bitmap (via update_frame_annotations_and_bitmap), then uploads
    the updated datalabel.

    IMPORTANT: If a loop is detected in any processed frame, no changes are applied and nothing is uploaded.
    The user must resolve the loop before updating Segments.ai.

    frames_to_update: a set of 0-indexed frame indices that should be processed.

    Returns a single summary line describing the operation with extra newlines for readability.
    """
    client = SegmentsClient(api_key)
    try:
        label = client.get_label(sample_uuid)
    except Exception as e:
        return f"Error retrieving label for sample {sample_uuid}: {e}"
    
    attributes = label.attributes.model_dump()
    frames = attributes.get("frames", [])
    
    total_collisions = 0
    conflict_found = None
    for i, frame in enumerate(frames):
        if i in frames_to_update:
            collisions, cycle, conflict_pair = update_frame_annotations_and_bitmap(client, frame)
            if cycle:
                conflict_found = conflict_pair
                break
            total_collisions += collisions

    if conflict_found is not None:
        return (f"Error: Cycle detected in annotation id mappings: original id {conflict_found[0]} -> new id {conflict_found[1]}.\n"
                "Please resolve the loop before updating. No changes have been uploaded to Segments.ai.")
    
    try:
        client.update_label(sample_uuid, labelset=labelset, attributes=attributes)
    except Exception as e:
        return f"Error updating label on Segments.ai: {e}"
    
    if total_collisions > 0:
        summary = ("Updated annotation ids with track_ids and updated bitmap for the specified frames.\n"
                   "Collisions in ids were detected and resolved using logical processing.")
    else:
        summary = "Updated annotation ids with track_ids and updated bitmap for the specified frames."
    
    return summary

# ---------------- Copy Annotations ----------------

def copy_annotations_to_frames(client, attributes: dict, sample_uuid: str, source_index: int, target_indexes: list) -> str:
    """
    Copies annotations from the source frame to each target frame and merges the segmentation bitmap.
    For each target frame:
      - The annotations from the source frame are merged with the target's annotations.
        Only those annotations from the source that are not already present (based on id) are added.
      - The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding
        r-value from the source is used.
      - After merging, only annotations whose id appears in the set of unique r-values in the merged bitmap are kept.
    After processing all target frames, the updated label is uploaded.
    Returns a single summary line with extra newlines.
    """
    frames = attributes.get("frames", [])
    if source_index < 0 or source_index >= len(frames):
        return f"Source frame index {source_index+1} is out of range."
    source_frame = frames[source_index]
    
    for tgt in target_indexes:
        if tgt < 0 or tgt >= len(frames):
            return f"Target frame index {tgt+1} is out of range."
        target_frame = frames[tgt]
        # Merge annotations: keep existing target annotations and add source annotations not already present.
        target_annotations = target_frame.get("annotations", [])
        source_annotations = source_frame.get("annotations", [])
        existing_ids = {ann.get("id") for ann in target_annotations}
        for ann in source_annotations:
            if ann.get("id") not in existing_ids:
                target_annotations.append(ann)
        target_frame["annotations"] = target_annotations
        
        # Merge segmentation bitmaps if both exist.
        source_seg_url = source_frame.get("segmentation_bitmap", {}).get("url")
        target_seg_url = target_frame.get("segmentation_bitmap", {}).get("url")
        if source_seg_url and target_seg_url:
            try:
                source_img = download_image(source_seg_url)
                target_img = download_image(target_seg_url)
                arr_source = np.array(source_img)
                arr_target = np.array(target_img)
                if arr_source.shape != arr_target.shape:
                    # If shapes differ, simply use source bitmap.
                    target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
                else:
                    # For pixels where the target's r-value is 0, use the source's pixel.
                    r_target = arr_target[:, :, 0]
                    mask = (r_target == 0)
                    merged_arr = arr_target.copy()
                    merged_arr[mask] = arr_source[mask]
                    # Upload the merged image.
                    merged_img = Image.fromarray(merged_arr.astype(np.uint8))
                    buf = io.BytesIO()
                    merged_img.save(buf, format="PNG")
                    buf.seek(0)
                    base = os.path.basename(target_seg_url)
                    name, _ = os.path.splitext(base)
                    new_filename = f"{name}_merged.png"
                    asset = client.upload_asset(buf, filename=new_filename)
                    merged_url = asset.url
                    target_frame.setdefault("segmentation_bitmap", {})["url"] = merged_url
                    # Determine unique r-values from the merged bitmap.
                    unique_vals = set(np.unique(merged_arr[:, :, 0])) - {0}
                    # Filter annotations: keep only those whose id (as int) is in unique_vals.
                    filtered_annotations = []
                    for ann in target_frame.get("annotations", []):
                        try:
                            ann_id = int(ann.get("id"))
                            if ann_id in unique_vals:
                                filtered_annotations.append(ann)
                        except (ValueError, TypeError):
                            pass
                    target_frame["annotations"] = filtered_annotations
            except Exception:
                # If any error occurs during merge, fall back to using the source segmentation bitmap.
                target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
        else:
            # If one of the bitmaps is missing, use the source's.
            target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
    
    try:
        client.update_label(sample_uuid, labelset="ground-truth", attributes=attributes)
    except Exception as e:
        return f"Error updating label on Segments.ai during annotation copy: {e}"
    
    target_frames_str = ", ".join(str(tgt+1) for tgt in target_indexes)
    return (f"Annotations merged from frame {source_index+1} into frames {target_frames_str}.\n"
            f"Bitmap updated with merged r-values and annotations filtered accordingly.")

# ---------------------- Main UI ----------------------

st.title("Copy/Merge Annotations to Target Frames")

# Prompt user for API key as the first input (use type="password" for security if desired)
api_key = st.text_input("API Key", type="password")
sample_uuid = st.text_input("Sample UUID", value="")
source_frame_num = st.number_input("Source Frame Number (1-indexed)", min_value=1, step=1)
target_frames_str = st.text_input("Target Frame Numbers (comma-separated or range, e.g., '2,4-6')", value="2")

if "result" not in st.session_state:
    st.session_state["result"] = ""
if "original_label" not in st.session_state:
    st.session_state["original_label"] = ""
if "new_label" not in st.session_state:
    st.session_state["new_label"] = ""

if st.button("Update and Copy Annotations"):
    if not api_key:
        st.error("Please enter your API Key.")
    else:
        client = SegmentsClient(api_key)
        try:
            orig_label_obj = client.get_label(sample_uuid)
        except Exception as e:
            st.error("Error retrieving original label: " + str(e))
            orig_label_obj = None
        if orig_label_obj is not None:
            original_label_json = json.dumps(orig_label_obj.attributes.model_dump(), indent=4)
            try:
                target_frames_nums = parse_frame_numbers(target_frames_str)
            except Exception as e:
                st.error(f"Error parsing target frame numbers: {e}")
                target_frames_nums = []
            source_index = int(source_frame_num) - 1
            target_indexes = [n - 1 for n in target_frames_nums]
            # Only update the frames that are in the source or target list.
            frames_to_update = set([source_index] + target_indexes)
            
            # --- Warning Check: Verify no track_id > 255 in selected frames ---
            attributes = orig_label_obj.attributes.model_dump()
            frames = attributes.get("frames", [])
            # Dictionary to record offending track_id -> first (1-indexed) frame number where it appears
            track_id_warnings = {}
            # Also accumulate id and track_id values in the selected frames (for reference)
            selected_existing_values = set()
            for i in sorted(frames_to_update):
                if i < 0 or i >= len(frames):
                    continue
                frame = frames[i]
                for ann in frame.get("annotations", []):
                    try:
                        t_id = int(ann.get("track_id"))
                        selected_existing_values.add(t_id)
                        selected_existing_values.add(int(ann.get("id")))
                        if t_id > 255 and t_id not in track_id_warnings:
                            track_id_warnings[t_id] = i + 1  # record first appearance (1-indexed)
                    except Exception:
                        continue
            
            # Compute available lowest values across ALL frames.
            all_existing_values = set()
            for frame in frames:
                for ann in frame.get("annotations", []):
                    try:
                        t_id = int(ann.get("track_id"))
                        all_existing_values.add(t_id)
                        all_existing_values.add(int(ann.get("id")))
                    except Exception:
                        continue
            num_offending = len(track_id_warnings)
            lowest_available_list = []
            candidate = 1
            while len(lowest_available_list) < num_offending:
                if candidate not in all_existing_values:
                    lowest_available_list.append(candidate)
                candidate += 1

            if track_id_warnings:
                warning_message = "Warning: The following track_id values exceed 255:\n"
                for t_id, frame_no in sorted(track_id_warnings.items()):
                    warning_message += f" - Track_id {t_id} appears first in frame {frame_no}.\n"
                warning_message += "\nPlease change these values on Segments.ai before proceeding.\n"
                warning_message += f"The lowest available id values (across all frames) are: {', '.join(map(str, lowest_available_list))}."
                st.error(warning_message)
            else:
                update_summary = update_datalabel(sample_uuid, api_key, frames_to_update)
                if update_summary.startswith("Error"):
                    st.error(update_summary)
                else:
                    try:
                        # Retrieve the label after the update.
                        label = client.get_label(sample_uuid)
                    except Exception as e:
                        st.error("Error retrieving updated label: " + str(e))
                        label = None
                    if label is not None:
                        attributes = label.attributes.model_dump()
                        copy_summary = copy_annotations_to_frames(client, attributes, sample_uuid, source_index, target_indexes)
                        final_summary = update_summary + "\n\n" + copy_summary
                        st.session_state["result"] = final_summary
                        # Retrieve the final updated label.
                        try:
                            label_after = client.get_label(sample_uuid)
                        except Exception as e:
                            st.error("Error retrieving final updated label: " + str(e))
                            label_after = None
                        if label_after is not None:
                            new_label_json = json.dumps(label_after.attributes.model_dump(), indent=4)
                            st.session_state["original_label"] = original_label_json
                            st.session_state["new_label"] = new_label_json

if st.session_state["result"]:
    st.text_area("Output", value=st.session_state["result"], height=150)
    st.download_button("Download Original Label",
                       data=st.session_state["original_label"],
                       file_name=f"{sample_uuid}_original.json",
                       mime="application/json")
    st.download_button("Download Updated Label",
                       data=st.session_state["new_label"],
                       file_name=f"{sample_uuid}_updated.json",
                       mime="application/json")