adityanairneuro commited on
Commit
92a7fcc
·
verified ·
1 Parent(s): ac9e1fa

Update explore.py

Browse files
Files changed (1) hide show
  1. explore.py +337 -337
explore.py CHANGED
@@ -1,337 +1,337 @@
1
- import streamlit as st
2
- import plotly.express as px
3
- import numpy as np
4
- import pandas as pd
5
- import torch
6
- from utils.mp4Io import mp4Io_reader
7
- from utils.seqIo import seqIo_reader
8
- import pandas as pd
9
- from PIL import Image
10
- from pathlib import Path
11
- from transformers import AutoProcessor, AutoModel
12
- from tempfile import NamedTemporaryFile
13
- from tqdm import tqdm
14
- from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
15
- from get_llava_response import get_llava_response, load_llava_checkpoint_hf
16
- from sklearn.manifold import TSNE
17
- from openai import OpenAI
18
- import cv2
19
- import base64
20
- from hdbscan import HDBSCAN, all_points_membership_vectors
21
- import random
22
-
23
- # --server.maxUploadSize 3000
24
- REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
25
-
26
- def load_llava_model(hf_token):
27
- return load_llava_checkpoint_hf(REPO_NAME, hf_token)
28
-
29
- def get_unique_labels(label_list: list[str]):
30
- label_set = set()
31
- for label in label_list:
32
- individual_labels = label.split('||')
33
- for individual_label in individual_labels:
34
- label_set.add(individual_label)
35
- return list(label_set)
36
-
37
- SYSTEM_PROMPT = """You are a researcher studying mice interactions from videos of the inside of a resident
38
- intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one).
39
- Your job is to answer questions about the behavior of the mice in the image given the context that each image is a frame of a continuous video.
40
- Thus, you should use the visual information about the mice in the image to try to provide a detailed behavioral description of the image."""
41
-
42
- @st.cache_resource
43
- def get_io_reader(uploaded_file):
44
- if uploaded_file.name[-3:]=='seq':
45
- with NamedTemporaryFile(suffix="seq", delete=False) as temp:
46
- temp.write(uploaded_file.getvalue())
47
- sr = seqIo_reader(temp.name)
48
- else:
49
- with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
50
- temp.write(uploaded_file.getvalue())
51
- sr = mp4Io_reader(temp.name)
52
- return sr
53
-
54
- def get_image(sr, frame_no: int):
55
- image, _ = sr.getFrame(frame_no)
56
- return image
57
-
58
- @st.cache_data
59
- def get_2d_embedding(embeddings: pd.DataFrame):
60
- tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
61
- embedding_2d = tsne.fit_transform(np.array(embeddings))
62
- return embedding_2d
63
-
64
- HDBSCAN_PARAMS = {
65
- 'min_samples': 1
66
- }
67
-
68
- @st.cache_data
69
- def hdbscan_classification(umap_embeddings, embeddings_2d, cluster_range):
70
- max_num_clusters = -np.infty
71
- num_clusters = []
72
- min_cluster_size = np.linspace(cluster_range[0], cluster_range[1], 4)
73
- for min_c in min_cluster_size:
74
- learned_hierarchy = HDBSCAN(
75
- prediction_data=True, min_cluster_size=int(round(min_c * 0.01 *umap_embeddings.shape[0])),
76
- cluster_selection_method='leaf' ,
77
- **HDBSCAN_PARAMS).fit(umap_embeddings)
78
- num_clusters.append(len(np.unique(learned_hierarchy.labels_)))
79
- if num_clusters[-1] > max_num_clusters:
80
- max_num_clusters = num_clusters[-1]
81
- retained_hierarchy = learned_hierarchy
82
- assignments = retained_hierarchy.labels_
83
- assign_prob = all_points_membership_vectors(retained_hierarchy)
84
- soft_assignments = np.argmax(assign_prob, axis=1)
85
- retained_hierarchy.fit(embeddings_2d)
86
- return retained_hierarchy, assignments, assign_prob, soft_assignments
87
-
88
- def upload_image(frame: np.ndarray):
89
- """returns the file ID."""
90
- _, encoded_image = cv2.imencode('.png', frame)
91
- return base64.b64encode(encoded_image.tobytes()).decode('utf-8')
92
-
93
- def ask_question_with_image_gpt(file_id, system_prompt, question, api_key):
94
- """Asks a question about the uploaded image."""
95
- client = OpenAI(api_key=api_key)
96
-
97
- if file_id != None:
98
- response = client.chat.completions.create(
99
- model="gpt-4o",
100
- messages=[
101
- {"role": "system", "content": system_prompt},
102
- {"role": "user", "content": [
103
- {"type": "text", "text": question},
104
- {"type": "image_url", "image_url": {"url": f"data:image/jpg:base64, {file_id}"}}]
105
- }
106
- ]
107
- )
108
- else:
109
- response = client.chat.completions.create(
110
- model="gpt-4o",
111
- messages=[
112
- {"role": "system", "content": system_prompt},
113
- {"role": "user", "content": question}
114
- ]
115
- )
116
- return response.choices[0].message.content
117
-
118
- def ask_question_with_image_llava(image, system_prompt, question,
119
- tokenizer, model, image_processor):
120
- outputs = get_llava_response([question],
121
- [image],
122
- system_prompt,
123
- tokenizer,
124
- model,
125
- image_processor,
126
- REPO_NAME,
127
- stream_output=False)
128
- return outputs[0]
129
-
130
- def ask_summary_question(image_array, label_array, api_key):
131
- # load llava model
132
- tokenizer, model, image_processor = load_llava_model(hf_token)
133
-
134
- # global variable
135
- system_prompt = SYSTEM_PROMPT
136
-
137
- # collect responses
138
- responses = []
139
-
140
- # create progress bar
141
- j = 0
142
- pbar_text = lambda j: f'Creating llava response {j}/{len(label_array)}.'
143
- pbar = st.progress(0, text=pbar_text(0))
144
-
145
- for i, image in enumerate(image_array):
146
- label = label_array[i]
147
- question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "
148
- question += "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
149
- "how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)"\
150
- "of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
151
- response = ask_question_with_image_llava(image, system_prompt, question,
152
- tokenizer, model, image_processor)
153
- responses.append(response)
154
- # update progress bar
155
- j += 1
156
- pbar.progress(j/len(label_array), pbar_text(j))
157
-
158
- system_prompt_summarize = "You are a researcher studying mice interactions from videos of the inside of a resident "\
159
- "intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one). "\
160
- "You will be given a question about a list of descriptions from frames of these videos. "\
161
- "Your job is to answer the question by focusing on the behaviors of the mice and their postures "\
162
- "as well as any other aspects of the descriptions that may be relevant to the class label associated with them"
163
- user_prompt_summarize = "Here are several descriptions of individual frames from a mouse behavior video. Please summarize these descriptions and provide a suggestion for a "\
164
- "behavior label which captures what is described in the descriptions: \n\n"
165
- user_prompt_summarize = user_prompt_summarize + '\n'.join(responses)
166
- summary_response = ask_question_with_image_gpt(None, system_prompt_summarize, user_prompt_summarize, api_key)
167
- return summary_response
168
-
169
- if "embeddings_df" not in st.session_state:
170
- st.session_state.embeddings_df = None
171
-
172
- st.title('batik: frame classifier')
173
-
174
- api_key = st.text_input("OpenAI API Key:","")
175
- hf_token = st.text_input("HuggingFace Token:","")
176
- st.subheader("generate or import embeddings")
177
-
178
- st.text("Upload files to generate embeddings.")
179
- with st.form('embedding_generation_settings'):
180
- seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
181
- annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
182
- downsample_rate = st.number_input('Downsample Rate',value=4)
183
- submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
184
-
185
- st.markdown("**(Optional)** Upload embeddings.")
186
- embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
187
-
188
- if submit_embed_settings and seq_file is not None and annot_files is not None:
189
- video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
190
- "SLIP",
191
- downsample_rate,
192
- False)
193
-
194
- fnames = [seq_file.name]
195
- embeddings_df = create_embeddings_csv_io(out="file",
196
- fnames=fnames,
197
- embeddings=video_embeddings,
198
- frames=video_frames,
199
- annotations=[annot_files],
200
- test_fnames=None,
201
- views=None,
202
- conditions=None,
203
- downsample_rate=downsample_rate)
204
- st.session_state.embeddings_df = embeddings_df
205
- elif embeddings_csv is not None:
206
- embeddings_df = pd.read_csv(embeddings_csv)
207
- st.session_state.embeddings_df = embeddings_df
208
- else:
209
- st.text('Please upload file(s).')
210
-
211
- st.divider()
212
- st.subheader("provide video file if not yet already provided")
213
-
214
- uploaded_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
215
-
216
- st.divider()
217
- if st.session_state.embeddings_df is not None and (uploaded_file is not None or seq_file is not None):
218
- if seq_file is not None:
219
- uploaded_file = seq_file
220
- io_reader = get_io_reader(uploaded_file)
221
- print("CONVERTED SEQ")
222
- label_list = st.session_state.embeddings_df['Label'].to_list()
223
- unique_label_list = get_unique_labels(label_list)
224
- print(f"unique_labels: {unique_label_list}")
225
- #unique_label_list = ['check_genital', 'wiggle', 'lordose', 'stay', 'turn', 'top_up', 'dart', 'sniff', 'approach', 'into_male_cage']
226
- #unique_label_list = ['into_male_cage', 'intromission', 'male_sniff', 'mount']
227
- kwargs = {'embeddings_df' : st.session_state.embeddings_df,
228
- 'specified_classes' : unique_label_list,
229
- 'classes_to_remove' : None,
230
- 'max_class_size' : None,
231
- 'animal_state' : None,
232
- 'view' : None,
233
- 'shuffle_data' : False,
234
- 'test_videos' : None}
235
- train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
236
- print("PROCESSED DATASET")
237
- if "Images" in st.session_state.embeddings_df.keys():
238
- train_images = [i for i in range(len(train_images))]
239
- embedding_2d = get_2d_embedding(train_embeds)
240
- else:
241
- st.text('Please generate embeddings and provide video file.')
242
- print("GOT 2D EMBEDS")
243
-
244
- if uploaded_file is not None and st.session_state.embeddings_df is not None:
245
- st.subheader("t-SNE Projection")
246
- option = st.selectbox(
247
- "Select Color Option",
248
- ("By Label", "By Time", "By Cluster")
249
- )
250
- if embedding_2d is not None:
251
- if option is not None:
252
- if option == "By Label":
253
- color = 'label'
254
- elif option == "By Time":
255
- color = 'frame_no'
256
- else:
257
- color = 'cluster_label'
258
-
259
- if option in ["By Label", "By Time"]:
260
- edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
261
- edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
262
- edf.insert(3, 'label', train_labels)
263
- fig = px.scatter(
264
- edf,
265
- x="tsne_dim_1",
266
- y="tsne_dim_2",
267
- color=color,
268
- hover_data=["frame_no"],
269
- color_discrete_sequence=px.colors.qualitative.Dark24
270
- )
271
- else:
272
- r, _, _, _ = hdbscan_classification(train_embeds, embedding_2d, [4, 6])
273
- edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
274
- edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
275
- edf.insert(3, 'label', train_labels)
276
- edf.insert(4, 'cluster_label', [str(c_id) for c_id in r.labels_.tolist()])
277
- fig = px.scatter(
278
- edf,
279
- x="tsne_dim_1",
280
- y="tsne_dim_2",
281
- color=color,
282
- hover_data=["frame_no"],
283
- color_discrete_sequence=px.colors.qualitative.Dark24
284
- )
285
-
286
- event = st.plotly_chart(fig, key="df", on_select="rerun")
287
- else:
288
- st.text("No Color Option Selected")
289
- else:
290
- st.text('No Embeddings Loaded')
291
-
292
- event_dict = event.selection
293
-
294
- if event_dict is not None:
295
- custom_data = []
296
- for point in event_dict['points']:
297
- data = point["customdata"][0]
298
- custom_data.append(int(data))
299
-
300
- if len(custom_data) > 10:
301
- custom_data = random.sample(custom_data, 10)
302
- if len(custom_data) > 1:
303
- col_1, col_2 = st.columns(2)
304
- with col_1:
305
- for frame_no in custom_data[::2]:
306
- st.image(get_image(io_reader, frame_no))
307
- st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
308
- with col_2:
309
- for frame_no in custom_data[1::2]:
310
- st.image(get_image(io_reader, frame_no))
311
- st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
312
- elif len(custom_data) == 1:
313
- frame_no = custom_data[0]
314
- st.image(get_image(io_reader, frame_no))
315
- st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
316
- else:
317
- st.text('No Points Selected')
318
-
319
- if len(custom_data) == 1:
320
- frame_no = custom_data[0]
321
- image = get_image(io_reader, frame_no)
322
- system_prompt = SYSTEM_PROMPT
323
- label = train_labels[frame_no]
324
- question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "\
325
- "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
326
- "how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)" \
327
- "of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
328
- tokenizer, model, image_processor = load_llava_model(hf_token)
329
- response = ask_question_with_image_llava(image, system_prompt, question,
330
- tokenizer, model, image_processor)
331
- st.markdown(response)
332
-
333
- elif len(custom_data) > 1:
334
- image_array = [get_image(io_reader, f_no) for f_no in custom_data]
335
- label_array = [train_labels[f_no] for f_no in custom_data]
336
- response = ask_summary_question(image_array, label_array, api_key)
337
- st.markdown(response)
 
1
+ import streamlit as st
2
+ import plotly.express as px
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ from utils.mp4Io import mp4Io_reader
7
+ from utils.seqIo import seqIo_reader
8
+ import pandas as pd
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ from transformers import AutoProcessor, AutoModel
12
+ from tempfile import NamedTemporaryFile
13
+ from tqdm import tqdm
14
+ from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
15
+ from get_llava_response import get_llava_response, load_llava_checkpoint_hf
16
+ from sklearn.manifold import TSNE
17
+ from openai import OpenAI
18
+ import cv2
19
+ import base64
20
+ from hdbscan import HDBSCAN, all_points_membership_vectors
21
+ import random
22
+
23
+ # --server.maxUploadSize 3000
24
+ REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
25
+
26
+ def load_llava_model(hf_token):
27
+ return load_llava_checkpoint_hf(REPO_NAME, hf_token)
28
+
29
+ def get_unique_labels(label_list: list[str]):
30
+ label_set = set()
31
+ for label in label_list:
32
+ individual_labels = label.split('||')
33
+ for individual_label in individual_labels:
34
+ label_set.add(individual_label)
35
+ return list(label_set)
36
+
37
+ SYSTEM_PROMPT = """You are a researcher studying mice interactions from videos of the inside of a resident
38
+ intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one).
39
+ Your job is to answer questions about the behavior of the mice in the image given the context that each image is a frame of a continuous video.
40
+ Thus, you should use the visual information about the mice in the image to try to provide a detailed behavioral description of the image."""
41
+
42
+ @st.cache_resource
43
+ def get_io_reader(uploaded_file):
44
+ if uploaded_file.name[-3:]=='seq':
45
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
46
+ temp.write(uploaded_file.getvalue())
47
+ sr = seqIo_reader(temp.name)
48
+ else:
49
+ with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
50
+ temp.write(uploaded_file.getvalue())
51
+ sr = mp4Io_reader(temp.name)
52
+ return sr
53
+
54
+ def get_image(sr, frame_no: int):
55
+ image, _ = sr.getFrame(frame_no)
56
+ return image
57
+
58
+ @st.cache_data
59
+ def get_2d_embedding(embeddings: pd.DataFrame):
60
+ tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
61
+ embedding_2d = tsne.fit_transform(np.array(embeddings))
62
+ return embedding_2d
63
+
64
+ HDBSCAN_PARAMS = {
65
+ 'min_samples': 1
66
+ }
67
+
68
+ @st.cache_data
69
+ def hdbscan_classification(umap_embeddings, embeddings_2d, cluster_range):
70
+ max_num_clusters = -np.infty
71
+ num_clusters = []
72
+ min_cluster_size = np.linspace(cluster_range[0], cluster_range[1], 4)
73
+ for min_c in min_cluster_size:
74
+ learned_hierarchy = HDBSCAN(
75
+ prediction_data=True, min_cluster_size=int(round(min_c * 0.01 *umap_embeddings.shape[0])),
76
+ cluster_selection_method='leaf' ,
77
+ **HDBSCAN_PARAMS).fit(umap_embeddings)
78
+ num_clusters.append(len(np.unique(learned_hierarchy.labels_)))
79
+ if num_clusters[-1] > max_num_clusters:
80
+ max_num_clusters = num_clusters[-1]
81
+ retained_hierarchy = learned_hierarchy
82
+ assignments = retained_hierarchy.labels_
83
+ assign_prob = all_points_membership_vectors(retained_hierarchy)
84
+ soft_assignments = np.argmax(assign_prob, axis=1)
85
+ retained_hierarchy.fit(embeddings_2d)
86
+ return retained_hierarchy, assignments, assign_prob, soft_assignments
87
+
88
+ def upload_image(frame: np.ndarray):
89
+ """returns the file ID."""
90
+ _, encoded_image = cv2.imencode('.png', frame)
91
+ return base64.b64encode(encoded_image.tobytes()).decode('utf-8')
92
+
93
+ def ask_question_with_image_gpt(file_id, system_prompt, question, api_key):
94
+ """Asks a question about the uploaded image."""
95
+ client = OpenAI(api_key=api_key)
96
+
97
+ if file_id != None:
98
+ response = client.chat.completions.create(
99
+ model="gpt-4o",
100
+ messages=[
101
+ {"role": "system", "content": system_prompt},
102
+ {"role": "user", "content": [
103
+ {"type": "text", "text": question},
104
+ {"type": "image_url", "image_url": {"url": f"data:image/jpg:base64, {file_id}"}}]
105
+ }
106
+ ]
107
+ )
108
+ else:
109
+ response = client.chat.completions.create(
110
+ model="gpt-4o",
111
+ messages=[
112
+ {"role": "system", "content": system_prompt},
113
+ {"role": "user", "content": question}
114
+ ]
115
+ )
116
+ return response.choices[0].message.content
117
+
118
+ def ask_question_with_image_llava(image, system_prompt, question,
119
+ tokenizer, model, image_processor):
120
+ outputs = get_llava_response([question],
121
+ [image],
122
+ system_prompt,
123
+ tokenizer,
124
+ model,
125
+ image_processor,
126
+ REPO_NAME,
127
+ stream_output=False)
128
+ return outputs[0]
129
+
130
+ def ask_summary_question(image_array, label_array, api_key):
131
+ # load llava model
132
+ tokenizer, model, image_processor = load_llava_model(hf_token)
133
+
134
+ # global variable
135
+ system_prompt = SYSTEM_PROMPT
136
+
137
+ # collect responses
138
+ responses = []
139
+
140
+ # create progress bar
141
+ j = 0
142
+ pbar_text = lambda j: f'Creating llava response {j}/{len(label_array)}.'
143
+ pbar = st.progress(0, text=pbar_text(0))
144
+
145
+ for i, image in enumerate(image_array):
146
+ label = label_array[i]
147
+ question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "
148
+ question += "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
149
+ "how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)"\
150
+ "of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
151
+ response = ask_question_with_image_llava(image, system_prompt, question,
152
+ tokenizer, model, image_processor)
153
+ responses.append(response)
154
+ # update progress bar
155
+ j += 1
156
+ pbar.progress(j/len(label_array), pbar_text(j))
157
+
158
+ system_prompt_summarize = "You are a researcher studying mice interactions from videos of the inside of a resident "\
159
+ "intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one). "\
160
+ "You will be given a question about a list of descriptions from frames of these videos. "\
161
+ "Your job is to answer the question by focusing on the behaviors of the mice and their postures "\
162
+ "as well as any other aspects of the descriptions that may be relevant to the class label associated with them"
163
+ user_prompt_summarize = "Here are several descriptions of individual frames from a mouse behavior video. Please summarize these descriptions and provide a suggestion for a "\
164
+ "behavior label which captures what is described in the descriptions: \n\n"
165
+ user_prompt_summarize = user_prompt_summarize + '\n'.join(responses)
166
+ summary_response = ask_question_with_image_gpt(None, system_prompt_summarize, user_prompt_summarize, api_key)
167
+ return summary_response
168
+
169
+ if "embeddings_df" not in st.session_state:
170
+ st.session_state.embeddings_df = None
171
+
172
+ st.title('batik: behavior discovery and LLM-based interpretation')
173
+
174
+ api_key = st.text_input("OpenAI API Key:","")
175
+ hf_token = st.text_input("HuggingFace Token:","")
176
+ st.subheader("generate or import embeddings")
177
+
178
+ st.text("Upload files to generate embeddings.")
179
+ with st.form('embedding_generation_settings'):
180
+ seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
181
+ annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
182
+ downsample_rate = st.number_input('Downsample Rate',value=4)
183
+ submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
184
+
185
+ st.markdown("**(Optional)** Upload embeddings.")
186
+ embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
187
+
188
+ if submit_embed_settings and seq_file is not None and annot_files is not None:
189
+ video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
190
+ "SLIP",
191
+ downsample_rate,
192
+ False)
193
+
194
+ fnames = [seq_file.name]
195
+ embeddings_df = create_embeddings_csv_io(out="file",
196
+ fnames=fnames,
197
+ embeddings=video_embeddings,
198
+ frames=video_frames,
199
+ annotations=[annot_files],
200
+ test_fnames=None,
201
+ views=None,
202
+ conditions=None,
203
+ downsample_rate=downsample_rate)
204
+ st.session_state.embeddings_df = embeddings_df
205
+ elif embeddings_csv is not None:
206
+ embeddings_df = pd.read_csv(embeddings_csv)
207
+ st.session_state.embeddings_df = embeddings_df
208
+ else:
209
+ st.text('Please upload file(s).')
210
+
211
+ st.divider()
212
+ st.subheader("provide video file if not yet already provided")
213
+
214
+ uploaded_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
215
+
216
+ st.divider()
217
+ if st.session_state.embeddings_df is not None and (uploaded_file is not None or seq_file is not None):
218
+ if seq_file is not None:
219
+ uploaded_file = seq_file
220
+ io_reader = get_io_reader(uploaded_file)
221
+ print("CONVERTED SEQ")
222
+ label_list = st.session_state.embeddings_df['Label'].to_list()
223
+ unique_label_list = get_unique_labels(label_list)
224
+ print(f"unique_labels: {unique_label_list}")
225
+ #unique_label_list = ['check_genital', 'wiggle', 'lordose', 'stay', 'turn', 'top_up', 'dart', 'sniff', 'approach', 'into_male_cage']
226
+ #unique_label_list = ['into_male_cage', 'intromission', 'male_sniff', 'mount']
227
+ kwargs = {'embeddings_df' : st.session_state.embeddings_df,
228
+ 'specified_classes' : unique_label_list,
229
+ 'classes_to_remove' : None,
230
+ 'max_class_size' : None,
231
+ 'animal_state' : None,
232
+ 'view' : None,
233
+ 'shuffle_data' : False,
234
+ 'test_videos' : None}
235
+ train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
236
+ print("PROCESSED DATASET")
237
+ if "Images" in st.session_state.embeddings_df.keys():
238
+ train_images = [i for i in range(len(train_images))]
239
+ embedding_2d = get_2d_embedding(train_embeds)
240
+ else:
241
+ st.text('Please generate embeddings and provide video file.')
242
+ print("GOT 2D EMBEDS")
243
+
244
+ if uploaded_file is not None and st.session_state.embeddings_df is not None:
245
+ st.subheader("t-SNE Projection")
246
+ option = st.selectbox(
247
+ "Select Color Option",
248
+ ("By Label", "By Time", "By Cluster")
249
+ )
250
+ if embedding_2d is not None:
251
+ if option is not None:
252
+ if option == "By Label":
253
+ color = 'label'
254
+ elif option == "By Time":
255
+ color = 'frame_no'
256
+ else:
257
+ color = 'cluster_label'
258
+
259
+ if option in ["By Label", "By Time"]:
260
+ edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
261
+ edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
262
+ edf.insert(3, 'label', train_labels)
263
+ fig = px.scatter(
264
+ edf,
265
+ x="tsne_dim_1",
266
+ y="tsne_dim_2",
267
+ color=color,
268
+ hover_data=["frame_no"],
269
+ color_discrete_sequence=px.colors.qualitative.Dark24
270
+ )
271
+ else:
272
+ r, _, _, _ = hdbscan_classification(train_embeds, embedding_2d, [4, 6])
273
+ edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2'])
274
+ edf.insert(2,'frame_no',np.array([int(x) for x in train_images]))
275
+ edf.insert(3, 'label', train_labels)
276
+ edf.insert(4, 'cluster_label', [str(c_id) for c_id in r.labels_.tolist()])
277
+ fig = px.scatter(
278
+ edf,
279
+ x="tsne_dim_1",
280
+ y="tsne_dim_2",
281
+ color=color,
282
+ hover_data=["frame_no"],
283
+ color_discrete_sequence=px.colors.qualitative.Dark24
284
+ )
285
+
286
+ event = st.plotly_chart(fig, key="df", on_select="rerun")
287
+ else:
288
+ st.text("No Color Option Selected")
289
+ else:
290
+ st.text('No Embeddings Loaded')
291
+
292
+ event_dict = event.selection
293
+
294
+ if event_dict is not None:
295
+ custom_data = []
296
+ for point in event_dict['points']:
297
+ data = point["customdata"][0]
298
+ custom_data.append(int(data))
299
+
300
+ if len(custom_data) > 10:
301
+ custom_data = random.sample(custom_data, 10)
302
+ if len(custom_data) > 1:
303
+ col_1, col_2 = st.columns(2)
304
+ with col_1:
305
+ for frame_no in custom_data[::2]:
306
+ st.image(get_image(io_reader, frame_no))
307
+ st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
308
+ with col_2:
309
+ for frame_no in custom_data[1::2]:
310
+ st.image(get_image(io_reader, frame_no))
311
+ st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
312
+ elif len(custom_data) == 1:
313
+ frame_no = custom_data[0]
314
+ st.image(get_image(io_reader, frame_no))
315
+ st.caption(f"Frame {frame_no}, {train_labels[frame_no]}")
316
+ else:
317
+ st.text('No Points Selected')
318
+
319
+ if len(custom_data) == 1:
320
+ frame_no = custom_data[0]
321
+ image = get_image(io_reader, frame_no)
322
+ system_prompt = SYSTEM_PROMPT
323
+ label = train_labels[frame_no]
324
+ question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "\
325
+ "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\
326
+ "how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)" \
327
+ "of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences."
328
+ tokenizer, model, image_processor = load_llava_model(hf_token)
329
+ response = ask_question_with_image_llava(image, system_prompt, question,
330
+ tokenizer, model, image_processor)
331
+ st.markdown(response)
332
+
333
+ elif len(custom_data) > 1:
334
+ image_array = [get_image(io_reader, f_no) for f_no in custom_data]
335
+ label_array = [train_labels[f_no] for f_no in custom_data]
336
+ response = ask_summary_question(image_array, label_array, api_key)
337
+ st.markdown(response)