Spaces:
Build error
Build error
use api::conversions::json::json_path_from_proto; | |
use api::grpc::qdrant as grpc; | |
use api::grpc::qdrant::query::Variant; | |
use api::grpc::qdrant::RecommendInput; | |
use api::rest; | |
use api::rest::RecommendStrategy; | |
use collection::operations::universal_query::collection_query::{ | |
CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, Query, | |
VectorInputInternal, VectorQuery, | |
}; | |
use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal}; | |
use segment::data_types::order_by::OrderBy; | |
use segment::data_types::vectors::{VectorInternal, DEFAULT_VECTOR_NAME}; | |
use segment::vector_storage::query::{ContextPair, ContextQuery, DiscoveryQuery, RecoQuery}; | |
use tonic::Status; | |
use crate::common::inference::batch_processing_grpc::{ | |
collect_prefetch, collect_query, BatchAccumGrpc, | |
}; | |
use crate::common::inference::infer_processing::BatchAccumInferred; | |
use crate::common::inference::service::{InferenceData, InferenceType}; | |
/// ToDo: this function is supposed to call an inference endpoint internally | |
pub async fn convert_query_point_groups_from_grpc( | |
query: grpc::QueryPointGroups, | |
) -> Result<CollectionQueryGroupsRequest, Status> { | |
let grpc::QueryPointGroups { | |
collection_name: _, | |
prefetch, | |
query, | |
using, | |
filter, | |
params, | |
score_threshold, | |
with_payload, | |
with_vectors, | |
lookup_from, | |
limit, | |
group_size, | |
group_by, | |
with_lookup, | |
read_consistency: _, | |
timeout: _, | |
shard_key_selector: _, | |
} = query; | |
let mut batch = BatchAccumGrpc::new(); | |
if let Some(q) = &query { | |
collect_query(q, &mut batch)?; | |
} | |
for p in &prefetch { | |
collect_prefetch(p, &mut batch)?; | |
} | |
let BatchAccumGrpc { objects } = batch; | |
let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search) | |
.await | |
.map_err(|e| Status::internal(format!("Inference error: {e}")))?; | |
let query = if let Some(q) = query { | |
Some(convert_query_with_inferred(q, &inferred)?) | |
} else { | |
None | |
}; | |
let prefetch = prefetch | |
.into_iter() | |
.map(|p| convert_prefetch_with_inferred(p, &inferred)) | |
.collect::<Result<Vec<_>, _>>()?; | |
let request = CollectionQueryGroupsRequest { | |
prefetch, | |
query, | |
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), | |
filter: filter.map(TryFrom::try_from).transpose()?, | |
score_threshold, | |
with_vector: with_vectors | |
.map(From::from) | |
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR), | |
with_payload: with_payload | |
.map(TryFrom::try_from) | |
.transpose()? | |
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD), | |
lookup_from: lookup_from.map(From::from), | |
group_by: json_path_from_proto(&group_by)?, | |
group_size: group_size | |
.map(|s| s as usize) | |
.unwrap_or(CollectionQueryRequest::DEFAULT_GROUP_SIZE), | |
limit: limit | |
.map(|l| l as usize) | |
.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), | |
params: params.map(From::from), | |
with_lookup: with_lookup.map(TryFrom::try_from).transpose()?, | |
}; | |
Ok(request) | |
} | |
/// ToDo: this function is supposed to call an inference endpoint internally | |
pub async fn convert_query_points_from_grpc( | |
query: grpc::QueryPoints, | |
) -> Result<CollectionQueryRequest, Status> { | |
let grpc::QueryPoints { | |
collection_name: _, | |
prefetch, | |
query, | |
using, | |
filter, | |
params, | |
score_threshold, | |
limit, | |
offset, | |
with_payload, | |
with_vectors, | |
read_consistency: _, | |
shard_key_selector: _, | |
lookup_from, | |
timeout: _, | |
} = query; | |
let mut batch = BatchAccumGrpc::new(); | |
if let Some(q) = &query { | |
collect_query(q, &mut batch)?; | |
} | |
for p in &prefetch { | |
collect_prefetch(p, &mut batch)?; | |
} | |
let BatchAccumGrpc { objects } = batch; | |
let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search) | |
.await | |
.map_err(|e| Status::internal(format!("Inference error: {e}")))?; | |
let prefetch = prefetch | |
.into_iter() | |
.map(|p| convert_prefetch_with_inferred(p, &inferred)) | |
.collect::<Result<Vec<_>, _>>()?; | |
let query = query | |
.map(|q| convert_query_with_inferred(q, &inferred)) | |
.transpose()?; | |
Ok(CollectionQueryRequest { | |
prefetch, | |
query, | |
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), | |
filter: filter.map(TryFrom::try_from).transpose()?, | |
score_threshold, | |
limit: limit | |
.map(|l| l as usize) | |
.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), | |
offset: offset | |
.map(|o| o as usize) | |
.unwrap_or(CollectionQueryRequest::DEFAULT_OFFSET), | |
params: params.map(From::from), | |
with_vector: with_vectors | |
.map(From::from) | |
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR), | |
with_payload: with_payload | |
.map(TryFrom::try_from) | |
.transpose()? | |
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD), | |
lookup_from: lookup_from.map(From::from), | |
}) | |
} | |
fn convert_prefetch_with_inferred( | |
prefetch: grpc::PrefetchQuery, | |
inferred: &BatchAccumInferred, | |
) -> Result<CollectionPrefetch, Status> { | |
let grpc::PrefetchQuery { | |
prefetch, | |
query, | |
using, | |
filter, | |
params, | |
score_threshold, | |
limit, | |
lookup_from, | |
} = prefetch; | |
let nested_prefetches = prefetch | |
.into_iter() | |
.map(|p| convert_prefetch_with_inferred(p, inferred)) | |
.collect::<Result<Vec<_>, _>>()?; | |
let query = query | |
.map(|q| convert_query_with_inferred(q, inferred)) | |
.transpose()?; | |
Ok(CollectionPrefetch { | |
prefetch: nested_prefetches, | |
query, | |
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), | |
filter: filter.map(TryFrom::try_from).transpose()?, | |
score_threshold, | |
limit: limit | |
.map(|l| l as usize) | |
.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), | |
params: params.map(From::from), | |
lookup_from: lookup_from.map(From::from), | |
}) | |
} | |
fn convert_query_with_inferred( | |
query: grpc::Query, | |
inferred: &BatchAccumInferred, | |
) -> Result<Query, Status> { | |
let variant = query | |
.variant | |
.ok_or_else(|| Status::invalid_argument("Query variant is missing"))?; | |
let query = match variant { | |
Variant::Nearest(nearest) => { | |
let vector = convert_vector_input_with_inferred(nearest, inferred)?; | |
Query::Vector(VectorQuery::Nearest(vector)) | |
} | |
Variant::Recommend(recommend) => { | |
let RecommendInput { | |
positive, | |
negative, | |
strategy, | |
} = recommend; | |
let positives = positive | |
.into_iter() | |
.map(|v| convert_vector_input_with_inferred(v, inferred)) | |
.collect::<Result<Vec<_>, _>>()?; | |
let negatives = negative | |
.into_iter() | |
.map(|v| convert_vector_input_with_inferred(v, inferred)) | |
.collect::<Result<Vec<_>, _>>()?; | |
let reco_query = RecoQuery::new(positives, negatives); | |
let strategy = strategy | |
.and_then(|x| grpc::RecommendStrategy::try_from(x).ok()) | |
.map(RecommendStrategy::from) | |
.unwrap_or_default(); | |
match strategy { | |
RecommendStrategy::AverageVector => { | |
Query::Vector(VectorQuery::RecommendAverageVector(reco_query)) | |
} | |
RecommendStrategy::BestScore => { | |
Query::Vector(VectorQuery::RecommendBestScore(reco_query)) | |
} | |
} | |
} | |
Variant::Discover(discover) => { | |
let grpc::DiscoverInput { target, context } = discover; | |
let target = target | |
.map(|t| convert_vector_input_with_inferred(t, inferred)) | |
.transpose()? | |
.ok_or_else(|| Status::invalid_argument("DiscoverInput target is missing"))?; | |
let grpc::ContextInput { pairs } = context | |
.ok_or_else(|| Status::invalid_argument("DiscoverInput context is missing"))?; | |
let context = pairs | |
.into_iter() | |
.map(|pair| context_pair_from_grpc_with_inferred(pair, inferred)) | |
.collect::<Result<_, _>>()?; | |
Query::Vector(VectorQuery::Discover(DiscoveryQuery::new(target, context))) | |
} | |
Variant::Context(context) => { | |
let context_query = context_query_from_grpc_with_inferred(context, inferred)?; | |
Query::Vector(VectorQuery::Context(context_query)) | |
} | |
Variant::OrderBy(order_by) => Query::OrderBy(OrderBy::try_from(order_by)?), | |
Variant::Fusion(fusion) => Query::Fusion(FusionInternal::try_from(fusion)?), | |
Variant::Sample(sample) => Query::Sample(SampleInternal::try_from(sample)?), | |
}; | |
Ok(query) | |
} | |
fn convert_vector_input_with_inferred( | |
vector: grpc::VectorInput, | |
inferred: &BatchAccumInferred, | |
) -> Result<VectorInputInternal, Status> { | |
use api::grpc::qdrant::vector_input::Variant; | |
let variant = vector | |
.variant | |
.ok_or_else(|| Status::invalid_argument("VectorInput variant is missing"))?; | |
match variant { | |
Variant::Id(id) => Ok(VectorInputInternal::Id(TryFrom::try_from(id)?)), | |
Variant::Dense(dense) => Ok(VectorInputInternal::Vector(VectorInternal::Dense( | |
From::from(dense), | |
))), | |
Variant::Sparse(sparse) => Ok(VectorInputInternal::Vector(VectorInternal::Sparse( | |
From::from(sparse), | |
))), | |
Variant::MultiDense(multi_dense) => Ok(VectorInputInternal::Vector( | |
VectorInternal::MultiDense(From::from(multi_dense)), | |
)), | |
Variant::Document(doc) => { | |
let doc: rest::Document = doc | |
.try_into() | |
.map_err(|e| Status::internal(format!("Document conversion error: {e}")))?; | |
let data = InferenceData::Document(doc); | |
let vector = inferred | |
.get_vector(&data) | |
.ok_or_else(|| Status::internal("Missing inferred vector for document"))?; | |
Ok(VectorInputInternal::Vector(VectorInternal::from( | |
vector.clone(), | |
))) | |
} | |
Variant::Image(img) => { | |
let img: rest::Image = img | |
.try_into() | |
.map_err(|e| Status::internal(format!("Image conversion error: {e}",)))?; | |
let data = InferenceData::Image(img); | |
let vector = inferred | |
.get_vector(&data) | |
.ok_or_else(|| Status::internal("Missing inferred vector for image"))?; | |
Ok(VectorInputInternal::Vector(VectorInternal::from( | |
vector.clone(), | |
))) | |
} | |
Variant::Object(obj) => { | |
let obj: rest::InferenceObject = obj | |
.try_into() | |
.map_err(|e| Status::internal(format!("Object conversion error: {e}")))?; | |
let data = InferenceData::Object(obj); | |
let vector = inferred | |
.get_vector(&data) | |
.ok_or_else(|| Status::internal("Missing inferred vector for object"))?; | |
Ok(VectorInputInternal::Vector(VectorInternal::from( | |
vector.clone(), | |
))) | |
} | |
} | |
} | |
fn context_query_from_grpc_with_inferred( | |
value: grpc::ContextInput, | |
inferred: &BatchAccumInferred, | |
) -> Result<ContextQuery<VectorInputInternal>, Status> { | |
let grpc::ContextInput { pairs } = value; | |
Ok(ContextQuery { | |
pairs: pairs | |
.into_iter() | |
.map(|pair| context_pair_from_grpc_with_inferred(pair, inferred)) | |
.collect::<Result<_, _>>()?, | |
}) | |
} | |
fn context_pair_from_grpc_with_inferred( | |
value: grpc::ContextInputPair, | |
inferred: &BatchAccumInferred, | |
) -> Result<ContextPair<VectorInputInternal>, Status> { | |
let grpc::ContextInputPair { positive, negative } = value; | |
let positive = | |
positive.ok_or_else(|| Status::invalid_argument("ContextPair positive is missing"))?; | |
let negative = | |
negative.ok_or_else(|| Status::invalid_argument("ContextPair negative is missing"))?; | |
Ok(ContextPair { | |
positive: convert_vector_input_with_inferred(positive, inferred)?, | |
negative: convert_vector_input_with_inferred(negative, inferred)?, | |
}) | |
} | |
mod tests { | |
use std::collections::HashMap; | |
use api::grpc::qdrant::value::Kind; | |
use api::grpc::qdrant::vector_input::Variant; | |
use api::grpc::qdrant::Value; | |
use collection::operations::point_ops::VectorPersisted; | |
use super::*; | |
fn create_test_document() -> api::grpc::qdrant::Document { | |
api::grpc::qdrant::Document { | |
text: "test".to_string(), | |
model: "test-model".to_string(), | |
options: HashMap::new(), | |
} | |
} | |
fn create_test_image() -> api::grpc::qdrant::Image { | |
api::grpc::qdrant::Image { | |
image: Some(Value { | |
kind: Some(Kind::StringValue("test.jpg".to_string())), | |
}), | |
model: "test-model".to_string(), | |
options: HashMap::new(), | |
} | |
} | |
fn create_test_object() -> api::grpc::qdrant::InferenceObject { | |
api::grpc::qdrant::InferenceObject { | |
object: Some(Value { | |
kind: Some(Kind::StringValue("test".to_string())), | |
}), | |
model: "test-model".to_string(), | |
options: HashMap::new(), | |
} | |
} | |
fn create_test_inferred_batch() -> BatchAccumInferred { | |
let mut objects = HashMap::new(); | |
let grpc_doc = create_test_document(); | |
let grpc_img = create_test_image(); | |
let grpc_obj = create_test_object(); | |
let doc: rest::Document = grpc_doc.try_into().unwrap(); | |
let img: rest::Image = grpc_img.try_into().unwrap(); | |
let obj: rest::InferenceObject = grpc_obj.try_into().unwrap(); | |
let doc_data = InferenceData::Document(doc); | |
let img_data = InferenceData::Image(img); | |
let obj_data = InferenceData::Object(obj); | |
let dense_vector = vec![1.0, 2.0, 3.0]; | |
let vector_persisted = VectorPersisted::Dense(dense_vector); | |
objects.insert(doc_data, vector_persisted.clone()); | |
objects.insert(img_data, vector_persisted.clone()); | |
objects.insert(obj_data, vector_persisted); | |
BatchAccumInferred { objects } | |
} | |
fn test_convert_vector_input_with_inferred_dense() { | |
let inferred = create_test_inferred_batch(); | |
let vector = grpc::VectorInput { | |
variant: Some(Variant::Dense(grpc::DenseVector { | |
data: vec![1.0, 2.0, 3.0], | |
})), | |
}; | |
let result = convert_vector_input_with_inferred(vector, &inferred).unwrap(); | |
match result { | |
VectorInputInternal::Vector(VectorInternal::Dense(values)) => { | |
assert_eq!(values, vec![1.0, 2.0, 3.0]); | |
} | |
_ => panic!("Expected dense vector"), | |
} | |
} | |
fn test_convert_vector_input_with_inferred_document() { | |
let inferred = create_test_inferred_batch(); | |
let doc = create_test_document(); | |
let vector = grpc::VectorInput { | |
variant: Some(Variant::Document(doc)), | |
}; | |
let result = convert_vector_input_with_inferred(vector, &inferred).unwrap(); | |
match result { | |
VectorInputInternal::Vector(VectorInternal::Dense(values)) => { | |
assert_eq!(values, vec![1.0, 2.0, 3.0]); | |
} | |
_ => panic!("Expected dense vector from inference"), | |
} | |
} | |
fn test_convert_vector_input_missing_variant() { | |
let inferred = create_test_inferred_batch(); | |
let vector = grpc::VectorInput { variant: None }; | |
let result = convert_vector_input_with_inferred(vector, &inferred); | |
assert!(result.is_err()); | |
assert!(result.unwrap_err().message().contains("variant is missing")); | |
} | |
fn test_context_pair_from_grpc_with_inferred() { | |
let inferred = create_test_inferred_batch(); | |
let pair = grpc::ContextInputPair { | |
positive: Some(grpc::VectorInput { | |
variant: Some(Variant::Dense(grpc::DenseVector { | |
data: vec![1.0, 2.0, 3.0], | |
})), | |
}), | |
negative: Some(grpc::VectorInput { | |
variant: Some(Variant::Document(create_test_document())), | |
}), | |
}; | |
let result = context_pair_from_grpc_with_inferred(pair, &inferred).unwrap(); | |
match (result.positive, result.negative) { | |
( | |
VectorInputInternal::Vector(VectorInternal::Dense(pos)), | |
VectorInputInternal::Vector(VectorInternal::Dense(neg)), | |
) => { | |
assert_eq!(pos, vec![1.0, 2.0, 3.0]); | |
assert_eq!(neg, vec![1.0, 2.0, 3.0]); | |
} | |
_ => panic!("Expected dense vectors"), | |
} | |
} | |
fn test_context_pair_missing_vectors() { | |
let inferred = create_test_inferred_batch(); | |
let pair = grpc::ContextInputPair { | |
positive: None, | |
negative: Some(grpc::VectorInput { | |
variant: Some(Variant::Document(create_test_document())), | |
}), | |
}; | |
let result = context_pair_from_grpc_with_inferred(pair, &inferred); | |
assert!(result.is_err()); | |
assert!(result | |
.unwrap_err() | |
.message() | |
.contains("positive is missing")); | |
} | |
} | |