zwt123home123 commited on
Commit
ad59ce5
·
verified ·
1 Parent(s): eb89a71

Upload cluster_all_layers_nonorm_afterrope_group.py

Browse files
cluster_all_layers_nonorm_afterrope_group.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import faiss
5
+ import numpy as np
6
+ import random
7
+ from tqdm import tqdm
8
+ # Data directory
9
+ # data_dir = '../feats_offset'
10
+
11
+
12
+ # Parameter settings
13
+ feature_dim = 5120
14
+ num_clusters = 1000 # Number of clusters
15
+ #batch_size = 10000000
16
+ batch_size = 1000000
17
+ #batch_size=200000
18
+ niter = 20
19
+ num_tensor_file = int(batch_size/64/576)
20
+ save_folder = f"/sensei-fs/users/wezhao/projects/data/cluster/centroids_faiss_K_c1k_bs1m_iter_{niter}_nonorm_all_layers_afterrope_group"
21
+ os.system("mkdir "+save_folder)
22
+ for layer_idx in range(40):
23
+ os.system("mkdir "+save_folder+"/"+str(layer_idx))
24
+ #if layer_idx<=30:
25
+ # continue
26
+ data_dir = '/sensei-fs/users/wezhao/projects/proj-phu/DenseToken/data/key_states_save_13b_all_layers_after_rope/'+str(layer_idx)
27
+ # Get a list of all .pt files in the directory
28
+ pt_files = glob.glob(os.path.join(data_dir, '*.pth'))
29
+
30
+ print(f"Found {len(pt_files)} .pth files.")
31
+
32
+ print("num_tensor_file:",num_tensor_file)
33
+ tensor_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pth') ]
34
+ random.shuffle(tensor_files)
35
+
36
+ # Initialize GPU resources
37
+ print("Initializing GPU resources...")
38
+ res = faiss.StandardGpuResources()
39
+
40
+ print("Setting up clustering parameters...")
41
+
42
+
43
+
44
+ # Data iterator function
45
+ def data_iterator(tensor_files):
46
+ #
47
+ for i in range(0, len(tensor_files), num_tensor_file):
48
+ # Read three tensor files at a time
49
+ tensors = []
50
+ # import pdb; pdb.set_trace()
51
+ for j in range(num_tensor_file):
52
+ if i + j < len(tensor_files):
53
+ print("loading " + str(i)+" "+str(j)+tensor_files[i + j])
54
+ tensor = torch.load(tensor_files[i + j])
55
+ # import pdb; pdb.set_trace()
56
+ tensor = tensor.reshape(-1, feature_dim).cpu().numpy().astype(np.float32)
57
+ # import pdb; pdb.set_trace()
58
+
59
+ tensors.append(tensor)
60
+
61
+ if tensors:
62
+ yield np.concatenate(tensors, axis=0)
63
+
64
+ # Fit the MiniBatchKMeans model incrementally
65
+ count = 0
66
+ # import pdb; pdb.set_trace()
67
+ for data_batch in tqdm(data_iterator(tensor_files), desc="Processing batches"):
68
+ data_batch = data_batch.reshape(-1, 40, 128)
69
+ for i in range(40):
70
+ data = data_batch[:,i,:]
71
+ # faiss.normalize_L2(data)
72
+ kmeans = faiss.Kmeans(d=128, k=num_clusters, niter=niter, gpu=True, verbose=True)
73
+ # print("====")
74
+ # Train k-means clustering model on GPU
75
+ print("Training k-means clustering model on GPU...")
76
+ data = np.ascontiguousarray(data, dtype=np.float32)
77
+ # import pdb; pdb.set_trace()
78
+ # faiss.copy_array_to_vector(np.zeros((100000,5120),dtype=np.float32).ravel(), kmeans.centroids)
79
+ # faiss.vector_to_array(kmeans.centroids)
80
+ #kmeans.train(data_batch, index)
81
+ kmeans.train(data)
82
+ print("k-means training completed.")
83
+ # Extract centroids
84
+ print("Extracting centroids...")
85
+ centroids = kmeans.centroids
86
+
87
+ np.save(save_folder+"/"+f"{layer_idx}/{i}.npy", centroids)
88
+ #np.save(f"temp/{count}.npy", centroids)
89
+ print("Centroids saved layer "+str(layer_idx))
90
+ count += 1
91
+ break