Tomatillo's picture
Update src/streamlit_app.py
b4f72c6 verified
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import os
import io
import json
import numpy as np
from PIL import Image
import requests
from segments import SegmentsClient
"""
# Copy images from one frame to other frames in the same sample
This HF-application first updates selected frames for a given sample UUID by
replacing each annotation’s id with its track_id and updating the segmentation bitmap
accordingly (to avoid potential conflicts). Only the frames specified by the source or target
frame numbers are processed.
Then it copies annotations from one source frame to one or more target frames. When copying:
- The annotations from the source frame are merged with the target's annotations,
adding only those from the source that are not already present (based on id).
- The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding
r-value from the source is used.
- After merging, the bitmap is scanned for unique r-values and any annotation in the target
frame whose id is not present in these unique values is deleted.
Finally, the updated datalabel is uploaded.
**Important**
If a loop is detected in any frame’s id mappings, or if any track_id in the
selected frames is greater than 255, no changes are made and nothing is uploaded to Segments.ai.
For each track_id higher than 255, the warning will include:
- Which track_id is too high and the first frame (1-indexed) where it appears.
- The lowest available id values (searched across all frames) – one for each offending track_id.
The user must resolve these issues before updating.
"""
# ---------------- Utility Functions ----------------
def download_image(url: str) -> Image.Image:
"""Download an image from the given URL and return a PIL Image in RGB mode."""
resp = requests.get(url)
resp.raise_for_status()
return Image.open(io.BytesIO(resp.content)).convert("RGB")
def topological_sort(mapping: dict) -> list:
"""
Given a mapping (original id -> new id) for non-trivial changes,
compute a processing order so that if a new id exists among the original ids,
it is processed first.
Returns the order as a list, or None if a cycle is detected.
"""
nodes = set(mapping.keys())
graph = {}
for x in mapping:
new_id = mapping[x]
if new_id in nodes:
graph.setdefault(new_id, set()).add(x)
visited = {}
result = []
cycle_found = False
def dfs(node):
nonlocal cycle_found
if cycle_found:
return
if node in visited:
if visited[node] == "visiting":
cycle_found = True
return
visited[node] = "visiting"
for neighbor in graph.get(node, set()):
dfs(neighbor)
visited[node] = "visited"
result.append(node)
for node in nodes:
if node not in visited:
dfs(node)
if cycle_found:
return None
return result[::-1]
def detect_cycle(mapping: dict):
"""
For each chain in the mapping (original id -> new id), follow it.
If a cycle is detected, return the first conflicting mapping as a tuple (original, new).
Otherwise, return None.
"""
for orig in mapping:
visited = set()
current = orig
while current in mapping:
if current in visited:
return (orig, mapping[orig])
visited.add(current)
current = mapping[current]
return None
def parse_frame_numbers(frame_str: str):
"""
Parse a string representing frame numbers.
Accepts comma-separated values and ranges (e.g., "2,4-6,8").
Returns a list of integers (1-indexed).
"""
result = []
for part in frame_str.split(","):
part = part.strip()
if not part:
continue
if "-" in part:
tokens = part.split("-")
start = int(tokens[0].strip())
end = int(tokens[1].strip())
result.extend(range(start, end+1))
else:
result.append(int(part))
return result
# ---------------- Update Ids and Bitmaps ----------------
def update_frame_annotations_and_bitmap(client, frame: dict) -> (int, bool, tuple):
"""
For a given frame, update annotations (set id = track_id) and update the segmentation bitmap.
Only non-trivial mappings (where original id != track_id) are processed.
The bitmap's R channel is updated accordingly.
IMPORTANT: If a loop is detected in the mapping, the function returns immediately without
modifying any values in the frame.
Returns a tuple:
(collision_count, cycle_detected, conflict_pair)
"""
annotations = frame.get("annotations", [])
mapping = {}
original_ids = set()
for ann in annotations:
try:
orig_id = int(ann.get("id"))
new_id = int(ann.get("track_id"))
if orig_id != new_id:
mapping[orig_id] = new_id
original_ids.add(orig_id)
except (ValueError, TypeError):
continue
# Check for a cycle first. If a cycle is detected, do not modify the frame.
conflict = detect_cycle(mapping)
if conflict is not None:
return 0, True, conflict
collision_count = sum(1 for orig, new in mapping.items() if new in original_ids)
if mapping:
order = topological_sort(mapping)
if order is not None:
seg_info = frame.get("segmentation_bitmap", {})
seg_url = seg_info.get("url")
if seg_url:
try:
image = download_image(seg_url)
arr = np.array(image)
r_channel = arr[:, :, 0]
for orig in order:
new = mapping[orig]
r_channel[r_channel == orig] = new
arr[:, :, 0] = r_channel
updated_image = Image.fromarray(arr.astype(np.uint8))
buf = io.BytesIO()
updated_image.save(buf, format="PNG")
buf.seek(0)
base = os.path.basename(seg_url)
name, _ = os.path.splitext(base)
new_filename = f"{name}_updated.png"
asset = client.upload_asset(buf, filename=new_filename)
new_url = asset.url
frame.setdefault("segmentation_bitmap", {})["url"] = new_url
except Exception:
pass
# Update all annotations: set id = track_id.
for ann in annotations:
ann["id"] = ann.get("track_id")
return collision_count, False, None
def update_datalabel(sample_uuid: str, api_key: str, frames_to_update: set, labelset: str = "ground-truth") -> str:
"""
Retrieves the label for the given sample UUID, updates only the specified frames by modifying annotations
and updating the segmentation bitmap (via update_frame_annotations_and_bitmap), then uploads
the updated datalabel.
IMPORTANT: If a loop is detected in any processed frame, no changes are applied and nothing is uploaded.
The user must resolve the loop before updating Segments.ai.
frames_to_update: a set of 0-indexed frame indices that should be processed.
Returns a single summary line describing the operation with extra newlines for readability.
"""
client = SegmentsClient(api_key)
try:
label = client.get_label(sample_uuid)
except Exception as e:
return f"Error retrieving label for sample {sample_uuid}: {e}"
attributes = label.attributes.model_dump()
frames = attributes.get("frames", [])
total_collisions = 0
conflict_found = None
for i, frame in enumerate(frames):
if i in frames_to_update:
collisions, cycle, conflict_pair = update_frame_annotations_and_bitmap(client, frame)
if cycle:
conflict_found = conflict_pair
break
total_collisions += collisions
if conflict_found is not None:
return (f"Error: Cycle detected in annotation id mappings: original id {conflict_found[0]} -> new id {conflict_found[1]}.\n"
"Please resolve the loop before updating. No changes have been uploaded to Segments.ai.")
try:
client.update_label(sample_uuid, labelset=labelset, attributes=attributes)
except Exception as e:
return f"Error updating label on Segments.ai: {e}"
if total_collisions > 0:
summary = ("Updated annotation ids with track_ids and updated bitmap for the specified frames.\n"
"Collisions in ids were detected and resolved using logical processing.")
else:
summary = "Updated annotation ids with track_ids and updated bitmap for the specified frames."
return summary
# ---------------- Copy Annotations ----------------
def copy_annotations_to_frames(client, attributes: dict, sample_uuid: str, source_index: int, target_indexes: list) -> str:
"""
Copies annotations from the source frame to each target frame and merges the segmentation bitmap.
For each target frame:
- The annotations from the source frame are merged with the target's annotations.
Only those annotations from the source that are not already present (based on id) are added.
- The segmentation bitmap is merged: for each pixel, if the target's r-value is 0, the corresponding
r-value from the source is used.
- After merging, only annotations whose id appears in the set of unique r-values in the merged bitmap are kept.
After processing all target frames, the updated label is uploaded.
Returns a single summary line with extra newlines.
"""
frames = attributes.get("frames", [])
if source_index < 0 or source_index >= len(frames):
return f"Source frame index {source_index+1} is out of range."
source_frame = frames[source_index]
for tgt in target_indexes:
if tgt < 0 or tgt >= len(frames):
return f"Target frame index {tgt+1} is out of range."
target_frame = frames[tgt]
# Merge annotations: keep existing target annotations and add source annotations not already present.
target_annotations = target_frame.get("annotations", [])
source_annotations = source_frame.get("annotations", [])
existing_ids = {ann.get("id") for ann in target_annotations}
for ann in source_annotations:
if ann.get("id") not in existing_ids:
target_annotations.append(ann)
target_frame["annotations"] = target_annotations
# Merge segmentation bitmaps if both exist.
source_seg_url = source_frame.get("segmentation_bitmap", {}).get("url")
target_seg_url = target_frame.get("segmentation_bitmap", {}).get("url")
if source_seg_url and target_seg_url:
try:
source_img = download_image(source_seg_url)
target_img = download_image(target_seg_url)
arr_source = np.array(source_img)
arr_target = np.array(target_img)
if arr_source.shape != arr_target.shape:
# If shapes differ, simply use source bitmap.
target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
else:
# For pixels where the target's r-value is 0, use the source's pixel.
r_target = arr_target[:, :, 0]
mask = (r_target == 0)
merged_arr = arr_target.copy()
merged_arr[mask] = arr_source[mask]
# Upload the merged image.
merged_img = Image.fromarray(merged_arr.astype(np.uint8))
buf = io.BytesIO()
merged_img.save(buf, format="PNG")
buf.seek(0)
base = os.path.basename(target_seg_url)
name, _ = os.path.splitext(base)
new_filename = f"{name}_merged.png"
asset = client.upload_asset(buf, filename=new_filename)
merged_url = asset.url
target_frame.setdefault("segmentation_bitmap", {})["url"] = merged_url
# Determine unique r-values from the merged bitmap.
unique_vals = set(np.unique(merged_arr[:, :, 0])) - {0}
# Filter annotations: keep only those whose id (as int) is in unique_vals.
filtered_annotations = []
for ann in target_frame.get("annotations", []):
try:
ann_id = int(ann.get("id"))
if ann_id in unique_vals:
filtered_annotations.append(ann)
except (ValueError, TypeError):
pass
target_frame["annotations"] = filtered_annotations
except Exception:
# If any error occurs during merge, fall back to using the source segmentation bitmap.
target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
else:
# If one of the bitmaps is missing, use the source's.
target_frame["segmentation_bitmap"] = source_frame.get("segmentation_bitmap", {})
try:
client.update_label(sample_uuid, labelset="ground-truth", attributes=attributes)
except Exception as e:
return f"Error updating label on Segments.ai during annotation copy: {e}"
target_frames_str = ", ".join(str(tgt+1) for tgt in target_indexes)
return (f"Annotations merged from frame {source_index+1} into frames {target_frames_str}.\n"
f"Bitmap updated with merged r-values and annotations filtered accordingly.")
# ---------------------- Main UI ----------------------
st.title("Copy/Merge Annotations to Target Frames")
# Prompt user for API key as the first input (use type="password" for security if desired)
api_key = st.text_input("API Key", type="password")
sample_uuid = st.text_input("Sample UUID", value="")
source_frame_num = st.number_input("Source Frame Number (1-indexed)", min_value=1, step=1)
target_frames_str = st.text_input("Target Frame Numbers (comma-separated or range, e.g., '2,4-6')", value="2")
if "result" not in st.session_state:
st.session_state["result"] = ""
if "original_label" not in st.session_state:
st.session_state["original_label"] = ""
if "new_label" not in st.session_state:
st.session_state["new_label"] = ""
if st.button("Update and Copy Annotations"):
if not api_key:
st.error("Please enter your API Key.")
else:
client = SegmentsClient(api_key)
try:
orig_label_obj = client.get_label(sample_uuid)
except Exception as e:
st.error("Error retrieving original label: " + str(e))
orig_label_obj = None
if orig_label_obj is not None:
original_label_json = json.dumps(orig_label_obj.attributes.model_dump(), indent=4)
try:
target_frames_nums = parse_frame_numbers(target_frames_str)
except Exception as e:
st.error(f"Error parsing target frame numbers: {e}")
target_frames_nums = []
source_index = int(source_frame_num) - 1
target_indexes = [n - 1 for n in target_frames_nums]
# Only update the frames that are in the source or target list.
frames_to_update = set([source_index] + target_indexes)
# --- Warning Check: Verify no track_id > 255 in selected frames ---
attributes = orig_label_obj.attributes.model_dump()
frames = attributes.get("frames", [])
# Dictionary to record offending track_id -> first (1-indexed) frame number where it appears
track_id_warnings = {}
# Also accumulate id and track_id values in the selected frames (for reference)
selected_existing_values = set()
for i in sorted(frames_to_update):
if i < 0 or i >= len(frames):
continue
frame = frames[i]
for ann in frame.get("annotations", []):
try:
t_id = int(ann.get("track_id"))
selected_existing_values.add(t_id)
selected_existing_values.add(int(ann.get("id")))
if t_id > 255 and t_id not in track_id_warnings:
track_id_warnings[t_id] = i + 1 # record first appearance (1-indexed)
except Exception:
continue
# Compute available lowest values across ALL frames.
all_existing_values = set()
for frame in frames:
for ann in frame.get("annotations", []):
try:
t_id = int(ann.get("track_id"))
all_existing_values.add(t_id)
all_existing_values.add(int(ann.get("id")))
except Exception:
continue
num_offending = len(track_id_warnings)
lowest_available_list = []
candidate = 1
while len(lowest_available_list) < num_offending:
if candidate not in all_existing_values:
lowest_available_list.append(candidate)
candidate += 1
if track_id_warnings:
warning_message = "Warning: The following track_id values exceed 255:\n"
for t_id, frame_no in sorted(track_id_warnings.items()):
warning_message += f" - Track_id {t_id} appears first in frame {frame_no}.\n"
warning_message += "\nPlease change these values on Segments.ai before proceeding.\n"
warning_message += f"The lowest available id values (across all frames) are: {', '.join(map(str, lowest_available_list))}."
st.error(warning_message)
else:
update_summary = update_datalabel(sample_uuid, api_key, frames_to_update)
if update_summary.startswith("Error"):
st.error(update_summary)
else:
try:
# Retrieve the label after the update.
label = client.get_label(sample_uuid)
except Exception as e:
st.error("Error retrieving updated label: " + str(e))
label = None
if label is not None:
attributes = label.attributes.model_dump()
copy_summary = copy_annotations_to_frames(client, attributes, sample_uuid, source_index, target_indexes)
final_summary = update_summary + "\n\n" + copy_summary
st.session_state["result"] = final_summary
# Retrieve the final updated label.
try:
label_after = client.get_label(sample_uuid)
except Exception as e:
st.error("Error retrieving final updated label: " + str(e))
label_after = None
if label_after is not None:
new_label_json = json.dumps(label_after.attributes.model_dump(), indent=4)
st.session_state["original_label"] = original_label_json
st.session_state["new_label"] = new_label_json
if st.session_state["result"]:
st.text_area("Output", value=st.session_state["result"], height=150)
st.download_button("Download Original Label",
data=st.session_state["original_label"],
file_name=f"{sample_uuid}_original.json",
mime="application/json")
st.download_button("Download Updated Label",
data=st.session_state["new_label"],
file_name=f"{sample_uuid}_updated.json",
mime="application/json")