File size: 8,706 Bytes
9e3c899
b6ff680
 
 
 
 
 
 
 
 
9e3c899
b6ff680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import streamlit as st
import io
import csv
from segments import SegmentsClient
from get_labels_from_samples import (
    get_samples as get_samples_objects,
    export_frames_and_annotations,
    export_sensor_frames_and_annotations,
    export_all_sensor_frames_and_annotations
)

def init_session_state():
    if 'csv_content' not in st.session_state:
        st.session_state.csv_content = None
    if 'error' not in st.session_state:
        st.session_state.error = None


def init_client(api_key: str) -> SegmentsClient:
    """Initialize the Segments.ai API client using the provided API key."""
    return SegmentsClient(api_key)


def parse_classes(input_str: str) -> list:
    """
    Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints.
    """
    classes = []
    tokens = input_str.split(',')
    for token in tokens:
        token = token.strip()
        if '-' in token:
            try:
                start, end = map(int, token.split('-'))
                classes.extend(range(start, end + 1))
            except ValueError:
                continue
        else:
            try:
                classes.append(int(token))
            except ValueError:
                continue
    return sorted(set(classes))


def generate_csv(metrics: list, dataset_identifier: str) -> str:
    """
    Generate CSV content from list of per-sample metrics.
    Columns: name, sample_url, sensor, num_frames, total_annotations,
             matching_annotations, labeled_by, reviewed_by
    """
    output = io.StringIO()
    writer = csv.writer(output)
    writer.writerow([
        'name', 'sample_url', 'sensor', 'num_frames',
        'total_annotations', 'matching_annotations',
        'labeled_by', 'reviewed_by'
    ])
    for m in metrics:
        url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}"
        writer.writerow([
            m['name'], url, m['sensor'],
            m['num_frames'], m['total_annotations'],
            m['matching_annotations'], m['labeled_by'],
            m['reviewed_by']
        ])
    content = output.getvalue()
    output.close()
    return content

# ----------------------
# Streamlit UI
# ----------------------

init_session_state()
st.title("Per-Sample Annotation Counts by Class")

api_key = st.text_input("API Key", type="password", key="api_key_input")
dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input")
classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input")
run_button = st.button("Generate CSV", key="run_button")

sensor_names = []
is_multisensor = False
sensor_select = None
samples_objects = []

if api_key and dataset_identifier:
    try:
        client = init_client(api_key)
        samples_objects = get_samples_objects(client, dataset_identifier)
        if samples_objects:
            label = client.get_label(samples_objects[0].uuid)
            sensors = getattr(getattr(label, 'attributes', None), 'sensors', None)
            if sensors is not None:
                is_multisensor = True
                sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors]
    except Exception as e:
        st.warning(f"Could not inspect dataset sensors: {e}")

if is_multisensor:
    sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names)

if run_button:
    st.session_state.csv_content = None
    st.session_state.error = None
    if not api_key:
        st.session_state.error = "API Key is required."
    elif not dataset_identifier:
        st.session_state.error = "Dataset identifier is required."
    elif not classes_input:
        st.session_state.error = "Please specify at least one class."
    elif is_multisensor and not sensor_select:
        st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV."
    else:
        with st.spinner("Processing samples..."):
            try:
                target_classes = parse_classes(classes_input)
                client = init_client(api_key)
                metrics = []
                for sample in samples_objects:
                    try:
                        label = client.get_label(sample.uuid)
                        labelset = getattr(label, 'labelset', '') or ''
                        labeled_by = getattr(label, 'created_by', '') or ''
                        reviewed_by = getattr(label, 'reviewed_by', '') or ''
                        if is_multisensor and sensor_select and sensor_select != 'All sensors':
                            frames_list = export_sensor_frames_and_annotations(label, sensor_select)
                            sensor_val = sensor_select
                            num_frames = len(frames_list)
                            total_annotations = sum(len(f['annotations']) for f in frames_list)
                            matching_annotations = sum(
                                1
                                for f in frames_list
                                for ann in f['annotations']
                                if getattr(ann, 'category_id', None) in target_classes
                            )
                        elif is_multisensor and (not sensor_select or sensor_select == 'All sensors'):
                            all_sensor_frames = export_all_sensor_frames_and_annotations(label)
                            for sensor_name, frames_list in all_sensor_frames.items():
                                num_frames = len(frames_list)
                                total_annotations = sum(len(f['annotations']) for f in frames_list)
                                matching_annotations = sum(
                                    1
                                    for f in frames_list
                                    for ann in f['annotations']
                                    if getattr(ann, 'category_id', None) in target_classes
                                )
                                metrics.append({
                                    'name': getattr(sample, 'name', sample.uuid),
                                    'uuid': sample.uuid,
                                    'labelset': labelset,
                                    'sensor': sensor_name,
                                    'num_frames': num_frames,
                                    'total_annotations': total_annotations,
                                    'matching_annotations': matching_annotations,
                                    'labeled_by': labeled_by,
                                    'reviewed_by': reviewed_by
                                })
                            continue
                        else:
                            frames_list = export_frames_and_annotations(label)
                            sensor_val = ''
                            num_frames = len(frames_list)
                            total_annotations = sum(len(f['annotations']) for f in frames_list)
                            matching_annotations = sum(
                                1
                                for f in frames_list
                                for ann in f['annotations']
                                if getattr(ann, 'category_id', None) in target_classes
                            )
                        metrics.append({
                            'name': getattr(sample, 'name', sample.uuid),
                            'uuid': sample.uuid,
                            'labelset': labelset,
                            'sensor': sensor_val if is_multisensor else '',
                            'num_frames': num_frames,
                            'total_annotations': total_annotations,
                            'matching_annotations': matching_annotations,
                            'labeled_by': labeled_by,
                            'reviewed_by': reviewed_by
                        })
                    except Exception as e:
                        continue
                if not metrics:
                    st.session_state.error = "No metrics could be generated for the dataset."
                else:
                    st.session_state.csv_content = generate_csv(metrics, dataset_identifier)
            except Exception as e:
                st.session_state.error = f"An error occurred: {e}"

if st.session_state.error:
    st.error(st.session_state.error)

if st.session_state.csv_content:
    st.download_button(
        label="Download Metrics CSV",
        data=st.session_state.csv_content,
        file_name="sample_metrics.csv",
        mime="text/csv"
    )