ncoria commited on
Commit
da69119
·
verified ·
1 Parent(s): 32bc64a

update bout threshold

Browse files
Files changed (1) hide show
  1. apply_model.py +203 -203
apply_model.py CHANGED
@@ -1,203 +1,203 @@
1
- import os
2
- import pickle
3
- from random import random
4
- import streamlit as st
5
- import matplotlib.pyplot as plt
6
- from matplotlib.colors import ListedColormap
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- from utils.mp4Io import mp4Io_reader
11
- from utils.seqIo import seqIo_reader
12
- import pandas as pd
13
- from PIL import Image
14
- from pathlib import Path
15
- from transformers import AutoProcessor, AutoModel
16
- from tempfile import NamedTemporaryFile
17
- from tqdm import tqdm
18
- from sklearn.metrics import accuracy_score, classification_report
19
- from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, multiclass_merge_and_filter_bouts, generate_embeddings_stream_io
20
-
21
- # --server.maxUploadSize 3000
22
-
23
- def get_io_reader(uploaded_file):
24
- if uploaded_file.name[-3:]=='seq':
25
- with NamedTemporaryFile(suffix="seq", delete=False) as temp:
26
- temp.write(uploaded_file.getvalue())
27
- sr = seqIo_reader(temp.name)
28
- else:
29
- with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
30
- temp.write(uploaded_file.getvalue())
31
- sr = mp4Io_reader(temp.name)
32
- return sr
33
-
34
- def get_unique_labels(label_list: list[str]):
35
- label_set = set()
36
- for label in label_list:
37
- individual_labels = label.split('||')
38
- for individual_label in individual_labels:
39
- label_set.add(individual_label)
40
- return list(label_set)
41
-
42
- def get_smoothed_predictions(svm_model, test_embeds):
43
- test_pred = svm_model.predict(test_embeds)
44
- test_prob = svm_model.predict_proba(test_embeds)
45
-
46
- bout_threshold = 5
47
- proximity_threshold = 2
48
-
49
- predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
50
- return predictions
51
-
52
- if "embeddings_df" not in st.session_state:
53
- st.session_state.embeddings_df = None
54
-
55
- if "smoothed_predictions" not in st.session_state:
56
- st.session_state.smoothed_predictions = None
57
- st.session_state.test_labels = []
58
-
59
- st.title('batik: frame classifier')
60
-
61
- st.text("Upload files to apply trained classifier on.")
62
- with st.form('embedding_generation_settings'):
63
- seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
64
- annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
65
- downsample_rate = st.number_input('Downsample Rate',value=4)
66
- submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
67
-
68
- st.markdown("**(Optional)** Upload embeddings if not generating above.")
69
- embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
70
-
71
- if submit_embed_settings and seq_file is not None and annot_files is not None:
72
- video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
73
- "SLIP",
74
- downsample_rate,
75
- False)
76
-
77
- fnames = [seq_file.name]
78
- embeddings_df = create_embeddings_csv_io(out="file",
79
- fnames=fnames,
80
- embeddings=video_embeddings,
81
- frames=video_frames,
82
- annotations=[annot_files],
83
- test_fnames=None,
84
- views=None,
85
- conditions=None,
86
- downsample_rate=downsample_rate)
87
- st.session_state.embeddings_df = embeddings_df
88
-
89
- elif embeddings_csv is not None:
90
- embeddings_df = pd.read_csv(embeddings_csv)
91
- st.session_state.embeddings_df = embeddings_df
92
- else:
93
- st.text('Please upload file(s).')
94
-
95
- st.divider()
96
- st.markdown("Upload classifier model.")
97
- pickled_file = st.file_uploader("Choose a .pkl File", type=['pkl'])
98
-
99
- if pickled_file is not None:
100
- with NamedTemporaryFile(suffix='pkl', delete=False) as temp:
101
- temp.write(pickled_file.getvalue())
102
- with open(temp.name, 'rb') as pickled_model:
103
- svm_clf = pickle.load(pickled_model)
104
- else:
105
- svm_clf = None
106
-
107
- st.divider()
108
- if st.session_state.embeddings_df is not None and svm_clf is not None:
109
- st.subheader("specify dataset labels")
110
- label_list = st.session_state.embeddings_df['Label'].to_list()
111
- unique_label_list = get_unique_labels(label_list)
112
-
113
- with st.form('apply_model_settings'):
114
- st.text("Select label(s):")
115
- specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
116
-
117
-
118
- apply_model = st.form_submit_button("Apply Model")
119
-
120
- if apply_model:
121
- if 'Test' in st.session_state.embeddings_df:
122
- test_videos = True
123
- else:
124
- print(f'shape of df: {st.session_state.embeddings_df.shape[0]}')
125
- test_videos_array = [True for i in range(st.session_state.embeddings_df.shape[0])]
126
- st.session_state.embeddings_df['Test'] = test_videos_array
127
- test_videos = True
128
-
129
- kwargs = {'embeddings_df' : st.session_state.embeddings_df,
130
- 'specified_classes' : specified_classes,
131
- 'classes_to_remove' : None,
132
- 'max_class_size' : None,
133
- 'animal_state' : None,
134
- 'view' : None,
135
- 'shuffle_data' : False,
136
- 'test_videos' : test_videos}
137
- train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\
138
- process_dataset_in_mem(**kwargs)
139
-
140
- # get predictions from embeddings
141
- with st.spinner("Model application in progress..."):
142
- smoothed_predictions = get_smoothed_predictions(svm_clf, test_embeds)
143
-
144
- # save variables to state
145
- st.session_state.smoothed_predictions = smoothed_predictions
146
- st.session_state.test_labels = test_labels
147
-
148
- if st.session_state.smoothed_predictions is not None:
149
- # Convert labels to numerical values
150
- label_to_appear_first = 'other'
151
- unique_labels = set(st.session_state.test_labels)
152
- unique_labels.discard(label_to_appear_first)
153
-
154
- label_to_index = {label_to_appear_first: 0}
155
-
156
- label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
157
- index_to_label = {idx: label for label, idx in label_to_index.items()}
158
-
159
- numerical_labels_test = np.array([label_to_index[label] for label in st.session_state.test_labels])
160
- print("Label Valence: ", label_to_index)
161
-
162
- #smoothed_predictions test labels
163
- if len(st.session_state.smoothed_predictions) > 0:
164
- test_accuracy = accuracy_score(numerical_labels_test, st.session_state.smoothed_predictions)
165
- else:
166
- test_accuracy = 0 # If no predictions meet the threshold, set accuracy to 0
167
-
168
- # test_accuracy = accuracy_score(numerical_labels_test, test_pred)
169
- report = classification_report(numerical_labels_test,
170
- st.session_state.smoothed_predictions,
171
- target_names=[index_to_label[idx] for idx in range(len(index_to_label))],
172
- output_dict=True)
173
- report_df = pd.DataFrame(report).transpose()
174
-
175
- st.text(f"Eval Accuracy: {test_accuracy}")
176
- st.subheader("Classification Report:")
177
- st.dataframe(report_df)
178
-
179
- # create figure (behavior raster)
180
- fig, ax = plt.subplots()
181
- raster = ax.imshow(st.session_state.smoothed_predictions.reshape((1,st.session_state.smoothed_predictions.size)),
182
- aspect='auto',
183
- interpolation='nearest',
184
- cmap=ListedColormap(['white'] + [(random(),random(),random()) for i in range(len(index_to_label) - 1)]))
185
- ax.set_yticklabels([])
186
- ax.set_xlabel('frames')
187
- cbar = fig.colorbar(raster)
188
- labels = [label_to_appear_first] + list(unique_labels)
189
- spacing = (len(labels) - 1)/len(labels)
190
- start = spacing/2
191
- ticks = [start] + [start + spacing*i for i in range(1,len(labels))]
192
- cbar.set_ticks(ticks=ticks, labels = labels)
193
-
194
- st.pyplot(fig)
195
-
196
- # save generated annotations
197
- annotations = [labels[x] for x in st.session_state.smoothed_predictions]
198
- annotations_df = pd.DataFrame(annotations, columns=['label'])
199
- csv = annotations_df.to_csv(header=False).encode("utf-8")
200
- output_file_name = st.text_input("Output File Name:","output")
201
- st.download_button("Download annotations as .csv",
202
- data=csv,
203
- file_name=f"{output_file_name}.csv")
 
1
+ import os
2
+ import pickle
3
+ from random import random
4
+ import streamlit as st
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.colors import ListedColormap
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from utils.mp4Io import mp4Io_reader
11
+ from utils.seqIo import seqIo_reader
12
+ import pandas as pd
13
+ from PIL import Image
14
+ from pathlib import Path
15
+ from transformers import AutoProcessor, AutoModel
16
+ from tempfile import NamedTemporaryFile
17
+ from tqdm import tqdm
18
+ from sklearn.metrics import accuracy_score, classification_report
19
+ from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, multiclass_merge_and_filter_bouts, generate_embeddings_stream_io
20
+
21
+ # --server.maxUploadSize 3000
22
+
23
+ def get_io_reader(uploaded_file):
24
+ if uploaded_file.name[-3:]=='seq':
25
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
26
+ temp.write(uploaded_file.getvalue())
27
+ sr = seqIo_reader(temp.name)
28
+ else:
29
+ with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
30
+ temp.write(uploaded_file.getvalue())
31
+ sr = mp4Io_reader(temp.name)
32
+ return sr
33
+
34
+ def get_unique_labels(label_list: list[str]):
35
+ label_set = set()
36
+ for label in label_list:
37
+ individual_labels = label.split('||')
38
+ for individual_label in individual_labels:
39
+ label_set.add(individual_label)
40
+ return list(label_set)
41
+
42
+ def get_smoothed_predictions(svm_model, test_embeds):
43
+ test_pred = svm_model.predict(test_embeds)
44
+ test_prob = svm_model.predict_proba(test_embeds)
45
+
46
+ bout_threshold = 30
47
+ proximity_threshold = 2
48
+
49
+ predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
50
+ return predictions
51
+
52
+ if "embeddings_df" not in st.session_state:
53
+ st.session_state.embeddings_df = None
54
+
55
+ if "smoothed_predictions" not in st.session_state:
56
+ st.session_state.smoothed_predictions = None
57
+ st.session_state.test_labels = []
58
+
59
+ st.title('batik: frame classifier')
60
+
61
+ st.text("Upload files to apply trained classifier on.")
62
+ with st.form('embedding_generation_settings'):
63
+ seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
64
+ annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
65
+ downsample_rate = st.number_input('Downsample Rate',value=4)
66
+ submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
67
+
68
+ st.markdown("**(Optional)** Upload embeddings if not generating above.")
69
+ embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
70
+
71
+ if submit_embed_settings and seq_file is not None and annot_files is not None:
72
+ video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
73
+ "SLIP",
74
+ downsample_rate,
75
+ False)
76
+
77
+ fnames = [seq_file.name]
78
+ embeddings_df = create_embeddings_csv_io(out="file",
79
+ fnames=fnames,
80
+ embeddings=video_embeddings,
81
+ frames=video_frames,
82
+ annotations=[annot_files],
83
+ test_fnames=None,
84
+ views=None,
85
+ conditions=None,
86
+ downsample_rate=downsample_rate)
87
+ st.session_state.embeddings_df = embeddings_df
88
+
89
+ elif embeddings_csv is not None:
90
+ embeddings_df = pd.read_csv(embeddings_csv)
91
+ st.session_state.embeddings_df = embeddings_df
92
+ else:
93
+ st.text('Please upload file(s).')
94
+
95
+ st.divider()
96
+ st.markdown("Upload classifier model.")
97
+ pickled_file = st.file_uploader("Choose a .pkl File", type=['pkl'])
98
+
99
+ if pickled_file is not None:
100
+ with NamedTemporaryFile(suffix='pkl', delete=False) as temp:
101
+ temp.write(pickled_file.getvalue())
102
+ with open(temp.name, 'rb') as pickled_model:
103
+ svm_clf = pickle.load(pickled_model)
104
+ else:
105
+ svm_clf = None
106
+
107
+ st.divider()
108
+ if st.session_state.embeddings_df is not None and svm_clf is not None:
109
+ st.subheader("specify dataset labels")
110
+ label_list = st.session_state.embeddings_df['Label'].to_list()
111
+ unique_label_list = get_unique_labels(label_list)
112
+
113
+ with st.form('apply_model_settings'):
114
+ st.text("Select label(s):")
115
+ specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
116
+
117
+
118
+ apply_model = st.form_submit_button("Apply Model")
119
+
120
+ if apply_model:
121
+ if 'Test' in st.session_state.embeddings_df:
122
+ test_videos = True
123
+ else:
124
+ print(f'shape of df: {st.session_state.embeddings_df.shape[0]}')
125
+ test_videos_array = [True for i in range(st.session_state.embeddings_df.shape[0])]
126
+ st.session_state.embeddings_df['Test'] = test_videos_array
127
+ test_videos = True
128
+
129
+ kwargs = {'embeddings_df' : st.session_state.embeddings_df,
130
+ 'specified_classes' : specified_classes,
131
+ 'classes_to_remove' : None,
132
+ 'max_class_size' : None,
133
+ 'animal_state' : None,
134
+ 'view' : None,
135
+ 'shuffle_data' : False,
136
+ 'test_videos' : test_videos}
137
+ train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\
138
+ process_dataset_in_mem(**kwargs)
139
+
140
+ # get predictions from embeddings
141
+ with st.spinner("Model application in progress..."):
142
+ smoothed_predictions = get_smoothed_predictions(svm_clf, test_embeds)
143
+
144
+ # save variables to state
145
+ st.session_state.smoothed_predictions = smoothed_predictions
146
+ st.session_state.test_labels = test_labels
147
+
148
+ if st.session_state.smoothed_predictions is not None:
149
+ # Convert labels to numerical values
150
+ label_to_appear_first = 'other'
151
+ unique_labels = set(st.session_state.test_labels)
152
+ unique_labels.discard(label_to_appear_first)
153
+
154
+ label_to_index = {label_to_appear_first: 0}
155
+
156
+ label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
157
+ index_to_label = {idx: label for label, idx in label_to_index.items()}
158
+
159
+ numerical_labels_test = np.array([label_to_index[label] for label in st.session_state.test_labels])
160
+ print("Label Valence: ", label_to_index)
161
+
162
+ #smoothed_predictions test labels
163
+ if len(st.session_state.smoothed_predictions) > 0:
164
+ test_accuracy = accuracy_score(numerical_labels_test, st.session_state.smoothed_predictions)
165
+ else:
166
+ test_accuracy = 0 # If no predictions meet the threshold, set accuracy to 0
167
+
168
+ # test_accuracy = accuracy_score(numerical_labels_test, test_pred)
169
+ report = classification_report(numerical_labels_test,
170
+ st.session_state.smoothed_predictions,
171
+ target_names=[index_to_label[idx] for idx in range(len(index_to_label))],
172
+ output_dict=True)
173
+ report_df = pd.DataFrame(report).transpose()
174
+
175
+ st.text(f"Eval Accuracy: {test_accuracy}")
176
+ st.subheader("Classification Report:")
177
+ st.dataframe(report_df)
178
+
179
+ # create figure (behavior raster)
180
+ fig, ax = plt.subplots()
181
+ raster = ax.imshow(st.session_state.smoothed_predictions.reshape((1,st.session_state.smoothed_predictions.size)),
182
+ aspect='auto',
183
+ interpolation='nearest',
184
+ cmap=ListedColormap(['white'] + [(random(),random(),random()) for i in range(len(index_to_label) - 1)]))
185
+ ax.set_yticklabels([])
186
+ ax.set_xlabel('frames')
187
+ cbar = fig.colorbar(raster)
188
+ labels = [label_to_appear_first] + list(unique_labels)
189
+ spacing = (len(labels) - 1)/len(labels)
190
+ start = spacing/2
191
+ ticks = [start] + [start + spacing*i for i in range(1,len(labels))]
192
+ cbar.set_ticks(ticks=ticks, labels = labels)
193
+
194
+ st.pyplot(fig)
195
+
196
+ # save generated annotations
197
+ annotations = [labels[x] for x in st.session_state.smoothed_predictions]
198
+ annotations_df = pd.DataFrame(annotations, columns=['label'])
199
+ csv = annotations_df.to_csv(header=False).encode("utf-8")
200
+ output_file_name = st.text_input("Output File Name:","output")
201
+ st.download_button("Download annotations as .csv",
202
+ data=csv,
203
+ file_name=f"{output_file_name}.csv")