|
import altair as alt |
|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
import os |
|
import io |
|
import json |
|
import numpy as np |
|
from PIL import Image |
|
import requests |
|
from segments import SegmentsClient |
|
|
|
""" |
|
# Copy images from one frame to other frames in the same sample |
|
|
|
This HF-application first updates selected frames for a given sample UUID by |
|
replacing each annotation’s id with its track_id and updating the segmentation bitmap |
|
accordingly (to avoid potential conflicts). Only the frames specified by the source or target |
|
frame numbers are processed. |
|
|
|
Then it copies annotations from one source frame to one or more target frames. When copying: |
|
- The annotations from the source frame are merged with the target's annotations, |
|
adding only those from the source that are not already present (based on id). |
|
- The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding |
|
r-value from the source is used. |
|
- After merging, the bitmap is scanned for unique r-values and any annotation in the target |
|
frame whose id is not present in these unique values is deleted. |
|
Finally, the updated datalabel is uploaded. |
|
|
|
**Important** |
|
If a loop is detected in any frame’s id mappings, or if any track_id in the |
|
selected frames is greater than 255, no changes are made and nothing is uploaded to Segments.ai. |
|
For each track_id higher than 255, the warning will include: |
|
- Which track_id is too high and the first frame (1-indexed) where it appears. |
|
- The lowest available id values (searched across all frames) – one for each offending track_id. |
|
The user must resolve these issues before updating. |
|
|
|
""" |
|
|
|
|
|
|
|
def download_image(url: str) -> Image.Image: |
|
"""Download an image from the given URL and return a PIL Image in RGB mode.""" |
|
resp = requests.get(url) |
|
resp.raise_for_status() |
|
return Image.open(io.BytesIO(resp.content)).convert("RGB") |
|
|
|
def topological_sort(mapping: dict) -> list: |
|
""" |
|
Given a mapping (original id -> new id) for non-trivial changes, |
|
compute a processing order so that if a new id exists among the original ids, |
|
it is processed first. |
|
Returns the order as a list, or None if a cycle is detected. |
|
""" |
|
nodes = set(mapping.keys()) |
|
graph = {} |
|
for x in mapping: |
|
new_id = mapping[x] |
|
if new_id in nodes: |
|
graph.setdefault(new_id, set()).add(x) |
|
visited = {} |
|
result = [] |
|
cycle_found = False |
|
|
|
def dfs(node): |
|
nonlocal cycle_found |
|
if cycle_found: |
|
return |
|
if node in visited: |
|
if visited[node] == "visiting": |
|
cycle_found = True |
|
return |
|
visited[node] = "visiting" |
|
for neighbor in graph.get(node, set()): |
|
dfs(neighbor) |
|
visited[node] = "visited" |
|
result.append(node) |
|
|
|
for node in nodes: |
|
if node not in visited: |
|
dfs(node) |
|
if cycle_found: |
|
return None |
|
return result[::-1] |
|
|
|
def detect_cycle(mapping: dict): |
|
""" |
|
For each chain in the mapping (original id -> new id), follow it. |
|
If a cycle is detected, return the first conflicting mapping as a tuple (original, new). |
|
Otherwise, return None. |
|
""" |
|
for orig in mapping: |
|
visited = set() |
|
current = orig |
|
while current in mapping: |
|
if current in visited: |
|
return (orig, mapping[orig]) |
|
visited.add(current) |
|
current = mapping[current] |
|
return None |
|
|
|
def parse_frame_numbers(frame_str: str): |
|
""" |
|
Parse a string representing frame numbers. |
|
Accepts comma-separated values and ranges (e.g., "2,4-6,8"). |
|
Returns a list of integers (1-indexed). |
|
""" |
|
result = [] |
|
for part in frame_str.split(","): |
|
part = part.strip() |
|
if not part: |
|
continue |
|
if "-" in part: |
|
tokens = part.split("-") |
|
start = int(tokens[0].strip()) |
|
end = int(tokens[1].strip()) |
|
result.extend(range(start, end+1)) |
|
else: |
|
result.append(int(part)) |
|
return result |
|
|
|
|
|
|
|
def update_frame_annotations_and_bitmap(client, frame: dict) -> (int, bool, tuple): |
|
""" |
|
For a given frame, update annotations (set id = track_id) and update the segmentation bitmap. |
|
Only non-trivial mappings (where original id != track_id) are processed. |
|
The bitmap's R channel is updated accordingly. |
|
|
|
IMPORTANT: If a loop is detected in the mapping, the function returns immediately without |
|
modifying any values in the frame. |
|
|
|
Returns a tuple: |
|
(collision_count, cycle_detected, conflict_pair) |
|
""" |
|
annotations = frame.get("annotations", []) |
|
mapping = {} |
|
original_ids = set() |
|
for ann in annotations: |
|
try: |
|
orig_id = int(ann.get("id")) |
|
new_id = int(ann.get("track_id")) |
|
if orig_id != new_id: |
|
mapping[orig_id] = new_id |
|
original_ids.add(orig_id) |
|
except (ValueError, TypeError): |
|
continue |
|
|
|
|
|
conflict = detect_cycle(mapping) |
|
if conflict is not None: |
|
return 0, True, conflict |
|
|
|
collision_count = sum(1 for orig, new in mapping.items() if new in original_ids) |
|
if mapping: |
|
order = topological_sort(mapping) |
|
if order is not None: |
|
seg_info = frame.get("segmentation_bitmap", {}) |
|
seg_url = seg_info.get("url") |
|
if seg_url: |
|
try: |
|
image = download_image(seg_url) |
|
arr = np.array(image) |
|
r_channel = arr[:, :, 0] |
|
for orig in order: |
|
new = mapping[orig] |
|
r_channel[r_channel == orig] = new |
|
arr[:, :, 0] = r_channel |
|
updated_image = Image.fromarray(arr.astype(np.uint8)) |
|
buf = io.BytesIO() |
|
updated_image.save(buf, format="PNG") |
|
buf.seek(0) |
|
base = os.path.basename(seg_url) |
|
name, _ = os.path.splitext(base) |
|
new_filename = f"{name}_updated.png" |
|
asset = client.upload_asset(buf, filename=new_filename) |
|
new_url = asset.url |
|
frame.setdefault("segmentation_bitmap", {})["url"] = new_url |
|
except Exception: |
|
pass |
|
|
|
for ann in annotations: |
|
ann["id"] = ann.get("track_id") |
|
return collision_count, False, None |
|
|
|
def update_datalabel(sample_uuid: str, api_key: str, frames_to_update: set, labelset: str = "ground-truth") -> str: |
|
""" |
|
Retrieves the label for the given sample UUID, updates only the specified frames by modifying annotations |
|
and updating the segmentation bitmap (via update_frame_annotations_and_bitmap), then uploads |
|
the updated datalabel. |
|
|
|
IMPORTANT: If a loop is detected in any processed frame, no changes are applied and nothing is uploaded. |
|
The user must resolve the loop before updating Segments.ai. |
|
|
|
frames_to_update: a set of 0-indexed frame indices that should be processed. |
|
|
|
Returns a single summary line describing the operation with extra newlines for readability. |
|
""" |
|
client = SegmentsClient(api_key) |
|
try: |
|
label = client.get_label(sample_uuid) |
|
except Exception as e: |
|
return f"Error retrieving label for sample {sample_uuid}: {e}" |
|
|
|
attributes = label.attributes.model_dump() |
|
frames = attributes.get("frames", []) |
|
|
|
total_collisions = 0 |
|
conflict_found = None |
|
for i, frame in enumerate(frames): |
|
if i in frames_to_update: |
|
collisions, cycle, conflict_pair = update_frame_annotations_and_bitmap(client, frame) |
|
if cycle: |
|
conflict_found = conflict_pair |
|
break |
|
total_collisions += collisions |
|
|
|
if conflict_found is not None: |
|
return (f"Error: Cycle detected in annotation id mappings: original id {conflict_found[0]} -> new id {conflict_found[1]}.\n" |
|
"Please resolve the loop before updating. No changes have been uploaded to Segments.ai.") |
|
|
|
try: |
|
client.update_label(sample_uuid, labelset=labelset, attributes=attributes) |
|
except Exception as e: |
|
return f"Error updating label on Segments.ai: {e}" |
|
|
|
if total_collisions > 0: |
|
summary = ("Updated annotation ids with track_ids and updated bitmap for the specified frames.\n" |
|
"Collisions in ids were detected and resolved using logical processing.") |
|
else: |
|
summary = "Updated annotation ids with track_ids and updated bitmap for the specified frames." |
|
|
|
return summary |
|
|
|
|
|
|
|
def copy_annotations_to_frames(client, attributes: dict, sample_uuid: str, source_index: int, target_indexes: list) -> str: |
|
""" |
|
Copies annotations from the source frame to each target frame and merges the segmentation bitmap. |
|
For each target frame: |
|
- The annotations from the source frame are merged with the target's annotations. |
|
Only those annotations from the source that are not already present (based on id) are added. |
|
- The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding |
|
r-value from the source is used. |
|
- After merging, only annotations whose id appears in the set of unique r-values in the merged bitmap are kept. |
|
After processing all target frames, the updated label is uploaded. |
|
Returns a single summary line with extra newlines. |
|
""" |
|
frames = attributes.get("frames", []) |
|
if source_index < 0 or source_index >= len(frames): |
|
return f"Source frame index {source_index+1} is out of range." |
|
source_frame = frames[source_index] |
|
|
|
for tgt in target_indexes: |
|
if tgt < 0 or tgt >= len(frames): |
|
return f"Target frame index {tgt+1} is out of range." |
|
target_frame = frames[tgt] |
|
|
|
target_annotations = target_frame.get("annotations", []) |
|
source_annotations = source_frame.get("annotations", []) |
|
existing_ids = {ann.get("id") for ann in target_annotations} |
|
for ann in source_annotations: |
|
if ann.get("id") not in existing_ids: |
|
target_annotations.append(ann) |
|
target_frame["annotations"] = target_annotations |
|
|
|
|
|
source_seg_url = source_frame.get("segmentation_bitmap", {}).get("url") |
|
target_seg_url = target_frame.get("segmentation_bitmap", {}).get("url") |
|
if source_seg_url and target_seg_url: |
|
try: |
|
source_img = download_image(source_seg_url) |
|
target_img = download_image(target_seg_url) |
|
arr_source = np.array(source_img) |
|
arr_target = np.array(target_img) |
|
if arr_source.shape != arr_target.shape: |
|
|
|
target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {}) |
|
else: |
|
|
|
r_target = arr_target[:, :, 0] |
|
mask = (r_target == 0) |
|
merged_arr = arr_target.copy() |
|
merged_arr[mask] = arr_source[mask] |
|
|
|
merged_img = Image.fromarray(merged_arr.astype(np.uint8)) |
|
buf = io.BytesIO() |
|
merged_img.save(buf, format="PNG") |
|
buf.seek(0) |
|
base = os.path.basename(target_seg_url) |
|
name, _ = os.path.splitext(base) |
|
new_filename = f"{name}_merged.png" |
|
asset = client.upload_asset(buf, filename=new_filename) |
|
merged_url = asset.url |
|
target_frame.setdefault("segmentation_bitmap", {})["url"] = merged_url |
|
|
|
unique_vals = set(np.unique(merged_arr[:, :, 0])) - {0} |
|
|
|
filtered_annotations = [] |
|
for ann in target_frame.get("annotations", []): |
|
try: |
|
ann_id = int(ann.get("id")) |
|
if ann_id in unique_vals: |
|
filtered_annotations.append(ann) |
|
except (ValueError, TypeError): |
|
pass |
|
target_frame["annotations"] = filtered_annotations |
|
except Exception: |
|
|
|
target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {}) |
|
else: |
|
|
|
target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {}) |
|
|
|
try: |
|
client.update_label(sample_uuid, labelset="ground-truth", attributes=attributes) |
|
except Exception as e: |
|
return f"Error updating label on Segments.ai during annotation copy: {e}" |
|
|
|
target_frames_str = ", ".join(str(tgt+1) for tgt in target_indexes) |
|
return (f"Annotations merged from frame {source_index+1} into frames {target_frames_str}.\n" |
|
f"Bitmap updated with merged r-values and annotations filtered accordingly.") |
|
|
|
|
|
|
|
st.title("Copy/Merge Annotations to Target Frames") |
|
|
|
|
|
api_key = st.text_input("API Key", type="password") |
|
sample_uuid = st.text_input("Sample UUID", value="") |
|
source_frame_num = st.number_input("Source Frame Number (1-indexed)", min_value=1, step=1) |
|
target_frames_str = st.text_input("Target Frame Numbers (comma-separated or range, e.g., '2,4-6')", value="2") |
|
|
|
if "result" not in st.session_state: |
|
st.session_state["result"] = "" |
|
if "original_label" not in st.session_state: |
|
st.session_state["original_label"] = "" |
|
if "new_label" not in st.session_state: |
|
st.session_state["new_label"] = "" |
|
|
|
if st.button("Update and Copy Annotations"): |
|
if not api_key: |
|
st.error("Please enter your API Key.") |
|
else: |
|
client = SegmentsClient(api_key) |
|
try: |
|
orig_label_obj = client.get_label(sample_uuid) |
|
except Exception as e: |
|
st.error("Error retrieving original label: " + str(e)) |
|
orig_label_obj = None |
|
if orig_label_obj is not None: |
|
original_label_json = json.dumps(orig_label_obj.attributes.model_dump(), indent=4) |
|
try: |
|
target_frames_nums = parse_frame_numbers(target_frames_str) |
|
except Exception as e: |
|
st.error(f"Error parsing target frame numbers: {e}") |
|
target_frames_nums = [] |
|
source_index = int(source_frame_num) - 1 |
|
target_indexes = [n - 1 for n in target_frames_nums] |
|
|
|
frames_to_update = set([source_index] + target_indexes) |
|
|
|
|
|
attributes = orig_label_obj.attributes.model_dump() |
|
frames = attributes.get("frames", []) |
|
|
|
track_id_warnings = {} |
|
|
|
selected_existing_values = set() |
|
for i in sorted(frames_to_update): |
|
if i < 0 or i >= len(frames): |
|
continue |
|
frame = frames[i] |
|
for ann in frame.get("annotations", []): |
|
try: |
|
t_id = int(ann.get("track_id")) |
|
selected_existing_values.add(t_id) |
|
selected_existing_values.add(int(ann.get("id"))) |
|
if t_id > 255 and t_id not in track_id_warnings: |
|
track_id_warnings[t_id] = i + 1 |
|
except Exception: |
|
continue |
|
|
|
|
|
all_existing_values = set() |
|
for frame in frames: |
|
for ann in frame.get("annotations", []): |
|
try: |
|
t_id = int(ann.get("track_id")) |
|
all_existing_values.add(t_id) |
|
all_existing_values.add(int(ann.get("id"))) |
|
except Exception: |
|
continue |
|
num_offending = len(track_id_warnings) |
|
lowest_available_list = [] |
|
candidate = 1 |
|
while len(lowest_available_list) < num_offending: |
|
if candidate not in all_existing_values: |
|
lowest_available_list.append(candidate) |
|
candidate += 1 |
|
|
|
if track_id_warnings: |
|
warning_message = "Warning: The following track_id values exceed 255:\n" |
|
for t_id, frame_no in sorted(track_id_warnings.items()): |
|
warning_message += f" - Track_id {t_id} appears first in frame {frame_no}.\n" |
|
warning_message += "\nPlease change these values on Segments.ai before proceeding.\n" |
|
warning_message += f"The lowest available id values (across all frames) are: {', '.join(map(str, lowest_available_list))}." |
|
st.error(warning_message) |
|
else: |
|
update_summary = update_datalabel(sample_uuid, api_key, frames_to_update) |
|
if update_summary.startswith("Error"): |
|
st.error(update_summary) |
|
else: |
|
try: |
|
|
|
label = client.get_label(sample_uuid) |
|
except Exception as e: |
|
st.error("Error retrieving updated label: " + str(e)) |
|
label = None |
|
if label is not None: |
|
attributes = label.attributes.model_dump() |
|
copy_summary = copy_annotations_to_frames(client, attributes, sample_uuid, source_index, target_indexes) |
|
final_summary = update_summary + "\n\n" + copy_summary |
|
st.session_state["result"] = final_summary |
|
|
|
try: |
|
label_after = client.get_label(sample_uuid) |
|
except Exception as e: |
|
st.error("Error retrieving final updated label: " + str(e)) |
|
label_after = None |
|
if label_after is not None: |
|
new_label_json = json.dumps(label_after.attributes.model_dump(), indent=4) |
|
st.session_state["original_label"] = original_label_json |
|
st.session_state["new_label"] = new_label_json |
|
|
|
if st.session_state["result"]: |
|
st.text_area("Output", value=st.session_state["result"], height=150) |
|
st.download_button("Download Original Label", |
|
data=st.session_state["original_label"], |
|
file_name=f"{sample_uuid}_original.json", |
|
mime="application/json") |
|
st.download_button("Download Updated Label", |
|
data=st.session_state["new_label"], |
|
file_name=f"{sample_uuid}_updated.json", |
|
mime="application/json") |
|
|