Spaces:
Sleeping
Sleeping
import streamlit as st | |
import plotly.express as px | |
import numpy as np | |
import pandas as pd | |
import torch | |
from utils.mp4Io import mp4Io_reader | |
from utils.seqIo import seqIo_reader | |
import pandas as pd | |
from PIL import Image | |
from pathlib import Path | |
from transformers import AutoProcessor, AutoModel | |
from tempfile import NamedTemporaryFile | |
from tqdm import tqdm | |
from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io | |
from get_llava_response import get_llava_response, load_llava_checkpoint_hf | |
from sklearn.manifold import TSNE | |
from openai import OpenAI | |
import cv2 | |
import base64 | |
from hdbscan import HDBSCAN, all_points_membership_vectors | |
import random | |
# --server.maxUploadSize 3000 | |
REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge' | |
def load_llava_model(): | |
return load_llava_checkpoint_hf(REPO_NAME) | |
def get_unique_labels(label_list: list[str]): | |
label_set = set() | |
for label in label_list: | |
individual_labels = label.split('||') | |
for individual_label in individual_labels: | |
label_set.add(individual_label) | |
return list(label_set) | |
SYSTEM_PROMPT = """You are a researcher studying mice interactions from videos of the inside of a resident | |
intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one). | |
Your job is to answer questions about the behavior of the mice in the image given the context that each image is a frame of a continuous video. | |
Thus, you should use the visual information about the mice in the image to try to provide a detailed behavioral description of the image.""" | |
def get_io_reader(uploaded_file): | |
if uploaded_file.name[-3:]=='seq': | |
with NamedTemporaryFile(suffix="seq", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = seqIo_reader(temp.name) | |
else: | |
with NamedTemporaryFile(suffix="mp4", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = mp4Io_reader(temp.name) | |
return sr | |
def get_image(sr, frame_no: int): | |
image, _ = sr.getFrame(frame_no) | |
return image | |
def get_2d_embedding(embeddings: pd.DataFrame): | |
tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50) | |
embedding_2d = tsne.fit_transform(np.array(embeddings)) | |
return embedding_2d | |
HDBSCAN_PARAMS = { | |
'min_samples': 1 | |
} | |
def hdbscan_classification(umap_embeddings, embeddings_2d, cluster_range): | |
max_num_clusters = -np.infty | |
num_clusters = [] | |
min_cluster_size = np.linspace(cluster_range[0], cluster_range[1], 4) | |
for min_c in min_cluster_size: | |
learned_hierarchy = HDBSCAN( | |
prediction_data=True, min_cluster_size=int(round(min_c * 0.01 *umap_embeddings.shape[0])), | |
cluster_selection_method='leaf' , | |
**HDBSCAN_PARAMS).fit(umap_embeddings) | |
num_clusters.append(len(np.unique(learned_hierarchy.labels_))) | |
if num_clusters[-1] > max_num_clusters: | |
max_num_clusters = num_clusters[-1] | |
retained_hierarchy = learned_hierarchy | |
assignments = retained_hierarchy.labels_ | |
assign_prob = all_points_membership_vectors(retained_hierarchy) | |
soft_assignments = np.argmax(assign_prob, axis=1) | |
retained_hierarchy.fit(embeddings_2d) | |
return retained_hierarchy, assignments, assign_prob, soft_assignments | |
def upload_image(frame: np.ndarray): | |
"""returns the file ID.""" | |
_, encoded_image = cv2.imencode('.png', frame) | |
return base64.b64encode(encoded_image.tobytes()).decode('utf-8') | |
def ask_question_with_image_gpt(file_id, system_prompt, question, api_key): | |
"""Asks a question about the uploaded image.""" | |
client = OpenAI(api_key=api_key) | |
if file_id != None: | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": [ | |
{"type": "text", "text": question}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpg:base64, {file_id}"}}] | |
} | |
] | |
) | |
else: | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": question} | |
] | |
) | |
return response.choices[0].message.content | |
def ask_question_with_image_llava(image, system_prompt, question, | |
tokenizer, model, image_processor): | |
outputs = get_llava_response([question], | |
[image], | |
system_prompt, | |
tokenizer, | |
model, | |
image_processor, | |
REPO_NAME, | |
stream_output=False) | |
return outputs[0] | |
def ask_summary_question(image_array, label_array, api_key): | |
# load llava model | |
with st.spinner("Loading LLaVA model. This can take 10 to 30 minutes. Please wait..."): | |
tokenizer, model, image_processor = load_llava_model() | |
# global variable | |
system_prompt = SYSTEM_PROMPT | |
# collect responses | |
responses = [] | |
# create progress bar | |
j = 0 | |
pbar_text = lambda j: f'Creating llava response {j}/{len(label_array)}.' | |
pbar = st.progress(0, text=pbar_text(0)) | |
for i, image in enumerate(image_array): | |
label = label_array[i] | |
question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. " | |
question += "Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\ | |
"how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)"\ | |
"of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences." | |
response = ask_question_with_image_llava(image, system_prompt, question, | |
tokenizer, model, image_processor) | |
responses.append(response) | |
# update progress bar | |
j += 1 | |
pbar.progress(j/len(label_array), pbar_text(j)) | |
system_prompt_summarize = "You are a researcher studying mice interactions from videos of the inside of a resident "\ | |
"intruder box where there is either just the resident mouse (the black one) or the resident and the intruder mouse (the white one). "\ | |
"You will be given a question about a list of descriptions from frames of these videos. "\ | |
"Your job is to answer the question by focusing on the behaviors of the mice and their postures "\ | |
"as well as any other aspects of the descriptions that may be relevant to the class label associated with them" | |
user_prompt_summarize = "Here are several descriptions of individual frames from a mouse behavior video. Please summarize these descriptions and provide a suggestion for a "\ | |
"behavior label which captures what is described in the descriptions: \n\n" | |
user_prompt_summarize = user_prompt_summarize + '\n'.join(responses) | |
summary_response = ask_question_with_image_gpt(None, system_prompt_summarize, user_prompt_summarize, api_key) | |
return summary_response | |
if "embeddings_df" not in st.session_state: | |
st.session_state.embeddings_df = None | |
st.title('batik: behavior discovery and LLM-based interpretation') | |
api_key = st.text_input("OpenAI API Key:","") | |
st.subheader("generate or import embeddings") | |
st.text("Upload files to generate embeddings.") | |
with st.form('embedding_generation_settings'): | |
seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4']) | |
annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True) | |
downsample_rate = st.number_input('Downsample Rate',value=4) | |
submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary') | |
st.markdown("**(Optional)** Upload embeddings.") | |
embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv']) | |
if submit_embed_settings and seq_file is not None and annot_files is not None: | |
video_embeddings, video_frames = generate_embeddings_stream_io([seq_file], | |
"SLIP", | |
downsample_rate, | |
False) | |
fnames = [seq_file.name] | |
embeddings_df = create_embeddings_csv_io(out="file", | |
fnames=fnames, | |
embeddings=video_embeddings, | |
frames=video_frames, | |
annotations=[annot_files], | |
test_fnames=None, | |
views=None, | |
conditions=None, | |
downsample_rate=downsample_rate) | |
st.session_state.embeddings_df = embeddings_df | |
elif embeddings_csv is not None: | |
embeddings_df = pd.read_csv(embeddings_csv) | |
st.session_state.embeddings_df = embeddings_df | |
else: | |
st.text('Please upload file(s).') | |
st.divider() | |
st.subheader("provide video file if not yet already provided") | |
uploaded_file = st.file_uploader("Choose a video file", type=['seq', 'mp4']) | |
st.divider() | |
if st.session_state.embeddings_df is not None and (uploaded_file is not None or seq_file is not None): | |
if seq_file is not None: | |
uploaded_file = seq_file | |
io_reader = get_io_reader(uploaded_file) | |
print("CONVERTED SEQ") | |
label_list = st.session_state.embeddings_df['Label'].to_list() | |
unique_label_list = get_unique_labels(label_list) | |
print(f"unique_labels: {unique_label_list}") | |
#unique_label_list = ['check_genital', 'wiggle', 'lordose', 'stay', 'turn', 'top_up', 'dart', 'sniff', 'approach', 'into_male_cage'] | |
#unique_label_list = ['into_male_cage', 'intromission', 'male_sniff', 'mount'] | |
kwargs = {'embeddings_df' : st.session_state.embeddings_df, | |
'specified_classes' : unique_label_list, | |
'classes_to_remove' : None, | |
'max_class_size' : None, | |
'animal_state' : None, | |
'view' : None, | |
'shuffle_data' : False, | |
'test_videos' : None} | |
train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs) | |
print("PROCESSED DATASET") | |
if "Images" in st.session_state.embeddings_df.keys(): | |
train_images = [i for i in range(len(train_images))] | |
embedding_2d = get_2d_embedding(train_embeds) | |
else: | |
st.text('Please generate embeddings and provide video file.') | |
print("GOT 2D EMBEDS") | |
if uploaded_file is not None and st.session_state.embeddings_df is not None: | |
st.subheader("t-SNE Projection") | |
option = st.selectbox( | |
"Select Color Option", | |
("By Label", "By Time", "By Cluster") | |
) | |
if embedding_2d is not None: | |
if option is not None: | |
if option == "By Label": | |
color = 'label' | |
elif option == "By Time": | |
color = 'frame_no' | |
else: | |
color = 'cluster_label' | |
if option in ["By Label", "By Time"]: | |
edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2']) | |
edf.insert(2,'frame_no',np.array([int(x) for x in train_images])) | |
edf.insert(3, 'label', train_labels) | |
fig = px.scatter( | |
edf, | |
x="tsne_dim_1", | |
y="tsne_dim_2", | |
color=color, | |
hover_data=["frame_no"], | |
color_discrete_sequence=px.colors.qualitative.Dark24 | |
) | |
else: | |
r, _, _, _ = hdbscan_classification(train_embeds, embedding_2d, [4, 6]) | |
edf = pd.DataFrame(embedding_2d,columns=['tsne_dim_1', 'tsne_dim_2']) | |
edf.insert(2,'frame_no',np.array([int(x) for x in train_images])) | |
edf.insert(3, 'label', train_labels) | |
edf.insert(4, 'cluster_label', [str(c_id) for c_id in r.labels_.tolist()]) | |
fig = px.scatter( | |
edf, | |
x="tsne_dim_1", | |
y="tsne_dim_2", | |
color=color, | |
hover_data=["frame_no"], | |
color_discrete_sequence=px.colors.qualitative.Dark24 | |
) | |
event = st.plotly_chart(fig, key="df", on_select="rerun") | |
else: | |
st.text("No Color Option Selected") | |
else: | |
st.text('No Embeddings Loaded') | |
event_dict = event.selection | |
if event_dict is not None: | |
custom_data = [] | |
for point in event_dict['points']: | |
data = point["customdata"][0] | |
custom_data.append(int(data)) | |
if len(custom_data) > 10: | |
custom_data = random.sample(custom_data, 10) | |
if len(custom_data) > 1: | |
col_1, col_2 = st.columns(2) | |
with col_1: | |
for frame_no in custom_data[::2]: | |
st.image(get_image(io_reader, frame_no)) | |
st.caption(f"Frame {frame_no}, {train_labels[frame_no]}") | |
with col_2: | |
for frame_no in custom_data[1::2]: | |
st.image(get_image(io_reader, frame_no)) | |
st.caption(f"Frame {frame_no}, {train_labels[frame_no]}") | |
elif len(custom_data) == 1: | |
frame_no = custom_data[0] | |
st.image(get_image(io_reader, frame_no)) | |
st.caption(f"Frame {frame_no}, {train_labels[frame_no]}") | |
else: | |
st.text('No Points Selected') | |
if len(custom_data) == 1: | |
frame_no = custom_data[0] | |
image = get_image(io_reader, frame_no) | |
system_prompt = SYSTEM_PROMPT | |
label = train_labels[frame_no] | |
question = f"The frame is annotated by a human observer with the label: {label}. Give evidence for this label using the posture of the mice and their current behavior. "\ | |
"Also, designate a behavioral subtype of the given label that describes the current social interaction based on what you see about the posture of the mice and "\ | |
"how they are positioned with respect to each other. Usually, the body parts (i.e., tail, genitals, face, body, ears, paws)" \ | |
"of the mice that are closest to each other will give some clue. Please limit behavioral subtype to a 1-4 word phrase. limit your response to 4 sentences." | |
with st.spinner("Loading LLaVA model. This can take 10 to 30 minutes. Please wait..."): | |
tokenizer, model, image_processor = load_llava_model() | |
response = ask_question_with_image_llava(image, system_prompt, question, | |
tokenizer, model, image_processor) | |
st.markdown(response) | |
elif len(custom_data) > 1: | |
image_array = [get_image(io_reader, f_no) for f_no in custom_data] | |
label_array = [train_labels[f_no] for f_no in custom_data] | |
response = ask_summary_question(image_array, label_array, api_key) | |
st.markdown(response) | |