Spaces:
Build error
Build error
use rand::Rng; | |
use rayon::prelude::*; | |
use rayon::ThreadPool; | |
use crate::EncodingError; | |
pub fn kmeans( | |
data: &[f32], | |
centroids_count: usize, | |
dim: usize, | |
max_iterations: usize, | |
max_threads: usize, | |
accuracy: f32, | |
stop_condition: &impl Fn() -> bool, | |
) -> Result<Vec<f32>, EncodingError> { | |
let pool = rayon::ThreadPoolBuilder::new() | |
.thread_name(|idx| format!("kmeans-{idx}")) | |
.num_threads(max_threads) | |
.build() | |
.map_err(|e| { | |
EncodingError::EncodingError(format!("Failed PQ encoding while thread pool init: {e}")) | |
})?; | |
// initial centroids positions are some vectors from data | |
let mut centroids = data[0..centroids_count * dim].to_vec(); | |
let mut centroid_indexes = vec![0u32; data.len() / dim]; | |
for _ in 0..max_iterations { | |
if stop_condition() { | |
return Err(EncodingError::Stopped); | |
} | |
update_indexes(&pool, data, &mut centroid_indexes, ¢roids); | |
if update_centroids( | |
&pool, | |
data, | |
¢roid_indexes, | |
&mut centroids, | |
max_threads, | |
accuracy, | |
) { | |
break; | |
} | |
} | |
update_indexes(&pool, data, &mut centroid_indexes, ¢roids); | |
Ok(centroids) | |
} | |
fn update_centroids( | |
pool: &ThreadPool, | |
data: &[f32], | |
centroid_indexes: &[u32], | |
centroids: &mut [f32], | |
max_threads: usize, | |
accuracy: f32, | |
) -> bool { | |
struct CentroidsCounter { | |
counter: Vec<usize>, | |
acc: Vec<f64>, | |
} | |
let dim = data.len() / centroid_indexes.len(); | |
let centroids_count = centroids.len() / dim; | |
let mut counters = (0..max_threads) | |
.map(|_| CentroidsCounter { | |
counter: vec![0usize; centroids_count], | |
acc: vec![0.0_f64; centroids.len()], | |
}) | |
.collect::<Vec<_>>(); | |
pool.install(|| { | |
counters | |
.par_iter_mut() | |
.enumerate() | |
.for_each(|(i, counter)| { | |
let chunk_size = centroid_indexes.len() / max_threads; | |
let vector_data_range = if i + 1 == max_threads { | |
chunk_size * i..centroid_indexes.len() | |
} else { | |
chunk_size * i..chunk_size * (i + 1) | |
}; | |
for i in vector_data_range { | |
let vector_data = &data[dim * i..dim * (i + 1)]; | |
let centroid_index = centroid_indexes[i] as usize; | |
counter.counter[centroid_index] += 1; | |
let centroid_data = | |
&mut counter.acc[dim * centroid_index..dim * (centroid_index + 1)]; | |
for (c, v) in centroid_data.iter_mut().zip(vector_data.iter()) { | |
*c += f64::from(*v); | |
} | |
} | |
}) | |
}); | |
let mut counter = CentroidsCounter { | |
counter: vec![0usize; centroids_count], | |
acc: vec![0.0_f64; centroids.len()], | |
}; | |
for c in counters { | |
for (dst, src) in counter.counter.iter_mut().zip(c.counter.iter()) { | |
*dst += src; | |
} | |
for (dst, src) in counter.acc.iter_mut().zip(c.acc.iter()) { | |
*dst += src; | |
} | |
} | |
for (centroid_index, centroid_data) in counter.acc.chunks_exact_mut(dim).enumerate() { | |
if counter.counter[centroid_index] == 0 { | |
// the cluster is empty, so we take random vector as centroid | |
let data_index = rand::thread_rng().gen_range(0..centroid_indexes.len()); | |
let vector = &data[dim * data_index..dim * (data_index + 1)]; | |
centroid_data | |
.iter_mut() | |
.zip(vector.iter()) | |
.for_each(|(c, v)| *c = f64::from(*v)); | |
} else { | |
let count = counter.counter[centroid_index] as f64; | |
centroid_data.iter_mut().for_each(|c| *c /= count); | |
} | |
} | |
let diff = centroids | |
.iter_mut() | |
.zip(counter.acc.iter()) | |
.map(|(c, c_acc)| { | |
let c_acc = *c_acc as f32; | |
let value = (*c - c_acc).abs(); | |
*c = c_acc; | |
value | |
}) | |
.sum::<f32>(); | |
diff < accuracy | |
} | |
fn update_indexes( | |
pool: &ThreadPool, | |
data: &[f32], | |
centroid_indexes: &mut [u32], | |
centroids: &[f32], | |
) { | |
let dim = data.len() / centroid_indexes.len(); | |
pool.install(|| { | |
centroid_indexes | |
.par_iter_mut() | |
.enumerate() | |
.for_each(|(i, c)| { | |
let vector_data = &data[dim * i..dim * (i + 1)]; | |
let mut min_distance = f32::MAX; | |
let mut min_centroid_index = 0; | |
for (centroid_index, centroid_data) in centroids.chunks_exact(dim).enumerate() { | |
let distance = vector_data | |
.iter() | |
.zip(centroid_data.iter()) | |
.map(|(a, b)| (a - b).powi(2)) | |
.sum(); | |
if distance < min_distance { | |
min_distance = distance; | |
min_centroid_index = centroid_index; | |
} | |
} | |
*c = min_centroid_index as u32; | |
}) | |
}); | |
} | |