use std::borrow::Cow; use common::validation::validate_multi_vector; use validator::{Validate, ValidationError, ValidationErrors}; use super::schema::BatchVectorStruct; use super::{ Batch, ContextInput, Fusion, OrderByInterface, PointVectors, Query, QueryInterface, RecommendInput, Sample, VectorInput, }; use crate::rest::NamedVectorStruct; impl Validate for NamedVectorStruct { fn validate(&self) -> Result<(), validator::ValidationErrors> { match self { NamedVectorStruct::Default(_) => Ok(()), NamedVectorStruct::Dense(_) => Ok(()), NamedVectorStruct::Sparse(v) => v.validate(), } } } impl Validate for QueryInterface { fn validate(&self) -> Result<(), validator::ValidationErrors> { match self { QueryInterface::Nearest(vector) => vector.validate(), QueryInterface::Query(query) => query.validate(), } } } impl Validate for Query { fn validate(&self) -> Result<(), validator::ValidationErrors> { match self { Query::Nearest(vector) => vector.nearest.validate(), Query::Recommend(recommend) => recommend.recommend.validate(), Query::Discover(discover) => discover.discover.validate(), Query::Context(context) => context.context.validate(), Query::Fusion(fusion) => fusion.fusion.validate(), Query::OrderBy(order_by) => order_by.order_by.validate(), Query::Sample(sample) => sample.sample.validate(), } } } impl Validate for VectorInput { fn validate(&self) -> Result<(), validator::ValidationErrors> { match self { VectorInput::Id(_id) => Ok(()), VectorInput::DenseVector(_dense) => Ok(()), VectorInput::SparseVector(sparse) => sparse.validate(), VectorInput::MultiDenseVector(multi) => validate_multi_vector(multi), VectorInput::Document(doc) => doc.validate(), VectorInput::Image(image) => image.validate(), VectorInput::Object(obj) => obj.validate(), } } } impl Validate for RecommendInput { fn validate(&self) -> Result<(), validator::ValidationErrors> { let no_positives = self.positive.as_ref().map(|p| p.is_empty()).unwrap_or(true); let no_negatives = self.negative.as_ref().map(|n| n.is_empty()).unwrap_or(true); if no_positives && no_negatives { let mut errors = validator::ValidationErrors::new(); errors.add( "positives, negatives", ValidationError::new( "At least one positive or negative vector/id must be provided", ), ); return Err(errors); } for item in self.iter() { item.validate()?; } Ok(()) } } impl Validate for ContextInput { fn validate(&self) -> Result<(), validator::ValidationErrors> { for item in self.0.iter().flatten().flat_map(|item| item.iter()) { item.validate()?; } Ok(()) } } impl Validate for Fusion { fn validate(&self) -> Result<(), validator::ValidationErrors> { match self { Fusion::Rrf | Fusion::Dbsf => Ok(()), } } } impl Validate for OrderByInterface { fn validate(&self) -> Result<(), validator::ValidationErrors> { match self { OrderByInterface::Key(_key) => Ok(()), // validated during parsing OrderByInterface::Struct(order_by) => order_by.validate(), } } } impl Validate for Sample { fn validate(&self) -> Result<(), ValidationErrors> { match self { Sample::Random => Ok(()), } } } impl Validate for BatchVectorStruct { fn validate(&self) -> Result<(), ValidationErrors> { match self { BatchVectorStruct::Single(_) => Ok(()), BatchVectorStruct::MultiDense(vectors) => { for vector in vectors { validate_multi_vector(vector)?; } Ok(()) } BatchVectorStruct::Named(v) => { common::validation::validate_iter(v.values().flat_map(|batch| batch.iter())) } BatchVectorStruct::Document(_) => Ok(()), BatchVectorStruct::Image(_) => Ok(()), BatchVectorStruct::Object(_) => Ok(()), } } } impl Validate for Batch { fn validate(&self) -> Result<(), ValidationErrors> { let batch = self; let bad_input_description = |ids: usize, vecs: usize| -> String { format!("number of ids and vectors must be equal ({ids} != {vecs})") }; let create_error = |message: String| -> ValidationErrors { let mut errors = ValidationErrors::new(); errors.add("batch", { let mut error = ValidationError::new("point_insert_operation"); error.message.replace(Cow::from(message)); error }); errors }; self.vectors.validate()?; match &batch.vectors { BatchVectorStruct::Single(vectors) => { if batch.ids.len() != vectors.len() { return Err(create_error(bad_input_description( batch.ids.len(), vectors.len(), ))); } } BatchVectorStruct::MultiDense(vectors) => { if batch.ids.len() != vectors.len() { return Err(create_error(bad_input_description( batch.ids.len(), vectors.len(), ))); } } BatchVectorStruct::Named(named_vectors) => { for vectors in named_vectors.values() { if batch.ids.len() != vectors.len() { return Err(create_error(bad_input_description( batch.ids.len(), vectors.len(), ))); } } } BatchVectorStruct::Document(_) => {} BatchVectorStruct::Image(_) => {} BatchVectorStruct::Object(_) => {} } if let Some(payload_vector) = &batch.payloads { if payload_vector.len() != batch.ids.len() { return Err(create_error(format!( "number of ids and payloads must be equal ({} != {})", batch.ids.len(), payload_vector.len(), ))); } } Ok(()) } } impl Validate for PointVectors { fn validate(&self) -> Result<(), ValidationErrors> { if self.vector.is_empty() { let mut err = ValidationError::new("length"); err.message = Some(Cow::from("must specify vectors to update for point")); err.add_param(Cow::from("min"), &1); let mut errors = ValidationErrors::new(); errors.add("vector", err); Err(errors) } else { self.vector.validate() } } }