Spaces:
Build error
Build error
File size: 2,165 Bytes
84d2a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
use std::path::Path;
use common::counter::hardware_counter::HardwareCounterCell;
use serde::{Deserialize, Serialize};
use crate::EncodingError;
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceType {
Dot,
L1,
L2,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct VectorParameters {
pub dim: usize,
pub count: usize,
pub distance_type: DistanceType,
pub invert: bool,
}
pub trait EncodedVectors<TEncodedQuery: Sized>: Sized {
fn save(&self, data_path: &Path, meta_path: &Path) -> std::io::Result<()>;
fn load(
data_path: &Path,
meta_path: &Path,
vector_parameters: &VectorParameters,
) -> std::io::Result<Self>;
fn encode_query(&self, query: &[f32]) -> TEncodedQuery;
fn score_point(&self, query: &TEncodedQuery, i: u32, hw_couter: &HardwareCounterCell) -> f32;
fn score_internal(&self, i: u32, j: u32, hw_couter: &HardwareCounterCell) -> f32;
}
impl DistanceType {
pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self {
DistanceType::Dot => a.iter().zip(b).map(|(a, b)| a * b).sum(),
DistanceType::L1 => a.iter().zip(b).map(|(a, b)| (a - b).abs()).sum(),
DistanceType::L2 => a.iter().zip(b).map(|(a, b)| (a - b) * (a - b)).sum(),
}
}
}
pub(crate) fn validate_vector_parameters<'a>(
data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
vector_parameters: &VectorParameters,
) -> Result<(), EncodingError> {
let mut count = 0;
for vector in data {
let vector = vector.as_ref();
if vector.len() != vector_parameters.dim {
return Err(EncodingError::ArgumentsError(format!(
"Vector length {} does not match vector parameters dim {}",
vector.len(),
vector_parameters.dim
)));
}
count += 1;
}
if count != vector_parameters.count {
return Err(EncodingError::ArgumentsError(format!(
"Vector count {} does not match vector parameters count {}",
count, vector_parameters.count
)));
}
Ok(())
}
|