Spaces:
Sleeping
Sleeping
add main program files
Browse files- app.py +21 -0
- apply_model.py +195 -0
- explore.py +337 -0
- generate_embeddings.py +131 -0
- get_llava_response.py +184 -0
- home.py +5 -0
- pyproject.toml +43 -0
- train_model.py +159 -0
- utils/__init__.py +0 -0
- utils/annot.py +641 -0
- utils/behavior.py +291 -0
- utils/data_loading.py +198 -0
- utils/data_processing.py +384 -0
- utils/mp4Io.py +60 -0
- utils/seqIo.py +1189 -0
- utils/utils.py +632 -0
app.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
pages = {
|
5 |
+
"home": [
|
6 |
+
st.Page("home.py",title="about",icon=":material/home:")
|
7 |
+
],
|
8 |
+
"generate embeddings": [
|
9 |
+
st.Page("generate_embeddings.py", title="generate",icon=":material/dataset:")
|
10 |
+
],
|
11 |
+
"annotation": [
|
12 |
+
st.Page("train_model.py", title="train model",icon=":material/model_training:"),
|
13 |
+
st.Page("apply_model.py", title="apply model",icon=":material/grade:"),
|
14 |
+
],
|
15 |
+
"behavior discovery": [
|
16 |
+
st.Page("explore.py", title="explore",icon=":material/search:")
|
17 |
+
],
|
18 |
+
}
|
19 |
+
|
20 |
+
pg = st.navigation(pages)
|
21 |
+
pg.run()
|
apply_model.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
from random import random
|
4 |
+
import streamlit as st
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from matplotlib.colors import ListedColormap
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
from utils.mp4Io import mp4Io_reader
|
11 |
+
from utils.seqIo import seqIo_reader
|
12 |
+
import pandas as pd
|
13 |
+
from PIL import Image
|
14 |
+
from pathlib import Path
|
15 |
+
from transformers import AutoProcessor, AutoModel
|
16 |
+
from tempfile import NamedTemporaryFile
|
17 |
+
from tqdm import tqdm
|
18 |
+
from sklearn.metrics import accuracy_score, classification_report
|
19 |
+
from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, multiclass_merge_and_filter_bouts, generate_embeddings_stream_io
|
20 |
+
|
21 |
+
# --server.maxUploadSize 3000
|
22 |
+
|
23 |
+
def get_io_reader(uploaded_file):
|
24 |
+
if uploaded_file.name[-3:]=='seq':
|
25 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
26 |
+
temp.write(uploaded_file.getvalue())
|
27 |
+
sr = seqIo_reader(temp.name)
|
28 |
+
else:
|
29 |
+
with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
|
30 |
+
temp.write(uploaded_file.getvalue())
|
31 |
+
sr = mp4Io_reader(temp.name)
|
32 |
+
return sr
|
33 |
+
|
34 |
+
def get_unique_labels(label_list: list[str]):
|
35 |
+
label_set = set()
|
36 |
+
for label in label_list:
|
37 |
+
individual_labels = label.split('||')
|
38 |
+
for individual_label in individual_labels:
|
39 |
+
label_set.add(individual_label)
|
40 |
+
return list(label_set)
|
41 |
+
|
42 |
+
def get_smoothed_predictions(svm_model, test_embeds):
|
43 |
+
test_pred = svm_model.predict(test_embeds)
|
44 |
+
test_prob = svm_model.predict_proba(test_embeds)
|
45 |
+
|
46 |
+
bout_threshold = 5
|
47 |
+
proximity_threshold = 2
|
48 |
+
|
49 |
+
predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
|
50 |
+
return predictions
|
51 |
+
|
52 |
+
if "embeddings_df" not in st.session_state:
|
53 |
+
st.session_state.embeddings_df = None
|
54 |
+
|
55 |
+
if "smoothed_predictions" not in st.session_state:
|
56 |
+
st.session_state.smoothed_predictions = None
|
57 |
+
st.session_state.test_labels = []
|
58 |
+
|
59 |
+
st.title('batik: frame classifier')
|
60 |
+
|
61 |
+
st.text("Upload files to apply trained classifier on.")
|
62 |
+
with st.form('embedding_generation_settings'):
|
63 |
+
seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=False)
|
64 |
+
annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
|
65 |
+
downsample_rate = st.number_input('Downsample Rate',value=4)
|
66 |
+
submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
|
67 |
+
|
68 |
+
st.markdown("**(Optional)** Upload embeddings if not generating above.")
|
69 |
+
embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
|
70 |
+
|
71 |
+
if submit_embed_settings and seq_file is not None and annot_files is not None:
|
72 |
+
video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
|
73 |
+
"SLIP",
|
74 |
+
downsample_rate,
|
75 |
+
False)
|
76 |
+
|
77 |
+
fnames = [seq_file.name]
|
78 |
+
embeddings_df = create_embeddings_csv_io(out="file",
|
79 |
+
fnames=fnames,
|
80 |
+
embeddings=video_embeddings,
|
81 |
+
frames=video_frames,
|
82 |
+
annotations=[annot_files],
|
83 |
+
test_fnames=None,
|
84 |
+
views=None,
|
85 |
+
conditions=None,
|
86 |
+
downsample_rate=downsample_rate)
|
87 |
+
st.session_state.embeddings_df = embeddings_df
|
88 |
+
|
89 |
+
elif embeddings_csv is not None:
|
90 |
+
embeddings_df = pd.read_csv(embeddings_csv)
|
91 |
+
st.session_state.embeddings_df = embeddings_df
|
92 |
+
else:
|
93 |
+
st.text('Please upload file(s).')
|
94 |
+
|
95 |
+
st.divider()
|
96 |
+
st.markdown("Upload classifier model.")
|
97 |
+
pickled_file = st.file_uploader("Choose a .pkl File", type=['pkl'])
|
98 |
+
|
99 |
+
if pickled_file is not None:
|
100 |
+
with NamedTemporaryFile(suffix='pkl', delete=False) as temp:
|
101 |
+
temp.write(pickled_file.getvalue())
|
102 |
+
with open(temp.name, 'rb') as pickled_model:
|
103 |
+
svm_clf = pickle.load(pickled_model)
|
104 |
+
else:
|
105 |
+
svm_clf = None
|
106 |
+
|
107 |
+
st.divider()
|
108 |
+
if st.session_state.embeddings_df is not None and svm_clf is not None:
|
109 |
+
st.subheader("specify dataset labels")
|
110 |
+
label_list = st.session_state.embeddings_df['Label'].to_list()
|
111 |
+
unique_label_list = get_unique_labels(label_list)
|
112 |
+
|
113 |
+
with st.form('apply_model_settings'):
|
114 |
+
st.text("Select label(s):")
|
115 |
+
specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
|
116 |
+
|
117 |
+
|
118 |
+
apply_model = st.form_submit_button("Apply Model")
|
119 |
+
|
120 |
+
if apply_model:
|
121 |
+
kwargs = {'embeddings_df' : st.session_state.embeddings_df,
|
122 |
+
'specified_classes' : specified_classes,
|
123 |
+
'classes_to_remove' : None,
|
124 |
+
'max_class_size' : None,
|
125 |
+
'animal_state' : None,
|
126 |
+
'view' : None,
|
127 |
+
'shuffle_data' : False,
|
128 |
+
'test_videos' : list(set(st.session_state.embeddings_df['Source'].to_list()))}
|
129 |
+
train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\
|
130 |
+
process_dataset_in_mem(**kwargs)
|
131 |
+
|
132 |
+
# get predictions from embeddings
|
133 |
+
with st.spinner("Model application in progress..."):
|
134 |
+
smoothed_predictions = get_smoothed_predictions(svm_clf, test_embeds)
|
135 |
+
|
136 |
+
# save variables to state
|
137 |
+
st.session_state.smoothed_predictions = smoothed_predictions
|
138 |
+
st.session_state.test_labels = test_labels
|
139 |
+
|
140 |
+
if st.session_state.smoothed_predictions is not None:
|
141 |
+
# Convert labels to numerical values
|
142 |
+
label_to_appear_first = 'other'
|
143 |
+
unique_labels = set(st.session_state.test_labels)
|
144 |
+
unique_labels.discard(label_to_appear_first)
|
145 |
+
|
146 |
+
label_to_index = {label_to_appear_first: 0}
|
147 |
+
|
148 |
+
label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
|
149 |
+
index_to_label = {idx: label for label, idx in label_to_index.items()}
|
150 |
+
|
151 |
+
numerical_labels_test = np.array([label_to_index[label] for label in st.session_state.test_labels])
|
152 |
+
print("Label Valence: ", label_to_index)
|
153 |
+
|
154 |
+
#smoothed_predictions test labels
|
155 |
+
if len(st.session_state.smoothed_predictions) > 0:
|
156 |
+
test_accuracy = accuracy_score(numerical_labels_test, st.session_state.smoothed_predictions)
|
157 |
+
else:
|
158 |
+
test_accuracy = 0 # If no predictions meet the threshold, set accuracy to 0
|
159 |
+
|
160 |
+
# test_accuracy = accuracy_score(numerical_labels_test, test_pred)
|
161 |
+
report = classification_report(numerical_labels_test,
|
162 |
+
st.session_state.smoothed_predictions,
|
163 |
+
target_names=[index_to_label[idx] for idx in range(len(index_to_label))],
|
164 |
+
output_dict=True)
|
165 |
+
report_df = pd.DataFrame(report).transpose()
|
166 |
+
|
167 |
+
st.text(f"Eval Accuracy: {test_accuracy}")
|
168 |
+
st.subheader("Classification Report:")
|
169 |
+
st.dataframe(report_df)
|
170 |
+
|
171 |
+
# create figure (behavior raster)
|
172 |
+
fig, ax = plt.subplots()
|
173 |
+
raster = ax.imshow(st.session_state.smoothed_predictions.reshape((1,st.session_state.smoothed_predictions.size)),
|
174 |
+
aspect='auto',
|
175 |
+
interpolation='nearest',
|
176 |
+
cmap=ListedColormap(['white'] + [(random(),random(),random()) for i in range(len(index_to_label) - 1)]))
|
177 |
+
ax.set_yticklabels([])
|
178 |
+
ax.set_xlabel('frames')
|
179 |
+
cbar = fig.colorbar(raster)
|
180 |
+
labels = [label_to_appear_first] + list(unique_labels)
|
181 |
+
spacing = (len(labels) - 1)/len(labels)
|
182 |
+
start = spacing/2
|
183 |
+
ticks = [start] + [start + spacing*i for i in range(1,len(labels))]
|
184 |
+
cbar.set_ticks(ticks=ticks, labels = labels)
|
185 |
+
|
186 |
+
st.pyplot(fig)
|
187 |
+
|
188 |
+
# save generated annotations
|
189 |
+
annotations = [labels[x] for x in st.session_state.smoothed_predictions]
|
190 |
+
annotations_df = pd.DataFrame(annotations, columns=['label'])
|
191 |
+
csv = annotations_df.to_csv(header=False).encode("utf-8")
|
192 |
+
output_file_name = st.text_input("Output File Name:","output")
|
193 |
+
st.download_button("Download annotations as .csv",
|
194 |
+
data=csv,
|
195 |
+
file_name=f"{output_file_name}.csv")
|
explore.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import plotly.express as px
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
from utils.mp4Io import mp4Io_reader
|
7 |
+
from utils.seqIo import seqIo_reader
|
8 |
+
import pandas as pd
|
9 |
+
from PIL import Image
|
10 |
+
from pathlib import Path
|
11 |
+
from transformers import AutoProcessor, AutoModel
|
12 |
+
from tempfile import NamedTemporaryFile
|
13 |
+
from tqdm import tqdm
|
14 |
+
from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
|
15 |
+
from get_llava_response import get_llava_response, load_llava_checkpoint_hf
|
16 |
+
from sklearn.manifold import TSNE
|
17 |
+
from openai import OpenAI
|
18 |
+
import cv2
|
19 |
+
import base64
|
20 |
+
from hdbscan import HDBSCAN, all_points_membership_vectors
|
21 |
+
import random
|
22 |
+
|
23 |
+
# --server.maxUploadSize 3000
|
24 |
+
REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
|
25 |
+
|
26 |
+
def load_llava_model(hf_token):
|
27 |
+
return load_llava_checkpoint_hf(REPO_NAME, hf_token)
|
28 |
+
|
29 |
+
def get_unique_labels(label_list: list[str]):
|
30 |
+
label_set = set()
|
31 |
+
for label in label_list:
|
32 |
+
individual_labels = label.split('||')
|
33 |
+
for individual_label in individual_labels:
|
34 |
+
label_set.add(individual_label)
|
35 |
+
return list(label_set)
|
36 |
+
|
37 |
+
SYSTEM_PROMPT = """You are a researcher studying mice interactions from videos of the inside of a resident
|
38 |
+
intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one).
|
39 |
+
Your job is to answer questions about the behavior of the mice in the image given the context that each image is a frame of a continuous video.
|
40 |
+
Thus, you should use the visual information about the mice in the image to try to provide a detailed behavioral description of the image."""
|
41 |
+
|
42 |
+
@st.cache_resource
|
43 |
+
def get_io_reader(uploaded_file):
|
44 |
+
if uploaded_file.name[-3:]=='seq':
|
45 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
46 |
+
temp.write(uploaded_file.getvalue())
|
47 |
+
sr = seqIo_reader(temp.name)
|
48 |
+
else:
|
49 |
+
with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
|
50 |
+
temp.write(uploaded_file.getvalue())
|
51 |
+
sr = mp4Io_reader(temp.name)
|
52 |
+
return sr
|
53 |
+
|
54 |
+
def get_image(sr, frame_no: int):
|
55 |
+
image, _ = sr.getFrame(frame_no)
|
56 |
+
return image
|
57 |
+
|
58 |
+
@st.cache_data
|
59 |
+
def get_2d_embedding(embeddings: pd.DataFrame):
|
60 |
+
tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
|
61 |
+
embedding_2d = tsne.fit_transform(np.array(embeddings))
|
62 |
+
return embedding_2d
|
63 |
+
|
64 |
+
HDBSCAN_PARAMS = {
|
65 |
+
'min_samples': 1
|
66 |
+
}
|
67 |
+
|
68 |
+
@st.cache_data
|
69 |
+
def hdbscan_classification(umap_embeddings, embeddings_2d, cluster_range):
|
70 |
+
max_num_clusters = -np.infty
|
71 |
+
num_clusters = []
|
72 |
+
min_cluster_size = np.linspace(cluster_range[0], cluster_range[1], 4)
|
73 |
+
for min_c in min_cluster_size:
|
74 |
+
learned_hierarchy = HDBSCAN(
|
75 |
+
prediction_data=True, min_cluster_size=int(round(min_c * 0.01 *umap_embeddings.shape[0])),
|
76 |
+
cluster_selection_method='leaf' ,
|
77 |
+
**HDBSCAN_PARAMS).fit(umap_embeddings)
|
78 |
+
num_clusters.append(len(np.unique(learned_hierarchy.labels_)))
|
79 |
+
if num_clusters[-1] > max_num_clusters:
|
80 |
+
max_num_clusters = num_clusters[-1]
|
81 |
+
retained_hierarchy = learned_hierarchy
|
82 |
+
assignments = retained_hierarchy.labels_
|
83 |
+
assign_prob = all_points_membership_vectors(retained_hierarchy)
|
84 |
+
soft_assignments = np.argmax(assign_prob, axis=1)
|
85 |
+
retained_hierarchy.fit(embeddings_2d)
|
86 |
+
return retained_hierarchy, assignments, assign_prob, soft_assignments
|
87 |
+
|
88 |
+
def upload_image(frame: np.ndarray):
|
89 |
+
"""returns the file ID."""
|
90 |
+
_, encoded_image = cv2.imencode('.png', frame)
|
91 |
+
return base64.b64encode(encoded_image.tobytes()).decode('utf-8')
|
92 |
+
|
93 |
+
def ask_question_with_image_gpt(file_id, system_prompt, question, api_key):
|
94 |
+
"""Asks a question about the uploaded image."""
|
95 |
+
client = OpenAI(api_key=api_key)
|
96 |
+
|
97 |
+
if file_id != None:
|
98 |
+
response = client.chat.completions.create(
|
99 |
+
model="gpt-4o",
|
100 |
+
messages=[
|
101 |
+
{"role": "system", "content": system_prompt},
|
102 |
+
{"role": "user", "content": [
|
103 |
+
{"type": "text", "text": question},
|
104 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpg:base64, {file_id}"}}]
|
105 |
+
}
|
106 |
+
]
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
response = client.chat.completions.create(
|
110 |
+
model="gpt-4o",
|
111 |
+
messages=[
|
112 |
+
{"role": "system", "content": system_prompt},
|
113 |
+
{"role": "user", "content": question}
|
114 |
+
]
|
115 |
+
)
|
116 |
+
return response.choices[0].message.content
|
117 |
+
|
118 |
+
def ask_question_with_image_llava(image, system_prompt, question,
|
119 |
+
tokenizer, model, image_processor):
|
120 |
+
outputs = get_llava_response([question],
|
121 |
+
[image],
|
122 |
+
system_prompt,
|
123 |
+
tokenizer,
|
124 |
+
model,
|
125 |
+
image_processor,
|
126 |
+
REPO_NAME,
|
127 |
+
stream_output=False)
|
128 |
+
return outputs[0]
|
129 |
+
|
130 |
+
def ask_summary_question(image_array, label_array, api_key):
|
131 |
+
# load llava model
|
132 |
+
tokenizer, model, image_processor = load_llava_model(hf_token)
|
133 |
+
|
134 |
+
# global variable
|
135 |
+
system_prompt = SYSTEM_PROMPT
|
136 |
+
|
137 |
+
# collect responses
|
138 |
+
responses = []
|
139 |
+
|
140 |
+
# create progress bar
|
141 |
+
j = 0
|
142 |
+
pbar_text = lambda j: f'Creating llava response {j}/{len(label_array)}.'
|
143 |
+
pbar = st.progress(0, text=pbar_text(0))
|
144 |
+
|
145 |
+
for i, image in enumerate(image_array):
|
146 |
+
label = label_array[i]
|
147 |
+
question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "
|
148 |
+
question += "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
|
149 |
+
"how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)"\
|
150 |
+
"of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
|
151 |
+
response = ask_question_with_image_llava(image, system_prompt, question,
|
152 |
+
tokenizer, model, image_processor)
|
153 |
+
responses.append(response)
|
154 |
+
# update progress bar
|
155 |
+
j += 1
|
156 |
+
pbar.progress(j/len(label_array), pbar_text(j))
|
157 |
+
|
158 |
+
system_prompt_summarize = "You are a researcher studying mice interactions from videos of the inside of a resident "\
|
159 |
+
"intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one). "\
|
160 |
+
"You will be given a question about a list of descriptions from frames of these videos. "\
|
161 |
+
"Your job is to answer the question by focusing on the behaviors of the mice and their postures "\
|
162 |
+
"as well as any other aspects of the descriptions that may be relevant to the class label associated with them"
|
163 |
+
user_prompt_summarize = "Here are several descriptions of individual frames from a mouse behavior video. Please summarize these descriptions and provide a suggestion for a "\
|
164 |
+
"behavior label which captures what is described in the descriptions: \n\n"
|
165 |
+
user_prompt_summarize = user_prompt_summarize + '\n'.join(responses)
|
166 |
+
summary_response = ask_question_with_image_gpt(None, system_prompt_summarize, user_prompt_summarize, api_key)
|
167 |
+
return summary_response
|
168 |
+
|
169 |
+
if "embeddings_df" not in st.session_state:
|
170 |
+
st.session_state.embeddings_df = None
|
171 |
+
|
172 |
+
st.title('batik: frame classifier')
|
173 |
+
|
174 |
+
api_key = st.text_input("OpenAI API Key:","")
|
175 |
+
hf_token = st.text_input("HuggingFace Token:","")
|
176 |
+
st.subheader("generate or import embeddings")
|
177 |
+
|
178 |
+
st.text("Upload files to generate embeddings.")
|
179 |
+
with st.form('embedding_generation_settings'):
|
180 |
+
seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
|
181 |
+
annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
|
182 |
+
downsample_rate = st.number_input('Downsample Rate',value=4)
|
183 |
+
submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
|
184 |
+
|
185 |
+
st.markdown("**(Optional)** Upload embeddings.")
|
186 |
+
embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
|
187 |
+
|
188 |
+
if submit_embed_settings and seq_file is not None and annot_files is not None:
|
189 |
+
video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
|
190 |
+
"SLIP",
|
191 |
+
downsample_rate,
|
192 |
+
False)
|
193 |
+
|
194 |
+
fnames = [seq_file.name]
|
195 |
+
embeddings_df = create_embeddings_csv_io(out="file",
|
196 |
+
fnames=fnames,
|
197 |
+
embeddings=video_embeddings,
|
198 |
+
frames=video_frames,
|
199 |
+
annotations=[annot_files],
|
200 |
+
test_fnames=None,
|
201 |
+
views=None,
|
202 |
+
conditions=None,
|
203 |
+
downsample_rate=downsample_rate)
|
204 |
+
st.session_state.embeddings_df = embeddings_df
|
205 |
+
elif embeddings_csv is not None:
|
206 |
+
embeddings_df = pd.read_csv(embeddings_csv)
|
207 |
+
st.session_state.embeddings_df = embeddings_df
|
208 |
+
else:
|
209 |
+
st.text('Please upload file(s).')
|
210 |
+
|
211 |
+
st.divider()
|
212 |
+
st.subheader("provide video file if not yet already provided")
|
213 |
+
|
214 |
+
uploaded_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
|
215 |
+
|
216 |
+
st.divider()
|
217 |
+
if st.session_state.embeddings_df is not None and (uploaded_file is not None or seq_file is not None):
|
218 |
+
if seq_file is not None:
|
219 |
+
uploaded_file = seq_file
|
220 |
+
io_reader = get_io_reader(uploaded_file)
|
221 |
+
print("CONVERTED SEQ")
|
222 |
+
label_list = st.session_state.embeddings_df['Label'].to_list()
|
223 |
+
unique_label_list = get_unique_labels(label_list)
|
224 |
+
print(f"unique_labels: {unique_label_list}")
|
225 |
+
#unique_label_list = ['check_genital', 'wiggle', 'lordose', 'stay', 'turn', 'top_up', 'dart', 'sniff', 'approach', 'into_male_cage']
|
226 |
+
#unique_label_list = ['into_male_cage', 'intromission', 'male_sniff', 'mount']
|
227 |
+
kwargs = {'embeddings_df' : st.session_state.embeddings_df,
|
228 |
+
'specified_classes' : unique_label_list,
|
229 |
+
'classes_to_remove' : None,
|
230 |
+
'max_class_size' : None,
|
231 |
+
'animal_state' : None,
|
232 |
+
'view' : None,
|
233 |
+
'shuffle_data' : False,
|
234 |
+
'test_videos' : None}
|
235 |
+
train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
|
236 |
+
print("PROCESSED DATASET")
|
237 |
+
if "Images" in st.session_state.embeddings_df.keys():
|
238 |
+
train_images = [i for i in range(len(train_images))]
|
239 |
+
embedding_2d = get_2d_embedding(train_embeds)
|
240 |
+
else:
|
241 |
+
st.text('Please generate embeddings and provide video file.')
|
242 |
+
print("GOT 2D EMBEDS")
|
243 |
+
|
244 |
+
if uploaded_file is not None and st.session_state.embeddings_df is not None:
|
245 |
+
st.subheader("t-SNE Projection")
|
246 |
+
option = st.selectbox(
|
247 |
+
"Select Color Option",
|
248 |
+
("By Label", "By Time", "By Cluster")
|
249 |
+
)
|
250 |
+
if embedding_2d is not None:
|
251 |
+
if option is not None:
|
252 |
+
if option == "By Label":
|
253 |
+
color = 'label'
|
254 |
+
elif option == "By Time":
|
255 |
+
color = 'frame_no'
|
256 |
+
else:
|
257 |
+
color = 'cluster_label'
|
258 |
+
|
259 |
+
if option in ["By Label", "By Time"]:
|
260 |
+
edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
|
261 |
+
edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
|
262 |
+
edf.insert(3, 'label', train_labels)
|
263 |
+
fig = px.scatter(
|
264 |
+
edf,
|
265 |
+
x="tsne_dim_1",
|
266 |
+
y="tsne_dim_2",
|
267 |
+
color=color,
|
268 |
+
hover_data=["frame_no"],
|
269 |
+
color_discrete_sequence=px.colors.qualitative.Dark24
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
r, _, _, _ = hdbscan_classification(train_embeds, embedding_2d, [4, 6])
|
273 |
+
edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
|
274 |
+
edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
|
275 |
+
edf.insert(3, 'label', train_labels)
|
276 |
+
edf.insert(4, 'cluster_label', [str(c_id) for c_id in r.labels_.tolist()])
|
277 |
+
fig = px.scatter(
|
278 |
+
edf,
|
279 |
+
x="tsne_dim_1",
|
280 |
+
y="tsne_dim_2",
|
281 |
+
color=color,
|
282 |
+
hover_data=["frame_no"],
|
283 |
+
color_discrete_sequence=px.colors.qualitative.Dark24
|
284 |
+
)
|
285 |
+
|
286 |
+
event = st.plotly_chart(fig, key="df", on_select="rerun")
|
287 |
+
else:
|
288 |
+
st.text("No Color Option Selected")
|
289 |
+
else:
|
290 |
+
st.text('No Embeddings Loaded')
|
291 |
+
|
292 |
+
event_dict = event.selection
|
293 |
+
|
294 |
+
if event_dict is not None:
|
295 |
+
custom_data = []
|
296 |
+
for point in event_dict['points']:
|
297 |
+
data = point["customdata"][0]
|
298 |
+
custom_data.append(int(data))
|
299 |
+
|
300 |
+
if len(custom_data) > 10:
|
301 |
+
custom_data = random.sample(custom_data, 10)
|
302 |
+
if len(custom_data) > 1:
|
303 |
+
col_1, col_2 = st.columns(2)
|
304 |
+
with col_1:
|
305 |
+
for frame_no in custom_data[::2]:
|
306 |
+
st.image(get_image(io_reader, frame_no))
|
307 |
+
st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
|
308 |
+
with col_2:
|
309 |
+
for frame_no in custom_data[1::2]:
|
310 |
+
st.image(get_image(io_reader, frame_no))
|
311 |
+
st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
|
312 |
+
elif len(custom_data) == 1:
|
313 |
+
frame_no = custom_data[0]
|
314 |
+
st.image(get_image(io_reader, frame_no))
|
315 |
+
st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
|
316 |
+
else:
|
317 |
+
st.text('No Points Selected')
|
318 |
+
|
319 |
+
if len(custom_data) == 1:
|
320 |
+
frame_no = custom_data[0]
|
321 |
+
image = get_image(io_reader, frame_no)
|
322 |
+
system_prompt = SYSTEM_PROMPT
|
323 |
+
label = train_labels[frame_no]
|
324 |
+
question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "\
|
325 |
+
"Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
|
326 |
+
"how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)" \
|
327 |
+
"of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
|
328 |
+
tokenizer, model, image_processor = load_llava_model(hf_token)
|
329 |
+
response = ask_question_with_image_llava(image, system_prompt, question,
|
330 |
+
tokenizer, model, image_processor)
|
331 |
+
st.markdown(response)
|
332 |
+
|
333 |
+
elif len(custom_data) > 1:
|
334 |
+
image_array = [get_image(io_reader, f_no) for f_no in custom_data]
|
335 |
+
label_array = [train_labels[f_no] for f_no in custom_data]
|
336 |
+
response = ask_summary_question(image_array, label_array, api_key)
|
337 |
+
st.markdown(response)
|
generate_embeddings.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from utils.utils import create_embeddings_csv_io, create_annot_fname_dict_io, generate_embeddings_stream_io
|
4 |
+
|
5 |
+
if "video_embeddings" not in st.session_state:
|
6 |
+
st.session_state.video_embeddings = None
|
7 |
+
st.session_state.video_frames = None
|
8 |
+
st.session_state.fnames = []
|
9 |
+
|
10 |
+
st.title('batik: embedding generator')
|
11 |
+
uploaded_files = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
|
12 |
+
with st.form('initial_settings'):
|
13 |
+
st.header('Embedding Generation Options')
|
14 |
+
model_select = st.selectbox('Select Model', ['SLIP', 'CLIP'])
|
15 |
+
downsample_rate = st.number_input('Downsample Rate',value=4)
|
16 |
+
save_csv = st.toggle('Save Individual Results', value=False)
|
17 |
+
submit_initial_settings = st.form_submit_button('Create Embeddings', type='secondary')
|
18 |
+
|
19 |
+
if submit_initial_settings and uploaded_files is not None and len(uploaded_files) > 0:
|
20 |
+
video_embeddings, video_frames = generate_embeddings_stream_io(uploaded_files,
|
21 |
+
model_select,
|
22 |
+
downsample_rate,
|
23 |
+
save_csv)
|
24 |
+
fnames = [vid_file.name for vid_file in uploaded_files]
|
25 |
+
st.session_state.video_embeddings = video_embeddings
|
26 |
+
st.session_state.video_frames = video_frames
|
27 |
+
st.session_state.fnames = fnames
|
28 |
+
|
29 |
+
if st.session_state.video_embeddings is not None:
|
30 |
+
st.header('CSV Configuration Options')
|
31 |
+
st.markdown('If using `.annot` files and multiple files should be grouped together, '\
|
32 |
+
'please ensure that they share a common name and end with a number describing '\
|
33 |
+
'the order of the files. For example:\n\n'\
|
34 |
+
'`mouse_224_file_1.annot`, `mouse_224_file_2.annot`.')
|
35 |
+
annot_files = st.file_uploader("Upload all annotation files", type=['.annot','.csv'], accept_multiple_files=True)
|
36 |
+
|
37 |
+
annot_options = []
|
38 |
+
if annot_files is not None and len(annot_files) > 0:
|
39 |
+
annot_fnames = [annot_file.name for annot_file in annot_files]
|
40 |
+
annot_fname_dict = create_annot_fname_dict_io(annot_fnames=annot_fnames,
|
41 |
+
annot_files=annot_files)
|
42 |
+
annot_options = [str(key) for key in annot_fname_dict.keys()]
|
43 |
+
|
44 |
+
if len(annot_options) > 0:
|
45 |
+
with st.form('csv_settings'):
|
46 |
+
csv_setting_def = pd.DataFrame(
|
47 |
+
{
|
48 |
+
"File Name" : st.session_state.fnames,
|
49 |
+
"Annotations" : [
|
50 |
+
"Upload File" for _ in st.session_state.fnames
|
51 |
+
],
|
52 |
+
"Test" : [
|
53 |
+
False for _ in st.session_state.fnames
|
54 |
+
],
|
55 |
+
"View" : [
|
56 |
+
"Top" for _ in st.session_state.fnames
|
57 |
+
],
|
58 |
+
"Condition" : [
|
59 |
+
"None" for _ in st.session_state.fnames
|
60 |
+
]
|
61 |
+
|
62 |
+
}
|
63 |
+
)
|
64 |
+
|
65 |
+
csv_settings = st.data_editor(
|
66 |
+
csv_setting_def,
|
67 |
+
column_config={
|
68 |
+
"Annotations" : st.column_config.SelectboxColumn(
|
69 |
+
"Annotations",
|
70 |
+
help="The annotation file(s) to use for the given video file.",
|
71 |
+
width="medium",
|
72 |
+
options=annot_options,
|
73 |
+
required=True
|
74 |
+
),
|
75 |
+
"Test" : st.column_config.CheckboxColumn(
|
76 |
+
"Test",
|
77 |
+
help="Designate file(s) to use as the test set.",
|
78 |
+
default=False,
|
79 |
+
required=True
|
80 |
+
),
|
81 |
+
"View" : st.column_config.SelectboxColumn(
|
82 |
+
"View",
|
83 |
+
help="The view used within the video (either Top or Front).",
|
84 |
+
options=["Top", "Front"],
|
85 |
+
required=True
|
86 |
+
),
|
87 |
+
"Condition" : st.column_config.TextColumn(
|
88 |
+
"Condition",
|
89 |
+
help="A condition the video has (i.e. Control).",
|
90 |
+
default="None",
|
91 |
+
max_chars=30,
|
92 |
+
validate=r"[a-z]+$",
|
93 |
+
)
|
94 |
+
},
|
95 |
+
hide_index=True
|
96 |
+
)
|
97 |
+
save_csv_bttn = st.form_submit_button("Create CSV")
|
98 |
+
|
99 |
+
if save_csv_bttn and csv_settings is not None:
|
100 |
+
annot_chosen_options = csv_settings['Annotations'].tolist()
|
101 |
+
annot_option = [annot_fname_dict[key] for key in annot_chosen_options]
|
102 |
+
test_chosen_option = csv_settings['Test'].tolist()
|
103 |
+
test_option = [st.session_state.fnames[i] for i, is_test in enumerate(test_chosen_option) if is_test]
|
104 |
+
view_option = csv_settings['View'].tolist()
|
105 |
+
condition_option = csv_settings['Condition'].tolist()
|
106 |
+
|
107 |
+
out_name = st.text_input("Embeddings Outpit File Name", "out.csv")
|
108 |
+
try:
|
109 |
+
df = create_embeddings_csv_io(out=out_name,
|
110 |
+
fnames=st.session_state.fnames,
|
111 |
+
embeddings=st.session_state.video_embeddings,
|
112 |
+
frames=st.session_state.video_frames,
|
113 |
+
annotations=annot_option,
|
114 |
+
test_fnames=test_option,
|
115 |
+
views=view_option,
|
116 |
+
conditions=condition_option,
|
117 |
+
downsample_rate=downsample_rate)
|
118 |
+
st.success('Created Embeddings File!', icon="✅")
|
119 |
+
st.download_button(
|
120 |
+
label="Download CSV",
|
121 |
+
data=df.to_csv().encode("utf-8"),
|
122 |
+
file_name=out_name,
|
123 |
+
mime="text/csv"
|
124 |
+
)
|
125 |
+
except:
|
126 |
+
st.error('Something went wrong.')
|
127 |
+
|
128 |
+
else:
|
129 |
+
st.text('Please Upload Files')
|
130 |
+
else:
|
131 |
+
st.text('Please Upload Files')
|
get_llava_response.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from huggingface_hub import whoami
|
7 |
+
|
8 |
+
import llava
|
9 |
+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
|
10 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
11 |
+
from llava.model.builder import load_pretrained_model
|
12 |
+
from llava.utils import disable_torch_init
|
13 |
+
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
import requests
|
18 |
+
from PIL import Image
|
19 |
+
from io import BytesIO
|
20 |
+
from transformers import TextStreamer
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
import warnings
|
24 |
+
warnings.filterwarnings('ignore')
|
25 |
+
|
26 |
+
REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
|
27 |
+
|
28 |
+
def load_image(image_file):
|
29 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
30 |
+
response = requests.get(image_file)
|
31 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
32 |
+
else:
|
33 |
+
image = Image.open(image_file).convert('RGB')
|
34 |
+
return image
|
35 |
+
|
36 |
+
def load_llava_checkpoint(model_path: str):
|
37 |
+
model_name = get_model_name_from_path(model_path)
|
38 |
+
return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda")
|
39 |
+
|
40 |
+
def load_llava_checkpoint_hf(model_path, hf_token):
|
41 |
+
user = whoami(token=hf_token)
|
42 |
+
kwargs = {"device_map": "auto"}
|
43 |
+
kwargs['load_in_4bit'] = True
|
44 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
45 |
+
load_in_4bit=True,
|
46 |
+
bnb_4bit_compute_dtype=torch.float16,
|
47 |
+
bnb_4bit_use_double_quant=True,
|
48 |
+
bnb_4bit_quant_type='nf4'
|
49 |
+
)
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
51 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
52 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
53 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
54 |
+
if mm_use_im_patch_token:
|
55 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
56 |
+
if mm_use_im_start_end:
|
57 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
58 |
+
model.resize_token_embeddings(len(tokenizer))
|
59 |
+
|
60 |
+
vision_tower = model.get_vision_tower()
|
61 |
+
if not vision_tower.is_loaded:
|
62 |
+
vision_tower.load_model(device_map="auto")
|
63 |
+
image_processor = vision_tower.image_processor
|
64 |
+
return tokenizer, model, image_processor
|
65 |
+
|
66 |
+
def get_llava_response(user_prompts: list[str],
|
67 |
+
images: list,
|
68 |
+
sys_prompt: str,
|
69 |
+
tokenizer,
|
70 |
+
model,
|
71 |
+
image_processor,
|
72 |
+
model_path = REPO_NAME,
|
73 |
+
stream_output = True):
|
74 |
+
"""
|
75 |
+
This function returns the response from the given model. It creates a one turn conversation in which
|
76 |
+
the only content is a system prompt and the given user message applied to each image.
|
77 |
+
|
78 |
+
Parameters:
|
79 |
+
----------
|
80 |
+
user_prompt : str
|
81 |
+
The prompt sent by the user.
|
82 |
+
images : str
|
83 |
+
List of images from file.
|
84 |
+
sys_prompt : str
|
85 |
+
The prompt that sets the tone for the conversation.
|
86 |
+
model_path : str
|
87 |
+
The path to the merged checkpoint or base model.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
--------
|
91 |
+
"""
|
92 |
+
# set up and load model
|
93 |
+
model_name = get_model_name_from_path(model_path)
|
94 |
+
temperature = 0.2 # default
|
95 |
+
max_new_tokens = 512 # default
|
96 |
+
|
97 |
+
# determine conversation type
|
98 |
+
if "llama-2" in model_name.lower():
|
99 |
+
conv_mode = "llava_llama_2"
|
100 |
+
elif "mistral" in model_name.lower():
|
101 |
+
conv_mode = "mistral_instruct"
|
102 |
+
elif "v1.6-34b" in model_name.lower():
|
103 |
+
conv_mode = "chatml_direct"
|
104 |
+
elif "v1" in model_name.lower():
|
105 |
+
conv_mode = "llava_v1"
|
106 |
+
elif "mpt" in model_name.lower():
|
107 |
+
conv_mode = "mpt"
|
108 |
+
else:
|
109 |
+
conv_mode = "llava_v0"
|
110 |
+
|
111 |
+
# run clean conversation for each image
|
112 |
+
llm_outputs = []
|
113 |
+
for i, img in tqdm(enumerate(images)):
|
114 |
+
# set up clean conversation
|
115 |
+
conv = conv_templates[conv_mode].copy()
|
116 |
+
if "mpt" in model_name.lower():
|
117 |
+
roles = ('user', 'assistant')
|
118 |
+
else:
|
119 |
+
roles = conv.roles
|
120 |
+
|
121 |
+
conv.system = sys_prompt
|
122 |
+
|
123 |
+
# load image
|
124 |
+
# image = load_image("../images/mouse.png") # previous method
|
125 |
+
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
126 |
+
img = Image.fromarray(img, 'L')
|
127 |
+
|
128 |
+
image = img.convert('RGB')
|
129 |
+
image_size = image.size
|
130 |
+
|
131 |
+
# NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files!
|
132 |
+
|
133 |
+
# Similar operation in model_worker.py
|
134 |
+
image_tensor = process_images([image], image_processor, model.config)
|
135 |
+
if type(image_tensor) is list:
|
136 |
+
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
137 |
+
else:
|
138 |
+
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
139 |
+
|
140 |
+
# execute conversation
|
141 |
+
inp = user_prompts[i]
|
142 |
+
if image is not None:
|
143 |
+
# first message
|
144 |
+
if model.config.mm_use_im_start_end:
|
145 |
+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
146 |
+
else:
|
147 |
+
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
|
148 |
+
image = None
|
149 |
+
conv.append_message(conv.roles[0], inp)
|
150 |
+
conv.append_message(conv.roles[1], None)
|
151 |
+
prompt = conv.get_prompt()
|
152 |
+
input_ids = tokenizer_image_token(prompt,
|
153 |
+
tokenizer,
|
154 |
+
IMAGE_TOKEN_INDEX,
|
155 |
+
return_tensors='pt').unsqueeze(0).to(model.device)
|
156 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
157 |
+
keywords = [stop_str]
|
158 |
+
if stream_output:
|
159 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
160 |
+
else:
|
161 |
+
streamer = None
|
162 |
+
|
163 |
+
with torch.inference_mode():
|
164 |
+
output_ids = model.generate(
|
165 |
+
input_ids,
|
166 |
+
images=image_tensor,
|
167 |
+
image_sizes=[image_size],
|
168 |
+
do_sample=True if temperature > 0 else False,
|
169 |
+
temperature=temperature,
|
170 |
+
max_new_tokens=max_new_tokens,
|
171 |
+
streamer=streamer,
|
172 |
+
use_cache=True)
|
173 |
+
|
174 |
+
outputs = tokenizer.decode(output_ids[0]).strip()
|
175 |
+
llm_outputs.append(outputs)
|
176 |
+
return llm_outputs
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
home.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.title('batik')
|
4 |
+
st.subheader('Abstract')
|
5 |
+
st.markdown('Quantitative analysis of animal behavior represents a burgeoning frontier in neuroscience and ethology. Recent years have witnessed a proliferation of computational methods aimed at identifying behavioral subtypes, or "syllables," from video data. However, while significant advances have been made in behavior segmentation, comparatively few approaches address the interpretation of these behavior syllables, leaving researchers to spend considerable time curating and interpreting the characteristics of the behavioral subtype. Furthermore, most current techniques rely heavily on pose estimation—a prerequisite that, while useful, can introduce limitations concerning generalization in behavioral classification and discovery. Here, we introduce Batik, a system leveraging pre-trained and fine-tuned multimodal transformers to perform end-to-end behavior analysis directly from raw video. Batik excels at supervised behavior annotation, utilizing lightweight models trained on the transformer-extracted feature space to achieve state-of-the-art performance. By integrating a pre-trained vision transformer with a custom fine-tuned language model, Batik not only discovers behavior syllables but also provides expert-level interpretations of mouse behavior, directly from visual data. This comprehensive platform empowers researchers with automated behavior discovery and interpretation, significantly reducing the time burden on experimentalists. Coupled with an intuitive user interface, Batik offers a transformative tool for the next generation of behavioral analysis, showcasing the potential of what is possible with transformer-based language models for behavior.')
|
pyproject.toml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "llava"
|
7 |
+
version = "1.2.2.post1"
|
8 |
+
description = "Towards GPT-4 like large language and visual assistant."
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.8"
|
11 |
+
classifiers = [
|
12 |
+
"Programming Language :: Python :: 3",
|
13 |
+
"License :: OSI Approved :: Apache Software License",
|
14 |
+
]
|
15 |
+
dependencies = [
|
16 |
+
"torch==2.1.2", "torchvision==0.16.2",
|
17 |
+
"transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
|
18 |
+
"accelerate==0.21.0", "peft", "bitsandbytes",
|
19 |
+
"pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
|
20 |
+
"gradio==4.16.0", "gradio_client==0.8.1",
|
21 |
+
"requests", "httpx==0.24.0", "uvicorn", "fastapi",
|
22 |
+
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
|
23 |
+
"protobuf", "timecode", "sortedcontainers", "qtpy", "pyqt5-tools",
|
24 |
+
"scipy", "matplotlib", "colour_demosaicing", "sk-video",
|
25 |
+
"opencv-python", "progressbar", "openai",
|
26 |
+
"clip @ git+https://github.com/openai/CLIP@main",
|
27 |
+
"scikit-learn", "tensorflow", "sentencepiece", "streamlit",
|
28 |
+
"hdbscan", "plotly", "ipywidgets"
|
29 |
+
]
|
30 |
+
|
31 |
+
[project.optional-dependencies]
|
32 |
+
train = ["deepspeed==0.12.6", "ninja", "wandb"]
|
33 |
+
build = ["build", "twine"]
|
34 |
+
|
35 |
+
[project.urls]
|
36 |
+
"Homepage" = "https://llava-vl.github.io"
|
37 |
+
"Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
|
38 |
+
|
39 |
+
[tool.setuptools.packages.find]
|
40 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
41 |
+
|
42 |
+
[tool.wheel]
|
43 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
train_model.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import pickle
|
4 |
+
import regex
|
5 |
+
import streamlit as st
|
6 |
+
import plotly.express as px
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
from utils.seqIo import seqIo_reader
|
11 |
+
import pandas as pd
|
12 |
+
from PIL import Image
|
13 |
+
from pathlib import Path
|
14 |
+
from transformers import AutoProcessor, AutoModel
|
15 |
+
from tqdm import tqdm
|
16 |
+
from sklearn.svm import SVC
|
17 |
+
from sklearn.model_selection import train_test_split
|
18 |
+
from sklearn.metrics import accuracy_score, classification_report
|
19 |
+
from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
|
20 |
+
|
21 |
+
# --server.maxUploadSize 3000
|
22 |
+
|
23 |
+
def get_unique_labels(label_list: list[str]):
|
24 |
+
label_set = set()
|
25 |
+
for label in label_list:
|
26 |
+
individual_labels = label.split('||')
|
27 |
+
for individual_label in individual_labels:
|
28 |
+
label_set.add(individual_label)
|
29 |
+
return list(label_set)
|
30 |
+
|
31 |
+
@st.cache_data
|
32 |
+
def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
|
33 |
+
return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
|
34 |
+
|
35 |
+
@st.cache_resource
|
36 |
+
def train_model(X_train, y_train, random_state=42):
|
37 |
+
# Train SVM Classifier
|
38 |
+
svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
|
39 |
+
svm_clf.fit(X_train, y_train)
|
40 |
+
return svm_clf
|
41 |
+
|
42 |
+
def pickle_model(model):
|
43 |
+
pickled = io.BytesIO()
|
44 |
+
pickle.dump(model, pickled)
|
45 |
+
return pickled
|
46 |
+
|
47 |
+
if "embeddings_df" not in st.session_state:
|
48 |
+
st.session_state.embeddings_df = None
|
49 |
+
|
50 |
+
if "svm_clf" not in st.session_state:
|
51 |
+
st.session_state.svm_clf = None
|
52 |
+
st.session_state.report_df = None
|
53 |
+
st.session_state.accuracy = None
|
54 |
+
|
55 |
+
st.title('batik: frame classifier training')
|
56 |
+
|
57 |
+
st.text("Upload files to train classifier on.")
|
58 |
+
with st.form('embedding_generation_settings'):
|
59 |
+
seq_file = st.file_uploader("Choose a .seq File", type=['seq'])
|
60 |
+
annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
|
61 |
+
downsample_rate = st.number_input('Downsample Rate',value=4)
|
62 |
+
submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
|
63 |
+
|
64 |
+
st.markdown("**(Optional)** Upload embeddings.")
|
65 |
+
embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
|
66 |
+
|
67 |
+
if submit_embed_settings and seq_file is not None and annot_files is not None:
|
68 |
+
video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
|
69 |
+
"SLIP",
|
70 |
+
downsample_rate,
|
71 |
+
False)
|
72 |
+
|
73 |
+
fnames = [seq_file.name]
|
74 |
+
embeddings_df = create_embeddings_csv_io(out="file",
|
75 |
+
fnames=fnames,
|
76 |
+
embeddings=video_embeddings,
|
77 |
+
frames=video_frames,
|
78 |
+
annotations=[annot_files],
|
79 |
+
test_fnames=None,
|
80 |
+
views=None,
|
81 |
+
conditions=None,
|
82 |
+
downsample_rate=downsample_rate)
|
83 |
+
st.session_state.embeddings_df = embeddings_df
|
84 |
+
|
85 |
+
elif embeddings_csv is not None:
|
86 |
+
embeddings_df = pd.read_csv(embeddings_csv)
|
87 |
+
st.session_state.embeddings_df = embeddings_df
|
88 |
+
else:
|
89 |
+
st.text('Please upload file(s).')
|
90 |
+
|
91 |
+
st.divider()
|
92 |
+
|
93 |
+
if st.session_state.embeddings_df is not None:
|
94 |
+
st.subheader("specify dataset preprocessing options")
|
95 |
+
st.text("Select frames with label(s) to include:")
|
96 |
+
|
97 |
+
with st.form('train_settings'):
|
98 |
+
label_list = st.session_state.embeddings_df['Label'].to_list()
|
99 |
+
unique_label_list = get_unique_labels(label_list)
|
100 |
+
specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
|
101 |
+
|
102 |
+
st.text("Select label(s) that should be removed:")
|
103 |
+
classes_to_remove = st.multiselect("Label(s) excluded:", options=unique_label_list)
|
104 |
+
|
105 |
+
max_class_size = st.number_input("(Optional) Specify max class size:", value=None)
|
106 |
+
|
107 |
+
shuffle_data = st.toggle("Shuffle data:")
|
108 |
+
|
109 |
+
train_model_clicked = st.form_submit_button("Train Model")
|
110 |
+
|
111 |
+
if train_model_clicked:
|
112 |
+
kwargs = {'embeddings_df' : st.session_state.embeddings_df,
|
113 |
+
'specified_classes' : specified_classes,
|
114 |
+
'classes_to_remove' : classes_to_remove,
|
115 |
+
'max_class_size' : max_class_size,
|
116 |
+
'animal_state' : None,
|
117 |
+
'view' : None,
|
118 |
+
'shuffle_data' : shuffle_data,
|
119 |
+
'test_videos' : None}
|
120 |
+
train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
|
121 |
+
# Convert labels to numerical values
|
122 |
+
label_to_appear_first = 'other'
|
123 |
+
unique_labels = set(train_labels)
|
124 |
+
unique_labels.discard(label_to_appear_first)
|
125 |
+
|
126 |
+
label_to_index = {label_to_appear_first: 0}
|
127 |
+
|
128 |
+
label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
|
129 |
+
index_to_label = {idx: label for label, idx in label_to_index.items()}
|
130 |
+
numerical_labels = np.array([label_to_index[label] for label in train_labels])
|
131 |
+
|
132 |
+
print("Label Valence: ", label_to_index)
|
133 |
+
# Split data into train and test sets
|
134 |
+
X_train, X_test, y_train, y_test = get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42)
|
135 |
+
with st.spinner("Model training in progress..."):
|
136 |
+
svm_clf = train_model(X_train, y_train)
|
137 |
+
|
138 |
+
# Predict on the test set
|
139 |
+
with st.spinner("In progress..."):
|
140 |
+
y_pred = svm_clf.predict(X_test)
|
141 |
+
accuracy = accuracy_score(y_test, y_pred)
|
142 |
+
report = classification_report(y_test, y_pred, target_names=[index_to_label[idx] for idx in range(len(label_to_index))], output_dict=True)
|
143 |
+
report_df = pd.DataFrame(report).transpose()
|
144 |
+
|
145 |
+
# save results to session state
|
146 |
+
st.session_state.svm_clf = svm_clf
|
147 |
+
st.session_state.report_df = report_df
|
148 |
+
st.session_state.accuracy = accuracy
|
149 |
+
|
150 |
+
if st.session_state.svm_clf is not None:
|
151 |
+
pickled_model = pickle_model(st.session_state.svm_clf)
|
152 |
+
|
153 |
+
st.text(f"Eval Accuracy: {st.session_state.accuracy}")
|
154 |
+
st.subheader("Classification Report:")
|
155 |
+
st.dataframe(st.session_state.report_df)
|
156 |
+
|
157 |
+
st.download_button("Download model as .pkl file",
|
158 |
+
data=pickled_model,
|
159 |
+
file_name=f"{'_'.join(specified_classes)}_classifier.pkl")
|
utils/__init__.py
ADDED
File without changes
|
utils/annot.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# annot.py
|
2 |
+
from io import StringIO
|
3 |
+
from random import sample
|
4 |
+
from collections import OrderedDict
|
5 |
+
import timecode as tc
|
6 |
+
from .behavior import Behavior
|
7 |
+
from sortedcontainers import SortedKeyList
|
8 |
+
from qtpy.QtCore import QObject, QRectF, Signal, Slot
|
9 |
+
from qtpy.QtGui import QColor
|
10 |
+
from qtpy.QtWidgets import QGraphicsItem
|
11 |
+
|
12 |
+
class Bout(object):
|
13 |
+
"""
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, start, end, behavior):
|
17 |
+
self._start = start
|
18 |
+
self._end = end
|
19 |
+
self._behavior = behavior
|
20 |
+
|
21 |
+
def __lt__(self, b):
|
22 |
+
if type(b) is tc.Timecode:
|
23 |
+
return self._start.float < b.float
|
24 |
+
elif type(b) is Bout:
|
25 |
+
return self._start < b._start
|
26 |
+
else:
|
27 |
+
raise NotImplementedError(f"Comparing Bout with {type(b)} not supported")
|
28 |
+
|
29 |
+
def __le__(self, b):
|
30 |
+
if type(b) is tc.Timecode:
|
31 |
+
return self._start.float <= b.is_float
|
32 |
+
elif type(b) is Bout:
|
33 |
+
return self._start <= b._start
|
34 |
+
else:
|
35 |
+
raise NotImplementedError(f'Comparing Bout with {type(b)} not supported')
|
36 |
+
|
37 |
+
def is_at(self, t):
|
38 |
+
return self._start <= t and self._end >= t
|
39 |
+
|
40 |
+
def start(self):
|
41 |
+
return self._start
|
42 |
+
|
43 |
+
def set_start(self, start):
|
44 |
+
self._start = start
|
45 |
+
|
46 |
+
def end(self):
|
47 |
+
return self._end
|
48 |
+
|
49 |
+
def set_end(self, end):
|
50 |
+
self._end = end
|
51 |
+
|
52 |
+
def len(self):
|
53 |
+
return self._end - self._start + tc.Timecode(self._start.framerate, frames=1)
|
54 |
+
|
55 |
+
def behavior(self):
|
56 |
+
return self._behavior
|
57 |
+
|
58 |
+
def name(self):
|
59 |
+
return self._behavior.get_name()
|
60 |
+
|
61 |
+
def color(self):
|
62 |
+
return self._behavior.get_color()
|
63 |
+
|
64 |
+
def is_active(self):
|
65 |
+
return self._behavior.is_active()
|
66 |
+
|
67 |
+
def is_visible(self):
|
68 |
+
return self._behavior.is_visible() and self._behavior.is_active()
|
69 |
+
|
70 |
+
def __repr__(self):
|
71 |
+
return f"Bout: start = {self.start()}, end = {self.end()}, behavior: {self.behavior()}"
|
72 |
+
|
73 |
+
class Channel(QGraphicsItem):
|
74 |
+
"""
|
75 |
+
"""
|
76 |
+
|
77 |
+
contentChanged = Signal()
|
78 |
+
|
79 |
+
def __init__(self, chan = None):
|
80 |
+
super().__init__()
|
81 |
+
if not chan is None:
|
82 |
+
self._bouts_by_start = chan._bouts_by_start
|
83 |
+
self._bouts_by_end = chan._bouts_by_end
|
84 |
+
self._top = chan._top
|
85 |
+
else:
|
86 |
+
self._bouts_by_start = SortedKeyList(key=lambda bout: bout.start().float)
|
87 |
+
self._bouts_by_end = SortedKeyList(key=lambda bout: bout.end().float)
|
88 |
+
self._top = 0.
|
89 |
+
self.cur_ix = 0
|
90 |
+
self.fakeFirstBout = Bout(
|
91 |
+
tc.Timecode('30.0', '0:0:0:0'),
|
92 |
+
tc.Timecode('30.0', '0:0:0:0'),
|
93 |
+
Behavior('none', '', QColor.fromRgbF(0., 0., 0.), active = True)
|
94 |
+
)
|
95 |
+
self.fakeLastBout = Bout(
|
96 |
+
tc.Timecode('30.0', '23:59:59:29'),
|
97 |
+
tc.Timecode('30.0', '23:59:59:29'),
|
98 |
+
Behavior('none', '', QColor.fromRgbF(0., 0., 0.), active = True)
|
99 |
+
)
|
100 |
+
|
101 |
+
def add(self, b):
|
102 |
+
if isinstance(b, Bout):
|
103 |
+
self._bouts_by_start.add(b)
|
104 |
+
self._bouts_by_end.add(b)
|
105 |
+
else:
|
106 |
+
raise TypeError("Can only add Bouts to Channel")
|
107 |
+
|
108 |
+
def remove(self, b):
|
109 |
+
# can raise ValueError if b is not in the channel
|
110 |
+
self._bouts_by_start.remove(b)
|
111 |
+
self._bouts_by_end.remove(b)
|
112 |
+
|
113 |
+
def update_start(self, b, new_start):
|
114 |
+
"""
|
115 |
+
Update the starting time of a bout while
|
116 |
+
preserving the _bouts_by_start access order
|
117 |
+
"""
|
118 |
+
self._bouts_by_start.remove(b)
|
119 |
+
b.set_start(new_start)
|
120 |
+
self._bouts_by_start.add(b)
|
121 |
+
|
122 |
+
def update_end(self, b, new_end):
|
123 |
+
"""
|
124 |
+
Update the ending time of a bout while
|
125 |
+
preserving the _bouts_by_end access order
|
126 |
+
"""
|
127 |
+
self._bouts_by_end.remove(b)
|
128 |
+
b.set_end(new_end)
|
129 |
+
self._bouts_by_end.add(b)
|
130 |
+
|
131 |
+
def __add__(self, b):
|
132 |
+
self.add(b)
|
133 |
+
|
134 |
+
def _get_next(self, t, sortedlist):
|
135 |
+
ix = sortedlist.bisect_key_right(t.float)
|
136 |
+
l = len(sortedlist)
|
137 |
+
if ix == l:
|
138 |
+
return self.fakeLastBout, t.next()
|
139 |
+
return sortedlist[ix], t.next()
|
140 |
+
|
141 |
+
def _get_prev(self, t, sortedlist):
|
142 |
+
ix = sortedlist.bisect_key_left(t.float)
|
143 |
+
if ix == 0:
|
144 |
+
return self.fakeFirstBout, t.back()
|
145 |
+
return sortedlist[ix-1], t.back()
|
146 |
+
|
147 |
+
def _get_inner(self, t, sortedList, op):
|
148 |
+
t_local = t + 0 # kludgy copy constructor!
|
149 |
+
visible = False
|
150 |
+
while not visible:
|
151 |
+
# no need to check for the end, because the fake first and last bouts are visible
|
152 |
+
bout, t_local = op(t_local, sortedList)
|
153 |
+
visible = bout.is_visible()
|
154 |
+
return bout
|
155 |
+
|
156 |
+
def get_next_start(self, t):
|
157 |
+
return self._get_inner(t, self._bouts_by_start, self._get_next)
|
158 |
+
|
159 |
+
def get_next_end(self, t):
|
160 |
+
return self._get_inner(t, self._bouts_by_end, self._get_next)
|
161 |
+
|
162 |
+
def get_prev_start(self, t):
|
163 |
+
return self._get_inner(t, self._bouts_by_start, self._get_prev)
|
164 |
+
|
165 |
+
def get_prev_end(self, t):
|
166 |
+
return self._get_inner(t, self._bouts_by_end, self._get_prev)
|
167 |
+
|
168 |
+
def get_in_range(self, start, end):
|
169 |
+
"""
|
170 |
+
get all bouts that intersect the range [start, end]
|
171 |
+
"""
|
172 |
+
return [bout for bout in self._bouts_by_start
|
173 |
+
if bout.start().float <= end.float and bout.end().float >= start.float]
|
174 |
+
|
175 |
+
def get_at(self, t):
|
176 |
+
"""
|
177 |
+
get all bouts that span time t
|
178 |
+
"""
|
179 |
+
return self.get_in_range(t, t)
|
180 |
+
|
181 |
+
def __iter__(self):
|
182 |
+
return iter(self._bouts_by_start)
|
183 |
+
|
184 |
+
def irange(self, start_time, end_time):
|
185 |
+
if isinstance(start_time, tc.Timecode):
|
186 |
+
start_time = start_time.float
|
187 |
+
if isinstance(end_time, tc.Timecode):
|
188 |
+
end_time = end_time.float
|
189 |
+
if not isinstance(start_time, float):
|
190 |
+
raise TypeError(f"Can't handle start_time of type {type(start_time)}")
|
191 |
+
if not isinstance(end_time, float):
|
192 |
+
raise TypeError(f"Can't handle end_time of type {type(end_time)}")
|
193 |
+
return self._bouts_by_start.irange_key(start_time, end_time)
|
194 |
+
|
195 |
+
def top(self):
|
196 |
+
return self._top
|
197 |
+
|
198 |
+
def set_top(self, top):
|
199 |
+
self._top = top
|
200 |
+
|
201 |
+
def boundingRect(self):
|
202 |
+
width = self.fakeLastBout.end().float
|
203 |
+
return QRectF(0., self.top(), width, 1.)
|
204 |
+
|
205 |
+
def paint(self, painter, option, widget=None):
|
206 |
+
boundingRect = option.rect
|
207 |
+
in_range_bouts = self._bouts_by_start.irange_key(boundingRect.left(), boundingRect.right())
|
208 |
+
while True:
|
209 |
+
try:
|
210 |
+
bout = next(in_range_bouts)
|
211 |
+
except StopIteration:
|
212 |
+
break
|
213 |
+
if bout.is_visible():
|
214 |
+
painter.fillRect(
|
215 |
+
QRectF(bout.start().float, self.top(), bout.len().float, 1.),
|
216 |
+
bout.color()
|
217 |
+
)
|
218 |
+
|
219 |
+
def _delete_all_inner(self, predicate):
|
220 |
+
to_delete = list()
|
221 |
+
# can't alter the bouts within the iterator
|
222 |
+
for bout in iter(self):
|
223 |
+
if predicate(bout):
|
224 |
+
to_delete.append(bout)
|
225 |
+
deleted_names = set()
|
226 |
+
for bout in to_delete:
|
227 |
+
deleted_names.add(bout.name())
|
228 |
+
self.remove(bout)
|
229 |
+
return deleted_names
|
230 |
+
|
231 |
+
def delete_bouts_by_name(self, behavior_name):
|
232 |
+
return self._delete_all_inner(lambda bout: bout.name() == behavior_name)
|
233 |
+
|
234 |
+
def delete_inactive_bouts(self):
|
235 |
+
return self._delete_all_inner(lambda bout: not bout.is_active())
|
236 |
+
|
237 |
+
def truncate_or_remove_bouts(self, behavior, start, end, delete_all=False):
|
238 |
+
"""
|
239 |
+
Delete bouts entirely within the range [start, end], and
|
240 |
+
truncate bouts that extend outside the range.
|
241 |
+
If behavior matches _deleteBehavior, the activity affects
|
242 |
+
all bouts. Otherwise, it only affects bouts with matching behavior.
|
243 |
+
"""
|
244 |
+
items = self.get_in_range(start, end)
|
245 |
+
for item in items:
|
246 |
+
if not delete_all and behavior.get_name() != item.name():
|
247 |
+
continue
|
248 |
+
# Delete bouts that are entirely within the range
|
249 |
+
if item.start() >= start and item.end() < end:
|
250 |
+
print(f"removing {item} from active channel")
|
251 |
+
self.remove(item)
|
252 |
+
|
253 |
+
# Truncate and duplicate bouts that extend out both sides of the range
|
254 |
+
if item.start() < start and item.end() > end:
|
255 |
+
self.add(Bout(end, item.end(), item.behavior()))
|
256 |
+
self.update_end(item, start)
|
257 |
+
|
258 |
+
# Truncate bouts at the start boundary that start before the range
|
259 |
+
elif item.start() < start and item.end() <= end:
|
260 |
+
self.update_end(item, start)
|
261 |
+
|
262 |
+
# Truncate bouts at the end boundary that end after the range
|
263 |
+
elif item.start() >= start and item.end() > end:
|
264 |
+
self.update_start(item, end)
|
265 |
+
|
266 |
+
else:
|
267 |
+
print(f"truncate_or_delete_bouts: Unexpected bout {item}")
|
268 |
+
|
269 |
+
def coalesce_bouts(self, start, end):
|
270 |
+
"""
|
271 |
+
combine overlapping bouts of the same behavior within [start, end]
|
272 |
+
"""
|
273 |
+
to_delete = []
|
274 |
+
items = self.get_in_range(start, end)
|
275 |
+
# items will be ordered by start time
|
276 |
+
for ix, first in enumerate(items):
|
277 |
+
if first in to_delete:
|
278 |
+
# previously coalesced
|
279 |
+
continue
|
280 |
+
if ix == len(items)-1:
|
281 |
+
break
|
282 |
+
for second in items[ix+1:]:
|
283 |
+
if (first.name() == second.name() and
|
284 |
+
first.end() >= second.start()):
|
285 |
+
if first.end() < second.end():
|
286 |
+
self.update_end(first, second.end())
|
287 |
+
to_delete.append(second)
|
288 |
+
for item in to_delete:
|
289 |
+
self.remove(item)
|
290 |
+
|
291 |
+
class Annotations(QObject):
|
292 |
+
"""
|
293 |
+
"""
|
294 |
+
|
295 |
+
# Signals
|
296 |
+
annotations_changed = Signal()
|
297 |
+
active_annotations_changed = Signal()
|
298 |
+
|
299 |
+
def __init__(self, behaviors):
|
300 |
+
super().__init__()
|
301 |
+
self._channels = OrderedDict()
|
302 |
+
self._behaviors = behaviors
|
303 |
+
self._movies = []
|
304 |
+
self._start_frame = None
|
305 |
+
self._end_frame = None
|
306 |
+
self._sample_rate = None
|
307 |
+
self._stimulus = None
|
308 |
+
self._format = None
|
309 |
+
self.annotation_names = []
|
310 |
+
behaviors.behaviors_changed.connect(self.note_annotations_changed)
|
311 |
+
|
312 |
+
def read(self, fn):
|
313 |
+
with open(fn, "r") as f:
|
314 |
+
line = f.readline()
|
315 |
+
line = line.strip().lower()
|
316 |
+
if line.endswith("annotation file"):
|
317 |
+
self._format = 'Caltech'
|
318 |
+
self._read_caltech(f)
|
319 |
+
elif line.startswith("scorevideo log"):
|
320 |
+
self._format = 'Ethovision'
|
321 |
+
self._read_ethovision(f)
|
322 |
+
else:
|
323 |
+
print("Unsupported annotation file format")
|
324 |
+
|
325 |
+
def read_io(self, uploaded_file):
|
326 |
+
text_str = uploaded_file.getvalue().decode("utf-8")
|
327 |
+
with StringIO(text_str) as f:
|
328 |
+
f.__setattr__('name', uploaded_file.name)
|
329 |
+
line = f.readline()
|
330 |
+
line = line.strip().lower()
|
331 |
+
if line.endswith("annotation file"):
|
332 |
+
self._format = 'Caltech'
|
333 |
+
self._read_caltech(f)
|
334 |
+
elif line.startswith("scorevideo log"):
|
335 |
+
self._format = 'Ethovision'
|
336 |
+
self._read_ethovision(f)
|
337 |
+
else:
|
338 |
+
print("Unsupported annotation file format")
|
339 |
+
|
340 |
+
def _read_caltech(self, f):
|
341 |
+
found_movies = False
|
342 |
+
found_timecode = False
|
343 |
+
found_stimulus = False
|
344 |
+
found_channel_names = False
|
345 |
+
found_annotation_names = False
|
346 |
+
found_all_channels = False
|
347 |
+
found_all_annotations = False
|
348 |
+
new_behaviors_activated = False
|
349 |
+
reading_channel = False
|
350 |
+
to_activate = []
|
351 |
+
channel_names = []
|
352 |
+
current_channel = None
|
353 |
+
current_bout = None
|
354 |
+
|
355 |
+
self._format = 'Caltech'
|
356 |
+
|
357 |
+
line = f.readline()
|
358 |
+
while line:
|
359 |
+
if found_annotation_names and not new_behaviors_activated:
|
360 |
+
self.ensure_and_activate_behaviors(to_activate)
|
361 |
+
new_behaviors_activated = True
|
362 |
+
|
363 |
+
line.strip()
|
364 |
+
|
365 |
+
if not line:
|
366 |
+
if reading_channel:
|
367 |
+
reading_channel = False
|
368 |
+
current_channel = None
|
369 |
+
current_bout = None
|
370 |
+
elif line.lower().startswith("movie file"):
|
371 |
+
items = line.split()
|
372 |
+
for item in items:
|
373 |
+
if item.lower().startswith("movie"):
|
374 |
+
continue
|
375 |
+
if item.lower().startswith("file"):
|
376 |
+
continue
|
377 |
+
self._movies.append(item)
|
378 |
+
found_movies = True
|
379 |
+
elif line.lower().startswith("stimulus name"):
|
380 |
+
# TODO: do something when we know what one of these looks like
|
381 |
+
found_stimulus = True
|
382 |
+
elif line.lower().startswith("annotation start frame") or line.lower().startswith("annotation start time"):
|
383 |
+
items = line.split()
|
384 |
+
if len(items) > 3:
|
385 |
+
try:
|
386 |
+
self._start_frame = int(items[3])
|
387 |
+
except:
|
388 |
+
self._start_frame = int(float(items[3]))
|
389 |
+
if self._end_frame and self._sample_rate:
|
390 |
+
found_timecode = True
|
391 |
+
elif line.lower().startswith("annotation stop frame") or line.lower().startswith("annotation stop time"):
|
392 |
+
items = line.split()
|
393 |
+
if len(items) > 3:
|
394 |
+
try:
|
395 |
+
self._end_frame = int(items[3])
|
396 |
+
except:
|
397 |
+
self._end_frame = int(float(items[3]))
|
398 |
+
if self._start_frame and self._sample_rate:
|
399 |
+
found_timecode = True
|
400 |
+
elif line.lower().startswith("annotation framerate"):
|
401 |
+
items = line.split()
|
402 |
+
if len(items) > 2:
|
403 |
+
self._sample_rate = float(items[2])
|
404 |
+
if self._start_frame and self._end_frame:
|
405 |
+
found_timecode = True
|
406 |
+
elif line.lower().startswith("list of channels"):
|
407 |
+
line = f.readline()
|
408 |
+
while line:
|
409 |
+
line = line.strip()
|
410 |
+
if not line:
|
411 |
+
break # blank line -- end of section
|
412 |
+
channel_names.append(line)
|
413 |
+
line = f.readline()
|
414 |
+
found_channel_names = True
|
415 |
+
elif line.lower().startswith("list of annotations"):
|
416 |
+
line = f.readline()
|
417 |
+
while line:
|
418 |
+
line = line.strip()
|
419 |
+
if not line:
|
420 |
+
break # blank line -- end of section
|
421 |
+
to_activate.append(line)
|
422 |
+
line = f.readline().strip()
|
423 |
+
found_annotation_names = True
|
424 |
+
elif line.strip().lower().endswith("---"):
|
425 |
+
for ch_name in channel_names:
|
426 |
+
if line.startswith(ch_name):
|
427 |
+
self._channels[ch_name] = Channel()
|
428 |
+
current_channel = ch_name
|
429 |
+
reading_channel = True
|
430 |
+
break
|
431 |
+
if reading_channel:
|
432 |
+
reading_annot = False
|
433 |
+
line = f.readline()
|
434 |
+
while line:
|
435 |
+
line = line.strip()
|
436 |
+
if not line: # blank line
|
437 |
+
if reading_annot:
|
438 |
+
reading_annot = False
|
439 |
+
current_bout = None
|
440 |
+
else:
|
441 |
+
# second blank line, so done with channel
|
442 |
+
reading_channel = False
|
443 |
+
current_channel = None
|
444 |
+
break
|
445 |
+
elif line.startswith(">"):
|
446 |
+
current_bout = line[1:]
|
447 |
+
reading_annot = True
|
448 |
+
elif line.lower().startswith("start"):
|
449 |
+
pass # skip a header line
|
450 |
+
else:
|
451 |
+
items = line.split()
|
452 |
+
is_float = '.' in items[0] or '.' in items[1] or '.' in items[2]
|
453 |
+
bout_start = items[0]
|
454 |
+
if float(bout_start) < 1:
|
455 |
+
bout_start = 1
|
456 |
+
bout_end = items[1]
|
457 |
+
if float(bout_end) < 1:
|
458 |
+
bout_end = 1
|
459 |
+
self.add_bout(
|
460 |
+
Bout(
|
461 |
+
tc.Timecode(self._sample_rate, start_seconds=float(bout_start)) if is_float
|
462 |
+
else tc.Timecode(self._sample_rate, frames=int(bout_start)),
|
463 |
+
tc.Timecode(self._sample_rate, start_seconds=float(bout_end)) if is_float
|
464 |
+
else tc.Timecode(self._sample_rate, frames=int(bout_end)),
|
465 |
+
self._behaviors.get(current_bout)),
|
466 |
+
current_channel)
|
467 |
+
line = f.readline()
|
468 |
+
line = f.readline()
|
469 |
+
print(f"Done reading Caltech annotation file {f.name}")
|
470 |
+
self.note_annotations_changed()
|
471 |
+
|
472 |
+
def write_caltech(self, f, video_files, stimulus):
|
473 |
+
if not f.writable():
|
474 |
+
raise RuntimeError("File not writable")
|
475 |
+
f.write("Bento annotation file\n")
|
476 |
+
f.write("Movie file(s):")
|
477 |
+
for file in video_files:
|
478 |
+
f.write(' ' + file)
|
479 |
+
f.write('\n\n')
|
480 |
+
|
481 |
+
f.write(f"Stimulus name: {stimulus}\n")
|
482 |
+
f.write(f"Annotation start frame: {self._start_frame}\n")
|
483 |
+
f.write(f"Annotation stop frame: {self._end_frame}\n")
|
484 |
+
f.write(f"Annotation framerate: {self._sample_rate}\n")
|
485 |
+
f.write("\n")
|
486 |
+
|
487 |
+
f.write("List of Channels:\n")
|
488 |
+
for ch in self.channel_names():
|
489 |
+
f.write(ch + "\n")
|
490 |
+
f.write("\n")
|
491 |
+
|
492 |
+
f.write("List of annotations:\n")
|
493 |
+
for annot in self.annotation_names:
|
494 |
+
f.write(annot + "\n")
|
495 |
+
f.write("\n")
|
496 |
+
|
497 |
+
for ch in self.channel_names():
|
498 |
+
by_name = {}
|
499 |
+
f.write(f"{ch}----------\n")
|
500 |
+
|
501 |
+
for bout in self.channel(ch):
|
502 |
+
if not by_name.get(bout.name()):
|
503 |
+
by_name[bout.name()] = []
|
504 |
+
by_name[bout.name()].append(bout)
|
505 |
+
|
506 |
+
for annot in by_name:
|
507 |
+
f.write(f">{annot}\n")
|
508 |
+
f.write("Start\tStop\tDuration\n")
|
509 |
+
for bout in by_name[annot]:
|
510 |
+
start = bout.start().frames
|
511 |
+
end = bout.end().frames
|
512 |
+
f.write(f"{start}\t{end}\t{end - start}\n")
|
513 |
+
f.write("\n")
|
514 |
+
|
515 |
+
f.write("\n")
|
516 |
+
|
517 |
+
f.close()
|
518 |
+
print(f"Done writing Caltech annotation file {f.name}")
|
519 |
+
|
520 |
+
def _read_ethovision(self, f):
|
521 |
+
print("Ethovision annotations not yet supported")
|
522 |
+
|
523 |
+
def clear_channels(self):
|
524 |
+
self._channels.clear()
|
525 |
+
|
526 |
+
def channel_names(self):
|
527 |
+
return list(self._channels.keys())
|
528 |
+
|
529 |
+
def channel(self, ch: str) -> Channel:
|
530 |
+
return self._channels[ch]
|
531 |
+
|
532 |
+
def addEmptyChannel(self, ch: str):
|
533 |
+
if ch not in self.channel_names():
|
534 |
+
self._channels[ch] = Channel()
|
535 |
+
|
536 |
+
def add_bout(self, bout, channel):
|
537 |
+
if bout.name() not in self.annotation_names:
|
538 |
+
self.annotation_names.append(bout.name())
|
539 |
+
self._channels[channel].add(bout)
|
540 |
+
if bout.end() > self.end_time():
|
541 |
+
self.set_end_frame(bout.end())
|
542 |
+
|
543 |
+
def start_time(self):
|
544 |
+
"""
|
545 |
+
At some point we will need to support a start time distinct from
|
546 |
+
frame number, perhaps derived from the OS file modify time
|
547 |
+
or the start time of the corresponding video (or other media) file
|
548 |
+
"""
|
549 |
+
if not self._start_frame or not self._sample_rate:
|
550 |
+
return tc.Timecode('30.0', '0:0:0:0')
|
551 |
+
return tc.Timecode(self._sample_rate, frames=self._start_frame)
|
552 |
+
|
553 |
+
def start_frame(self):
|
554 |
+
return self._start_frame
|
555 |
+
|
556 |
+
def set_start_frame(self, t):
|
557 |
+
if isinstance(t, int):
|
558 |
+
self._start_frame = t
|
559 |
+
elif isinstance(t, tc.Timecode):
|
560 |
+
self._start_frame = t.frames
|
561 |
+
else:
|
562 |
+
raise TypeError("Expected a frame number or Timecode")
|
563 |
+
|
564 |
+
def end_time(self):
|
565 |
+
if not self._end_frame or not self._sample_rate:
|
566 |
+
return tc.Timecode('30.0', '23:59:59:29')
|
567 |
+
return tc.Timecode(self._sample_rate, frames=self._end_frame)
|
568 |
+
|
569 |
+
def end_frame(self):
|
570 |
+
return self._end_frame
|
571 |
+
|
572 |
+
def set_end_frame(self, t):
|
573 |
+
if isinstance(t, int):
|
574 |
+
self._end_frame = t
|
575 |
+
elif isinstance(t, tc.Timecode):
|
576 |
+
self._end_frame = t.frames
|
577 |
+
else:
|
578 |
+
raise TypeError("Expected a frame number or Timecode")
|
579 |
+
|
580 |
+
def sample_rate(self):
|
581 |
+
return self._sample_rate
|
582 |
+
|
583 |
+
def set_sample_rate(self, sample_rate):
|
584 |
+
self._sample_rate = sample_rate
|
585 |
+
|
586 |
+
def format(self):
|
587 |
+
return self._format
|
588 |
+
|
589 |
+
def set_format(self, format):
|
590 |
+
self._format = format
|
591 |
+
|
592 |
+
def delete_bouts_by_name(self, behavior_name):
|
593 |
+
deleted_names = set()
|
594 |
+
for chan_name in self.channel_names():
|
595 |
+
deleted_names.update(self.channel(chan_name).delete_bouts_by_name(behavior_name))
|
596 |
+
for name in deleted_names:
|
597 |
+
self.annotation_names.remove(name)
|
598 |
+
|
599 |
+
def delete_inactive_bouts(self):
|
600 |
+
deleted_names = set()
|
601 |
+
for chan_name in self.channel_names():
|
602 |
+
deleted_names.update(self.channel(chan_name).delete_inactive_bouts())
|
603 |
+
for name in deleted_names:
|
604 |
+
self.annotation_names.remove(name)
|
605 |
+
|
606 |
+
def ensure_and_activate_behaviors(self, toActivate):
|
607 |
+
behaviorSetUpdated = False
|
608 |
+
for behaviorName in toActivate:
|
609 |
+
behaviorSetUpdated |= self._behaviors.addIfMissing(behaviorName)
|
610 |
+
self.annotation_names.append(behaviorName)
|
611 |
+
self._behaviors.get(behaviorName).set_active(True)
|
612 |
+
if behaviorSetUpdated:
|
613 |
+
self.annotations_changed.emit()
|
614 |
+
self.active_annotations_changed.emit()
|
615 |
+
|
616 |
+
def ensure_active_behaviors(self):
|
617 |
+
for behavior in self._behaviors:
|
618 |
+
if behavior.is_active() and behavior.get_name() not in self.annotation_names:
|
619 |
+
self.annotation_names.append(behavior.get_name())
|
620 |
+
|
621 |
+
def truncate_or_remove_bouts(self, behavior, start, end, chan):
|
622 |
+
"""
|
623 |
+
Delete bouts entirely within the range [start, end], or
|
624 |
+
truncate bouts that extend outside the range.
|
625 |
+
If behavior matches _deleteBehavior, the activity affects
|
626 |
+
all bouts. Otherwise, it only affects bouts with matching behavior.
|
627 |
+
"""
|
628 |
+
delete_all = (behavior.get_name() == self._behaviors.getDeleteBehavior().get_name())
|
629 |
+
self._channels[chan].truncate_or_remove_bouts(behavior, start, end, delete_all)
|
630 |
+
self.note_annotations_changed()
|
631 |
+
|
632 |
+
def coalesce_bouts(self, start, end, chan):
|
633 |
+
"""
|
634 |
+
combine overlapping bouts of the same behavior within [start, end]
|
635 |
+
"""
|
636 |
+
self._channels[chan].coalesce_bouts(start, end)
|
637 |
+
self.note_annotations_changed()
|
638 |
+
|
639 |
+
@Slot()
|
640 |
+
def note_annotations_changed(self):
|
641 |
+
self.annotations_changed.emit()
|
utils/behavior.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# behavior.py
|
2 |
+
"""
|
3 |
+
Overview comment here
|
4 |
+
"""
|
5 |
+
|
6 |
+
from qtpy.QtCore import QAbstractItemModel, QAbstractTableModel, QModelIndex, QObject, Qt, Signal, Slot
|
7 |
+
from qtpy.QtGui import QColor
|
8 |
+
import os
|
9 |
+
|
10 |
+
class Behavior(QObject):
|
11 |
+
"""
|
12 |
+
An annotation behavior, which is quite simple. It comprises:
|
13 |
+
name The name of the behavior, which is displayed on various UI widgets
|
14 |
+
hot_key The case-sensitive key stroke used to start and stop instances of the behavior
|
15 |
+
color The color with which to display this behavior
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, name: str, hot_key: str = '', color: QColor = QColor('gray'), active = False, visible = True):
|
19 |
+
super().__init__()
|
20 |
+
self._name = name
|
21 |
+
self._hot_key = '' if hot_key == '_' else hot_key
|
22 |
+
self._color = color
|
23 |
+
self._visible = visible
|
24 |
+
self._active = active
|
25 |
+
self._get_functions = {
|
26 |
+
'hot_key': self.get_hot_key,
|
27 |
+
'name': self.get_name,
|
28 |
+
'color': self.get_color,
|
29 |
+
'active': self.is_active,
|
30 |
+
'visible': self.is_visible
|
31 |
+
}
|
32 |
+
self._set_functions = {
|
33 |
+
'hot_key': self.set_hot_key,
|
34 |
+
'name': self.set_name,
|
35 |
+
'color': self.set_color,
|
36 |
+
'active': self.set_active,
|
37 |
+
'visible': self.set_visible
|
38 |
+
}
|
39 |
+
|
40 |
+
def __repr__(self):
|
41 |
+
return f"Behavior: name={self._name}, hot_key={self._hot_key}, color={self._color}, active={self._active}, visible={self._visible}"
|
42 |
+
|
43 |
+
def get(self, key):
|
44 |
+
# may raise KeyError
|
45 |
+
return self._get_functions[key]()
|
46 |
+
|
47 |
+
def set(self, key, value):
|
48 |
+
# may raise KeyError
|
49 |
+
self._set_functions[key](value)
|
50 |
+
|
51 |
+
def get_hot_key(self):
|
52 |
+
return self._hot_key
|
53 |
+
|
54 |
+
def set_hot_key(self, hot_key: str):
|
55 |
+
self._hot_key = hot_key
|
56 |
+
|
57 |
+
def get_color(self):
|
58 |
+
return self._color
|
59 |
+
|
60 |
+
def set_color(self, color: QColor):
|
61 |
+
self._color = color
|
62 |
+
|
63 |
+
def is_active(self):
|
64 |
+
return self._active
|
65 |
+
|
66 |
+
@Slot(bool)
|
67 |
+
def set_active(self, active):
|
68 |
+
self._active = active
|
69 |
+
|
70 |
+
def is_visible(self):
|
71 |
+
return self._visible
|
72 |
+
|
73 |
+
@Slot(bool)
|
74 |
+
def set_visible(self, visible):
|
75 |
+
self._visible = visible
|
76 |
+
|
77 |
+
def get_name(self):
|
78 |
+
return self._name
|
79 |
+
|
80 |
+
def set_name(self, name):
|
81 |
+
self._name = name
|
82 |
+
|
83 |
+
def toDict(self):
|
84 |
+
return {
|
85 |
+
'hot_key': '_' if self._hot_key == '' else self._hot_key,
|
86 |
+
'color': self._color,
|
87 |
+
'name': self._name,
|
88 |
+
'active': self._active,
|
89 |
+
'visible': self._visible
|
90 |
+
}
|
91 |
+
|
92 |
+
class Behaviors(QAbstractTableModel):
|
93 |
+
"""
|
94 |
+
A set of behaviors, which represent "all" possible behaviors.
|
95 |
+
The class supports reading from and writing to profile files that specify
|
96 |
+
the default hot_key and color for each behavior, along with the name.
|
97 |
+
|
98 |
+
Derives from QAbstractTableModel so that it can be viewed and edited
|
99 |
+
directly in a QTableView widget.
|
100 |
+
|
101 |
+
Use getattr(name) to get the Behavior instance for a given name.
|
102 |
+
Use from_hot_key(key) to get the behavior(s) given the hot key.
|
103 |
+
Returns None if the hot_key isn't defined.
|
104 |
+
"""
|
105 |
+
|
106 |
+
behaviors_changed = Signal()
|
107 |
+
layout_changed = Signal()
|
108 |
+
|
109 |
+
def __init__(self):
|
110 |
+
super().__init__()
|
111 |
+
self._items = []
|
112 |
+
self._by_name = {}
|
113 |
+
self._by_hot_key = {}
|
114 |
+
self._header = ['hot_key', 'color', 'name', 'active', 'visible']
|
115 |
+
self._searchList = [self._by_hot_key, None, self._by_name, None, None]
|
116 |
+
self._delete_behavior = Behavior('_delete', color = QColor('black'))
|
117 |
+
self._immutableColumns = set()
|
118 |
+
self._booleanColumns = set([self._header.index('active'), self._header.index('visible')])
|
119 |
+
self._role_to_str = {
|
120 |
+
Qt.DisplayRole: "DisplayRole",
|
121 |
+
Qt.DecorationRole: "DecorationRole",
|
122 |
+
Qt.EditRole: "EditRole",
|
123 |
+
Qt.ToolTipRole: "ToolTipRole",
|
124 |
+
Qt.StatusTipRole: "StatusTipRole",
|
125 |
+
Qt.WhatsThisRole: "WhatsThisRole",
|
126 |
+
Qt.SizeHintRole: "SizeHintRole",
|
127 |
+
Qt.FontRole: "FontRole",
|
128 |
+
Qt.TextAlignmentRole: "TextAlignmentRole",
|
129 |
+
Qt.BackgroundRole: "BackgroundRole",
|
130 |
+
Qt.ForegroundRole: "ForegroundRole",
|
131 |
+
Qt.CheckStateRole: "CheckStateRole",
|
132 |
+
Qt.InitialSortOrderRole: "InitialSortOrderRole",
|
133 |
+
Qt.AccessibleTextRole: "AccessibleTextRole",
|
134 |
+
Qt.UserRole: "UserRole"
|
135 |
+
}
|
136 |
+
|
137 |
+
def add(self, beh: Behavior, row=-1):
|
138 |
+
if row < 0:
|
139 |
+
row = self.rowCount()
|
140 |
+
self.beginInsertRows(QModelIndex(), row, row)
|
141 |
+
self._items.insert(row, beh)
|
142 |
+
self._by_name[beh.get_name()] = beh
|
143 |
+
hot_key = beh.get_hot_key()
|
144 |
+
if hot_key:
|
145 |
+
if hot_key not in self._by_hot_key.keys():
|
146 |
+
self._by_hot_key[hot_key] = []
|
147 |
+
assert(isinstance(self._by_hot_key[hot_key], list))
|
148 |
+
self._by_hot_key[hot_key].append(beh)
|
149 |
+
self.endInsertRows()
|
150 |
+
self.dataChanged.emit(
|
151 |
+
self.index(row, 0, QModelIndex()),
|
152 |
+
self.index(row, self.columnCount()-1, QModelIndex()),
|
153 |
+
[Qt.DisplayRole, Qt.EditRole])
|
154 |
+
self.behaviors_changed.emit()
|
155 |
+
|
156 |
+
def load(self, f):
|
157 |
+
line = f.readline()
|
158 |
+
while line:
|
159 |
+
hot_key, name, r, g, b = line.strip().split(' ')
|
160 |
+
if hot_key == '_':
|
161 |
+
hot_key = ''
|
162 |
+
self.add(Behavior(name, hot_key, QColor.fromRgbF(float(r), float(g), float(b))))
|
163 |
+
line = f.readline()
|
164 |
+
|
165 |
+
def save(self, f):
|
166 |
+
for beh in self._items:
|
167 |
+
h = beh.get_hot_key()
|
168 |
+
if h == '':
|
169 |
+
h = '_'
|
170 |
+
color = beh.get_color()
|
171 |
+
f.write(f"{h} {beh.get_name()} {color.redF()} {color.greenF()} {color.blueF()}" + os.linesep)
|
172 |
+
|
173 |
+
def get(self, name):
|
174 |
+
if name not in self._by_name.keys():
|
175 |
+
return None
|
176 |
+
return self._by_name[name]
|
177 |
+
|
178 |
+
def from_hot_key(self, key):
|
179 |
+
"""
|
180 |
+
Return the list of behaviors associated with this hot key, if any
|
181 |
+
"""
|
182 |
+
try:
|
183 |
+
return self._by_hot_key[key]
|
184 |
+
except KeyError:
|
185 |
+
return None
|
186 |
+
|
187 |
+
def len(self):
|
188 |
+
return len(self._items)
|
189 |
+
|
190 |
+
def header(self):
|
191 |
+
return self._header
|
192 |
+
|
193 |
+
def colorColumns(self):
|
194 |
+
return [self._header.index('color')]
|
195 |
+
|
196 |
+
def __iter__(self):
|
197 |
+
return iter(self._items)
|
198 |
+
|
199 |
+
def getDeleteBehavior(self):
|
200 |
+
return self._delete_behavior
|
201 |
+
|
202 |
+
def addIfMissing(self, nameToAdd):
|
203 |
+
if nameToAdd not in self._by_name:
|
204 |
+
self.add(Behavior(nameToAdd, '', QColor('gray')))
|
205 |
+
return True
|
206 |
+
return False
|
207 |
+
|
208 |
+
def isImmutable(self, index):
|
209 |
+
return index.column() in self._immutableColumns
|
210 |
+
|
211 |
+
def setImmutable(self, column):
|
212 |
+
self._immutableColumns.add(column)
|
213 |
+
|
214 |
+
# QAbstractTableModel API methods
|
215 |
+
|
216 |
+
def headerData(self, col, orientation, role):
|
217 |
+
if orientation == Qt.Horizontal and role == Qt.DisplayRole:
|
218 |
+
return self._header[col]
|
219 |
+
return None
|
220 |
+
|
221 |
+
def rowCount(self, parent=None):
|
222 |
+
return len(self._items)
|
223 |
+
|
224 |
+
def columnCount(self, parent=None):
|
225 |
+
return len(self._header)
|
226 |
+
|
227 |
+
def data(self, index, role=Qt.DisplayRole):
|
228 |
+
datum = self._items[index.row()].get(self._header[index.column()])
|
229 |
+
if isinstance(datum, bool):
|
230 |
+
if role in [Qt.CheckStateRole, Qt.EditRole]:
|
231 |
+
return Qt.Checked if datum else Qt.Unchecked
|
232 |
+
return None
|
233 |
+
if role in [Qt.DisplayRole, Qt.EditRole]:
|
234 |
+
return self._items[index.row()].get(self._header[index.column()])
|
235 |
+
return None
|
236 |
+
|
237 |
+
def setData(self, index, value, role=Qt.EditRole):
|
238 |
+
if not role in [Qt.CheckStateRole, Qt.EditRole]:
|
239 |
+
return False
|
240 |
+
if role == Qt.CheckStateRole:
|
241 |
+
value = bool(value)
|
242 |
+
beh = self._items[index.row()]
|
243 |
+
key = self._header[index.column()]
|
244 |
+
name = beh.get_name()
|
245 |
+
hot_key = beh.get_hot_key()
|
246 |
+
beh.set(key, value)
|
247 |
+
if key == 'hot_key' and value != hot_key:
|
248 |
+
# disassociate this behavior from the hot_key
|
249 |
+
# and associate with the new hot_key if not ''
|
250 |
+
if hot_key != '':
|
251 |
+
del(self._by_hot_key[hot_key])
|
252 |
+
if value != '':
|
253 |
+
if value not in self._by_hot_key.keys():
|
254 |
+
self._by_hot_key[value] = []
|
255 |
+
assert(isinstance(self._by_hot_key[value], list))
|
256 |
+
self._by_hot_key[value].append(beh)
|
257 |
+
elif key == 'name' and value != name:
|
258 |
+
if name in self._by_name.keys():
|
259 |
+
del(self._by_name[name])
|
260 |
+
self._by_name[value] = beh
|
261 |
+
self.behaviors_changed.emit()
|
262 |
+
self.dataChanged.emit(index, index, [role])
|
263 |
+
return True
|
264 |
+
|
265 |
+
def insertRows(self, row, count, parent):
|
266 |
+
if count < 1 or row < 0 or row > self.rowCount():
|
267 |
+
return False
|
268 |
+
self.beginInsertRows(QModelIndex(), row, row)
|
269 |
+
for r in range(count):
|
270 |
+
self._items.insert(row, Behavior('', active=True))
|
271 |
+
self.endInsertRows()
|
272 |
+
return True
|
273 |
+
|
274 |
+
def removeRows(self, row, count, parent=QModelIndex()):
|
275 |
+
if count <= 0 or row < 0 or row + count > self.rowCount(parent):
|
276 |
+
return False
|
277 |
+
self.beginRemoveRows(parent, row, row + count - 1)
|
278 |
+
for item in self._items[row:row+count-1]:
|
279 |
+
self._by_name.pop(item.name)
|
280 |
+
self._by_hot_key.pop(item.hot_key)
|
281 |
+
for i in range(count):
|
282 |
+
self._items.pop(row)
|
283 |
+
self.endRemoveRows()
|
284 |
+
|
285 |
+
def flags(self, index):
|
286 |
+
f = super().flags(index)
|
287 |
+
if index.column() not in self._immutableColumns:
|
288 |
+
f |= Qt.ItemIsEditable
|
289 |
+
if index.column() in self._booleanColumns:
|
290 |
+
f = (f & ~(Qt.ItemIsSelectable | Qt.ItemIsEditable)) | Qt.ItemIsUserCheckable
|
291 |
+
return f
|
utils/data_loading.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""File for loading data into AnimalEditor"""
|
2 |
+
import io
|
3 |
+
from random import random
|
4 |
+
from os.path import splitext
|
5 |
+
from collections import OrderedDict
|
6 |
+
import numpy as np
|
7 |
+
from tempfile import NamedTemporaryFile
|
8 |
+
|
9 |
+
from .annot import Annotations
|
10 |
+
from .behavior import Behaviors
|
11 |
+
|
12 |
+
def has_extension(fname:str, extension:str|list[str]) -> bool:
|
13 |
+
"""
|
14 |
+
Checks to see if the passed in file name ends with an expected extension.
|
15 |
+
"""
|
16 |
+
_, ext = splitext(fname)
|
17 |
+
if isinstance(extension, str):
|
18 |
+
return ext == extension
|
19 |
+
elif isinstance(extension, list):
|
20 |
+
return ext in extension
|
21 |
+
|
22 |
+
def _clean_annotations(annotations):
|
23 |
+
"""
|
24 |
+
While reading in behaviors from an .annot file, sometimes channels without normally
|
25 |
+
callable keys appear (i.e. keys that are strings which name a behavior), thus this
|
26 |
+
code only accepts keys which are strings.
|
27 |
+
"""
|
28 |
+
if not annotations:
|
29 |
+
raise ValueError('No annotations found.')
|
30 |
+
clean_annot = OrderedDict()
|
31 |
+
for channel in annotations.keys():
|
32 |
+
channel_dict = OrderedDict()
|
33 |
+
for behavior_name in annotations[channel].keys():
|
34 |
+
if isinstance(behavior_name, str):
|
35 |
+
channel_dict.update({behavior_name : annotations[channel][behavior_name]})
|
36 |
+
clean_annot.update({channel: channel_dict})
|
37 |
+
return clean_annot
|
38 |
+
|
39 |
+
def load_annot_sheet_txt(fname, offset = 0):
|
40 |
+
"""
|
41 |
+
Generated a dictionary for retrieving the beginning and end frames of behaviors from
|
42 |
+
an .annot file.
|
43 |
+
|
44 |
+
Note that 0:00:00 is frame 1
|
45 |
+
|
46 |
+
Args:
|
47 |
+
fname - the path to the .annot file to be read (must be Caltech format)\n
|
48 |
+
offset - a value which offsets the start and end frame of each bout in
|
49 |
+
the sheet, as well as the absolute start and end frame of the file.
|
50 |
+
This value is optional, and is set to 0 by default
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
annotations - dictionary of beginning and end frames for behaviors\n
|
54 |
+
start_time - the frame the movie started at (0:00:00 is 1)\n
|
55 |
+
end_time - the frame the movie ended at (0:00:00 is 1)\n
|
56 |
+
sample_rate - the sample rate reported within the file
|
57 |
+
"""
|
58 |
+
# from bento for python
|
59 |
+
behaviors = Behaviors()
|
60 |
+
annot_sheet = Annotations(behaviors)
|
61 |
+
annot_sheet.read(fname)
|
62 |
+
|
63 |
+
sample_rate = annot_sheet.sample_rate()
|
64 |
+
annotations = OrderedDict()
|
65 |
+
for key in annot_sheet.channel_names():
|
66 |
+
annot_behaviors = OrderedDict()
|
67 |
+
bout_names = set()
|
68 |
+
for bout in annot_sheet.channel(key): #._bouts_by_start:
|
69 |
+
bout_names.add(bout.name())
|
70 |
+
for name in bout_names:
|
71 |
+
annot_behaviors.update({name : []})
|
72 |
+
for bout in annot_sheet.channel(key): #._bouts_by_start:
|
73 |
+
start_frame = bout.start().frames + offset
|
74 |
+
end_frame = bout.end().frames + offset
|
75 |
+
bout_frames = [start_frame, end_frame]
|
76 |
+
curr_table = annot_behaviors.get(bout.name())
|
77 |
+
new_table = curr_table.append(bout_frames)
|
78 |
+
annot_behaviors.update({bout.name : new_table})
|
79 |
+
for name in bout_names:
|
80 |
+
curr_table = annot_behaviors.get(name)
|
81 |
+
beh_array = np.array(curr_table)
|
82 |
+
annot_behaviors.update({name : beh_array})
|
83 |
+
|
84 |
+
annotations.update({key : annot_behaviors})
|
85 |
+
annotations = _clean_annotations(annotations)
|
86 |
+
start_time = annot_sheet.start_frame() + offset
|
87 |
+
end_time = annot_sheet.end_frame() + offset
|
88 |
+
return annotations, start_time, end_time, sample_rate
|
89 |
+
|
90 |
+
def load_multiple_annotations(fnames):
|
91 |
+
"""
|
92 |
+
Generates a single dictionary given multiple .annot files.
|
93 |
+
"""
|
94 |
+
if not isinstance(fnames, list):
|
95 |
+
raise TypeError(f'Expected list[str], got {type(fnames)} instead.')
|
96 |
+
if not fnames:
|
97 |
+
raise ValueError('No file names passed in.')
|
98 |
+
if len(fnames) == 1:
|
99 |
+
return load_annot_sheet_txt(fnames[0])
|
100 |
+
head_annot, head_start_frame, head_end_frame, sample_rate = load_annot_sheet_txt(fnames[0])
|
101 |
+
end_frame = head_end_frame
|
102 |
+
for fname in fnames[1:]:
|
103 |
+
curr_annot, _, curr_end_frame, _ = load_annot_sheet_txt(fname, end_frame)
|
104 |
+
end_frame = curr_end_frame
|
105 |
+
for channel in curr_annot.keys():
|
106 |
+
if channel not in head_annot:
|
107 |
+
channel_dict = {}
|
108 |
+
head_annot.update({channel : channel_dict})
|
109 |
+
for behavior in curr_annot[channel].keys():
|
110 |
+
curr_behavior_bout_array = curr_annot[channel][behavior]
|
111 |
+
if channel in head_annot and behavior in head_annot[channel]:
|
112 |
+
new_bout_array = np.vstack((head_annot[channel][behavior],
|
113 |
+
curr_behavior_bout_array))
|
114 |
+
else:
|
115 |
+
new_bout_array = curr_behavior_bout_array
|
116 |
+
head_annot[channel].update({behavior : new_bout_array})
|
117 |
+
return head_annot, head_start_frame, end_frame, sample_rate
|
118 |
+
|
119 |
+
def load_annot_sheet_txt_io(uploaded_file, offset = 0):
|
120 |
+
"""
|
121 |
+
Generated a dictionary for retrieving the beginning and end frames of behaviors from
|
122 |
+
an .annot file.
|
123 |
+
|
124 |
+
Note that 0:00:00 is frame 1
|
125 |
+
|
126 |
+
Args:
|
127 |
+
fname - the path to the .annot file to be read (must be Caltech format)\n
|
128 |
+
offset - a value which offsets the start and end frame of each bout in
|
129 |
+
the sheet, as well as the absolute start and end frame of the file.
|
130 |
+
This value is optional, and is set to 0 by default
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
annotations - dictionary of beginning and end frames for behaviors\n
|
134 |
+
start_time - the frame the movie started at (0:00:00 is 1)\n
|
135 |
+
end_time - the frame the movie ended at (0:00:00 is 1)\n
|
136 |
+
sample_rate - the sample rate reported within the file
|
137 |
+
"""
|
138 |
+
# from bento for python
|
139 |
+
behaviors = Behaviors()
|
140 |
+
annot_sheet = Annotations(behaviors)
|
141 |
+
|
142 |
+
annot_sheet.read_io(uploaded_file)
|
143 |
+
|
144 |
+
sample_rate = annot_sheet.sample_rate()
|
145 |
+
annotations = OrderedDict()
|
146 |
+
for key in annot_sheet.channel_names():
|
147 |
+
annot_behaviors = OrderedDict()
|
148 |
+
bout_names = set()
|
149 |
+
for bout in annot_sheet.channel(key): #._bouts_by_start:
|
150 |
+
bout_names.add(bout.name())
|
151 |
+
for name in bout_names:
|
152 |
+
annot_behaviors.update({name : []})
|
153 |
+
for bout in annot_sheet.channel(key): #._bouts_by_start:
|
154 |
+
start_frame = bout.start().frames + offset
|
155 |
+
end_frame = bout.end().frames + offset
|
156 |
+
bout_frames = [start_frame, end_frame]
|
157 |
+
curr_table = annot_behaviors.get(bout.name())
|
158 |
+
new_table = curr_table.append(bout_frames)
|
159 |
+
annot_behaviors.update({bout.name : new_table})
|
160 |
+
for name in bout_names:
|
161 |
+
curr_table = annot_behaviors.get(name)
|
162 |
+
beh_array = np.array(curr_table)
|
163 |
+
annot_behaviors.update({name : beh_array})
|
164 |
+
|
165 |
+
annotations.update({key : annot_behaviors})
|
166 |
+
annotations = _clean_annotations(annotations)
|
167 |
+
start_time = annot_sheet.start_frame() + offset
|
168 |
+
end_time = annot_sheet.end_frame() + offset
|
169 |
+
return annotations, start_time, end_time, sample_rate
|
170 |
+
|
171 |
+
def load_multiple_annotations_io(uploaded_files):
|
172 |
+
"""
|
173 |
+
Generates a single dictionary given multiple .annot files.
|
174 |
+
"""
|
175 |
+
if not isinstance(uploaded_files, list):
|
176 |
+
raise TypeError(f'Expected list, got {type(uploaded_files)} instead.')
|
177 |
+
if not uploaded_files:
|
178 |
+
raise ValueError('No file names passed in.')
|
179 |
+
if len(uploaded_files) == 1:
|
180 |
+
return load_annot_sheet_txt_io(uploaded_files[0])
|
181 |
+
head_annot, head_start_frame, head_end_frame, sample_rate = load_annot_sheet_txt_io(uploaded_files[0])
|
182 |
+
end_frame = head_end_frame
|
183 |
+
for uploaded_file in uploaded_files[1:]:
|
184 |
+
curr_annot, _, curr_end_frame, _ = load_annot_sheet_txt_io(uploaded_file, end_frame)
|
185 |
+
end_frame = curr_end_frame
|
186 |
+
for channel in curr_annot.keys():
|
187 |
+
if channel not in head_annot:
|
188 |
+
channel_dict = {}
|
189 |
+
head_annot.update({channel : channel_dict})
|
190 |
+
for behavior in curr_annot[channel].keys():
|
191 |
+
curr_behavior_bout_array = curr_annot[channel][behavior]
|
192 |
+
if channel in head_annot and behavior in head_annot[channel]:
|
193 |
+
new_bout_array = np.vstack((head_annot[channel][behavior],
|
194 |
+
curr_behavior_bout_array))
|
195 |
+
else:
|
196 |
+
new_bout_array = curr_behavior_bout_array
|
197 |
+
head_annot[channel].update({behavior : new_bout_array})
|
198 |
+
return head_annot, head_start_frame, end_frame, sample_rate
|
utils/data_processing.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Provides the standard data processing functions performed on CNMFe and annotation data"""
|
2 |
+
import numpy as np
|
3 |
+
from scipy.signal import correlate
|
4 |
+
from scipy.stats import zscore
|
5 |
+
|
6 |
+
def smooth(data: np.ndarray, window_size=5):
|
7 |
+
"""
|
8 |
+
Returns a smoothed version of response data using a moving average filter.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
----------
|
12 |
+
data : np.ndarray
|
13 |
+
A numpy 1-D array containing data to be smoothed.
|
14 |
+
window_size : int
|
15 |
+
Number of data points for calculating the smoothed value. If an even number is
|
16 |
+
passed in, window_size is autmoatically reduced by 1.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
--------
|
20 |
+
smooth_data : np.ndarray
|
21 |
+
Smoothed data, returned as a 1-D array of the same size as ``data``.
|
22 |
+
|
23 |
+
Notes:
|
24 |
+
------
|
25 |
+
Implements MATLAB's smooth function.
|
26 |
+
"""
|
27 |
+
if window_size == 0:
|
28 |
+
raise ValueError('window_size can not be 0.')
|
29 |
+
if window_size == 1:
|
30 |
+
return data
|
31 |
+
if window_size > data.size:
|
32 |
+
window_size = data.size
|
33 |
+
if window_size%2 == 0:
|
34 |
+
window_size = window_size - 1
|
35 |
+
outside_valid_window_size = int((window_size-1)/2)
|
36 |
+
start = np.array([np.sum(data[0:(2*k+1)]/(2*k+1)) for k in range(outside_valid_window_size)])
|
37 |
+
end = np.array([np.sum(data[-(2*k+1):]/(2*k+1)) for k in range(outside_valid_window_size)])[::-1]
|
38 |
+
smoothed_data = np.convolve(data,np.ones(window_size,dtype=int),'valid')/window_size
|
39 |
+
return np.hstack((start,smoothed_data,end))
|
40 |
+
|
41 |
+
def corr(x: np.ndarray, y: np.ndarray):
|
42 |
+
"""
|
43 |
+
Returns a matrix of the pairwise correlation coefficient between each pair of columns
|
44 |
+
in the input matrices x and y.
|
45 |
+
|
46 |
+
Parameters:
|
47 |
+
-----------
|
48 |
+
x : np.ndarray
|
49 |
+
Input matrix, specified as an n x k_1 matrix. Its rows correspond to
|
50 |
+
observations, and the columns correspond to variables.
|
51 |
+
y : np.ndarray
|
52 |
+
Input matrix, specified as an n x k_2 matrix. Its rows correspond to
|
53 |
+
observations, and the columns correspond to variables.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
--------
|
57 |
+
rho - Pairwise linear correlation coefficient, returned as a matrix.
|
58 |
+
|
59 |
+
Notes:
|
60 |
+
------
|
61 |
+
Implements MATLAB's corr function.
|
62 |
+
"""
|
63 |
+
return np.corrcoef(x,y)[0][1]
|
64 |
+
|
65 |
+
def autocorr(x:np.ndarray,
|
66 |
+
max_lags=10):
|
67 |
+
"""
|
68 |
+
Returns the correlations and associated lags of the univariate time series x.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
-----------
|
72 |
+
x : np.ndarray
|
73 |
+
Observed univariate time series.
|
74 |
+
max_lags : int
|
75 |
+
Number of lags, specified as a positive integer.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
acf : np.ndarray
|
79 |
+
Correlations, returned as a numeric vector of length ``max_lags`` + 1.
|
80 |
+
lags : np.ndarray
|
81 |
+
Autocorrelation lags.
|
82 |
+
|
83 |
+
Notes:
|
84 |
+
------
|
85 |
+
Modified version of matplotlib's acorr function.
|
86 |
+
"""
|
87 |
+
Nx = len(x)
|
88 |
+
|
89 |
+
correls = correlate(x, x, mode="full")
|
90 |
+
correls = correls / np.dot(x, x)
|
91 |
+
|
92 |
+
if max_lags is None:
|
93 |
+
max_lags = Nx - 1
|
94 |
+
|
95 |
+
if max_lags >= Nx or max_lags < 1:
|
96 |
+
raise ValueError('maxlags must be None or strictly '
|
97 |
+
'positive < %d' % Nx)
|
98 |
+
|
99 |
+
lags = np.arange(-max_lags, max_lags + 1)
|
100 |
+
acf = correls[Nx - 1 - max_lags:Nx + max_lags]
|
101 |
+
|
102 |
+
return acf, lags
|
103 |
+
|
104 |
+
def convert_to_rast(behavior_ts, time_max):
|
105 |
+
"""
|
106 |
+
Converts a list of behavior time stamps to a one-hot vector where 0 indicates no
|
107 |
+
presence of the given behavior, and 1 indicates presence of it.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
behavior_ts - a list of time stamps (start and end) for a particular behavior\n
|
111 |
+
time_max - the length in frames of the vector
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
behavior_rast - a one-hot vector
|
115 |
+
"""
|
116 |
+
behavior_rast = np.zeros(time_max)
|
117 |
+
for time_stamps in behavior_ts:
|
118 |
+
start = int(round(time_stamps[0]))
|
119 |
+
end = int(round(time_stamps[1] + 1))
|
120 |
+
if start > time_max:
|
121 |
+
break
|
122 |
+
if end > time_max:
|
123 |
+
end = time_max
|
124 |
+
np.put(behavior_rast,range(start,end),np.ones(end-start))
|
125 |
+
return behavior_rast
|
126 |
+
|
127 |
+
def convert_to_raster(bouts: list,
|
128 |
+
neural_activity_sr: float,
|
129 |
+
observation_sr: float,
|
130 |
+
max_frame: int):
|
131 |
+
"""
|
132 |
+
Converts bouts into a behavior raster, a one hot encoding of a behavior describing
|
133 |
+
when it is active.
|
134 |
+
|
135 |
+
It is often the case that the start and stop timestamps found in ``bouts`` are
|
136 |
+
collected at a different sample rate than ``neural_activity``, which are often what
|
137 |
+
behavior rasters align to. In order to align the two, a ratio between the sample
|
138 |
+
rates of ``neural_activity`` and the bouts of behavior, which are observations,
|
139 |
+
is calculated and then multiplied to the timestamps.
|
140 |
+
|
141 |
+
Parameters:
|
142 |
+
-----------
|
143 |
+
bouts : np.ndarray
|
144 |
+
An array where each element is a pair of integers where the first integer denotes
|
145 |
+
the beginning of a bout of behavior, and the second integer denotes the end of
|
146 |
+
the bout.
|
147 |
+
neural_activity_sr : float
|
148 |
+
Sample rate of ``neural_activity``.
|
149 |
+
observation_sr : float
|
150 |
+
Sample rate for the ``bouts`` used.
|
151 |
+
max_frame : int
|
152 |
+
The length of the behavior raster, often set to the number of frames of
|
153 |
+
``neural_activity``.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
--------
|
157 |
+
behavior_raster : np.ndarray
|
158 |
+
A raster (a one hot encoding) of a behavior, describing when it is active.
|
159 |
+
"""
|
160 |
+
sr_ratio = neural_activity_sr/observation_sr
|
161 |
+
behavior_ts_adjusted = bouts*sr_ratio
|
162 |
+
behavior_raster = np.zeros(max_frame)
|
163 |
+
for time_stamps in behavior_ts_adjusted:
|
164 |
+
start = int(round(time_stamps[0]))
|
165 |
+
end = int(round(time_stamps[1] + 1))
|
166 |
+
if start > max_frame:
|
167 |
+
break
|
168 |
+
if end > max_frame:
|
169 |
+
end = max_frame
|
170 |
+
np.put(behavior_raster,range(start,end),np.ones(end-start))
|
171 |
+
return behavior_raster
|
172 |
+
|
173 |
+
def convert_to_bouts(behavior_raster: np.ndarray):
|
174 |
+
"""
|
175 |
+
Converts a behavior raster into behavior bouts, an array where each element is a
|
176 |
+
pair of timestamps (int) where the first timestamp denotes the beginning of a bout of
|
177 |
+
behavior, and the second timestamp denotes the end of the bout.
|
178 |
+
|
179 |
+
Parameters:
|
180 |
+
-----------
|
181 |
+
behavior_raster : np.ndarray
|
182 |
+
A raster (a one hot encoding) of a behavior, describing when it is active.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
--------
|
186 |
+
bouts : np.ndarray
|
187 |
+
An array where each element is a pair of timestamps (int) where the first
|
188 |
+
timestamp denotes the beginning of a bout of behavior, and the second timestamp
|
189 |
+
denotes the end of the bout.
|
190 |
+
"""
|
191 |
+
dt = behavior_raster[1:] - behavior_raster[:-1]
|
192 |
+
start = np.where(dt==1)[0] + 1
|
193 |
+
stop = np.where(dt==-1)[0]
|
194 |
+
if behavior_raster[0]:
|
195 |
+
start = np.concatenate((np.array([0]),start))
|
196 |
+
if behavior_raster[-1]:
|
197 |
+
stop = np.concatenate((stop,[behavior_raster.size]))
|
198 |
+
bouts = np.hstack((np.reshape(start,(len(start),1)),
|
199 |
+
np.reshape(stop,(len(stop),1))))
|
200 |
+
return bouts
|
201 |
+
|
202 |
+
def merge_rasters_down(behavior_raster_array: np.ndarray)-> np.ndarray:
|
203 |
+
"""
|
204 |
+
For a behavior raster, merges down all rasters to one array in such a way that no
|
205 |
+
two behaviors are occuring at the same time.
|
206 |
+
|
207 |
+
It determines which behavior should remain 'on top' by determening which behavior
|
208 |
+
has the least amount of active frames.
|
209 |
+
|
210 |
+
This method should only be used on behavior rasters where all behaviors come from a
|
211 |
+
single channel.
|
212 |
+
|
213 |
+
Parameters:
|
214 |
+
-----------
|
215 |
+
behavior_raster_array : np.ndarray
|
216 |
+
An array where each row is a behavior raster, a one hot encoding of behaviors,
|
217 |
+
describing when that behavior is active. Each row of this array must use a
|
218 |
+
different value to indicate that a behavior is active (for example, if one
|
219 |
+
row uses 1s, another row must not use 1 as well).
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
--------
|
223 |
+
single_track : np.ndarray
|
224 |
+
An array which is the length of a behavior raster in ``behavior_raster_array``,
|
225 |
+
where each entry is either 0 indicating that no behavior is active, or a value
|
226 |
+
indicating that a specific behavior is active.
|
227 |
+
"""
|
228 |
+
# single track
|
229 |
+
single_track = np.zeros((1,behavior_raster_array.shape[1]))
|
230 |
+
|
231 |
+
# determine order to insert row values
|
232 |
+
num_active_frames = [np.sum(np.where(row > 0, 1, 0)) for row in behavior_raster_array]
|
233 |
+
|
234 |
+
for i in range(behavior_raster_array.shape[0]):
|
235 |
+
max_i = np.argmax(num_active_frames)
|
236 |
+
num_active_frames[max_i] = -1
|
237 |
+
|
238 |
+
unique_values = np.unique(behavior_raster_array[max_i])
|
239 |
+
if len(unique_values) > 1: value = unique_values[1]
|
240 |
+
else: value = 0
|
241 |
+
active_inds = np.where(behavior_raster_array[max_i] == value)[0]
|
242 |
+
|
243 |
+
single_track[:,active_inds] = value
|
244 |
+
return single_track
|
245 |
+
|
246 |
+
def separate_tracks(single_track: np.ndarray,
|
247 |
+
behavior_values: list):
|
248 |
+
"""
|
249 |
+
For a single track, separates each unique value (except for 0) into its own raster
|
250 |
+
within a 2-D array.
|
251 |
+
|
252 |
+
Parameters:
|
253 |
+
-----------
|
254 |
+
single_track : np.ndarray
|
255 |
+
An array which is the length of a behavior raster in ``behavior_raster_array``,
|
256 |
+
where each entry is either 0 indicating that no behavior is active, or a value
|
257 |
+
indicating that a specific behavior is active.
|
258 |
+
behavior_values : list
|
259 |
+
A list of values corresponding to the specific behaviors within ``single_track``.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
--------
|
263 |
+
behavior_raster_array : np.ndarray
|
264 |
+
An array where each row is a behavior raster, a one hot encoding of behaviors,
|
265 |
+
describing when that behavior is active.
|
266 |
+
"""
|
267 |
+
if len(behavior_values) < np.unique(single_track).size - 1:
|
268 |
+
raise KeyError("There are not sufficient values within ``behavior_values`` to "
|
269 |
+
"accomodate those present in ``single_track``.")
|
270 |
+
tracks = []
|
271 |
+
for value in behavior_values:
|
272 |
+
tracks.append(np.where(single_track == value, value, 0))
|
273 |
+
return np.vstack(tracks)
|
274 |
+
|
275 |
+
def config_neural_activity(config: dict, neural_activity: np.ndarray):
|
276 |
+
"""
|
277 |
+
Configures `neural_activity` according to parameters set in config.
|
278 |
+
|
279 |
+
Parameters:
|
280 |
+
-----------
|
281 |
+
config : dict
|
282 |
+
A dictionary which specifies the following parameters: 'smooth_window',
|
283 |
+
'baseline_frame', and 'zscore_method'. 'zscore_method' is one of "All Data",
|
284 |
+
"Baseline", or "No Z-Score".
|
285 |
+
neural_activity : np.ndarray
|
286 |
+
Neural activity being used.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
--------
|
290 |
+
mod_neural_activity : np.ndarray
|
291 |
+
Modified `neural_activity`, accodring to `config`.
|
292 |
+
"""
|
293 |
+
smooth_window = config['smooth_window']
|
294 |
+
zscore_method = config['zscore_method']
|
295 |
+
baseline_frame = config['baseline_frame']
|
296 |
+
|
297 |
+
# smooth
|
298 |
+
if len(neural_activity.shape) > 1:
|
299 |
+
neural_data_smooth = np.zeros(neural_activity.shape)
|
300 |
+
for i in range(neural_activity.shape[0]):
|
301 |
+
neural_data_smooth[i] = smooth(neural_activity[i], int(smooth_window))
|
302 |
+
mod_neural_activity = neural_data_smooth
|
303 |
+
else:
|
304 |
+
mod_neural_activity = smooth(neural_activity, int(smooth_window))
|
305 |
+
|
306 |
+
# z-score
|
307 |
+
if zscore_method == 'Baseline' and (not baseline_frame is None or baseline_frame == 0):
|
308 |
+
if len(neural_activity.shape)> 1:
|
309 |
+
mean = mod_neural_activity[:,:baseline_frame].mean(axis=1,keepdims=True)
|
310 |
+
std = mod_neural_activity[:,:baseline_frame].std(axis=1,keepdims=True)
|
311 |
+
else:
|
312 |
+
mean = mod_neural_activity[:baseline_frame].mean()
|
313 |
+
std = mod_neural_activity[:baseline_frame].std()
|
314 |
+
mod_neural_activity = (mod_neural_activity - mean) / std
|
315 |
+
elif zscore_method == 'No Z-Score':
|
316 |
+
mod_neural_activity = mod_neural_activity
|
317 |
+
else:
|
318 |
+
if len(neural_activity.shape) > 1:
|
319 |
+
mod_neural_activity = zscore(mod_neural_activity,axis=1)
|
320 |
+
else:
|
321 |
+
mod_neural_activity = zscore(mod_neural_activity)
|
322 |
+
return mod_neural_activity
|
323 |
+
|
324 |
+
def compress_annotations(annot: dict, downsample_rate: int, max_frame: int)-> dict:
|
325 |
+
"""
|
326 |
+
Takes in an annotation dictionary and creates a single raster per channel, where the
|
327 |
+
raster contains the behaviors from their respective channel.
|
328 |
+
|
329 |
+
annot : dict
|
330 |
+
Dictionary of beginning and end frames for behaviors.
|
331 |
+
downsample_rate : int
|
332 |
+
The rate at which samples should be taken. Divides bout timing (in frames) by
|
333 |
+
value.
|
334 |
+
max_frame : int
|
335 |
+
The last frame for annotations from `annot`.
|
336 |
+
"""
|
337 |
+
annot_single_track = {}
|
338 |
+
channel_behavior_map = {}
|
339 |
+
for channel in annot:
|
340 |
+
channel_rasters = []
|
341 |
+
behavior_map = {}
|
342 |
+
behavior_map.update({0: 'None'})
|
343 |
+
for i, behavior in enumerate(annot[channel]):
|
344 |
+
bouts = annot[channel][behavior]
|
345 |
+
raster = convert_to_raster(bouts, 1, downsample_rate, max_frame)
|
346 |
+
channel_rasters.append(raster*(i+1))
|
347 |
+
behavior_map.update({(i+1) : behavior})
|
348 |
+
channel_raster = merge_rasters_down(np.array(channel_rasters))[0]
|
349 |
+
annot_single_track.update({channel : channel_raster})
|
350 |
+
channel_behavior_map.update({channel : behavior_map})
|
351 |
+
return annot_single_track, channel_behavior_map
|
352 |
+
|
353 |
+
def compress_compressed_annotations(annot_single_track: dict,
|
354 |
+
channel_behavior_map: dict,
|
355 |
+
max_frame: int):
|
356 |
+
"""
|
357 |
+
Further compresses the results from `compress_annotations` to get a single array
|
358 |
+
where each entry is a list of the behaviors present at that frame across all channels.
|
359 |
+
"""
|
360 |
+
labels = []
|
361 |
+
for frame in range(max_frame):
|
362 |
+
labels_at_frame = []
|
363 |
+
for channel in annot_single_track:
|
364 |
+
channel_raster = annot_single_track[channel]
|
365 |
+
behavior_map = channel_behavior_map[channel]
|
366 |
+
behavior_value = int(channel_raster[frame])
|
367 |
+
behavior_label = behavior_map.get(behavior_value)
|
368 |
+
labels_at_frame.append(behavior_label)
|
369 |
+
labels.append('||'.join(labels_at_frame))
|
370 |
+
return labels
|
371 |
+
|
372 |
+
def generate_label_array(annot: dict,
|
373 |
+
downsample_rate: int,
|
374 |
+
max_frame: int)-> list[str]:
|
375 |
+
"""
|
376 |
+
Generates an array of lists of labels, where each entry is a video frame, and the
|
377 |
+
labels come from each channel in `annot`.
|
378 |
+
"""
|
379 |
+
annot_single_track,\
|
380 |
+
channel_behavior_map = compress_annotations(annot, downsample_rate, max_frame)
|
381 |
+
labels = compress_compressed_annotations(annot_single_track,
|
382 |
+
channel_behavior_map,
|
383 |
+
max_frame)
|
384 |
+
return labels
|
utils/mp4Io.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from qtpy.QtGui import QImage, QPixmap
|
5 |
+
|
6 |
+
class mp4Io_reader():
|
7 |
+
def __init__(self, filename, info=[]):
|
8 |
+
self.filename = filename
|
9 |
+
self.file = cv2.VideoCapture(filename)
|
10 |
+
|
11 |
+
if self.file.isOpened()==False:
|
12 |
+
print("Error in opening video file.")
|
13 |
+
|
14 |
+
self.header={}
|
15 |
+
if info==[]:
|
16 |
+
self.readHeader()
|
17 |
+
|
18 |
+
def readHeader(self):
|
19 |
+
|
20 |
+
self.header = {
|
21 |
+
'width': int(self.file.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
22 |
+
'height': int(self.file.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
23 |
+
'fps': self.file.get(cv2.CAP_PROP_FPS),
|
24 |
+
'numFrames': int(self.file.get(cv2.CAP_PROP_FRAME_COUNT))
|
25 |
+
}
|
26 |
+
|
27 |
+
def seek(self, index):
|
28 |
+
|
29 |
+
self.file.set(cv2.CAP_PROP_POS_FRAMES, index)
|
30 |
+
|
31 |
+
def getTs(self,n=None):
|
32 |
+
if n==None:
|
33 |
+
n = self.header['numFrames']
|
34 |
+
|
35 |
+
ts = np.zeros(n+1)
|
36 |
+
for i in np.arange(1,n+1):
|
37 |
+
self.seek(i)
|
38 |
+
self.file.read()
|
39 |
+
ts[i] = self.file.get(cv2.CAP_PROP_POS_MSEC)/1000.
|
40 |
+
|
41 |
+
self.ts = ts[1:]
|
42 |
+
return self.ts
|
43 |
+
|
44 |
+
def getFrame(self, index, decode=True):
|
45 |
+
|
46 |
+
self.seek(index)
|
47 |
+
ret, frame = self.file.read()
|
48 |
+
|
49 |
+
ts = self.file.get(cv2.CAP_PROP_POS_MSEC)/1000.
|
50 |
+
return frame, ts
|
51 |
+
|
52 |
+
def getFrameAsQPixmap(self, index, decode=True):
|
53 |
+
image, _ = self.getFrame(index, decode)
|
54 |
+
h, w, ch = image.shape
|
55 |
+
bytes_per_line = ch * w
|
56 |
+
convert_to_Qt_format = QImage(image.data, w, h, bytes_per_line, QImage.Format_BGR888)
|
57 |
+
return QPixmap.fromImage(convert_to_Qt_format)
|
58 |
+
|
59 |
+
def close(self):
|
60 |
+
self.file.release()
|
utils/seqIo.py
ADDED
@@ -0,0 +1,1189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
from datetime import datetime, timedelta,date
|
7 |
+
import time
|
8 |
+
from matplotlib.dates import date2num, num2date
|
9 |
+
# import colour_demosaicing
|
10 |
+
import skvideo.io
|
11 |
+
import re
|
12 |
+
import pickle
|
13 |
+
import cv2
|
14 |
+
import progressbar as pb
|
15 |
+
|
16 |
+
# Create interface sr for reading seq files.
|
17 |
+
# sr = seqIo_reader( fName )
|
18 |
+
# Create interface sw for writing seq files.
|
19 |
+
# sw = seqIo_Writer( fName, header )
|
20 |
+
# Crop sub-sequence from seq file.
|
21 |
+
# seqIo_crop( fName, 'crop', tName, frames )
|
22 |
+
# Extract images from seq file to target directory or array.
|
23 |
+
# Is = seqIo_toImgs( fName, tDir=[],skip=1,f0=0,f1=np.inf,ext='' )
|
24 |
+
# Create seq file from an array or directory of images or from an AVI file. DONE
|
25 |
+
# seqIo_frImgs( fName, fName,header,aviName=[],Is=[],sDir=[],name='I',ndig=5,f0=0,f1=1e6 )
|
26 |
+
# Convert seq file by applying imgFun(I) to each frame I.
|
27 |
+
# seqIo( fName, 'convert', tName, imgFun, varargin )
|
28 |
+
# Replace header of seq file with provided info.
|
29 |
+
# seqIo( fName, 'newHeader', info )
|
30 |
+
# Create interface sr for reading dual seq files.
|
31 |
+
# sr = seqIo( fNames, 'readerDual', [cache] )
|
32 |
+
|
33 |
+
FRAME_FORMAT_RAW_GRAY = 100 #RAW
|
34 |
+
FRAME_FORMAT_RAW_COLOR = 200 #RAW
|
35 |
+
FRAME_FORMAT_JPEG_GRAY = 102 #JPG
|
36 |
+
FRAME_FORMAT_JPEG_COLOR = 201 #JPG
|
37 |
+
FRAME_FORMAT_MONOB = 101 #BRGB8
|
38 |
+
FRAME_FORMAT_MONOB_JPEG = 103 #JBRGB
|
39 |
+
FRAME_FORMAT_PNG_GRAY = 0x001 #PNG
|
40 |
+
FRAME_FORMAT_PNG_COLOR = 0x002 #PNG
|
41 |
+
|
42 |
+
#matlab equivalent fread
|
43 |
+
def fread(fid, nelements, dtype):
|
44 |
+
|
45 |
+
"""Equivalent to Matlab fread function"""
|
46 |
+
|
47 |
+
if dtype is np.str_:
|
48 |
+
dt = np.uint8 # WARNING: assuming 8-bit ASCII for np.str!
|
49 |
+
else:
|
50 |
+
dt = dtype
|
51 |
+
|
52 |
+
data_array = np.fromfile(fid, dt, nelements)
|
53 |
+
if data_array.size == 1:
|
54 |
+
data_array = data_array[0]
|
55 |
+
return data_array
|
56 |
+
|
57 |
+
def fwrite(fid,a,dtype=np.str_):
|
58 |
+
# assuming 8but ASCII for string
|
59 |
+
if dtype is np.str_:
|
60 |
+
dt = np.uint8 # WARNING: assuming 8-bit ASCII for np.str!
|
61 |
+
else:
|
62 |
+
dt = dtype
|
63 |
+
if isinstance(a,np.ndarray):
|
64 |
+
data_array = a.astype(dt)
|
65 |
+
else:
|
66 |
+
data_array = np.array(a).astype(dt)
|
67 |
+
data_array.tofile(fid)
|
68 |
+
|
69 |
+
def tsSync(video_path, srTop, srFront):
|
70 |
+
#tsSync
|
71 |
+
name = video_path.split('/')[-1]
|
72 |
+
srDepth=seqIo_reader(video_path + '/' + name + '_DepGr_Raw.seq')
|
73 |
+
|
74 |
+
#read timestamp of individual frames
|
75 |
+
tsTop = srTop.getTs()
|
76 |
+
tsFront = srFront.getTs()
|
77 |
+
tsDepth = srDepth.getTs()
|
78 |
+
|
79 |
+
#check the version of the videos
|
80 |
+
videoDateStr = re.search('[0-9]+_[0-9]+-[0-9]+-[0-9]+',name).group(0)
|
81 |
+
videoDateNum =date2num(datetime.strptime(videoDateStr ,'%Y%m%d_%H-%M-%S'))
|
82 |
+
videoRefNum = date2num(datetime.strptime('20150401_00-00-00','%Y%m%d_%H-%M-%S'))
|
83 |
+
seqV = 1 if videoDateNum < videoRefNum else 2
|
84 |
+
|
85 |
+
##correlate timestamps from one view to another
|
86 |
+
mapTs = {}
|
87 |
+
if seqV ==1:
|
88 |
+
for f in range(len(tsDepth)):
|
89 |
+
if tsDepth[f] - np.floor(tsDepth[f])>=.5: # Santiago's bug acquisistion software
|
90 |
+
tsDepth[f]-= 1
|
91 |
+
tsDepth-= 0.03 # substract the systematic timeshift
|
92 |
+
|
93 |
+
#load front and top view and convert from UTC to PST
|
94 |
+
hourShift = np.round((tsDepth[0]-tsTop[0])/3600.)*3600
|
95 |
+
timeShift = tsDepth[0] - tsTop[0] - .066
|
96 |
+
# print('tsDepth[0] - tsTop[0] = timeShift: %s sec'% str(timeShift - hourShift))
|
97 |
+
tsTop += timeShift
|
98 |
+
tsFront += timeShift
|
99 |
+
|
100 |
+
srTop.ts = tsTop
|
101 |
+
srFront.ts = tsFront
|
102 |
+
srDepth.ts = tsDepth
|
103 |
+
|
104 |
+
# Convert timestamps from left to right
|
105 |
+
mapTs['T2F'] = transformTs(tsTop, tsFront)
|
106 |
+
mapTs['F2T'] = transformTs(tsFront, tsTop)
|
107 |
+
mapTs['T2D'] = transformTs(tsTop, tsDepth)
|
108 |
+
mapTs['D2T'] = transformTs(tsDepth, tsTop)
|
109 |
+
mapTs['F2D'] = transformTs(tsFront, tsDepth)
|
110 |
+
mapTs['D2F'] = transformTs(tsDepth, tsFront)
|
111 |
+
else:
|
112 |
+
T = len(tsTop)
|
113 |
+
F = len(tsFront)
|
114 |
+
D = len(tsDepth)
|
115 |
+
mapTs['T2F'] = resizeTs(T,F)
|
116 |
+
mapTs['F2T'] = resizeTs(F,T)
|
117 |
+
mapTs['T2D'] = resizeTs(T,D)
|
118 |
+
mapTs['D2T'] = resizeTs(D,T)
|
119 |
+
mapTs['F2D'] = resizeTs(F,D)
|
120 |
+
mapTs['D2F'] = resizeTs(D,F)
|
121 |
+
|
122 |
+
#display the time in string format
|
123 |
+
# fTp=5
|
124 |
+
# fFr = mapTs['T2F'][fTp]
|
125 |
+
# fDp = mapTs['T2D'][fTp]
|
126 |
+
# tF = ts2str(tsFront[fFr])
|
127 |
+
# tT = ts2str(tsTop[fTp])
|
128 |
+
# tD = ts2str(tsDepth[fDp])
|
129 |
+
#
|
130 |
+
# print('mapping to tTop:')
|
131 |
+
# print('tTop = ' + str(fTp) + ' - ' + tT)
|
132 |
+
# print('tFront = ' + str(fFr) + ' - ' + tF)
|
133 |
+
# print('tDepth = ' + str(fDp) + ' - ' + tD)
|
134 |
+
|
135 |
+
# fDp = np.round(len(tsDepth)/1.5).astype(int)
|
136 |
+
# fTp = mapTs['D2T'][fDp]
|
137 |
+
# tT = ts2str(tsTop[fTp])
|
138 |
+
# tD = ts2str(tsDepth[fDp])
|
139 |
+
# print('mapping to tTop:')
|
140 |
+
# print('tTop = ' + str(fTp) + '. ' + tT)
|
141 |
+
# print('tDepth = ' + str(fDp) + '. ' + tD)
|
142 |
+
|
143 |
+
return mapTs,srTop,srFront
|
144 |
+
|
145 |
+
def transformTs(ts1, ts2):
|
146 |
+
# map ts1 to ts2 where ts2 is the reference
|
147 |
+
rankTs = np.zeros((len(ts1),4))
|
148 |
+
for f in range(len(ts1)):
|
149 |
+
tsDiff = ts2-ts1[f]
|
150 |
+
tsRank = np.sort(abs(tsDiff))
|
151 |
+
ind = np.argsort(abs(tsDiff))
|
152 |
+
rankTs[f,:] = [ind[0]+1,ind[1]+1,tsRank[0],tsRank[1]]
|
153 |
+
|
154 |
+
mapTs = np.round(smooth(rankTs[:,0],7)).astype(int)
|
155 |
+
mapTs = np.round(smooth(mapTs,7)).astype(int)
|
156 |
+
return mapTs
|
157 |
+
|
158 |
+
def smooth(a, WSZ):
|
159 |
+
out0 = np.convolve(a, np.ones(WSZ, dtype=int), 'valid') / WSZ
|
160 |
+
r = np.arange(1, WSZ - 1, 2)
|
161 |
+
start = np.cumsum(a[:WSZ - 1])[::2] / r
|
162 |
+
stop = (np.cumsum(a[:-WSZ:-1])[::2] / r)[::-1]
|
163 |
+
return np.concatenate((start, out0, stop))
|
164 |
+
|
165 |
+
def resizeTs(t1, t2):
|
166 |
+
if t1>t2:
|
167 |
+
mapTs = np.hstack((np.array(range(t2))+1,np.ones((t1-t2),int)*t2))
|
168 |
+
else:
|
169 |
+
mapTs = np.array(range(t1))+1
|
170 |
+
|
171 |
+
return mapTs
|
172 |
+
|
173 |
+
def ts2str(ts):
|
174 |
+
t = ts / 86400. + date.toordinal(date(1971, 1, 2))
|
175 |
+
# datetime.fromtimestamp(t)
|
176 |
+
str_time = (datetime.fromordinal(int(t)) + timedelta(days=t % 1) - timedelta(days=366)).strftime(
|
177 |
+
"%Y-%m-%d %H:%M:%S") + '.%03d' % np.round((ts - np.floor(ts)) * 1000)
|
178 |
+
return str_time
|
179 |
+
|
180 |
+
def parse_ann(f_ann):
|
181 |
+
header = 'Caltech Behavior Annotator - Annotation File'
|
182 |
+
conf = 'Configuration file:'
|
183 |
+
fid = open(f_ann)
|
184 |
+
ann = fid.read().splitlines()
|
185 |
+
fid.close()
|
186 |
+
NFrames = []
|
187 |
+
# check the header
|
188 |
+
assert ann[0].rstrip() == header
|
189 |
+
assert ann[1].rstrip() == ''
|
190 |
+
assert ann[2].rstrip() == conf
|
191 |
+
# parse action list
|
192 |
+
l = 3
|
193 |
+
names = [None] * 1000
|
194 |
+
keys = [None] * 1000
|
195 |
+
types = []
|
196 |
+
bnds = []
|
197 |
+
k = -1
|
198 |
+
|
199 |
+
# get config keys and names
|
200 |
+
while True:
|
201 |
+
ann[l] = ann[l].rstrip()
|
202 |
+
if not isinstance(ann[l], str) or not ann[l]:
|
203 |
+
l += 1
|
204 |
+
break
|
205 |
+
values = ann[l].split()
|
206 |
+
k += 1
|
207 |
+
names[k] = values[0]
|
208 |
+
keys[k] = values[1]
|
209 |
+
l += 1
|
210 |
+
names = names[:k + 1]
|
211 |
+
keys = keys[:k + 1]
|
212 |
+
|
213 |
+
# read in each stream in turn until end of file
|
214 |
+
bnds0 = [None] * 10000
|
215 |
+
types0 = [None] * 10000
|
216 |
+
actions0 = [None] * 10000
|
217 |
+
nStrm1 = 0
|
218 |
+
while True:
|
219 |
+
ann[l] = ann[l].rstrip()
|
220 |
+
nStrm1 += 1
|
221 |
+
t = ann[l].split(":")
|
222 |
+
l += 1
|
223 |
+
ann[l] = ann[l].rstrip()
|
224 |
+
assert int(t[0][1]) == nStrm1
|
225 |
+
assert ann[l] == '-----------------------------'
|
226 |
+
l += 1
|
227 |
+
bnds1 = np.ones((10000, 2), dtype=int)
|
228 |
+
types1 = np.ones(10000, dtype=int) * -1
|
229 |
+
actions1 = [None] * 10000
|
230 |
+
k = 0
|
231 |
+
# start the annotations
|
232 |
+
while True:
|
233 |
+
ann[l] = ann[l].rstrip()
|
234 |
+
t = ann[l]
|
235 |
+
if not isinstance(t, str) or not t:
|
236 |
+
l += 1
|
237 |
+
break
|
238 |
+
t = ann[l].split()
|
239 |
+
type = [i for i in range(len(names)) if t[2] == names[i]]
|
240 |
+
type = type[0]
|
241 |
+
if type == None:
|
242 |
+
print('undefined behavior' + t[2])
|
243 |
+
if bnds1[k - 1, 1] != int(t[0]) - 1 and k > 0:
|
244 |
+
print('%d ~= %d' % (bnds1[k, 1], int(t[0]) - 1))
|
245 |
+
bnds1[k, :] = [int(t[0]), int(t[1])]
|
246 |
+
types1[k] = type
|
247 |
+
actions1[k] = names[type]
|
248 |
+
k += 1
|
249 |
+
l += 1
|
250 |
+
if l == len(ann):
|
251 |
+
break
|
252 |
+
if nStrm1 == 1:
|
253 |
+
nFrames = bnds1[k - 1, 1]
|
254 |
+
assert nFrames == bnds1[k - 1, 1]
|
255 |
+
bnds0[nStrm1 - 1] = bnds1[:k]
|
256 |
+
types0[nStrm1 - 1] = types1[:k]
|
257 |
+
actions0[nStrm1 - 1] = actions1[:k]
|
258 |
+
if l == len(ann):
|
259 |
+
break
|
260 |
+
while not ann[l]:
|
261 |
+
l += 1
|
262 |
+
|
263 |
+
bnds = bnds0[:nStrm1]
|
264 |
+
types = types0[:nStrm1]
|
265 |
+
actions = actions0[:nStrm1]
|
266 |
+
|
267 |
+
idx = 0
|
268 |
+
if len(actions[0]) < len(actions[1]):
|
269 |
+
idx = 1
|
270 |
+
type_frame = []
|
271 |
+
action_frame = []
|
272 |
+
len_bnd = []
|
273 |
+
|
274 |
+
for i in range(len(bnds[idx])):
|
275 |
+
numf = bnds[idx][i, 1] - bnds[idx][i, 0] + 1
|
276 |
+
len_bnd.append(numf)
|
277 |
+
action_frame.extend([actions[idx][i]] * numf)
|
278 |
+
type_frame.extend([types[idx][i]] * numf)
|
279 |
+
|
280 |
+
ann_dict = {
|
281 |
+
'keys': keys,
|
282 |
+
'behs': names,
|
283 |
+
'nstrm': nStrm1,
|
284 |
+
'nFrames': nFrames,
|
285 |
+
'behs_se': bnds,
|
286 |
+
'behs_dur': len_bnd,
|
287 |
+
'behs_bout': actions,
|
288 |
+
'behs_frame': action_frame
|
289 |
+
}
|
290 |
+
|
291 |
+
return ann_dict
|
292 |
+
|
293 |
+
def parse_ann_dual(f_ann):
|
294 |
+
header = 'Caltech Behavior Annotator - Annotation File'
|
295 |
+
conf = 'Configuration file:'
|
296 |
+
fid = open(f_ann)
|
297 |
+
ann = fid.read().splitlines()
|
298 |
+
fid.close()
|
299 |
+
NFrames = []
|
300 |
+
# check the header
|
301 |
+
assert ann[0].rstrip() == header
|
302 |
+
assert ann[1].rstrip() == ''
|
303 |
+
assert ann[2].rstrip() == conf
|
304 |
+
# parse action list
|
305 |
+
l = 3
|
306 |
+
names = [None] * 1000
|
307 |
+
keys = [None] * 1000
|
308 |
+
types = []
|
309 |
+
bnds = []
|
310 |
+
k = -1
|
311 |
+
|
312 |
+
# get config keys and names
|
313 |
+
while True:
|
314 |
+
ann[l] = ann[l].rstrip()
|
315 |
+
if not isinstance(ann[l], str) or not ann[l]:
|
316 |
+
l += 1
|
317 |
+
break
|
318 |
+
values = ann[l].split()
|
319 |
+
k += 1
|
320 |
+
names[k] = values[0]
|
321 |
+
keys[k] = values[1]
|
322 |
+
l += 1
|
323 |
+
names = names[:k + 1]
|
324 |
+
keys = keys[:k + 1]
|
325 |
+
|
326 |
+
# read in each stream in turn until end of file
|
327 |
+
bnds0 = [None] * 10000
|
328 |
+
types0 = [None] * 10000
|
329 |
+
actions0 = [None] * 10000
|
330 |
+
nStrm1 = 0
|
331 |
+
while True:
|
332 |
+
ann[l] = ann[l].rstrip()
|
333 |
+
nStrm1 += 1
|
334 |
+
t = ann[l].split(":")
|
335 |
+
l += 1
|
336 |
+
ann[l] = ann[l].rstrip()
|
337 |
+
assert int(t[0][1]) == nStrm1
|
338 |
+
assert ann[l] == '-----------------------------'
|
339 |
+
l += 1
|
340 |
+
bnds1 = np.ones((10000, 2), dtype=int)
|
341 |
+
types1 = np.ones(10000, dtype=int) * -1
|
342 |
+
actions1 = [None] * 10000
|
343 |
+
k = 0
|
344 |
+
# start the annotations
|
345 |
+
while True:
|
346 |
+
ann[l] = ann[l].rstrip()
|
347 |
+
t = ann[l]
|
348 |
+
if not isinstance(t, str) or not t:
|
349 |
+
l += 1
|
350 |
+
break
|
351 |
+
t = ann[l].split()
|
352 |
+
type = [i for i in range(len(names)) if t[2] == names[i]]
|
353 |
+
type = type[0]
|
354 |
+
if type == None:
|
355 |
+
print('undefined behavior' + t[2])
|
356 |
+
if bnds1[k - 1, 1] != int(t[0]) - 1 and k > 0:
|
357 |
+
print('%d ~= %d' % (bnds1[k, 1], int(t[0]) - 1))
|
358 |
+
bnds1[k, :] = [int(t[0]), int(t[1])]
|
359 |
+
types1[k] = type
|
360 |
+
actions1[k] = names[type]
|
361 |
+
k += 1
|
362 |
+
l += 1
|
363 |
+
if l == len(ann):
|
364 |
+
break
|
365 |
+
if nStrm1 == 1:
|
366 |
+
nFrames = bnds1[k - 1, 1]
|
367 |
+
assert nFrames == bnds1[k - 1, 1]
|
368 |
+
bnds0[nStrm1 - 1] = bnds1[:k]
|
369 |
+
types0[nStrm1 - 1] = types1[:k]
|
370 |
+
actions0[nStrm1 - 1] = actions1[:k]
|
371 |
+
if l == len(ann):
|
372 |
+
break
|
373 |
+
while not ann[l]:
|
374 |
+
l += 1
|
375 |
+
|
376 |
+
bnds = bnds0[:nStrm1]
|
377 |
+
types = types0[:nStrm1]
|
378 |
+
actions = actions0[:nStrm1]
|
379 |
+
|
380 |
+
idx = 0
|
381 |
+
if len(actions[0]) < len(actions[1]):
|
382 |
+
idx = 1
|
383 |
+
type_frame = []
|
384 |
+
action_frame = []
|
385 |
+
len_bnd = []
|
386 |
+
|
387 |
+
|
388 |
+
for i in range(len(bnds[idx])):
|
389 |
+
numf = bnds[idx][i, 1] - bnds[idx][i, 0] + 1
|
390 |
+
len_bnd.append(numf)
|
391 |
+
action_frame.extend([actions[idx][i]] * numf)
|
392 |
+
type_frame.extend([types[idx][i]] * numf)
|
393 |
+
|
394 |
+
|
395 |
+
type_frame2 = []
|
396 |
+
action_frame2 = []
|
397 |
+
len_bnd2 = []
|
398 |
+
idx=1 if idx==0 else 0
|
399 |
+
|
400 |
+
for i in range(len(bnds[idx])):
|
401 |
+
numf = bnds[idx][i, 1] - bnds[idx][i, 0] + 1
|
402 |
+
len_bnd2.append(numf)
|
403 |
+
action_frame2.extend([actions[idx][i]] * numf)
|
404 |
+
type_frame2.extend([types[idx][i]] * numf)
|
405 |
+
|
406 |
+
ann_dict = {
|
407 |
+
'keys': keys,
|
408 |
+
'behs': names,
|
409 |
+
'nstrm': nStrm1,
|
410 |
+
'nFrames': nFrames,
|
411 |
+
'behs_se': bnds,
|
412 |
+
'behs_dur': len_bnd,
|
413 |
+
'behs_bout': actions,
|
414 |
+
'behs_frame': action_frame if 'interaction' not in action_frame else action_frame2,
|
415 |
+
'behs_frame2': action_frame2 if 'interaction' in action_frame2 else action_frame
|
416 |
+
}
|
417 |
+
|
418 |
+
return ann_dict
|
419 |
+
|
420 |
+
def syncTopFront(f,num_frames,num_framesf):
|
421 |
+
return int(round(f / (num_framesf - 1) * (num_frames - 1))) if num_framesf > num_frames else int(round(f / (num_frames - 1) * (num_framesf - 1)))
|
422 |
+
|
423 |
+
|
424 |
+
|
425 |
+
class seqIo_reader():
|
426 |
+
def __init__(self,fname,info=[],buildTable=True):
|
427 |
+
self.filename = fname
|
428 |
+
try:
|
429 |
+
self.file=open(fname,'rb')
|
430 |
+
except EnvironmentError as e:
|
431 |
+
print(os.strerror(e.errno))
|
432 |
+
self.header={}
|
433 |
+
self.seek_table=None
|
434 |
+
self.frames_read=-1
|
435 |
+
self.timestamp_length = 10
|
436 |
+
if info==[]:
|
437 |
+
self.readHeader()
|
438 |
+
else:
|
439 |
+
info.numFrames=0
|
440 |
+
if buildTable:
|
441 |
+
print("buildTable was True, so calling buildSeekTable()")
|
442 |
+
self.buildSeekTable(False)
|
443 |
+
|
444 |
+
|
445 |
+
def readHeader(self):
|
446 |
+
#make sure we do this at the beginning of the file
|
447 |
+
assert self.frames_read == -1, "Can only read header from beginning of file"
|
448 |
+
self.file.seek(0,0)
|
449 |
+
# pdb.set_trace()
|
450 |
+
|
451 |
+
# Read 1024 bytes (len of header)
|
452 |
+
tmp = fread(self.file,1024,np.uint8)
|
453 |
+
#check that the header is not all 0's
|
454 |
+
n=len(tmp)
|
455 |
+
if n<1024:raise ValueError('no header')
|
456 |
+
if all(tmp==0): raise ValueError('fully empty header')
|
457 |
+
self.file.seek(0,0)
|
458 |
+
#first 4 bytes stor 0XFEED next 24 store 'Norpix seq '
|
459 |
+
magic_number = fread(self.file,1,np.uint32)
|
460 |
+
name = fread(self.file,10,np.uint16)
|
461 |
+
name = ''.join(map(chr,name))
|
462 |
+
if not '{0:X}'.format(magic_number)=='FEED' or not name=='Norpix seq':raise ValueError('invalid header')
|
463 |
+
self.file.seek(4,1)
|
464 |
+
#next 8 bytes for version and header size (1024) then 512 for desc
|
465 |
+
version = int(fread(self.file,1,np.int32))
|
466 |
+
hsize =int(fread(self.file,1,np.uint32))
|
467 |
+
assert(hsize)==1024 ,"incorrect header size"
|
468 |
+
# d = self.file.read(512)
|
469 |
+
descr=fread(self.file,256,np.uint16)
|
470 |
+
# descr = ''.join(map(chr,descr))
|
471 |
+
# descr = ''.join(map(unichr,descr)).replace('\x00',' ')
|
472 |
+
descr = ''.join([chr(x) for x in descr]).replace('\x00',' ')
|
473 |
+
# descr = descr.encode('utf-8')
|
474 |
+
#read more info
|
475 |
+
tmp = fread(self.file,9,np.uint32)
|
476 |
+
assert tmp[7]==0, "incorrect origin"
|
477 |
+
fps = fread(self.file,1,np.float64)
|
478 |
+
codec = 'imageFormat' + '%03d'%tmp[5]
|
479 |
+
desc_format = fread(self.file,1,np.uint32)
|
480 |
+
padding = fread(self.file,428,np.uint8)
|
481 |
+
padding = ''.join(map(chr,padding))
|
482 |
+
#store info
|
483 |
+
self.header={'magicNumber':magic_number,
|
484 |
+
'name':name,
|
485 |
+
'seqVersion': version,
|
486 |
+
'headerSize':hsize,
|
487 |
+
'descr': descr,
|
488 |
+
'width':int(tmp[0]),
|
489 |
+
'height':int(tmp[1]),
|
490 |
+
'imageBitDepth':int(tmp[2]),
|
491 |
+
'imageBitDepthReal':int(tmp[3]),
|
492 |
+
'imageSizeBytes':int(tmp[4]),
|
493 |
+
'imageFormat':int(tmp[5]),
|
494 |
+
'numFrames':int(tmp[6]),
|
495 |
+
'origin':int(tmp[7]),
|
496 |
+
'trueImageSize':int(tmp[8]),
|
497 |
+
'fps':fps,
|
498 |
+
'codec':codec,
|
499 |
+
'descFormat':desc_format,
|
500 |
+
'padding':padding,
|
501 |
+
'nHiddenFinalFrames':0
|
502 |
+
}
|
503 |
+
assert(self.header['imageBitDepthReal']==8)
|
504 |
+
# seek to end fo header
|
505 |
+
self.file.seek(432,1)
|
506 |
+
self.frames_read += 1
|
507 |
+
|
508 |
+
self.imageFormat = self.header['imageFormat']
|
509 |
+
if self.imageFormat in (100,200): self.ext = 'raw'
|
510 |
+
elif self.imageFormat in (102,201): self.ext = 'jpg'
|
511 |
+
elif self.imageFormat in(0x001,0x002): self.ext = 'png'
|
512 |
+
elif self.imageFormat == 101: self.ext = 'brgb8'
|
513 |
+
elif self.imageFormat == 103: self.ext = 'jbrgb'
|
514 |
+
else: raise ValueError('uknown format')
|
515 |
+
|
516 |
+
self.compressed = True if self.ext in ['jpg','jbrgb','png','brgb8'] else False
|
517 |
+
self.bit_depth = self.header['imageBitDepth']
|
518 |
+
|
519 |
+
# My code uses a timestamp_length of 10 bytes, old uses 8. Check if not 10
|
520 |
+
if self.bit_depth / 8 * (self.header['height'] * self.header['width']) + self.timestamp_length \
|
521 |
+
!= self.header['trueImageSize']:
|
522 |
+
# If not 10, adjust to actual (likely 8) and print message
|
523 |
+
self.timestamp_length = int(self.header['trueImageSize'] \
|
524 |
+
- (self.bit_depth / 8 * (self.header['height'] * self.header['width'])))
|
525 |
+
|
526 |
+
def buildSeekTable(self,memoize=False):
|
527 |
+
"""Build a seek table containing the offset and frame size for every frame in the video."""
|
528 |
+
print("in seqIo_reader.buildSeekTable()")
|
529 |
+
pickle_name = self.filename.strip(".seq") + ".seek"
|
530 |
+
if memoize:
|
531 |
+
if os.path.isfile(pickle_name):
|
532 |
+
self.seek_table = pickle.load(open(pickle_name, 'rb'))
|
533 |
+
return
|
534 |
+
|
535 |
+
# assert self.header['numFrames']>0
|
536 |
+
n=self.header['numFrames']
|
537 |
+
if n==0:n=1e7
|
538 |
+
|
539 |
+
seek_table = np.zeros((n)).astype(np.int64)
|
540 |
+
seek_table[0]=1024
|
541 |
+
extra = 8 # extra bytes after image data , 8 for ts then 0 or 8 empty
|
542 |
+
self.file.seek(1024,0)
|
543 |
+
#compressed case
|
544 |
+
|
545 |
+
if self.compressed:
|
546 |
+
i=1
|
547 |
+
while (True):
|
548 |
+
try:
|
549 |
+
# size = fread(self.file,1,np.uint32)
|
550 |
+
# offset = seek_table[i-1] + size +extra
|
551 |
+
# seek_table[i]=offset
|
552 |
+
# # seek_table[i-1,1]=size
|
553 |
+
# self.file.seek(size-4+extra,1)
|
554 |
+
|
555 |
+
size = fread(self.file, 1, np.uint32)
|
556 |
+
offset = seek_table[i - 1] + size + extra
|
557 |
+
# self.file.seek(size-4+extra,1)
|
558 |
+
self.file.seek(offset, 0)
|
559 |
+
if i == 1:
|
560 |
+
if fread(self.file, 1, np.uint32) != 0:
|
561 |
+
self.file.seek(-4, 1)
|
562 |
+
else:
|
563 |
+
extra += 8;
|
564 |
+
offset += 8
|
565 |
+
self.file.seek(offset, 0)
|
566 |
+
|
567 |
+
seek_table[i] = offset
|
568 |
+
# seek_table[i-1,1]=size
|
569 |
+
i+=1
|
570 |
+
except Exception as e:
|
571 |
+
break
|
572 |
+
#most likely EOF
|
573 |
+
else:
|
574 |
+
#uncompressed case
|
575 |
+
assert (self.header['numFrames']>0)
|
576 |
+
frames = range(0, self.header["numFrames"])
|
577 |
+
offsets = [x * self.header["trueImageSize"] + 1024 for x in frames]
|
578 |
+
for i,offset in enumerate(offsets):
|
579 |
+
seek_table[i]=offset
|
580 |
+
# seek_table[i,1]=self.header["imageSize"]
|
581 |
+
if n==1e7:
|
582 |
+
n = np.minimum(n,i)
|
583 |
+
self.seek_table=seek_table[:n]
|
584 |
+
self.header['numFrames']=n
|
585 |
+
else:
|
586 |
+
self.seek_table=seek_table
|
587 |
+
if memoize:
|
588 |
+
pickle.dump(seek_table,open(pickle_name,'wb'))
|
589 |
+
|
590 |
+
#compute frame rate from timestamps as stored fps may be incorrect
|
591 |
+
# if n==1: return
|
592 |
+
self.getTs()
|
593 |
+
# ds = self.ts[1:100]-self.ts[:99]
|
594 |
+
# ds = ds[abs(ds-np.median(ds))<.005]
|
595 |
+
# if bool(np.prod(ds)): self.header['fps']=1/np.mean(ds)
|
596 |
+
|
597 |
+
def getTs(self, n=None):
|
598 |
+
if n==None: n=self.header['numFrames']
|
599 |
+
if self.compressed and self.seek_table is None:
|
600 |
+
self.buildSeekTable()
|
601 |
+
|
602 |
+
ts = np.zeros((n))
|
603 |
+
for i in range(n):
|
604 |
+
if not self.compressed: #uncompressed
|
605 |
+
self.file.seek(1024 + i*self.header['trueImageSize']+self.header['imageSizeBytes'],0)
|
606 |
+
else: #compressed
|
607 |
+
self.file.seek(self.seek_table[i],0)
|
608 |
+
self.file.seek(fread(self.file,1,np.uint32)-4,1)
|
609 |
+
# print(i)
|
610 |
+
ts[i]=fread(self.file,1,np.uint32)+fread(self.file,1,np.uint16)/1000.
|
611 |
+
|
612 |
+
|
613 |
+
self.ts=ts
|
614 |
+
return self.ts
|
615 |
+
|
616 |
+
def getFrame(self,index,decode=True):
|
617 |
+
#get frame image (I) and timestamp (ts) at which frame was recorded
|
618 |
+
nch = self.header['imageBitDepth']/8
|
619 |
+
if self.ext in ['raw','brgb8']: #read in an uncompressed image( assume imageBitDepthReal==8)
|
620 |
+
shape = (self.header['height'], self.header['width'])
|
621 |
+
self.file.seek(1024 + index*self.header['trueImageSize'],0)
|
622 |
+
I = fread(self.file,self.header['imageSizeBytes'],np.uint8)
|
623 |
+
|
624 |
+
if decode:
|
625 |
+
if nch==1:
|
626 |
+
I=np.reshape(I,shape)
|
627 |
+
else:
|
628 |
+
I=np.reshape(I,(shape,nch))
|
629 |
+
if nch==3:
|
630 |
+
t=I[:,:,2]; I[:,:,2]=I[:,:,0]; I[:,:,1]=t
|
631 |
+
if self.ext=='brgb8':
|
632 |
+
I= cv2.demosaicing(I, code=cv2.COLOR_BAYER_BG2BGR)
|
633 |
+
# I= colour_demosaicing.demosaicing_CFA_Bayer_bilinear(I,'BGGR')
|
634 |
+
|
635 |
+
elif self.ext in ['jpg','jbrgb']:
|
636 |
+
self.file.seek(self.seek_table[index],0)
|
637 |
+
nBytes = fread(self.file,1,np.uint32)
|
638 |
+
data = fread(self.file,nBytes-4,np.uint8)
|
639 |
+
if decode:
|
640 |
+
I = PIL.Image.open(io.BytesIO(data))
|
641 |
+
if self.ext == 'jbrgb':
|
642 |
+
I= cv2.demosaicing(I, code=cv2.COLOR_BAYER_BG2BGR)
|
643 |
+
# I=colour_demosaicing.demosaicing_CFA_Bayer_bilinear(I,'BGGR')
|
644 |
+
else:
|
645 |
+
I = data
|
646 |
+
|
647 |
+
elif self.ext=='png':
|
648 |
+
self.file.seek(self.seek_table[index],0)
|
649 |
+
nBytes = fread(self.file,1,np.uint32)
|
650 |
+
I= fread(self.file,nBytes-4,np.uint8)
|
651 |
+
if decode:
|
652 |
+
I= np.array(I).transpose(range(I.shape,-1,-1))
|
653 |
+
else: assert(False)
|
654 |
+
ts = fread(self.file,1,np.uint32)+fread(self.file,1,np.uint16)/1000.
|
655 |
+
return np.array(I), ts
|
656 |
+
|
657 |
+
# Close the file
|
658 |
+
def close(self):
|
659 |
+
self.file.close()
|
660 |
+
|
661 |
+
class seqIo_writer():
|
662 |
+
def __init__(self,filename,old_header):
|
663 |
+
self.file = open(filename,'wb')
|
664 |
+
self.file.seek(0,0)
|
665 |
+
self.header=old_header
|
666 |
+
|
667 |
+
#create space for header
|
668 |
+
fwrite(self.file,np.zeros(1024).astype(int),np.uint8)
|
669 |
+
|
670 |
+
assert(set(['width','height','fps','codec']).issubset(self.header.keys()))
|
671 |
+
|
672 |
+
codec = self.header['codec']
|
673 |
+
if codec in ['monoraw', 'imageFormat100']: self.frmt = 100;self.nch = 1;self.ext = 'raw'
|
674 |
+
elif codec in ['raw', 'imageFormat200']: self.frmt = 200;self.nch = 3;self.ext = 'raw'
|
675 |
+
elif codec in ['monojpg', 'imageFormat102']: self.frmt = 102;self.nch = 1;self.ext = 'jpg'
|
676 |
+
elif codec in ['jpg', 'imageFormat201']: self.frmt = 201;self.nch = 3;self.ext = 'jpg'
|
677 |
+
elif codec in ['monopng', 'imageFormat001']: self.frmt = 0x001;self.nch = 1;self.ext = 'png'
|
678 |
+
elif codec in ['png', 'imageFormat002']: self.frmt = 0x002;self.nch = 3;self.ext = 'png'
|
679 |
+
else: raise ValueError('unknown format')
|
680 |
+
|
681 |
+
self.header['imageFormat']=self.frmt
|
682 |
+
self.header['imageBitDepth']=8*self.nch
|
683 |
+
self.header['imageBitDepthReal']=8
|
684 |
+
nBytes = self.header['width']*self.header['height']*self.nch
|
685 |
+
self.header['imageSizeBytes']=nBytes
|
686 |
+
self.header['numFrames']=0
|
687 |
+
self.header['trueImageSize']=nBytes + 6 +512-np.mod(nBytes+6,512)
|
688 |
+
|
689 |
+
# Close the file
|
690 |
+
def close(self):
|
691 |
+
self.writeHeader()
|
692 |
+
self.file.close()
|
693 |
+
|
694 |
+
def writeHeader(self):
|
695 |
+
self.file.seek(0,0)
|
696 |
+
# first write 4 bytes to store 0XFEED, next 24 store 'Nrpix seq '
|
697 |
+
fwrite(self.file,int('FEED',16),np.uint32)
|
698 |
+
name = np.array(['Norpix seq ']).view(np.uint8)
|
699 |
+
fwrite(self.file,name, np.uint16)
|
700 |
+
# next 8 bytes for version (3) and header size (1024) then 512 for descr
|
701 |
+
fwrite(self.file,[3,1024],np.int32)
|
702 |
+
if not 'descr' in self.header.keys() or len(np.array([self.header['descr']]).view(np.uint8))>256: d = np.array(['No Description']).view(np.uint8)
|
703 |
+
else: d= np.array([self.header['descr']]).view(np.uint8)
|
704 |
+
d = np.concatenate((d[:np.minimum(256,len(d))],np.zeros(256-len(d)).astype(np.uint8)))
|
705 |
+
fwrite(self.file,d,np.uint16)
|
706 |
+
#write remaining info
|
707 |
+
vals= [self.header['width'],self.header['height'],self.header['imageBitDepth'],self.header['imageBitDepthReal'],
|
708 |
+
self.header['imageSizeBytes'],self.header['imageFormat'],self.header['numFrames'],0,self.header['trueImageSize']]
|
709 |
+
fwrite(self.file,vals,np.uint32)
|
710 |
+
#store frame rate nad pad with 0s
|
711 |
+
fwrite(self.file,self.header['fps'],np.float64)
|
712 |
+
fwrite(self.file,np.zeros(432),np.uint8)
|
713 |
+
|
714 |
+
def addFrame(self,I,ts=0,encode=1):
|
715 |
+
nCh = self.header['imageBitDepth']/8
|
716 |
+
ext = self.ext
|
717 |
+
c = self.header['numFrames']+1
|
718 |
+
if encode:
|
719 |
+
siz = [self.header['height'],self.header['width'],nCh]
|
720 |
+
assert(I.shape[0]==siz[0] and I.shape[1]==siz[1])
|
721 |
+
if len(I.shape)==3:
|
722 |
+
assert(I.shape[2]==siz[2] or I.shape[2]==self.nch)
|
723 |
+
if ext=='raw':
|
724 |
+
#write uncompressed image and assume imageBitDepthReal==8
|
725 |
+
if not encode : assert(I.size==self.header['imageSizeBytes'])
|
726 |
+
else:
|
727 |
+
if nCh==3: t=I[:,:,2]; I[:,:,2]=I[:,:,0];I[:,:,0]=t
|
728 |
+
if nCh==1: I=I.transpose()
|
729 |
+
else: I = np.transpose( np.expand_dims(I, axis=2), (2, 1, 0) )
|
730 |
+
# I= I.flat.view(np.uint8)
|
731 |
+
I= I.flat
|
732 |
+
fwrite(self.file,I,np.uint8)
|
733 |
+
pad = self.header['trueImageSize']-self.header['imageSizeBytes']-6
|
734 |
+
if ext =='jpg':
|
735 |
+
if encode:
|
736 |
+
#write red from to temporary jpg
|
737 |
+
cv2.imwrite('tmp.jpg',I, [int(cv2.IMWRITE_JPEG_QUALITY ),80])
|
738 |
+
# j=Image.fromarray(I.astype(np.uint8))
|
739 |
+
# j.save('tmp.jpg')
|
740 |
+
# I=Image.open('tmp.jpg')
|
741 |
+
fid = open('tmp.jpg','r')
|
742 |
+
I = fid.read()
|
743 |
+
fid.close()
|
744 |
+
b=bytearray(I)
|
745 |
+
assert (b[0] == 255 and b[1] == 216 and b[-2] == 255 and b[-1] == 217); # JPG
|
746 |
+
os.remove('tmp.jpg')
|
747 |
+
I = np.array(list(b)).astype(np.uint8)
|
748 |
+
nbytes = len(I)+4
|
749 |
+
fwrite(self.file,nbytes,np.uint32)
|
750 |
+
# self.file.write(I)
|
751 |
+
fwrite(self.file,I,np.uint8)
|
752 |
+
pad = 10
|
753 |
+
if ts==0: ts = (c-1)/self.header['fps']
|
754 |
+
s = int(np.floor(ts))
|
755 |
+
ms = int(np.round(np.mod(ts,1)*1000))
|
756 |
+
fwrite(self.file,s,np.int32)
|
757 |
+
fwrite(self.file,ms,np.uint16)
|
758 |
+
self.header['numFrames']=c
|
759 |
+
if pad>0:
|
760 |
+
pad = np.zeros(pad).astype(np.uint8)
|
761 |
+
fwrite(self.file,pad,np.uint8)
|
762 |
+
|
763 |
+
def seqIo_crop(fname, tname, frames):
|
764 |
+
"""
|
765 |
+
Crop sub-sequence from seq file.
|
766 |
+
|
767 |
+
Frame indices are 0 indexed. frames need not be consecutive and can
|
768 |
+
contain duplicates. An index of -1 indicates a blank (all 0) frame. If
|
769 |
+
contiguous subset of frames is cropped timestamps are preserved.
|
770 |
+
|
771 |
+
USAGE
|
772 |
+
seqIo( fName, 'crop', tName, frames )
|
773 |
+
|
774 |
+
INPUTS
|
775 |
+
fName - seq file name
|
776 |
+
tName - cropped seq file name
|
777 |
+
frames - frame indices (0 indexed)
|
778 |
+
"""
|
779 |
+
if not isinstance(frames, np.ndarray): frames=np.array(frames)
|
780 |
+
sr = seqIo_reader(fname)
|
781 |
+
sw = seqIo_writer(tname,sr.header)
|
782 |
+
pad,_= sr.getFrame(0)
|
783 |
+
pad = np.zeros(pad.size).astype(np.uint8)
|
784 |
+
kp = frames>=0 & frames<sr.header['numFrames']
|
785 |
+
if not np.all(kp): frames = frames[kp]
|
786 |
+
print('%i out of bounds frames'% np.sum(~kp))
|
787 |
+
ordered = np.all(frames[1:]==frames[:-1]+1)
|
788 |
+
n= frames.size
|
789 |
+
k=0
|
790 |
+
for f in frames:
|
791 |
+
if f<0:
|
792 |
+
sw.addFrame(pad)
|
793 |
+
continue
|
794 |
+
I,ts = sr.getFrame(f)
|
795 |
+
k+=1
|
796 |
+
if ordered:
|
797 |
+
sw.addFrame(I,ts)
|
798 |
+
else:
|
799 |
+
sw.addFrame(I)
|
800 |
+
sr.close()
|
801 |
+
sw.close
|
802 |
+
|
803 |
+
def seqIo_toImgs(fName, tDir=[], skip=1, f0=0, f1=np.inf, ext=''):
|
804 |
+
"""
|
805 |
+
Extract images from seq file to target directory or array.
|
806 |
+
|
807 |
+
USAGE
|
808 |
+
Is = seqIo( fName, 'toImgs', [tDir], [skip], [f0], [f1], [ext] )
|
809 |
+
|
810 |
+
INPUTS
|
811 |
+
fName - seq file name
|
812 |
+
tDir - [] target directory (if empty extract images to array)
|
813 |
+
skip - [1] skip between written frames
|
814 |
+
f0 - [0] first frame to write
|
815 |
+
f1 - [numFrames-1] last frame to write
|
816 |
+
ext - [] optionally save as given type (slow, reconverts)
|
817 |
+
|
818 |
+
OUTPUTS
|
819 |
+
Is - if isempty(tDir) outputs image array (else Is=[])
|
820 |
+
"""
|
821 |
+
sr = seqIo_reader(fName)
|
822 |
+
f1 = np.minimum(f1,sr.header['numFrames']-1)
|
823 |
+
frames = range(f0,f1,skip)
|
824 |
+
n=len(frames)
|
825 |
+
k=0
|
826 |
+
#output images to array
|
827 |
+
if tDir==[]:
|
828 |
+
I,_=sr.getFrame(0)
|
829 |
+
d = I.shape
|
830 |
+
assert(len(d)==2 or len(d)==3)
|
831 |
+
try:
|
832 |
+
Is = np.zeros((I.shape+(n,))).astype(I.dtype)
|
833 |
+
except:
|
834 |
+
sr.close()
|
835 |
+
raise
|
836 |
+
for k in range(n):
|
837 |
+
I,ts = sr.getFrame(k)
|
838 |
+
if len(d)==2:
|
839 |
+
Is[:,:,k]=I
|
840 |
+
else:
|
841 |
+
Is[:,:,:,k]=I
|
842 |
+
print('saved %d' % k)
|
843 |
+
|
844 |
+
sr.close()
|
845 |
+
# output image directory
|
846 |
+
if not os.path.exists(tDir):os.makedirs(tDir)
|
847 |
+
if tDir.split('/')[-1]!='/':tDir+'/'
|
848 |
+
Is = np.array([])
|
849 |
+
for frame in frames:
|
850 |
+
f = tDir + 'I%05.' % (frame)
|
851 |
+
I, ts = sr.getFrame(frame)
|
852 |
+
if ext!='':
|
853 |
+
cv2.imwrite(f+ext,I)
|
854 |
+
else:
|
855 |
+
cv2.imwrite(f+sr.ext)
|
856 |
+
k+=1
|
857 |
+
print('saved %d' % frame)
|
858 |
+
sr.close()
|
859 |
+
return Is
|
860 |
+
|
861 |
+
def seqIo_frImgs(fName, header=[], aviName=[], Is=[], sDir=[], name='I', ndig=5, f0=0, f1=1e6):
|
862 |
+
"""
|
863 |
+
Create seq file from an array or directory of images or from an AVI file.
|
864 |
+
|
865 |
+
For info, if converting from array, only codec (e.g., 'jpg') and fps must
|
866 |
+
be specified while width and height and determined automatically. If
|
867 |
+
converting from AVI, fps is also determined automatically.
|
868 |
+
|
869 |
+
USAGE
|
870 |
+
seqIo( fName, 'frImgs', info, varargin )
|
871 |
+
|
872 |
+
INPUTS
|
873 |
+
fName - seq file name
|
874 |
+
info - defines codec, etc, see seqIo>writer
|
875 |
+
varargin - additional params (struct or name/value pairs)
|
876 |
+
.aviName - [] if specified create seq from avi file
|
877 |
+
.Is - [] if specified create seq from image array
|
878 |
+
.sDir - [] source directory
|
879 |
+
.skip - [1] skip between frames
|
880 |
+
.name - ['I'] base name of images
|
881 |
+
.nDigits - [5] number of digits for filename index
|
882 |
+
.f0 - [0] first frame to read
|
883 |
+
.f1 - [10^6] last frame to read
|
884 |
+
"""
|
885 |
+
|
886 |
+
if aviName!=[]: #avi movie exists
|
887 |
+
vc = cv2.VideoCapture(aviName)
|
888 |
+
if vc.isOpened(): rval = True
|
889 |
+
else:
|
890 |
+
rval = False
|
891 |
+
print('video not readable')
|
892 |
+
return
|
893 |
+
fps = vc.get(cv2.cv.CV_CAP_PROP_FPS)
|
894 |
+
NUM_FRAMES = int(vc.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
|
895 |
+
print(NUM_FRAMES)
|
896 |
+
IM_TOP_H = vc.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
|
897 |
+
IM_TOP_W = vc.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
|
898 |
+
header['width']=IM_TOP_W
|
899 |
+
header['height']=IM_TOP_H
|
900 |
+
header['fps']=fps
|
901 |
+
|
902 |
+
sw = seqIo_writer(fName,header)
|
903 |
+
print('creating seq from AVI')
|
904 |
+
# initialize timer
|
905 |
+
timer = pb.ProgressBar(widgets=['Converting ', pb.Percentage(), ' -- ',
|
906 |
+
pb.FormatLabel('Frame %(value)d'), '/',
|
907 |
+
pb.FormatLabel('%(max)d'), ' [', pb.Timer(), '] ',
|
908 |
+
pb.Bar(), ' (', pb.ETA(), ') '], maxval=NUM_FRAMES)
|
909 |
+
for f in range(NUM_FRAMES):
|
910 |
+
rval, im = vc.read()
|
911 |
+
if rval:
|
912 |
+
im= im.astype(np.uint8)
|
913 |
+
sw.addFrame(im)
|
914 |
+
timer.update(f)
|
915 |
+
sw.close()
|
916 |
+
timer.finish()
|
917 |
+
elif Is==[]:
|
918 |
+
assert(os.path.isdir(sDir))
|
919 |
+
sw = seqIo_writer(fName,header)
|
920 |
+
frmstr = '%s/%s%%0%ii.%s' % (sDir,name,ndig,header.ext)
|
921 |
+
for frame in range(f0,f1):
|
922 |
+
f = frmstr % frame
|
923 |
+
if not os.path.isfile(f):break
|
924 |
+
fid = open(f, 'r')
|
925 |
+
if fid<0: sw.close(); assert(False)
|
926 |
+
I = fid.read()
|
927 |
+
fid.close()
|
928 |
+
b = bytearray(I)
|
929 |
+
assert (b[0] == 255 and b[1] == 216 and b[-2] == 255 and b[-1] == 217); # JPG
|
930 |
+
I = np.array(list(b)).astype(np.uint8)
|
931 |
+
sw.addFrame(I,0,0)
|
932 |
+
sw.close()
|
933 |
+
if frame==f0: print('No images found')
|
934 |
+
else:
|
935 |
+
nd = len(Is.shape)
|
936 |
+
if nd==2: nd=3
|
937 |
+
assert(nd<=4)
|
938 |
+
nFrm = Is.shape[nd-1]
|
939 |
+
header['height']=Is.shape[0]
|
940 |
+
header['width']=Is.shape[1]
|
941 |
+
sw =seqIo_writer(fName,header)
|
942 |
+
if nd==3:
|
943 |
+
for f in range(nFrm): sw.addFrame(Is[:,:,f])
|
944 |
+
if nd==4:
|
945 |
+
for f in range(nFrm): sw.addFrame(Is[:,:,:,f])
|
946 |
+
sw.close()
|
947 |
+
|
948 |
+
def seqIo_convert(fName, tName, imgFun, info=[], skip=1, f0=0, f1=np.inf):
|
949 |
+
"""
|
950 |
+
Convert seq file by applying imgFun(I) to each frame I.
|
951 |
+
|
952 |
+
USAGE
|
953 |
+
seqIo( fName, 'convert', tName, imgFun, varargin )
|
954 |
+
|
955 |
+
INPUTS
|
956 |
+
fName - seq file name
|
957 |
+
tName - converted seq file name
|
958 |
+
imgFun - function to apply to each image
|
959 |
+
varargin - additional params (struct or name/value pairs)
|
960 |
+
.info - [] info for target seq file
|
961 |
+
.skip - [1] skip between frames
|
962 |
+
.f0 - [0] first frame to read
|
963 |
+
.f1 - [inf] last frame to read
|
964 |
+
"""
|
965 |
+
assert(fName!=tName)
|
966 |
+
sr = seqIo_reader(fName)
|
967 |
+
if info==[]: info=sr.header
|
968 |
+
n=sr.header['numFrames']
|
969 |
+
f1=np.minimum(f1,n-1)
|
970 |
+
I,ts=sr.getFrame(0)
|
971 |
+
I=imgFun(I)
|
972 |
+
info['width']=I.shape[1]
|
973 |
+
info['height']=I.shape[0]
|
974 |
+
sw =seqIo_writer(tName,info)
|
975 |
+
print('converting seq')
|
976 |
+
for frame in range(f0,f1,skip):
|
977 |
+
I, ts = sr.getFrame(frame)
|
978 |
+
I = imgFun(I)
|
979 |
+
if skip==1:
|
980 |
+
sw.addFrame(I,ts)
|
981 |
+
else:
|
982 |
+
sw.addFrameI
|
983 |
+
sw.close()
|
984 |
+
sr.close()
|
985 |
+
|
986 |
+
def seqIo_newHeader(fName, info):
|
987 |
+
"""
|
988 |
+
Replace header of seq file with provided info.
|
989 |
+
|
990 |
+
Can be used if the file fName has a corrupt header. Automatically tries
|
991 |
+
to compute number of frames in fName. No guarantees that it will work.
|
992 |
+
|
993 |
+
USAGE
|
994 |
+
seqIo( fName, 'newHeader', info )
|
995 |
+
|
996 |
+
INPUTS
|
997 |
+
fName - seq file name
|
998 |
+
info - info for target seq file
|
999 |
+
"""
|
1000 |
+
d, n = os.path.split(fName)
|
1001 |
+
if d==[]:d='./'
|
1002 |
+
tName=fName[:-4] + '_new' + time.strftime("%d_%m_%Y") + fName[-4:]
|
1003 |
+
sr = seqIo_reader(fName)
|
1004 |
+
sw = seqIo_writer(tName,info)
|
1005 |
+
n=sr.header['numFrames']
|
1006 |
+
for f in range(n):
|
1007 |
+
I,ts=sr.getFrame(f)
|
1008 |
+
sw.addFrame(I,ts)
|
1009 |
+
sr.close()
|
1010 |
+
sw.close()
|
1011 |
+
|
1012 |
+
class seqIo_dualReader():
|
1013 |
+
"""
|
1014 |
+
seqIo_dualReader
|
1015 |
+
Create interface sr for reading dual seq files.
|
1016 |
+
|
1017 |
+
Wrapper for two seq files of the same image dims and roughly the same
|
1018 |
+
frame counts that are treated as a single reader object. getframe()
|
1019 |
+
returns the concatentation of the two frames. For videos of different
|
1020 |
+
frame counts, the first video serves as the "dominant" video and the
|
1021 |
+
frame count of the second video is adjusted accordingly. Same general
|
1022 |
+
usage as in reader, but the only supported operations are: close(),
|
1023 |
+
getframe(), getinfo(), and seek().
|
1024 |
+
|
1025 |
+
USAGE
|
1026 |
+
sr = seqIo( fNames, 'readerDual', [cache] )
|
1027 |
+
|
1028 |
+
INPUTS
|
1029 |
+
fNames - two seq file names
|
1030 |
+
cache - [0] size of cache (see seqIo>reader)
|
1031 |
+
|
1032 |
+
OUTPUTS
|
1033 |
+
sr - interface for reading seq file
|
1034 |
+
"""
|
1035 |
+
def __init__(self,file1,file2):
|
1036 |
+
self.s1 = seqIo_reader(file1)
|
1037 |
+
self.s2 = seqIo_reader(file2)
|
1038 |
+
self.info = self.s1.header
|
1039 |
+
#set the display to be vertically align
|
1040 |
+
self.info['height']=self.s1.header['height']+self.s2.header['height']
|
1041 |
+
self.info['width']=np.maximum(self.s1.header['width'],self.s2.header['width'])
|
1042 |
+
|
1043 |
+
if self.s1.header['numFrames']!=self.s2.header['numFrames']:
|
1044 |
+
print('Two videos files have different number of frames')
|
1045 |
+
print('1st video has %d frames' % self.s1.header['numFrames'])
|
1046 |
+
print('2nd video has %d frames' % self.s2.header['numFrames'])
|
1047 |
+
print('first video %s is used as annotation refeence' % file1)
|
1048 |
+
|
1049 |
+
def getFrame(self):
|
1050 |
+
I1,ts = self.s1.getFrame(0)
|
1051 |
+
I2,_ = self.s2.getFrame(0)
|
1052 |
+
|
1053 |
+
w1 = I1.shape[1]
|
1054 |
+
w2 = I2.shape[1]
|
1055 |
+
|
1056 |
+
if w1!=w2:
|
1057 |
+
m=np.argmax(w1,w2)
|
1058 |
+
if m==0:
|
1059 |
+
wl = int(np.floor((w1-w2)/2.))
|
1060 |
+
wr = w1-w2-wl
|
1061 |
+
nd = len(I2.shape)
|
1062 |
+
if nd==2:
|
1063 |
+
padl = np.zeros((I2.shape[0],wl)).astype(np.uint8)
|
1064 |
+
padr = np.zeros((I2.shape[0],wr)).astype(np.uint8)
|
1065 |
+
else:
|
1066 |
+
padl = np.zeros((I2.shape[0],wl,I2.shape[2])).astype(np.uint8)
|
1067 |
+
padr = np.zeros((I2.shape[0],wr,I2.shape[2])).astype(np.uint8)
|
1068 |
+
I2 = np.concatenate((padl,I2,padr),axis=1)
|
1069 |
+
else:
|
1070 |
+
wl = int(np.floor((w2 - w1) / 2.))
|
1071 |
+
wr = w2 - w1 - wl
|
1072 |
+
nd = len(I2.shape)
|
1073 |
+
if nd == 2:
|
1074 |
+
padl = np.zeros((I1.shape[0], wl)).astype(np.uint8)
|
1075 |
+
padr = np.zeros((I1.shape[0], wr)).astype(np.uint8)
|
1076 |
+
else:
|
1077 |
+
padl = np.zeros((I1.shape[0], wl, I1.shape[2])).astype(np.uint8)
|
1078 |
+
padr = np.zeros((I1.shape[0], wr, I1.shape[2])).astype(np.uint8)
|
1079 |
+
I1 = np.concatenate((padl, I1, padr), axis=1)
|
1080 |
+
I = np.hstack((I1,I2))
|
1081 |
+
return I,ts
|
1082 |
+
|
1083 |
+
class seqIo_extractor():
|
1084 |
+
"""
|
1085 |
+
Create new seq files from top and fron view and syncronize them is not
|
1086 |
+
path_vid: video path
|
1087 |
+
vid_top: seq top video path and name
|
1088 |
+
vid_front: seq front video path and name
|
1089 |
+
s: start frame
|
1090 |
+
e: end frame
|
1091 |
+
|
1092 |
+
"""
|
1093 |
+
def __init__(self,path_vid,vid_top,vid_front,s,e):
|
1094 |
+
sr_top = seqIo_reader(path_vid+vid_top)
|
1095 |
+
sr_front = seqIo_reader(path_vid+vid_front)
|
1096 |
+
num_frames=sr_top.header['numFrames']
|
1097 |
+
num_framesf=sr_front.header['numFrames']
|
1098 |
+
name =os.path.dirname(video_top).split('/')[-1]
|
1099 |
+
|
1100 |
+
if not os.path.exists(pathvid + name + '_%06d_%06d' % (s, e)):
|
1101 |
+
os.makedirs(pathvid + name + '_%06d_%06d' % (s, e))
|
1102 |
+
newdir = pathvid + name + '_%06d_%06d' % (s, e)
|
1103 |
+
video_out_top = newdir + '/' + name + '_%06d_%06d_Top_J85.seq' % (s, e)
|
1104 |
+
video_out_front = newdir + '/' + name + '_%06d_%06d_Front_J85.seq' % (s, e)
|
1105 |
+
|
1106 |
+
sw_top = seqIo_writer(video_out_top, sr_top.header)
|
1107 |
+
sw_front = seqIo_writer(video_out_front, sr_front.header)
|
1108 |
+
|
1109 |
+
for f in range(s - 1, e):
|
1110 |
+
if num_framesf > num_frames:
|
1111 |
+
I_top, ts = sr_top.getFrame(f2(f))
|
1112 |
+
I_front, ts2 = sr_front.getFrame(f)
|
1113 |
+
else:
|
1114 |
+
I_top, ts = sr_top.getFrame(f)
|
1115 |
+
I_front, ts2 = sr_front.getFrame(f2(f))
|
1116 |
+
sw_top.addFrame(I_top, ts)
|
1117 |
+
sw_front.addFrame(I_front, ts2)
|
1118 |
+
print(f)
|
1119 |
+
sw_top.close()
|
1120 |
+
sw_front.close()
|
1121 |
+
|
1122 |
+
def f2(f):
|
1123 |
+
return int(round(f / (num_framesf - 1) * (num_frames - 1))) if num_framesf > num_frames else int(round(f / (num_frames - 1) * (num_framesf - 1)))
|
1124 |
+
|
1125 |
+
def seqIo_toVid(fName, ext='avi'):
|
1126 |
+
"""
|
1127 |
+
seqIo_toVid
|
1128 |
+
Create seq file to another common used format as avi or mp4.
|
1129 |
+
|
1130 |
+
USAGE
|
1131 |
+
seqIo( fName, ext )
|
1132 |
+
|
1133 |
+
INPUTS
|
1134 |
+
fName - seq file name
|
1135 |
+
ext - video extension to convert to
|
1136 |
+
"""
|
1137 |
+
|
1138 |
+
assert fName[-3:]=='seq', 'Not a seq file'
|
1139 |
+
sr = seqIo_reader(fName)
|
1140 |
+
N = sr.header['numFrames']
|
1141 |
+
h = sr.header['height']
|
1142 |
+
w = sr.header['width']
|
1143 |
+
fps = sr.header['fps']
|
1144 |
+
|
1145 |
+
out = fName[:-3]+ext
|
1146 |
+
sw = skvideo.io.FFmpegWriter(out)
|
1147 |
+
# sw = cv2.VideoWriter(out, -1, fps, (w, h))
|
1148 |
+
timer = pb.ProgressBar(widgets=['Converting ', pb.Percentage(), ' -- ',
|
1149 |
+
pb.FormatLabel('Frame %(value)d'), '/',
|
1150 |
+
pb.FormatLabel('%(max)d'), ' [', pb.Timer(), '] ',
|
1151 |
+
pb.Bar(), ' (', pb.ETA(), ') '], maxval=N)
|
1152 |
+
|
1153 |
+
for f in range(N):
|
1154 |
+
I, ts = sr.getFrame(f)
|
1155 |
+
#sw.writeFrame(Image.fromarray(I))
|
1156 |
+
sw.write(I)
|
1157 |
+
timer.update(f)
|
1158 |
+
timer.finish()
|
1159 |
+
# cv2.destroyAllWindows()
|
1160 |
+
# sw.release()
|
1161 |
+
sw.close()
|
1162 |
+
sr.close()
|
1163 |
+
print(out + ' converted')
|
1164 |
+
|
1165 |
+
|
1166 |
+
|
1167 |
+
|
1168 |
+
# minimum header
|
1169 |
+
# header = {'width': IM_TOP_W,
|
1170 |
+
# 'height': IM_TOP_H,
|
1171 |
+
# 'fps': fps,
|
1172 |
+
# 'codec': 'imageFormat102'}
|
1173 |
+
# filename= '/media/cristina/MARS_data/mice_project/teresa/Mouse156_20161017_17-22-09/Mouse156_20161017_17-22-09_Top_J85.seq'
|
1174 |
+
# filename_out = filename[:-4] + '_new.seq'
|
1175 |
+
# reader = seqIo_reader(filename)
|
1176 |
+
# reader.header
|
1177 |
+
# Initialize a SEQ writer
|
1178 |
+
# writer = seqIo_writer(filename_out,reader.header)
|
1179 |
+
# I,ts = reader.getFrame(0)
|
1180 |
+
# writer.addFrame(I,ts)
|
1181 |
+
# for f in range(8):
|
1182 |
+
# I,ts = reader.getFrame(f)
|
1183 |
+
# print(writer.file.tell())
|
1184 |
+
# writer.addFrame(I,ts)
|
1185 |
+
# writer.close()
|
1186 |
+
# reader.close()
|
1187 |
+
|
1188 |
+
|
1189 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import pickle
|
4 |
+
import copy
|
5 |
+
from collections import Counter
|
6 |
+
from pathlib import Path
|
7 |
+
from tempfile import NamedTemporaryFile
|
8 |
+
import regex as re
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from sklearn.manifold import TSNE
|
12 |
+
from sklearn.svm import SVC
|
13 |
+
from sklearn.model_selection import train_test_split
|
14 |
+
from sklearn.metrics import accuracy_score, classification_report
|
15 |
+
import torch
|
16 |
+
from tqdm import tqdm
|
17 |
+
from PIL import Image
|
18 |
+
from transformers import AutoProcessor, AutoModel
|
19 |
+
import streamlit as st
|
20 |
+
from .data_loading import load_multiple_annotations, load_multiple_annotations_io
|
21 |
+
from .data_processing import generate_label_array
|
22 |
+
from .seqIo import seqIo_reader
|
23 |
+
from .mp4Io import mp4Io_reader
|
24 |
+
|
25 |
+
SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
|
26 |
+
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
|
27 |
+
|
28 |
+
def create_annot_fname_dict(annot_fnames: list[str])-> dict:
|
29 |
+
fs = re.compile(r'.*(_\d+)$')
|
30 |
+
|
31 |
+
unique_files = set()
|
32 |
+
for file in annot_fnames:
|
33 |
+
file_name = os.fsdecode(file)
|
34 |
+
base_name, _ = os.path.splitext(file_name)
|
35 |
+
if fs.match(base_name):
|
36 |
+
ind = len(fs.match(base_name).group(1))
|
37 |
+
unique_files.add(base_name[:-ind])
|
38 |
+
else:
|
39 |
+
unique_files.add(base_name)
|
40 |
+
|
41 |
+
annot_fname_dict = {}
|
42 |
+
for unique_file in unique_files:
|
43 |
+
annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
|
44 |
+
return annot_fname_dict
|
45 |
+
|
46 |
+
def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
|
47 |
+
annot_file_dict = {}
|
48 |
+
for file in annot_files:
|
49 |
+
annot_file_dict.update({file.name : file})
|
50 |
+
fs = re.compile(r'.*(_\d+)$')
|
51 |
+
|
52 |
+
unique_files = set()
|
53 |
+
for file in annot_fnames:
|
54 |
+
file_name = os.fsdecode(file)
|
55 |
+
base_name, _ = os.path.splitext(file_name)
|
56 |
+
if fs.match(base_name):
|
57 |
+
ind = len(fs.match(base_name).group(1))
|
58 |
+
unique_files.add(base_name[:-ind])
|
59 |
+
else:
|
60 |
+
unique_files.add(base_name)
|
61 |
+
|
62 |
+
annot_fname_dict = {}
|
63 |
+
for unique_file in unique_files:
|
64 |
+
annot_list = [file for file in annot_fnames if unique_file in file]
|
65 |
+
annot_list.sort()
|
66 |
+
annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
|
67 |
+
annot_fname_dict.update({unique_file: annot_file_list})
|
68 |
+
return annot_fname_dict
|
69 |
+
|
70 |
+
def get_io_reader(uploaded_file):
|
71 |
+
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
|
72 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
73 |
+
temp.write(uploaded_file.getvalue())
|
74 |
+
sr = seqIo_reader(temp.name)
|
75 |
+
return sr
|
76 |
+
|
77 |
+
def load_slip_model(device):
|
78 |
+
return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)
|
79 |
+
|
80 |
+
def load_slip_preprocessor():
|
81 |
+
return AutoProcessor.from_pretrained(SLIP_MODEL_ID)
|
82 |
+
|
83 |
+
def load_clip_model(device):
|
84 |
+
return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)
|
85 |
+
|
86 |
+
def load_clip_preprocessor():
|
87 |
+
return AutoProcessor.from_pretrained(CLIP_MODEL_ID)
|
88 |
+
|
89 |
+
def encode_image(image, device, model, processor):
|
90 |
+
with torch.no_grad():
|
91 |
+
#convert_models_to_fp32(model)
|
92 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
93 |
+
image_features = model.get_image_features(**inputs)
|
94 |
+
return image_features.cpu().numpy().flatten()
|
95 |
+
|
96 |
+
def generate_embeddings_stream(fnames : list[str],
|
97 |
+
model = 'SLIP',
|
98 |
+
downsample_rate = 4,
|
99 |
+
save_csv = False)-> tuple[list, list, list]:
|
100 |
+
# set up model and device
|
101 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
102 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
103 |
+
if model == 'SLIP':
|
104 |
+
embed_model = load_slip_model(device)
|
105 |
+
processor = load_slip_preprocessor()
|
106 |
+
elif model == 'CLIP':
|
107 |
+
embed_model = load_clip_model(device)
|
108 |
+
processor = load_clip_preprocessor()
|
109 |
+
|
110 |
+
all_video_embeddings = []
|
111 |
+
all_video_frames = []
|
112 |
+
for fname in fnames:
|
113 |
+
# read in file
|
114 |
+
is_seq = False
|
115 |
+
if fname[-3:] == 'seq': is_seq = True
|
116 |
+
|
117 |
+
if is_seq:
|
118 |
+
sr = seqIo_reader(fname)
|
119 |
+
else:
|
120 |
+
sr = mp4Io_reader(fname)
|
121 |
+
N = sr.header['numFrames']
|
122 |
+
|
123 |
+
# set up embeddings and frame arrays
|
124 |
+
embeddings = []
|
125 |
+
frames = list(range(N))[::downsample_rate]
|
126 |
+
print(frames)
|
127 |
+
|
128 |
+
# create progress bar
|
129 |
+
i = 0
|
130 |
+
pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
|
131 |
+
pbar = st.progress(0, text=pbar_text(0))
|
132 |
+
|
133 |
+
# convert each frame to embeddings
|
134 |
+
for f in tqdm(frames):
|
135 |
+
img, _ = sr.getFrame(f)
|
136 |
+
img_arr = np.array(img)
|
137 |
+
if is_seq:
|
138 |
+
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
|
139 |
+
else:
|
140 |
+
img_rgb = Image.fromarray(img_arr).convert('RGB')
|
141 |
+
|
142 |
+
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
|
143 |
+
|
144 |
+
# update progress bar
|
145 |
+
i += 1
|
146 |
+
pbar.progress(i/len(frames), pbar_text(i))
|
147 |
+
|
148 |
+
# save csv of single file
|
149 |
+
if save_csv:
|
150 |
+
df = pd.DataFrame(embeddings)
|
151 |
+
df['Frame'] = frames
|
152 |
+
|
153 |
+
# save csv
|
154 |
+
basename = Path(fname).stem
|
155 |
+
df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)
|
156 |
+
|
157 |
+
all_video_embeddings.append(np.array(embeddings))
|
158 |
+
all_video_frames.append(frames)
|
159 |
+
return all_video_embeddings, all_video_frames
|
160 |
+
|
161 |
+
def get_io_reader(uploaded_file):
|
162 |
+
if uploaded_file.name[-3:]=='seq':
|
163 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
164 |
+
temp.write(uploaded_file.getvalue())
|
165 |
+
sr = seqIo_reader(temp.name)
|
166 |
+
else:
|
167 |
+
with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
|
168 |
+
temp.write(uploaded_file.getvalue())
|
169 |
+
sr = mp4Io_reader(temp.name)
|
170 |
+
return sr
|
171 |
+
|
172 |
+
def generate_embeddings_stream_io(uploaded_files : list,
|
173 |
+
model = 'SLIP',
|
174 |
+
downsample_rate = 4,
|
175 |
+
save_csv = False)-> tuple[list, list, list]:
|
176 |
+
# set up model and device
|
177 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
178 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
179 |
+
if model == 'SLIP':
|
180 |
+
embed_model = load_slip_model(device)
|
181 |
+
processor = load_slip_preprocessor()
|
182 |
+
elif model == 'CLIP':
|
183 |
+
embed_model = load_clip_model(device)
|
184 |
+
processor = load_clip_preprocessor()
|
185 |
+
|
186 |
+
all_video_embeddings = []
|
187 |
+
all_video_frames = []
|
188 |
+
for file in uploaded_files:
|
189 |
+
is_seq = False
|
190 |
+
if file.name[-3:] == 'seq': is_seq = True
|
191 |
+
|
192 |
+
# read in file
|
193 |
+
sr = get_io_reader(file)
|
194 |
+
N = sr.header['numFrames']
|
195 |
+
|
196 |
+
# set up embeddings and frame arrays
|
197 |
+
embeddings = []
|
198 |
+
frames = list(range(N))[::downsample_rate]
|
199 |
+
print(frames)
|
200 |
+
|
201 |
+
# create progress bar
|
202 |
+
i = 0
|
203 |
+
pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
|
204 |
+
pbar = st.progress(0, text=pbar_text(0))
|
205 |
+
|
206 |
+
# convert each frame to embeddings
|
207 |
+
for f in tqdm(frames):
|
208 |
+
img, _ = sr.getFrame(f)
|
209 |
+
img_arr = np.array(img)
|
210 |
+
if is_seq:
|
211 |
+
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
|
212 |
+
else:
|
213 |
+
img_rgb = Image.fromarray(img_arr).convert('RGB')
|
214 |
+
|
215 |
+
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
|
216 |
+
|
217 |
+
# update progress bar
|
218 |
+
i += 1
|
219 |
+
pbar.progress(i/len(frames), pbar_text(i))
|
220 |
+
|
221 |
+
# save csv of single file
|
222 |
+
if save_csv:
|
223 |
+
df = pd.DataFrame(embeddings)
|
224 |
+
df['Frame'] = frames
|
225 |
+
|
226 |
+
# save csv
|
227 |
+
df.to_csv(f'embeddings_downsample_{downsample_rate}_{frames}_frames.csv', index=False)
|
228 |
+
|
229 |
+
all_video_embeddings.append(np.array(embeddings))
|
230 |
+
all_video_frames.append(frames)
|
231 |
+
return all_video_embeddings, all_video_frames
|
232 |
+
|
233 |
+
def create_embeddings_csv(out: str,
|
234 |
+
fnames: list[str],
|
235 |
+
embeddings: list[np.ndarray],
|
236 |
+
frames: list[list[int]],
|
237 |
+
annotations: list[list[str]],
|
238 |
+
test_fnames: None | list[str],
|
239 |
+
views: None | list[str],
|
240 |
+
conditions: None | list[str],
|
241 |
+
downsample_rate = 4,
|
242 |
+
filesystem = None):
|
243 |
+
"""
|
244 |
+
Creates a .csv file containing all of the generated embeddings and provived information.
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
-----------
|
248 |
+
out : str
|
249 |
+
The name of the resulting file.
|
250 |
+
fnames : list[str]
|
251 |
+
Video sources for each of the embedding arrays.
|
252 |
+
embeddings : np.ndarray
|
253 |
+
The generated embeddings from the images.
|
254 |
+
downsample_rate : int
|
255 |
+
The downsample_rate used for generating the embeddings.
|
256 |
+
"""
|
257 |
+
assert len(fnames) == len(embeddings)
|
258 |
+
assert len(embeddings) == len(frames)
|
259 |
+
all_embeddings = np.vstack(embeddings)
|
260 |
+
df = pd.DataFrame(all_embeddings)
|
261 |
+
|
262 |
+
labels = []
|
263 |
+
for i, annot_fnames in enumerate(annotations):
|
264 |
+
_, ext = os.path.splitext(annot_fnames[0])
|
265 |
+
if ext == '.annot':
|
266 |
+
annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
|
267 |
+
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
|
268 |
+
elif ext == '.csv':
|
269 |
+
if not filesystem:
|
270 |
+
annot_df = pd.read_csv(annot_fnames[0], header=None)
|
271 |
+
else:
|
272 |
+
with filesystem.open(annot_fnames[0], 'r') as csv_file:
|
273 |
+
annot_df = pd.read_csv(csv_file, header=None)
|
274 |
+
annot_labels = annot_df[0].to_list()[::downsample_rate]
|
275 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
|
276 |
+
else:
|
277 |
+
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
|
278 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
|
279 |
+
print(annot_labels)
|
280 |
+
labels.append(annot_labels)
|
281 |
+
all_labels = np.hstack(labels)
|
282 |
+
print(len(all_labels))
|
283 |
+
df['Label'] = all_labels
|
284 |
+
|
285 |
+
all_frames = np.hstack(frames)
|
286 |
+
df['Frame'] = all_frames
|
287 |
+
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
288 |
+
all_sources = np.hstack(sources)
|
289 |
+
df['Source'] = all_sources
|
290 |
+
|
291 |
+
if test_fnames:
|
292 |
+
t_split = lambda x: True if x in test_fnames else False
|
293 |
+
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
294 |
+
else:
|
295 |
+
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
|
296 |
+
all_test = np.hstack(test)
|
297 |
+
df['Test'] = all_test
|
298 |
+
|
299 |
+
if views:
|
300 |
+
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
301 |
+
else:
|
302 |
+
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
303 |
+
all_view = np.hstack(view)
|
304 |
+
df['View'] = all_view
|
305 |
+
|
306 |
+
if conditions:
|
307 |
+
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
308 |
+
else:
|
309 |
+
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
310 |
+
all_condition = np.hstack(condition)
|
311 |
+
df['Condition'] = all_condition
|
312 |
+
return df
|
313 |
+
|
314 |
+
def create_embeddings_csv_io(out: str,
|
315 |
+
fnames: list[str],
|
316 |
+
embeddings: list[np.ndarray],
|
317 |
+
frames: list[list[int]],
|
318 |
+
annotations: list,
|
319 |
+
test_fnames: None | list[str],
|
320 |
+
views: None | list[str],
|
321 |
+
conditions: None | list[str],
|
322 |
+
downsample_rate = 4):
|
323 |
+
"""
|
324 |
+
Creates a .csv file containing all of the generated embeddings and provived information.
|
325 |
+
|
326 |
+
Parameters:
|
327 |
+
-----------
|
328 |
+
out : str
|
329 |
+
The name of the resulting file.
|
330 |
+
fnames : list[str]
|
331 |
+
Video sources for each of the embedding arrays.
|
332 |
+
embeddings : np.ndarray
|
333 |
+
The generated embeddings from the images.
|
334 |
+
downsample_rate : int
|
335 |
+
The downsample_rate used for generating the embeddings.
|
336 |
+
"""
|
337 |
+
assert len(fnames) == len(embeddings)
|
338 |
+
assert len(embeddings) == len(frames)
|
339 |
+
all_embeddings = np.vstack(embeddings)
|
340 |
+
df = pd.DataFrame(all_embeddings)
|
341 |
+
|
342 |
+
labels = []
|
343 |
+
for i, uploaded_annots in enumerate(annotations):
|
344 |
+
print(i)
|
345 |
+
_, ext = os.path.splitext(uploaded_annots[0].name)
|
346 |
+
if ext == '.annot':
|
347 |
+
annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
|
348 |
+
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
|
349 |
+
elif ext == '.csv':
|
350 |
+
annot_df = pd.read_csv(uploaded_annots[0], header=None)
|
351 |
+
annot_labels = annot_df[0].to_list()[::downsample_rate]
|
352 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
|
353 |
+
else:
|
354 |
+
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
|
355 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
|
356 |
+
print(annot_labels)
|
357 |
+
labels.append(annot_labels)
|
358 |
+
all_labels = np.hstack(labels)
|
359 |
+
print(len(all_labels))
|
360 |
+
df['Label'] = all_labels
|
361 |
+
|
362 |
+
all_frames = np.hstack(frames)
|
363 |
+
df['Frame'] = all_frames
|
364 |
+
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
365 |
+
all_sources = np.hstack(sources)
|
366 |
+
df['Source'] = all_sources
|
367 |
+
|
368 |
+
if test_fnames:
|
369 |
+
t_split = lambda x: True if x in test_fnames else False
|
370 |
+
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
371 |
+
else:
|
372 |
+
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
|
373 |
+
all_test = np.hstack(test)
|
374 |
+
df['Test'] = all_test
|
375 |
+
|
376 |
+
if views:
|
377 |
+
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
378 |
+
else:
|
379 |
+
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
380 |
+
all_view = np.hstack(view)
|
381 |
+
df['View'] = all_view
|
382 |
+
|
383 |
+
if conditions:
|
384 |
+
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
385 |
+
else:
|
386 |
+
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
387 |
+
all_condition = np.hstack(condition)
|
388 |
+
df['Condition'] = all_condition
|
389 |
+
return df
|
390 |
+
|
391 |
+
def process_dataset_in_mem(embeddings_df: pd.DataFrame,
|
392 |
+
specified_classes=None,
|
393 |
+
classes_to_remove=None,
|
394 |
+
max_class_size=None,
|
395 |
+
animal_state=None,
|
396 |
+
view=None,
|
397 |
+
shuffle_data=False,
|
398 |
+
test_videos=None):
|
399 |
+
"""
|
400 |
+
Processes output generated from embeddings paired with images and behavior labels.
|
401 |
+
|
402 |
+
Parameters:
|
403 |
+
-----------
|
404 |
+
csv_path : str
|
405 |
+
Path to the file containing the original data. This should contain embeddings,
|
406 |
+
a column named `'Label'` and a column named `'Images'`.
|
407 |
+
specified_classes : None | list[str]
|
408 |
+
An optional input. Defines labels which should be kept as is in the `'Label'`
|
409 |
+
column and which should be changed to a default `other` label.
|
410 |
+
classes_to_remove : None | list[str]
|
411 |
+
An optional input. Drops rows from the dataframe which contain a label in the
|
412 |
+
list.
|
413 |
+
max_class_size : None | int
|
414 |
+
An optional input. Determines the maximum amount of rows a single label can
|
415 |
+
appear in for each unique label in the `'Label'` column.
|
416 |
+
animal_state : None | str
|
417 |
+
An optional input. Drops rows from the dataframe which do not contain a match
|
418 |
+
for `animal_state` in the text field within the `'Images'` column.
|
419 |
+
view : None | str
|
420 |
+
An optional input. Drops rows from the dataframe which do not contain a match
|
421 |
+
for `view` in the text field within the `'Images'` column.
|
422 |
+
shuffle_data : bool
|
423 |
+
Determines wether the dataframe should have its rows shuffled.
|
424 |
+
test_videos : None | list[str]
|
425 |
+
An optional input. Determines what rows should be in the `test` dataframe, and
|
426 |
+
which should be in the `train` dataframe. It drops rows from the respective
|
427 |
+
dataframe by keeping or dropping rows which do not contain a match for a `str`
|
428 |
+
in `test_videos` in the text field within the `'Images'` column, respectively.
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
--------
|
432 |
+
balanced_train_embeddings : pandas.DataFrame
|
433 |
+
A processed dataframe whose rows contain the embeddings for each of the images
|
434 |
+
at the corresponding index within `balanced_train_images`.
|
435 |
+
balanced_train_labels : list[str]
|
436 |
+
A list of labels for each of the images at the corresponing index within
|
437 |
+
`balanced_train_images`.
|
438 |
+
balanced_train_images: list[str]
|
439 |
+
A list of paths to images with each image at an index corresponding to a label
|
440 |
+
with the same index in `balanced_train_labels` and the same row index within
|
441 |
+
`balanced_train_embeddings`.
|
442 |
+
test_embeddings : pandas.DataFrame
|
443 |
+
A processed dataframe whose rows contain the embeddings for each of the images
|
444 |
+
at the corresponding index within `test_images`.
|
445 |
+
test_labels : list[str]
|
446 |
+
A list of labels for each of the images at the corresponing index within
|
447 |
+
`test_images`.
|
448 |
+
test_images : list[str]
|
449 |
+
A list of paths to images with each image at an index corresponding to a label
|
450 |
+
with the same index in `test_labels` and the same row index within
|
451 |
+
`test_embeddings`.
|
452 |
+
"""
|
453 |
+
# Convert embeddings, labels, and images to a DataFrame for easy manipulation
|
454 |
+
df = copy.deepcopy(embeddings_df)
|
455 |
+
df_keys = [str(x) for x in df.keys()]
|
456 |
+
#Filter by fed or fasted
|
457 |
+
if 'Condition' in df_keys and animal_state:
|
458 |
+
df = df[df['Condition'].str.contains(animal_state, na=False)]
|
459 |
+
|
460 |
+
if 'View' in df_keys and view:
|
461 |
+
df = df[df['View'].str.contains(view, na=False)]
|
462 |
+
|
463 |
+
# Extract unique video names excluding the frame number
|
464 |
+
#unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
|
465 |
+
#print("\nUnique video names:\n", unique_video_names)
|
466 |
+
|
467 |
+
if classes_to_remove:
|
468 |
+
df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
|
469 |
+
elif classes_to_remove and 'all' in classes_to_remove:
|
470 |
+
df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
|
471 |
+
|
472 |
+
# Further filter to include only specified_classes
|
473 |
+
if specified_classes:
|
474 |
+
single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
|
475 |
+
df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
|
476 |
+
specified_classes.append('other')
|
477 |
+
|
478 |
+
# Separate the DataFrame into test and training sets based on test_videos
|
479 |
+
if 'Test' in df_keys and test_videos:
|
480 |
+
test_df = df[df['Test']]
|
481 |
+
train_df = df[~df['Test']]
|
482 |
+
elif test_videos:
|
483 |
+
test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
|
484 |
+
train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
|
485 |
+
else:
|
486 |
+
test_df = pd.DataFrame(columns=df.columns)
|
487 |
+
train_df = df
|
488 |
+
|
489 |
+
# Print the number of frames in each class before balancing
|
490 |
+
label_counts = train_df['Label'].value_counts()
|
491 |
+
print("\nNumber of training frames in each class before balancing:")
|
492 |
+
print(label_counts)
|
493 |
+
|
494 |
+
if max_class_size:
|
495 |
+
balanced_train_df = pd.concat([
|
496 |
+
group.sample(n=min(len(group), max_class_size), random_state=1)
|
497 |
+
for label, group in train_df.groupby('Label')
|
498 |
+
])
|
499 |
+
else:
|
500 |
+
balanced_train_df = train_df
|
501 |
+
|
502 |
+
# Shuffle the training DataFrame
|
503 |
+
if shuffle_data:
|
504 |
+
balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
|
505 |
+
|
506 |
+
# Convert training set back to numpy array and list
|
507 |
+
if not "Images" in df_keys:
|
508 |
+
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
|
509 |
+
balanced_train_labels = balanced_train_df['Label'].tolist()
|
510 |
+
balanced_train_images = balanced_train_df['Frame'].tolist()
|
511 |
+
|
512 |
+
# Convert test set back to numpy array and list
|
513 |
+
test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
|
514 |
+
test_labels = test_df['Label'].tolist()
|
515 |
+
test_images = test_df['Frame'].tolist()
|
516 |
+
else:
|
517 |
+
# Convert training set back to numpy array and list
|
518 |
+
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
|
519 |
+
balanced_train_labels = balanced_train_df['Label'].tolist()
|
520 |
+
balanced_train_images = balanced_train_df['Images'].tolist()
|
521 |
+
|
522 |
+
# Convert test set back to numpy array and list
|
523 |
+
test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()
|
524 |
+
test_labels = test_df['Label'].tolist()
|
525 |
+
test_images = test_df['Images'].tolist()
|
526 |
+
|
527 |
+
# Print the number of frames in each class after balancing
|
528 |
+
if specified_classes or max_class_size:
|
529 |
+
balanced_label_counts = Counter(balanced_train_labels)
|
530 |
+
print("\nNumber of training frames in each class after balancing:")
|
531 |
+
print(balanced_label_counts)
|
532 |
+
|
533 |
+
test_label_counts = test_df['Label'].value_counts()
|
534 |
+
# print("\nNumber of testing frames in each class:")
|
535 |
+
print(test_label_counts)
|
536 |
+
|
537 |
+
return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images
|
538 |
+
|
539 |
+
def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
|
540 |
+
# Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
|
541 |
+
unique_labels = np.unique(multiclass_vector)
|
542 |
+
unique_labels = unique_labels[unique_labels != 0]
|
543 |
+
|
544 |
+
# Initialize a vector to store the merged and filtered multiclass vector
|
545 |
+
merged_vector = np.zeros_like(multiclass_vector)
|
546 |
+
|
547 |
+
for label in unique_labels:
|
548 |
+
# Create a binary vector for the current label
|
549 |
+
binary_vector = (multiclass_vector == label)
|
550 |
+
|
551 |
+
# Find the start and end indices of all sequences of 1's for this label
|
552 |
+
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
|
553 |
+
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
|
554 |
+
|
555 |
+
# Step 1: Merge close short bouts
|
556 |
+
i = 0
|
557 |
+
while i < len(starts) - 1:
|
558 |
+
# Check if the gap between the end of the current bout and the start of the next bout
|
559 |
+
# is within the proximity threshold
|
560 |
+
if starts[i + 1] - ends[i] <= proximity_threshold:
|
561 |
+
# Merge the two bouts by setting all elements between the start of the first
|
562 |
+
# and the end of the second bout to 1
|
563 |
+
binary_vector[ends[i]:starts[i + 1]] = 1
|
564 |
+
# Remove the next bout from consideration
|
565 |
+
starts = np.delete(starts, i + 1)
|
566 |
+
ends = np.delete(ends, i)
|
567 |
+
else:
|
568 |
+
i += 1
|
569 |
+
|
570 |
+
# Update the starts and ends after merging
|
571 |
+
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
|
572 |
+
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
|
573 |
+
|
574 |
+
# Step 2: Remove standalone short bouts
|
575 |
+
for i in range(len(starts)):
|
576 |
+
# Check the length of the bout
|
577 |
+
length_of_bout = ends[i] - starts[i] + 1
|
578 |
+
|
579 |
+
# If the length is less than the threshold, set those elements to 0
|
580 |
+
if length_of_bout < bout_threshold:
|
581 |
+
binary_vector[starts[i]:ends[i] + 1] = 0
|
582 |
+
|
583 |
+
# Combine the binary vector with the merged_vector, ensuring only the current label is set
|
584 |
+
merged_vector[binary_vector] = label
|
585 |
+
|
586 |
+
# Return the filtered multiclass vector
|
587 |
+
return merged_vector
|
588 |
+
|
589 |
+
def get_unique_labels(label_list: list[str]):
|
590 |
+
label_set = set()
|
591 |
+
for label in label_list:
|
592 |
+
individual_labels = label.split('||')
|
593 |
+
for individual_label in individual_labels:
|
594 |
+
label_set.add(individual_label)
|
595 |
+
return list(label_set)
|
596 |
+
|
597 |
+
def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
|
598 |
+
return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
|
599 |
+
|
600 |
+
def train_model(X_train, y_train, random_state=42):
|
601 |
+
# Train SVM Classifier
|
602 |
+
svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
|
603 |
+
svm_clf.fit(X_train, y_train)
|
604 |
+
return svm_clf
|
605 |
+
|
606 |
+
def pickle_model(model):
|
607 |
+
pickled = io.BytesIO()
|
608 |
+
pickle.dump(model, pickled)
|
609 |
+
return pickled
|
610 |
+
|
611 |
+
def get_seq_io_reader(uploaded_file):
|
612 |
+
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
|
613 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
614 |
+
temp.write(uploaded_file.getvalue())
|
615 |
+
sr = seqIo_reader(temp.name)
|
616 |
+
return sr
|
617 |
+
|
618 |
+
def seq_to_arr(sr):
|
619 |
+
N = sr.header['numFrames']
|
620 |
+
images = []
|
621 |
+
for f in range(N):
|
622 |
+
I, ts = sr.getFrame(f)
|
623 |
+
images.append(I)
|
624 |
+
return np.array(images)
|
625 |
+
|
626 |
+
def get_2d_embedding(embeddings: pd.DataFrame):
|
627 |
+
tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
|
628 |
+
embedding_2d = tsne.fit_transform(np.array(embeddings))
|
629 |
+
return embedding_2d
|
630 |
+
|
631 |
+
|
632 |
+
|