Spaces:
Runtime error
Runtime error
File size: 1,529 Bytes
f98f59d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
class ClusterEmbeddings():
def __init__(
self,
cluster_estimate,
cluster_fn,
embeddings,
sentences,
words
):
self.cluster_estimate = cluster_estimate
self.embeddings = embeddings
self.sentences = sentences
self.words = words
self.cluster_fn = cluster_fn
if self.cluster_fn == "agglo":
self.clustering_algo = AgglomerativeClustering(n_clusters=self.cluster_estimate)
self.num_clusters = cluster_estimate
elif self.cluster_fn == "kmeans":
self.clustering_algo = KMeans(n_clusters=self.cluster_estimate)
self.num_clusters = cluster_estimate
self.cluster = self.clustering_algo.fit(embeddings)
self.labels = self.cluster.labels_
def get_sentence_clusters(self):
sent_clusters = []
chunk = ""
for lbl in range(self.num_clusters):
single_cluster = self.sentences[self.labels == lbl]
for sent in single_cluster:
chunk += sent + " "
sent_clusters.append(chunk)
chunk = ""
return np.array(sent_clusters)
def make_plot(self):
projector = TSNE(
n_components=2,
learning_rate="auto",
init="random"
)
proj_embeddings = np.array(
projector.fit_transform(self.embeddings)
)
for lbl in range(self.num_clusters):
xs = proj_embeddings[self.labels == lbl]
plt.scatter(xs[:, 0], xs[:, 1], label=f"Cluster {lbl}")
plt.legend()
plt.xlabel("x1")
plt.ylabel("x2")
plt.show() |