ncoria commited on
Commit
ed29c11
·
verified ·
1 Parent(s): b8c85bc

add main program files

Browse files
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+
4
+ pages = {
5
+ "home": [
6
+ st.Page("home.py",title="about",icon=":material/home:")
7
+ ],
8
+ "generate embeddings": [
9
+ st.Page("generate_embeddings.py", title="generate",icon=":material/dataset:")
10
+ ],
11
+ "annotation": [
12
+ st.Page("train_model.py", title="train model",icon=":material/model_training:"),
13
+ st.Page("apply_model.py", title="apply model",icon=":material/grade:"),
14
+ ],
15
+ "behavior discovery": [
16
+ st.Page("explore.py", title="explore",icon=":material/search:")
17
+ ],
18
+ }
19
+
20
+ pg = st.navigation(pages)
21
+ pg.run()
apply_model.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from random import random
4
+ import streamlit as st
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.colors import ListedColormap
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from utils.mp4Io import mp4Io_reader
11
+ from utils.seqIo import seqIo_reader
12
+ import pandas as pd
13
+ from PIL import Image
14
+ from pathlib import Path
15
+ from transformers import AutoProcessor, AutoModel
16
+ from tempfile import NamedTemporaryFile
17
+ from tqdm import tqdm
18
+ from sklearn.metrics import accuracy_score, classification_report
19
+ from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, multiclass_merge_and_filter_bouts, generate_embeddings_stream_io
20
+
21
+ # --server.maxUploadSize 3000
22
+
23
+ def get_io_reader(uploaded_file):
24
+ if uploaded_file.name[-3:]=='seq':
25
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
26
+ temp.write(uploaded_file.getvalue())
27
+ sr = seqIo_reader(temp.name)
28
+ else:
29
+ with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
30
+ temp.write(uploaded_file.getvalue())
31
+ sr = mp4Io_reader(temp.name)
32
+ return sr
33
+
34
+ def get_unique_labels(label_list: list[str]):
35
+ label_set = set()
36
+ for label in label_list:
37
+ individual_labels = label.split('||')
38
+ for individual_label in individual_labels:
39
+ label_set.add(individual_label)
40
+ return list(label_set)
41
+
42
+ def get_smoothed_predictions(svm_model, test_embeds):
43
+ test_pred = svm_model.predict(test_embeds)
44
+ test_prob = svm_model.predict_proba(test_embeds)
45
+
46
+ bout_threshold = 5
47
+ proximity_threshold = 2
48
+
49
+ predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
50
+ return predictions
51
+
52
+ if "embeddings_df" not in st.session_state:
53
+ st.session_state.embeddings_df = None
54
+
55
+ if "smoothed_predictions" not in st.session_state:
56
+ st.session_state.smoothed_predictions = None
57
+ st.session_state.test_labels = []
58
+
59
+ st.title('batik: frame classifier')
60
+
61
+ st.text("Upload files to apply trained classifier on.")
62
+ with st.form('embedding_generation_settings'):
63
+ seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=False)
64
+ annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
65
+ downsample_rate = st.number_input('Downsample Rate',value=4)
66
+ submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
67
+
68
+ st.markdown("**(Optional)** Upload embeddings if not generating above.")
69
+ embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
70
+
71
+ if submit_embed_settings and seq_file is not None and annot_files is not None:
72
+ video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
73
+ "SLIP",
74
+ downsample_rate,
75
+ False)
76
+
77
+ fnames = [seq_file.name]
78
+ embeddings_df = create_embeddings_csv_io(out="file",
79
+ fnames=fnames,
80
+ embeddings=video_embeddings,
81
+ frames=video_frames,
82
+ annotations=[annot_files],
83
+ test_fnames=None,
84
+ views=None,
85
+ conditions=None,
86
+ downsample_rate=downsample_rate)
87
+ st.session_state.embeddings_df = embeddings_df
88
+
89
+ elif embeddings_csv is not None:
90
+ embeddings_df = pd.read_csv(embeddings_csv)
91
+ st.session_state.embeddings_df = embeddings_df
92
+ else:
93
+ st.text('Please upload file(s).')
94
+
95
+ st.divider()
96
+ st.markdown("Upload classifier model.")
97
+ pickled_file = st.file_uploader("Choose a .pkl File", type=['pkl'])
98
+
99
+ if pickled_file is not None:
100
+ with NamedTemporaryFile(suffix='pkl', delete=False) as temp:
101
+ temp.write(pickled_file.getvalue())
102
+ with open(temp.name, 'rb') as pickled_model:
103
+ svm_clf = pickle.load(pickled_model)
104
+ else:
105
+ svm_clf = None
106
+
107
+ st.divider()
108
+ if st.session_state.embeddings_df is not None and svm_clf is not None:
109
+ st.subheader("specify dataset labels")
110
+ label_list = st.session_state.embeddings_df['Label'].to_list()
111
+ unique_label_list = get_unique_labels(label_list)
112
+
113
+ with st.form('apply_model_settings'):
114
+ st.text("Select label(s):")
115
+ specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
116
+
117
+
118
+ apply_model = st.form_submit_button("Apply Model")
119
+
120
+ if apply_model:
121
+ kwargs = {'embeddings_df' : st.session_state.embeddings_df,
122
+ 'specified_classes' : specified_classes,
123
+ 'classes_to_remove' : None,
124
+ 'max_class_size' : None,
125
+ 'animal_state' : None,
126
+ 'view' : None,
127
+ 'shuffle_data' : False,
128
+ 'test_videos' : list(set(st.session_state.embeddings_df['Source'].to_list()))}
129
+ train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\
130
+ process_dataset_in_mem(**kwargs)
131
+
132
+ # get predictions from embeddings
133
+ with st.spinner("Model application in progress..."):
134
+ smoothed_predictions = get_smoothed_predictions(svm_clf, test_embeds)
135
+
136
+ # save variables to state
137
+ st.session_state.smoothed_predictions = smoothed_predictions
138
+ st.session_state.test_labels = test_labels
139
+
140
+ if st.session_state.smoothed_predictions is not None:
141
+ # Convert labels to numerical values
142
+ label_to_appear_first = 'other'
143
+ unique_labels = set(st.session_state.test_labels)
144
+ unique_labels.discard(label_to_appear_first)
145
+
146
+ label_to_index = {label_to_appear_first: 0}
147
+
148
+ label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
149
+ index_to_label = {idx: label for label, idx in label_to_index.items()}
150
+
151
+ numerical_labels_test = np.array([label_to_index[label] for label in st.session_state.test_labels])
152
+ print("Label Valence: ", label_to_index)
153
+
154
+ #smoothed_predictions test labels
155
+ if len(st.session_state.smoothed_predictions) > 0:
156
+ test_accuracy = accuracy_score(numerical_labels_test, st.session_state.smoothed_predictions)
157
+ else:
158
+ test_accuracy = 0 # If no predictions meet the threshold, set accuracy to 0
159
+
160
+ # test_accuracy = accuracy_score(numerical_labels_test, test_pred)
161
+ report = classification_report(numerical_labels_test,
162
+ st.session_state.smoothed_predictions,
163
+ target_names=[index_to_label[idx] for idx in range(len(index_to_label))],
164
+ output_dict=True)
165
+ report_df = pd.DataFrame(report).transpose()
166
+
167
+ st.text(f"Eval Accuracy: {test_accuracy}")
168
+ st.subheader("Classification Report:")
169
+ st.dataframe(report_df)
170
+
171
+ # create figure (behavior raster)
172
+ fig, ax = plt.subplots()
173
+ raster = ax.imshow(st.session_state.smoothed_predictions.reshape((1,st.session_state.smoothed_predictions.size)),
174
+ aspect='auto',
175
+ interpolation='nearest',
176
+ cmap=ListedColormap(['white'] + [(random(),random(),random()) for i in range(len(index_to_label) - 1)]))
177
+ ax.set_yticklabels([])
178
+ ax.set_xlabel('frames')
179
+ cbar = fig.colorbar(raster)
180
+ labels = [label_to_appear_first] + list(unique_labels)
181
+ spacing = (len(labels) - 1)/len(labels)
182
+ start = spacing/2
183
+ ticks = [start] + [start + spacing*i for i in range(1,len(labels))]
184
+ cbar.set_ticks(ticks=ticks, labels = labels)
185
+
186
+ st.pyplot(fig)
187
+
188
+ # save generated annotations
189
+ annotations = [labels[x] for x in st.session_state.smoothed_predictions]
190
+ annotations_df = pd.DataFrame(annotations, columns=['label'])
191
+ csv = annotations_df.to_csv(header=False).encode("utf-8")
192
+ output_file_name = st.text_input("Output File Name:","output")
193
+ st.download_button("Download annotations as .csv",
194
+ data=csv,
195
+ file_name=f"{output_file_name}.csv")
explore.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.express as px
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ from utils.mp4Io import mp4Io_reader
7
+ from utils.seqIo import seqIo_reader
8
+ import pandas as pd
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ from transformers import AutoProcessor, AutoModel
12
+ from tempfile import NamedTemporaryFile
13
+ from tqdm import tqdm
14
+ from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
15
+ from get_llava_response import get_llava_response, load_llava_checkpoint_hf
16
+ from sklearn.manifold import TSNE
17
+ from openai import OpenAI
18
+ import cv2
19
+ import base64
20
+ from hdbscan import HDBSCAN, all_points_membership_vectors
21
+ import random
22
+
23
+ # --server.maxUploadSize 3000
24
+ REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
25
+
26
+ def load_llava_model(hf_token):
27
+ return load_llava_checkpoint_hf(REPO_NAME, hf_token)
28
+
29
+ def get_unique_labels(label_list: list[str]):
30
+ label_set = set()
31
+ for label in label_list:
32
+ individual_labels = label.split('||')
33
+ for individual_label in individual_labels:
34
+ label_set.add(individual_label)
35
+ return list(label_set)
36
+
37
+ SYSTEM_PROMPT = """You are a researcher studying mice interactions from videos of the inside of a resident
38
+ intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one).
39
+ Your job is to answer questions about the behavior of the mice in the image given the context that each image is a frame of a continuous video.
40
+ Thus, you should use the visual information about the mice in the image to try to provide a detailed behavioral description of the image."""
41
+
42
+ @st.cache_resource
43
+ def get_io_reader(uploaded_file):
44
+ if uploaded_file.name[-3:]=='seq':
45
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
46
+ temp.write(uploaded_file.getvalue())
47
+ sr = seqIo_reader(temp.name)
48
+ else:
49
+ with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
50
+ temp.write(uploaded_file.getvalue())
51
+ sr = mp4Io_reader(temp.name)
52
+ return sr
53
+
54
+ def get_image(sr, frame_no: int):
55
+ image, _ = sr.getFrame(frame_no)
56
+ return image
57
+
58
+ @st.cache_data
59
+ def get_2d_embedding(embeddings: pd.DataFrame):
60
+ tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
61
+ embedding_2d = tsne.fit_transform(np.array(embeddings))
62
+ return embedding_2d
63
+
64
+ HDBSCAN_PARAMS = {
65
+ 'min_samples': 1
66
+ }
67
+
68
+ @st.cache_data
69
+ def hdbscan_classification(umap_embeddings, embeddings_2d, cluster_range):
70
+ max_num_clusters = -np.infty
71
+ num_clusters = []
72
+ min_cluster_size = np.linspace(cluster_range[0], cluster_range[1], 4)
73
+ for min_c in min_cluster_size:
74
+ learned_hierarchy = HDBSCAN(
75
+ prediction_data=True, min_cluster_size=int(round(min_c * 0.01 *umap_embeddings.shape[0])),
76
+ cluster_selection_method='leaf' ,
77
+ **HDBSCAN_PARAMS).fit(umap_embeddings)
78
+ num_clusters.append(len(np.unique(learned_hierarchy.labels_)))
79
+ if num_clusters[-1] > max_num_clusters:
80
+ max_num_clusters = num_clusters[-1]
81
+ retained_hierarchy = learned_hierarchy
82
+ assignments = retained_hierarchy.labels_
83
+ assign_prob = all_points_membership_vectors(retained_hierarchy)
84
+ soft_assignments = np.argmax(assign_prob, axis=1)
85
+ retained_hierarchy.fit(embeddings_2d)
86
+ return retained_hierarchy, assignments, assign_prob, soft_assignments
87
+
88
+ def upload_image(frame: np.ndarray):
89
+ """returns the file ID."""
90
+ _, encoded_image = cv2.imencode('.png', frame)
91
+ return base64.b64encode(encoded_image.tobytes()).decode('utf-8')
92
+
93
+ def ask_question_with_image_gpt(file_id, system_prompt, question, api_key):
94
+ """Asks a question about the uploaded image."""
95
+ client = OpenAI(api_key=api_key)
96
+
97
+ if file_id != None:
98
+ response = client.chat.completions.create(
99
+ model="gpt-4o",
100
+ messages=[
101
+ {"role": "system", "content": system_prompt},
102
+ {"role": "user", "content": [
103
+ {"type": "text", "text": question},
104
+ {"type": "image_url", "image_url": {"url": f"data:image/jpg:base64, {file_id}"}}]
105
+ }
106
+ ]
107
+ )
108
+ else:
109
+ response = client.chat.completions.create(
110
+ model="gpt-4o",
111
+ messages=[
112
+ {"role": "system", "content": system_prompt},
113
+ {"role": "user", "content": question}
114
+ ]
115
+ )
116
+ return response.choices[0].message.content
117
+
118
+ def ask_question_with_image_llava(image, system_prompt, question,
119
+ tokenizer, model, image_processor):
120
+ outputs = get_llava_response([question],
121
+ [image],
122
+ system_prompt,
123
+ tokenizer,
124
+ model,
125
+ image_processor,
126
+ REPO_NAME,
127
+ stream_output=False)
128
+ return outputs[0]
129
+
130
+ def ask_summary_question(image_array, label_array, api_key):
131
+ # load llava model
132
+ tokenizer, model, image_processor = load_llava_model(hf_token)
133
+
134
+ # global variable
135
+ system_prompt = SYSTEM_PROMPT
136
+
137
+ # collect responses
138
+ responses = []
139
+
140
+ # create progress bar
141
+ j = 0
142
+ pbar_text = lambda j: f'Creating llava response {j}/{len(label_array)}.'
143
+ pbar = st.progress(0, text=pbar_text(0))
144
+
145
+ for i, image in enumerate(image_array):
146
+ label = label_array[i]
147
+ question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "
148
+ question += "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
149
+ "how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)"\
150
+ "of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
151
+ response = ask_question_with_image_llava(image, system_prompt, question,
152
+ tokenizer, model, image_processor)
153
+ responses.append(response)
154
+ # update progress bar
155
+ j += 1
156
+ pbar.progress(j/len(label_array), pbar_text(j))
157
+
158
+ system_prompt_summarize = "You are a researcher studying mice interactions from videos of the inside of a resident "\
159
+ "intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one). "\
160
+ "You will be given a question about a list of descriptions from frames of these videos. "\
161
+ "Your job is to answer the question by focusing on the behaviors of the mice and their postures "\
162
+ "as well as any other aspects of the descriptions that may be relevant to the class label associated with them"
163
+ user_prompt_summarize = "Here are several descriptions of individual frames from a mouse behavior video. Please summarize these descriptions and provide a suggestion for a "\
164
+ "behavior label which captures what is described in the descriptions: \n\n"
165
+ user_prompt_summarize = user_prompt_summarize + '\n'.join(responses)
166
+ summary_response = ask_question_with_image_gpt(None, system_prompt_summarize, user_prompt_summarize, api_key)
167
+ return summary_response
168
+
169
+ if "embeddings_df" not in st.session_state:
170
+ st.session_state.embeddings_df = None
171
+
172
+ st.title('batik: frame classifier')
173
+
174
+ api_key = st.text_input("OpenAI API Key:","")
175
+ hf_token = st.text_input("HuggingFace Token:","")
176
+ st.subheader("generate or import embeddings")
177
+
178
+ st.text("Upload files to generate embeddings.")
179
+ with st.form('embedding_generation_settings'):
180
+ seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
181
+ annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
182
+ downsample_rate = st.number_input('Downsample Rate',value=4)
183
+ submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
184
+
185
+ st.markdown("**(Optional)** Upload embeddings.")
186
+ embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
187
+
188
+ if submit_embed_settings and seq_file is not None and annot_files is not None:
189
+ video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
190
+ "SLIP",
191
+ downsample_rate,
192
+ False)
193
+
194
+ fnames = [seq_file.name]
195
+ embeddings_df = create_embeddings_csv_io(out="file",
196
+ fnames=fnames,
197
+ embeddings=video_embeddings,
198
+ frames=video_frames,
199
+ annotations=[annot_files],
200
+ test_fnames=None,
201
+ views=None,
202
+ conditions=None,
203
+ downsample_rate=downsample_rate)
204
+ st.session_state.embeddings_df = embeddings_df
205
+ elif embeddings_csv is not None:
206
+ embeddings_df = pd.read_csv(embeddings_csv)
207
+ st.session_state.embeddings_df = embeddings_df
208
+ else:
209
+ st.text('Please upload file(s).')
210
+
211
+ st.divider()
212
+ st.subheader("provide video file if not yet already provided")
213
+
214
+ uploaded_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
215
+
216
+ st.divider()
217
+ if st.session_state.embeddings_df is not None and (uploaded_file is not None or seq_file is not None):
218
+ if seq_file is not None:
219
+ uploaded_file = seq_file
220
+ io_reader = get_io_reader(uploaded_file)
221
+ print("CONVERTED SEQ")
222
+ label_list = st.session_state.embeddings_df['Label'].to_list()
223
+ unique_label_list = get_unique_labels(label_list)
224
+ print(f"unique_labels: {unique_label_list}")
225
+ #unique_label_list = ['check_genital', 'wiggle', 'lordose', 'stay', 'turn', 'top_up', 'dart', 'sniff', 'approach', 'into_male_cage']
226
+ #unique_label_list = ['into_male_cage', 'intromission', 'male_sniff', 'mount']
227
+ kwargs = {'embeddings_df' : st.session_state.embeddings_df,
228
+ 'specified_classes' : unique_label_list,
229
+ 'classes_to_remove' : None,
230
+ 'max_class_size' : None,
231
+ 'animal_state' : None,
232
+ 'view' : None,
233
+ 'shuffle_data' : False,
234
+ 'test_videos' : None}
235
+ train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
236
+ print("PROCESSED DATASET")
237
+ if "Images" in st.session_state.embeddings_df.keys():
238
+ train_images = [i for i in range(len(train_images))]
239
+ embedding_2d = get_2d_embedding(train_embeds)
240
+ else:
241
+ st.text('Please generate embeddings and provide video file.')
242
+ print("GOT 2D EMBEDS")
243
+
244
+ if uploaded_file is not None and st.session_state.embeddings_df is not None:
245
+ st.subheader("t-SNE Projection")
246
+ option = st.selectbox(
247
+ "Select Color Option",
248
+ ("By Label", "By Time", "By Cluster")
249
+ )
250
+ if embedding_2d is not None:
251
+ if option is not None:
252
+ if option == "By Label":
253
+ color = 'label'
254
+ elif option == "By Time":
255
+ color = 'frame_no'
256
+ else:
257
+ color = 'cluster_label'
258
+
259
+ if option in ["By Label", "By Time"]:
260
+ edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
261
+ edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
262
+ edf.insert(3, 'label', train_labels)
263
+ fig = px.scatter(
264
+ edf,
265
+ x="tsne_dim_1",
266
+ y="tsne_dim_2",
267
+ color=color,
268
+ hover_data=["frame_no"],
269
+ color_discrete_sequence=px.colors.qualitative.Dark24
270
+ )
271
+ else:
272
+ r, _, _, _ = hdbscan_classification(train_embeds, embedding_2d, [4, 6])
273
+ edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
274
+ edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
275
+ edf.insert(3, 'label', train_labels)
276
+ edf.insert(4, 'cluster_label', [str(c_id) for c_id in r.labels_.tolist()])
277
+ fig = px.scatter(
278
+ edf,
279
+ x="tsne_dim_1",
280
+ y="tsne_dim_2",
281
+ color=color,
282
+ hover_data=["frame_no"],
283
+ color_discrete_sequence=px.colors.qualitative.Dark24
284
+ )
285
+
286
+ event = st.plotly_chart(fig, key="df", on_select="rerun")
287
+ else:
288
+ st.text("No Color Option Selected")
289
+ else:
290
+ st.text('No Embeddings Loaded')
291
+
292
+ event_dict = event.selection
293
+
294
+ if event_dict is not None:
295
+ custom_data = []
296
+ for point in event_dict['points']:
297
+ data = point["customdata"][0]
298
+ custom_data.append(int(data))
299
+
300
+ if len(custom_data) > 10:
301
+ custom_data = random.sample(custom_data, 10)
302
+ if len(custom_data) > 1:
303
+ col_1, col_2 = st.columns(2)
304
+ with col_1:
305
+ for frame_no in custom_data[::2]:
306
+ st.image(get_image(io_reader, frame_no))
307
+ st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
308
+ with col_2:
309
+ for frame_no in custom_data[1::2]:
310
+ st.image(get_image(io_reader, frame_no))
311
+ st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
312
+ elif len(custom_data) == 1:
313
+ frame_no = custom_data[0]
314
+ st.image(get_image(io_reader, frame_no))
315
+ st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
316
+ else:
317
+ st.text('No Points Selected')
318
+
319
+ if len(custom_data) == 1:
320
+ frame_no = custom_data[0]
321
+ image = get_image(io_reader, frame_no)
322
+ system_prompt = SYSTEM_PROMPT
323
+ label = train_labels[frame_no]
324
+ question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "\
325
+ "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
326
+ "how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)" \
327
+ "of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
328
+ tokenizer, model, image_processor = load_llava_model(hf_token)
329
+ response = ask_question_with_image_llava(image, system_prompt, question,
330
+ tokenizer, model, image_processor)
331
+ st.markdown(response)
332
+
333
+ elif len(custom_data) > 1:
334
+ image_array = [get_image(io_reader, f_no) for f_no in custom_data]
335
+ label_array = [train_labels[f_no] for f_no in custom_data]
336
+ response = ask_summary_question(image_array, label_array, api_key)
337
+ st.markdown(response)
generate_embeddings.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from utils.utils import create_embeddings_csv_io, create_annot_fname_dict_io, generate_embeddings_stream_io
4
+
5
+ if "video_embeddings" not in st.session_state:
6
+ st.session_state.video_embeddings = None
7
+ st.session_state.video_frames = None
8
+ st.session_state.fnames = []
9
+
10
+ st.title('batik: embedding generator')
11
+ uploaded_files = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
12
+ with st.form('initial_settings'):
13
+ st.header('Embedding Generation Options')
14
+ model_select = st.selectbox('Select Model', ['SLIP', 'CLIP'])
15
+ downsample_rate = st.number_input('Downsample Rate',value=4)
16
+ save_csv = st.toggle('Save Individual Results', value=False)
17
+ submit_initial_settings = st.form_submit_button('Create Embeddings', type='secondary')
18
+
19
+ if submit_initial_settings and uploaded_files is not None and len(uploaded_files) > 0:
20
+ video_embeddings, video_frames = generate_embeddings_stream_io(uploaded_files,
21
+ model_select,
22
+ downsample_rate,
23
+ save_csv)
24
+ fnames = [vid_file.name for vid_file in uploaded_files]
25
+ st.session_state.video_embeddings = video_embeddings
26
+ st.session_state.video_frames = video_frames
27
+ st.session_state.fnames = fnames
28
+
29
+ if st.session_state.video_embeddings is not None:
30
+ st.header('CSV Configuration Options')
31
+ st.markdown('If using `.annot` files and multiple files should be grouped together, '\
32
+ 'please ensure that they share a common name and end with a number describing '\
33
+ 'the order of the files. For example:\n\n'\
34
+ '`mouse_224_file_1.annot`, `mouse_224_file_2.annot`.')
35
+ annot_files = st.file_uploader("Upload all annotation files", type=['.annot','.csv'], accept_multiple_files=True)
36
+
37
+ annot_options = []
38
+ if annot_files is not None and len(annot_files) > 0:
39
+ annot_fnames = [annot_file.name for annot_file in annot_files]
40
+ annot_fname_dict = create_annot_fname_dict_io(annot_fnames=annot_fnames,
41
+ annot_files=annot_files)
42
+ annot_options = [str(key) for key in annot_fname_dict.keys()]
43
+
44
+ if len(annot_options) > 0:
45
+ with st.form('csv_settings'):
46
+ csv_setting_def = pd.DataFrame(
47
+ {
48
+ "File Name" : st.session_state.fnames,
49
+ "Annotations" : [
50
+ "Upload File" for _ in st.session_state.fnames
51
+ ],
52
+ "Test" : [
53
+ False for _ in st.session_state.fnames
54
+ ],
55
+ "View" : [
56
+ "Top" for _ in st.session_state.fnames
57
+ ],
58
+ "Condition" : [
59
+ "None" for _ in st.session_state.fnames
60
+ ]
61
+
62
+ }
63
+ )
64
+
65
+ csv_settings = st.data_editor(
66
+ csv_setting_def,
67
+ column_config={
68
+ "Annotations" : st.column_config.SelectboxColumn(
69
+ "Annotations",
70
+ help="The annotation file(s) to use for the given video file.",
71
+ width="medium",
72
+ options=annot_options,
73
+ required=True
74
+ ),
75
+ "Test" : st.column_config.CheckboxColumn(
76
+ "Test",
77
+ help="Designate file(s) to use as the test set.",
78
+ default=False,
79
+ required=True
80
+ ),
81
+ "View" : st.column_config.SelectboxColumn(
82
+ "View",
83
+ help="The view used within the video (either Top or Front).",
84
+ options=["Top", "Front"],
85
+ required=True
86
+ ),
87
+ "Condition" : st.column_config.TextColumn(
88
+ "Condition",
89
+ help="A condition the video has (i.e. Control).",
90
+ default="None",
91
+ max_chars=30,
92
+ validate=r"[a-z]+$",
93
+ )
94
+ },
95
+ hide_index=True
96
+ )
97
+ save_csv_bttn = st.form_submit_button("Create CSV")
98
+
99
+ if save_csv_bttn and csv_settings is not None:
100
+ annot_chosen_options = csv_settings['Annotations'].tolist()
101
+ annot_option = [annot_fname_dict[key] for key in annot_chosen_options]
102
+ test_chosen_option = csv_settings['Test'].tolist()
103
+ test_option = [st.session_state.fnames[i] for i, is_test in enumerate(test_chosen_option) if is_test]
104
+ view_option = csv_settings['View'].tolist()
105
+ condition_option = csv_settings['Condition'].tolist()
106
+
107
+ out_name = st.text_input("Embeddings Outpit File Name", "out.csv")
108
+ try:
109
+ df = create_embeddings_csv_io(out=out_name,
110
+ fnames=st.session_state.fnames,
111
+ embeddings=st.session_state.video_embeddings,
112
+ frames=st.session_state.video_frames,
113
+ annotations=annot_option,
114
+ test_fnames=test_option,
115
+ views=view_option,
116
+ conditions=condition_option,
117
+ downsample_rate=downsample_rate)
118
+ st.success('Created Embeddings File!', icon="✅")
119
+ st.download_button(
120
+ label="Download CSV",
121
+ data=df.to_csv().encode("utf-8"),
122
+ file_name=out_name,
123
+ mime="text/csv"
124
+ )
125
+ except:
126
+ st.error('Something went wrong.')
127
+
128
+ else:
129
+ st.text('Please Upload Files')
130
+ else:
131
+ st.text('Please Upload Files')
get_llava_response.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
4
+ import numpy as np
5
+
6
+ from huggingface_hub import whoami
7
+
8
+ import llava
9
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
10
+ from llava.conversation import conv_templates, SeparatorStyle
11
+ from llava.model.builder import load_pretrained_model
12
+ from llava.utils import disable_torch_init
13
+ from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
14
+
15
+ from PIL import Image
16
+
17
+ import requests
18
+ from PIL import Image
19
+ from io import BytesIO
20
+ from transformers import TextStreamer
21
+ from tqdm import tqdm
22
+
23
+ import warnings
24
+ warnings.filterwarnings('ignore')
25
+
26
+ REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
27
+
28
+ def load_image(image_file):
29
+ if image_file.startswith('http://') or image_file.startswith('https://'):
30
+ response = requests.get(image_file)
31
+ image = Image.open(BytesIO(response.content)).convert('RGB')
32
+ else:
33
+ image = Image.open(image_file).convert('RGB')
34
+ return image
35
+
36
+ def load_llava_checkpoint(model_path: str):
37
+ model_name = get_model_name_from_path(model_path)
38
+ return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda")
39
+
40
+ def load_llava_checkpoint_hf(model_path, hf_token):
41
+ user = whoami(token=hf_token)
42
+ kwargs = {"device_map": "auto"}
43
+ kwargs['load_in_4bit'] = True
44
+ kwargs['quantization_config'] = BitsAndBytesConfig(
45
+ load_in_4bit=True,
46
+ bnb_4bit_compute_dtype=torch.float16,
47
+ bnb_4bit_use_double_quant=True,
48
+ bnb_4bit_quant_type='nf4'
49
+ )
50
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
51
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
52
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
53
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
54
+ if mm_use_im_patch_token:
55
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
56
+ if mm_use_im_start_end:
57
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
58
+ model.resize_token_embeddings(len(tokenizer))
59
+
60
+ vision_tower = model.get_vision_tower()
61
+ if not vision_tower.is_loaded:
62
+ vision_tower.load_model(device_map="auto")
63
+ image_processor = vision_tower.image_processor
64
+ return tokenizer, model, image_processor
65
+
66
+ def get_llava_response(user_prompts: list[str],
67
+ images: list,
68
+ sys_prompt: str,
69
+ tokenizer,
70
+ model,
71
+ image_processor,
72
+ model_path = REPO_NAME,
73
+ stream_output = True):
74
+ """
75
+ This function returns the response from the given model. It creates a one turn conversation in which
76
+ the only content is a system prompt and the given user message applied to each image.
77
+
78
+ Parameters:
79
+ ----------
80
+ user_prompt : str
81
+ The prompt sent by the user.
82
+ images : str
83
+ List of images from file.
84
+ sys_prompt : str
85
+ The prompt that sets the tone for the conversation.
86
+ model_path : str
87
+ The path to the merged checkpoint or base model.
88
+
89
+ Returns:
90
+ --------
91
+ """
92
+ # set up and load model
93
+ model_name = get_model_name_from_path(model_path)
94
+ temperature = 0.2 # default
95
+ max_new_tokens = 512 # default
96
+
97
+ # determine conversation type
98
+ if "llama-2" in model_name.lower():
99
+ conv_mode = "llava_llama_2"
100
+ elif "mistral" in model_name.lower():
101
+ conv_mode = "mistral_instruct"
102
+ elif "v1.6-34b" in model_name.lower():
103
+ conv_mode = "chatml_direct"
104
+ elif "v1" in model_name.lower():
105
+ conv_mode = "llava_v1"
106
+ elif "mpt" in model_name.lower():
107
+ conv_mode = "mpt"
108
+ else:
109
+ conv_mode = "llava_v0"
110
+
111
+ # run clean conversation for each image
112
+ llm_outputs = []
113
+ for i, img in tqdm(enumerate(images)):
114
+ # set up clean conversation
115
+ conv = conv_templates[conv_mode].copy()
116
+ if "mpt" in model_name.lower():
117
+ roles = ('user', 'assistant')
118
+ else:
119
+ roles = conv.roles
120
+
121
+ conv.system = sys_prompt
122
+
123
+ # load image
124
+ # image = load_image("../images/mouse.png") # previous method
125
+ if isinstance(img, np.ndarray) and len(img.shape) == 2:
126
+ img = Image.fromarray(img, 'L')
127
+
128
+ image = img.convert('RGB')
129
+ image_size = image.size
130
+
131
+ # NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files!
132
+
133
+ # Similar operation in model_worker.py
134
+ image_tensor = process_images([image], image_processor, model.config)
135
+ if type(image_tensor) is list:
136
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
137
+ else:
138
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
139
+
140
+ # execute conversation
141
+ inp = user_prompts[i]
142
+ if image is not None:
143
+ # first message
144
+ if model.config.mm_use_im_start_end:
145
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
146
+ else:
147
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
148
+ image = None
149
+ conv.append_message(conv.roles[0], inp)
150
+ conv.append_message(conv.roles[1], None)
151
+ prompt = conv.get_prompt()
152
+ input_ids = tokenizer_image_token(prompt,
153
+ tokenizer,
154
+ IMAGE_TOKEN_INDEX,
155
+ return_tensors='pt').unsqueeze(0).to(model.device)
156
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
157
+ keywords = [stop_str]
158
+ if stream_output:
159
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
160
+ else:
161
+ streamer = None
162
+
163
+ with torch.inference_mode():
164
+ output_ids = model.generate(
165
+ input_ids,
166
+ images=image_tensor,
167
+ image_sizes=[image_size],
168
+ do_sample=True if temperature > 0 else False,
169
+ temperature=temperature,
170
+ max_new_tokens=max_new_tokens,
171
+ streamer=streamer,
172
+ use_cache=True)
173
+
174
+ outputs = tokenizer.decode(output_ids[0]).strip()
175
+ llm_outputs.append(outputs)
176
+ return llm_outputs
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
home.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.title('batik')
4
+ st.subheader('Abstract')
5
+ st.markdown('Quantitative analysis of animal behavior represents a burgeoning frontier in neuroscience and ethology. Recent years have witnessed a proliferation of computational methods aimed at identifying behavioral subtypes, or "syllables," from video data. However, while significant advances have been made in behavior segmentation, comparatively few approaches address the interpretation of these behavior syllables, leaving researchers to spend considerable time curating and interpreting the characteristics of the behavioral subtype. Furthermore, most current techniques rely heavily on pose estimation—a prerequisite that, while useful, can introduce limitations concerning generalization in behavioral classification and discovery. Here, we introduce Batik, a system leveraging pre-trained and fine-tuned multimodal transformers to perform end-to-end behavior analysis directly from raw video. Batik excels at supervised behavior annotation, utilizing lightweight models trained on the transformer-extracted feature space to achieve state-of-the-art performance. By integrating a pre-trained vision transformer with a custom fine-tuned language model, Batik not only discovers behavior syllables but also provides expert-level interpretations of mouse behavior, directly from visual data. This comprehensive platform empowers researchers with automated behavior discovery and interpretation, significantly reducing the time burden on experimentalists. Coupled with an intuitive user interface, Batik offers a transformative tool for the next generation of behavioral analysis, showcasing the potential of what is possible with transformer-based language models for behavior.')
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "llava"
7
+ version = "1.2.2.post1"
8
+ description = "Towards GPT-4 like large language and visual assistant."
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch==2.1.2", "torchvision==0.16.2",
17
+ "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate==0.21.0", "peft", "bitsandbytes",
19
+ "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20
+ "gradio==4.16.0", "gradio_client==0.8.1",
21
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23
+ "protobuf", "timecode", "sortedcontainers", "qtpy", "pyqt5-tools",
24
+ "scipy", "matplotlib", "colour_demosaicing", "sk-video",
25
+ "opencv-python", "progressbar", "openai",
26
+ "clip @ git+https://github.com/openai/CLIP@main",
27
+ "scikit-learn", "tensorflow", "sentencepiece", "streamlit",
28
+ "hdbscan", "plotly", "ipywidgets"
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ train = ["deepspeed==0.12.6", "ninja", "wandb"]
33
+ build = ["build", "twine"]
34
+
35
+ [project.urls]
36
+ "Homepage" = "https://llava-vl.github.io"
37
+ "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
38
+
39
+ [tool.setuptools.packages.find]
40
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
41
+
42
+ [tool.wheel]
43
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
train_model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import pickle
4
+ import regex
5
+ import streamlit as st
6
+ import plotly.express as px
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from utils.seqIo import seqIo_reader
11
+ import pandas as pd
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from transformers import AutoProcessor, AutoModel
15
+ from tqdm import tqdm
16
+ from sklearn.svm import SVC
17
+ from sklearn.model_selection import train_test_split
18
+ from sklearn.metrics import accuracy_score, classification_report
19
+ from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
20
+
21
+ # --server.maxUploadSize 3000
22
+
23
+ def get_unique_labels(label_list: list[str]):
24
+ label_set = set()
25
+ for label in label_list:
26
+ individual_labels = label.split('||')
27
+ for individual_label in individual_labels:
28
+ label_set.add(individual_label)
29
+ return list(label_set)
30
+
31
+ @st.cache_data
32
+ def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
33
+ return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
34
+
35
+ @st.cache_resource
36
+ def train_model(X_train, y_train, random_state=42):
37
+ # Train SVM Classifier
38
+ svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
39
+ svm_clf.fit(X_train, y_train)
40
+ return svm_clf
41
+
42
+ def pickle_model(model):
43
+ pickled = io.BytesIO()
44
+ pickle.dump(model, pickled)
45
+ return pickled
46
+
47
+ if "embeddings_df" not in st.session_state:
48
+ st.session_state.embeddings_df = None
49
+
50
+ if "svm_clf" not in st.session_state:
51
+ st.session_state.svm_clf = None
52
+ st.session_state.report_df = None
53
+ st.session_state.accuracy = None
54
+
55
+ st.title('batik: frame classifier training')
56
+
57
+ st.text("Upload files to train classifier on.")
58
+ with st.form('embedding_generation_settings'):
59
+ seq_file = st.file_uploader("Choose a .seq File", type=['seq'])
60
+ annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
61
+ downsample_rate = st.number_input('Downsample Rate',value=4)
62
+ submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
63
+
64
+ st.markdown("**(Optional)** Upload embeddings.")
65
+ embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
66
+
67
+ if submit_embed_settings and seq_file is not None and annot_files is not None:
68
+ video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
69
+ "SLIP",
70
+ downsample_rate,
71
+ False)
72
+
73
+ fnames = [seq_file.name]
74
+ embeddings_df = create_embeddings_csv_io(out="file",
75
+ fnames=fnames,
76
+ embeddings=video_embeddings,
77
+ frames=video_frames,
78
+ annotations=[annot_files],
79
+ test_fnames=None,
80
+ views=None,
81
+ conditions=None,
82
+ downsample_rate=downsample_rate)
83
+ st.session_state.embeddings_df = embeddings_df
84
+
85
+ elif embeddings_csv is not None:
86
+ embeddings_df = pd.read_csv(embeddings_csv)
87
+ st.session_state.embeddings_df = embeddings_df
88
+ else:
89
+ st.text('Please upload file(s).')
90
+
91
+ st.divider()
92
+
93
+ if st.session_state.embeddings_df is not None:
94
+ st.subheader("specify dataset preprocessing options")
95
+ st.text("Select frames with label(s) to include:")
96
+
97
+ with st.form('train_settings'):
98
+ label_list = st.session_state.embeddings_df['Label'].to_list()
99
+ unique_label_list = get_unique_labels(label_list)
100
+ specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
101
+
102
+ st.text("Select label(s) that should be removed:")
103
+ classes_to_remove = st.multiselect("Label(s) excluded:", options=unique_label_list)
104
+
105
+ max_class_size = st.number_input("(Optional) Specify max class size:", value=None)
106
+
107
+ shuffle_data = st.toggle("Shuffle data:")
108
+
109
+ train_model_clicked = st.form_submit_button("Train Model")
110
+
111
+ if train_model_clicked:
112
+ kwargs = {'embeddings_df' : st.session_state.embeddings_df,
113
+ 'specified_classes' : specified_classes,
114
+ 'classes_to_remove' : classes_to_remove,
115
+ 'max_class_size' : max_class_size,
116
+ 'animal_state' : None,
117
+ 'view' : None,
118
+ 'shuffle_data' : shuffle_data,
119
+ 'test_videos' : None}
120
+ train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
121
+ # Convert labels to numerical values
122
+ label_to_appear_first = 'other'
123
+ unique_labels = set(train_labels)
124
+ unique_labels.discard(label_to_appear_first)
125
+
126
+ label_to_index = {label_to_appear_first: 0}
127
+
128
+ label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
129
+ index_to_label = {idx: label for label, idx in label_to_index.items()}
130
+ numerical_labels = np.array([label_to_index[label] for label in train_labels])
131
+
132
+ print("Label Valence: ", label_to_index)
133
+ # Split data into train and test sets
134
+ X_train, X_test, y_train, y_test = get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42)
135
+ with st.spinner("Model training in progress..."):
136
+ svm_clf = train_model(X_train, y_train)
137
+
138
+ # Predict on the test set
139
+ with st.spinner("In progress..."):
140
+ y_pred = svm_clf.predict(X_test)
141
+ accuracy = accuracy_score(y_test, y_pred)
142
+ report = classification_report(y_test, y_pred, target_names=[index_to_label[idx] for idx in range(len(label_to_index))], output_dict=True)
143
+ report_df = pd.DataFrame(report).transpose()
144
+
145
+ # save results to session state
146
+ st.session_state.svm_clf = svm_clf
147
+ st.session_state.report_df = report_df
148
+ st.session_state.accuracy = accuracy
149
+
150
+ if st.session_state.svm_clf is not None:
151
+ pickled_model = pickle_model(st.session_state.svm_clf)
152
+
153
+ st.text(f"Eval Accuracy: {st.session_state.accuracy}")
154
+ st.subheader("Classification Report:")
155
+ st.dataframe(st.session_state.report_df)
156
+
157
+ st.download_button("Download model as .pkl file",
158
+ data=pickled_model,
159
+ file_name=f"{'_'.join(specified_classes)}_classifier.pkl")
utils/__init__.py ADDED
File without changes
utils/annot.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # annot.py
2
+ from io import StringIO
3
+ from random import sample
4
+ from collections import OrderedDict
5
+ import timecode as tc
6
+ from .behavior import Behavior
7
+ from sortedcontainers import SortedKeyList
8
+ from qtpy.QtCore import QObject, QRectF, Signal, Slot
9
+ from qtpy.QtGui import QColor
10
+ from qtpy.QtWidgets import QGraphicsItem
11
+
12
+ class Bout(object):
13
+ """
14
+ """
15
+
16
+ def __init__(self, start, end, behavior):
17
+ self._start = start
18
+ self._end = end
19
+ self._behavior = behavior
20
+
21
+ def __lt__(self, b):
22
+ if type(b) is tc.Timecode:
23
+ return self._start.float < b.float
24
+ elif type(b) is Bout:
25
+ return self._start < b._start
26
+ else:
27
+ raise NotImplementedError(f"Comparing Bout with {type(b)} not supported")
28
+
29
+ def __le__(self, b):
30
+ if type(b) is tc.Timecode:
31
+ return self._start.float <= b.is_float
32
+ elif type(b) is Bout:
33
+ return self._start <= b._start
34
+ else:
35
+ raise NotImplementedError(f'Comparing Bout with {type(b)} not supported')
36
+
37
+ def is_at(self, t):
38
+ return self._start <= t and self._end >= t
39
+
40
+ def start(self):
41
+ return self._start
42
+
43
+ def set_start(self, start):
44
+ self._start = start
45
+
46
+ def end(self):
47
+ return self._end
48
+
49
+ def set_end(self, end):
50
+ self._end = end
51
+
52
+ def len(self):
53
+ return self._end - self._start + tc.Timecode(self._start.framerate, frames=1)
54
+
55
+ def behavior(self):
56
+ return self._behavior
57
+
58
+ def name(self):
59
+ return self._behavior.get_name()
60
+
61
+ def color(self):
62
+ return self._behavior.get_color()
63
+
64
+ def is_active(self):
65
+ return self._behavior.is_active()
66
+
67
+ def is_visible(self):
68
+ return self._behavior.is_visible() and self._behavior.is_active()
69
+
70
+ def __repr__(self):
71
+ return f"Bout: start = {self.start()}, end = {self.end()}, behavior: {self.behavior()}"
72
+
73
+ class Channel(QGraphicsItem):
74
+ """
75
+ """
76
+
77
+ contentChanged = Signal()
78
+
79
+ def __init__(self, chan = None):
80
+ super().__init__()
81
+ if not chan is None:
82
+ self._bouts_by_start = chan._bouts_by_start
83
+ self._bouts_by_end = chan._bouts_by_end
84
+ self._top = chan._top
85
+ else:
86
+ self._bouts_by_start = SortedKeyList(key=lambda bout: bout.start().float)
87
+ self._bouts_by_end = SortedKeyList(key=lambda bout: bout.end().float)
88
+ self._top = 0.
89
+ self.cur_ix = 0
90
+ self.fakeFirstBout = Bout(
91
+ tc.Timecode('30.0', '0:0:0:0'),
92
+ tc.Timecode('30.0', '0:0:0:0'),
93
+ Behavior('none', '', QColor.fromRgbF(0., 0., 0.), active = True)
94
+ )
95
+ self.fakeLastBout = Bout(
96
+ tc.Timecode('30.0', '23:59:59:29'),
97
+ tc.Timecode('30.0', '23:59:59:29'),
98
+ Behavior('none', '', QColor.fromRgbF(0., 0., 0.), active = True)
99
+ )
100
+
101
+ def add(self, b):
102
+ if isinstance(b, Bout):
103
+ self._bouts_by_start.add(b)
104
+ self._bouts_by_end.add(b)
105
+ else:
106
+ raise TypeError("Can only add Bouts to Channel")
107
+
108
+ def remove(self, b):
109
+ # can raise ValueError if b is not in the channel
110
+ self._bouts_by_start.remove(b)
111
+ self._bouts_by_end.remove(b)
112
+
113
+ def update_start(self, b, new_start):
114
+ """
115
+ Update the starting time of a bout while
116
+ preserving the _bouts_by_start access order
117
+ """
118
+ self._bouts_by_start.remove(b)
119
+ b.set_start(new_start)
120
+ self._bouts_by_start.add(b)
121
+
122
+ def update_end(self, b, new_end):
123
+ """
124
+ Update the ending time of a bout while
125
+ preserving the _bouts_by_end access order
126
+ """
127
+ self._bouts_by_end.remove(b)
128
+ b.set_end(new_end)
129
+ self._bouts_by_end.add(b)
130
+
131
+ def __add__(self, b):
132
+ self.add(b)
133
+
134
+ def _get_next(self, t, sortedlist):
135
+ ix = sortedlist.bisect_key_right(t.float)
136
+ l = len(sortedlist)
137
+ if ix == l:
138
+ return self.fakeLastBout, t.next()
139
+ return sortedlist[ix], t.next()
140
+
141
+ def _get_prev(self, t, sortedlist):
142
+ ix = sortedlist.bisect_key_left(t.float)
143
+ if ix == 0:
144
+ return self.fakeFirstBout, t.back()
145
+ return sortedlist[ix-1], t.back()
146
+
147
+ def _get_inner(self, t, sortedList, op):
148
+ t_local = t + 0 # kludgy copy constructor!
149
+ visible = False
150
+ while not visible:
151
+ # no need to check for the end, because the fake first and last bouts are visible
152
+ bout, t_local = op(t_local, sortedList)
153
+ visible = bout.is_visible()
154
+ return bout
155
+
156
+ def get_next_start(self, t):
157
+ return self._get_inner(t, self._bouts_by_start, self._get_next)
158
+
159
+ def get_next_end(self, t):
160
+ return self._get_inner(t, self._bouts_by_end, self._get_next)
161
+
162
+ def get_prev_start(self, t):
163
+ return self._get_inner(t, self._bouts_by_start, self._get_prev)
164
+
165
+ def get_prev_end(self, t):
166
+ return self._get_inner(t, self._bouts_by_end, self._get_prev)
167
+
168
+ def get_in_range(self, start, end):
169
+ """
170
+ get all bouts that intersect the range [start, end]
171
+ """
172
+ return [bout for bout in self._bouts_by_start
173
+ if bout.start().float <= end.float and bout.end().float >= start.float]
174
+
175
+ def get_at(self, t):
176
+ """
177
+ get all bouts that span time t
178
+ """
179
+ return self.get_in_range(t, t)
180
+
181
+ def __iter__(self):
182
+ return iter(self._bouts_by_start)
183
+
184
+ def irange(self, start_time, end_time):
185
+ if isinstance(start_time, tc.Timecode):
186
+ start_time = start_time.float
187
+ if isinstance(end_time, tc.Timecode):
188
+ end_time = end_time.float
189
+ if not isinstance(start_time, float):
190
+ raise TypeError(f"Can't handle start_time of type {type(start_time)}")
191
+ if not isinstance(end_time, float):
192
+ raise TypeError(f"Can't handle end_time of type {type(end_time)}")
193
+ return self._bouts_by_start.irange_key(start_time, end_time)
194
+
195
+ def top(self):
196
+ return self._top
197
+
198
+ def set_top(self, top):
199
+ self._top = top
200
+
201
+ def boundingRect(self):
202
+ width = self.fakeLastBout.end().float
203
+ return QRectF(0., self.top(), width, 1.)
204
+
205
+ def paint(self, painter, option, widget=None):
206
+ boundingRect = option.rect
207
+ in_range_bouts = self._bouts_by_start.irange_key(boundingRect.left(), boundingRect.right())
208
+ while True:
209
+ try:
210
+ bout = next(in_range_bouts)
211
+ except StopIteration:
212
+ break
213
+ if bout.is_visible():
214
+ painter.fillRect(
215
+ QRectF(bout.start().float, self.top(), bout.len().float, 1.),
216
+ bout.color()
217
+ )
218
+
219
+ def _delete_all_inner(self, predicate):
220
+ to_delete = list()
221
+ # can't alter the bouts within the iterator
222
+ for bout in iter(self):
223
+ if predicate(bout):
224
+ to_delete.append(bout)
225
+ deleted_names = set()
226
+ for bout in to_delete:
227
+ deleted_names.add(bout.name())
228
+ self.remove(bout)
229
+ return deleted_names
230
+
231
+ def delete_bouts_by_name(self, behavior_name):
232
+ return self._delete_all_inner(lambda bout: bout.name() == behavior_name)
233
+
234
+ def delete_inactive_bouts(self):
235
+ return self._delete_all_inner(lambda bout: not bout.is_active())
236
+
237
+ def truncate_or_remove_bouts(self, behavior, start, end, delete_all=False):
238
+ """
239
+ Delete bouts entirely within the range [start, end], and
240
+ truncate bouts that extend outside the range.
241
+ If behavior matches _deleteBehavior, the activity affects
242
+ all bouts. Otherwise, it only affects bouts with matching behavior.
243
+ """
244
+ items = self.get_in_range(start, end)
245
+ for item in items:
246
+ if not delete_all and behavior.get_name() != item.name():
247
+ continue
248
+ # Delete bouts that are entirely within the range
249
+ if item.start() >= start and item.end() < end:
250
+ print(f"removing {item} from active channel")
251
+ self.remove(item)
252
+
253
+ # Truncate and duplicate bouts that extend out both sides of the range
254
+ if item.start() < start and item.end() > end:
255
+ self.add(Bout(end, item.end(), item.behavior()))
256
+ self.update_end(item, start)
257
+
258
+ # Truncate bouts at the start boundary that start before the range
259
+ elif item.start() < start and item.end() <= end:
260
+ self.update_end(item, start)
261
+
262
+ # Truncate bouts at the end boundary that end after the range
263
+ elif item.start() >= start and item.end() > end:
264
+ self.update_start(item, end)
265
+
266
+ else:
267
+ print(f"truncate_or_delete_bouts: Unexpected bout {item}")
268
+
269
+ def coalesce_bouts(self, start, end):
270
+ """
271
+ combine overlapping bouts of the same behavior within [start, end]
272
+ """
273
+ to_delete = []
274
+ items = self.get_in_range(start, end)
275
+ # items will be ordered by start time
276
+ for ix, first in enumerate(items):
277
+ if first in to_delete:
278
+ # previously coalesced
279
+ continue
280
+ if ix == len(items)-1:
281
+ break
282
+ for second in items[ix+1:]:
283
+ if (first.name() == second.name() and
284
+ first.end() >= second.start()):
285
+ if first.end() < second.end():
286
+ self.update_end(first, second.end())
287
+ to_delete.append(second)
288
+ for item in to_delete:
289
+ self.remove(item)
290
+
291
+ class Annotations(QObject):
292
+ """
293
+ """
294
+
295
+ # Signals
296
+ annotations_changed = Signal()
297
+ active_annotations_changed = Signal()
298
+
299
+ def __init__(self, behaviors):
300
+ super().__init__()
301
+ self._channels = OrderedDict()
302
+ self._behaviors = behaviors
303
+ self._movies = []
304
+ self._start_frame = None
305
+ self._end_frame = None
306
+ self._sample_rate = None
307
+ self._stimulus = None
308
+ self._format = None
309
+ self.annotation_names = []
310
+ behaviors.behaviors_changed.connect(self.note_annotations_changed)
311
+
312
+ def read(self, fn):
313
+ with open(fn, "r") as f:
314
+ line = f.readline()
315
+ line = line.strip().lower()
316
+ if line.endswith("annotation file"):
317
+ self._format = 'Caltech'
318
+ self._read_caltech(f)
319
+ elif line.startswith("scorevideo log"):
320
+ self._format = 'Ethovision'
321
+ self._read_ethovision(f)
322
+ else:
323
+ print("Unsupported annotation file format")
324
+
325
+ def read_io(self, uploaded_file):
326
+ text_str = uploaded_file.getvalue().decode("utf-8")
327
+ with StringIO(text_str) as f:
328
+ f.__setattr__('name', uploaded_file.name)
329
+ line = f.readline()
330
+ line = line.strip().lower()
331
+ if line.endswith("annotation file"):
332
+ self._format = 'Caltech'
333
+ self._read_caltech(f)
334
+ elif line.startswith("scorevideo log"):
335
+ self._format = 'Ethovision'
336
+ self._read_ethovision(f)
337
+ else:
338
+ print("Unsupported annotation file format")
339
+
340
+ def _read_caltech(self, f):
341
+ found_movies = False
342
+ found_timecode = False
343
+ found_stimulus = False
344
+ found_channel_names = False
345
+ found_annotation_names = False
346
+ found_all_channels = False
347
+ found_all_annotations = False
348
+ new_behaviors_activated = False
349
+ reading_channel = False
350
+ to_activate = []
351
+ channel_names = []
352
+ current_channel = None
353
+ current_bout = None
354
+
355
+ self._format = 'Caltech'
356
+
357
+ line = f.readline()
358
+ while line:
359
+ if found_annotation_names and not new_behaviors_activated:
360
+ self.ensure_and_activate_behaviors(to_activate)
361
+ new_behaviors_activated = True
362
+
363
+ line.strip()
364
+
365
+ if not line:
366
+ if reading_channel:
367
+ reading_channel = False
368
+ current_channel = None
369
+ current_bout = None
370
+ elif line.lower().startswith("movie file"):
371
+ items = line.split()
372
+ for item in items:
373
+ if item.lower().startswith("movie"):
374
+ continue
375
+ if item.lower().startswith("file"):
376
+ continue
377
+ self._movies.append(item)
378
+ found_movies = True
379
+ elif line.lower().startswith("stimulus name"):
380
+ # TODO: do something when we know what one of these looks like
381
+ found_stimulus = True
382
+ elif line.lower().startswith("annotation start frame") or line.lower().startswith("annotation start time"):
383
+ items = line.split()
384
+ if len(items) > 3:
385
+ try:
386
+ self._start_frame = int(items[3])
387
+ except:
388
+ self._start_frame = int(float(items[3]))
389
+ if self._end_frame and self._sample_rate:
390
+ found_timecode = True
391
+ elif line.lower().startswith("annotation stop frame") or line.lower().startswith("annotation stop time"):
392
+ items = line.split()
393
+ if len(items) > 3:
394
+ try:
395
+ self._end_frame = int(items[3])
396
+ except:
397
+ self._end_frame = int(float(items[3]))
398
+ if self._start_frame and self._sample_rate:
399
+ found_timecode = True
400
+ elif line.lower().startswith("annotation framerate"):
401
+ items = line.split()
402
+ if len(items) > 2:
403
+ self._sample_rate = float(items[2])
404
+ if self._start_frame and self._end_frame:
405
+ found_timecode = True
406
+ elif line.lower().startswith("list of channels"):
407
+ line = f.readline()
408
+ while line:
409
+ line = line.strip()
410
+ if not line:
411
+ break # blank line -- end of section
412
+ channel_names.append(line)
413
+ line = f.readline()
414
+ found_channel_names = True
415
+ elif line.lower().startswith("list of annotations"):
416
+ line = f.readline()
417
+ while line:
418
+ line = line.strip()
419
+ if not line:
420
+ break # blank line -- end of section
421
+ to_activate.append(line)
422
+ line = f.readline().strip()
423
+ found_annotation_names = True
424
+ elif line.strip().lower().endswith("---"):
425
+ for ch_name in channel_names:
426
+ if line.startswith(ch_name):
427
+ self._channels[ch_name] = Channel()
428
+ current_channel = ch_name
429
+ reading_channel = True
430
+ break
431
+ if reading_channel:
432
+ reading_annot = False
433
+ line = f.readline()
434
+ while line:
435
+ line = line.strip()
436
+ if not line: # blank line
437
+ if reading_annot:
438
+ reading_annot = False
439
+ current_bout = None
440
+ else:
441
+ # second blank line, so done with channel
442
+ reading_channel = False
443
+ current_channel = None
444
+ break
445
+ elif line.startswith(">"):
446
+ current_bout = line[1:]
447
+ reading_annot = True
448
+ elif line.lower().startswith("start"):
449
+ pass # skip a header line
450
+ else:
451
+ items = line.split()
452
+ is_float = '.' in items[0] or '.' in items[1] or '.' in items[2]
453
+ bout_start = items[0]
454
+ if float(bout_start) < 1:
455
+ bout_start = 1
456
+ bout_end = items[1]
457
+ if float(bout_end) < 1:
458
+ bout_end = 1
459
+ self.add_bout(
460
+ Bout(
461
+ tc.Timecode(self._sample_rate, start_seconds=float(bout_start)) if is_float
462
+ else tc.Timecode(self._sample_rate, frames=int(bout_start)),
463
+ tc.Timecode(self._sample_rate, start_seconds=float(bout_end)) if is_float
464
+ else tc.Timecode(self._sample_rate, frames=int(bout_end)),
465
+ self._behaviors.get(current_bout)),
466
+ current_channel)
467
+ line = f.readline()
468
+ line = f.readline()
469
+ print(f"Done reading Caltech annotation file {f.name}")
470
+ self.note_annotations_changed()
471
+
472
+ def write_caltech(self, f, video_files, stimulus):
473
+ if not f.writable():
474
+ raise RuntimeError("File not writable")
475
+ f.write("Bento annotation file\n")
476
+ f.write("Movie file(s):")
477
+ for file in video_files:
478
+ f.write(' ' + file)
479
+ f.write('\n\n')
480
+
481
+ f.write(f"Stimulus name: {stimulus}\n")
482
+ f.write(f"Annotation start frame: {self._start_frame}\n")
483
+ f.write(f"Annotation stop frame: {self._end_frame}\n")
484
+ f.write(f"Annotation framerate: {self._sample_rate}\n")
485
+ f.write("\n")
486
+
487
+ f.write("List of Channels:\n")
488
+ for ch in self.channel_names():
489
+ f.write(ch + "\n")
490
+ f.write("\n")
491
+
492
+ f.write("List of annotations:\n")
493
+ for annot in self.annotation_names:
494
+ f.write(annot + "\n")
495
+ f.write("\n")
496
+
497
+ for ch in self.channel_names():
498
+ by_name = {}
499
+ f.write(f"{ch}----------\n")
500
+
501
+ for bout in self.channel(ch):
502
+ if not by_name.get(bout.name()):
503
+ by_name[bout.name()] = []
504
+ by_name[bout.name()].append(bout)
505
+
506
+ for annot in by_name:
507
+ f.write(f">{annot}\n")
508
+ f.write("Start\tStop\tDuration\n")
509
+ for bout in by_name[annot]:
510
+ start = bout.start().frames
511
+ end = bout.end().frames
512
+ f.write(f"{start}\t{end}\t{end - start}\n")
513
+ f.write("\n")
514
+
515
+ f.write("\n")
516
+
517
+ f.close()
518
+ print(f"Done writing Caltech annotation file {f.name}")
519
+
520
+ def _read_ethovision(self, f):
521
+ print("Ethovision annotations not yet supported")
522
+
523
+ def clear_channels(self):
524
+ self._channels.clear()
525
+
526
+ def channel_names(self):
527
+ return list(self._channels.keys())
528
+
529
+ def channel(self, ch: str) -> Channel:
530
+ return self._channels[ch]
531
+
532
+ def addEmptyChannel(self, ch: str):
533
+ if ch not in self.channel_names():
534
+ self._channels[ch] = Channel()
535
+
536
+ def add_bout(self, bout, channel):
537
+ if bout.name() not in self.annotation_names:
538
+ self.annotation_names.append(bout.name())
539
+ self._channels[channel].add(bout)
540
+ if bout.end() > self.end_time():
541
+ self.set_end_frame(bout.end())
542
+
543
+ def start_time(self):
544
+ """
545
+ At some point we will need to support a start time distinct from
546
+ frame number, perhaps derived from the OS file modify time
547
+ or the start time of the corresponding video (or other media) file
548
+ """
549
+ if not self._start_frame or not self._sample_rate:
550
+ return tc.Timecode('30.0', '0:0:0:0')
551
+ return tc.Timecode(self._sample_rate, frames=self._start_frame)
552
+
553
+ def start_frame(self):
554
+ return self._start_frame
555
+
556
+ def set_start_frame(self, t):
557
+ if isinstance(t, int):
558
+ self._start_frame = t
559
+ elif isinstance(t, tc.Timecode):
560
+ self._start_frame = t.frames
561
+ else:
562
+ raise TypeError("Expected a frame number or Timecode")
563
+
564
+ def end_time(self):
565
+ if not self._end_frame or not self._sample_rate:
566
+ return tc.Timecode('30.0', '23:59:59:29')
567
+ return tc.Timecode(self._sample_rate, frames=self._end_frame)
568
+
569
+ def end_frame(self):
570
+ return self._end_frame
571
+
572
+ def set_end_frame(self, t):
573
+ if isinstance(t, int):
574
+ self._end_frame = t
575
+ elif isinstance(t, tc.Timecode):
576
+ self._end_frame = t.frames
577
+ else:
578
+ raise TypeError("Expected a frame number or Timecode")
579
+
580
+ def sample_rate(self):
581
+ return self._sample_rate
582
+
583
+ def set_sample_rate(self, sample_rate):
584
+ self._sample_rate = sample_rate
585
+
586
+ def format(self):
587
+ return self._format
588
+
589
+ def set_format(self, format):
590
+ self._format = format
591
+
592
+ def delete_bouts_by_name(self, behavior_name):
593
+ deleted_names = set()
594
+ for chan_name in self.channel_names():
595
+ deleted_names.update(self.channel(chan_name).delete_bouts_by_name(behavior_name))
596
+ for name in deleted_names:
597
+ self.annotation_names.remove(name)
598
+
599
+ def delete_inactive_bouts(self):
600
+ deleted_names = set()
601
+ for chan_name in self.channel_names():
602
+ deleted_names.update(self.channel(chan_name).delete_inactive_bouts())
603
+ for name in deleted_names:
604
+ self.annotation_names.remove(name)
605
+
606
+ def ensure_and_activate_behaviors(self, toActivate):
607
+ behaviorSetUpdated = False
608
+ for behaviorName in toActivate:
609
+ behaviorSetUpdated |= self._behaviors.addIfMissing(behaviorName)
610
+ self.annotation_names.append(behaviorName)
611
+ self._behaviors.get(behaviorName).set_active(True)
612
+ if behaviorSetUpdated:
613
+ self.annotations_changed.emit()
614
+ self.active_annotations_changed.emit()
615
+
616
+ def ensure_active_behaviors(self):
617
+ for behavior in self._behaviors:
618
+ if behavior.is_active() and behavior.get_name() not in self.annotation_names:
619
+ self.annotation_names.append(behavior.get_name())
620
+
621
+ def truncate_or_remove_bouts(self, behavior, start, end, chan):
622
+ """
623
+ Delete bouts entirely within the range [start, end], or
624
+ truncate bouts that extend outside the range.
625
+ If behavior matches _deleteBehavior, the activity affects
626
+ all bouts. Otherwise, it only affects bouts with matching behavior.
627
+ """
628
+ delete_all = (behavior.get_name() == self._behaviors.getDeleteBehavior().get_name())
629
+ self._channels[chan].truncate_or_remove_bouts(behavior, start, end, delete_all)
630
+ self.note_annotations_changed()
631
+
632
+ def coalesce_bouts(self, start, end, chan):
633
+ """
634
+ combine overlapping bouts of the same behavior within [start, end]
635
+ """
636
+ self._channels[chan].coalesce_bouts(start, end)
637
+ self.note_annotations_changed()
638
+
639
+ @Slot()
640
+ def note_annotations_changed(self):
641
+ self.annotations_changed.emit()
utils/behavior.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # behavior.py
2
+ """
3
+ Overview comment here
4
+ """
5
+
6
+ from qtpy.QtCore import QAbstractItemModel, QAbstractTableModel, QModelIndex, QObject, Qt, Signal, Slot
7
+ from qtpy.QtGui import QColor
8
+ import os
9
+
10
+ class Behavior(QObject):
11
+ """
12
+ An annotation behavior, which is quite simple. It comprises:
13
+ name The name of the behavior, which is displayed on various UI widgets
14
+ hot_key The case-sensitive key stroke used to start and stop instances of the behavior
15
+ color The color with which to display this behavior
16
+ """
17
+
18
+ def __init__(self, name: str, hot_key: str = '', color: QColor = QColor('gray'), active = False, visible = True):
19
+ super().__init__()
20
+ self._name = name
21
+ self._hot_key = '' if hot_key == '_' else hot_key
22
+ self._color = color
23
+ self._visible = visible
24
+ self._active = active
25
+ self._get_functions = {
26
+ 'hot_key': self.get_hot_key,
27
+ 'name': self.get_name,
28
+ 'color': self.get_color,
29
+ 'active': self.is_active,
30
+ 'visible': self.is_visible
31
+ }
32
+ self._set_functions = {
33
+ 'hot_key': self.set_hot_key,
34
+ 'name': self.set_name,
35
+ 'color': self.set_color,
36
+ 'active': self.set_active,
37
+ 'visible': self.set_visible
38
+ }
39
+
40
+ def __repr__(self):
41
+ return f"Behavior: name={self._name}, hot_key={self._hot_key}, color={self._color}, active={self._active}, visible={self._visible}"
42
+
43
+ def get(self, key):
44
+ # may raise KeyError
45
+ return self._get_functions[key]()
46
+
47
+ def set(self, key, value):
48
+ # may raise KeyError
49
+ self._set_functions[key](value)
50
+
51
+ def get_hot_key(self):
52
+ return self._hot_key
53
+
54
+ def set_hot_key(self, hot_key: str):
55
+ self._hot_key = hot_key
56
+
57
+ def get_color(self):
58
+ return self._color
59
+
60
+ def set_color(self, color: QColor):
61
+ self._color = color
62
+
63
+ def is_active(self):
64
+ return self._active
65
+
66
+ @Slot(bool)
67
+ def set_active(self, active):
68
+ self._active = active
69
+
70
+ def is_visible(self):
71
+ return self._visible
72
+
73
+ @Slot(bool)
74
+ def set_visible(self, visible):
75
+ self._visible = visible
76
+
77
+ def get_name(self):
78
+ return self._name
79
+
80
+ def set_name(self, name):
81
+ self._name = name
82
+
83
+ def toDict(self):
84
+ return {
85
+ 'hot_key': '_' if self._hot_key == '' else self._hot_key,
86
+ 'color': self._color,
87
+ 'name': self._name,
88
+ 'active': self._active,
89
+ 'visible': self._visible
90
+ }
91
+
92
+ class Behaviors(QAbstractTableModel):
93
+ """
94
+ A set of behaviors, which represent "all" possible behaviors.
95
+ The class supports reading from and writing to profile files that specify
96
+ the default hot_key and color for each behavior, along with the name.
97
+
98
+ Derives from QAbstractTableModel so that it can be viewed and edited
99
+ directly in a QTableView widget.
100
+
101
+ Use getattr(name) to get the Behavior instance for a given name.
102
+ Use from_hot_key(key) to get the behavior(s) given the hot key.
103
+ Returns None if the hot_key isn't defined.
104
+ """
105
+
106
+ behaviors_changed = Signal()
107
+ layout_changed = Signal()
108
+
109
+ def __init__(self):
110
+ super().__init__()
111
+ self._items = []
112
+ self._by_name = {}
113
+ self._by_hot_key = {}
114
+ self._header = ['hot_key', 'color', 'name', 'active', 'visible']
115
+ self._searchList = [self._by_hot_key, None, self._by_name, None, None]
116
+ self._delete_behavior = Behavior('_delete', color = QColor('black'))
117
+ self._immutableColumns = set()
118
+ self._booleanColumns = set([self._header.index('active'), self._header.index('visible')])
119
+ self._role_to_str = {
120
+ Qt.DisplayRole: "DisplayRole",
121
+ Qt.DecorationRole: "DecorationRole",
122
+ Qt.EditRole: "EditRole",
123
+ Qt.ToolTipRole: "ToolTipRole",
124
+ Qt.StatusTipRole: "StatusTipRole",
125
+ Qt.WhatsThisRole: "WhatsThisRole",
126
+ Qt.SizeHintRole: "SizeHintRole",
127
+ Qt.FontRole: "FontRole",
128
+ Qt.TextAlignmentRole: "TextAlignmentRole",
129
+ Qt.BackgroundRole: "BackgroundRole",
130
+ Qt.ForegroundRole: "ForegroundRole",
131
+ Qt.CheckStateRole: "CheckStateRole",
132
+ Qt.InitialSortOrderRole: "InitialSortOrderRole",
133
+ Qt.AccessibleTextRole: "AccessibleTextRole",
134
+ Qt.UserRole: "UserRole"
135
+ }
136
+
137
+ def add(self, beh: Behavior, row=-1):
138
+ if row < 0:
139
+ row = self.rowCount()
140
+ self.beginInsertRows(QModelIndex(), row, row)
141
+ self._items.insert(row, beh)
142
+ self._by_name[beh.get_name()] = beh
143
+ hot_key = beh.get_hot_key()
144
+ if hot_key:
145
+ if hot_key not in self._by_hot_key.keys():
146
+ self._by_hot_key[hot_key] = []
147
+ assert(isinstance(self._by_hot_key[hot_key], list))
148
+ self._by_hot_key[hot_key].append(beh)
149
+ self.endInsertRows()
150
+ self.dataChanged.emit(
151
+ self.index(row, 0, QModelIndex()),
152
+ self.index(row, self.columnCount()-1, QModelIndex()),
153
+ [Qt.DisplayRole, Qt.EditRole])
154
+ self.behaviors_changed.emit()
155
+
156
+ def load(self, f):
157
+ line = f.readline()
158
+ while line:
159
+ hot_key, name, r, g, b = line.strip().split(' ')
160
+ if hot_key == '_':
161
+ hot_key = ''
162
+ self.add(Behavior(name, hot_key, QColor.fromRgbF(float(r), float(g), float(b))))
163
+ line = f.readline()
164
+
165
+ def save(self, f):
166
+ for beh in self._items:
167
+ h = beh.get_hot_key()
168
+ if h == '':
169
+ h = '_'
170
+ color = beh.get_color()
171
+ f.write(f"{h} {beh.get_name()} {color.redF()} {color.greenF()} {color.blueF()}" + os.linesep)
172
+
173
+ def get(self, name):
174
+ if name not in self._by_name.keys():
175
+ return None
176
+ return self._by_name[name]
177
+
178
+ def from_hot_key(self, key):
179
+ """
180
+ Return the list of behaviors associated with this hot key, if any
181
+ """
182
+ try:
183
+ return self._by_hot_key[key]
184
+ except KeyError:
185
+ return None
186
+
187
+ def len(self):
188
+ return len(self._items)
189
+
190
+ def header(self):
191
+ return self._header
192
+
193
+ def colorColumns(self):
194
+ return [self._header.index('color')]
195
+
196
+ def __iter__(self):
197
+ return iter(self._items)
198
+
199
+ def getDeleteBehavior(self):
200
+ return self._delete_behavior
201
+
202
+ def addIfMissing(self, nameToAdd):
203
+ if nameToAdd not in self._by_name:
204
+ self.add(Behavior(nameToAdd, '', QColor('gray')))
205
+ return True
206
+ return False
207
+
208
+ def isImmutable(self, index):
209
+ return index.column() in self._immutableColumns
210
+
211
+ def setImmutable(self, column):
212
+ self._immutableColumns.add(column)
213
+
214
+ # QAbstractTableModel API methods
215
+
216
+ def headerData(self, col, orientation, role):
217
+ if orientation == Qt.Horizontal and role == Qt.DisplayRole:
218
+ return self._header[col]
219
+ return None
220
+
221
+ def rowCount(self, parent=None):
222
+ return len(self._items)
223
+
224
+ def columnCount(self, parent=None):
225
+ return len(self._header)
226
+
227
+ def data(self, index, role=Qt.DisplayRole):
228
+ datum = self._items[index.row()].get(self._header[index.column()])
229
+ if isinstance(datum, bool):
230
+ if role in [Qt.CheckStateRole, Qt.EditRole]:
231
+ return Qt.Checked if datum else Qt.Unchecked
232
+ return None
233
+ if role in [Qt.DisplayRole, Qt.EditRole]:
234
+ return self._items[index.row()].get(self._header[index.column()])
235
+ return None
236
+
237
+ def setData(self, index, value, role=Qt.EditRole):
238
+ if not role in [Qt.CheckStateRole, Qt.EditRole]:
239
+ return False
240
+ if role == Qt.CheckStateRole:
241
+ value = bool(value)
242
+ beh = self._items[index.row()]
243
+ key = self._header[index.column()]
244
+ name = beh.get_name()
245
+ hot_key = beh.get_hot_key()
246
+ beh.set(key, value)
247
+ if key == 'hot_key' and value != hot_key:
248
+ # disassociate this behavior from the hot_key
249
+ # and associate with the new hot_key if not ''
250
+ if hot_key != '':
251
+ del(self._by_hot_key[hot_key])
252
+ if value != '':
253
+ if value not in self._by_hot_key.keys():
254
+ self._by_hot_key[value] = []
255
+ assert(isinstance(self._by_hot_key[value], list))
256
+ self._by_hot_key[value].append(beh)
257
+ elif key == 'name' and value != name:
258
+ if name in self._by_name.keys():
259
+ del(self._by_name[name])
260
+ self._by_name[value] = beh
261
+ self.behaviors_changed.emit()
262
+ self.dataChanged.emit(index, index, [role])
263
+ return True
264
+
265
+ def insertRows(self, row, count, parent):
266
+ if count < 1 or row < 0 or row > self.rowCount():
267
+ return False
268
+ self.beginInsertRows(QModelIndex(), row, row)
269
+ for r in range(count):
270
+ self._items.insert(row, Behavior('', active=True))
271
+ self.endInsertRows()
272
+ return True
273
+
274
+ def removeRows(self, row, count, parent=QModelIndex()):
275
+ if count <= 0 or row < 0 or row + count > self.rowCount(parent):
276
+ return False
277
+ self.beginRemoveRows(parent, row, row + count - 1)
278
+ for item in self._items[row:row+count-1]:
279
+ self._by_name.pop(item.name)
280
+ self._by_hot_key.pop(item.hot_key)
281
+ for i in range(count):
282
+ self._items.pop(row)
283
+ self.endRemoveRows()
284
+
285
+ def flags(self, index):
286
+ f = super().flags(index)
287
+ if index.column() not in self._immutableColumns:
288
+ f |= Qt.ItemIsEditable
289
+ if index.column() in self._booleanColumns:
290
+ f = (f & ~(Qt.ItemIsSelectable | Qt.ItemIsEditable)) | Qt.ItemIsUserCheckable
291
+ return f
utils/data_loading.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """File for loading data into AnimalEditor"""
2
+ import io
3
+ from random import random
4
+ from os.path import splitext
5
+ from collections import OrderedDict
6
+ import numpy as np
7
+ from tempfile import NamedTemporaryFile
8
+
9
+ from .annot import Annotations
10
+ from .behavior import Behaviors
11
+
12
+ def has_extension(fname:str, extension:str|list[str]) -> bool:
13
+ """
14
+ Checks to see if the passed in file name ends with an expected extension.
15
+ """
16
+ _, ext = splitext(fname)
17
+ if isinstance(extension, str):
18
+ return ext == extension
19
+ elif isinstance(extension, list):
20
+ return ext in extension
21
+
22
+ def _clean_annotations(annotations):
23
+ """
24
+ While reading in behaviors from an .annot file, sometimes channels without normally
25
+ callable keys appear (i.e. keys that are strings which name a behavior), thus this
26
+ code only accepts keys which are strings.
27
+ """
28
+ if not annotations:
29
+ raise ValueError('No annotations found.')
30
+ clean_annot = OrderedDict()
31
+ for channel in annotations.keys():
32
+ channel_dict = OrderedDict()
33
+ for behavior_name in annotations[channel].keys():
34
+ if isinstance(behavior_name, str):
35
+ channel_dict.update({behavior_name : annotations[channel][behavior_name]})
36
+ clean_annot.update({channel: channel_dict})
37
+ return clean_annot
38
+
39
+ def load_annot_sheet_txt(fname, offset = 0):
40
+ """
41
+ Generated a dictionary for retrieving the beginning and end frames of behaviors from
42
+ an .annot file.
43
+
44
+ Note that 0:00:00 is frame 1
45
+
46
+ Args:
47
+ fname - the path to the .annot file to be read (must be Caltech format)\n
48
+ offset - a value which offsets the start and end frame of each bout in
49
+ the sheet, as well as the absolute start and end frame of the file.
50
+ This value is optional, and is set to 0 by default
51
+
52
+ Returns:
53
+ annotations - dictionary of beginning and end frames for behaviors\n
54
+ start_time - the frame the movie started at (0:00:00 is 1)\n
55
+ end_time - the frame the movie ended at (0:00:00 is 1)\n
56
+ sample_rate - the sample rate reported within the file
57
+ """
58
+ # from bento for python
59
+ behaviors = Behaviors()
60
+ annot_sheet = Annotations(behaviors)
61
+ annot_sheet.read(fname)
62
+
63
+ sample_rate = annot_sheet.sample_rate()
64
+ annotations = OrderedDict()
65
+ for key in annot_sheet.channel_names():
66
+ annot_behaviors = OrderedDict()
67
+ bout_names = set()
68
+ for bout in annot_sheet.channel(key): #._bouts_by_start:
69
+ bout_names.add(bout.name())
70
+ for name in bout_names:
71
+ annot_behaviors.update({name : []})
72
+ for bout in annot_sheet.channel(key): #._bouts_by_start:
73
+ start_frame = bout.start().frames + offset
74
+ end_frame = bout.end().frames + offset
75
+ bout_frames = [start_frame, end_frame]
76
+ curr_table = annot_behaviors.get(bout.name())
77
+ new_table = curr_table.append(bout_frames)
78
+ annot_behaviors.update({bout.name : new_table})
79
+ for name in bout_names:
80
+ curr_table = annot_behaviors.get(name)
81
+ beh_array = np.array(curr_table)
82
+ annot_behaviors.update({name : beh_array})
83
+
84
+ annotations.update({key : annot_behaviors})
85
+ annotations = _clean_annotations(annotations)
86
+ start_time = annot_sheet.start_frame() + offset
87
+ end_time = annot_sheet.end_frame() + offset
88
+ return annotations, start_time, end_time, sample_rate
89
+
90
+ def load_multiple_annotations(fnames):
91
+ """
92
+ Generates a single dictionary given multiple .annot files.
93
+ """
94
+ if not isinstance(fnames, list):
95
+ raise TypeError(f'Expected list[str], got {type(fnames)} instead.')
96
+ if not fnames:
97
+ raise ValueError('No file names passed in.')
98
+ if len(fnames) == 1:
99
+ return load_annot_sheet_txt(fnames[0])
100
+ head_annot, head_start_frame, head_end_frame, sample_rate = load_annot_sheet_txt(fnames[0])
101
+ end_frame = head_end_frame
102
+ for fname in fnames[1:]:
103
+ curr_annot, _, curr_end_frame, _ = load_annot_sheet_txt(fname, end_frame)
104
+ end_frame = curr_end_frame
105
+ for channel in curr_annot.keys():
106
+ if channel not in head_annot:
107
+ channel_dict = {}
108
+ head_annot.update({channel : channel_dict})
109
+ for behavior in curr_annot[channel].keys():
110
+ curr_behavior_bout_array = curr_annot[channel][behavior]
111
+ if channel in head_annot and behavior in head_annot[channel]:
112
+ new_bout_array = np.vstack((head_annot[channel][behavior],
113
+ curr_behavior_bout_array))
114
+ else:
115
+ new_bout_array = curr_behavior_bout_array
116
+ head_annot[channel].update({behavior : new_bout_array})
117
+ return head_annot, head_start_frame, end_frame, sample_rate
118
+
119
+ def load_annot_sheet_txt_io(uploaded_file, offset = 0):
120
+ """
121
+ Generated a dictionary for retrieving the beginning and end frames of behaviors from
122
+ an .annot file.
123
+
124
+ Note that 0:00:00 is frame 1
125
+
126
+ Args:
127
+ fname - the path to the .annot file to be read (must be Caltech format)\n
128
+ offset - a value which offsets the start and end frame of each bout in
129
+ the sheet, as well as the absolute start and end frame of the file.
130
+ This value is optional, and is set to 0 by default
131
+
132
+ Returns:
133
+ annotations - dictionary of beginning and end frames for behaviors\n
134
+ start_time - the frame the movie started at (0:00:00 is 1)\n
135
+ end_time - the frame the movie ended at (0:00:00 is 1)\n
136
+ sample_rate - the sample rate reported within the file
137
+ """
138
+ # from bento for python
139
+ behaviors = Behaviors()
140
+ annot_sheet = Annotations(behaviors)
141
+
142
+ annot_sheet.read_io(uploaded_file)
143
+
144
+ sample_rate = annot_sheet.sample_rate()
145
+ annotations = OrderedDict()
146
+ for key in annot_sheet.channel_names():
147
+ annot_behaviors = OrderedDict()
148
+ bout_names = set()
149
+ for bout in annot_sheet.channel(key): #._bouts_by_start:
150
+ bout_names.add(bout.name())
151
+ for name in bout_names:
152
+ annot_behaviors.update({name : []})
153
+ for bout in annot_sheet.channel(key): #._bouts_by_start:
154
+ start_frame = bout.start().frames + offset
155
+ end_frame = bout.end().frames + offset
156
+ bout_frames = [start_frame, end_frame]
157
+ curr_table = annot_behaviors.get(bout.name())
158
+ new_table = curr_table.append(bout_frames)
159
+ annot_behaviors.update({bout.name : new_table})
160
+ for name in bout_names:
161
+ curr_table = annot_behaviors.get(name)
162
+ beh_array = np.array(curr_table)
163
+ annot_behaviors.update({name : beh_array})
164
+
165
+ annotations.update({key : annot_behaviors})
166
+ annotations = _clean_annotations(annotations)
167
+ start_time = annot_sheet.start_frame() + offset
168
+ end_time = annot_sheet.end_frame() + offset
169
+ return annotations, start_time, end_time, sample_rate
170
+
171
+ def load_multiple_annotations_io(uploaded_files):
172
+ """
173
+ Generates a single dictionary given multiple .annot files.
174
+ """
175
+ if not isinstance(uploaded_files, list):
176
+ raise TypeError(f'Expected list, got {type(uploaded_files)} instead.')
177
+ if not uploaded_files:
178
+ raise ValueError('No file names passed in.')
179
+ if len(uploaded_files) == 1:
180
+ return load_annot_sheet_txt_io(uploaded_files[0])
181
+ head_annot, head_start_frame, head_end_frame, sample_rate = load_annot_sheet_txt_io(uploaded_files[0])
182
+ end_frame = head_end_frame
183
+ for uploaded_file in uploaded_files[1:]:
184
+ curr_annot, _, curr_end_frame, _ = load_annot_sheet_txt_io(uploaded_file, end_frame)
185
+ end_frame = curr_end_frame
186
+ for channel in curr_annot.keys():
187
+ if channel not in head_annot:
188
+ channel_dict = {}
189
+ head_annot.update({channel : channel_dict})
190
+ for behavior in curr_annot[channel].keys():
191
+ curr_behavior_bout_array = curr_annot[channel][behavior]
192
+ if channel in head_annot and behavior in head_annot[channel]:
193
+ new_bout_array = np.vstack((head_annot[channel][behavior],
194
+ curr_behavior_bout_array))
195
+ else:
196
+ new_bout_array = curr_behavior_bout_array
197
+ head_annot[channel].update({behavior : new_bout_array})
198
+ return head_annot, head_start_frame, end_frame, sample_rate
utils/data_processing.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Provides the standard data processing functions performed on CNMFe and annotation data"""
2
+ import numpy as np
3
+ from scipy.signal import correlate
4
+ from scipy.stats import zscore
5
+
6
+ def smooth(data: np.ndarray, window_size=5):
7
+ """
8
+ Returns a smoothed version of response data using a moving average filter.
9
+
10
+ Parameters:
11
+ ----------
12
+ data : np.ndarray
13
+ A numpy 1-D array containing data to be smoothed.
14
+ window_size : int
15
+ Number of data points for calculating the smoothed value. If an even number is
16
+ passed in, window_size is autmoatically reduced by 1.
17
+
18
+ Returns:
19
+ --------
20
+ smooth_data : np.ndarray
21
+ Smoothed data, returned as a 1-D array of the same size as ``data``.
22
+
23
+ Notes:
24
+ ------
25
+ Implements MATLAB's smooth function.
26
+ """
27
+ if window_size == 0:
28
+ raise ValueError('window_size can not be 0.')
29
+ if window_size == 1:
30
+ return data
31
+ if window_size > data.size:
32
+ window_size = data.size
33
+ if window_size%2 == 0:
34
+ window_size = window_size - 1
35
+ outside_valid_window_size = int((window_size-1)/2)
36
+ start = np.array([np.sum(data[0:(2*k+1)]/(2*k+1)) for k in range(outside_valid_window_size)])
37
+ end = np.array([np.sum(data[-(2*k+1):]/(2*k+1)) for k in range(outside_valid_window_size)])[::-1]
38
+ smoothed_data = np.convolve(data,np.ones(window_size,dtype=int),'valid')/window_size
39
+ return np.hstack((start,smoothed_data,end))
40
+
41
+ def corr(x: np.ndarray, y: np.ndarray):
42
+ """
43
+ Returns a matrix of the pairwise correlation coefficient between each pair of columns
44
+ in the input matrices x and y.
45
+
46
+ Parameters:
47
+ -----------
48
+ x : np.ndarray
49
+ Input matrix, specified as an n x k_1 matrix. Its rows correspond to
50
+ observations, and the columns correspond to variables.
51
+ y : np.ndarray
52
+ Input matrix, specified as an n x k_2 matrix. Its rows correspond to
53
+ observations, and the columns correspond to variables.
54
+
55
+ Returns:
56
+ --------
57
+ rho - Pairwise linear correlation coefficient, returned as a matrix.
58
+
59
+ Notes:
60
+ ------
61
+ Implements MATLAB's corr function.
62
+ """
63
+ return np.corrcoef(x,y)[0][1]
64
+
65
+ def autocorr(x:np.ndarray,
66
+ max_lags=10):
67
+ """
68
+ Returns the correlations and associated lags of the univariate time series x.
69
+
70
+ Parameters:
71
+ -----------
72
+ x : np.ndarray
73
+ Observed univariate time series.
74
+ max_lags : int
75
+ Number of lags, specified as a positive integer.
76
+
77
+ Returns:
78
+ acf : np.ndarray
79
+ Correlations, returned as a numeric vector of length ``max_lags`` + 1.
80
+ lags : np.ndarray
81
+ Autocorrelation lags.
82
+
83
+ Notes:
84
+ ------
85
+ Modified version of matplotlib's acorr function.
86
+ """
87
+ Nx = len(x)
88
+
89
+ correls = correlate(x, x, mode="full")
90
+ correls = correls / np.dot(x, x)
91
+
92
+ if max_lags is None:
93
+ max_lags = Nx - 1
94
+
95
+ if max_lags >= Nx or max_lags < 1:
96
+ raise ValueError('maxlags must be None or strictly '
97
+ 'positive < %d' % Nx)
98
+
99
+ lags = np.arange(-max_lags, max_lags + 1)
100
+ acf = correls[Nx - 1 - max_lags:Nx + max_lags]
101
+
102
+ return acf, lags
103
+
104
+ def convert_to_rast(behavior_ts, time_max):
105
+ """
106
+ Converts a list of behavior time stamps to a one-hot vector where 0 indicates no
107
+ presence of the given behavior, and 1 indicates presence of it.
108
+
109
+ Args:
110
+ behavior_ts - a list of time stamps (start and end) for a particular behavior\n
111
+ time_max - the length in frames of the vector
112
+
113
+ Returns:
114
+ behavior_rast - a one-hot vector
115
+ """
116
+ behavior_rast = np.zeros(time_max)
117
+ for time_stamps in behavior_ts:
118
+ start = int(round(time_stamps[0]))
119
+ end = int(round(time_stamps[1] + 1))
120
+ if start > time_max:
121
+ break
122
+ if end > time_max:
123
+ end = time_max
124
+ np.put(behavior_rast,range(start,end),np.ones(end-start))
125
+ return behavior_rast
126
+
127
+ def convert_to_raster(bouts: list,
128
+ neural_activity_sr: float,
129
+ observation_sr: float,
130
+ max_frame: int):
131
+ """
132
+ Converts bouts into a behavior raster, a one hot encoding of a behavior describing
133
+ when it is active.
134
+
135
+ It is often the case that the start and stop timestamps found in ``bouts`` are
136
+ collected at a different sample rate than ``neural_activity``, which are often what
137
+ behavior rasters align to. In order to align the two, a ratio between the sample
138
+ rates of ``neural_activity`` and the bouts of behavior, which are observations,
139
+ is calculated and then multiplied to the timestamps.
140
+
141
+ Parameters:
142
+ -----------
143
+ bouts : np.ndarray
144
+ An array where each element is a pair of integers where the first integer denotes
145
+ the beginning of a bout of behavior, and the second integer denotes the end of
146
+ the bout.
147
+ neural_activity_sr : float
148
+ Sample rate of ``neural_activity``.
149
+ observation_sr : float
150
+ Sample rate for the ``bouts`` used.
151
+ max_frame : int
152
+ The length of the behavior raster, often set to the number of frames of
153
+ ``neural_activity``.
154
+
155
+ Returns:
156
+ --------
157
+ behavior_raster : np.ndarray
158
+ A raster (a one hot encoding) of a behavior, describing when it is active.
159
+ """
160
+ sr_ratio = neural_activity_sr/observation_sr
161
+ behavior_ts_adjusted = bouts*sr_ratio
162
+ behavior_raster = np.zeros(max_frame)
163
+ for time_stamps in behavior_ts_adjusted:
164
+ start = int(round(time_stamps[0]))
165
+ end = int(round(time_stamps[1] + 1))
166
+ if start > max_frame:
167
+ break
168
+ if end > max_frame:
169
+ end = max_frame
170
+ np.put(behavior_raster,range(start,end),np.ones(end-start))
171
+ return behavior_raster
172
+
173
+ def convert_to_bouts(behavior_raster: np.ndarray):
174
+ """
175
+ Converts a behavior raster into behavior bouts, an array where each element is a
176
+ pair of timestamps (int) where the first timestamp denotes the beginning of a bout of
177
+ behavior, and the second timestamp denotes the end of the bout.
178
+
179
+ Parameters:
180
+ -----------
181
+ behavior_raster : np.ndarray
182
+ A raster (a one hot encoding) of a behavior, describing when it is active.
183
+
184
+ Returns:
185
+ --------
186
+ bouts : np.ndarray
187
+ An array where each element is a pair of timestamps (int) where the first
188
+ timestamp denotes the beginning of a bout of behavior, and the second timestamp
189
+ denotes the end of the bout.
190
+ """
191
+ dt = behavior_raster[1:] - behavior_raster[:-1]
192
+ start = np.where(dt==1)[0] + 1
193
+ stop = np.where(dt==-1)[0]
194
+ if behavior_raster[0]:
195
+ start = np.concatenate((np.array([0]),start))
196
+ if behavior_raster[-1]:
197
+ stop = np.concatenate((stop,[behavior_raster.size]))
198
+ bouts = np.hstack((np.reshape(start,(len(start),1)),
199
+ np.reshape(stop,(len(stop),1))))
200
+ return bouts
201
+
202
+ def merge_rasters_down(behavior_raster_array: np.ndarray)-> np.ndarray:
203
+ """
204
+ For a behavior raster, merges down all rasters to one array in such a way that no
205
+ two behaviors are occuring at the same time.
206
+
207
+ It determines which behavior should remain 'on top' by determening which behavior
208
+ has the least amount of active frames.
209
+
210
+ This method should only be used on behavior rasters where all behaviors come from a
211
+ single channel.
212
+
213
+ Parameters:
214
+ -----------
215
+ behavior_raster_array : np.ndarray
216
+ An array where each row is a behavior raster, a one hot encoding of behaviors,
217
+ describing when that behavior is active. Each row of this array must use a
218
+ different value to indicate that a behavior is active (for example, if one
219
+ row uses 1s, another row must not use 1 as well).
220
+
221
+ Returns:
222
+ --------
223
+ single_track : np.ndarray
224
+ An array which is the length of a behavior raster in ``behavior_raster_array``,
225
+ where each entry is either 0 indicating that no behavior is active, or a value
226
+ indicating that a specific behavior is active.
227
+ """
228
+ # single track
229
+ single_track = np.zeros((1,behavior_raster_array.shape[1]))
230
+
231
+ # determine order to insert row values
232
+ num_active_frames = [np.sum(np.where(row > 0, 1, 0)) for row in behavior_raster_array]
233
+
234
+ for i in range(behavior_raster_array.shape[0]):
235
+ max_i = np.argmax(num_active_frames)
236
+ num_active_frames[max_i] = -1
237
+
238
+ unique_values = np.unique(behavior_raster_array[max_i])
239
+ if len(unique_values) > 1: value = unique_values[1]
240
+ else: value = 0
241
+ active_inds = np.where(behavior_raster_array[max_i] == value)[0]
242
+
243
+ single_track[:,active_inds] = value
244
+ return single_track
245
+
246
+ def separate_tracks(single_track: np.ndarray,
247
+ behavior_values: list):
248
+ """
249
+ For a single track, separates each unique value (except for 0) into its own raster
250
+ within a 2-D array.
251
+
252
+ Parameters:
253
+ -----------
254
+ single_track : np.ndarray
255
+ An array which is the length of a behavior raster in ``behavior_raster_array``,
256
+ where each entry is either 0 indicating that no behavior is active, or a value
257
+ indicating that a specific behavior is active.
258
+ behavior_values : list
259
+ A list of values corresponding to the specific behaviors within ``single_track``.
260
+
261
+ Returns:
262
+ --------
263
+ behavior_raster_array : np.ndarray
264
+ An array where each row is a behavior raster, a one hot encoding of behaviors,
265
+ describing when that behavior is active.
266
+ """
267
+ if len(behavior_values) < np.unique(single_track).size - 1:
268
+ raise KeyError("There are not sufficient values within ``behavior_values`` to "
269
+ "accomodate those present in ``single_track``.")
270
+ tracks = []
271
+ for value in behavior_values:
272
+ tracks.append(np.where(single_track == value, value, 0))
273
+ return np.vstack(tracks)
274
+
275
+ def config_neural_activity(config: dict, neural_activity: np.ndarray):
276
+ """
277
+ Configures `neural_activity` according to parameters set in config.
278
+
279
+ Parameters:
280
+ -----------
281
+ config : dict
282
+ A dictionary which specifies the following parameters: 'smooth_window',
283
+ 'baseline_frame', and 'zscore_method'. 'zscore_method' is one of "All Data",
284
+ "Baseline", or "No Z-Score".
285
+ neural_activity : np.ndarray
286
+ Neural activity being used.
287
+
288
+ Returns:
289
+ --------
290
+ mod_neural_activity : np.ndarray
291
+ Modified `neural_activity`, accodring to `config`.
292
+ """
293
+ smooth_window = config['smooth_window']
294
+ zscore_method = config['zscore_method']
295
+ baseline_frame = config['baseline_frame']
296
+
297
+ # smooth
298
+ if len(neural_activity.shape) > 1:
299
+ neural_data_smooth = np.zeros(neural_activity.shape)
300
+ for i in range(neural_activity.shape[0]):
301
+ neural_data_smooth[i] = smooth(neural_activity[i], int(smooth_window))
302
+ mod_neural_activity = neural_data_smooth
303
+ else:
304
+ mod_neural_activity = smooth(neural_activity, int(smooth_window))
305
+
306
+ # z-score
307
+ if zscore_method == 'Baseline' and (not baseline_frame is None or baseline_frame == 0):
308
+ if len(neural_activity.shape)> 1:
309
+ mean = mod_neural_activity[:,:baseline_frame].mean(axis=1,keepdims=True)
310
+ std = mod_neural_activity[:,:baseline_frame].std(axis=1,keepdims=True)
311
+ else:
312
+ mean = mod_neural_activity[:baseline_frame].mean()
313
+ std = mod_neural_activity[:baseline_frame].std()
314
+ mod_neural_activity = (mod_neural_activity - mean) / std
315
+ elif zscore_method == 'No Z-Score':
316
+ mod_neural_activity = mod_neural_activity
317
+ else:
318
+ if len(neural_activity.shape) > 1:
319
+ mod_neural_activity = zscore(mod_neural_activity,axis=1)
320
+ else:
321
+ mod_neural_activity = zscore(mod_neural_activity)
322
+ return mod_neural_activity
323
+
324
+ def compress_annotations(annot: dict, downsample_rate: int, max_frame: int)-> dict:
325
+ """
326
+ Takes in an annotation dictionary and creates a single raster per channel, where the
327
+ raster contains the behaviors from their respective channel.
328
+
329
+ annot : dict
330
+ Dictionary of beginning and end frames for behaviors.
331
+ downsample_rate : int
332
+ The rate at which samples should be taken. Divides bout timing (in frames) by
333
+ value.
334
+ max_frame : int
335
+ The last frame for annotations from `annot`.
336
+ """
337
+ annot_single_track = {}
338
+ channel_behavior_map = {}
339
+ for channel in annot:
340
+ channel_rasters = []
341
+ behavior_map = {}
342
+ behavior_map.update({0: 'None'})
343
+ for i, behavior in enumerate(annot[channel]):
344
+ bouts = annot[channel][behavior]
345
+ raster = convert_to_raster(bouts, 1, downsample_rate, max_frame)
346
+ channel_rasters.append(raster*(i+1))
347
+ behavior_map.update({(i+1) : behavior})
348
+ channel_raster = merge_rasters_down(np.array(channel_rasters))[0]
349
+ annot_single_track.update({channel : channel_raster})
350
+ channel_behavior_map.update({channel : behavior_map})
351
+ return annot_single_track, channel_behavior_map
352
+
353
+ def compress_compressed_annotations(annot_single_track: dict,
354
+ channel_behavior_map: dict,
355
+ max_frame: int):
356
+ """
357
+ Further compresses the results from `compress_annotations` to get a single array
358
+ where each entry is a list of the behaviors present at that frame across all channels.
359
+ """
360
+ labels = []
361
+ for frame in range(max_frame):
362
+ labels_at_frame = []
363
+ for channel in annot_single_track:
364
+ channel_raster = annot_single_track[channel]
365
+ behavior_map = channel_behavior_map[channel]
366
+ behavior_value = int(channel_raster[frame])
367
+ behavior_label = behavior_map.get(behavior_value)
368
+ labels_at_frame.append(behavior_label)
369
+ labels.append('||'.join(labels_at_frame))
370
+ return labels
371
+
372
+ def generate_label_array(annot: dict,
373
+ downsample_rate: int,
374
+ max_frame: int)-> list[str]:
375
+ """
376
+ Generates an array of lists of labels, where each entry is a video frame, and the
377
+ labels come from each channel in `annot`.
378
+ """
379
+ annot_single_track,\
380
+ channel_behavior_map = compress_annotations(annot, downsample_rate, max_frame)
381
+ labels = compress_compressed_annotations(annot_single_track,
382
+ channel_behavior_map,
383
+ max_frame)
384
+ return labels
utils/mp4Io.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ from qtpy.QtGui import QImage, QPixmap
5
+
6
+ class mp4Io_reader():
7
+ def __init__(self, filename, info=[]):
8
+ self.filename = filename
9
+ self.file = cv2.VideoCapture(filename)
10
+
11
+ if self.file.isOpened()==False:
12
+ print("Error in opening video file.")
13
+
14
+ self.header={}
15
+ if info==[]:
16
+ self.readHeader()
17
+
18
+ def readHeader(self):
19
+
20
+ self.header = {
21
+ 'width': int(self.file.get(cv2.CAP_PROP_FRAME_WIDTH)),
22
+ 'height': int(self.file.get(cv2.CAP_PROP_FRAME_HEIGHT)),
23
+ 'fps': self.file.get(cv2.CAP_PROP_FPS),
24
+ 'numFrames': int(self.file.get(cv2.CAP_PROP_FRAME_COUNT))
25
+ }
26
+
27
+ def seek(self, index):
28
+
29
+ self.file.set(cv2.CAP_PROP_POS_FRAMES, index)
30
+
31
+ def getTs(self,n=None):
32
+ if n==None:
33
+ n = self.header['numFrames']
34
+
35
+ ts = np.zeros(n+1)
36
+ for i in np.arange(1,n+1):
37
+ self.seek(i)
38
+ self.file.read()
39
+ ts[i] = self.file.get(cv2.CAP_PROP_POS_MSEC)/1000.
40
+
41
+ self.ts = ts[1:]
42
+ return self.ts
43
+
44
+ def getFrame(self, index, decode=True):
45
+
46
+ self.seek(index)
47
+ ret, frame = self.file.read()
48
+
49
+ ts = self.file.get(cv2.CAP_PROP_POS_MSEC)/1000.
50
+ return frame, ts
51
+
52
+ def getFrameAsQPixmap(self, index, decode=True):
53
+ image, _ = self.getFrame(index, decode)
54
+ h, w, ch = image.shape
55
+ bytes_per_line = ch * w
56
+ convert_to_Qt_format = QImage(image.data, w, h, bytes_per_line, QImage.Format_BGR888)
57
+ return QPixmap.fromImage(convert_to_Qt_format)
58
+
59
+ def close(self):
60
+ self.file.release()
utils/seqIo.py ADDED
@@ -0,0 +1,1189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ import io
6
+ from datetime import datetime, timedelta,date
7
+ import time
8
+ from matplotlib.dates import date2num, num2date
9
+ # import colour_demosaicing
10
+ import skvideo.io
11
+ import re
12
+ import pickle
13
+ import cv2
14
+ import progressbar as pb
15
+
16
+ # Create interface sr for reading seq files.
17
+ # sr = seqIo_reader( fName )
18
+ # Create interface sw for writing seq files.
19
+ # sw = seqIo_Writer( fName, header )
20
+ # Crop sub-sequence from seq file.
21
+ # seqIo_crop( fName, 'crop', tName, frames )
22
+ # Extract images from seq file to target directory or array.
23
+ # Is = seqIo_toImgs( fName, tDir=[],skip=1,f0=0,f1=np.inf,ext='' )
24
+ # Create seq file from an array or directory of images or from an AVI file. DONE
25
+ # seqIo_frImgs( fName, fName,header,aviName=[],Is=[],sDir=[],name='I',ndig=5,f0=0,f1=1e6 )
26
+ # Convert seq file by applying imgFun(I) to each frame I.
27
+ # seqIo( fName, 'convert', tName, imgFun, varargin )
28
+ # Replace header of seq file with provided info.
29
+ # seqIo( fName, 'newHeader', info )
30
+ # Create interface sr for reading dual seq files.
31
+ # sr = seqIo( fNames, 'readerDual', [cache] )
32
+
33
+ FRAME_FORMAT_RAW_GRAY = 100 #RAW
34
+ FRAME_FORMAT_RAW_COLOR = 200 #RAW
35
+ FRAME_FORMAT_JPEG_GRAY = 102 #JPG
36
+ FRAME_FORMAT_JPEG_COLOR = 201 #JPG
37
+ FRAME_FORMAT_MONOB = 101 #BRGB8
38
+ FRAME_FORMAT_MONOB_JPEG = 103 #JBRGB
39
+ FRAME_FORMAT_PNG_GRAY = 0x001 #PNG
40
+ FRAME_FORMAT_PNG_COLOR = 0x002 #PNG
41
+
42
+ #matlab equivalent fread
43
+ def fread(fid, nelements, dtype):
44
+
45
+ """Equivalent to Matlab fread function"""
46
+
47
+ if dtype is np.str_:
48
+ dt = np.uint8 # WARNING: assuming 8-bit ASCII for np.str!
49
+ else:
50
+ dt = dtype
51
+
52
+ data_array = np.fromfile(fid, dt, nelements)
53
+ if data_array.size == 1:
54
+ data_array = data_array[0]
55
+ return data_array
56
+
57
+ def fwrite(fid,a,dtype=np.str_):
58
+ # assuming 8but ASCII for string
59
+ if dtype is np.str_:
60
+ dt = np.uint8 # WARNING: assuming 8-bit ASCII for np.str!
61
+ else:
62
+ dt = dtype
63
+ if isinstance(a,np.ndarray):
64
+ data_array = a.astype(dt)
65
+ else:
66
+ data_array = np.array(a).astype(dt)
67
+ data_array.tofile(fid)
68
+
69
+ def tsSync(video_path, srTop, srFront):
70
+ #tsSync
71
+ name = video_path.split('/')[-1]
72
+ srDepth=seqIo_reader(video_path + '/' + name + '_DepGr_Raw.seq')
73
+
74
+ #read timestamp of individual frames
75
+ tsTop = srTop.getTs()
76
+ tsFront = srFront.getTs()
77
+ tsDepth = srDepth.getTs()
78
+
79
+ #check the version of the videos
80
+ videoDateStr = re.search('[0-9]+_[0-9]+-[0-9]+-[0-9]+',name).group(0)
81
+ videoDateNum =date2num(datetime.strptime(videoDateStr ,'%Y%m%d_%H-%M-%S'))
82
+ videoRefNum = date2num(datetime.strptime('20150401_00-00-00','%Y%m%d_%H-%M-%S'))
83
+ seqV = 1 if videoDateNum < videoRefNum else 2
84
+
85
+ ##correlate timestamps from one view to another
86
+ mapTs = {}
87
+ if seqV ==1:
88
+ for f in range(len(tsDepth)):
89
+ if tsDepth[f] - np.floor(tsDepth[f])>=.5: # Santiago's bug acquisistion software
90
+ tsDepth[f]-= 1
91
+ tsDepth-= 0.03 # substract the systematic timeshift
92
+
93
+ #load front and top view and convert from UTC to PST
94
+ hourShift = np.round((tsDepth[0]-tsTop[0])/3600.)*3600
95
+ timeShift = tsDepth[0] - tsTop[0] - .066
96
+ # print('tsDepth[0] - tsTop[0] = timeShift: %s sec'% str(timeShift - hourShift))
97
+ tsTop += timeShift
98
+ tsFront += timeShift
99
+
100
+ srTop.ts = tsTop
101
+ srFront.ts = tsFront
102
+ srDepth.ts = tsDepth
103
+
104
+ # Convert timestamps from left to right
105
+ mapTs['T2F'] = transformTs(tsTop, tsFront)
106
+ mapTs['F2T'] = transformTs(tsFront, tsTop)
107
+ mapTs['T2D'] = transformTs(tsTop, tsDepth)
108
+ mapTs['D2T'] = transformTs(tsDepth, tsTop)
109
+ mapTs['F2D'] = transformTs(tsFront, tsDepth)
110
+ mapTs['D2F'] = transformTs(tsDepth, tsFront)
111
+ else:
112
+ T = len(tsTop)
113
+ F = len(tsFront)
114
+ D = len(tsDepth)
115
+ mapTs['T2F'] = resizeTs(T,F)
116
+ mapTs['F2T'] = resizeTs(F,T)
117
+ mapTs['T2D'] = resizeTs(T,D)
118
+ mapTs['D2T'] = resizeTs(D,T)
119
+ mapTs['F2D'] = resizeTs(F,D)
120
+ mapTs['D2F'] = resizeTs(D,F)
121
+
122
+ #display the time in string format
123
+ # fTp=5
124
+ # fFr = mapTs['T2F'][fTp]
125
+ # fDp = mapTs['T2D'][fTp]
126
+ # tF = ts2str(tsFront[fFr])
127
+ # tT = ts2str(tsTop[fTp])
128
+ # tD = ts2str(tsDepth[fDp])
129
+ #
130
+ # print('mapping to tTop:')
131
+ # print('tTop = ' + str(fTp) + ' - ' + tT)
132
+ # print('tFront = ' + str(fFr) + ' - ' + tF)
133
+ # print('tDepth = ' + str(fDp) + ' - ' + tD)
134
+
135
+ # fDp = np.round(len(tsDepth)/1.5).astype(int)
136
+ # fTp = mapTs['D2T'][fDp]
137
+ # tT = ts2str(tsTop[fTp])
138
+ # tD = ts2str(tsDepth[fDp])
139
+ # print('mapping to tTop:')
140
+ # print('tTop = ' + str(fTp) + '. ' + tT)
141
+ # print('tDepth = ' + str(fDp) + '. ' + tD)
142
+
143
+ return mapTs,srTop,srFront
144
+
145
+ def transformTs(ts1, ts2):
146
+ # map ts1 to ts2 where ts2 is the reference
147
+ rankTs = np.zeros((len(ts1),4))
148
+ for f in range(len(ts1)):
149
+ tsDiff = ts2-ts1[f]
150
+ tsRank = np.sort(abs(tsDiff))
151
+ ind = np.argsort(abs(tsDiff))
152
+ rankTs[f,:] = [ind[0]+1,ind[1]+1,tsRank[0],tsRank[1]]
153
+
154
+ mapTs = np.round(smooth(rankTs[:,0],7)).astype(int)
155
+ mapTs = np.round(smooth(mapTs,7)).astype(int)
156
+ return mapTs
157
+
158
+ def smooth(a, WSZ):
159
+ out0 = np.convolve(a, np.ones(WSZ, dtype=int), 'valid') / WSZ
160
+ r = np.arange(1, WSZ - 1, 2)
161
+ start = np.cumsum(a[:WSZ - 1])[::2] / r
162
+ stop = (np.cumsum(a[:-WSZ:-1])[::2] / r)[::-1]
163
+ return np.concatenate((start, out0, stop))
164
+
165
+ def resizeTs(t1, t2):
166
+ if t1>t2:
167
+ mapTs = np.hstack((np.array(range(t2))+1,np.ones((t1-t2),int)*t2))
168
+ else:
169
+ mapTs = np.array(range(t1))+1
170
+
171
+ return mapTs
172
+
173
+ def ts2str(ts):
174
+ t = ts / 86400. + date.toordinal(date(1971, 1, 2))
175
+ # datetime.fromtimestamp(t)
176
+ str_time = (datetime.fromordinal(int(t)) + timedelta(days=t % 1) - timedelta(days=366)).strftime(
177
+ "%Y-%m-%d %H:%M:%S") + '.%03d' % np.round((ts - np.floor(ts)) * 1000)
178
+ return str_time
179
+
180
+ def parse_ann(f_ann):
181
+ header = 'Caltech Behavior Annotator - Annotation File'
182
+ conf = 'Configuration file:'
183
+ fid = open(f_ann)
184
+ ann = fid.read().splitlines()
185
+ fid.close()
186
+ NFrames = []
187
+ # check the header
188
+ assert ann[0].rstrip() == header
189
+ assert ann[1].rstrip() == ''
190
+ assert ann[2].rstrip() == conf
191
+ # parse action list
192
+ l = 3
193
+ names = [None] * 1000
194
+ keys = [None] * 1000
195
+ types = []
196
+ bnds = []
197
+ k = -1
198
+
199
+ # get config keys and names
200
+ while True:
201
+ ann[l] = ann[l].rstrip()
202
+ if not isinstance(ann[l], str) or not ann[l]:
203
+ l += 1
204
+ break
205
+ values = ann[l].split()
206
+ k += 1
207
+ names[k] = values[0]
208
+ keys[k] = values[1]
209
+ l += 1
210
+ names = names[:k + 1]
211
+ keys = keys[:k + 1]
212
+
213
+ # read in each stream in turn until end of file
214
+ bnds0 = [None] * 10000
215
+ types0 = [None] * 10000
216
+ actions0 = [None] * 10000
217
+ nStrm1 = 0
218
+ while True:
219
+ ann[l] = ann[l].rstrip()
220
+ nStrm1 += 1
221
+ t = ann[l].split(":")
222
+ l += 1
223
+ ann[l] = ann[l].rstrip()
224
+ assert int(t[0][1]) == nStrm1
225
+ assert ann[l] == '-----------------------------'
226
+ l += 1
227
+ bnds1 = np.ones((10000, 2), dtype=int)
228
+ types1 = np.ones(10000, dtype=int) * -1
229
+ actions1 = [None] * 10000
230
+ k = 0
231
+ # start the annotations
232
+ while True:
233
+ ann[l] = ann[l].rstrip()
234
+ t = ann[l]
235
+ if not isinstance(t, str) or not t:
236
+ l += 1
237
+ break
238
+ t = ann[l].split()
239
+ type = [i for i in range(len(names)) if t[2] == names[i]]
240
+ type = type[0]
241
+ if type == None:
242
+ print('undefined behavior' + t[2])
243
+ if bnds1[k - 1, 1] != int(t[0]) - 1 and k > 0:
244
+ print('%d ~= %d' % (bnds1[k, 1], int(t[0]) - 1))
245
+ bnds1[k, :] = [int(t[0]), int(t[1])]
246
+ types1[k] = type
247
+ actions1[k] = names[type]
248
+ k += 1
249
+ l += 1
250
+ if l == len(ann):
251
+ break
252
+ if nStrm1 == 1:
253
+ nFrames = bnds1[k - 1, 1]
254
+ assert nFrames == bnds1[k - 1, 1]
255
+ bnds0[nStrm1 - 1] = bnds1[:k]
256
+ types0[nStrm1 - 1] = types1[:k]
257
+ actions0[nStrm1 - 1] = actions1[:k]
258
+ if l == len(ann):
259
+ break
260
+ while not ann[l]:
261
+ l += 1
262
+
263
+ bnds = bnds0[:nStrm1]
264
+ types = types0[:nStrm1]
265
+ actions = actions0[:nStrm1]
266
+
267
+ idx = 0
268
+ if len(actions[0]) < len(actions[1]):
269
+ idx = 1
270
+ type_frame = []
271
+ action_frame = []
272
+ len_bnd = []
273
+
274
+ for i in range(len(bnds[idx])):
275
+ numf = bnds[idx][i, 1] - bnds[idx][i, 0] + 1
276
+ len_bnd.append(numf)
277
+ action_frame.extend([actions[idx][i]] * numf)
278
+ type_frame.extend([types[idx][i]] * numf)
279
+
280
+ ann_dict = {
281
+ 'keys': keys,
282
+ 'behs': names,
283
+ 'nstrm': nStrm1,
284
+ 'nFrames': nFrames,
285
+ 'behs_se': bnds,
286
+ 'behs_dur': len_bnd,
287
+ 'behs_bout': actions,
288
+ 'behs_frame': action_frame
289
+ }
290
+
291
+ return ann_dict
292
+
293
+ def parse_ann_dual(f_ann):
294
+ header = 'Caltech Behavior Annotator - Annotation File'
295
+ conf = 'Configuration file:'
296
+ fid = open(f_ann)
297
+ ann = fid.read().splitlines()
298
+ fid.close()
299
+ NFrames = []
300
+ # check the header
301
+ assert ann[0].rstrip() == header
302
+ assert ann[1].rstrip() == ''
303
+ assert ann[2].rstrip() == conf
304
+ # parse action list
305
+ l = 3
306
+ names = [None] * 1000
307
+ keys = [None] * 1000
308
+ types = []
309
+ bnds = []
310
+ k = -1
311
+
312
+ # get config keys and names
313
+ while True:
314
+ ann[l] = ann[l].rstrip()
315
+ if not isinstance(ann[l], str) or not ann[l]:
316
+ l += 1
317
+ break
318
+ values = ann[l].split()
319
+ k += 1
320
+ names[k] = values[0]
321
+ keys[k] = values[1]
322
+ l += 1
323
+ names = names[:k + 1]
324
+ keys = keys[:k + 1]
325
+
326
+ # read in each stream in turn until end of file
327
+ bnds0 = [None] * 10000
328
+ types0 = [None] * 10000
329
+ actions0 = [None] * 10000
330
+ nStrm1 = 0
331
+ while True:
332
+ ann[l] = ann[l].rstrip()
333
+ nStrm1 += 1
334
+ t = ann[l].split(":")
335
+ l += 1
336
+ ann[l] = ann[l].rstrip()
337
+ assert int(t[0][1]) == nStrm1
338
+ assert ann[l] == '-----------------------------'
339
+ l += 1
340
+ bnds1 = np.ones((10000, 2), dtype=int)
341
+ types1 = np.ones(10000, dtype=int) * -1
342
+ actions1 = [None] * 10000
343
+ k = 0
344
+ # start the annotations
345
+ while True:
346
+ ann[l] = ann[l].rstrip()
347
+ t = ann[l]
348
+ if not isinstance(t, str) or not t:
349
+ l += 1
350
+ break
351
+ t = ann[l].split()
352
+ type = [i for i in range(len(names)) if t[2] == names[i]]
353
+ type = type[0]
354
+ if type == None:
355
+ print('undefined behavior' + t[2])
356
+ if bnds1[k - 1, 1] != int(t[0]) - 1 and k > 0:
357
+ print('%d ~= %d' % (bnds1[k, 1], int(t[0]) - 1))
358
+ bnds1[k, :] = [int(t[0]), int(t[1])]
359
+ types1[k] = type
360
+ actions1[k] = names[type]
361
+ k += 1
362
+ l += 1
363
+ if l == len(ann):
364
+ break
365
+ if nStrm1 == 1:
366
+ nFrames = bnds1[k - 1, 1]
367
+ assert nFrames == bnds1[k - 1, 1]
368
+ bnds0[nStrm1 - 1] = bnds1[:k]
369
+ types0[nStrm1 - 1] = types1[:k]
370
+ actions0[nStrm1 - 1] = actions1[:k]
371
+ if l == len(ann):
372
+ break
373
+ while not ann[l]:
374
+ l += 1
375
+
376
+ bnds = bnds0[:nStrm1]
377
+ types = types0[:nStrm1]
378
+ actions = actions0[:nStrm1]
379
+
380
+ idx = 0
381
+ if len(actions[0]) < len(actions[1]):
382
+ idx = 1
383
+ type_frame = []
384
+ action_frame = []
385
+ len_bnd = []
386
+
387
+
388
+ for i in range(len(bnds[idx])):
389
+ numf = bnds[idx][i, 1] - bnds[idx][i, 0] + 1
390
+ len_bnd.append(numf)
391
+ action_frame.extend([actions[idx][i]] * numf)
392
+ type_frame.extend([types[idx][i]] * numf)
393
+
394
+
395
+ type_frame2 = []
396
+ action_frame2 = []
397
+ len_bnd2 = []
398
+ idx=1 if idx==0 else 0
399
+
400
+ for i in range(len(bnds[idx])):
401
+ numf = bnds[idx][i, 1] - bnds[idx][i, 0] + 1
402
+ len_bnd2.append(numf)
403
+ action_frame2.extend([actions[idx][i]] * numf)
404
+ type_frame2.extend([types[idx][i]] * numf)
405
+
406
+ ann_dict = {
407
+ 'keys': keys,
408
+ 'behs': names,
409
+ 'nstrm': nStrm1,
410
+ 'nFrames': nFrames,
411
+ 'behs_se': bnds,
412
+ 'behs_dur': len_bnd,
413
+ 'behs_bout': actions,
414
+ 'behs_frame': action_frame if 'interaction' not in action_frame else action_frame2,
415
+ 'behs_frame2': action_frame2 if 'interaction' in action_frame2 else action_frame
416
+ }
417
+
418
+ return ann_dict
419
+
420
+ def syncTopFront(f,num_frames,num_framesf):
421
+ return int(round(f / (num_framesf - 1) * (num_frames - 1))) if num_framesf > num_frames else int(round(f / (num_frames - 1) * (num_framesf - 1)))
422
+
423
+
424
+
425
+ class seqIo_reader():
426
+ def __init__(self,fname,info=[],buildTable=True):
427
+ self.filename = fname
428
+ try:
429
+ self.file=open(fname,'rb')
430
+ except EnvironmentError as e:
431
+ print(os.strerror(e.errno))
432
+ self.header={}
433
+ self.seek_table=None
434
+ self.frames_read=-1
435
+ self.timestamp_length = 10
436
+ if info==[]:
437
+ self.readHeader()
438
+ else:
439
+ info.numFrames=0
440
+ if buildTable:
441
+ print("buildTable was True, so calling buildSeekTable()")
442
+ self.buildSeekTable(False)
443
+
444
+
445
+ def readHeader(self):
446
+ #make sure we do this at the beginning of the file
447
+ assert self.frames_read == -1, "Can only read header from beginning of file"
448
+ self.file.seek(0,0)
449
+ # pdb.set_trace()
450
+
451
+ # Read 1024 bytes (len of header)
452
+ tmp = fread(self.file,1024,np.uint8)
453
+ #check that the header is not all 0's
454
+ n=len(tmp)
455
+ if n<1024:raise ValueError('no header')
456
+ if all(tmp==0): raise ValueError('fully empty header')
457
+ self.file.seek(0,0)
458
+ #first 4 bytes stor 0XFEED next 24 store 'Norpix seq '
459
+ magic_number = fread(self.file,1,np.uint32)
460
+ name = fread(self.file,10,np.uint16)
461
+ name = ''.join(map(chr,name))
462
+ if not '{0:X}'.format(magic_number)=='FEED' or not name=='Norpix seq':raise ValueError('invalid header')
463
+ self.file.seek(4,1)
464
+ #next 8 bytes for version and header size (1024) then 512 for desc
465
+ version = int(fread(self.file,1,np.int32))
466
+ hsize =int(fread(self.file,1,np.uint32))
467
+ assert(hsize)==1024 ,"incorrect header size"
468
+ # d = self.file.read(512)
469
+ descr=fread(self.file,256,np.uint16)
470
+ # descr = ''.join(map(chr,descr))
471
+ # descr = ''.join(map(unichr,descr)).replace('\x00',' ')
472
+ descr = ''.join([chr(x) for x in descr]).replace('\x00',' ')
473
+ # descr = descr.encode('utf-8')
474
+ #read more info
475
+ tmp = fread(self.file,9,np.uint32)
476
+ assert tmp[7]==0, "incorrect origin"
477
+ fps = fread(self.file,1,np.float64)
478
+ codec = 'imageFormat' + '%03d'%tmp[5]
479
+ desc_format = fread(self.file,1,np.uint32)
480
+ padding = fread(self.file,428,np.uint8)
481
+ padding = ''.join(map(chr,padding))
482
+ #store info
483
+ self.header={'magicNumber':magic_number,
484
+ 'name':name,
485
+ 'seqVersion': version,
486
+ 'headerSize':hsize,
487
+ 'descr': descr,
488
+ 'width':int(tmp[0]),
489
+ 'height':int(tmp[1]),
490
+ 'imageBitDepth':int(tmp[2]),
491
+ 'imageBitDepthReal':int(tmp[3]),
492
+ 'imageSizeBytes':int(tmp[4]),
493
+ 'imageFormat':int(tmp[5]),
494
+ 'numFrames':int(tmp[6]),
495
+ 'origin':int(tmp[7]),
496
+ 'trueImageSize':int(tmp[8]),
497
+ 'fps':fps,
498
+ 'codec':codec,
499
+ 'descFormat':desc_format,
500
+ 'padding':padding,
501
+ 'nHiddenFinalFrames':0
502
+ }
503
+ assert(self.header['imageBitDepthReal']==8)
504
+ # seek to end fo header
505
+ self.file.seek(432,1)
506
+ self.frames_read += 1
507
+
508
+ self.imageFormat = self.header['imageFormat']
509
+ if self.imageFormat in (100,200): self.ext = 'raw'
510
+ elif self.imageFormat in (102,201): self.ext = 'jpg'
511
+ elif self.imageFormat in(0x001,0x002): self.ext = 'png'
512
+ elif self.imageFormat == 101: self.ext = 'brgb8'
513
+ elif self.imageFormat == 103: self.ext = 'jbrgb'
514
+ else: raise ValueError('uknown format')
515
+
516
+ self.compressed = True if self.ext in ['jpg','jbrgb','png','brgb8'] else False
517
+ self.bit_depth = self.header['imageBitDepth']
518
+
519
+ # My code uses a timestamp_length of 10 bytes, old uses 8. Check if not 10
520
+ if self.bit_depth / 8 * (self.header['height'] * self.header['width']) + self.timestamp_length \
521
+ != self.header['trueImageSize']:
522
+ # If not 10, adjust to actual (likely 8) and print message
523
+ self.timestamp_length = int(self.header['trueImageSize'] \
524
+ - (self.bit_depth / 8 * (self.header['height'] * self.header['width'])))
525
+
526
+ def buildSeekTable(self,memoize=False):
527
+ """Build a seek table containing the offset and frame size for every frame in the video."""
528
+ print("in seqIo_reader.buildSeekTable()")
529
+ pickle_name = self.filename.strip(".seq") + ".seek"
530
+ if memoize:
531
+ if os.path.isfile(pickle_name):
532
+ self.seek_table = pickle.load(open(pickle_name, 'rb'))
533
+ return
534
+
535
+ # assert self.header['numFrames']>0
536
+ n=self.header['numFrames']
537
+ if n==0:n=1e7
538
+
539
+ seek_table = np.zeros((n)).astype(np.int64)
540
+ seek_table[0]=1024
541
+ extra = 8 # extra bytes after image data , 8 for ts then 0 or 8 empty
542
+ self.file.seek(1024,0)
543
+ #compressed case
544
+
545
+ if self.compressed:
546
+ i=1
547
+ while (True):
548
+ try:
549
+ # size = fread(self.file,1,np.uint32)
550
+ # offset = seek_table[i-1] + size +extra
551
+ # seek_table[i]=offset
552
+ # # seek_table[i-1,1]=size
553
+ # self.file.seek(size-4+extra,1)
554
+
555
+ size = fread(self.file, 1, np.uint32)
556
+ offset = seek_table[i - 1] + size + extra
557
+ # self.file.seek(size-4+extra,1)
558
+ self.file.seek(offset, 0)
559
+ if i == 1:
560
+ if fread(self.file, 1, np.uint32) != 0:
561
+ self.file.seek(-4, 1)
562
+ else:
563
+ extra += 8;
564
+ offset += 8
565
+ self.file.seek(offset, 0)
566
+
567
+ seek_table[i] = offset
568
+ # seek_table[i-1,1]=size
569
+ i+=1
570
+ except Exception as e:
571
+ break
572
+ #most likely EOF
573
+ else:
574
+ #uncompressed case
575
+ assert (self.header['numFrames']>0)
576
+ frames = range(0, self.header["numFrames"])
577
+ offsets = [x * self.header["trueImageSize"] + 1024 for x in frames]
578
+ for i,offset in enumerate(offsets):
579
+ seek_table[i]=offset
580
+ # seek_table[i,1]=self.header["imageSize"]
581
+ if n==1e7:
582
+ n = np.minimum(n,i)
583
+ self.seek_table=seek_table[:n]
584
+ self.header['numFrames']=n
585
+ else:
586
+ self.seek_table=seek_table
587
+ if memoize:
588
+ pickle.dump(seek_table,open(pickle_name,'wb'))
589
+
590
+ #compute frame rate from timestamps as stored fps may be incorrect
591
+ # if n==1: return
592
+ self.getTs()
593
+ # ds = self.ts[1:100]-self.ts[:99]
594
+ # ds = ds[abs(ds-np.median(ds))<.005]
595
+ # if bool(np.prod(ds)): self.header['fps']=1/np.mean(ds)
596
+
597
+ def getTs(self, n=None):
598
+ if n==None: n=self.header['numFrames']
599
+ if self.compressed and self.seek_table is None:
600
+ self.buildSeekTable()
601
+
602
+ ts = np.zeros((n))
603
+ for i in range(n):
604
+ if not self.compressed: #uncompressed
605
+ self.file.seek(1024 + i*self.header['trueImageSize']+self.header['imageSizeBytes'],0)
606
+ else: #compressed
607
+ self.file.seek(self.seek_table[i],0)
608
+ self.file.seek(fread(self.file,1,np.uint32)-4,1)
609
+ # print(i)
610
+ ts[i]=fread(self.file,1,np.uint32)+fread(self.file,1,np.uint16)/1000.
611
+
612
+
613
+ self.ts=ts
614
+ return self.ts
615
+
616
+ def getFrame(self,index,decode=True):
617
+ #get frame image (I) and timestamp (ts) at which frame was recorded
618
+ nch = self.header['imageBitDepth']/8
619
+ if self.ext in ['raw','brgb8']: #read in an uncompressed image( assume imageBitDepthReal==8)
620
+ shape = (self.header['height'], self.header['width'])
621
+ self.file.seek(1024 + index*self.header['trueImageSize'],0)
622
+ I = fread(self.file,self.header['imageSizeBytes'],np.uint8)
623
+
624
+ if decode:
625
+ if nch==1:
626
+ I=np.reshape(I,shape)
627
+ else:
628
+ I=np.reshape(I,(shape,nch))
629
+ if nch==3:
630
+ t=I[:,:,2]; I[:,:,2]=I[:,:,0]; I[:,:,1]=t
631
+ if self.ext=='brgb8':
632
+ I= cv2.demosaicing(I, code=cv2.COLOR_BAYER_BG2BGR)
633
+ # I= colour_demosaicing.demosaicing_CFA_Bayer_bilinear(I,'BGGR')
634
+
635
+ elif self.ext in ['jpg','jbrgb']:
636
+ self.file.seek(self.seek_table[index],0)
637
+ nBytes = fread(self.file,1,np.uint32)
638
+ data = fread(self.file,nBytes-4,np.uint8)
639
+ if decode:
640
+ I = PIL.Image.open(io.BytesIO(data))
641
+ if self.ext == 'jbrgb':
642
+ I= cv2.demosaicing(I, code=cv2.COLOR_BAYER_BG2BGR)
643
+ # I=colour_demosaicing.demosaicing_CFA_Bayer_bilinear(I,'BGGR')
644
+ else:
645
+ I = data
646
+
647
+ elif self.ext=='png':
648
+ self.file.seek(self.seek_table[index],0)
649
+ nBytes = fread(self.file,1,np.uint32)
650
+ I= fread(self.file,nBytes-4,np.uint8)
651
+ if decode:
652
+ I= np.array(I).transpose(range(I.shape,-1,-1))
653
+ else: assert(False)
654
+ ts = fread(self.file,1,np.uint32)+fread(self.file,1,np.uint16)/1000.
655
+ return np.array(I), ts
656
+
657
+ # Close the file
658
+ def close(self):
659
+ self.file.close()
660
+
661
+ class seqIo_writer():
662
+ def __init__(self,filename,old_header):
663
+ self.file = open(filename,'wb')
664
+ self.file.seek(0,0)
665
+ self.header=old_header
666
+
667
+ #create space for header
668
+ fwrite(self.file,np.zeros(1024).astype(int),np.uint8)
669
+
670
+ assert(set(['width','height','fps','codec']).issubset(self.header.keys()))
671
+
672
+ codec = self.header['codec']
673
+ if codec in ['monoraw', 'imageFormat100']: self.frmt = 100;self.nch = 1;self.ext = 'raw'
674
+ elif codec in ['raw', 'imageFormat200']: self.frmt = 200;self.nch = 3;self.ext = 'raw'
675
+ elif codec in ['monojpg', 'imageFormat102']: self.frmt = 102;self.nch = 1;self.ext = 'jpg'
676
+ elif codec in ['jpg', 'imageFormat201']: self.frmt = 201;self.nch = 3;self.ext = 'jpg'
677
+ elif codec in ['monopng', 'imageFormat001']: self.frmt = 0x001;self.nch = 1;self.ext = 'png'
678
+ elif codec in ['png', 'imageFormat002']: self.frmt = 0x002;self.nch = 3;self.ext = 'png'
679
+ else: raise ValueError('unknown format')
680
+
681
+ self.header['imageFormat']=self.frmt
682
+ self.header['imageBitDepth']=8*self.nch
683
+ self.header['imageBitDepthReal']=8
684
+ nBytes = self.header['width']*self.header['height']*self.nch
685
+ self.header['imageSizeBytes']=nBytes
686
+ self.header['numFrames']=0
687
+ self.header['trueImageSize']=nBytes + 6 +512-np.mod(nBytes+6,512)
688
+
689
+ # Close the file
690
+ def close(self):
691
+ self.writeHeader()
692
+ self.file.close()
693
+
694
+ def writeHeader(self):
695
+ self.file.seek(0,0)
696
+ # first write 4 bytes to store 0XFEED, next 24 store 'Nrpix seq '
697
+ fwrite(self.file,int('FEED',16),np.uint32)
698
+ name = np.array(['Norpix seq ']).view(np.uint8)
699
+ fwrite(self.file,name, np.uint16)
700
+ # next 8 bytes for version (3) and header size (1024) then 512 for descr
701
+ fwrite(self.file,[3,1024],np.int32)
702
+ if not 'descr' in self.header.keys() or len(np.array([self.header['descr']]).view(np.uint8))>256: d = np.array(['No Description']).view(np.uint8)
703
+ else: d= np.array([self.header['descr']]).view(np.uint8)
704
+ d = np.concatenate((d[:np.minimum(256,len(d))],np.zeros(256-len(d)).astype(np.uint8)))
705
+ fwrite(self.file,d,np.uint16)
706
+ #write remaining info
707
+ vals= [self.header['width'],self.header['height'],self.header['imageBitDepth'],self.header['imageBitDepthReal'],
708
+ self.header['imageSizeBytes'],self.header['imageFormat'],self.header['numFrames'],0,self.header['trueImageSize']]
709
+ fwrite(self.file,vals,np.uint32)
710
+ #store frame rate nad pad with 0s
711
+ fwrite(self.file,self.header['fps'],np.float64)
712
+ fwrite(self.file,np.zeros(432),np.uint8)
713
+
714
+ def addFrame(self,I,ts=0,encode=1):
715
+ nCh = self.header['imageBitDepth']/8
716
+ ext = self.ext
717
+ c = self.header['numFrames']+1
718
+ if encode:
719
+ siz = [self.header['height'],self.header['width'],nCh]
720
+ assert(I.shape[0]==siz[0] and I.shape[1]==siz[1])
721
+ if len(I.shape)==3:
722
+ assert(I.shape[2]==siz[2] or I.shape[2]==self.nch)
723
+ if ext=='raw':
724
+ #write uncompressed image and assume imageBitDepthReal==8
725
+ if not encode : assert(I.size==self.header['imageSizeBytes'])
726
+ else:
727
+ if nCh==3: t=I[:,:,2]; I[:,:,2]=I[:,:,0];I[:,:,0]=t
728
+ if nCh==1: I=I.transpose()
729
+ else: I = np.transpose( np.expand_dims(I, axis=2), (2, 1, 0) )
730
+ # I= I.flat.view(np.uint8)
731
+ I= I.flat
732
+ fwrite(self.file,I,np.uint8)
733
+ pad = self.header['trueImageSize']-self.header['imageSizeBytes']-6
734
+ if ext =='jpg':
735
+ if encode:
736
+ #write red from to temporary jpg
737
+ cv2.imwrite('tmp.jpg',I, [int(cv2.IMWRITE_JPEG_QUALITY ),80])
738
+ # j=Image.fromarray(I.astype(np.uint8))
739
+ # j.save('tmp.jpg')
740
+ # I=Image.open('tmp.jpg')
741
+ fid = open('tmp.jpg','r')
742
+ I = fid.read()
743
+ fid.close()
744
+ b=bytearray(I)
745
+ assert (b[0] == 255 and b[1] == 216 and b[-2] == 255 and b[-1] == 217); # JPG
746
+ os.remove('tmp.jpg')
747
+ I = np.array(list(b)).astype(np.uint8)
748
+ nbytes = len(I)+4
749
+ fwrite(self.file,nbytes,np.uint32)
750
+ # self.file.write(I)
751
+ fwrite(self.file,I,np.uint8)
752
+ pad = 10
753
+ if ts==0: ts = (c-1)/self.header['fps']
754
+ s = int(np.floor(ts))
755
+ ms = int(np.round(np.mod(ts,1)*1000))
756
+ fwrite(self.file,s,np.int32)
757
+ fwrite(self.file,ms,np.uint16)
758
+ self.header['numFrames']=c
759
+ if pad>0:
760
+ pad = np.zeros(pad).astype(np.uint8)
761
+ fwrite(self.file,pad,np.uint8)
762
+
763
+ def seqIo_crop(fname, tname, frames):
764
+ """
765
+ Crop sub-sequence from seq file.
766
+
767
+ Frame indices are 0 indexed. frames need not be consecutive and can
768
+ contain duplicates. An index of -1 indicates a blank (all 0) frame. If
769
+ contiguous subset of frames is cropped timestamps are preserved.
770
+
771
+ USAGE
772
+ seqIo( fName, 'crop', tName, frames )
773
+
774
+ INPUTS
775
+ fName - seq file name
776
+ tName - cropped seq file name
777
+ frames - frame indices (0 indexed)
778
+ """
779
+ if not isinstance(frames, np.ndarray): frames=np.array(frames)
780
+ sr = seqIo_reader(fname)
781
+ sw = seqIo_writer(tname,sr.header)
782
+ pad,_= sr.getFrame(0)
783
+ pad = np.zeros(pad.size).astype(np.uint8)
784
+ kp = frames>=0 & frames<sr.header['numFrames']
785
+ if not np.all(kp): frames = frames[kp]
786
+ print('%i out of bounds frames'% np.sum(~kp))
787
+ ordered = np.all(frames[1:]==frames[:-1]+1)
788
+ n= frames.size
789
+ k=0
790
+ for f in frames:
791
+ if f<0:
792
+ sw.addFrame(pad)
793
+ continue
794
+ I,ts = sr.getFrame(f)
795
+ k+=1
796
+ if ordered:
797
+ sw.addFrame(I,ts)
798
+ else:
799
+ sw.addFrame(I)
800
+ sr.close()
801
+ sw.close
802
+
803
+ def seqIo_toImgs(fName, tDir=[], skip=1, f0=0, f1=np.inf, ext=''):
804
+ """
805
+ Extract images from seq file to target directory or array.
806
+
807
+ USAGE
808
+ Is = seqIo( fName, 'toImgs', [tDir], [skip], [f0], [f1], [ext] )
809
+
810
+ INPUTS
811
+ fName - seq file name
812
+ tDir - [] target directory (if empty extract images to array)
813
+ skip - [1] skip between written frames
814
+ f0 - [0] first frame to write
815
+ f1 - [numFrames-1] last frame to write
816
+ ext - [] optionally save as given type (slow, reconverts)
817
+
818
+ OUTPUTS
819
+ Is - if isempty(tDir) outputs image array (else Is=[])
820
+ """
821
+ sr = seqIo_reader(fName)
822
+ f1 = np.minimum(f1,sr.header['numFrames']-1)
823
+ frames = range(f0,f1,skip)
824
+ n=len(frames)
825
+ k=0
826
+ #output images to array
827
+ if tDir==[]:
828
+ I,_=sr.getFrame(0)
829
+ d = I.shape
830
+ assert(len(d)==2 or len(d)==3)
831
+ try:
832
+ Is = np.zeros((I.shape+(n,))).astype(I.dtype)
833
+ except:
834
+ sr.close()
835
+ raise
836
+ for k in range(n):
837
+ I,ts = sr.getFrame(k)
838
+ if len(d)==2:
839
+ Is[:,:,k]=I
840
+ else:
841
+ Is[:,:,:,k]=I
842
+ print('saved %d' % k)
843
+
844
+ sr.close()
845
+ # output image directory
846
+ if not os.path.exists(tDir):os.makedirs(tDir)
847
+ if tDir.split('/')[-1]!='/':tDir+'/'
848
+ Is = np.array([])
849
+ for frame in frames:
850
+ f = tDir + 'I%05.' % (frame)
851
+ I, ts = sr.getFrame(frame)
852
+ if ext!='':
853
+ cv2.imwrite(f+ext,I)
854
+ else:
855
+ cv2.imwrite(f+sr.ext)
856
+ k+=1
857
+ print('saved %d' % frame)
858
+ sr.close()
859
+ return Is
860
+
861
+ def seqIo_frImgs(fName, header=[], aviName=[], Is=[], sDir=[], name='I', ndig=5, f0=0, f1=1e6):
862
+ """
863
+ Create seq file from an array or directory of images or from an AVI file.
864
+
865
+ For info, if converting from array, only codec (e.g., 'jpg') and fps must
866
+ be specified while width and height and determined automatically. If
867
+ converting from AVI, fps is also determined automatically.
868
+
869
+ USAGE
870
+ seqIo( fName, 'frImgs', info, varargin )
871
+
872
+ INPUTS
873
+ fName - seq file name
874
+ info - defines codec, etc, see seqIo>writer
875
+ varargin - additional params (struct or name/value pairs)
876
+ .aviName - [] if specified create seq from avi file
877
+ .Is - [] if specified create seq from image array
878
+ .sDir - [] source directory
879
+ .skip - [1] skip between frames
880
+ .name - ['I'] base name of images
881
+ .nDigits - [5] number of digits for filename index
882
+ .f0 - [0] first frame to read
883
+ .f1 - [10^6] last frame to read
884
+ """
885
+
886
+ if aviName!=[]: #avi movie exists
887
+ vc = cv2.VideoCapture(aviName)
888
+ if vc.isOpened(): rval = True
889
+ else:
890
+ rval = False
891
+ print('video not readable')
892
+ return
893
+ fps = vc.get(cv2.cv.CV_CAP_PROP_FPS)
894
+ NUM_FRAMES = int(vc.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
895
+ print(NUM_FRAMES)
896
+ IM_TOP_H = vc.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
897
+ IM_TOP_W = vc.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
898
+ header['width']=IM_TOP_W
899
+ header['height']=IM_TOP_H
900
+ header['fps']=fps
901
+
902
+ sw = seqIo_writer(fName,header)
903
+ print('creating seq from AVI')
904
+ # initialize timer
905
+ timer = pb.ProgressBar(widgets=['Converting ', pb.Percentage(), ' -- ',
906
+ pb.FormatLabel('Frame %(value)d'), '/',
907
+ pb.FormatLabel('%(max)d'), ' [', pb.Timer(), '] ',
908
+ pb.Bar(), ' (', pb.ETA(), ') '], maxval=NUM_FRAMES)
909
+ for f in range(NUM_FRAMES):
910
+ rval, im = vc.read()
911
+ if rval:
912
+ im= im.astype(np.uint8)
913
+ sw.addFrame(im)
914
+ timer.update(f)
915
+ sw.close()
916
+ timer.finish()
917
+ elif Is==[]:
918
+ assert(os.path.isdir(sDir))
919
+ sw = seqIo_writer(fName,header)
920
+ frmstr = '%s/%s%%0%ii.%s' % (sDir,name,ndig,header.ext)
921
+ for frame in range(f0,f1):
922
+ f = frmstr % frame
923
+ if not os.path.isfile(f):break
924
+ fid = open(f, 'r')
925
+ if fid<0: sw.close(); assert(False)
926
+ I = fid.read()
927
+ fid.close()
928
+ b = bytearray(I)
929
+ assert (b[0] == 255 and b[1] == 216 and b[-2] == 255 and b[-1] == 217); # JPG
930
+ I = np.array(list(b)).astype(np.uint8)
931
+ sw.addFrame(I,0,0)
932
+ sw.close()
933
+ if frame==f0: print('No images found')
934
+ else:
935
+ nd = len(Is.shape)
936
+ if nd==2: nd=3
937
+ assert(nd<=4)
938
+ nFrm = Is.shape[nd-1]
939
+ header['height']=Is.shape[0]
940
+ header['width']=Is.shape[1]
941
+ sw =seqIo_writer(fName,header)
942
+ if nd==3:
943
+ for f in range(nFrm): sw.addFrame(Is[:,:,f])
944
+ if nd==4:
945
+ for f in range(nFrm): sw.addFrame(Is[:,:,:,f])
946
+ sw.close()
947
+
948
+ def seqIo_convert(fName, tName, imgFun, info=[], skip=1, f0=0, f1=np.inf):
949
+ """
950
+ Convert seq file by applying imgFun(I) to each frame I.
951
+
952
+ USAGE
953
+ seqIo( fName, 'convert', tName, imgFun, varargin )
954
+
955
+ INPUTS
956
+ fName - seq file name
957
+ tName - converted seq file name
958
+ imgFun - function to apply to each image
959
+ varargin - additional params (struct or name/value pairs)
960
+ .info - [] info for target seq file
961
+ .skip - [1] skip between frames
962
+ .f0 - [0] first frame to read
963
+ .f1 - [inf] last frame to read
964
+ """
965
+ assert(fName!=tName)
966
+ sr = seqIo_reader(fName)
967
+ if info==[]: info=sr.header
968
+ n=sr.header['numFrames']
969
+ f1=np.minimum(f1,n-1)
970
+ I,ts=sr.getFrame(0)
971
+ I=imgFun(I)
972
+ info['width']=I.shape[1]
973
+ info['height']=I.shape[0]
974
+ sw =seqIo_writer(tName,info)
975
+ print('converting seq')
976
+ for frame in range(f0,f1,skip):
977
+ I, ts = sr.getFrame(frame)
978
+ I = imgFun(I)
979
+ if skip==1:
980
+ sw.addFrame(I,ts)
981
+ else:
982
+ sw.addFrameI
983
+ sw.close()
984
+ sr.close()
985
+
986
+ def seqIo_newHeader(fName, info):
987
+ """
988
+ Replace header of seq file with provided info.
989
+
990
+ Can be used if the file fName has a corrupt header. Automatically tries
991
+ to compute number of frames in fName. No guarantees that it will work.
992
+
993
+ USAGE
994
+ seqIo( fName, 'newHeader', info )
995
+
996
+ INPUTS
997
+ fName - seq file name
998
+ info - info for target seq file
999
+ """
1000
+ d, n = os.path.split(fName)
1001
+ if d==[]:d='./'
1002
+ tName=fName[:-4] + '_new' + time.strftime("%d_%m_%Y") + fName[-4:]
1003
+ sr = seqIo_reader(fName)
1004
+ sw = seqIo_writer(tName,info)
1005
+ n=sr.header['numFrames']
1006
+ for f in range(n):
1007
+ I,ts=sr.getFrame(f)
1008
+ sw.addFrame(I,ts)
1009
+ sr.close()
1010
+ sw.close()
1011
+
1012
+ class seqIo_dualReader():
1013
+ """
1014
+ seqIo_dualReader
1015
+ Create interface sr for reading dual seq files.
1016
+
1017
+ Wrapper for two seq files of the same image dims and roughly the same
1018
+ frame counts that are treated as a single reader object. getframe()
1019
+ returns the concatentation of the two frames. For videos of different
1020
+ frame counts, the first video serves as the "dominant" video and the
1021
+ frame count of the second video is adjusted accordingly. Same general
1022
+ usage as in reader, but the only supported operations are: close(),
1023
+ getframe(), getinfo(), and seek().
1024
+
1025
+ USAGE
1026
+ sr = seqIo( fNames, 'readerDual', [cache] )
1027
+
1028
+ INPUTS
1029
+ fNames - two seq file names
1030
+ cache - [0] size of cache (see seqIo>reader)
1031
+
1032
+ OUTPUTS
1033
+ sr - interface for reading seq file
1034
+ """
1035
+ def __init__(self,file1,file2):
1036
+ self.s1 = seqIo_reader(file1)
1037
+ self.s2 = seqIo_reader(file2)
1038
+ self.info = self.s1.header
1039
+ #set the display to be vertically align
1040
+ self.info['height']=self.s1.header['height']+self.s2.header['height']
1041
+ self.info['width']=np.maximum(self.s1.header['width'],self.s2.header['width'])
1042
+
1043
+ if self.s1.header['numFrames']!=self.s2.header['numFrames']:
1044
+ print('Two videos files have different number of frames')
1045
+ print('1st video has %d frames' % self.s1.header['numFrames'])
1046
+ print('2nd video has %d frames' % self.s2.header['numFrames'])
1047
+ print('first video %s is used as annotation refeence' % file1)
1048
+
1049
+ def getFrame(self):
1050
+ I1,ts = self.s1.getFrame(0)
1051
+ I2,_ = self.s2.getFrame(0)
1052
+
1053
+ w1 = I1.shape[1]
1054
+ w2 = I2.shape[1]
1055
+
1056
+ if w1!=w2:
1057
+ m=np.argmax(w1,w2)
1058
+ if m==0:
1059
+ wl = int(np.floor((w1-w2)/2.))
1060
+ wr = w1-w2-wl
1061
+ nd = len(I2.shape)
1062
+ if nd==2:
1063
+ padl = np.zeros((I2.shape[0],wl)).astype(np.uint8)
1064
+ padr = np.zeros((I2.shape[0],wr)).astype(np.uint8)
1065
+ else:
1066
+ padl = np.zeros((I2.shape[0],wl,I2.shape[2])).astype(np.uint8)
1067
+ padr = np.zeros((I2.shape[0],wr,I2.shape[2])).astype(np.uint8)
1068
+ I2 = np.concatenate((padl,I2,padr),axis=1)
1069
+ else:
1070
+ wl = int(np.floor((w2 - w1) / 2.))
1071
+ wr = w2 - w1 - wl
1072
+ nd = len(I2.shape)
1073
+ if nd == 2:
1074
+ padl = np.zeros((I1.shape[0], wl)).astype(np.uint8)
1075
+ padr = np.zeros((I1.shape[0], wr)).astype(np.uint8)
1076
+ else:
1077
+ padl = np.zeros((I1.shape[0], wl, I1.shape[2])).astype(np.uint8)
1078
+ padr = np.zeros((I1.shape[0], wr, I1.shape[2])).astype(np.uint8)
1079
+ I1 = np.concatenate((padl, I1, padr), axis=1)
1080
+ I = np.hstack((I1,I2))
1081
+ return I,ts
1082
+
1083
+ class seqIo_extractor():
1084
+ """
1085
+ Create new seq files from top and fron view and syncronize them is not
1086
+ path_vid: video path
1087
+ vid_top: seq top video path and name
1088
+ vid_front: seq front video path and name
1089
+ s: start frame
1090
+ e: end frame
1091
+
1092
+ """
1093
+ def __init__(self,path_vid,vid_top,vid_front,s,e):
1094
+ sr_top = seqIo_reader(path_vid+vid_top)
1095
+ sr_front = seqIo_reader(path_vid+vid_front)
1096
+ num_frames=sr_top.header['numFrames']
1097
+ num_framesf=sr_front.header['numFrames']
1098
+ name =os.path.dirname(video_top).split('/')[-1]
1099
+
1100
+ if not os.path.exists(pathvid + name + '_%06d_%06d' % (s, e)):
1101
+ os.makedirs(pathvid + name + '_%06d_%06d' % (s, e))
1102
+ newdir = pathvid + name + '_%06d_%06d' % (s, e)
1103
+ video_out_top = newdir + '/' + name + '_%06d_%06d_Top_J85.seq' % (s, e)
1104
+ video_out_front = newdir + '/' + name + '_%06d_%06d_Front_J85.seq' % (s, e)
1105
+
1106
+ sw_top = seqIo_writer(video_out_top, sr_top.header)
1107
+ sw_front = seqIo_writer(video_out_front, sr_front.header)
1108
+
1109
+ for f in range(s - 1, e):
1110
+ if num_framesf > num_frames:
1111
+ I_top, ts = sr_top.getFrame(f2(f))
1112
+ I_front, ts2 = sr_front.getFrame(f)
1113
+ else:
1114
+ I_top, ts = sr_top.getFrame(f)
1115
+ I_front, ts2 = sr_front.getFrame(f2(f))
1116
+ sw_top.addFrame(I_top, ts)
1117
+ sw_front.addFrame(I_front, ts2)
1118
+ print(f)
1119
+ sw_top.close()
1120
+ sw_front.close()
1121
+
1122
+ def f2(f):
1123
+ return int(round(f / (num_framesf - 1) * (num_frames - 1))) if num_framesf > num_frames else int(round(f / (num_frames - 1) * (num_framesf - 1)))
1124
+
1125
+ def seqIo_toVid(fName, ext='avi'):
1126
+ """
1127
+ seqIo_toVid
1128
+ Create seq file to another common used format as avi or mp4.
1129
+
1130
+ USAGE
1131
+ seqIo( fName, ext )
1132
+
1133
+ INPUTS
1134
+ fName - seq file name
1135
+ ext - video extension to convert to
1136
+ """
1137
+
1138
+ assert fName[-3:]=='seq', 'Not a seq file'
1139
+ sr = seqIo_reader(fName)
1140
+ N = sr.header['numFrames']
1141
+ h = sr.header['height']
1142
+ w = sr.header['width']
1143
+ fps = sr.header['fps']
1144
+
1145
+ out = fName[:-3]+ext
1146
+ sw = skvideo.io.FFmpegWriter(out)
1147
+ # sw = cv2.VideoWriter(out, -1, fps, (w, h))
1148
+ timer = pb.ProgressBar(widgets=['Converting ', pb.Percentage(), ' -- ',
1149
+ pb.FormatLabel('Frame %(value)d'), '/',
1150
+ pb.FormatLabel('%(max)d'), ' [', pb.Timer(), '] ',
1151
+ pb.Bar(), ' (', pb.ETA(), ') '], maxval=N)
1152
+
1153
+ for f in range(N):
1154
+ I, ts = sr.getFrame(f)
1155
+ #sw.writeFrame(Image.fromarray(I))
1156
+ sw.write(I)
1157
+ timer.update(f)
1158
+ timer.finish()
1159
+ # cv2.destroyAllWindows()
1160
+ # sw.release()
1161
+ sw.close()
1162
+ sr.close()
1163
+ print(out + ' converted')
1164
+
1165
+
1166
+
1167
+
1168
+ # minimum header
1169
+ # header = {'width': IM_TOP_W,
1170
+ # 'height': IM_TOP_H,
1171
+ # 'fps': fps,
1172
+ # 'codec': 'imageFormat102'}
1173
+ # filename= '/media/cristina/MARS_data/mice_project/teresa/Mouse156_20161017_17-22-09/Mouse156_20161017_17-22-09_Top_J85.seq'
1174
+ # filename_out = filename[:-4] + '_new.seq'
1175
+ # reader = seqIo_reader(filename)
1176
+ # reader.header
1177
+ # Initialize a SEQ writer
1178
+ # writer = seqIo_writer(filename_out,reader.header)
1179
+ # I,ts = reader.getFrame(0)
1180
+ # writer.addFrame(I,ts)
1181
+ # for f in range(8):
1182
+ # I,ts = reader.getFrame(f)
1183
+ # print(writer.file.tell())
1184
+ # writer.addFrame(I,ts)
1185
+ # writer.close()
1186
+ # reader.close()
1187
+
1188
+
1189
+
utils/utils.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import pickle
4
+ import copy
5
+ from collections import Counter
6
+ from pathlib import Path
7
+ from tempfile import NamedTemporaryFile
8
+ import regex as re
9
+ import numpy as np
10
+ import pandas as pd
11
+ from sklearn.manifold import TSNE
12
+ from sklearn.svm import SVC
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import accuracy_score, classification_report
15
+ import torch
16
+ from tqdm import tqdm
17
+ from PIL import Image
18
+ from transformers import AutoProcessor, AutoModel
19
+ import streamlit as st
20
+ from .data_loading import load_multiple_annotations, load_multiple_annotations_io
21
+ from .data_processing import generate_label_array
22
+ from .seqIo import seqIo_reader
23
+ from .mp4Io import mp4Io_reader
24
+
25
+ SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
26
+ CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
27
+
28
+ def create_annot_fname_dict(annot_fnames: list[str])-> dict:
29
+ fs = re.compile(r'.*(_\d+)$')
30
+
31
+ unique_files = set()
32
+ for file in annot_fnames:
33
+ file_name = os.fsdecode(file)
34
+ base_name, _ = os.path.splitext(file_name)
35
+ if fs.match(base_name):
36
+ ind = len(fs.match(base_name).group(1))
37
+ unique_files.add(base_name[:-ind])
38
+ else:
39
+ unique_files.add(base_name)
40
+
41
+ annot_fname_dict = {}
42
+ for unique_file in unique_files:
43
+ annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
44
+ return annot_fname_dict
45
+
46
+ def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
47
+ annot_file_dict = {}
48
+ for file in annot_files:
49
+ annot_file_dict.update({file.name : file})
50
+ fs = re.compile(r'.*(_\d+)$')
51
+
52
+ unique_files = set()
53
+ for file in annot_fnames:
54
+ file_name = os.fsdecode(file)
55
+ base_name, _ = os.path.splitext(file_name)
56
+ if fs.match(base_name):
57
+ ind = len(fs.match(base_name).group(1))
58
+ unique_files.add(base_name[:-ind])
59
+ else:
60
+ unique_files.add(base_name)
61
+
62
+ annot_fname_dict = {}
63
+ for unique_file in unique_files:
64
+ annot_list = [file for file in annot_fnames if unique_file in file]
65
+ annot_list.sort()
66
+ annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
67
+ annot_fname_dict.update({unique_file: annot_file_list})
68
+ return annot_fname_dict
69
+
70
+ def get_io_reader(uploaded_file):
71
+ assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
72
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
73
+ temp.write(uploaded_file.getvalue())
74
+ sr = seqIo_reader(temp.name)
75
+ return sr
76
+
77
+ def load_slip_model(device):
78
+ return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)
79
+
80
+ def load_slip_preprocessor():
81
+ return AutoProcessor.from_pretrained(SLIP_MODEL_ID)
82
+
83
+ def load_clip_model(device):
84
+ return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)
85
+
86
+ def load_clip_preprocessor():
87
+ return AutoProcessor.from_pretrained(CLIP_MODEL_ID)
88
+
89
+ def encode_image(image, device, model, processor):
90
+ with torch.no_grad():
91
+ #convert_models_to_fp32(model)
92
+ inputs = processor(images=image, return_tensors="pt").to(device)
93
+ image_features = model.get_image_features(**inputs)
94
+ return image_features.cpu().numpy().flatten()
95
+
96
+ def generate_embeddings_stream(fnames : list[str],
97
+ model = 'SLIP',
98
+ downsample_rate = 4,
99
+ save_csv = False)-> tuple[list, list, list]:
100
+ # set up model and device
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
103
+ if model == 'SLIP':
104
+ embed_model = load_slip_model(device)
105
+ processor = load_slip_preprocessor()
106
+ elif model == 'CLIP':
107
+ embed_model = load_clip_model(device)
108
+ processor = load_clip_preprocessor()
109
+
110
+ all_video_embeddings = []
111
+ all_video_frames = []
112
+ for fname in fnames:
113
+ # read in file
114
+ is_seq = False
115
+ if fname[-3:] == 'seq': is_seq = True
116
+
117
+ if is_seq:
118
+ sr = seqIo_reader(fname)
119
+ else:
120
+ sr = mp4Io_reader(fname)
121
+ N = sr.header['numFrames']
122
+
123
+ # set up embeddings and frame arrays
124
+ embeddings = []
125
+ frames = list(range(N))[::downsample_rate]
126
+ print(frames)
127
+
128
+ # create progress bar
129
+ i = 0
130
+ pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
131
+ pbar = st.progress(0, text=pbar_text(0))
132
+
133
+ # convert each frame to embeddings
134
+ for f in tqdm(frames):
135
+ img, _ = sr.getFrame(f)
136
+ img_arr = np.array(img)
137
+ if is_seq:
138
+ img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
139
+ else:
140
+ img_rgb = Image.fromarray(img_arr).convert('RGB')
141
+
142
+ embeddings.append(encode_image(img_rgb, device, embed_model, processor))
143
+
144
+ # update progress bar
145
+ i += 1
146
+ pbar.progress(i/len(frames), pbar_text(i))
147
+
148
+ # save csv of single file
149
+ if save_csv:
150
+ df = pd.DataFrame(embeddings)
151
+ df['Frame'] = frames
152
+
153
+ # save csv
154
+ basename = Path(fname).stem
155
+ df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)
156
+
157
+ all_video_embeddings.append(np.array(embeddings))
158
+ all_video_frames.append(frames)
159
+ return all_video_embeddings, all_video_frames
160
+
161
+ def get_io_reader(uploaded_file):
162
+ if uploaded_file.name[-3:]=='seq':
163
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
164
+ temp.write(uploaded_file.getvalue())
165
+ sr = seqIo_reader(temp.name)
166
+ else:
167
+ with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
168
+ temp.write(uploaded_file.getvalue())
169
+ sr = mp4Io_reader(temp.name)
170
+ return sr
171
+
172
+ def generate_embeddings_stream_io(uploaded_files : list,
173
+ model = 'SLIP',
174
+ downsample_rate = 4,
175
+ save_csv = False)-> tuple[list, list, list]:
176
+ # set up model and device
177
+ device = "cuda" if torch.cuda.is_available() else "cpu"
178
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
179
+ if model == 'SLIP':
180
+ embed_model = load_slip_model(device)
181
+ processor = load_slip_preprocessor()
182
+ elif model == 'CLIP':
183
+ embed_model = load_clip_model(device)
184
+ processor = load_clip_preprocessor()
185
+
186
+ all_video_embeddings = []
187
+ all_video_frames = []
188
+ for file in uploaded_files:
189
+ is_seq = False
190
+ if file.name[-3:] == 'seq': is_seq = True
191
+
192
+ # read in file
193
+ sr = get_io_reader(file)
194
+ N = sr.header['numFrames']
195
+
196
+ # set up embeddings and frame arrays
197
+ embeddings = []
198
+ frames = list(range(N))[::downsample_rate]
199
+ print(frames)
200
+
201
+ # create progress bar
202
+ i = 0
203
+ pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
204
+ pbar = st.progress(0, text=pbar_text(0))
205
+
206
+ # convert each frame to embeddings
207
+ for f in tqdm(frames):
208
+ img, _ = sr.getFrame(f)
209
+ img_arr = np.array(img)
210
+ if is_seq:
211
+ img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
212
+ else:
213
+ img_rgb = Image.fromarray(img_arr).convert('RGB')
214
+
215
+ embeddings.append(encode_image(img_rgb, device, embed_model, processor))
216
+
217
+ # update progress bar
218
+ i += 1
219
+ pbar.progress(i/len(frames), pbar_text(i))
220
+
221
+ # save csv of single file
222
+ if save_csv:
223
+ df = pd.DataFrame(embeddings)
224
+ df['Frame'] = frames
225
+
226
+ # save csv
227
+ df.to_csv(f'embeddings_downsample_{downsample_rate}_{frames}_frames.csv', index=False)
228
+
229
+ all_video_embeddings.append(np.array(embeddings))
230
+ all_video_frames.append(frames)
231
+ return all_video_embeddings, all_video_frames
232
+
233
+ def create_embeddings_csv(out: str,
234
+ fnames: list[str],
235
+ embeddings: list[np.ndarray],
236
+ frames: list[list[int]],
237
+ annotations: list[list[str]],
238
+ test_fnames: None | list[str],
239
+ views: None | list[str],
240
+ conditions: None | list[str],
241
+ downsample_rate = 4,
242
+ filesystem = None):
243
+ """
244
+ Creates a .csv file containing all of the generated embeddings and provived information.
245
+
246
+ Parameters:
247
+ -----------
248
+ out : str
249
+ The name of the resulting file.
250
+ fnames : list[str]
251
+ Video sources for each of the embedding arrays.
252
+ embeddings : np.ndarray
253
+ The generated embeddings from the images.
254
+ downsample_rate : int
255
+ The downsample_rate used for generating the embeddings.
256
+ """
257
+ assert len(fnames) == len(embeddings)
258
+ assert len(embeddings) == len(frames)
259
+ all_embeddings = np.vstack(embeddings)
260
+ df = pd.DataFrame(all_embeddings)
261
+
262
+ labels = []
263
+ for i, annot_fnames in enumerate(annotations):
264
+ _, ext = os.path.splitext(annot_fnames[0])
265
+ if ext == '.annot':
266
+ annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
267
+ annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
268
+ elif ext == '.csv':
269
+ if not filesystem:
270
+ annot_df = pd.read_csv(annot_fnames[0], header=None)
271
+ else:
272
+ with filesystem.open(annot_fnames[0], 'r') as csv_file:
273
+ annot_df = pd.read_csv(csv_file, header=None)
274
+ annot_labels = annot_df[0].to_list()[::downsample_rate]
275
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
276
+ else:
277
+ raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
278
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
279
+ print(annot_labels)
280
+ labels.append(annot_labels)
281
+ all_labels = np.hstack(labels)
282
+ print(len(all_labels))
283
+ df['Label'] = all_labels
284
+
285
+ all_frames = np.hstack(frames)
286
+ df['Frame'] = all_frames
287
+ sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
288
+ all_sources = np.hstack(sources)
289
+ df['Source'] = all_sources
290
+
291
+ if test_fnames:
292
+ t_split = lambda x: True if x in test_fnames else False
293
+ test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
294
+ else:
295
+ test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
296
+ all_test = np.hstack(test)
297
+ df['Test'] = all_test
298
+
299
+ if views:
300
+ view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
301
+ else:
302
+ view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
303
+ all_view = np.hstack(view)
304
+ df['View'] = all_view
305
+
306
+ if conditions:
307
+ condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
308
+ else:
309
+ condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
310
+ all_condition = np.hstack(condition)
311
+ df['Condition'] = all_condition
312
+ return df
313
+
314
+ def create_embeddings_csv_io(out: str,
315
+ fnames: list[str],
316
+ embeddings: list[np.ndarray],
317
+ frames: list[list[int]],
318
+ annotations: list,
319
+ test_fnames: None | list[str],
320
+ views: None | list[str],
321
+ conditions: None | list[str],
322
+ downsample_rate = 4):
323
+ """
324
+ Creates a .csv file containing all of the generated embeddings and provived information.
325
+
326
+ Parameters:
327
+ -----------
328
+ out : str
329
+ The name of the resulting file.
330
+ fnames : list[str]
331
+ Video sources for each of the embedding arrays.
332
+ embeddings : np.ndarray
333
+ The generated embeddings from the images.
334
+ downsample_rate : int
335
+ The downsample_rate used for generating the embeddings.
336
+ """
337
+ assert len(fnames) == len(embeddings)
338
+ assert len(embeddings) == len(frames)
339
+ all_embeddings = np.vstack(embeddings)
340
+ df = pd.DataFrame(all_embeddings)
341
+
342
+ labels = []
343
+ for i, uploaded_annots in enumerate(annotations):
344
+ print(i)
345
+ _, ext = os.path.splitext(uploaded_annots[0].name)
346
+ if ext == '.annot':
347
+ annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
348
+ annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
349
+ elif ext == '.csv':
350
+ annot_df = pd.read_csv(uploaded_annots[0], header=None)
351
+ annot_labels = annot_df[0].to_list()[::downsample_rate]
352
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
353
+ else:
354
+ raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
355
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
356
+ print(annot_labels)
357
+ labels.append(annot_labels)
358
+ all_labels = np.hstack(labels)
359
+ print(len(all_labels))
360
+ df['Label'] = all_labels
361
+
362
+ all_frames = np.hstack(frames)
363
+ df['Frame'] = all_frames
364
+ sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
365
+ all_sources = np.hstack(sources)
366
+ df['Source'] = all_sources
367
+
368
+ if test_fnames:
369
+ t_split = lambda x: True if x in test_fnames else False
370
+ test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
371
+ else:
372
+ test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
373
+ all_test = np.hstack(test)
374
+ df['Test'] = all_test
375
+
376
+ if views:
377
+ view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
378
+ else:
379
+ view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
380
+ all_view = np.hstack(view)
381
+ df['View'] = all_view
382
+
383
+ if conditions:
384
+ condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
385
+ else:
386
+ condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
387
+ all_condition = np.hstack(condition)
388
+ df['Condition'] = all_condition
389
+ return df
390
+
391
+ def process_dataset_in_mem(embeddings_df: pd.DataFrame,
392
+ specified_classes=None,
393
+ classes_to_remove=None,
394
+ max_class_size=None,
395
+ animal_state=None,
396
+ view=None,
397
+ shuffle_data=False,
398
+ test_videos=None):
399
+ """
400
+ Processes output generated from embeddings paired with images and behavior labels.
401
+
402
+ Parameters:
403
+ -----------
404
+ csv_path : str
405
+ Path to the file containing the original data. This should contain embeddings,
406
+ a column named `'Label'` and a column named `'Images'`.
407
+ specified_classes : None | list[str]
408
+ An optional input. Defines labels which should be kept as is in the `'Label'`
409
+ column and which should be changed to a default `other` label.
410
+ classes_to_remove : None | list[str]
411
+ An optional input. Drops rows from the dataframe which contain a label in the
412
+ list.
413
+ max_class_size : None | int
414
+ An optional input. Determines the maximum amount of rows a single label can
415
+ appear in for each unique label in the `'Label'` column.
416
+ animal_state : None | str
417
+ An optional input. Drops rows from the dataframe which do not contain a match
418
+ for `animal_state` in the text field within the `'Images'` column.
419
+ view : None | str
420
+ An optional input. Drops rows from the dataframe which do not contain a match
421
+ for `view` in the text field within the `'Images'` column.
422
+ shuffle_data : bool
423
+ Determines wether the dataframe should have its rows shuffled.
424
+ test_videos : None | list[str]
425
+ An optional input. Determines what rows should be in the `test` dataframe, and
426
+ which should be in the `train` dataframe. It drops rows from the respective
427
+ dataframe by keeping or dropping rows which do not contain a match for a `str`
428
+ in `test_videos` in the text field within the `'Images'` column, respectively.
429
+
430
+ Returns:
431
+ --------
432
+ balanced_train_embeddings : pandas.DataFrame
433
+ A processed dataframe whose rows contain the embeddings for each of the images
434
+ at the corresponding index within `balanced_train_images`.
435
+ balanced_train_labels : list[str]
436
+ A list of labels for each of the images at the corresponing index within
437
+ `balanced_train_images`.
438
+ balanced_train_images: list[str]
439
+ A list of paths to images with each image at an index corresponding to a label
440
+ with the same index in `balanced_train_labels` and the same row index within
441
+ `balanced_train_embeddings`.
442
+ test_embeddings : pandas.DataFrame
443
+ A processed dataframe whose rows contain the embeddings for each of the images
444
+ at the corresponding index within `test_images`.
445
+ test_labels : list[str]
446
+ A list of labels for each of the images at the corresponing index within
447
+ `test_images`.
448
+ test_images : list[str]
449
+ A list of paths to images with each image at an index corresponding to a label
450
+ with the same index in `test_labels` and the same row index within
451
+ `test_embeddings`.
452
+ """
453
+ # Convert embeddings, labels, and images to a DataFrame for easy manipulation
454
+ df = copy.deepcopy(embeddings_df)
455
+ df_keys = [str(x) for x in df.keys()]
456
+ #Filter by fed or fasted
457
+ if 'Condition' in df_keys and animal_state:
458
+ df = df[df['Condition'].str.contains(animal_state, na=False)]
459
+
460
+ if 'View' in df_keys and view:
461
+ df = df[df['View'].str.contains(view, na=False)]
462
+
463
+ # Extract unique video names excluding the frame number
464
+ #unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
465
+ #print("\nUnique video names:\n", unique_video_names)
466
+
467
+ if classes_to_remove:
468
+ df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
469
+ elif classes_to_remove and 'all' in classes_to_remove:
470
+ df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
471
+
472
+ # Further filter to include only specified_classes
473
+ if specified_classes:
474
+ single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
475
+ df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
476
+ specified_classes.append('other')
477
+
478
+ # Separate the DataFrame into test and training sets based on test_videos
479
+ if 'Test' in df_keys and test_videos:
480
+ test_df = df[df['Test']]
481
+ train_df = df[~df['Test']]
482
+ elif test_videos:
483
+ test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
484
+ train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
485
+ else:
486
+ test_df = pd.DataFrame(columns=df.columns)
487
+ train_df = df
488
+
489
+ # Print the number of frames in each class before balancing
490
+ label_counts = train_df['Label'].value_counts()
491
+ print("\nNumber of training frames in each class before balancing:")
492
+ print(label_counts)
493
+
494
+ if max_class_size:
495
+ balanced_train_df = pd.concat([
496
+ group.sample(n=min(len(group), max_class_size), random_state=1)
497
+ for label, group in train_df.groupby('Label')
498
+ ])
499
+ else:
500
+ balanced_train_df = train_df
501
+
502
+ # Shuffle the training DataFrame
503
+ if shuffle_data:
504
+ balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
505
+
506
+ # Convert training set back to numpy array and list
507
+ if not "Images" in df_keys:
508
+ balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
509
+ balanced_train_labels = balanced_train_df['Label'].tolist()
510
+ balanced_train_images = balanced_train_df['Frame'].tolist()
511
+
512
+ # Convert test set back to numpy array and list
513
+ test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
514
+ test_labels = test_df['Label'].tolist()
515
+ test_images = test_df['Frame'].tolist()
516
+ else:
517
+ # Convert training set back to numpy array and list
518
+ balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
519
+ balanced_train_labels = balanced_train_df['Label'].tolist()
520
+ balanced_train_images = balanced_train_df['Images'].tolist()
521
+
522
+ # Convert test set back to numpy array and list
523
+ test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()
524
+ test_labels = test_df['Label'].tolist()
525
+ test_images = test_df['Images'].tolist()
526
+
527
+ # Print the number of frames in each class after balancing
528
+ if specified_classes or max_class_size:
529
+ balanced_label_counts = Counter(balanced_train_labels)
530
+ print("\nNumber of training frames in each class after balancing:")
531
+ print(balanced_label_counts)
532
+
533
+ test_label_counts = test_df['Label'].value_counts()
534
+ # print("\nNumber of testing frames in each class:")
535
+ print(test_label_counts)
536
+
537
+ return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images
538
+
539
+ def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
540
+ # Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
541
+ unique_labels = np.unique(multiclass_vector)
542
+ unique_labels = unique_labels[unique_labels != 0]
543
+
544
+ # Initialize a vector to store the merged and filtered multiclass vector
545
+ merged_vector = np.zeros_like(multiclass_vector)
546
+
547
+ for label in unique_labels:
548
+ # Create a binary vector for the current label
549
+ binary_vector = (multiclass_vector == label)
550
+
551
+ # Find the start and end indices of all sequences of 1's for this label
552
+ starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
553
+ ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
554
+
555
+ # Step 1: Merge close short bouts
556
+ i = 0
557
+ while i < len(starts) - 1:
558
+ # Check if the gap between the end of the current bout and the start of the next bout
559
+ # is within the proximity threshold
560
+ if starts[i + 1] - ends[i] <= proximity_threshold:
561
+ # Merge the two bouts by setting all elements between the start of the first
562
+ # and the end of the second bout to 1
563
+ binary_vector[ends[i]:starts[i + 1]] = 1
564
+ # Remove the next bout from consideration
565
+ starts = np.delete(starts, i + 1)
566
+ ends = np.delete(ends, i)
567
+ else:
568
+ i += 1
569
+
570
+ # Update the starts and ends after merging
571
+ starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
572
+ ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
573
+
574
+ # Step 2: Remove standalone short bouts
575
+ for i in range(len(starts)):
576
+ # Check the length of the bout
577
+ length_of_bout = ends[i] - starts[i] + 1
578
+
579
+ # If the length is less than the threshold, set those elements to 0
580
+ if length_of_bout < bout_threshold:
581
+ binary_vector[starts[i]:ends[i] + 1] = 0
582
+
583
+ # Combine the binary vector with the merged_vector, ensuring only the current label is set
584
+ merged_vector[binary_vector] = label
585
+
586
+ # Return the filtered multiclass vector
587
+ return merged_vector
588
+
589
+ def get_unique_labels(label_list: list[str]):
590
+ label_set = set()
591
+ for label in label_list:
592
+ individual_labels = label.split('||')
593
+ for individual_label in individual_labels:
594
+ label_set.add(individual_label)
595
+ return list(label_set)
596
+
597
+ def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
598
+ return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
599
+
600
+ def train_model(X_train, y_train, random_state=42):
601
+ # Train SVM Classifier
602
+ svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
603
+ svm_clf.fit(X_train, y_train)
604
+ return svm_clf
605
+
606
+ def pickle_model(model):
607
+ pickled = io.BytesIO()
608
+ pickle.dump(model, pickled)
609
+ return pickled
610
+
611
+ def get_seq_io_reader(uploaded_file):
612
+ assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
613
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
614
+ temp.write(uploaded_file.getvalue())
615
+ sr = seqIo_reader(temp.name)
616
+ return sr
617
+
618
+ def seq_to_arr(sr):
619
+ N = sr.header['numFrames']
620
+ images = []
621
+ for f in range(N):
622
+ I, ts = sr.getFrame(f)
623
+ images.append(I)
624
+ return np.array(images)
625
+
626
+ def get_2d_embedding(embeddings: pd.DataFrame):
627
+ tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
628
+ embedding_2d = tsne.fit_transform(np.array(embeddings))
629
+ return embedding_2d
630
+
631
+
632
+