Spaces:
Runtime error
Runtime error
import pickle | |
import numpy as np | |
import faiss | |
from tqdm.auto import tqdm | |
def sample_tf(x,y,ndim=1000): | |
''' | |
input: tf.x,tf.y, ndim | |
return: n-dim tf values | |
''' | |
t = np.linspace(0,1,ndim) | |
return np.piecewise(t,[t>=xx for xx in x],y) | |
tf_train = pickle.load(open('./data/trainTF.pkl','rb')) | |
tf = [] | |
for i in tqdm(range(len(tf_train))): | |
tf_i = tf_train[i] | |
tf.append(sample_tf(tf_i['x'],tf_i['y'])) | |
d = 1000 | |
tf = np.array(tf).astype(np.float32) | |
ncentroids = 1000 | |
niter = 200 | |
verbose = True | |
kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose,gpu=True) | |
kmeans.train(tf) | |
centroids = kmeans.centroids | |
index = faiss.IndexFlatL2(d) | |
index.add(tf) | |
nNN = 1000 | |
D, I = index.search (kmeans.centroids, nNN) | |
np.save(f'./data/centroids_train.npy',centroids) | |
np.save(f'./data/clusters_train.npy',I) |