xnetba's picture
Duplicate from ryancahildebrandt/all_in_one_sentence_embeddings
5f33ab8
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 17 06:46:29 PM EDT 2022
author: Ryan Hildebrandt, github.com/ryancahildebrandt
"""
# imports
import numpy as np
import pandas as pd
import plotly.express as px
import random
from cluster import *
random.seed(42)
px.defaults.template = "plotly"
def d2_plot(in_df):
fig = px.scatter(
x = np.array(in_df["d2"].values.tolist())[:,0],
y = np.array(in_df["d2"].values.tolist())[:,1],
color = list(map(str, in_df["cluster"])),
hover_name = in_df["prep"]
)
fig.update_layout(showlegend=False)
return fig
def d3_plot(in_df):
fig = px.scatter_3d(
x = np.array(in_df["d3"].values.tolist())[:,0],
y = np.array(in_df["d3"].values.tolist())[:,1],
z = np.array(in_df["d3"].values.tolist())[:,2],
color = list(map(str, in_df["cluster"])),
hover_name = in_df["prep"]
)
fig.update_traces(marker={'size': 2})
fig.update_layout(showlegend=False)
return fig
def viz_ex(func, cldata, use_embs):
dim_df = pd.DataFrame({
"prep" : cldata,
"cluster" : cluster_hdbscan(use_embs, 1.0, "euclidean", 5),
"emb" : list(use_embs),
"d2" : list(func[0]),
"d3" : list(func[1])
})
out = {"d2": d2_plot(dim_df), "d3":d3_plot(dim_df)}
return out