File size: 1,226 Bytes
5f33ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 17 06:46:29 PM EDT 2022 
author: Ryan Hildebrandt, github.com/ryancahildebrandt
"""
# imports
import numpy as np
import pandas as pd
import plotly.express as px
import random

from cluster import *

random.seed(42)

px.defaults.template = "plotly"

def d2_plot(in_df):
	fig = px.scatter(
		x = np.array(in_df["d2"].values.tolist())[:,0], 
		y = np.array(in_df["d2"].values.tolist())[:,1], 
		color = list(map(str, in_df["cluster"])),
		hover_name = in_df["prep"]
		)
	fig.update_layout(showlegend=False)
	return fig

def d3_plot(in_df):
	fig = px.scatter_3d(
		x = np.array(in_df["d3"].values.tolist())[:,0], 
		y = np.array(in_df["d3"].values.tolist())[:,1], 
		z = np.array(in_df["d3"].values.tolist())[:,2], 
		color = list(map(str, in_df["cluster"])),
		hover_name = in_df["prep"]
		)
	fig.update_traces(marker={'size': 2})
	fig.update_layout(showlegend=False)
	return fig

def viz_ex(func, cldata, use_embs):
	dim_df = pd.DataFrame({
	"prep" : cldata,
	"cluster" : cluster_hdbscan(use_embs, 1.0, "euclidean", 5),
	"emb" : list(use_embs),
	"d2" : list(func[0]),
	"d3" : list(func[1])
	})
	out = {"d2": d2_plot(dim_df), "d3":d3_plot(dim_df)}
	return out