Spaces:
Sleeping
Sleeping
import os | |
import io | |
import pickle | |
import copy | |
from collections import Counter | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
import regex as re | |
import numpy as np | |
import pandas as pd | |
from sklearn.manifold import TSNE | |
from sklearn.svm import SVC | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import accuracy_score, classification_report | |
import torch | |
from tqdm import tqdm | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModel | |
import streamlit as st | |
from .data_loading import load_multiple_annotations, load_multiple_annotations_io | |
from .data_processing import generate_label_array | |
from .seqIo import seqIo_reader | |
from .mp4Io import mp4Io_reader | |
SLIP_MODEL_ID = "google/siglip-so400m-patch14-384" | |
CLIP_MODEL_ID = "openai/clip-vit-base-patch32" | |
def create_annot_fname_dict(annot_fnames: list[str])-> dict: | |
fs = re.compile(r'.*(_\d+)$') | |
unique_files = set() | |
for file in annot_fnames: | |
file_name = os.fsdecode(file) | |
base_name, _ = os.path.splitext(file_name) | |
if fs.match(base_name): | |
ind = len(fs.match(base_name).group(1)) | |
unique_files.add(base_name[:-ind]) | |
else: | |
unique_files.add(base_name) | |
annot_fname_dict = {} | |
for unique_file in unique_files: | |
annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]}) | |
return annot_fname_dict | |
def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict: | |
annot_file_dict = {} | |
for file in annot_files: | |
annot_file_dict.update({file.name : file}) | |
fs = re.compile(r'.*(_\d+)$') | |
unique_files = set() | |
for file in annot_fnames: | |
file_name = os.fsdecode(file) | |
base_name, _ = os.path.splitext(file_name) | |
if fs.match(base_name): | |
ind = len(fs.match(base_name).group(1)) | |
unique_files.add(base_name[:-ind]) | |
else: | |
unique_files.add(base_name) | |
annot_fname_dict = {} | |
for unique_file in unique_files: | |
annot_list = [file for file in annot_fnames if unique_file in file] | |
annot_list.sort() | |
annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list] | |
annot_fname_dict.update({unique_file: annot_file_list}) | |
return annot_fname_dict | |
def get_io_reader(uploaded_file): | |
assert uploaded_file.name[-3:]=='seq', 'Not a seq file' | |
with NamedTemporaryFile(suffix="seq", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = seqIo_reader(temp.name) | |
return sr | |
def load_slip_model(device): | |
return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device) | |
def load_slip_preprocessor(): | |
return AutoProcessor.from_pretrained(SLIP_MODEL_ID) | |
def load_clip_model(device): | |
return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device) | |
def load_clip_preprocessor(): | |
return AutoProcessor.from_pretrained(CLIP_MODEL_ID) | |
def encode_image(image, device, model, processor): | |
with torch.no_grad(): | |
#convert_models_to_fp32(model) | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
image_features = model.get_image_features(**inputs) | |
return image_features.cpu().numpy().flatten() | |
def generate_embeddings_stream(fnames : list[str], | |
model = 'SLIP', | |
downsample_rate = 4, | |
save_csv = False)-> tuple[list, list, list]: | |
# set up model and device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
if model == 'SLIP': | |
embed_model = load_slip_model(device) | |
processor = load_slip_preprocessor() | |
elif model == 'CLIP': | |
embed_model = load_clip_model(device) | |
processor = load_clip_preprocessor() | |
all_video_embeddings = [] | |
all_video_frames = [] | |
for fname in fnames: | |
# read in file | |
is_seq = False | |
if fname[-3:] == 'seq': is_seq = True | |
if is_seq: | |
sr = seqIo_reader(fname) | |
else: | |
sr = mp4Io_reader(fname) | |
N = sr.header['numFrames'] | |
# set up embeddings and frame arrays | |
embeddings = [] | |
frames = list(range(N))[::downsample_rate] | |
print(frames) | |
# create progress bar | |
i = 0 | |
pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.' | |
pbar = st.progress(0, text=pbar_text(0)) | |
# convert each frame to embeddings | |
for f in tqdm(frames): | |
img, _ = sr.getFrame(f) | |
img_arr = np.array(img) | |
if is_seq: | |
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB') | |
else: | |
img_rgb = Image.fromarray(img_arr).convert('RGB') | |
embeddings.append(encode_image(img_rgb, device, embed_model, processor)) | |
# update progress bar | |
i += 1 | |
pbar.progress(i/len(frames), pbar_text(i)) | |
# save csv of single file | |
if save_csv: | |
df = pd.DataFrame(embeddings) | |
df['Frame'] = frames | |
# save csv | |
basename = Path(fname).stem | |
df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False) | |
all_video_embeddings.append(np.array(embeddings)) | |
all_video_frames.append(frames) | |
return all_video_embeddings, all_video_frames | |
def get_io_reader(uploaded_file): | |
if uploaded_file.name[-3:]=='seq': | |
with NamedTemporaryFile(suffix="seq", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = seqIo_reader(temp.name) | |
else: | |
with NamedTemporaryFile(suffix="mp4", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = mp4Io_reader(temp.name) | |
return sr | |
def generate_embeddings_stream_io(uploaded_files : list, | |
model = 'SLIP', | |
downsample_rate = 4, | |
save_csv = False)-> tuple[list, list, list]: | |
# set up model and device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
with st.spinner('Loading multimodal model...'): | |
if model == 'SLIP': | |
embed_model = load_slip_model(device) | |
processor = load_slip_preprocessor() | |
elif model == 'CLIP': | |
embed_model = load_clip_model(device) | |
processor = load_clip_preprocessor() | |
all_video_embeddings = [] | |
all_video_frames = [] | |
for file in uploaded_files: | |
is_seq = False | |
if file.name[-3:] == 'seq': is_seq = True | |
# read in file | |
sr = get_io_reader(file) | |
N = sr.header['numFrames'] | |
# set up embeddings and frame arrays | |
embeddings = [] | |
frames = list(range(N))[::downsample_rate] | |
print(frames) | |
# create progress bar | |
i = 0 | |
pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.' | |
pbar = st.progress(0, text=pbar_text(0)) | |
# convert each frame to embeddings | |
for f in tqdm(frames): | |
img, _ = sr.getFrame(f) | |
img_arr = np.array(img) | |
if is_seq: | |
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB') | |
else: | |
img_rgb = Image.fromarray(img_arr).convert('RGB') | |
embeddings.append(encode_image(img_rgb, device, embed_model, processor)) | |
# update progress bar | |
i += 1 | |
pbar.progress(i/len(frames), pbar_text(i)) | |
# save csv of single file | |
if save_csv: | |
df = pd.DataFrame(embeddings) | |
df['Frame'] = frames | |
# save csv | |
df.to_csv(f'embeddings_downsample_{downsample_rate}_{N}_frames.csv', index=False) | |
all_video_embeddings.append(np.array(embeddings)) | |
all_video_frames.append(frames) | |
return all_video_embeddings, all_video_frames | |
def create_embeddings_csv(out: str, | |
fnames: list[str], | |
embeddings: list[np.ndarray], | |
frames: list[list[int]], | |
annotations: list[list[str]], | |
test_fnames: None | list[str], | |
views: None | list[str], | |
conditions: None | list[str], | |
downsample_rate = 4, | |
filesystem = None): | |
""" | |
Creates a .csv file containing all of the generated embeddings and provived information. | |
Parameters: | |
----------- | |
out : str | |
The name of the resulting file. | |
fnames : list[str] | |
Video sources for each of the embedding arrays. | |
embeddings : np.ndarray | |
The generated embeddings from the images. | |
downsample_rate : int | |
The downsample_rate used for generating the embeddings. | |
""" | |
assert len(fnames) == len(embeddings) | |
assert len(embeddings) == len(frames) | |
all_embeddings = np.vstack(embeddings) | |
df = pd.DataFrame(all_embeddings) | |
labels = [] | |
for i, annot_fnames in enumerate(annotations): | |
_, ext = os.path.splitext(annot_fnames[0]) | |
if ext == '.annot': | |
annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem) | |
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i])) | |
elif ext == '.csv': | |
if not filesystem: | |
annot_df = pd.read_csv(annot_fnames[0], header=None) | |
else: | |
with filesystem.open(annot_fnames[0], 'r') as csv_file: | |
annot_df = pd.read_csv(csv_file, header=None) | |
annot_labels = annot_df[0].to_list()[::downsample_rate] | |
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header." | |
else: | |
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".') | |
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files." | |
print(annot_labels) | |
labels.append(annot_labels) | |
all_labels = np.hstack(labels) | |
print(len(all_labels)) | |
df['Label'] = all_labels | |
all_frames = np.hstack(frames) | |
df['Frame'] = all_frames | |
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] | |
all_sources = np.hstack(sources) | |
df['Source'] = all_sources | |
if test_fnames: | |
t_split = lambda x: True if x in test_fnames else False | |
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] | |
else: | |
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)] | |
all_test = np.hstack(test) | |
df['Test'] = all_test | |
if views: | |
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
else: | |
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
all_view = np.hstack(view) | |
df['View'] = all_view | |
if conditions: | |
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
else: | |
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
all_condition = np.hstack(condition) | |
df['Condition'] = all_condition | |
return df | |
def create_embeddings_csv_io(out: str, | |
fnames: list[str], | |
embeddings: list[np.ndarray], | |
frames: list[list[int]], | |
annotations: list, | |
test_fnames: None | list[str], | |
views: None | list[str], | |
conditions: None | list[str], | |
downsample_rate = 4): | |
""" | |
Creates a .csv file containing all of the generated embeddings and provived information. | |
Parameters: | |
----------- | |
out : str | |
The name of the resulting file. | |
fnames : list[str] | |
Video sources for each of the embedding arrays. | |
embeddings : np.ndarray | |
The generated embeddings from the images. | |
downsample_rate : int | |
The downsample_rate used for generating the embeddings. | |
""" | |
assert len(fnames) == len(embeddings) | |
assert len(embeddings) == len(frames) | |
all_embeddings = np.vstack(embeddings) | |
df = pd.DataFrame(all_embeddings) | |
labels = [] | |
for i, uploaded_annots in enumerate(annotations): | |
print(i) | |
_, ext = os.path.splitext(uploaded_annots[0].name) | |
if ext == '.annot': | |
annot, _, _, sr = load_multiple_annotations_io(uploaded_annots) | |
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i])) | |
elif ext == '.csv': | |
annot_df = pd.read_csv(uploaded_annots[0], header=None) | |
annot_labels = annot_df[0].to_list()[::downsample_rate] | |
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header." | |
else: | |
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".') | |
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files." | |
print(annot_labels) | |
labels.append(annot_labels) | |
all_labels = np.hstack(labels) | |
print(len(all_labels)) | |
df['Label'] = all_labels | |
all_frames = np.hstack(frames) | |
df['Frame'] = all_frames | |
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] | |
all_sources = np.hstack(sources) | |
df['Source'] = all_sources | |
if test_fnames: | |
t_split = lambda x: True if x in test_fnames else False | |
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] | |
else: | |
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)] | |
all_test = np.hstack(test) | |
df['Test'] = all_test | |
if views: | |
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
else: | |
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
all_view = np.hstack(view) | |
df['View'] = all_view | |
if conditions: | |
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
else: | |
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] | |
all_condition = np.hstack(condition) | |
df['Condition'] = all_condition | |
return df | |
def process_dataset_in_mem(embeddings_df: pd.DataFrame, | |
specified_classes=None, | |
classes_to_remove=None, | |
max_class_size=None, | |
animal_state=None, | |
view=None, | |
shuffle_data=False, | |
test_videos=None): | |
""" | |
Processes output generated from embeddings paired with images and behavior labels. | |
Parameters: | |
----------- | |
csv_path : str | |
Path to the file containing the original data. This should contain embeddings, | |
a column named `'Label'` and a column named `'Images'`. | |
specified_classes : None | list[str] | |
An optional input. Defines labels which should be kept as is in the `'Label'` | |
column and which should be changed to a default `other` label. | |
classes_to_remove : None | list[str] | |
An optional input. Drops rows from the dataframe which contain a label in the | |
list. | |
max_class_size : None | int | |
An optional input. Determines the maximum amount of rows a single label can | |
appear in for each unique label in the `'Label'` column. | |
animal_state : None | str | |
An optional input. Drops rows from the dataframe which do not contain a match | |
for `animal_state` in the text field within the `'Images'` column. | |
view : None | str | |
An optional input. Drops rows from the dataframe which do not contain a match | |
for `view` in the text field within the `'Images'` column. | |
shuffle_data : bool | |
Determines wether the dataframe should have its rows shuffled. | |
test_videos : None | list[str] | |
An optional input. Determines what rows should be in the `test` dataframe, and | |
which should be in the `train` dataframe. It drops rows from the respective | |
dataframe by keeping or dropping rows which do not contain a match for a `str` | |
in `test_videos` in the text field within the `'Images'` column, respectively. | |
Returns: | |
-------- | |
balanced_train_embeddings : pandas.DataFrame | |
A processed dataframe whose rows contain the embeddings for each of the images | |
at the corresponding index within `balanced_train_images`. | |
balanced_train_labels : list[str] | |
A list of labels for each of the images at the corresponing index within | |
`balanced_train_images`. | |
balanced_train_images: list[str] | |
A list of paths to images with each image at an index corresponding to a label | |
with the same index in `balanced_train_labels` and the same row index within | |
`balanced_train_embeddings`. | |
test_embeddings : pandas.DataFrame | |
A processed dataframe whose rows contain the embeddings for each of the images | |
at the corresponding index within `test_images`. | |
test_labels : list[str] | |
A list of labels for each of the images at the corresponing index within | |
`test_images`. | |
test_images : list[str] | |
A list of paths to images with each image at an index corresponding to a label | |
with the same index in `test_labels` and the same row index within | |
`test_embeddings`. | |
""" | |
# Convert embeddings, labels, and images to a DataFrame for easy manipulation | |
df = copy.deepcopy(embeddings_df) | |
df_keys = [str(x) for x in df.keys()] | |
#Filter by fed or fasted | |
if 'Condition' in df_keys and animal_state: | |
df = df[df['Condition'].str.contains(animal_state, na=False)] | |
if 'View' in df_keys and view: | |
df = df[df['View'].str.contains(view, na=False)] | |
# Extract unique video names excluding the frame number | |
#unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique() | |
#print("\nUnique video names:\n", unique_video_names) | |
if classes_to_remove: | |
df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)] | |
elif classes_to_remove and 'all' in classes_to_remove: | |
df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)] | |
# Further filter to include only specified_classes | |
if specified_classes: | |
single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0] | |
df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other') | |
specified_classes.append('other') | |
# Separate the DataFrame into test and training sets based on test_videos | |
if 'Test' in df_keys and test_videos: | |
test_df = df[df['Test']] | |
train_df = df[~df['Test']] | |
elif test_videos: | |
test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)] | |
train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)] | |
else: | |
test_df = pd.DataFrame(columns=df.columns) | |
train_df = df | |
# Print the number of frames in each class before balancing | |
label_counts = train_df['Label'].value_counts() | |
print("\nNumber of training frames in each class before balancing:") | |
print(label_counts) | |
if max_class_size: | |
balanced_train_df = pd.concat([ | |
group.sample(n=min(len(group), max_class_size), random_state=1) | |
for label, group in train_df.groupby('Label') | |
]) | |
else: | |
balanced_train_df = train_df | |
# Shuffle the training DataFrame | |
if shuffle_data: | |
balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True) | |
# Convert training set back to numpy array and list | |
if not "Images" in df_keys: | |
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy() | |
balanced_train_labels = balanced_train_df['Label'].tolist() | |
balanced_train_images = balanced_train_df['Frame'].tolist() | |
# Convert test set back to numpy array and list | |
test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy() | |
test_labels = test_df['Label'].tolist() | |
test_images = test_df['Frame'].tolist() | |
else: | |
# Convert training set back to numpy array and list | |
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy() | |
balanced_train_labels = balanced_train_df['Label'].tolist() | |
balanced_train_images = balanced_train_df['Images'].tolist() | |
# Convert test set back to numpy array and list | |
if 'Test' in test_df: | |
test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy() | |
else: | |
test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy() | |
test_labels = test_df['Label'].tolist() | |
test_images = test_df['Images'].tolist() | |
# Print the number of frames in each class after balancing | |
if specified_classes or max_class_size: | |
balanced_label_counts = Counter(balanced_train_labels) | |
print("\nNumber of training frames in each class after balancing:") | |
print(balanced_label_counts) | |
test_label_counts = test_df['Label'].value_counts() | |
# print("\nNumber of testing frames in each class:") | |
print(test_label_counts) | |
return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images | |
def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold): | |
# Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label) | |
unique_labels = np.unique(multiclass_vector) | |
unique_labels = unique_labels[unique_labels != 0] | |
# Initialize a vector to store the merged and filtered multiclass vector | |
merged_vector = np.zeros_like(multiclass_vector) | |
for label in unique_labels: | |
# Create a binary vector for the current label | |
binary_vector = (multiclass_vector == label) | |
# Find the start and end indices of all sequences of 1's for this label | |
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0] | |
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0] | |
# Step 1: Merge close short bouts | |
i = 0 | |
while i < len(starts) - 1: | |
# Check if the gap between the end of the current bout and the start of the next bout | |
# is within the proximity threshold | |
if starts[i + 1] - ends[i] <= proximity_threshold: | |
# Merge the two bouts by setting all elements between the start of the first | |
# and the end of the second bout to 1 | |
binary_vector[ends[i]:starts[i + 1]] = 1 | |
# Remove the next bout from consideration | |
starts = np.delete(starts, i + 1) | |
ends = np.delete(ends, i) | |
else: | |
i += 1 | |
# Update the starts and ends after merging | |
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0] | |
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0] | |
# Step 2: Remove standalone short bouts | |
for i in range(len(starts)): | |
# Check the length of the bout | |
length_of_bout = ends[i] - starts[i] + 1 | |
# If the length is less than the threshold, set those elements to 0 | |
if length_of_bout < bout_threshold: | |
binary_vector[starts[i]:ends[i] + 1] = 0 | |
# Combine the binary vector with the merged_vector, ensuring only the current label is set | |
merged_vector[binary_vector] = label | |
# Return the filtered multiclass vector | |
return merged_vector | |
def get_unique_labels(label_list: list[str]): | |
label_set = set() | |
for label in label_list: | |
individual_labels = label.split('||') | |
for individual_label in individual_labels: | |
label_set.add(individual_label) | |
return list(label_set) | |
def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42): | |
return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state) | |
def train_model(X_train, y_train, random_state=42): | |
# Train SVM Classifier | |
svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True) | |
svm_clf.fit(X_train, y_train) | |
return svm_clf | |
def pickle_model(model): | |
pickled = io.BytesIO() | |
pickle.dump(model, pickled) | |
return pickled | |
def get_seq_io_reader(uploaded_file): | |
assert uploaded_file.name[-3:]=='seq', 'Not a seq file' | |
with NamedTemporaryFile(suffix="seq", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = seqIo_reader(temp.name) | |
return sr | |
def seq_to_arr(sr): | |
N = sr.header['numFrames'] | |
images = [] | |
for f in range(N): | |
I, ts = sr.getFrame(f) | |
images.append(I) | |
return np.array(images) | |
def get_2d_embedding(embeddings: pd.DataFrame): | |
tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50) | |
embedding_2d = tsne.fit_transform(np.array(embeddings)) | |
return embedding_2d | |