import streamlit as st import io import csv from datetime import datetime from segments import SegmentsClient from get_labels_from_samples import ( get_samples as get_samples_objects, export_frames_and_annotations, export_sensor_frames_and_annotations, export_all_sensor_frames_and_annotations ) def init_session_state(): if 'csv_content' not in st.session_state: st.session_state.csv_content = None if 'error' not in st.session_state: st.session_state.error = None def init_client(api_key: str) -> SegmentsClient: """Initialize the Segments.ai API client using the provided API key.""" return SegmentsClient(api_key) def parse_classes(input_str: str) -> list: """ Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints. """ classes = [] tokens = input_str.split(',') for token in tokens: token = token.strip() if '-' in token: try: start, end = map(int, token.split('-')) classes.extend(range(start, end + 1)) except ValueError: continue else: try: classes.append(int(token)) except ValueError: continue return sorted(set(classes)) def generate_csv(metrics: list, dataset_identifier: str) -> str: """ Generate CSV content from list of per-sample metrics. Columns: name, sample_url, sensor, num_frames, total_annotations, matching_annotations, labeled_by, reviewed_by """ output = io.StringIO() writer = csv.writer(output) writer.writerow([ 'name', 'sample_url', 'sensor', 'num_frames', 'total_annotations', 'matching_annotations', 'labeled_by', 'reviewed_by' ]) for m in metrics: url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}" writer.writerow([ m['name'], url, m['sensor'], m['num_frames'], m['total_annotations'], m['matching_annotations'], m['labeled_by'], m['reviewed_by'] ]) content = output.getvalue() output.close() return content # ---------------------- # Streamlit UI # ---------------------- init_session_state() st.title("Per-Sample Annotation Counts by Class") api_key = st.text_input("API Key", type="password", key="api_key_input") dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input") classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input") run_button = st.button("Generate CSV", key="run_button") sensor_names = [] is_multisensor = False sensor_select = None samples_objects = [] if api_key and dataset_identifier: try: client = init_client(api_key) samples_objects = get_samples_objects(client, dataset_identifier) if samples_objects: label = client.get_label(samples_objects[0].uuid) sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) if sensors is not None: is_multisensor = True sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors] except Exception as e: st.warning(f"Could not inspect dataset sensors: {e}") if is_multisensor: sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names) if run_button: st.session_state.csv_content = None st.session_state.error = None if not api_key: st.session_state.error = "API Key is required." elif not dataset_identifier: st.session_state.error = "Dataset identifier is required." elif not classes_input: st.session_state.error = "Please specify at least one class." elif is_multisensor and not sensor_select: st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV." else: # Show loader/status message while checking dataset type and generating CSV status_ctx = None try: status_ctx = st.status("Checking dataset type...", expanded=True) except AttributeError: st.info("Checking dataset type...") try: target_classes = parse_classes(classes_input) client = init_client(api_key) metrics = [] # Update loader after dataset type check if status_ctx is not None: status_ctx.update(label="Dataset type checked. Processing samples...", state="running") for sample in samples_objects: try: label = client.get_label(sample.uuid) labelset = getattr(label, 'labelset', '') or '' labeled_by = getattr(label, 'created_by', '') or '' reviewed_by = getattr(label, 'reviewed_by', '') or '' if is_multisensor and sensor_select and sensor_select != 'All sensors': frames_list = export_sensor_frames_and_annotations(label, sensor_select) sensor_val = sensor_select num_frames = len(frames_list) total_annotations = sum(len(f['annotations']) for f in frames_list) matching_annotations = sum( 1 for f in frames_list for ann in f['annotations'] if getattr(ann, 'category_id', None) in target_classes ) elif is_multisensor and (not sensor_select or sensor_select == 'All sensors'): all_sensor_frames = export_all_sensor_frames_and_annotations(label) for sensor_name, frames_list in all_sensor_frames.items(): num_frames = len(frames_list) total_annotations = sum(len(f['annotations']) for f in frames_list) matching_annotations = sum( 1 for f in frames_list for ann in f['annotations'] if getattr(ann, 'category_id', None) in target_classes ) metrics.append({ 'name': getattr(sample, 'name', sample.uuid), 'uuid': sample.uuid, 'labelset': labelset, 'sensor': sensor_name, 'num_frames': num_frames, 'total_annotations': total_annotations, 'matching_annotations': matching_annotations, 'labeled_by': labeled_by, 'reviewed_by': reviewed_by }) continue else: frames_list = export_frames_and_annotations(label) sensor_val = '' num_frames = len(frames_list) total_annotations = sum(len(f['annotations']) for f in frames_list) matching_annotations = sum( 1 for f in frames_list for ann in f['annotations'] if getattr(ann, 'category_id', None) in target_classes ) metrics.append({ 'name': getattr(sample, 'name', sample.uuid), 'uuid': sample.uuid, 'labelset': labelset, 'sensor': sensor_val if is_multisensor else '', 'num_frames': num_frames, 'total_annotations': total_annotations, 'matching_annotations': matching_annotations, 'labeled_by': labeled_by, 'reviewed_by': reviewed_by }) except Exception as e: continue if not metrics: st.session_state.error = "No metrics could be generated for the dataset." else: st.session_state.csv_content = generate_csv(metrics, dataset_identifier) if status_ctx is not None: status_ctx.update(label="CSV generated!", state="complete") except Exception as e: st.session_state.error = f"An error occurred: {e}" if status_ctx is not None: status_ctx.update(label="Error occurred.", state="error") if st.session_state.error: st.error(st.session_state.error) if st.session_state.csv_content: today_str = datetime.now().strftime("%Y%m%d") filename = f"{today_str}_{dataset_identifier}_count-by-class.csv" st.download_button( "Download CSV", data=st.session_state.csv_content, file_name=filename, mime="text/csv" )