File size: 830 Bytes
06db6e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pickle
import numpy as np
import faiss
from tqdm.auto import tqdm

def sample_tf(x,y,ndim=1000):
    '''
    input: tf.x,tf.y, ndim
    return: n-dim tf values
    '''
    t = np.linspace(0,1,ndim)
    return np.piecewise(t,[t>=xx for xx in x],y)

tf_train = pickle.load(open('./data/trainTF.pkl','rb'))

tf = []
for i in tqdm(range(len(tf_train))):
    tf_i = tf_train[i]
    tf.append(sample_tf(tf_i['x'],tf_i['y']))

d = 1000
tf = np.array(tf).astype(np.float32)

ncentroids = 1000
niter = 200
verbose = True

kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose,gpu=True)
kmeans.train(tf)
centroids = kmeans.centroids

index = faiss.IndexFlatL2(d)
index.add(tf)
nNN = 1000
D, I = index.search (kmeans.centroids, nNN)

np.save(f'./data/centroids_train.npy',centroids)
np.save(f'./data/clusters_train.npy',I)