Spaces:
Running
Running
import streamlit as st | |
import io | |
import csv | |
from datetime import datetime | |
from segments import SegmentsClient | |
from get_labels_from_samples import ( | |
get_samples as get_samples_objects, | |
export_frames_and_annotations, | |
export_sensor_frames_and_annotations, | |
export_all_sensor_frames_and_annotations | |
) | |
def init_session_state(): | |
if 'csv_content' not in st.session_state: | |
st.session_state.csv_content = None | |
if 'error' not in st.session_state: | |
st.session_state.error = None | |
def init_client(api_key: str) -> SegmentsClient: | |
"""Initialize the Segments.ai API client using the provided API key.""" | |
return SegmentsClient(api_key) | |
def parse_classes(input_str: str) -> list: | |
""" | |
Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints. | |
""" | |
classes = [] | |
tokens = input_str.split(',') | |
for token in tokens: | |
token = token.strip() | |
if '-' in token: | |
try: | |
start, end = map(int, token.split('-')) | |
classes.extend(range(start, end + 1)) | |
except ValueError: | |
continue | |
else: | |
try: | |
classes.append(int(token)) | |
except ValueError: | |
continue | |
return sorted(set(classes)) | |
def generate_csv(metrics: list, dataset_identifier: str) -> str: | |
""" | |
Generate CSV content from list of per-sample metrics. | |
Columns: name, sample_url, sensor, num_frames, total_annotations, | |
matching_annotations, labeled_by, reviewed_by | |
""" | |
output = io.StringIO() | |
writer = csv.writer(output) | |
writer.writerow([ | |
'name', 'sample_url', 'sensor', 'num_frames', | |
'total_annotations', 'matching_annotations', | |
'labeled_by', 'reviewed_by' | |
]) | |
for m in metrics: | |
url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}" | |
writer.writerow([ | |
m['name'], url, m['sensor'], | |
m['num_frames'], m['total_annotations'], | |
m['matching_annotations'], m['labeled_by'], | |
m['reviewed_by'] | |
]) | |
content = output.getvalue() | |
output.close() | |
return content | |
# ---------------------- | |
# Streamlit UI | |
# ---------------------- | |
init_session_state() | |
st.title("Per-Sample Annotation Counts by Class") | |
api_key = st.text_input("API Key", type="password", key="api_key_input") | |
dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input") | |
classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input") | |
run_button = st.button("Generate CSV", key="run_button") | |
sensor_names = [] | |
is_multisensor = False | |
sensor_select = None | |
samples_objects = [] | |
if api_key and dataset_identifier: | |
try: | |
client = init_client(api_key) | |
samples_objects = get_samples_objects(client, dataset_identifier) | |
if samples_objects: | |
label = client.get_label(samples_objects[0].uuid) | |
sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) | |
if sensors is not None: | |
is_multisensor = True | |
sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors] | |
except Exception as e: | |
st.warning(f"Could not inspect dataset sensors: {e}") | |
if is_multisensor: | |
sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names) | |
if run_button: | |
st.session_state.csv_content = None | |
st.session_state.error = None | |
if not api_key: | |
st.session_state.error = "API Key is required." | |
elif not dataset_identifier: | |
st.session_state.error = "Dataset identifier is required." | |
elif not classes_input: | |
st.session_state.error = "Please specify at least one class." | |
elif is_multisensor and not sensor_select: | |
st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV." | |
else: | |
# Show loader/status message while checking dataset type and generating CSV | |
status_ctx = None | |
try: | |
status_ctx = st.status("Checking dataset type...", expanded=True) | |
except AttributeError: | |
st.info("Checking dataset type...") | |
try: | |
target_classes = parse_classes(classes_input) | |
client = init_client(api_key) | |
metrics = [] | |
# Update loader after dataset type check | |
if status_ctx is not None: | |
status_ctx.update(label="Dataset type checked. Processing samples...", state="running") | |
for sample in samples_objects: | |
try: | |
label = client.get_label(sample.uuid) | |
labelset = getattr(label, 'labelset', '') or '' | |
labeled_by = getattr(label, 'created_by', '') or '' | |
reviewed_by = getattr(label, 'reviewed_by', '') or '' | |
if is_multisensor and sensor_select and sensor_select != 'All sensors': | |
frames_list = export_sensor_frames_and_annotations(label, sensor_select) | |
sensor_val = sensor_select | |
num_frames = len(frames_list) | |
total_annotations = sum(len(f['annotations']) for f in frames_list) | |
matching_annotations = sum( | |
1 | |
for f in frames_list | |
for ann in f['annotations'] | |
if getattr(ann, 'category_id', None) in target_classes | |
) | |
elif is_multisensor and (not sensor_select or sensor_select == 'All sensors'): | |
all_sensor_frames = export_all_sensor_frames_and_annotations(label) | |
for sensor_name, frames_list in all_sensor_frames.items(): | |
num_frames = len(frames_list) | |
total_annotations = sum(len(f['annotations']) for f in frames_list) | |
matching_annotations = sum( | |
1 | |
for f in frames_list | |
for ann in f['annotations'] | |
if getattr(ann, 'category_id', None) in target_classes | |
) | |
metrics.append({ | |
'name': getattr(sample, 'name', sample.uuid), | |
'uuid': sample.uuid, | |
'labelset': labelset, | |
'sensor': sensor_name, | |
'num_frames': num_frames, | |
'total_annotations': total_annotations, | |
'matching_annotations': matching_annotations, | |
'labeled_by': labeled_by, | |
'reviewed_by': reviewed_by | |
}) | |
continue | |
else: | |
frames_list = export_frames_and_annotations(label) | |
sensor_val = '' | |
num_frames = len(frames_list) | |
total_annotations = sum(len(f['annotations']) for f in frames_list) | |
matching_annotations = sum( | |
1 | |
for f in frames_list | |
for ann in f['annotations'] | |
if getattr(ann, 'category_id', None) in target_classes | |
) | |
metrics.append({ | |
'name': getattr(sample, 'name', sample.uuid), | |
'uuid': sample.uuid, | |
'labelset': labelset, | |
'sensor': sensor_val if is_multisensor else '', | |
'num_frames': num_frames, | |
'total_annotations': total_annotations, | |
'matching_annotations': matching_annotations, | |
'labeled_by': labeled_by, | |
'reviewed_by': reviewed_by | |
}) | |
except Exception as e: | |
continue | |
if not metrics: | |
st.session_state.error = "No metrics could be generated for the dataset." | |
else: | |
st.session_state.csv_content = generate_csv(metrics, dataset_identifier) | |
if status_ctx is not None: | |
status_ctx.update(label="CSV generated!", state="complete") | |
except Exception as e: | |
st.session_state.error = f"An error occurred: {e}" | |
if status_ctx is not None: | |
status_ctx.update(label="Error occurred.", state="error") | |
if st.session_state.error: | |
st.error(st.session_state.error) | |
if st.session_state.csv_content: | |
today_str = datetime.now().strftime("%Y%m%d") | |
filename = f"{today_str}_{dataset_identifier}_count-by-class.csv" | |
st.download_button( | |
"Download CSV", | |
data=st.session_state.csv_content, | |
file_name=filename, | |
mime="text/csv" | |
) |