Tomatillo commited on
Commit
b6ff680
·
verified ·
1 Parent(s): 00c6aac

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +199 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,201 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import io
3
+ import csv
4
+ from segments import SegmentsClient
5
+ from get_labels_from_samples import (
6
+ get_samples as get_samples_objects,
7
+ export_frames_and_annotations,
8
+ export_sensor_frames_and_annotations,
9
+ export_all_sensor_frames_and_annotations
10
+ )
11
 
12
+ def init_session_state():
13
+ if 'csv_content' not in st.session_state:
14
+ st.session_state.csv_content = None
15
+ if 'error' not in st.session_state:
16
+ st.session_state.error = None
17
+
18
+
19
+ def init_client(api_key: str) -> SegmentsClient:
20
+ """Initialize the Segments.ai API client using the provided API key."""
21
+ return SegmentsClient(api_key)
22
+
23
+
24
+ def parse_classes(input_str: str) -> list:
25
+ """
26
+ Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints.
27
+ """
28
+ classes = []
29
+ tokens = input_str.split(',')
30
+ for token in tokens:
31
+ token = token.strip()
32
+ if '-' in token:
33
+ try:
34
+ start, end = map(int, token.split('-'))
35
+ classes.extend(range(start, end + 1))
36
+ except ValueError:
37
+ continue
38
+ else:
39
+ try:
40
+ classes.append(int(token))
41
+ except ValueError:
42
+ continue
43
+ return sorted(set(classes))
44
+
45
+
46
+ def generate_csv(metrics: list, dataset_identifier: str) -> str:
47
+ """
48
+ Generate CSV content from list of per-sample metrics.
49
+ Columns: name, sample_url, sensor, num_frames, total_annotations,
50
+ matching_annotations, labeled_by, reviewed_by
51
+ """
52
+ output = io.StringIO()
53
+ writer = csv.writer(output)
54
+ writer.writerow([
55
+ 'name', 'sample_url', 'sensor', 'num_frames',
56
+ 'total_annotations', 'matching_annotations',
57
+ 'labeled_by', 'reviewed_by'
58
+ ])
59
+ for m in metrics:
60
+ url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}"
61
+ writer.writerow([
62
+ m['name'], url, m['sensor'],
63
+ m['num_frames'], m['total_annotations'],
64
+ m['matching_annotations'], m['labeled_by'],
65
+ m['reviewed_by']
66
+ ])
67
+ content = output.getvalue()
68
+ output.close()
69
+ return content
70
+
71
+ # ----------------------
72
+ # Streamlit UI
73
+ # ----------------------
74
+
75
+ init_session_state()
76
+ st.title("Per-Sample Annotation Counts by Class")
77
+
78
+ api_key = st.text_input("API Key", type="password", key="api_key_input")
79
+ dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input")
80
+ classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input")
81
+ run_button = st.button("Generate CSV", key="run_button")
82
+
83
+ sensor_names = []
84
+ is_multisensor = False
85
+ sensor_select = None
86
+ samples_objects = []
87
+
88
+ if api_key and dataset_identifier:
89
+ try:
90
+ client = init_client(api_key)
91
+ samples_objects = get_samples_objects(client, dataset_identifier)
92
+ if samples_objects:
93
+ label = client.get_label(samples_objects[0].uuid)
94
+ sensors = getattr(getattr(label, 'attributes', None), 'sensors', None)
95
+ if sensors is not None:
96
+ is_multisensor = True
97
+ sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors]
98
+ except Exception as e:
99
+ st.warning(f"Could not inspect dataset sensors: {e}")
100
+
101
+ if is_multisensor:
102
+ sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names)
103
+
104
+ if run_button:
105
+ st.session_state.csv_content = None
106
+ st.session_state.error = None
107
+ if not api_key:
108
+ st.session_state.error = "API Key is required."
109
+ elif not dataset_identifier:
110
+ st.session_state.error = "Dataset identifier is required."
111
+ elif not classes_input:
112
+ st.session_state.error = "Please specify at least one class."
113
+ elif is_multisensor and not sensor_select:
114
+ st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV."
115
+ else:
116
+ with st.spinner("Processing samples..."):
117
+ try:
118
+ target_classes = parse_classes(classes_input)
119
+ client = init_client(api_key)
120
+ metrics = []
121
+ for sample in samples_objects:
122
+ try:
123
+ label = client.get_label(sample.uuid)
124
+ labelset = getattr(label, 'labelset', '') or ''
125
+ labeled_by = getattr(label, 'created_by', '') or ''
126
+ reviewed_by = getattr(label, 'reviewed_by', '') or ''
127
+ if is_multisensor and sensor_select and sensor_select != 'All sensors':
128
+ frames_list = export_sensor_frames_and_annotations(label, sensor_select)
129
+ sensor_val = sensor_select
130
+ num_frames = len(frames_list)
131
+ total_annotations = sum(len(f['annotations']) for f in frames_list)
132
+ matching_annotations = sum(
133
+ 1
134
+ for f in frames_list
135
+ for ann in f['annotations']
136
+ if getattr(ann, 'category_id', None) in target_classes
137
+ )
138
+ elif is_multisensor and (not sensor_select or sensor_select == 'All sensors'):
139
+ all_sensor_frames = export_all_sensor_frames_and_annotations(label)
140
+ for sensor_name, frames_list in all_sensor_frames.items():
141
+ num_frames = len(frames_list)
142
+ total_annotations = sum(len(f['annotations']) for f in frames_list)
143
+ matching_annotations = sum(
144
+ 1
145
+ for f in frames_list
146
+ for ann in f['annotations']
147
+ if getattr(ann, 'category_id', None) in target_classes
148
+ )
149
+ metrics.append({
150
+ 'name': getattr(sample, 'name', sample.uuid),
151
+ 'uuid': sample.uuid,
152
+ 'labelset': labelset,
153
+ 'sensor': sensor_name,
154
+ 'num_frames': num_frames,
155
+ 'total_annotations': total_annotations,
156
+ 'matching_annotations': matching_annotations,
157
+ 'labeled_by': labeled_by,
158
+ 'reviewed_by': reviewed_by
159
+ })
160
+ continue
161
+ else:
162
+ frames_list = export_frames_and_annotations(label)
163
+ sensor_val = ''
164
+ num_frames = len(frames_list)
165
+ total_annotations = sum(len(f['annotations']) for f in frames_list)
166
+ matching_annotations = sum(
167
+ 1
168
+ for f in frames_list
169
+ for ann in f['annotations']
170
+ if getattr(ann, 'category_id', None) in target_classes
171
+ )
172
+ metrics.append({
173
+ 'name': getattr(sample, 'name', sample.uuid),
174
+ 'uuid': sample.uuid,
175
+ 'labelset': labelset,
176
+ 'sensor': sensor_val if is_multisensor else '',
177
+ 'num_frames': num_frames,
178
+ 'total_annotations': total_annotations,
179
+ 'matching_annotations': matching_annotations,
180
+ 'labeled_by': labeled_by,
181
+ 'reviewed_by': reviewed_by
182
+ })
183
+ except Exception as e:
184
+ continue
185
+ if not metrics:
186
+ st.session_state.error = "No metrics could be generated for the dataset."
187
+ else:
188
+ st.session_state.csv_content = generate_csv(metrics, dataset_identifier)
189
+ except Exception as e:
190
+ st.session_state.error = f"An error occurred: {e}"
191
+
192
+ if st.session_state.error:
193
+ st.error(st.session_state.error)
194
+
195
+ if st.session_state.csv_content:
196
+ st.download_button(
197
+ label="Download Metrics CSV",
198
+ data=st.session_state.csv_content,
199
+ file_name="sample_metrics.csv",
200
+ mime="text/csv"
201
+ )