diff --git a/src/actix/actix_telemetry.rs b/src/actix/actix_telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..8c8bcf0e4f0b8a3dc41a9a24c2d3052fb24bb4fe --- /dev/null +++ b/src/actix/actix_telemetry.rs @@ -0,0 +1,90 @@ +use std::future::{ready, Ready}; +use std::sync::Arc; + +use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; +use actix_web::Error; +use futures_util::future::LocalBoxFuture; +use parking_lot::Mutex; + +use crate::common::telemetry_ops::requests_telemetry::{ + ActixTelemetryCollector, ActixWorkerTelemetryCollector, +}; + +pub struct ActixTelemetryService { + service: S, + telemetry_data: Arc>, +} + +pub struct ActixTelemetryTransform { + telemetry_collector: Arc>, +} + +/// Actix telemetry service. It hooks every request and looks into response status code. +/// +/// More about actix service with similar example +/// +impl Service for ActixTelemetryService +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + actix_web::dev::forward_ready!(service); + + fn call(&self, request: ServiceRequest) -> Self::Future { + let match_pattern = request + .match_pattern() + .unwrap_or_else(|| "unknown".to_owned()); + let request_key = format!("{} {}", request.method(), match_pattern); + let future = self.service.call(request); + let telemetry_data = self.telemetry_data.clone(); + Box::pin(async move { + let instant = std::time::Instant::now(); + let response = future.await?; + let status = response.response().status().as_u16(); + telemetry_data + .lock() + .add_response(request_key, status, instant); + Ok(response) + }) + } +} + +impl ActixTelemetryTransform { + pub fn new(telemetry_collector: Arc>) -> Self { + Self { + telemetry_collector, + } + } +} + +/// Actix telemetry transform. It's a builder for an actix service +/// +/// More about actix transform with similar example +/// +impl Transform for ActixTelemetryTransform +where + S: Service, Error = Error> + 'static, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = Error; + type Transform = ActixTelemetryService; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(ActixTelemetryService { + service, + telemetry_data: self + .telemetry_collector + .lock() + .create_web_worker_telemetry(), + })) + } +} diff --git a/src/actix/api/cluster_api.rs b/src/actix/api/cluster_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..478fc1be395372f17619c98a627b66b8ef58a863 --- /dev/null +++ b/src/actix/api/cluster_api.rs @@ -0,0 +1,189 @@ +use std::future::Future; + +use actix_web::{delete, get, post, put, web, HttpResponse}; +use actix_web_validator::Query; +use collection::operations::verification::new_unchecked_verification_pass; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use storage::content_manager::consensus_ops::ConsensusOperations; +use storage::content_manager::errors::StorageError; +use storage::dispatcher::Dispatcher; +use storage::rbac::AccessRequirements; +use validator::Validate; + +use crate::actix::auth::ActixAccess; +use crate::actix::helpers; + +#[derive(Debug, Deserialize, Validate)] +struct QueryParams { + #[serde(default)] + force: bool, + #[serde(default)] + #[validate(range(min = 1))] + timeout: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct MetadataParams { + #[serde(default)] + pub wait: bool, +} + +#[get("/cluster")] +fn cluster_status( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Future { + helpers::time(async move { + access.check_global_access(AccessRequirements::new())?; + Ok(dispatcher.cluster_status()) + }) +} + +#[post("/cluster/recover")] +fn recover_current_peer( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Future { + // Not a collection level request. + let pass = new_unchecked_verification_pass(); + + helpers::time(async move { + access.check_global_access(AccessRequirements::new().manage())?; + dispatcher.toc(&access, &pass).request_snapshot()?; + Ok(true) + }) +} + +#[delete("/cluster/peer/{peer_id}")] +fn remove_peer( + dispatcher: web::Data, + peer_id: web::Path, + Query(params): Query, + ActixAccess(access): ActixAccess, +) -> impl Future { + // Not a collection level request. + let pass = new_unchecked_verification_pass(); + + helpers::time(async move { + access.check_global_access(AccessRequirements::new().manage())?; + + let dispatcher = dispatcher.into_inner(); + let toc = dispatcher.toc(&access, &pass); + let peer_id = peer_id.into_inner(); + + let has_shards = toc.peer_has_shards(peer_id).await; + if !params.force && has_shards { + return Err(StorageError::BadRequest { + description: format!("Cannot remove peer {peer_id} as there are shards on it"), + }); + } + + match dispatcher.consensus_state() { + Some(consensus_state) => { + consensus_state + .propose_consensus_op_with_await( + ConsensusOperations::RemovePeer(peer_id), + params.timeout.map(std::time::Duration::from_secs), + ) + .await + } + None => Err(StorageError::BadRequest { + description: "Distributed mode disabled.".to_string(), + }), + } + }) +} + +#[get("/cluster/metadata/keys")] +async fn get_cluster_metadata_keys( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + helpers::time(async move { + access.check_global_access(AccessRequirements::new())?; + + let keys = dispatcher + .consensus_state() + .ok_or_else(|| StorageError::service_error("Qdrant is running in standalone mode"))? + .persistent + .read() + .get_cluster_metadata_keys(); + + Ok(keys) + }) + .await +} + +#[get("/cluster/metadata/keys/{key}")] +async fn get_cluster_metadata_key( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, + key: web::Path, +) -> HttpResponse { + helpers::time(async move { + access.check_global_access(AccessRequirements::new())?; + + let value = dispatcher + .consensus_state() + .ok_or_else(|| StorageError::service_error("Qdrant is running in standalone mode"))? + .persistent + .read() + .get_cluster_metadata_key(key.as_ref()); + + Ok(value) + }) + .await +} + +#[put("/cluster/metadata/keys/{key}")] +async fn update_cluster_metadata_key( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, + key: web::Path, + params: Query, + value: web::Json, +) -> HttpResponse { + // Not a collection level request. + let pass = new_unchecked_verification_pass(); + helpers::time(async move { + let toc = dispatcher.toc(&access, &pass); + access.check_global_access(AccessRequirements::new().write())?; + + toc.update_cluster_metadata(key.into_inner(), value.into_inner(), params.wait) + .await?; + Ok(true) + }) + .await +} + +#[delete("/cluster/metadata/keys/{key}")] +async fn delete_cluster_metadata_key( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, + key: web::Path, + params: Query, +) -> HttpResponse { + // Not a collection level request. + let pass = new_unchecked_verification_pass(); + helpers::time(async move { + let toc = dispatcher.toc(&access, &pass); + access.check_global_access(AccessRequirements::new().write())?; + + toc.update_cluster_metadata(key.into_inner(), serde_json::Value::Null, params.wait) + .await?; + Ok(true) + }) + .await +} + +// Configure services +pub fn config_cluster_api(cfg: &mut web::ServiceConfig) { + cfg.service(cluster_status) + .service(remove_peer) + .service(recover_current_peer) + .service(get_cluster_metadata_keys) + .service(get_cluster_metadata_key) + .service(update_cluster_metadata_key) + .service(delete_cluster_metadata_key); +} diff --git a/src/actix/api/collections_api.rs b/src/actix/api/collections_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..2c1739a1e1b2912967bd0b6b0ed6136a7c19adfb --- /dev/null +++ b/src/actix/api/collections_api.rs @@ -0,0 +1,256 @@ +use std::time::Duration; + +use actix_web::rt::time::Instant; +use actix_web::{delete, get, patch, post, put, web, HttpResponse, Responder}; +use actix_web_validator::{Json, Path, Query}; +use collection::operations::cluster_ops::ClusterOperations; +use collection::operations::verification::new_unchecked_verification_pass; +use serde::Deserialize; +use storage::content_manager::collection_meta_ops::{ + ChangeAliasesOperation, CollectionMetaOperations, CreateCollection, CreateCollectionOperation, + DeleteCollectionOperation, UpdateCollection, UpdateCollectionOperation, +}; +use storage::dispatcher::Dispatcher; +use validator::Validate; + +use super::CollectionPath; +use crate::actix::api::StrictCollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, process_response}; +use crate::common::collections::*; + +#[derive(Debug, Deserialize, Validate)] +pub struct WaitTimeout { + #[validate(range(min = 1))] + timeout: Option, +} + +impl WaitTimeout { + pub fn timeout(&self) -> Option { + self.timeout.map(Duration::from_secs) + } +} + +#[get("/collections")] +async fn get_collections( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + // No request to verify + let pass = new_unchecked_verification_pass(); + + helpers::time(do_list_collections(dispatcher.toc(&access, &pass), access)).await +} + +#[get("/aliases")] +async fn get_aliases( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + // No request to verify + let pass = new_unchecked_verification_pass(); + + helpers::time(do_list_aliases(dispatcher.toc(&access, &pass), access)).await +} + +#[get("/collections/{name}")] +async fn get_collection( + dispatcher: web::Data, + collection: Path, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + // No request to verify + let pass = new_unchecked_verification_pass(); + + helpers::time(do_get_collection( + dispatcher.toc(&access, &pass), + access, + &collection.name, + None, + )) + .await +} + +#[get("/collections/{name}/exists")] +async fn get_collection_existence( + dispatcher: web::Data, + collection: Path, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + // No request to verify + let pass = new_unchecked_verification_pass(); + + helpers::time(do_collection_exists( + dispatcher.toc(&access, &pass), + access, + &collection.name, + )) + .await +} + +#[get("/collections/{name}/aliases")] +async fn get_collection_aliases( + dispatcher: web::Data, + collection: Path, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + // No request to verify + let pass = new_unchecked_verification_pass(); + + helpers::time(do_list_collection_aliases( + dispatcher.toc(&access, &pass), + access, + &collection.name, + )) + .await +} + +#[put("/collections/{name}")] +async fn create_collection( + dispatcher: web::Data, + collection: Path, + operation: Json, + Query(query): Query, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + helpers::time(dispatcher.submit_collection_meta_op( + CollectionMetaOperations::CreateCollection(CreateCollectionOperation::new( + collection.name.clone(), + operation.into_inner(), + )), + access, + query.timeout(), + )) + .await +} + +#[patch("/collections/{name}")] +async fn update_collection( + dispatcher: web::Data, + collection: Path, + operation: Json, + Query(query): Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let name = collection.name.clone(); + let response = dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::UpdateCollection(UpdateCollectionOperation::new( + name, + operation.into_inner(), + )), + access, + query.timeout(), + ) + .await; + process_response(response, timing, None) +} + +#[delete("/collections/{name}")] +async fn delete_collection( + dispatcher: web::Data, + collection: Path, + Query(query): Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let response = dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::DeleteCollection(DeleteCollectionOperation( + collection.name.clone(), + )), + access, + query.timeout(), + ) + .await; + process_response(response, timing, None) +} + +#[post("/collections/aliases")] +async fn update_aliases( + dispatcher: web::Data, + operation: Json, + Query(query): Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let response = dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::ChangeAliases(operation.0), + access, + query.timeout(), + ) + .await; + process_response(response, timing, None) +} + +#[get("/collections/{name}/cluster")] +async fn get_cluster_info( + dispatcher: web::Data, + collection: Path, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // No request to verify + let pass = new_unchecked_verification_pass(); + + helpers::time(do_get_collection_cluster( + dispatcher.toc(&access, &pass), + access, + &collection.name, + )) + .await +} + +#[post("/collections/{name}/cluster")] +async fn update_collection_cluster( + dispatcher: web::Data, + collection: Path, + operation: Json, + Query(query): Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let wait_timeout = query.timeout(); + let response = do_update_collection_cluster( + &dispatcher.into_inner(), + collection.name.clone(), + operation.0, + access, + wait_timeout, + ) + .await; + process_response(response, timing, None) +} + +// Configure services +pub fn config_collections_api(cfg: &mut web::ServiceConfig) { + // Ordering of services is important for correct path pattern matching + // See: + cfg.service(update_aliases) + .service(get_collections) + .service(get_collection) + .service(get_collection_existence) + .service(create_collection) + .service(update_collection) + .service(delete_collection) + .service(get_aliases) + .service(get_collection_aliases) + .service(get_cluster_info) + .service(update_collection_cluster); +} + +#[cfg(test)] +mod tests { + use actix_web::web::Query; + + use super::WaitTimeout; + + #[test] + fn timeout_is_deserialized() { + let timeout: WaitTimeout = Query::from_query("").unwrap().0; + assert!(timeout.timeout.is_none()); + let timeout: WaitTimeout = Query::from_query("timeout=10").unwrap().0; + assert_eq!(timeout.timeout, Some(10)) + } +} diff --git a/src/actix/api/count_api.rs b/src/actix/api/count_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..a536e82debb3289dce6c0e6ff65e2260cc39fc0f --- /dev/null +++ b/src/actix/api/count_api.rs @@ -0,0 +1,69 @@ +use actix_web::{post, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::CountRequest; +use storage::content_manager::collection_verification::check_strict_mode; +use storage::dispatcher::Dispatcher; +use tokio::time::Instant; + +use super::CollectionPath; +use crate::actix::api::read_params::ReadParams; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error}; +use crate::common::points::do_count_points; +use crate::settings::ServiceConfig; + +#[post("/collections/{name}/points/count")] +async fn count_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let CountRequest { + count_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &count_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selector = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => ShardSelectorInternal::from(shard_keys), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + + let timing = Instant::now(); + + let result = do_count_points( + dispatcher.toc(&access, &pass), + &collection.name, + count_request, + params.consistency, + params.timeout(), + shard_selector, + access, + request_hw_counter.get_counter(), + ) + .await; + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} diff --git a/src/actix/api/debug_api.rs b/src/actix/api/debug_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..3e7a1b39ac1b6ff2e49145a25978ad92719aa924 --- /dev/null +++ b/src/actix/api/debug_api.rs @@ -0,0 +1,36 @@ +use actix_web::{get, patch, web, Responder}; +use storage::rbac::AccessRequirements; + +use crate::actix::auth::ActixAccess; +use crate::common::debugger::{DebugConfigPatch, DebuggerState}; + +#[get("/debugger")] +async fn get_debugger_config( + ActixAccess(access): ActixAccess, + debugger_state: web::Data, +) -> impl Responder { + crate::actix::helpers::time(async move { + access.check_global_access(AccessRequirements::new().manage())?; + Ok(debugger_state.get_config()) + }) + .await +} + +#[patch("/debugger")] +async fn update_debugger_config( + ActixAccess(access): ActixAccess, + debugger_state: web::Data, + debug_patch: web::Json, +) -> impl Responder { + crate::actix::helpers::time(async move { + access.check_global_access(AccessRequirements::new().manage())?; + Ok(debugger_state.apply_config_patch(debug_patch.into_inner())) + }) + .await +} + +// Configure services +pub fn config_debugger_api(cfg: &mut web::ServiceConfig) { + cfg.service(get_debugger_config); + cfg.service(update_debugger_config); +} diff --git a/src/actix/api/discovery_api.rs b/src/actix/api/discovery_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..c465045c5b33323fb63f37972febe675028d91ea --- /dev/null +++ b/src/actix/api/discovery_api.rs @@ -0,0 +1,140 @@ +use actix_web::{post, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::{DiscoverRequest, DiscoverRequestBatch}; +use itertools::Itertools; +use storage::content_manager::collection_verification::{ + check_strict_mode, check_strict_mode_batch, +}; +use storage::dispatcher::Dispatcher; +use tokio::time::Instant; + +use crate::actix::api::read_params::ReadParams; +use crate::actix::api::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error}; +use crate::common::points::do_discover_batch_points; +use crate::settings::ServiceConfig; + +#[post("/collections/{name}/points/discover")] +async fn discover_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let DiscoverRequest { + discover_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &discover_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + + let timing = Instant::now(); + + let result = dispatcher + .toc(&access, &pass) + .discover( + &collection.name, + discover_request, + params.consistency, + shard_selection, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(|scored_points| { + scored_points + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec() + }); + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/discover/batch")] +async fn discover_batch_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let request = request.into_inner(); + + let pass = match check_strict_mode_batch( + request.searches.iter().map(|i| &i.discover_request), + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + + let result = do_discover_batch_points( + dispatcher.toc(&access, &pass), + &collection.name, + request, + params.consistency, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(|batch_scored_points| { + batch_scored_points + .into_iter() + .map(|scored_points| { + scored_points + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec() + }) + .collect_vec() + }); + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +pub fn config_discovery_api(cfg: &mut web::ServiceConfig) { + cfg.service(discover_points); + cfg.service(discover_batch_points); +} diff --git a/src/actix/api/facet_api.rs b/src/actix/api/facet_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..516d21b9a96ad8cc78fe13c8f564d6053af5e1af --- /dev/null +++ b/src/actix/api/facet_api.rs @@ -0,0 +1,77 @@ +use actix_web::{post, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use api::rest::{FacetRequest, FacetResponse}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use storage::content_manager::collection_verification::check_strict_mode; +use storage::dispatcher::Dispatcher; +use tokio::time::Instant; + +use crate::actix::api::read_params::ReadParams; +use crate::actix::api::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{ + get_request_hardware_counter, process_response, process_response_error, +}; +use crate::settings::ServiceConfig; + +#[post("/collections/{name}/facet")] +async fn facet( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + + let FacetRequest { + facet_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &facet_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, timing, None), + }; + + let facet_params = From::from(facet_request); + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + + let response = dispatcher + .toc(&access, &pass) + .facet( + &collection.name, + facet_params, + shard_selection, + params.consistency, + access, + params.timeout(), + ) + .await + .map(FacetResponse::from); + + process_response(response, timing, request_hw_counter.to_rest_api()) +} + +pub fn config_facet_api(cfg: &mut web::ServiceConfig) { + cfg.service(facet); +} diff --git a/src/actix/api/issues_api.rs b/src/actix/api/issues_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..0920ee239a20789abbe0a3062dcf2c9b77625a67 --- /dev/null +++ b/src/actix/api/issues_api.rs @@ -0,0 +1,32 @@ +use actix_web::{delete, get, web, Responder}; +use collection::operations::types::IssuesReport; +use storage::rbac::AccessRequirements; + +use crate::actix::auth::ActixAccess; + +#[get("/issues")] +async fn get_issues(ActixAccess(access): ActixAccess) -> impl Responder { + crate::actix::helpers::time(async move { + access.check_global_access(AccessRequirements::new().manage())?; + Ok(IssuesReport { + issues: issues::all_issues(), + }) + }) + .await +} + +#[delete("/issues")] +async fn clear_issues(ActixAccess(access): ActixAccess) -> impl Responder { + crate::actix::helpers::time(async move { + access.check_global_access(AccessRequirements::new().manage())?; + issues::clear(); + Ok(true) + }) + .await +} + +// Configure services +pub fn config_issues_api(cfg: &mut web::ServiceConfig) { + cfg.service(get_issues); + cfg.service(clear_issues); +} diff --git a/src/actix/api/local_shard_api.rs b/src/actix/api/local_shard_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..e7b027a13cc42b3614cd29d2d3cb4b686d24a5da --- /dev/null +++ b/src/actix/api/local_shard_api.rs @@ -0,0 +1,267 @@ +use std::sync::Arc; + +use actix_web::{post, web, Responder}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::{ + CountRequestInternal, PointRequestInternal, ScrollRequestInternal, +}; +use collection::operations::verification::{new_unchecked_verification_pass, VerificationPass}; +use collection::shards::shard::ShardId; +use segment::types::{Condition, Filter}; +use storage::content_manager::collection_verification::check_strict_mode; +use storage::content_manager::errors::{StorageError, StorageResult}; +use storage::dispatcher::Dispatcher; +use storage::rbac::{Access, AccessRequirements}; +use tokio::time::Instant; + +use crate::actix::api::read_params::ReadParams; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error}; +use crate::common::points; +use crate::settings::ServiceConfig; + +// Configure services +pub fn config_local_shard_api(cfg: &mut web::ServiceConfig) { + cfg.service(get_points) + .service(scroll_points) + .service(count_points) + .service(cleanup_shard); +} + +#[post("/collections/{collection}/shards/{shard}/points")] +async fn get_points( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, + path: web::Path, + request: web::Json, + params: web::Query, +) -> impl Responder { + // No strict mode verification needed + let pass = new_unchecked_verification_pass(); + + helpers::time(async move { + let records = points::do_get_points( + dispatcher.toc(&access, &pass), + &path.collection, + request.into_inner(), + params.consistency, + params.timeout(), + ShardSelectorInternal::ShardId(path.shard), + access, + ) + .await?; + + let records: Vec<_> = records.into_iter().map(api::rest::Record::from).collect(); + Ok(records) + }) + .await +} + +#[post("/collections/{collection}/shards/{shard}/points/scroll")] +async fn scroll_points( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, + path: web::Path, + request: web::Json>, + params: web::Query, +) -> impl Responder { + let WithFilter { + mut request, + hash_ring_filter, + } = request.into_inner(); + + let pass = match check_strict_mode( + &request, + params.timeout_as_secs(), + &path.collection, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + helpers::time(async move { + let hash_ring_filter = match hash_ring_filter { + Some(filter) => get_hash_ring_filter( + &dispatcher, + &access, + &path.collection, + AccessRequirements::new(), + filter.expected_shard_id, + &pass, + ) + .await? + .into(), + + None => None, + }; + + request.filter = merge_with_optional_filter(request.filter.take(), hash_ring_filter); + + dispatcher + .toc(&access, &pass) + .scroll( + &path.collection, + request, + params.consistency, + params.timeout(), + ShardSelectorInternal::ShardId(path.shard), + access, + ) + .await + }) + .await +} + +#[post("/collections/{collection}/shards/{shard}/points/count")] +async fn count_points( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, + path: web::Path, + request: web::Json>, + params: web::Query, + service_config: web::Data, +) -> impl Responder { + let WithFilter { + mut request, + hash_ring_filter, + } = request.into_inner(); + + let pass = match check_strict_mode( + &request, + params.timeout_as_secs(), + &path.collection, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + path.collection.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + let hw_measurement_acc = request_hw_counter.get_counter(); + + let result = async move { + let hash_ring_filter = match hash_ring_filter { + Some(filter) => get_hash_ring_filter( + &dispatcher, + &access, + &path.collection, + AccessRequirements::new(), + filter.expected_shard_id, + &pass, + ) + .await? + .into(), + + None => None, + }; + + request.filter = merge_with_optional_filter(request.filter.take(), hash_ring_filter); + + points::do_count_points( + dispatcher.toc(&access, &pass), + &path.collection, + request, + params.consistency, + params.timeout(), + ShardSelectorInternal::ShardId(path.shard), + access, + hw_measurement_acc, + ) + .await + } + .await; + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{collection}/shards/{shard}/cleanup")] +async fn cleanup_shard( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, + path: web::Path, +) -> impl Responder { + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + helpers::time(async move { + let path = path.into_inner(); + dispatcher + .toc(&access, &pass) + .cleanup_local_shard(&path.collection, path.shard, access) + .await + }) + .await +} + +#[derive(serde::Deserialize, validator::Validate)] +struct CollectionShard { + #[validate(length(min = 1, max = 255))] + collection: String, + shard: ShardId, +} + +#[derive(Clone, Debug, serde::Deserialize)] +struct WithFilter { + #[serde(flatten)] + request: T, + #[serde(default)] + hash_ring_filter: Option, +} + +#[derive(Clone, Debug, serde::Deserialize)] +struct SerdeHelper { + expected_shard_id: ShardId, +} + +async fn get_hash_ring_filter( + dispatcher: &Dispatcher, + access: &Access, + collection: &str, + reqs: AccessRequirements, + expected_shard_id: ShardId, + verification_pass: &VerificationPass, +) -> StorageResult { + let pass = access.check_collection_access(collection, reqs)?; + + let shard_holder = dispatcher + .toc(access, verification_pass) + .get_collection(&pass) + .await? + .shards_holder(); + + let hash_ring_filter = shard_holder + .read() + .await + .hash_ring_filter(expected_shard_id) + .ok_or_else(|| { + StorageError::bad_request(format!( + "shard {expected_shard_id} does not exist in collection {collection}" + )) + })?; + + let condition = Condition::CustomIdChecker(Arc::new(hash_ring_filter)); + let filter = Filter::new_must(condition); + + Ok(filter) +} + +fn merge_with_optional_filter(filter: Option, hash_ring: Option) -> Option { + match (filter, hash_ring) { + (Some(filter), Some(hash_ring)) => hash_ring.merge_owned(filter).into(), + (Some(filter), None) => filter.into(), + (None, Some(hash_ring)) => hash_ring.into(), + _ => None, + } +} diff --git a/src/actix/api/mod.rs b/src/actix/api/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..2c7821a2d69cb3c3f0dcd32801ddf6d721ee9a13 --- /dev/null +++ b/src/actix/api/mod.rs @@ -0,0 +1,46 @@ +use common::validation::validate_collection_name; +use serde::Deserialize; +use validator::Validate; + +pub mod cluster_api; +pub mod collections_api; +pub mod count_api; +pub mod debug_api; +pub mod discovery_api; +pub mod facet_api; +pub mod issues_api; +pub mod local_shard_api; +pub mod query_api; +pub mod read_params; +pub mod recommend_api; +pub mod retrieve_api; +pub mod search_api; +pub mod service_api; +pub mod shards_api; +pub mod snapshot_api; +pub mod update_api; + +/// A collection path with stricter validation +/// +/// Validation for collection paths has been made more strict over time. +/// To prevent breaking changes on existing collections, this is only enforced for newly created +/// collections. Basic validation is enforced everywhere else. +#[derive(Deserialize, Validate)] +struct StrictCollectionPath { + #[validate( + length(min = 1, max = 255), + custom(function = "validate_collection_name") + )] + name: String, +} + +/// A collection path with basic validation +/// +/// Validation for collection paths has been made more strict over time. +/// To prevent breaking changes on existing collections, this is only enforced for newly created +/// collections. Basic validation is enforced everywhere else. +#[derive(Deserialize, Validate)] +struct CollectionPath { + #[validate(length(min = 1, max = 255))] + name: String, +} diff --git a/src/actix/api/query_api.rs b/src/actix/api/query_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..7cbb2acca022cd614fb0b697024b51483c0ac7fe --- /dev/null +++ b/src/actix/api/query_api.rs @@ -0,0 +1,232 @@ +use actix_web::{post, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use api::rest::{QueryGroupsRequest, QueryRequest, QueryRequestBatch, QueryResponse}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use itertools::Itertools; +use storage::content_manager::collection_verification::{ + check_strict_mode, check_strict_mode_batch, +}; +use storage::content_manager::errors::StorageError; +use storage::dispatcher::Dispatcher; +use tokio::time::Instant; + +use super::read_params::ReadParams; +use super::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error}; +use crate::common::inference::query_requests_rest::{ + convert_query_groups_request_from_rest, convert_query_request_from_rest, +}; +use crate::common::points::do_query_point_groups; +use crate::settings::ServiceConfig; + +#[post("/collections/{name}/points/query")] +async fn query_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let QueryRequest { + internal: query_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &query_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + let hw_measurement_acc = request_hw_counter.get_counter(); + + let result = async move { + let request = convert_query_request_from_rest(query_request).await?; + + let points = dispatcher + .toc(&access, &pass) + .query_batch( + &collection.name, + vec![(request, shard_selection)], + params.consistency, + access, + params.timeout(), + hw_measurement_acc, + ) + .await? + .pop() + .ok_or_else(|| { + StorageError::service_error("Expected at least one response for one query") + })? + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec(); + + Ok(QueryResponse { points }) + } + .await; + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/query/batch")] +async fn query_points_batch( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let QueryRequestBatch { searches } = request.into_inner(); + + let pass = match check_strict_mode_batch( + searches.iter().map(|i| &i.internal), + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + let hw_measurement_acc = request_hw_counter.get_counter(); + + let result = async move { + let mut batch = Vec::with_capacity(searches.len()); + for request in searches { + let QueryRequest { + internal, + shard_key, + } = request; + + let request = convert_query_request_from_rest(internal).await?; + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + batch.push((request, shard_selection)); + } + + let res = dispatcher + .toc(&access, &pass) + .query_batch( + &collection.name, + batch, + params.consistency, + access, + params.timeout(), + hw_measurement_acc, + ) + .await? + .into_iter() + .map(|response| QueryResponse { + points: response + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec(), + }) + .collect_vec(); + Ok(res) + } + .await; + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/query/groups")] +async fn query_points_groups( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let QueryGroupsRequest { + search_group_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &search_group_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + let hw_measurement_acc = request_hw_counter.get_counter(); + + let result = async move { + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let query_group_request = + convert_query_groups_request_from_rest(search_group_request).await?; + + do_query_point_groups( + dispatcher.toc(&access, &pass), + &collection.name, + query_group_request, + params.consistency, + shard_selection, + access, + params.timeout(), + hw_measurement_acc, + ) + .await + } + .await; + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +pub fn config_query_api(cfg: &mut web::ServiceConfig) { + cfg.service(query_points); + cfg.service(query_points_batch); + cfg.service(query_points_groups); +} diff --git a/src/actix/api/read_params.rs b/src/actix/api/read_params.rs new file mode 100644 index 0000000000000000000000000000000000000000..e2ab5ee34727cdbea4e8dadc5704369f6b1b8c67 --- /dev/null +++ b/src/actix/api/read_params.rs @@ -0,0 +1,118 @@ +use std::num::NonZeroU64; +use std::time::Duration; + +use collection::operations::consistency_params::ReadConsistency; +use schemars::JsonSchema; +use serde::Deserialize; +use validator::Validate; + +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Deserialize, JsonSchema, Validate)] +pub struct ReadParams { + #[serde(default, deserialize_with = "deserialize_read_consistency")] + #[validate(nested)] + pub consistency: Option, + /// If set, overrides global timeout for this request. Unit is seconds. + pub timeout: Option, +} + +impl ReadParams { + pub fn timeout(&self) -> Option { + self.timeout.map(|num| Duration::from_secs(num.get())) + } + + pub(crate) fn timeout_as_secs(&self) -> Option { + self.timeout.map(|i| i.get() as usize) + } +} + +fn deserialize_read_consistency<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + #[derive(Deserialize)] + #[serde(untagged)] + enum Helper<'a> { + ReadConsistency(ReadConsistency), + Str(&'a str), + } + + match Helper::deserialize(deserializer)? { + Helper::ReadConsistency(read_consistency) => Ok(Some(read_consistency)), + Helper::Str("") => Ok(None), + _ => Err(serde::de::Error::custom( + "failed to deserialize read consistency query parameter value", + )), + } +} + +#[cfg(test)] +mod test { + use collection::operations::consistency_params::ReadConsistencyType; + + use super::*; + + #[test] + fn deserialize_empty_string() { + test_str("", ReadParams::default()); + } + + #[test] + fn deserialize_empty_value() { + test("", ReadParams::default()); + } + + #[test] + fn deserialize_type() { + test("all", from_type(ReadConsistencyType::All)); + test("majority", from_type(ReadConsistencyType::Majority)); + test("quorum", from_type(ReadConsistencyType::Quorum)); + } + + #[test] + fn deserialize_factor() { + for factor in 1..42 { + test(&factor.to_string(), from_factor(factor)); + } + } + + #[test] + fn try_deserialize_factor_0() { + assert!(try_deserialize(&str("0")).is_err()); + } + + fn test(value: &str, params: ReadParams) { + test_str(&str(value), params); + } + + fn test_str(str: &str, params: ReadParams) { + assert_eq!(deserialize(str), params); + } + + fn deserialize(str: &str) -> ReadParams { + try_deserialize(str).unwrap() + } + + fn try_deserialize(str: &str) -> Result { + serde_urlencoded::from_str(str) + } + + fn str(value: &str) -> String { + format!("consistency={value}") + } + + fn from_type(r#type: ReadConsistencyType) -> ReadParams { + ReadParams { + consistency: Some(ReadConsistency::Type(r#type)), + ..Default::default() + } + } + + fn from_factor(factor: usize) -> ReadParams { + ReadParams { + consistency: Some(ReadConsistency::Factor(factor)), + ..Default::default() + } + } +} diff --git a/src/actix/api/recommend_api.rs b/src/actix/api/recommend_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..fc3164c32182f353de85e4169a1341416b607cac --- /dev/null +++ b/src/actix/api/recommend_api.rs @@ -0,0 +1,235 @@ +use std::time::Duration; + +use actix_web::{post, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use collection::operations::consistency_params::ReadConsistency; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::{ + RecommendGroupsRequest, RecommendRequest, RecommendRequestBatch, +}; +use common::counter::hardware_accumulator::HwMeasurementAcc; +use itertools::Itertools; +use segment::types::ScoredPoint; +use storage::content_manager::collection_verification::{ + check_strict_mode, check_strict_mode_batch, +}; +use storage::content_manager::errors::StorageError; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +use tokio::time::Instant; + +use super::read_params::ReadParams; +use super::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error}; +use crate::settings::ServiceConfig; + +#[post("/collections/{name}/points/recommend")] +async fn recommend_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let RecommendRequest { + recommend_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &recommend_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + + let timing = Instant::now(); + + let result = dispatcher + .toc(&access, &pass) + .recommend( + &collection.name, + recommend_request, + params.consistency, + shard_selection, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(|scored_points| { + scored_points + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec() + }); + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +async fn do_recommend_batch_points( + toc: &TableOfContent, + collection_name: &str, + request: RecommendRequestBatch, + read_consistency: Option, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result>, StorageError> { + let requests = request + .searches + .into_iter() + .map(|req| { + let shard_selector = match req.shard_key { + None => ShardSelectorInternal::All, + Some(shard_key) => ShardSelectorInternal::from(shard_key), + }; + + (req.recommend_request, shard_selector) + }) + .collect(); + + toc.recommend_batch( + collection_name, + requests, + read_consistency, + access, + timeout, + hw_measurement_acc, + ) + .await +} + +#[post("/collections/{name}/points/recommend/batch")] +async fn recommend_batch_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let pass = match check_strict_mode_batch( + request.searches.iter().map(|i| &i.recommend_request), + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + + let result = do_recommend_batch_points( + dispatcher.toc(&access, &pass), + &collection.name, + request.into_inner(), + params.consistency, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(|batch_scored_points| { + batch_scored_points + .into_iter() + .map(|scored_points| { + scored_points + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec() + }) + .collect_vec() + }); + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/recommend/groups")] +async fn recommend_point_groups( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let RecommendGroupsRequest { + recommend_group_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &recommend_group_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + + let result = crate::common::points::do_recommend_point_groups( + dispatcher.toc(&access, &pass), + &collection.name, + recommend_group_request, + params.consistency, + shard_selection, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await; + + helpers::process_response(result, timing, request_hw_counter.to_rest_api()) +} +// Configure services +pub fn config_recommend_api(cfg: &mut web::ServiceConfig) { + cfg.service(recommend_points) + .service(recommend_batch_points) + .service(recommend_point_groups); +} diff --git a/src/actix/api/retrieve_api.rs b/src/actix/api/retrieve_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..7f476219e5ad921d5108aaa4e5a121679fb1a418 --- /dev/null +++ b/src/actix/api/retrieve_api.rs @@ -0,0 +1,200 @@ +use std::time::Duration; + +use actix_web::{get, post, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use collection::operations::consistency_params::ReadConsistency; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::{ + PointRequest, PointRequestInternal, RecordInternal, ScrollRequest, +}; +use futures::TryFutureExt; +use itertools::Itertools; +use segment::types::{PointIdType, WithPayloadInterface}; +use serde::Deserialize; +use storage::content_manager::collection_verification::{ + check_strict_mode, check_strict_mode_timeout, +}; +use storage::content_manager::errors::StorageError; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +use tokio::time::Instant; +use validator::Validate; + +use super::read_params::ReadParams; +use super::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, process_response_error}; +use crate::common::points::do_get_points; + +#[derive(Deserialize, Validate)] +struct PointPath { + #[validate(length(min = 1))] + // TODO: validate this is a valid ID type (usize or UUID)? Does currently error on deserialize. + id: String, +} + +async fn do_get_point( + toc: &TableOfContent, + collection_name: &str, + point_id: PointIdType, + read_consistency: Option, + timeout: Option, + access: Access, +) -> Result, StorageError> { + let request = PointRequestInternal { + ids: vec![point_id], + with_payload: Some(WithPayloadInterface::Bool(true)), + with_vector: true.into(), + }; + + let shard_selection = ShardSelectorInternal::All; + + toc.retrieve( + collection_name, + request, + read_consistency, + timeout, + shard_selection, + access, + ) + .await + .map(|points| points.into_iter().next()) +} + +#[get("/collections/{name}/points/{id}")] +async fn get_point( + dispatcher: web::Data, + collection: Path, + point: Path, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let pass = match check_strict_mode_timeout( + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(p) => p, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + helpers::time(async move { + let point_id: PointIdType = point.id.parse().map_err(|_| StorageError::BadInput { + description: format!("Can not recognize \"{}\" as point id", point.id), + })?; + + let Some(record) = do_get_point( + dispatcher.toc(&access, &pass), + &collection.name, + point_id, + params.consistency, + params.timeout(), + access, + ) + .await? + else { + return Err(StorageError::NotFound { + description: format!("Point with id {point_id} does not exists!"), + }); + }; + + Ok(api::rest::Record::from(record)) + }) + .await +} + +#[post("/collections/{name}/points")] +async fn get_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let pass = match check_strict_mode_timeout( + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(p) => p, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let PointRequest { + point_request, + shard_key, + } = request.into_inner(); + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => ShardSelectorInternal::from(shard_keys), + }; + + helpers::time( + do_get_points( + dispatcher.toc(&access, &pass), + &collection.name, + point_request, + params.consistency, + params.timeout(), + shard_selection, + access, + ) + .map_ok(|response| { + response + .into_iter() + .map(api::rest::Record::from) + .collect_vec() + }), + ) + .await +} + +#[post("/collections/{name}/points/scroll")] +async fn scroll_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let ScrollRequest { + scroll_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &scroll_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => ShardSelectorInternal::from(shard_keys), + }; + + helpers::time(dispatcher.toc(&access, &pass).scroll( + &collection.name, + scroll_request, + params.consistency, + params.timeout(), + shard_selection, + access, + )) + .await +} diff --git a/src/actix/api/search_api.rs b/src/actix/api/search_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..9c8ea43e6a449df324372581bf4201b1c2680d14 --- /dev/null +++ b/src/actix/api/search_api.rs @@ -0,0 +1,333 @@ +use actix_web::{post, web, HttpResponse, Responder}; +use actix_web_validator::{Json, Path, Query}; +use api::rest::{SearchMatrixOffsetsResponse, SearchMatrixPairsResponse, SearchMatrixRequest}; +use collection::collection::distance_matrix::CollectionSearchMatrixRequest; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::{ + CoreSearchRequest, SearchGroupsRequest, SearchRequest, SearchRequestBatch, +}; +use itertools::Itertools; +use storage::content_manager::collection_verification::{ + check_strict_mode, check_strict_mode_batch, +}; +use storage::dispatcher::Dispatcher; +use tokio::time::Instant; + +use super::read_params::ReadParams; +use super::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{ + get_request_hardware_counter, process_response, process_response_error, +}; +use crate::common::points::{ + do_core_search_points, do_search_batch_points, do_search_point_groups, do_search_points_matrix, +}; +use crate::settings::ServiceConfig; + +#[post("/collections/{name}/points/search")] +async fn search_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + let SearchRequest { + search_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &search_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + + let timing = Instant::now(); + + let result = do_core_search_points( + dispatcher.toc(&access, &pass), + &collection.name, + search_request.into(), + params.consistency, + shard_selection, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(|scored_points| { + scored_points + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec() + }); + + process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/search/batch")] +async fn batch_search_points( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + let requests = request + .into_inner() + .searches + .into_iter() + .map(|req| { + let SearchRequest { + search_request, + shard_key, + } = req; + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + let core_request: CoreSearchRequest = search_request.into(); + + (core_request, shard_selection) + }) + .collect::>(); + + let pass = match check_strict_mode_batch( + requests.iter().map(|i| &i.0), + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + + let timing = Instant::now(); + + let result = do_search_batch_points( + dispatcher.toc(&access, &pass), + &collection.name, + requests, + params.consistency, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(|batch_scored_points| { + batch_scored_points + .into_iter() + .map(|scored_points| { + scored_points + .into_iter() + .map(api::rest::ScoredPoint::from) + .collect_vec() + }) + .collect_vec() + }); + + process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/search/groups")] +async fn search_point_groups( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + let SearchGroupsRequest { + search_group_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &search_group_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + + let result = do_search_point_groups( + dispatcher.toc(&access, &pass), + &collection.name, + search_group_request, + params.consistency, + shard_selection, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await; + + process_response(result, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/search/matrix/pairs")] +async fn search_points_matrix_pairs( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let SearchMatrixRequest { + search_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &search_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + + let response = do_search_points_matrix( + dispatcher.toc(&access, &pass), + &collection.name, + CollectionSearchMatrixRequest::from(search_request), + params.consistency, + shard_selection, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(SearchMatrixPairsResponse::from); + + process_response(response, timing, request_hw_counter.to_rest_api()) +} + +#[post("/collections/{name}/points/search/matrix/offsets")] +async fn search_points_matrix_offsets( + dispatcher: web::Data, + collection: Path, + request: Json, + params: Query, + service_config: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let SearchMatrixRequest { + search_request, + shard_key, + } = request.into_inner(); + + let pass = match check_strict_mode( + &search_request, + params.timeout_as_secs(), + &collection.name, + &dispatcher, + &access, + ) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let shard_selection = match shard_key { + None => ShardSelectorInternal::All, + Some(shard_keys) => shard_keys.into(), + }; + + let request_hw_counter = get_request_hardware_counter( + &dispatcher, + collection.name.clone(), + service_config.hardware_reporting(), + ); + let timing = Instant::now(); + + let response = do_search_points_matrix( + dispatcher.toc(&access, &pass), + &collection.name, + CollectionSearchMatrixRequest::from(search_request), + params.consistency, + shard_selection, + access, + params.timeout(), + request_hw_counter.get_counter(), + ) + .await + .map(SearchMatrixOffsetsResponse::from); + + process_response(response, timing, request_hw_counter.to_rest_api()) +} + +// Configure services +pub fn config_search_api(cfg: &mut web::ServiceConfig) { + cfg.service(search_points) + .service(batch_search_points) + .service(search_point_groups) + .service(search_points_matrix_pairs) + .service(search_points_matrix_offsets); +} diff --git a/src/actix/api/service_api.rs b/src/actix/api/service_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..64920d1ab456d4211badd0e139b74994c76f6bdf --- /dev/null +++ b/src/actix/api/service_api.rs @@ -0,0 +1,217 @@ +use std::future::Future; +use std::sync::Arc; + +use actix_web::http::header::ContentType; +use actix_web::http::StatusCode; +use actix_web::rt::time::Instant; +use actix_web::web::Query; +use actix_web::{get, post, web, HttpResponse, Responder}; +use actix_web_validator::Json; +use collection::operations::verification::new_unchecked_verification_pass; +use common::types::{DetailsLevel, TelemetryDetail}; +use schemars::JsonSchema; +use segment::common::anonymize::Anonymize; +use serde::{Deserialize, Serialize}; +use storage::content_manager::errors::StorageError; +use storage::dispatcher::Dispatcher; +use storage::rbac::AccessRequirements; +use tokio::sync::Mutex; + +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, process_response_error}; +use crate::common::health; +use crate::common::helpers::LocksOption; +use crate::common::metrics::MetricsData; +use crate::common::stacktrace::get_stack_trace; +use crate::common::telemetry::TelemetryCollector; +use crate::tracing; + +#[derive(Deserialize, Serialize, JsonSchema)] +pub struct TelemetryParam { + pub anonymize: Option, + pub details_level: Option, +} + +#[get("/telemetry")] +fn telemetry( + telemetry_collector: web::Data>, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Future { + helpers::time(async move { + access.check_global_access(AccessRequirements::new())?; + let anonymize = params.anonymize.unwrap_or(false); + let details_level = params + .details_level + .map_or(DetailsLevel::Level0, Into::into); + let detail = TelemetryDetail { + level: details_level, + histograms: false, + }; + let telemetry_collector = telemetry_collector.lock().await; + let telemetry_data = telemetry_collector.prepare_data(&access, detail).await; + let telemetry_data = if anonymize { + telemetry_data.anonymize() + } else { + telemetry_data + }; + Ok(telemetry_data) + }) +} + +#[derive(Deserialize, Serialize, JsonSchema)] +pub struct MetricsParam { + pub anonymize: Option, +} + +#[get("/metrics")] +async fn metrics( + telemetry_collector: web::Data>, + params: Query, + ActixAccess(access): ActixAccess, +) -> HttpResponse { + if let Err(err) = access.check_global_access(AccessRequirements::new()) { + return process_response_error(err, Instant::now(), None); + } + + let anonymize = params.anonymize.unwrap_or(false); + let telemetry_collector = telemetry_collector.lock().await; + let telemetry_data = telemetry_collector + .prepare_data( + &access, + TelemetryDetail { + level: DetailsLevel::Level1, + histograms: true, + }, + ) + .await; + let telemetry_data = if anonymize { + telemetry_data.anonymize() + } else { + telemetry_data + }; + + HttpResponse::Ok() + .content_type(ContentType::plaintext()) + .body(MetricsData::from(telemetry_data).format_metrics()) +} + +#[post("/locks")] +fn put_locks( + dispatcher: web::Data, + locks_option: Json, + ActixAccess(access): ActixAccess, +) -> impl Future { + // Not a collection level request. + let pass = new_unchecked_verification_pass(); + + helpers::time(async move { + let toc = dispatcher.toc(&access, &pass); + access.check_global_access(AccessRequirements::new().manage())?; + let result = LocksOption { + write: toc.is_write_locked(), + error_message: toc.get_lock_error_message(), + }; + toc.set_locks(locks_option.write, locks_option.error_message.clone()); + Ok(result) + }) +} + +#[get("/locks")] +fn get_locks( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Future { + // Not a collection level request. + let pass = new_unchecked_verification_pass(); + + helpers::time(async move { + access.check_global_access(AccessRequirements::new())?; + let toc = dispatcher.toc(&access, &pass); + let result = LocksOption { + write: toc.is_write_locked(), + error_message: toc.get_lock_error_message(), + }; + Ok(result) + }) +} + +#[get("/stacktrace")] +fn get_stacktrace(ActixAccess(access): ActixAccess) -> impl Future { + helpers::time(async move { + access.check_global_access(AccessRequirements::new().manage())?; + Ok(get_stack_trace()) + }) +} + +#[get("/healthz")] +async fn healthz() -> impl Responder { + kubernetes_healthz() +} + +#[get("/livez")] +async fn livez() -> impl Responder { + kubernetes_healthz() +} + +#[get("/readyz")] +async fn readyz(health_checker: web::Data>>) -> impl Responder { + let is_ready = match health_checker.as_ref() { + Some(health_checker) => health_checker.check_ready().await, + None => true, + }; + + let (status, body) = if is_ready { + (StatusCode::OK, "all shards are ready") + } else { + (StatusCode::SERVICE_UNAVAILABLE, "some shards are not ready") + }; + + HttpResponse::build(status) + .content_type(ContentType::plaintext()) + .body(body) +} + +/// Basic Kubernetes healthz endpoint +fn kubernetes_healthz() -> impl Responder { + HttpResponse::Ok() + .content_type(ContentType::plaintext()) + .body("healthz check passed") +} + +#[get("/logger")] +async fn get_logger_config(handle: web::Data) -> impl Responder { + let timing = Instant::now(); + let result = handle.get_config().await; + helpers::process_response(Ok(result), timing, None) +} + +#[post("/logger")] +async fn update_logger_config( + handle: web::Data, + config: web::Json, +) -> impl Responder { + let timing = Instant::now(); + + let result = handle + .update_config(config.into_inner()) + .await + .map(|_| true) + .map_err(|err| StorageError::service_error(err.to_string())); + + helpers::process_response(result, timing, None) +} + +// Configure services +pub fn config_service_api(cfg: &mut web::ServiceConfig) { + cfg.service(telemetry) + .service(metrics) + .service(put_locks) + .service(get_locks) + .service(get_stacktrace) + .service(healthz) + .service(livez) + .service(readyz) + .service(get_logger_config) + .service(update_logger_config); +} diff --git a/src/actix/api/shards_api.rs b/src/actix/api/shards_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..32bb9e9057f691ae9cd83c72da0fd3daab7d3642 --- /dev/null +++ b/src/actix/api/shards_api.rs @@ -0,0 +1,80 @@ +use actix_web::{post, put, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use collection::operations::cluster_ops::{ + ClusterOperations, CreateShardingKey, CreateShardingKeyOperation, DropShardingKey, + DropShardingKeyOperation, +}; +use storage::dispatcher::Dispatcher; +use tokio::time::Instant; + +use crate::actix::api::collections_api::WaitTimeout; +use crate::actix::api::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::process_response; +use crate::common::collections::do_update_collection_cluster; + +// ToDo: introduce API for listing shard keys + +#[put("/collections/{name}/shards")] +async fn create_shard_key( + dispatcher: web::Data, + collection: Path, + request: Json, + Query(query): Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let wait_timeout = query.timeout(); + let dispatcher = dispatcher.into_inner(); + + let request = request.into_inner(); + + let operation = ClusterOperations::CreateShardingKey(CreateShardingKeyOperation { + create_sharding_key: request, + }); + + let response = do_update_collection_cluster( + &dispatcher, + collection.name.clone(), + operation, + access, + wait_timeout, + ) + .await; + + process_response(response, timing, None) +} + +#[post("/collections/{name}/shards/delete")] +async fn delete_shard_key( + dispatcher: web::Data, + collection: Path, + request: Json, + Query(query): Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let wait_timeout = query.timeout(); + + let dispatcher = dispatcher.into_inner(); + let request = request.into_inner(); + + let operation = ClusterOperations::DropShardingKey(DropShardingKeyOperation { + drop_sharding_key: request, + }); + + let response = do_update_collection_cluster( + &dispatcher, + collection.name.clone(), + operation, + access, + wait_timeout, + ) + .await; + + process_response(response, timing, None) +} + +pub fn config_shards_api(cfg: &mut web::ServiceConfig) { + cfg.service(create_shard_key).service(delete_shard_key); +} diff --git a/src/actix/api/snapshot_api.rs b/src/actix/api/snapshot_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..3048c78e1714ba2e14cdd92fbb83b17db4691094 --- /dev/null +++ b/src/actix/api/snapshot_api.rs @@ -0,0 +1,585 @@ +use std::path::Path; + +use actix_multipart::form::tempfile::TempFile; +use actix_multipart::form::MultipartForm; +use actix_web::{delete, get, post, put, web, Responder, Result}; +use actix_web_validator as valid; +use collection::common::file_utils::move_file; +use collection::common::sha_256::{hash_file, hashes_equal}; +use collection::common::snapshot_stream::SnapshotStream; +use collection::operations::snapshot_ops::{ + ShardSnapshotRecover, SnapshotPriority, SnapshotRecover, +}; +use collection::operations::verification::new_unchecked_verification_pass; +use collection::shards::shard::ShardId; +use futures::{FutureExt as _, TryFutureExt as _}; +use reqwest::Url; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use storage::content_manager::errors::StorageError; +use storage::content_manager::snapshots::recover::do_recover_from_snapshot; +use storage::content_manager::snapshots::{ + do_create_full_snapshot, do_delete_collection_snapshot, do_delete_full_snapshot, + do_list_full_snapshots, +}; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::{Access, AccessRequirements}; +use uuid::Uuid; +use validator::Validate; + +use super::{CollectionPath, StrictCollectionPath}; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, HttpError}; +use crate::common; +use crate::common::collections::*; +use crate::common::http_client::HttpClient; + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct SnapshotUploadingParam { + pub wait: Option, + pub priority: Option, + + /// Optional SHA256 checksum to verify snapshot integrity before recovery. + #[serde(default)] + #[validate(custom(function = "::common::validation::validate_sha256_hash"))] + pub checksum: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct SnapshottingParam { + pub wait: Option, +} + +#[derive(MultipartForm)] +pub struct SnapshottingForm { + snapshot: TempFile, +} + +// Actix specific code +pub async fn do_get_full_snapshot( + toc: &TableOfContent, + access: Access, + snapshot_name: &str, +) -> Result { + access.check_global_access(AccessRequirements::new())?; + let snapshots_storage_manager = toc.get_snapshots_storage_manager()?; + let snapshot_path = + snapshots_storage_manager.get_full_snapshot_path(toc.snapshots_path(), snapshot_name)?; + let snapshot_stream = snapshots_storage_manager + .get_snapshot_stream(&snapshot_path) + .await?; + Ok(snapshot_stream) +} + +pub async fn do_save_uploaded_snapshot( + toc: &TableOfContent, + collection_name: &str, + snapshot: TempFile, +) -> Result { + let filename = snapshot + .file_name + // Sanitize the file name: + // - only take the top level path (no directories such as ../) + // - require the file name to be valid UTF-8 + .and_then(|x| { + Path::new(&x) + .file_name() + .map(|filename| filename.to_owned()) + }) + .and_then(|x| x.to_str().map(|x| x.to_owned())) + .unwrap_or_else(|| Uuid::new_v4().to_string()); + let collection_snapshot_path = toc.snapshots_path_for_collection(collection_name); + if !collection_snapshot_path.exists() { + log::debug!( + "Creating missing collection snapshots directory for {}", + collection_name + ); + toc.create_snapshots_path(collection_name).await?; + } + + let path = collection_snapshot_path.join(filename); + + move_file(snapshot.file.path(), &path).await?; + + let absolute_path = path.canonicalize()?; + + let snapshot_location = Url::from_file_path(&absolute_path).map_err(|_| { + StorageError::service_error(format!( + "Failed to convert path to URL: {}", + absolute_path.display() + )) + })?; + + Ok(snapshot_location) +} + +// Actix specific code +pub async fn do_get_snapshot( + toc: &TableOfContent, + access: Access, + collection_name: &str, + snapshot_name: &str, +) -> Result { + let collection_pass = + access.check_collection_access(collection_name, AccessRequirements::new().whole())?; + let collection: tokio::sync::RwLockReadGuard = + toc.get_collection(&collection_pass).await?; + let snapshot_storage_manager = collection.get_snapshots_storage_manager()?; + let snapshot_path = + snapshot_storage_manager.get_snapshot_path(collection.snapshots_path(), snapshot_name)?; + let snapshot_stream = snapshot_storage_manager + .get_snapshot_stream(&snapshot_path) + .await?; + Ok(snapshot_stream) +} + +#[get("/collections/{name}/snapshots")] +async fn list_snapshots( + dispatcher: web::Data, + path: web::Path, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // Nothing to verify. + let pass = new_unchecked_verification_pass(); + + helpers::time(do_list_snapshots( + dispatcher.toc(&access, &pass), + access, + &path, + )) + .await +} + +#[post("/collections/{name}/snapshots")] +async fn create_snapshot( + dispatcher: web::Data, + path: web::Path, + params: valid::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // Nothing to verify. + let pass = new_unchecked_verification_pass(); + + let collection_name = path.into_inner(); + + let future = async move { + do_create_snapshot( + dispatcher.toc(&access, &pass).clone(), + access, + &collection_name, + ) + .await + }; + + helpers::time_or_accept(future, params.wait.unwrap_or(true)).await +} + +#[post("/collections/{name}/snapshots/upload")] +async fn upload_snapshot( + dispatcher: web::Data, + http_client: web::Data, + collection: valid::Path, + MultipartForm(form): MultipartForm, + params: valid::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let wait = params.wait; + + // Nothing to verify. + let pass = new_unchecked_verification_pass(); + + let future = async move { + let snapshot = form.snapshot; + + access.check_global_access(AccessRequirements::new().manage())?; + + if let Some(checksum) = ¶ms.checksum { + let snapshot_checksum = hash_file(snapshot.file.path()).await?; + if !hashes_equal(snapshot_checksum.as_str(), checksum.as_str()) { + return Err(StorageError::checksum_mismatch(snapshot_checksum, checksum)); + } + } + + let snapshot_location = + do_save_uploaded_snapshot(dispatcher.toc(&access, &pass), &collection.name, snapshot) + .await?; + + // Snapshot is a local file, we do not need an API key for that + let http_client = http_client.client(None)?; + + let snapshot_recover = SnapshotRecover { + location: snapshot_location, + priority: params.priority, + checksum: None, + api_key: None, + }; + + do_recover_from_snapshot( + dispatcher.get_ref(), + &collection.name, + snapshot_recover, + access, + http_client, + ) + .await + }; + + helpers::time_or_accept(future, wait.unwrap_or(true)).await +} + +#[put("/collections/{name}/snapshots/recover")] +async fn recover_from_snapshot( + dispatcher: web::Data, + http_client: web::Data, + collection: valid::Path, + request: valid::Json, + params: valid::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let future = async move { + let snapshot_recover = request.into_inner(); + let http_client = http_client.client(snapshot_recover.api_key.as_deref())?; + + do_recover_from_snapshot( + dispatcher.get_ref(), + &collection.name, + snapshot_recover, + access, + http_client, + ) + .await + }; + + helpers::time_or_accept(future, params.wait.unwrap_or(true)).await +} + +#[get("/collections/{name}/snapshots/{snapshot_name}")] +async fn get_snapshot( + dispatcher: web::Data, + path: web::Path<(String, String)>, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // Nothing to verify. + let pass = new_unchecked_verification_pass(); + + let (collection_name, snapshot_name) = path.into_inner(); + do_get_snapshot( + dispatcher.toc(&access, &pass), + access, + &collection_name, + &snapshot_name, + ) + .await +} + +#[get("/snapshots")] +async fn list_full_snapshots( + dispatcher: web::Data, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + helpers::time(do_list_full_snapshots( + dispatcher.toc(&access, &pass), + access, + )) + .await +} + +#[post("/snapshots")] +async fn create_full_snapshot( + dispatcher: web::Data, + params: valid::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let future = async move { do_create_full_snapshot(dispatcher.get_ref(), access).await }; + helpers::time_or_accept(future, params.wait.unwrap_or(true)).await +} + +#[get("/snapshots/{snapshot_name}")] +async fn get_full_snapshot( + dispatcher: web::Data, + path: web::Path, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let snapshot_name = path.into_inner(); + do_get_full_snapshot(dispatcher.toc(&access, &pass), access, &snapshot_name).await +} + +#[delete("/snapshots/{snapshot_name}")] +async fn delete_full_snapshot( + dispatcher: web::Data, + path: web::Path, + params: valid::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let future = async move { + let snapshot_name = path.into_inner(); + do_delete_full_snapshot(dispatcher.get_ref(), access, &snapshot_name).await + }; + + helpers::time_or_accept(future, params.wait.unwrap_or(true)).await +} + +#[delete("/collections/{name}/snapshots/{snapshot_name}")] +async fn delete_collection_snapshot( + dispatcher: web::Data, + path: web::Path<(String, String)>, + params: valid::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let future = async move { + let (collection_name, snapshot_name) = path.into_inner(); + + do_delete_collection_snapshot( + dispatcher.get_ref(), + access, + &collection_name, + &snapshot_name, + ) + .await + }; + + helpers::time_or_accept(future, params.wait.unwrap_or(true)).await +} + +#[get("/collections/{collection}/shards/{shard}/snapshots")] +async fn list_shard_snapshots( + dispatcher: web::Data, + path: web::Path<(String, ShardId)>, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let (collection, shard) = path.into_inner(); + + let future = common::snapshots::list_shard_snapshots( + dispatcher.toc(&access, &pass).clone(), + access, + collection, + shard, + ) + .map_err(Into::into); + + helpers::time(future).await +} + +#[post("/collections/{collection}/shards/{shard}/snapshots")] +async fn create_shard_snapshot( + dispatcher: web::Data, + path: web::Path<(String, ShardId)>, + query: web::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let (collection, shard) = path.into_inner(); + let future = common::snapshots::create_shard_snapshot( + dispatcher.toc(&access, &pass).clone(), + access, + collection, + shard, + ); + + helpers::time_or_accept(future, query.wait.unwrap_or(true)).await +} + +#[get("/collections/{collection}/shards/{shard}/snapshot")] +async fn stream_shard_snapshot( + dispatcher: web::Data, + path: web::Path<(String, ShardId)>, + ActixAccess(access): ActixAccess, +) -> Result { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let (collection, shard) = path.into_inner(); + Ok(common::snapshots::stream_shard_snapshot( + dispatcher.toc(&access, &pass).clone(), + access, + collection, + shard, + ) + .await?) +} + +// TODO: `PUT` (same as `recover_from_snapshot`) or `POST`!? +#[put("/collections/{collection}/shards/{shard}/snapshots/recover")] +async fn recover_shard_snapshot( + dispatcher: web::Data, + http_client: web::Data, + path: web::Path<(String, ShardId)>, + query: web::Query, + web::Json(request): web::Json, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let future = async move { + let (collection, shard) = path.into_inner(); + + common::snapshots::recover_shard_snapshot( + dispatcher.toc(&access, &pass).clone(), + access, + collection, + shard, + request.location, + request.priority.unwrap_or_default(), + request.checksum, + http_client.as_ref().clone(), + request.api_key, + ) + .await?; + + Ok(true) + }; + + helpers::time_or_accept(future, query.wait.unwrap_or(true)).await +} + +// TODO: `POST` (same as `upload_snapshot`) or `PUT`!? +#[post("/collections/{collection}/shards/{shard}/snapshots/upload")] +async fn upload_shard_snapshot( + dispatcher: web::Data, + path: web::Path<(String, ShardId)>, + query: web::Query, + MultipartForm(form): MultipartForm, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let (collection, shard) = path.into_inner(); + let SnapshotUploadingParam { + wait, + priority, + checksum, + } = query.into_inner(); + + // - `recover_shard_snapshot_impl` is *not* cancel safe + // - but the task is *spawned* on the runtime and won't be cancelled, if request is cancelled + + let future = cancel::future::spawn_cancel_on_drop(move |cancel| async move { + // TODO: Run this check before the multipart blob is uploaded + let collection_pass = access + .check_global_access(AccessRequirements::new().manage())? + .issue_pass(&collection); + + if let Some(checksum) = checksum { + let snapshot_checksum = hash_file(form.snapshot.file.path()).await?; + if !hashes_equal(snapshot_checksum.as_str(), checksum.as_str()) { + return Err(StorageError::checksum_mismatch(snapshot_checksum, checksum)); + } + } + + let future = async { + let collection = dispatcher + .toc(&access, &pass) + .get_collection(&collection_pass) + .await?; + collection.assert_shard_exists(shard).await?; + + Result::<_, StorageError>::Ok(collection) + }; + + let collection = cancel::future::cancel_on_token(cancel.clone(), future).await??; + + // `recover_shard_snapshot_impl` is *not* cancel safe + common::snapshots::recover_shard_snapshot_impl( + dispatcher.toc(&access, &pass), + &collection, + shard, + form.snapshot.file.path(), + priority.unwrap_or_default(), + cancel, + ) + .await?; + + Ok(()) + }) + .map(|x| x.map_err(Into::into).and_then(|x| x)); + + helpers::time_or_accept(future, wait.unwrap_or(true)).await +} + +#[get("/collections/{collection}/shards/{shard}/snapshots/{snapshot}")] +async fn download_shard_snapshot( + dispatcher: web::Data, + path: web::Path<(String, ShardId, String)>, + ActixAccess(access): ActixAccess, +) -> Result { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let (collection, shard, snapshot) = path.into_inner(); + let collection_pass = + access.check_collection_access(&collection, AccessRequirements::new().whole())?; + let collection = dispatcher + .toc(&access, &pass) + .get_collection(&collection_pass) + .await?; + let snapshots_storage_manager = collection.get_snapshots_storage_manager()?; + let snapshot_path = collection + .shards_holder() + .read() + .await + .get_shard_snapshot_path(collection.snapshots_path(), shard, &snapshot) + .await?; + let snapshot_stream = snapshots_storage_manager + .get_snapshot_stream(&snapshot_path) + .await?; + Ok(snapshot_stream) +} + +#[delete("/collections/{collection}/shards/{shard}/snapshots/{snapshot}")] +async fn delete_shard_snapshot( + dispatcher: web::Data, + path: web::Path<(String, ShardId, String)>, + query: web::Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let (collection, shard, snapshot) = path.into_inner(); + let future = common::snapshots::delete_shard_snapshot( + dispatcher.toc(&access, &pass).clone(), + access, + collection, + shard, + snapshot, + ) + .map_ok(|_| true) + .map_err(Into::into); + + helpers::time_or_accept(future, query.wait.unwrap_or(true)).await +} + +// Configure services +pub fn config_snapshots_api(cfg: &mut web::ServiceConfig) { + cfg.service(list_snapshots) + .service(create_snapshot) + .service(upload_snapshot) + .service(recover_from_snapshot) + .service(get_snapshot) + .service(list_full_snapshots) + .service(create_full_snapshot) + .service(get_full_snapshot) + .service(delete_full_snapshot) + .service(delete_collection_snapshot) + .service(list_shard_snapshots) + .service(create_shard_snapshot) + .service(stream_shard_snapshot) + .service(recover_shard_snapshot) + .service(upload_shard_snapshot) + .service(download_shard_snapshot) + .service(delete_shard_snapshot); +} diff --git a/src/actix/api/update_api.rs b/src/actix/api/update_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..ae34d75437dd757575e50efb185920c7681d003e --- /dev/null +++ b/src/actix/api/update_api.rs @@ -0,0 +1,392 @@ +use actix_web::rt::time::Instant; +use actix_web::{delete, post, put, web, Responder}; +use actix_web_validator::{Json, Path, Query}; +use api::rest::schema::PointInsertOperations; +use api::rest::UpdateVectors; +use collection::operations::payload_ops::{DeletePayload, SetPayload}; +use collection::operations::point_ops::{PointsSelector, WriteOrdering}; +use collection::operations::types::UpdateResult; +use collection::operations::vector_ops::DeleteVectors; +use collection::operations::verification::new_unchecked_verification_pass; +use schemars::JsonSchema; +use segment::json_path::JsonPath; +use serde::{Deserialize, Serialize}; +use storage::content_manager::collection_verification::check_strict_mode; +use storage::dispatcher::Dispatcher; +use validator::Validate; + +use super::CollectionPath; +use crate::actix::auth::ActixAccess; +use crate::actix::helpers::{self, process_response, process_response_error}; +use crate::common::points::{ + do_batch_update_points, do_clear_payload, do_create_index, do_delete_index, do_delete_payload, + do_delete_points, do_delete_vectors, do_overwrite_payload, do_set_payload, do_update_vectors, + do_upsert_points, CreateFieldIndex, UpdateOperations, +}; + +#[derive(Deserialize, Validate)] +struct FieldPath { + #[serde(rename = "field_name")] + name: JsonPath, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct UpdateParam { + pub wait: Option, + pub ordering: Option, +} + +#[put("/collections/{name}/points")] +async fn upsert_points( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // nothing to verify. + let pass = new_unchecked_verification_pass(); + + let operation = operation.into_inner(); + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + helpers::time(do_upsert_points( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + )) + .await +} + +#[post("/collections/{name}/points/delete")] +async fn delete_points( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let operation = operation.into_inner(); + let pass = + match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + helpers::time(do_delete_points( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + )) + .await +} + +#[put("/collections/{name}/points/vectors")] +async fn update_vectors( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let operation = operation.into_inner(); + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + helpers::time(do_update_vectors( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + )) + .await +} + +#[post("/collections/{name}/points/vectors/delete")] +async fn delete_vectors( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + + let operation = operation.into_inner(); + let pass = + match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await { + Ok(pass) => pass, + Err(err) => return process_response_error(err, timing, None), + }; + + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + let response = do_delete_vectors( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + ) + .await; + process_response(response, timing, None) +} + +#[post("/collections/{name}/points/payload")] +async fn set_payload( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let operation = operation.into_inner(); + + let pass = + match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + helpers::time(do_set_payload( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + )) + .await +} + +#[put("/collections/{name}/points/payload")] +async fn overwrite_payload( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let operation = operation.into_inner(); + let pass = + match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + helpers::time(do_overwrite_payload( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + )) + .await +} + +#[post("/collections/{name}/points/payload/delete")] +async fn delete_payload( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let operation = operation.into_inner(); + let pass = + match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + helpers::time(do_delete_payload( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + )) + .await +} + +#[post("/collections/{name}/points/payload/clear")] +async fn clear_payload( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let operation = operation.into_inner(); + let pass = + match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + helpers::time(do_clear_payload( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + )) + .await +} + +#[post("/collections/{name}/points/batch")] +async fn update_batch( + dispatcher: web::Data, + collection: Path, + operations: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let operations = operations.into_inner(); + + let mut vpass = None; + for operation in operations.operations.iter() { + let pass = match check_strict_mode(operation, None, &collection.name, &dispatcher, &access) + .await + { + Ok(pass) => pass, + Err(err) => return process_response_error(err, Instant::now(), None), + }; + vpass = Some(pass); + } + + // vpass == None => No update operation available + let Some(pass) = vpass else { + return process_response::>(Ok(vec![]), timing, None); + }; + + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + let response = do_batch_update_points( + dispatcher.toc(&access, &pass).clone(), + collection.into_inner().name, + operations.operations, + None, + None, + wait, + ordering, + access, + ) + .await; + process_response(response, timing, None) +} +#[put("/collections/{name}/index")] +async fn create_field_index( + dispatcher: web::Data, + collection: Path, + operation: Json, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let operation = operation.into_inner(); + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + let response = do_create_index( + dispatcher.into_inner(), + collection.into_inner().name, + operation, + None, + None, + wait, + ordering, + access, + ) + .await; + process_response(response, timing, None) +} + +#[delete("/collections/{name}/index/{field_name}")] +async fn delete_field_index( + dispatcher: web::Data, + collection: Path, + field: Path, + params: Query, + ActixAccess(access): ActixAccess, +) -> impl Responder { + let timing = Instant::now(); + let wait = params.wait.unwrap_or(false); + let ordering = params.ordering.unwrap_or_default(); + + let response = do_delete_index( + dispatcher.into_inner(), + collection.into_inner().name, + field.name.clone(), + None, + None, + wait, + ordering, + access, + ) + .await; + process_response(response, timing, None) +} + +// Configure services +pub fn config_update_api(cfg: &mut web::ServiceConfig) { + cfg.service(upsert_points) + .service(delete_points) + .service(update_vectors) + .service(delete_vectors) + .service(set_payload) + .service(overwrite_payload) + .service(delete_payload) + .service(clear_payload) + .service(create_field_index) + .service(delete_field_index) + .service(update_batch); +} diff --git a/src/actix/auth.rs b/src/actix/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..adf1c42e5b2a543b555ffc732ba77bf8bab370cf --- /dev/null +++ b/src/actix/auth.rs @@ -0,0 +1,160 @@ +use std::convert::Infallible; +use std::future::{ready, Ready}; +use std::sync::Arc; + +use actix_web::body::{BoxBody, EitherBody}; +use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}; +use actix_web::{Error, FromRequest, HttpMessage, HttpResponse, ResponseError}; +use futures_util::future::LocalBoxFuture; +use storage::rbac::Access; + +use super::helpers::HttpError; +use crate::common::auth::{AuthError, AuthKeys}; + +pub struct Auth { + auth_keys: AuthKeys, + whitelist: Vec, +} + +impl Auth { + pub fn new(auth_keys: AuthKeys, whitelist: Vec) -> Self { + Self { + auth_keys, + whitelist, + } + } +} + +impl Transform for Auth +where + S: Service>, Error = Error> + + 'static, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse>; + type Error = Error; + type InitError = (); + type Transform = AuthMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(AuthMiddleware { + auth_keys: Arc::new(self.auth_keys.clone()), + whitelist: self.whitelist.clone(), + service: Arc::new(service), + })) + } +} + +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct WhitelistItem(pub String, pub PathMode); + +impl WhitelistItem { + pub fn exact>(path: S) -> Self { + Self(path.into(), PathMode::Exact) + } + + pub fn prefix>(path: S) -> Self { + Self(path.into(), PathMode::Prefix) + } + + pub fn matches(&self, other: &str) -> bool { + self.1.check(&self.0, other) + } +} + +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +pub enum PathMode { + /// Path must match exactly + Exact, + /// Path must have given prefix + Prefix, +} + +impl PathMode { + fn check(&self, key: &str, other: &str) -> bool { + match self { + Self::Exact => key == other, + Self::Prefix => other.starts_with(key), + } + } +} + +pub struct AuthMiddleware { + auth_keys: Arc, + /// List of items whitelisted from authentication. + whitelist: Vec, + service: Arc, +} + +impl AuthMiddleware { + pub fn is_path_whitelisted(&self, path: &str) -> bool { + self.whitelist.iter().any(|item| item.matches(path)) + } +} + +impl Service for AuthMiddleware +where + S: Service>, Error = Error> + + 'static, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse>; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + let path = req.path(); + + if self.is_path_whitelisted(path) { + return Box::pin(self.service.call(req)); + } + + let auth_keys = self.auth_keys.clone(); + let service = self.service.clone(); + Box::pin(async move { + match auth_keys + .validate_request(|key| req.headers().get(key).and_then(|val| val.to_str().ok())) + .await + { + Ok(access) => { + let previous = req.extensions_mut().insert::(access); + debug_assert!( + previous.is_none(), + "Previous access object should not exist in the request" + ); + service.call(req).await + } + Err(e) => { + let resp = match e { + AuthError::Unauthorized(e) => HttpResponse::Unauthorized().body(e), + AuthError::Forbidden(e) => HttpResponse::Forbidden().body(e), + AuthError::StorageError(e) => HttpError::from(e).error_response(), + }; + Ok(req.into_response(resp).map_into_right_body()) + } + } + }) + } +} + +pub struct ActixAccess(pub Access); + +impl FromRequest for ActixAccess { + type Error = Infallible; + type Future = Ready>; + + fn from_request( + req: &actix_web::HttpRequest, + _payload: &mut actix_web::dev::Payload, + ) -> Self::Future { + let access = req.extensions_mut().remove::().unwrap_or_else(|| { + Access::full("All requests have full by default access when API key is not configured") + }); + ready(Ok(ActixAccess(access))) + } +} diff --git a/src/actix/certificate_helpers.rs b/src/actix/certificate_helpers.rs new file mode 100644 index 0000000000000000000000000000000000000000..0a8e60c9eafc2a9d71c6d347970ef71580037148 --- /dev/null +++ b/src/actix/certificate_helpers.rs @@ -0,0 +1,203 @@ +use std::fmt::Debug; +use std::fs::File; +use std::io::{self, BufRead, BufReader}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use parking_lot::RwLock; +use rustls::client::VerifierBuilderError; +use rustls::pki_types::CertificateDer; +use rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier}; +use rustls::sign::CertifiedKey; +use rustls::{crypto, RootCertStore, ServerConfig}; +use rustls_pemfile::Item; + +use crate::settings::{Settings, TlsConfig}; + +type Result = std::result::Result; + +/// A TTL based rotating server certificate resolver +#[derive(Debug)] +struct RotatingCertificateResolver { + /// TLS configuration used for loading/refreshing certified key + tls_config: TlsConfig, + + /// TTL for each rotation + ttl: Option, + + /// Current certified key + key: RwLock, +} + +impl RotatingCertificateResolver { + pub fn new(tls_config: TlsConfig, ttl: Option) -> Result { + let certified_key = load_certified_key(&tls_config)?; + + Ok(Self { + tls_config, + ttl, + key: RwLock::new(CertifiedKeyWithAge::from(certified_key)), + }) + } + + /// Get certificate key or refresh + /// + /// The key is automatically refreshed when the TTL is reached. + /// If refreshing fails, an error is logged and the old key is persisted. + fn get_key_or_refresh(&self) -> Arc { + // Get read-only lock to the key. If TTL is not configured or is not expired, return key. + let key = self.key.read(); + let ttl = match self.ttl { + Some(ttl) if key.is_expired(ttl) => ttl, + _ => return key.key.clone(), + }; + drop(key); + + // If TTL is expired: + // - get read-write lock to the key + // - *re-check that TTL is expired* (to avoid refreshing the key multiple times from concurrent threads) + // - refresh and return the key + let mut key = self.key.write(); + if key.is_expired(ttl) { + if let Err(err) = key.refresh(&self.tls_config) { + log::error!("Failed to refresh server TLS certificate, keeping current: {err}"); + } + } + + key.key.clone() + } +} + +impl ResolvesServerCert for RotatingCertificateResolver { + fn resolve(&self, _client_hello: ClientHello<'_>) -> Option> { + Some(self.get_key_or_refresh()) + } +} + +#[derive(Debug)] +struct CertifiedKeyWithAge { + /// Last time the certificate was updated/replaced + last_update: Instant, + + /// Current certified key + key: Arc, +} + +impl CertifiedKeyWithAge { + pub fn from(key: Arc) -> Self { + Self { + last_update: Instant::now(), + key, + } + } + + pub fn refresh(&mut self, tls_config: &TlsConfig) -> Result<()> { + *self = Self::from(load_certified_key(tls_config)?); + Ok(()) + } + + pub fn age(&self) -> Duration { + self.last_update.elapsed() + } + + pub fn is_expired(&self, ttl: Duration) -> bool { + self.age() >= ttl + } +} + +/// Load TLS configuration and construct certified key. +fn load_certified_key(tls_config: &TlsConfig) -> Result> { + // Load certificates + let certs: Vec = with_buf_read(&tls_config.cert, |rd| { + rustls_pemfile::read_all(rd).collect::>>() + })? + .into_iter() + .filter_map(|item| match item { + Item::X509Certificate(data) => Some(data), + _ => None, + }) + .collect(); + if certs.is_empty() { + return Err(Error::NoServerCert); + } + + // Load private key + let private_key_item = + with_buf_read(&tls_config.key, rustls_pemfile::read_one)?.ok_or(Error::NoPrivateKey)?; + let private_key = match private_key_item { + Item::Pkcs1Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey), + Item::Pkcs8Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey), + Item::Sec1Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey), + _ => return Err(Error::InvalidPrivateKey), + }; + let signing_key = crypto::ring::sign::any_supported_type(&private_key).map_err(Error::Sign)?; + + // Construct certified key + let certified_key = CertifiedKey::new(certs, signing_key); + Ok(Arc::new(certified_key)) +} + +/// Generate an actix server configuration with TLS +/// +/// Uses TLS settings as configured in configuration by user. +pub fn actix_tls_server_config(settings: &Settings) -> Result { + let config = ServerConfig::builder(); + let tls_config = settings + .tls + .clone() + .ok_or_else(Settings::tls_config_is_undefined_error) + .map_err(Error::Io)?; + + // Verify client CA or not + let config = if settings.service.verify_https_client_certificate { + let mut root_cert_store = RootCertStore::empty(); + let ca_certs: Vec = with_buf_read(&tls_config.ca_cert, |rd| { + rustls_pemfile::certs(rd).collect() + })?; + root_cert_store.add_parsable_certificates(ca_certs); + let client_cert_verifier = WebPkiClientVerifier::builder(root_cert_store.into()) + .build() + .map_err(Error::ClientCertVerifier)?; + config.with_client_cert_verifier(client_cert_verifier) + } else { + config.with_no_client_auth() + }; + + // Configure rotating certificate resolver + let ttl = match tls_config.cert_ttl { + None | Some(0) => None, + Some(seconds) => Some(Duration::from_secs(seconds)), + }; + let cert_resolver = RotatingCertificateResolver::new(tls_config, ttl)?; + let config = config.with_cert_resolver(Arc::new(cert_resolver)); + + Ok(config) +} + +fn with_buf_read(path: &str, f: impl FnOnce(&mut dyn BufRead) -> io::Result) -> Result { + let file = File::open(path).map_err(|err| Error::OpenFile(err, path.into()))?; + let mut reader = BufReader::new(file); + let dyn_reader: &mut dyn BufRead = &mut reader; + f(dyn_reader).map_err(|err| Error::ReadFile(err, path.into())) +} + +/// Actix TLS errors. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("TLS file could not be opened: {1}")] + OpenFile(#[source] io::Error, String), + #[error("TLS file could not be read: {1}")] + ReadFile(#[source] io::Error, String), + #[error("general TLS IO error")] + Io(#[source] io::Error), + #[error("no server certificate found")] + NoServerCert, + #[error("no private key found")] + NoPrivateKey, + #[error("invalid private key")] + InvalidPrivateKey, + #[error("TLS signing error")] + Sign(#[source] rustls::Error), + #[error("client certificate verification")] + ClientCertVerifier(#[source] VerifierBuilderError), +} diff --git a/src/actix/helpers.rs b/src/actix/helpers.rs new file mode 100644 index 0000000000000000000000000000000000000000..2a54230940100111cf815cd6d213a90d08ea4573 --- /dev/null +++ b/src/actix/helpers.rs @@ -0,0 +1,179 @@ +use std::fmt::Debug; +use std::future::Future; + +use actix_web::rt::time::Instant; +use actix_web::{http, HttpResponse, ResponseError}; +use api::rest::models::{ApiResponse, ApiStatus, HardwareUsage}; +use collection::operations::types::CollectionError; +use common::counter::hardware_accumulator::HwMeasurementAcc; +use serde::Serialize; +use storage::content_manager::errors::StorageError; +use storage::content_manager::toc::request_hw_counter::RequestHwCounter; +use storage::dispatcher::Dispatcher; + +pub fn get_request_hardware_counter( + dispatcher: &Dispatcher, + collection_name: String, + report_to_api: bool, +) -> RequestHwCounter { + RequestHwCounter::new( + HwMeasurementAcc::new_with_drain(&dispatcher.get_collection_hw_metrics(collection_name)), + report_to_api, + false, + ) +} + +pub fn accepted_response(timing: Instant, hardware_usage: Option) -> HttpResponse { + HttpResponse::Accepted().json(ApiResponse::<()> { + result: None, + status: ApiStatus::Accepted, + time: timing.elapsed().as_secs_f64(), + usage: hardware_usage, + }) +} + +pub fn process_response( + response: Result, + timing: Instant, + hardware_usage: Option, +) -> HttpResponse +where + T: Serialize, +{ + match response { + Ok(res) => HttpResponse::Ok().json(ApiResponse { + result: Some(res), + status: ApiStatus::Ok, + time: timing.elapsed().as_secs_f64(), + usage: hardware_usage, + }), + Err(err) => process_response_error(err, timing, hardware_usage), + } +} + +pub fn process_response_error( + err: StorageError, + timing: Instant, + hardware_usage: Option, +) -> HttpResponse { + log_service_error(&err); + + let error = HttpError::from(err); + + HttpResponse::build(error.status_code()).json(ApiResponse::<()> { + result: None, + status: ApiStatus::Error(error.to_string()), + time: timing.elapsed().as_secs_f64(), + usage: hardware_usage, + }) +} + +/// Response wrapper for a `Future` returning `Result`. +/// +/// # Cancel safety +/// +/// Future must be cancel safe. +pub async fn time(future: Fut) -> HttpResponse +where + Fut: Future>, + T: serde::Serialize, +{ + time_impl(async { future.await.map(Some) }).await +} + +/// Response wrapper for a `Future` returning `Result`. +/// If `wait` is false, returns `202 Accepted` immediately. +pub async fn time_or_accept(future: Fut, wait: bool) -> HttpResponse +where + Fut: Future> + Send + 'static, + T: serde::Serialize + Send + 'static, +{ + let future = async move { + let handle = tokio::task::spawn(async move { + let result = future.await; + + if !wait { + if let Err(err) = &result { + log_service_error(err); + } + } + + result + }); + + if wait { + handle.await?.map(Some) + } else { + Ok(None) + } + }; + + time_impl(future).await +} + +/// # Cancel safety +/// +/// Future must be cancel safe. +async fn time_impl(future: Fut) -> HttpResponse +where + Fut: Future, StorageError>>, + T: serde::Serialize, +{ + let instant = Instant::now(); + match future.await.transpose() { + Some(res) => process_response(res, instant, None), + None => accepted_response(instant, None), + } +} + +fn log_service_error(err: &StorageError) { + if let StorageError::ServiceError { backtrace, .. } = err { + log::error!("Error processing request: {err}"); + + if let Some(backtrace) = backtrace { + log::trace!("Backtrace: {backtrace}"); + } + } +} + +pub type HttpResult = Result; + +#[derive(Clone, Debug, thiserror::Error)] +#[error("{0}")] +pub struct HttpError(StorageError); + +impl ResponseError for HttpError { + fn status_code(&self) -> http::StatusCode { + match &self.0 { + StorageError::BadInput { .. } => http::StatusCode::BAD_REQUEST, + StorageError::NotFound { .. } => http::StatusCode::NOT_FOUND, + StorageError::ServiceError { .. } => http::StatusCode::INTERNAL_SERVER_ERROR, + StorageError::BadRequest { .. } => http::StatusCode::BAD_REQUEST, + StorageError::Locked { .. } => http::StatusCode::FORBIDDEN, + StorageError::Timeout { .. } => http::StatusCode::REQUEST_TIMEOUT, + StorageError::AlreadyExists { .. } => http::StatusCode::CONFLICT, + StorageError::ChecksumMismatch { .. } => http::StatusCode::BAD_REQUEST, + StorageError::Forbidden { .. } => http::StatusCode::FORBIDDEN, + StorageError::PreconditionFailed { .. } => http::StatusCode::INTERNAL_SERVER_ERROR, + StorageError::InferenceError { .. } => http::StatusCode::BAD_REQUEST, + } + } +} + +impl From for HttpError { + fn from(err: StorageError) -> Self { + HttpError(err) + } +} + +impl From for HttpError { + fn from(err: CollectionError) -> Self { + HttpError(err.into()) + } +} + +impl From for HttpError { + fn from(err: std::io::Error) -> Self { + HttpError(err.into()) // TODO: Is this good enough?.. 🤔 + } +} diff --git a/src/actix/mod.rs b/src/actix/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..5b38ce71501f40ac54a53a1cf02d9e021ff57321 --- /dev/null +++ b/src/actix/mod.rs @@ -0,0 +1,262 @@ +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod actix_telemetry; +pub mod api; +mod auth; +mod certificate_helpers; +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod helpers; +pub mod web_ui; + +use std::io; +use std::sync::Arc; + +use ::api::rest::models::{ApiResponse, ApiStatus, VersionInfo}; +use actix_cors::Cors; +use actix_multipart::form::tempfile::TempFileConfig; +use actix_multipart::form::MultipartFormConfig; +use actix_web::middleware::{Compress, Condition, Logger}; +use actix_web::{error, get, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; +use actix_web_extras::middleware::Condition as ConditionEx; +use api::facet_api::config_facet_api; +use collection::operations::validation; +use collection::operations::verification::new_unchecked_verification_pass; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; + +use crate::actix::api::cluster_api::config_cluster_api; +use crate::actix::api::collections_api::config_collections_api; +use crate::actix::api::count_api::count_points; +use crate::actix::api::debug_api::config_debugger_api; +use crate::actix::api::discovery_api::config_discovery_api; +use crate::actix::api::issues_api::config_issues_api; +use crate::actix::api::local_shard_api::config_local_shard_api; +use crate::actix::api::query_api::config_query_api; +use crate::actix::api::recommend_api::config_recommend_api; +use crate::actix::api::retrieve_api::{get_point, get_points, scroll_points}; +use crate::actix::api::search_api::config_search_api; +use crate::actix::api::service_api::config_service_api; +use crate::actix::api::shards_api::config_shards_api; +use crate::actix::api::snapshot_api::config_snapshots_api; +use crate::actix::api::update_api::config_update_api; +use crate::actix::auth::{Auth, WhitelistItem}; +use crate::actix::web_ui::{web_ui_factory, web_ui_folder, WEB_UI_PATH}; +use crate::common::auth::AuthKeys; +use crate::common::debugger::DebuggerState; +use crate::common::health; +use crate::common::http_client::HttpClient; +use crate::common::telemetry::TelemetryCollector; +use crate::settings::{max_web_workers, Settings}; +use crate::tracing::LoggerHandle; + +#[get("/")] +pub async fn index() -> impl Responder { + HttpResponse::Ok().json(VersionInfo::default()) +} + +#[allow(dead_code)] +pub fn init( + dispatcher: Arc, + telemetry_collector: Arc>, + health_checker: Option>, + settings: Settings, + logger_handle: LoggerHandle, +) -> io::Result<()> { + actix_web::rt::System::new().block_on(async { + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + let auth_keys = AuthKeys::try_create( + &settings.service, + dispatcher + .toc(&Access::full("For JWT validation"), &pass) + .clone(), + ); + let upload_dir = dispatcher + .toc(&Access::full("For upload dir"), &pass) + .upload_dir() + .unwrap(); + let dispatcher_data = web::Data::from(dispatcher); + let actix_telemetry_collector = telemetry_collector + .lock() + .await + .actix_telemetry_collector + .clone(); + let debugger_state = web::Data::new(DebuggerState::from_settings(&settings)); + let telemetry_collector_data = web::Data::from(telemetry_collector); + let logger_handle_data = web::Data::new(logger_handle); + let http_client = web::Data::new(HttpClient::from_settings(&settings)?); + let health_checker = web::Data::new(health_checker); + let web_ui_available = web_ui_folder(&settings); + let service_config = web::Data::new(settings.service.clone()); + + let mut api_key_whitelist = vec![ + WhitelistItem::exact("/"), + WhitelistItem::exact("/healthz"), + WhitelistItem::prefix("/readyz"), + WhitelistItem::prefix("/livez"), + ]; + if web_ui_available.is_some() { + api_key_whitelist.push(WhitelistItem::prefix(WEB_UI_PATH)); + } + + let mut server = HttpServer::new(move || { + let cors = Cors::default() + .allow_any_origin() + .allow_any_method() + .allow_any_header(); + let validate_path_config = actix_web_validator::PathConfig::default() + .error_handler(|err, rec| validation_error_handler("path parameters", err, rec)); + let validate_query_config = actix_web_validator::QueryConfig::default() + .error_handler(|err, rec| validation_error_handler("query parameters", err, rec)); + let validate_json_config = actix_web_validator::JsonConfig::default() + .limit(settings.service.max_request_size_mb * 1024 * 1024) + .error_handler(|err, rec| validation_error_handler("JSON body", err, rec)); + + let mut app = App::new() + .wrap(Compress::default()) // Reads the `Accept-Encoding` header to negotiate which compression codec to use. + // api_key middleware + // note: the last call to `wrap()` or `wrap_fn()` is executed first + .wrap(ConditionEx::from_option(auth_keys.as_ref().map( + |auth_keys| Auth::new(auth_keys.clone(), api_key_whitelist.clone()), + ))) + .wrap(Condition::new(settings.service.enable_cors, cors)) + .wrap( + // Set up logger, but avoid logging hot status endpoints + Logger::default() + .exclude("/") + .exclude("/metrics") + .exclude("/telemetry") + .exclude("/healthz") + .exclude("/readyz") + .exclude("/livez"), + ) + .wrap(actix_telemetry::ActixTelemetryTransform::new( + actix_telemetry_collector.clone(), + )) + .app_data(dispatcher_data.clone()) + .app_data(telemetry_collector_data.clone()) + .app_data(logger_handle_data.clone()) + .app_data(http_client.clone()) + .app_data(debugger_state.clone()) + .app_data(health_checker.clone()) + .app_data(validate_path_config) + .app_data(validate_query_config) + .app_data(validate_json_config) + .app_data(TempFileConfig::default().directory(&upload_dir)) + .app_data(MultipartFormConfig::default().total_limit(usize::MAX)) + .app_data(service_config.clone()) + .service(index) + .configure(config_collections_api) + .configure(config_snapshots_api) + .configure(config_update_api) + .configure(config_cluster_api) + .configure(config_service_api) + .configure(config_search_api) + .configure(config_recommend_api) + .configure(config_discovery_api) + .configure(config_query_api) + .configure(config_facet_api) + .configure(config_shards_api) + .configure(config_issues_api) + .configure(config_debugger_api) + .configure(config_local_shard_api) + // Ordering of services is important for correct path pattern matching + // See: + .service(scroll_points) + .service(count_points) + .service(get_point) + .service(get_points); + + if let Some(static_folder) = web_ui_available.as_deref() { + app = app.service(web_ui_factory(static_folder)); + } + + app + }) + .workers(max_web_workers(&settings)); + + let port = settings.service.http_port; + let bind_addr = format!("{}:{}", settings.service.host, port); + + // With TLS enabled, bind with certificate helper and Rustls, or bind regularly + server = if settings.service.enable_tls { + log::info!( + "TLS enabled for REST API (TTL: {})", + settings + .tls + .as_ref() + .and_then(|tls| tls.cert_ttl) + .map(|ttl| ttl.to_string()) + .unwrap_or_else(|| "none".into()), + ); + + let config = certificate_helpers::actix_tls_server_config(&settings) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + server.bind_rustls_0_23(bind_addr, config)? + } else { + log::info!("TLS disabled for REST API"); + + server.bind(bind_addr)? + }; + + log::info!("Qdrant HTTP listening on {}", port); + server.run().await + }) +} + +fn validation_error_handler( + name: &str, + err: actix_web_validator::Error, + _req: &HttpRequest, +) -> error::Error { + use actix_web_validator::error::DeserializeErrors; + + // Nicely describe deserialization and validation errors + let msg = match &err { + actix_web_validator::Error::Validate(errs) => { + validation::label_errors(format!("Validation error in {name}"), errs) + } + actix_web_validator::Error::Deserialize(err) => { + format!( + "Deserialize error in {name}: {}", + match err { + DeserializeErrors::DeserializeQuery(err) => err.to_string(), + DeserializeErrors::DeserializeJson(err) => err.to_string(), + DeserializeErrors::DeserializePath(err) => err.to_string(), + } + ) + } + actix_web_validator::Error::JsonPayloadError( + actix_web::error::JsonPayloadError::Deserialize(err), + ) => { + format!("Format error in {name}: {err}",) + } + err => err.to_string(), + }; + + // Build fitting response + let response = match &err { + actix_web_validator::Error::Validate(_) => HttpResponse::UnprocessableEntity(), + _ => HttpResponse::BadRequest(), + } + .json(ApiResponse::<()> { + result: None, + status: ApiStatus::Error(msg), + time: 0.0, + usage: None, + }); + error::InternalError::from_response(err, response).into() +} + +#[cfg(test)] +mod tests { + use ::api::grpc::api_crate_version; + + #[test] + fn test_version() { + assert_eq!( + api_crate_version(), + env!("CARGO_PKG_VERSION"), + "Qdrant and lib/api crate versions are not same" + ); + } +} diff --git a/src/actix/web_ui.rs b/src/actix/web_ui.rs new file mode 100644 index 0000000000000000000000000000000000000000..f962d10183a4ff2ee2a8aa77e3025a6d2558093b --- /dev/null +++ b/src/actix/web_ui.rs @@ -0,0 +1,115 @@ +use std::path::Path; + +use actix_web::dev::HttpServiceFactory; +use actix_web::http::header::HeaderValue; +use actix_web::middleware::DefaultHeaders; +use actix_web::web; + +use crate::settings::Settings; + +const DEFAULT_STATIC_DIR: &str = "./static"; +pub const WEB_UI_PATH: &str = "/dashboard"; + +pub fn web_ui_folder(settings: &Settings) -> Option { + let web_ui_enabled = settings.service.enable_static_content.unwrap_or(true); + + if web_ui_enabled { + let static_folder = settings + .service + .static_content_dir + .clone() + .unwrap_or_else(|| DEFAULT_STATIC_DIR.to_string()); + let static_folder_path = Path::new(&static_folder); + if !static_folder_path.exists() || !static_folder_path.is_dir() { + // enabled BUT folder does not exist + log::warn!( + "Static content folder for Web UI '{}' does not exist", + static_folder_path.display(), + ); + None + } else { + // enabled AND folder exists + Some(static_folder) + } + } else { + // not enabled + None + } +} + +pub fn web_ui_factory(static_folder: &str) -> impl HttpServiceFactory { + web::scope(WEB_UI_PATH) + .wrap(DefaultHeaders::new().add(("X-Frame-Options", HeaderValue::from_static("DENY")))) + .service(actix_files::Files::new("/", static_folder).index_file("index.html")) +} + +#[cfg(test)] +mod tests { + use actix_web::http::header::{self, HeaderMap}; + use actix_web::http::StatusCode; + use actix_web::test::{self, TestRequest}; + use actix_web::App; + + use super::*; + + fn assert_html_custom_headers(headers: &HeaderMap) { + let content_type = header::HeaderValue::from_static("text/html; charset=utf-8"); + assert_eq!(headers.get(header::CONTENT_TYPE), Some(&content_type)); + let x_frame_options = header::HeaderValue::from_static("DENY"); + assert_eq!(headers.get(header::X_FRAME_OPTIONS), Some(&x_frame_options),); + } + + #[actix_web::test] + async fn test_web_ui() { + let static_dir = String::from("static"); + let mut settings = Settings::new(None).unwrap(); + settings.service.static_content_dir = Some(static_dir.clone()); + + let maybe_static_folder = web_ui_folder(&settings); + if maybe_static_folder.is_none() { + println!("Skipping test because the static folder was not found."); + return; + } + + let static_folder = maybe_static_folder.unwrap(); + let srv = test::init_service(App::new().service(web_ui_factory(&static_folder))).await; + + // Index path (no trailing slash) + let req = TestRequest::with_uri(WEB_UI_PATH).to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_html_custom_headers(headers); + // Index path (trailing slash) + let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/").as_str()).to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_html_custom_headers(headers); + // Index path (index.html file) + let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/index.html").as_str()).to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_html_custom_headers(headers); + // Static asset (favicon.ico) + let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/favicon.ico").as_str()).to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!( + headers.get(header::CONTENT_TYPE), + Some(&header::HeaderValue::from_static("image/x-icon")), + ); + // Non-existing path (404 Not Found) + let fake_path = uuid::Uuid::new_v4().to_string(); + let srv = test::init_service(App::new().service(web_ui_factory(&fake_path))).await; + + let req = TestRequest::with_uri(WEB_UI_PATH).to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + let headers = res.headers(); + assert_eq!(headers.get(header::CONTENT_TYPE), None); + assert_eq!(headers.get(header::CONTENT_LENGTH), None); + } +} diff --git a/src/common/auth/claims.rs b/src/common/auth/claims.rs new file mode 100644 index 0000000000000000000000000000000000000000..1bef598ceef22e32e724d7e2a778b79ad7e45f81 --- /dev/null +++ b/src/common/auth/claims.rs @@ -0,0 +1,69 @@ +use segment::json_path::JsonPath; +use segment::types::{Condition, FieldCondition, Filter, Match, ValueVariants}; +use serde::{Deserialize, Serialize}; +use storage::rbac::Access; +use validator::{Validate, ValidationErrors}; + +#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)] +pub struct Claims { + /// Expiration time (seconds since UNIX epoch) + pub exp: Option, + + #[serde(default = "default_access")] + pub access: Access, + + /// Validate this token by looking for a value inside a collection. + pub value_exists: Option, +} + +#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)] +pub struct KeyValuePair { + key: JsonPath, + value: ValueVariants, +} + +impl KeyValuePair { + pub fn to_condition(&self) -> Condition { + Condition::Field(FieldCondition::new_match( + self.key.clone(), + Match::new_value(self.value.clone()), + )) + } +} + +#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)] +pub struct ValueExists { + collection: String, + matches: Vec, +} + +fn default_access() -> Access { + Access::full("Give full access when the access field is not present") +} + +impl ValueExists { + pub fn get_collection(&self) -> &str { + &self.collection + } + + pub fn to_filter(&self) -> Filter { + let conditions = self + .matches + .iter() + .map(|pair| pair.to_condition()) + .collect(); + + Filter { + should: None, + min_should: None, + must: Some(conditions), + must_not: None, + } + } +} + +impl Validate for Claims { + fn validate(&self) -> Result<(), ValidationErrors> { + ValidationErrors::merge_all(Ok(()), "access", self.access.validate()) + } +} diff --git a/src/common/auth/jwt_parser.rs b/src/common/auth/jwt_parser.rs new file mode 100644 index 0000000000000000000000000000000000000000..8f5bffe286a90cf4ea81d319beea368fae0ecded --- /dev/null +++ b/src/common/auth/jwt_parser.rs @@ -0,0 +1,155 @@ +use jsonwebtoken::errors::ErrorKind; +use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use validator::Validate; + +use super::claims::Claims; +use super::AuthError; + +#[derive(Clone)] +pub struct JwtParser { + key: DecodingKey, + validation: Validation, +} + +impl JwtParser { + const ALGORITHM: Algorithm = Algorithm::HS256; + + pub fn new(secret: &str) -> Self { + let key = DecodingKey::from_secret(secret.as_bytes()); + let mut validation = Validation::new(Self::ALGORITHM); + + // Qdrant server is the only audience + validation.validate_aud = false; + + // Expiration time leeway to account for clock skew + validation.leeway = 30; + + // All claims are optional + validation.required_spec_claims = Default::default(); + + JwtParser { key, validation } + } + + /// Decode the token and return the claims, this already validates the `exp` claim with some leeway. + /// Returns None when the token doesn't look like a JWT. + pub fn decode(&self, token: &str) -> Option> { + let claims = match decode::(token, &self.key, &self.validation) { + Ok(token_data) => token_data.claims, + Err(e) => { + return match e.kind() { + ErrorKind::ExpiredSignature | ErrorKind::InvalidSignature => { + Some(Err(AuthError::Forbidden(e.to_string()))) + } + _ => None, + } + } + }; + if let Err(e) = claims.validate() { + return Some(Err(AuthError::Unauthorized(e.to_string()))); + } + Some(Ok(claims)) + } +} + +#[cfg(test)] +mod tests { + use segment::types::ValueVariants; + use storage::rbac::{ + Access, CollectionAccess, CollectionAccessList, CollectionAccessMode, GlobalAccessMode, + PayloadConstraint, + }; + + use super::*; + + pub fn create_token(claims: &Claims) -> String { + use jsonwebtoken::{encode, EncodingKey, Header}; + + let key = EncodingKey::from_secret("secret".as_ref()); + let header = Header::new(JwtParser::ALGORITHM); + encode(&header, claims, &key).unwrap() + } + + #[test] + fn test_jwt_parser() { + let exp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + let claims = Claims { + exp: Some(exp), + access: Access::Collection(CollectionAccessList(vec![CollectionAccess { + collection: "collection".to_string(), + access: CollectionAccessMode::ReadWrite, + payload: Some(PayloadConstraint( + vec![ + ( + "field1".parse().unwrap(), + ValueVariants::String("value".to_string()), + ), + ("field2".parse().unwrap(), ValueVariants::Integer(42)), + ("field3".parse().unwrap(), ValueVariants::Bool(true)), + ] + .into_iter() + .collect(), + )), + }])), + value_exists: None, + }; + let token = create_token(&claims); + + let secret = "secret"; + let parser = JwtParser::new(secret); + let decoded_claims = parser.decode(&token).unwrap().unwrap(); + + assert_eq!(claims, decoded_claims); + } + + #[test] + fn test_exp_validation() { + let exp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() + - 31; // 31 seconds in the past, bigger than the 30 seconds leeway + + let mut claims = Claims { + exp: Some(exp), + access: Access::Global(GlobalAccessMode::Read), + value_exists: None, + }; + + let token = create_token(&claims); + + let secret = "secret"; + let parser = JwtParser::new(secret); + assert!(matches!( + parser.decode(&token), + Some(Err(AuthError::Forbidden(_))) + )); + + // Remove the exp claim and it should work + claims.exp = None; + let token = create_token(&claims); + + let decoded_claims = parser.decode(&token).unwrap().unwrap(); + + assert_eq!(claims, decoded_claims); + } + + #[test] + fn test_invalid_token() { + let claims = Claims { + exp: None, + access: Access::Global(GlobalAccessMode::Read), + value_exists: None, + }; + let token = create_token(&claims); + + assert!(matches!( + JwtParser::new("wrong-secret").decode(&token), + Some(Err(AuthError::Forbidden(_))) + )); + + assert!(JwtParser::new("secret").decode("foo.bar.baz").is_none()); + } +} diff --git a/src/common/auth/mod.rs b/src/common/auth/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..08f9c8f6b1284a793c7cecb6216b4eac74ad024f --- /dev/null +++ b/src/common/auth/mod.rs @@ -0,0 +1,165 @@ +use std::sync::Arc; + +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::ScrollRequestInternal; +use segment::types::{WithPayloadInterface, WithVector}; +use storage::content_manager::errors::StorageError; +use storage::content_manager::toc::TableOfContent; +use storage::rbac::Access; + +use self::claims::{Claims, ValueExists}; +use self::jwt_parser::JwtParser; +use super::strings::ct_eq; +use crate::settings::ServiceConfig; + +pub mod claims; +pub mod jwt_parser; + +pub const HTTP_HEADER_API_KEY: &str = "api-key"; + +/// The API keys used for auth +#[derive(Clone)] +pub struct AuthKeys { + /// A key allowing Read or Write operations + read_write: Option, + + /// A key allowing Read operations + read_only: Option, + + /// A JWT parser, based on the read_write key + jwt_parser: Option, + + /// Table of content, needed to do stateful validation of JWT + toc: Arc, +} + +#[derive(Debug)] +pub enum AuthError { + Unauthorized(String), + Forbidden(String), + StorageError(StorageError), +} + +impl AuthKeys { + fn get_jwt_parser(service_config: &ServiceConfig) -> Option { + if service_config.jwt_rbac.unwrap_or_default() { + service_config + .api_key + .as_ref() + .map(|secret| JwtParser::new(secret)) + } else { + None + } + } + + /// Defines the auth scheme given the service config + /// + /// Returns None if no scheme is specified. + pub fn try_create(service_config: &ServiceConfig, toc: Arc) -> Option { + match ( + service_config.api_key.clone(), + service_config.read_only_api_key.clone(), + ) { + (None, None) => None, + (read_write, read_only) => Some(Self { + read_write, + read_only, + jwt_parser: Self::get_jwt_parser(service_config), + toc, + }), + } + } + + /// Validate that the specified request is allowed for given keys. + pub async fn validate_request<'a>( + &self, + get_header: impl Fn(&'a str) -> Option<&'a str>, + ) -> Result { + let Some(key) = get_header(HTTP_HEADER_API_KEY) + .or_else(|| get_header("authorization").and_then(|v| v.strip_prefix("Bearer "))) + else { + return Err(AuthError::Unauthorized( + "Must provide an API key or an Authorization bearer token".to_string(), + )); + }; + + if self.can_write(key) { + return Ok(Access::full("Read-write access by key")); + } + + if self.can_read(key) { + return Ok(Access::full_ro("Read-only access by key")); + } + + if let Some(claims) = self.jwt_parser.as_ref().and_then(|p| p.decode(key)) { + let Claims { + exp: _, // already validated on decoding + access, + value_exists, + } = claims?; + + if let Some(value_exists) = value_exists { + self.validate_value_exists(&value_exists).await?; + } + + return Ok(access); + } + + Err(AuthError::Unauthorized( + "Invalid API key or JWT".to_string(), + )) + } + + async fn validate_value_exists(&self, value_exists: &ValueExists) -> Result<(), AuthError> { + let scroll_req = ScrollRequestInternal { + offset: None, + limit: Some(1), + filter: Some(value_exists.to_filter()), + with_payload: Some(WithPayloadInterface::Bool(false)), + with_vector: WithVector::Bool(false), + order_by: None, + }; + + let res = self + .toc + .scroll( + value_exists.get_collection(), + scroll_req, + None, + None, // no timeout + ShardSelectorInternal::All, + Access::full("JWT stateful validation"), + ) + .await + .map_err(|e| match e { + StorageError::NotFound { .. } => { + AuthError::Forbidden("Invalid JWT, stateful validation failed".to_string()) + } + _ => AuthError::StorageError(e), + })?; + + if res.points.is_empty() { + return Err(AuthError::Unauthorized( + "Invalid JWT, stateful validation failed".to_string(), + )); + }; + + Ok(()) + } + + /// Check if a key is allowed to read + #[inline] + fn can_read(&self, key: &str) -> bool { + self.read_only + .as_ref() + .is_some_and(|ro_key| ct_eq(ro_key, key)) + } + + /// Check if a key is allowed to write + #[inline] + fn can_write(&self, key: &str) -> bool { + self.read_write + .as_ref() + .is_some_and(|rw_key| ct_eq(rw_key, key)) + } +} diff --git a/src/common/collections.rs b/src/common/collections.rs new file mode 100644 index 0000000000000000000000000000000000000000..2d535db8eeb418b41ae749ece6b19941f9e1c6f6 --- /dev/null +++ b/src/common/collections.rs @@ -0,0 +1,834 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use api::grpc::qdrant::CollectionExists; +use api::rest::models::{CollectionDescription, CollectionsResponse}; +use collection::config::ShardingMethod; +use collection::operations::cluster_ops::{ + AbortTransferOperation, ClusterOperations, DropReplicaOperation, MoveShardOperation, + ReplicateShardOperation, ReshardingDirection, RestartTransfer, RestartTransferOperation, + StartResharding, +}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::snapshot_ops::SnapshotDescription; +use collection::operations::types::{ + AliasDescription, CollectionClusterInfo, CollectionInfo, CollectionsAliasesResponse, +}; +use collection::operations::verification::new_unchecked_verification_pass; +use collection::shards::replica_set; +use collection::shards::resharding::ReshardKey; +use collection::shards::shard::{PeerId, ShardId, ShardsPlacement}; +use collection::shards::transfer::{ShardTransfer, ShardTransferKey, ShardTransferRestart}; +use itertools::Itertools; +use rand::prelude::SliceRandom; +use rand::seq::IteratorRandom; +use storage::content_manager::collection_meta_ops::ShardTransferOperations::{Abort, Start}; +use storage::content_manager::collection_meta_ops::{ + CollectionMetaOperations, CreateShardKey, DropShardKey, ReshardingOperation, + SetShardReplicaState, ShardTransferOperations, UpdateCollectionOperation, +}; +use storage::content_manager::errors::StorageError; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::{Access, AccessRequirements}; + +pub async fn do_collection_exists( + toc: &TableOfContent, + access: Access, + name: &str, +) -> Result { + let collection_pass = access.check_collection_access(name, AccessRequirements::new())?; + + // if this returns Ok, it means the collection exists. + // if not, we check that the error is NotFound + let Err(error) = toc.get_collection(&collection_pass).await else { + return Ok(CollectionExists { exists: true }); + }; + match error { + StorageError::NotFound { .. } => Ok(CollectionExists { exists: false }), + e => Err(e), + } +} + +pub async fn do_get_collection( + toc: &TableOfContent, + access: Access, + name: &str, + shard_selection: Option, +) -> Result { + let collection_pass = + access.check_collection_access(name, AccessRequirements::new().whole())?; + + let collection = toc.get_collection(&collection_pass).await?; + + let shard_selection = match shard_selection { + None => ShardSelectorInternal::All, + Some(shard_id) => ShardSelectorInternal::ShardId(shard_id), + }; + + Ok(collection.info(&shard_selection).await?) +} + +pub async fn do_list_collections( + toc: &TableOfContent, + access: Access, +) -> Result { + let collections = toc + .all_collections(&access) + .await + .into_iter() + .map(|pass| CollectionDescription { + name: pass.name().to_string(), + }) + .collect_vec(); + + Ok(CollectionsResponse { collections }) +} + +/// Construct shards-replicas layout for the shard from the given scope of peers +/// Example: +/// Shards: 3 +/// Replicas: 2 +/// Peers: [A, B, C] +/// +/// Placement: +/// [ +/// [A, B] +/// [B, C] +/// [A, C] +/// ] +fn generate_even_placement( + mut pool: Vec, + shard_number: usize, + replication_factor: usize, +) -> ShardsPlacement { + let mut exact_placement = Vec::new(); + let mut rng = rand::thread_rng(); + pool.shuffle(&mut rng); + let mut loop_iter = pool.iter().cycle(); + + // pool: [1,2,3,4] + // shuf_pool: [2,3,4,1] + // + // loop_iter: [2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1,...] + // shard_placement: [2, 3, 4][1, 2, 3][4, 1, 2][3, 4, 1][2, 3, 4] + + let max_replication_factor = std::cmp::min(replication_factor, pool.len()); + for _shard in 0..shard_number { + let mut shard_placement = Vec::new(); + for _replica in 0..max_replication_factor { + shard_placement.push(*loop_iter.next().unwrap()); + } + exact_placement.push(shard_placement); + } + exact_placement +} + +pub async fn do_list_collection_aliases( + toc: &TableOfContent, + access: Access, + collection_name: &str, +) -> Result { + let collection_pass = + access.check_collection_access(collection_name, AccessRequirements::new())?; + let aliases: Vec = toc + .collection_aliases(&collection_pass, &access) + .await? + .into_iter() + .map(|alias| AliasDescription { + alias_name: alias, + collection_name: collection_name.to_string(), + }) + .collect(); + Ok(CollectionsAliasesResponse { aliases }) +} + +pub async fn do_list_aliases( + toc: &TableOfContent, + access: Access, +) -> Result { + let aliases = toc.list_aliases(&access).await?; + Ok(CollectionsAliasesResponse { aliases }) +} + +pub async fn do_list_snapshots( + toc: &TableOfContent, + access: Access, + collection_name: &str, +) -> Result, StorageError> { + let collection_pass = + access.check_collection_access(collection_name, AccessRequirements::new().whole())?; + Ok(toc + .get_collection(&collection_pass) + .await? + .list_snapshots() + .await?) +} + +pub async fn do_create_snapshot( + toc: Arc, + access: Access, + collection_name: &str, +) -> Result { + let collection_pass = access + .check_collection_access(collection_name, AccessRequirements::new().write().whole())? + .into_static(); + + let result = tokio::spawn(async move { toc.create_snapshot(&collection_pass).await }).await??; + + Ok(result) +} + +pub async fn do_get_collection_cluster( + toc: &TableOfContent, + access: Access, + name: &str, +) -> Result { + let collection_pass = + access.check_collection_access(name, AccessRequirements::new().whole())?; + let collection = toc.get_collection(&collection_pass).await?; + Ok(collection.cluster_info(toc.this_peer_id).await?) +} + +pub async fn do_update_collection_cluster( + dispatcher: &Dispatcher, + collection_name: String, + operation: ClusterOperations, + access: Access, + wait_timeout: Option, +) -> Result { + let collection_pass = access.check_collection_access( + &collection_name, + AccessRequirements::new().write().manage().whole(), + )?; + + if dispatcher.consensus_state().is_none() { + return Err(StorageError::BadRequest { + description: "Distributed mode disabled".to_string(), + }); + } + let consensus_state = dispatcher.consensus_state().unwrap(); + + let get_all_peer_ids = || { + consensus_state + .persistent + .read() + .peer_address_by_id + .read() + .keys() + .cloned() + .collect_vec() + }; + + let validate_peer_exists = |peer_id| { + let target_peer_exist = consensus_state + .persistent + .read() + .peer_address_by_id + .read() + .contains_key(&peer_id); + if !target_peer_exist { + return Err(StorageError::BadRequest { + description: format!("Peer {peer_id} does not exist"), + }); + } + Ok(()) + }; + + // All checks should've been done at this point. + let pass = new_unchecked_verification_pass(); + + let collection = dispatcher + .toc(&access, &pass) + .get_collection(&collection_pass) + .await?; + + match operation { + ClusterOperations::MoveShard(MoveShardOperation { move_shard }) => { + // validate shard to move + if !collection.contains_shard(move_shard.shard_id).await { + return Err(StorageError::BadRequest { + description: format!( + "Shard {} of {} does not exist", + move_shard.shard_id, collection_name + ), + }); + }; + + // validate target and source peer exists + validate_peer_exists(move_shard.to_peer_id)?; + validate_peer_exists(move_shard.from_peer_id)?; + + // submit operation to consensus + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::TransferShard( + collection_name, + Start(ShardTransfer { + shard_id: move_shard.shard_id, + to_shard_id: move_shard.to_shard_id, + to: move_shard.to_peer_id, + from: move_shard.from_peer_id, + sync: false, + method: move_shard.method, + }), + ), + access, + wait_timeout, + ) + .await + } + ClusterOperations::ReplicateShard(ReplicateShardOperation { replicate_shard }) => { + // validate shard to move + if !collection.contains_shard(replicate_shard.shard_id).await { + return Err(StorageError::BadRequest { + description: format!( + "Shard {} of {} does not exist", + replicate_shard.shard_id, collection_name + ), + }); + }; + + // validate target peer exists + validate_peer_exists(replicate_shard.to_peer_id)?; + + // validate source peer exists + validate_peer_exists(replicate_shard.from_peer_id)?; + + // submit operation to consensus + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::TransferShard( + collection_name, + Start(ShardTransfer { + shard_id: replicate_shard.shard_id, + to_shard_id: replicate_shard.to_shard_id, + to: replicate_shard.to_peer_id, + from: replicate_shard.from_peer_id, + sync: true, + method: replicate_shard.method, + }), + ), + access, + wait_timeout, + ) + .await + } + ClusterOperations::AbortTransfer(AbortTransferOperation { abort_transfer }) => { + let transfer = ShardTransferKey { + shard_id: abort_transfer.shard_id, + to_shard_id: abort_transfer.to_shard_id, + to: abort_transfer.to_peer_id, + from: abort_transfer.from_peer_id, + }; + + if !collection.check_transfer_exists(&transfer).await { + return Err(StorageError::NotFound { + description: format!( + "Shard transfer {} -> {} for collection {}:{} does not exist", + transfer.from, transfer.to, collection_name, transfer.shard_id + ), + }); + } + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::TransferShard( + collection_name, + Abort { + transfer, + reason: "user request".to_string(), + }, + ), + access, + wait_timeout, + ) + .await + } + ClusterOperations::DropReplica(DropReplicaOperation { drop_replica }) => { + if !collection.contains_shard(drop_replica.shard_id).await { + return Err(StorageError::BadRequest { + description: format!( + "Shard {} of {} does not exist", + drop_replica.shard_id, collection_name + ), + }); + }; + + validate_peer_exists(drop_replica.peer_id)?; + + let mut update_operation = UpdateCollectionOperation::new_empty(collection_name); + + update_operation.set_shard_replica_changes(vec![replica_set::Change::Remove( + drop_replica.shard_id, + drop_replica.peer_id, + )]); + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::UpdateCollection(update_operation), + access, + wait_timeout, + ) + .await + } + ClusterOperations::CreateShardingKey(create_sharding_key_op) => { + let create_sharding_key = create_sharding_key_op.create_sharding_key; + + // Validate that: + // - proper sharding method is used + // - key does not exist yet + // + // If placement suggested: + // - Peers exist + + let state = collection.state().await; + + match state.config.params.sharding_method.unwrap_or_default() { + ShardingMethod::Auto => { + return Err(StorageError::bad_request( + "Shard Key cannot be created with Auto sharding method", + )); + } + ShardingMethod::Custom => {} + } + + let shard_number = create_sharding_key + .shards_number + .unwrap_or(state.config.params.shard_number) + .get() as usize; + let replication_factor = create_sharding_key + .replication_factor + .unwrap_or(state.config.params.replication_factor) + .get() as usize; + + let shard_keys_mapping = state.shards_key_mapping; + if shard_keys_mapping.contains_key(&create_sharding_key.shard_key) { + return Err(StorageError::BadRequest { + description: format!( + "Sharding key {} already exists for collection {}", + create_sharding_key.shard_key, collection_name + ), + }); + } + + let peers_pool: Vec<_> = if let Some(placement) = create_sharding_key.placement { + if placement.is_empty() { + return Err(StorageError::BadRequest { + description: format!( + "Sharding key {} placement cannot be empty. If you want to use random placement, do not specify placement", + create_sharding_key.shard_key + ), + }); + } + + for peer_id in placement.iter().copied() { + validate_peer_exists(peer_id)?; + } + placement + } else { + get_all_peer_ids() + }; + + let exact_placement = + generate_even_placement(peers_pool, shard_number, replication_factor); + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::CreateShardKey(CreateShardKey { + collection_name, + shard_key: create_sharding_key.shard_key, + placement: exact_placement, + }), + access, + wait_timeout, + ) + .await + } + ClusterOperations::DropShardingKey(drop_sharding_key_op) => { + let drop_sharding_key = drop_sharding_key_op.drop_sharding_key; + // Validate that: + // - proper sharding method is used + // - key does exist + + let state = collection.state().await; + + match state.config.params.sharding_method.unwrap_or_default() { + ShardingMethod::Auto => { + return Err(StorageError::bad_request( + "Shard Key cannot be created with Auto sharding method", + )); + } + ShardingMethod::Custom => {} + } + + let shard_keys_mapping = state.shards_key_mapping; + if !shard_keys_mapping.contains_key(&drop_sharding_key.shard_key) { + return Err(StorageError::BadRequest { + description: format!( + "Sharding key {} does not exists for collection {}", + drop_sharding_key.shard_key, collection_name + ), + }); + } + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::DropShardKey(DropShardKey { + collection_name, + shard_key: drop_sharding_key.shard_key, + }), + access, + wait_timeout, + ) + .await + } + ClusterOperations::RestartTransfer(RestartTransferOperation { restart_transfer }) => { + // TODO(reshading): Deduplicate resharding operations handling? + + let RestartTransfer { + shard_id, + to_shard_id, + from_peer_id, + to_peer_id, + method, + } = restart_transfer; + + let transfer_key = ShardTransferKey { + shard_id, + to_shard_id, + to: to_peer_id, + from: from_peer_id, + }; + + if !collection.check_transfer_exists(&transfer_key).await { + return Err(StorageError::NotFound { + description: format!( + "Shard transfer {} -> {} for collection {}:{} does not exist", + transfer_key.from, transfer_key.to, collection_name, transfer_key.shard_id + ), + }); + } + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::TransferShard( + collection_name, + ShardTransferOperations::Restart(ShardTransferRestart { + shard_id, + to_shard_id, + to: to_peer_id, + from: from_peer_id, + method, + }), + ), + access, + wait_timeout, + ) + .await + } + ClusterOperations::StartResharding(op) => { + let StartResharding { + direction, + peer_id, + shard_key, + } = op.start_resharding; + + let collection_state = collection.state().await; + + if let Some(shard_key) = &shard_key { + if !collection_state.shards_key_mapping.contains_key(shard_key) { + return Err(StorageError::bad_request(format!( + "sharding key {shard_key} does not exists for collection {collection_name}" + ))); + } + } + + let shard_id = match (direction, shard_key.as_ref()) { + // When scaling up, just pick the next shard ID + (ReshardingDirection::Up, _) => { + collection_state + .shards + .keys() + .copied() + .max() + .expect("collection must contain shards") + + 1 + } + // When scaling down without shard keys, pick the last shard ID + (ReshardingDirection::Down, None) => collection_state + .shards + .keys() + .copied() + .max() + .expect("collection must contain shards"), + // When scaling down with shard keys, pick the last shard ID of that key + (ReshardingDirection::Down, Some(shard_key)) => collection_state + .shards_key_mapping + .get(shard_key) + .expect("specified shard key must exist") + .iter() + .copied() + .max() + .expect("collection must contain shards"), + }; + + let peer_id = match (peer_id, direction) { + // Select user specified peer, but make sure it exists + (Some(peer_id), _) => { + validate_peer_exists(peer_id)?; + peer_id + } + + // When scaling up, select peer with least number of shards for this collection + (None, ReshardingDirection::Up) => { + let mut shards_on_peers = collection_state + .shards + .values() + .flat_map(|shard_info| shard_info.replicas.keys()) + .fold(HashMap::new(), |mut counts, peer_id| { + *counts.entry(*peer_id).or_insert(0) += 1; + counts + }); + for peer_id in get_all_peer_ids() { + // Add registered peers not holding any shard yet + shards_on_peers.entry(peer_id).or_insert(0); + } + shards_on_peers + .into_iter() + .min_by_key(|(_, count)| *count) + .map(|(peer_id, _)| peer_id) + .expect("expected at least one peer") + } + + // When scaling down, select random peer that contains the shard we're dropping + // Other peers work, but are less efficient due to remote operations + (None, ReshardingDirection::Down) => collection_state + .shards + .get(&shard_id) + .expect("select shard ID must always exist in collection state") + .replicas + .keys() + .choose(&mut rand::thread_rng()) + .copied() + .unwrap(), + }; + + if let Some(resharding) = &collection_state.resharding { + return Err(StorageError::bad_request(format!( + "resharding {resharding:?} is already in progress \ + for collection {collection_name}" + ))); + } + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::Resharding( + collection_name.clone(), + ReshardingOperation::Start(ReshardKey { + direction, + peer_id, + shard_id, + shard_key, + }), + ), + access, + wait_timeout, + ) + .await + } + ClusterOperations::AbortResharding(_) => { + // TODO(reshading): Deduplicate resharding operations handling? + + let Some(state) = collection.resharding_state().await else { + return Err(StorageError::bad_request(format!( + "resharding is not in progress for collection {collection_name}" + ))); + }; + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::Resharding( + collection_name.clone(), + ReshardingOperation::Abort(ReshardKey { + direction: state.direction, + peer_id: state.peer_id, + shard_id: state.shard_id, + shard_key: state.shard_key.clone(), + }), + ), + access, + wait_timeout, + ) + .await + } + ClusterOperations::FinishResharding(_) => { + // TODO(resharding): Deduplicate resharding operations handling? + + let Some(state) = collection.resharding_state().await else { + return Err(StorageError::bad_request(format!( + "resharding is not in progress for collection {collection_name}" + ))); + }; + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::Resharding( + collection_name.clone(), + ReshardingOperation::Finish(state.key()), + ), + access, + wait_timeout, + ) + .await + } + + ClusterOperations::FinishMigratingPoints(op) => { + // TODO(resharding): Deduplicate resharding operations handling? + + let Some(state) = collection.resharding_state().await else { + return Err(StorageError::bad_request(format!( + "resharding is not in progress for collection {collection_name}" + ))); + }; + + let op = op.finish_migrating_points; + + let shard_id = match (op.shard_id, state.direction) { + (Some(shard_id), _) => shard_id, + (None, ReshardingDirection::Up) => state.shard_id, + (None, ReshardingDirection::Down) => { + return Err(StorageError::bad_request( + "shard ID must be specified when resharding down", + )); + } + }; + + let peer_id = match (op.peer_id, state.direction) { + (Some(peer_id), _) => peer_id, + (None, ReshardingDirection::Up) => state.peer_id, + (None, ReshardingDirection::Down) => { + return Err(StorageError::bad_request( + "peer ID must be specified when resharding down", + )); + } + }; + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::SetShardReplicaState(SetShardReplicaState { + collection_name: collection_name.clone(), + shard_id, + peer_id, + state: replica_set::ReplicaState::Active, + from_state: Some(replica_set::ReplicaState::Resharding), + }), + access, + wait_timeout, + ) + .await + } + + ClusterOperations::CommitReadHashRing(_) => { + // TODO(reshading): Deduplicate resharding operations handling? + + let Some(state) = collection.resharding_state().await else { + return Err(StorageError::bad_request(format!( + "resharding is not in progress for collection {collection_name}" + ))); + }; + + // TODO(resharding): Add precondition checks? + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::Resharding( + collection_name.clone(), + ReshardingOperation::CommitRead(ReshardKey { + direction: state.direction, + peer_id: state.peer_id, + shard_id: state.shard_id, + shard_key: state.shard_key.clone(), + }), + ), + access, + wait_timeout, + ) + .await + } + + ClusterOperations::CommitWriteHashRing(_) => { + // TODO(reshading): Deduplicate resharding operations handling? + + let Some(state) = collection.resharding_state().await else { + return Err(StorageError::bad_request(format!( + "resharding is not in progress for collection {collection_name}" + ))); + }; + + // TODO(resharding): Add precondition checks? + + dispatcher + .submit_collection_meta_op( + CollectionMetaOperations::Resharding( + collection_name.clone(), + ReshardingOperation::CommitWrite(ReshardKey { + direction: state.direction, + peer_id: state.peer_id, + shard_id: state.shard_id, + shard_key: state.shard_key.clone(), + }), + ), + access, + wait_timeout, + ) + .await + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn test_generate_even_placement() { + let pool = vec![1, 2, 3]; + let placement = generate_even_placement(pool, 3, 2); + + assert_eq!(placement.len(), 3); + for shard_placement in placement { + assert_eq!(shard_placement.len(), 2); + assert_ne!(shard_placement[0], shard_placement[1]); + } + + let pool = vec![1, 2, 3]; + let placement = generate_even_placement(pool, 3, 3); + + assert_eq!(placement.len(), 3); + for shard_placement in placement { + assert_eq!(shard_placement.len(), 3); + let set: HashSet<_> = shard_placement.into_iter().collect(); + assert_eq!(set.len(), 3); + } + + let pool = vec![1, 2, 3, 4, 5, 6]; + let placement = generate_even_placement(pool, 3, 2); + + assert_eq!(placement.len(), 3); + let flat_placement: Vec<_> = placement.into_iter().flatten().collect(); + let set: HashSet<_> = flat_placement.into_iter().collect(); + assert_eq!(set.len(), 6); + + let pool = vec![1, 2, 3, 4, 5]; + let placement = generate_even_placement(pool, 3, 10); + + assert_eq!(placement.len(), 3); + for shard_placement in placement { + assert_eq!(shard_placement.len(), 5); + } + } +} diff --git a/src/common/debugger.rs b/src/common/debugger.rs new file mode 100644 index 0000000000000000000000000000000000000000..a6f5d4619541c752d2c2d6c34da4219bada29086 --- /dev/null +++ b/src/common/debugger.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use parking_lot::Mutex; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::common::pyroscope_state::pyro::PyroscopeState; +use crate::settings::Settings; + +#[derive(Serialize, JsonSchema, Debug, Deserialize, Clone)] +pub struct PyroscopeConfig { + pub url: String, + pub identifier: String, + pub user: Option, + pub password: Option, + pub sampling_rate: Option, +} + +#[derive(Default, Debug, Serialize, JsonSchema, Deserialize, Clone)] +pub struct DebuggerConfig { + pub pyroscope: Option, +} + +#[derive(Debug, Serialize, JsonSchema, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum DebugConfigPatch { + Pyroscope(Option), +} + +pub struct DebuggerState { + #[cfg_attr(not(target_os = "linux"), allow(dead_code))] + pub pyroscope: Arc>>, +} + +impl DebuggerState { + pub fn from_settings(settings: &Settings) -> Self { + let pyroscope_config = settings.debugger.pyroscope.clone(); + Self { + pyroscope: Arc::new(Mutex::new(PyroscopeState::from_config(pyroscope_config))), + } + } + + #[cfg_attr(not(target_os = "linux"), allow(clippy::unused_self))] + pub fn get_config(&self) -> DebuggerConfig { + let pyroscope_config = { + #[cfg(target_os = "linux")] + { + let pyroscope_state_guard = self.pyroscope.lock(); + pyroscope_state_guard.as_ref().map(|s| s.config.clone()) + } + #[cfg(not(target_os = "linux"))] + { + None + } + }; + + DebuggerConfig { + pyroscope: pyroscope_config, + } + } + + #[cfg_attr(not(target_os = "linux"), allow(clippy::unused_self))] + pub fn apply_config_patch(&self, patch: DebugConfigPatch) -> bool { + #[cfg(target_os = "linux")] + { + match patch { + DebugConfigPatch::Pyroscope(new_config) => { + let mut pyroscope_guard = self.pyroscope.lock(); + if let Some(pyroscope_state) = pyroscope_guard.as_mut() { + let stopped = pyroscope_state.stop_agent(); + if !stopped { + return false; + } + } + + if let Some(new_config) = new_config { + *pyroscope_guard = PyroscopeState::from_config(Some(new_config)); + } + true + } + } + } + + #[cfg(not(target_os = "linux"))] + { + let _ = patch; // Ignore new_config on non-linux OS + false + } + } +} diff --git a/src/common/error_reporting.rs b/src/common/error_reporting.rs new file mode 100644 index 0000000000000000000000000000000000000000..c49c6e18bf090a74d57c7ed19928c61f5cde9a35 --- /dev/null +++ b/src/common/error_reporting.rs @@ -0,0 +1,31 @@ +use std::time::Duration; + +pub struct ErrorReporter; + +impl ErrorReporter { + fn get_url() -> String { + if cfg!(debug_assertions) { + "https://staging-telemetry.qdrant.io".to_string() + } else { + "https://telemetry.qdrant.io".to_string() + } + } + + pub fn report(error: &str, reporting_id: &str, backtrace: Option<&str>) { + let client = reqwest::blocking::Client::new(); + + let report = serde_json::json!({ + "id": reporting_id, + "error": error, + "backtrace": backtrace.unwrap_or(""), + }); + + let data = serde_json::to_string(&report).unwrap(); + let _resp = client + .post(Self::get_url()) + .body(data) + .header("Content-Type", "application/json") + .timeout(Duration::from_secs(1)) + .send(); + } +} diff --git a/src/common/health.rs b/src/common/health.rs new file mode 100644 index 0000000000000000000000000000000000000000..c4478ce533d2457fa931a81e9b99e56bf32101f6 --- /dev/null +++ b/src/common/health.rs @@ -0,0 +1,372 @@ +use std::collections::HashSet; +use std::future::{self, Future}; +use std::sync::atomic::{self, AtomicBool}; +use std::sync::Arc; +use std::time::Duration; +use std::{panic, thread}; + +use api::grpc::qdrant::qdrant_internal_client::QdrantInternalClient; +use api::grpc::qdrant::{GetConsensusCommitRequest, GetConsensusCommitResponse}; +use api::grpc::transport_channel_pool::{self, TransportChannelPool}; +use collection::shards::shard::ShardId; +use collection::shards::CollectionId; +use common::defaults; +use futures::stream::FuturesUnordered; +use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _}; +use itertools::Itertools; +use storage::content_manager::consensus_manager::ConsensusStateRef; +use storage::content_manager::toc::TableOfContent; +use storage::rbac::Access; +use tokio::{runtime, sync, time}; + +const READY_CHECK_TIMEOUT: Duration = Duration::from_millis(500); +const GET_CONSENSUS_COMMITS_RETRIES: usize = 2; + +/// Structure used to process health checks like `/readyz` endpoints. +pub struct HealthChecker { + // The state of the health checker. + // Once set to `true`, it should not change back to `false`. + // Initially set to `false`. + is_ready: Arc, + // The signal that notifies that state has changed. + // Comes from the health checker task. + is_ready_signal: Arc, + // Signal to the health checker task, that the API was called. + // Used to drive the health checker task and avoid constant polling. + check_ready_signal: Arc, + cancel: cancel::DropGuard, +} + +impl HealthChecker { + pub fn spawn( + toc: Arc, + consensus_state: ConsensusStateRef, + runtime: &runtime::Handle, + wait_for_bootstrap: bool, + ) -> Self { + let task = Task { + toc, + consensus_state, + is_ready: Default::default(), + is_ready_signal: Default::default(), + check_ready_signal: Default::default(), + cancel: Default::default(), + wait_for_bootstrap, + }; + + let health_checker = Self { + is_ready: task.is_ready.clone(), + is_ready_signal: task.is_ready_signal.clone(), + check_ready_signal: task.check_ready_signal.clone(), + cancel: task.cancel.clone().drop_guard(), + }; + + let task = runtime.spawn(task.exec()); + drop(task); // drop `JoinFuture` explicitly to make clippy happy + + health_checker + } + + pub async fn check_ready(&self) -> bool { + if self.is_ready() { + return true; + } + + self.notify_task(); + self.wait_ready().await + } + + pub fn is_ready(&self) -> bool { + self.is_ready.load(atomic::Ordering::Relaxed) + } + + pub fn notify_task(&self) { + self.check_ready_signal.notify_one(); + } + + async fn wait_ready(&self) -> bool { + let is_ready_signal = self.is_ready_signal.notified(); + + if self.is_ready() { + return true; + } + + time::timeout(READY_CHECK_TIMEOUT, is_ready_signal) + .await + .is_ok() + } +} + +pub struct Task { + toc: Arc, + consensus_state: ConsensusStateRef, + // Shared state with the health checker + // Once set to `true`, it should not change back to `false`. + is_ready: Arc, + // Used to notify the health checker service that the state has changed. + is_ready_signal: Arc, + // Driver signal for the health checker task + // Once received, the task should proceed with an attempt to check the state. + // Usually comes from the API call, but can be triggered by the task itself. + check_ready_signal: Arc, + cancel: cancel::CancellationToken, + wait_for_bootstrap: bool, +} + +impl Task { + pub async fn exec(mut self) { + while let Err(err) = self.exec_catch_unwind().await { + let message = common::panic::downcast_str(&err).unwrap_or(""); + let separator = if !message.is_empty() { ": " } else { "" }; + + log::error!("HealthChecker task panicked, retrying{separator}{message}",); + } + } + + async fn exec_catch_unwind(&mut self) -> thread::Result<()> { + panic::AssertUnwindSafe(self.exec_cancel()) + .catch_unwind() + .await + } + + async fn exec_cancel(&mut self) { + let _ = cancel::future::cancel_on_token(self.cancel.clone(), self.exec_impl()).await; + } + + async fn exec_impl(&mut self) { + // Wait until node joins cluster for the first time + // + // If this is a new deployment and `--bootstrap` CLI parameter was specified... + if self.wait_for_bootstrap { + // Check if this is the only node in the cluster + while self.consensus_state.peer_count() <= 1 { + // If cluster is empty, make another attempt to check + // after we receive another call to `/readyz` + // + // Wait for `/readyz` signal + self.check_ready_signal.notified().await; + } + } + + // Artificial simulate signal from `/readyz` endpoint + // as if it was already called by the user. + // This allows to check the happy path without waiting for the first call. + self.check_ready_signal.notify_one(); + + // Get *cluster* commit index, or check if this is the only node in the cluster + let Some(cluster_commit_index) = self.cluster_commit_index().await else { + self.set_ready(); + return; + }; + + // Check if *local* commit index >= *cluster* commit index... + while self.commit_index() < cluster_commit_index { + // Wait for `/readyz` signal + self.check_ready_signal.notified().await; + + // If not: + // + // - Check if this is the only node in the cluster + if self.consensus_state.peer_count() <= 1 { + self.set_ready(); + return; + } + + // TODO: Do we want to update `cluster_commit_index` here? + // + // I.e.: + // - If we *don't* update `cluster_commit_index`, then we will only wait till the node + // catch up with the cluster commit index *at the moment the node has been started* + // - If we *do* update `cluster_commit_index`, then we will keep track of cluster + // commit index updates and wait till the node *completely* catch up with the leader, + // which might be hard (if not impossible) in some situations + } + + // Collect "unhealthy" shards list + let mut unhealthy_shards = self.unhealthy_shards().await; + + // Check if all shards are "healthy"... + while !unhealthy_shards.is_empty() { + // If not: + // + // - Wait for `/readyz` signal + self.check_ready_signal.notified().await; + + // - Refresh "unhealthy" shards list + let current_unhealthy_shards = self.unhealthy_shards().await; + + // - Check if any shards "healed" since last check + unhealthy_shards.retain(|shard| current_unhealthy_shards.contains(shard)); + } + + self.set_ready(); + } + + async fn cluster_commit_index(&self) -> Option { + // Wait for `/readyz` signal + self.check_ready_signal.notified().await; + + // Check if there is only 1 node in the cluster + if self.consensus_state.peer_count() <= 1 { + return None; + } + + // Get *cluster* commit index + let peer_address_by_id = self.consensus_state.peer_address_by_id(); + let transport_channel_pool = &self.toc.get_channel_service().channel_pool; + let this_peer_id = self.toc.this_peer_id; + let this_peer_uri = peer_address_by_id.get(&this_peer_id); + + let mut requests = peer_address_by_id + .values() + // Do not get the current commit from ourselves + .filter(|&uri| Some(uri) != this_peer_uri) + // Historic peers might use the same URLs as our current peers, request each URI once + .unique() + .map(|uri| get_consensus_commit(transport_channel_pool, uri)) + .collect::>() + .inspect_err(|err| log::error!("GetConsensusCommit request failed: {err}")) + .filter_map(|res| future::ready(res.ok())); + + // Raft commits consensus operation, after majority of nodes persisted it. + // + // This means, if we check the majority of nodes (e.g., `total nodes / 2 + 1`), at least one + // of these nodes will *always* have an up-to-date commit index. And so, the highest commit + // index among majority of nodes *is* the cluster commit index. + // + // Our current node *is* one of the cluster nodes, so it's enough to query `total nodes / 2` + // *additional* nodes, to get cluster commit index. + // + // The check goes like this: + // - Either at least one of the "additional" nodes return a *higher* commit index, which + // means our node is *not* up-to-date, and we have to wait to reach this commit index + // - Or *all* of them return *lower* commit index, which means current node is *already* + // up-to-date, and `/readyz` check will pass to the next step + // + // Example: + // + // Total nodes: 2 + // Required: 2 / 2 = 1 + // + // Total nodes: 3 + // Required: 3 / 2 = 1 + // + // Total nodes: 4 + // Required: 4 / 2 = 2 + // + // Total nodes: 5 + // Required: 5 / 2 = 2 + let sufficient_commit_indices_count = peer_address_by_id.len() / 2; + + // *Wait* for `total nodex / 2` successful responses... + let mut commit_indices: Vec<_> = (&mut requests) + .take(sufficient_commit_indices_count) + .collect() + .await; + + // ...and also collect any additional responses, that we might have *already* received + while let Ok(Some(resp)) = time::timeout(Duration::ZERO, requests.next()).await { + commit_indices.push(resp); + } + + // Find the maximum commit index among all responses. + // + // Note, that we progress even if most (or even *all*) requests failed (e.g., because all + // other nodes are unavailable or they don't support `GetConsensusCommit` gRPC API). + // + // So this check is not 100% reliable and can give a false-positive result! + let cluster_commit_index = commit_indices + .into_iter() + .map(|resp| resp.into_inner().commit) + .max() + .unwrap_or(0); + + Some(cluster_commit_index as _) + } + + fn commit_index(&self) -> u64 { + // TODO: Blocking call in async context!? + self.consensus_state + .persistent + .read() + .last_applied_entry() + .unwrap_or(0) + } + + /// List shards that are unhealthy, which may undergo automatic recovery. + /// + /// Shards in resharding state are not considered unhealthy and are excluded here. + /// They require an external driver to make them active or to drop them. + async fn unhealthy_shards(&self) -> HashSet { + let this_peer_id = self.toc.this_peer_id; + let collections = self + .toc + .all_collections(&Access::full("For health check")) + .await; + + let mut unhealthy_shards = HashSet::new(); + + for collection_pass in &collections { + let state = match self.toc.get_collection(collection_pass).await { + Ok(collection) => collection.state().await, + Err(_) => continue, + }; + + for (&shard, info) in state.shards.iter() { + let Some(state) = info.replicas.get(&this_peer_id) else { + continue; + }; + + if state.is_active_or_listener_or_resharding() { + continue; + } + + unhealthy_shards.insert(Shard::new(collection_pass.name(), shard)); + } + } + + unhealthy_shards + } + + fn set_ready(&self) { + self.is_ready.store(true, atomic::Ordering::Relaxed); + self.is_ready_signal.notify_waiters(); + } +} + +fn get_consensus_commit<'a>( + transport_channel_pool: &'a TransportChannelPool, + uri: &'a tonic::transport::Uri, +) -> impl Future + 'a { + transport_channel_pool.with_channel_timeout( + uri, + |channel| async { + let mut client = QdrantInternalClient::new(channel); + let mut request = tonic::Request::new(GetConsensusCommitRequest {}); + request.set_timeout(defaults::CONSENSUS_META_OP_WAIT); + client.get_consensus_commit(request).await + }, + Some(defaults::CONSENSUS_META_OP_WAIT), + GET_CONSENSUS_COMMITS_RETRIES, + ) +} + +type GetConsensusCommitResult = Result< + tonic::Response, + transport_channel_pool::RequestError, +>; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct Shard { + collection: CollectionId, + shard: ShardId, +} + +impl Shard { + pub fn new(collection: impl Into, shard: ShardId) -> Self { + Self { + collection: collection.into(), + shard, + } + } +} diff --git a/src/common/helpers.rs b/src/common/helpers.rs new file mode 100644 index 0000000000000000000000000000000000000000..bc2f105439c5654eb479b220e82b6e505d942e90 --- /dev/null +++ b/src/common/helpers.rs @@ -0,0 +1,151 @@ +use std::cmp::max; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{fs, io}; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tokio::runtime; +use tokio::runtime::Runtime; +use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig}; +use validator::Validate; + +use crate::settings::{Settings, TlsConfig}; + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate)] +pub struct LocksOption { + pub error_message: Option, + pub write: bool, +} + +pub fn create_search_runtime(max_search_threads: usize) -> io::Result { + let mut search_threads = max_search_threads; + + if search_threads == 0 { + let num_cpu = common::cpu::get_num_cpus(); + // At least one thread, but not more than number of CPUs - 1 if there are more than 2 CPU + // Example: + // Num CPU = 1 -> 1 thread + // Num CPU = 2 -> 2 thread - if we use one thread with 2 cpus, its too much un-utilized resources + // Num CPU = 3 -> 2 thread + // Num CPU = 4 -> 3 thread + // Num CPU = 5 -> 4 thread + search_threads = match num_cpu { + 0 => 1, + 1 => 1, + 2 => 2, + _ => num_cpu - 1, + }; + } + + runtime::Builder::new_multi_thread() + .worker_threads(search_threads) + .max_blocking_threads(search_threads) + .enable_all() + .thread_name_fn(|| { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("search-{id}") + }) + .build() +} + +pub fn create_update_runtime(max_optimization_threads: usize) -> io::Result { + let mut update_runtime_builder = runtime::Builder::new_multi_thread(); + + update_runtime_builder + .enable_time() + .thread_name_fn(move || { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let update_id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("update-{update_id}") + }); + + if max_optimization_threads > 0 { + // panics if val is not larger than 0. + update_runtime_builder.max_blocking_threads(max_optimization_threads); + } + update_runtime_builder.build() +} + +pub fn create_general_purpose_runtime() -> io::Result { + runtime::Builder::new_multi_thread() + .enable_time() + .enable_io() + .worker_threads(max(common::cpu::get_num_cpus(), 2)) + .thread_name_fn(|| { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let general_id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("general-{general_id}") + }) + .build() +} + +/// Load client TLS configuration. +pub fn load_tls_client_config(settings: &Settings) -> io::Result> { + if settings.cluster.p2p.enable_tls { + let tls_config = &settings.tls()?; + Ok(Some( + ClientTlsConfig::new() + .identity(load_identity(tls_config)?) + .ca_certificate(load_ca_certificate(tls_config)?), + )) + } else { + Ok(None) + } +} + +/// Load server TLS configuration for external gRPC +pub fn load_tls_external_server_config(tls_config: &TlsConfig) -> io::Result { + Ok(ServerTlsConfig::new().identity(load_identity(tls_config)?)) +} + +/// Load server TLS configuration for internal gRPC, check client certificate against CA +pub fn load_tls_internal_server_config(tls_config: &TlsConfig) -> io::Result { + Ok(ServerTlsConfig::new() + .identity(load_identity(tls_config)?) + .client_ca_root(load_ca_certificate(tls_config)?)) +} + +fn load_identity(tls_config: &TlsConfig) -> io::Result { + let cert = fs::read_to_string(&tls_config.cert)?; + let key = fs::read_to_string(&tls_config.key)?; + Ok(Identity::from_pem(cert, key)) +} + +fn load_ca_certificate(tls_config: &TlsConfig) -> io::Result { + let pem = fs::read_to_string(&tls_config.ca_cert)?; + Ok(Certificate::from_pem(pem)) +} + +pub fn tonic_error_to_io_error(err: tonic::transport::Error) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::thread; + use std::thread::sleep; + use std::time::Duration; + + use collection::common::is_ready::IsReady; + + #[test] + fn test_is_ready() { + let is_ready = Arc::new(IsReady::default()); + let is_ready_clone = is_ready.clone(); + let join = thread::spawn(move || { + is_ready_clone.await_ready(); + eprintln!( + "is_ready_clone.check_ready() = {:#?}", + is_ready_clone.check_ready() + ); + }); + + sleep(Duration::from_millis(500)); + eprintln!("Making ready"); + is_ready.make_ready(); + sleep(Duration::from_millis(500)); + join.join().unwrap() + } +} diff --git a/src/common/http_client.rs b/src/common/http_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..212b67f0357edacefe82db7040ae4c309069f5a4 --- /dev/null +++ b/src/common/http_client.rs @@ -0,0 +1,156 @@ +use std::path::Path; +use std::{fs, io, result}; + +use reqwest::header::{HeaderMap, HeaderValue, InvalidHeaderValue}; +use storage::content_manager::errors::StorageError; + +use super::auth::HTTP_HEADER_API_KEY; +use crate::settings::{Settings, TlsConfig}; + +#[derive(Clone)] +pub struct HttpClient { + tls_config: Option, + verify_https_client_certificate: bool, +} + +impl HttpClient { + pub fn from_settings(settings: &Settings) -> Result { + let tls_config = if settings.service.enable_tls { + let Some(tls_config) = settings.tls.clone() else { + return Err(Error::TlsConfigUndefined); + }; + + Some(tls_config) + } else { + None + }; + + let verify_https_client_certificate = settings.service.verify_https_client_certificate; + + let http_client = Self { + tls_config, + verify_https_client_certificate, + }; + + Ok(http_client) + } + + /// Create a new HTTP(S) client + /// + /// An API key can be optionally provided to be used in this HTTP client. It'll send the API + /// key as `Api-key` header in every request. + /// + /// # Warning + /// + /// Setting an API key may leak when the client is used to send a request to a malicious + /// server. This is potentially dangerous if a user has control over what URL is accessed. + /// + /// For this reason the API key is not set by default as provided in the configuration. It must + /// be explicitly provided when creating the HTTP client. + pub fn client(&self, api_key: Option<&str>) -> Result { + https_client( + api_key, + self.tls_config.as_ref(), + self.verify_https_client_certificate, + ) + } +} + +fn https_client( + api_key: Option<&str>, + tls_config: Option<&TlsConfig>, + verify_https_client_certificate: bool, +) -> Result { + let mut builder = reqwest::Client::builder(); + + // Configure TLS root certificate and validation + if let Some(tls_config) = tls_config { + builder = builder.add_root_certificate(https_client_ca_cert(tls_config.ca_cert.as_ref())?); + + if verify_https_client_certificate { + builder = builder.identity(https_client_identity( + tls_config.cert.as_ref(), + tls_config.key.as_ref(), + )?); + } + } + + // Attach API key as sensitive header + if let Some(api_key) = api_key { + let mut headers = HeaderMap::new(); + let mut api_key_value = HeaderValue::from_str(api_key).map_err(Error::MalformedApiKey)?; + api_key_value.set_sensitive(true); + headers.insert(HTTP_HEADER_API_KEY, api_key_value); + builder = builder.default_headers(headers); + } + + let client = builder.build()?; + + Ok(client) +} + +fn https_client_ca_cert(ca_cert: &Path) -> Result { + let ca_cert_pem = + fs::read(ca_cert).map_err(|err| Error::failed_to_read(err, "CA certificate", ca_cert))?; + + let ca_cert = reqwest::Certificate::from_pem(&ca_cert_pem)?; + + Ok(ca_cert) +} + +fn https_client_identity(cert: &Path, key: &Path) -> Result { + let mut identity_pem = + fs::read(cert).map_err(|err| Error::failed_to_read(err, "certificate", cert))?; + + let mut key_file = fs::File::open(key).map_err(|err| Error::failed_to_read(err, "key", key))?; + + // Concatenate certificate and key into a single PEM bytes + io::copy(&mut key_file, &mut identity_pem) + .map_err(|err| Error::failed_to_read(err, "key", key))?; + + let identity = reqwest::Identity::from_pem(&identity_pem)?; + + Ok(identity) +} + +pub type Result = result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("TLS config is not defined in the Qdrant config file")] + TlsConfigUndefined, + + #[error("{1}: {0}")] + Io(#[source] io::Error, String), + + #[error("failed to setup HTTPS client: {0}")] + Reqwest(#[from] reqwest::Error), + + #[error("malformed API key")] + MalformedApiKey(#[source] InvalidHeaderValue), +} + +impl Error { + pub fn io(source: io::Error, context: impl Into) -> Self { + Self::Io(source, context.into()) + } + + pub fn failed_to_read(source: io::Error, file: &str, path: &Path) -> Self { + Self::io( + source, + format!("failed to read HTTPS client {file} file {}", path.display()), + ) + } +} + +impl From for StorageError { + fn from(err: Error) -> Self { + StorageError::service_error(format!("failed to initialize HTTP(S) client: {err}")) + } +} + +impl From for io::Error { + fn from(err: Error) -> Self { + io::Error::new(io::ErrorKind::Other, err) + } +} diff --git a/src/common/inference/batch_processing.rs b/src/common/inference/batch_processing.rs new file mode 100644 index 0000000000000000000000000000000000000000..fd4a5051ee8523463553fbf1c32bf59e49263989 --- /dev/null +++ b/src/common/inference/batch_processing.rs @@ -0,0 +1,370 @@ +use std::collections::HashSet; + +use api::rest::{ + ContextInput, ContextPair, DiscoverInput, Prefetch, Query, QueryGroupsRequestInternal, + QueryInterface, QueryRequestInternal, RecommendInput, VectorInput, +}; + +use super::service::{InferenceData, InferenceInput, InferenceRequest}; + +pub struct BatchAccum { + pub(crate) objects: HashSet, +} + +impl BatchAccum { + pub fn new() -> Self { + Self { + objects: HashSet::new(), + } + } + + pub fn add(&mut self, data: InferenceData) { + self.objects.insert(data); + } + + pub fn extend(&mut self, other: BatchAccum) { + self.objects.extend(other.objects); + } + + pub fn is_empty(&self) -> bool { + self.objects.is_empty() + } +} + +impl From<&BatchAccum> for InferenceRequest { + fn from(batch: &BatchAccum) -> Self { + Self { + inputs: batch + .objects + .iter() + .cloned() + .map(InferenceInput::from) + .collect(), + inference: None, + token: None, + } + } +} + +fn collect_vector_input(vector: &VectorInput, batch: &mut BatchAccum) { + match vector { + VectorInput::Document(doc) => batch.add(InferenceData::Document(doc.clone())), + VectorInput::Image(img) => batch.add(InferenceData::Image(img.clone())), + VectorInput::Object(obj) => batch.add(InferenceData::Object(obj.clone())), + // types that are not supported in the Inference Service + VectorInput::DenseVector(_) => {} + VectorInput::SparseVector(_) => {} + VectorInput::MultiDenseVector(_) => {} + VectorInput::Id(_) => {} + } +} + +fn collect_context_pair(pair: &ContextPair, batch: &mut BatchAccum) { + collect_vector_input(&pair.positive, batch); + collect_vector_input(&pair.negative, batch); +} + +fn collect_discover_input(discover: &DiscoverInput, batch: &mut BatchAccum) { + collect_vector_input(&discover.target, batch); + if let Some(context) = &discover.context { + for pair in context { + collect_context_pair(pair, batch); + } + } +} + +fn collect_recommend_input(recommend: &RecommendInput, batch: &mut BatchAccum) { + if let Some(positive) = &recommend.positive { + for vector in positive { + collect_vector_input(vector, batch); + } + } + if let Some(negative) = &recommend.negative { + for vector in negative { + collect_vector_input(vector, batch); + } + } +} + +fn collect_query(query: &Query, batch: &mut BatchAccum) { + match query { + Query::Nearest(nearest) => collect_vector_input(&nearest.nearest, batch), + Query::Recommend(recommend) => collect_recommend_input(&recommend.recommend, batch), + Query::Discover(discover) => collect_discover_input(&discover.discover, batch), + Query::Context(context) => { + if let ContextInput(Some(pairs)) = &context.context { + for pair in pairs { + collect_context_pair(pair, batch); + } + } + } + Query::OrderBy(_) | Query::Fusion(_) | Query::Sample(_) => {} + } +} + +fn collect_query_interface(query: &QueryInterface, batch: &mut BatchAccum) { + match query { + QueryInterface::Nearest(vector) => collect_vector_input(vector, batch), + QueryInterface::Query(query) => collect_query(query, batch), + } +} + +fn collect_prefetch(prefetch: &Prefetch, batch: &mut BatchAccum) { + let Prefetch { + prefetch, + query, + using: _, + filter: _, + params: _, + score_threshold: _, + limit: _, + lookup_from: _, + } = prefetch; + + if let Some(query) = query { + collect_query_interface(query, batch); + } + + if let Some(prefetches) = prefetch { + for p in prefetches { + collect_prefetch(p, batch); + } + } +} + +pub fn collect_query_groups_request(request: &QueryGroupsRequestInternal) -> BatchAccum { + let mut batch = BatchAccum::new(); + + let QueryGroupsRequestInternal { + query, + prefetch, + using: _, + filter: _, + params: _, + score_threshold: _, + with_vector: _, + with_payload: _, + lookup_from: _, + group_request: _, + } = request; + + if let Some(query) = query { + collect_query_interface(query, &mut batch); + } + + if let Some(prefetches) = prefetch { + for prefetch in prefetches { + collect_prefetch(prefetch, &mut batch); + } + } + + batch +} + +pub fn collect_query_request(request: &QueryRequestInternal) -> BatchAccum { + let mut batch = BatchAccum::new(); + + let QueryRequestInternal { + prefetch, + query, + using: _, + filter: _, + score_threshold: _, + params: _, + limit: _, + offset: _, + with_vector: _, + with_payload: _, + lookup_from: _, + } = request; + + if let Some(query) = query { + collect_query_interface(query, &mut batch); + } + + if let Some(prefetches) = prefetch { + for prefetch in prefetches { + collect_prefetch(prefetch, &mut batch); + } + } + + batch +} + +#[cfg(test)] +mod tests { + use api::rest::schema::{DiscoverQuery, Document, Image, InferenceObject, NearestQuery}; + use api::rest::QueryBaseGroupRequest; + use serde_json::json; + + use super::*; + + fn create_test_document(text: &str) -> Document { + Document { + text: text.to_string(), + model: "test-model".to_string(), + options: Default::default(), + } + } + + fn create_test_image(url: &str) -> Image { + Image { + image: json!({"data": url.to_string()}), + model: "test-model".to_string(), + options: Default::default(), + } + } + + fn create_test_object(data: &str) -> InferenceObject { + InferenceObject { + object: json!({"data": data}), + model: "test-model".to_string(), + options: Default::default(), + } + } + + #[test] + fn test_batch_accum_basic() { + let mut batch = BatchAccum::new(); + assert!(batch.objects.is_empty()); + + let doc = InferenceData::Document(create_test_document("test")); + batch.add(doc.clone()); + assert_eq!(batch.objects.len(), 1); + + batch.add(doc); + assert_eq!(batch.objects.len(), 1); + } + + #[test] + fn test_batch_accum_extend() { + let mut batch1 = BatchAccum::new(); + let mut batch2 = BatchAccum::new(); + + let doc1 = InferenceData::Document(create_test_document("test1")); + let doc2 = InferenceData::Document(create_test_document("test2")); + + batch1.add(doc1); + batch2.add(doc2); + + batch1.extend(batch2); + assert_eq!(batch1.objects.len(), 2); + } + + #[test] + fn test_deduplication() { + let mut batch = BatchAccum::new(); + + let doc1 = InferenceData::Document(create_test_document("same")); + let doc2 = InferenceData::Document(create_test_document("same")); + + batch.add(doc1); + batch.add(doc2); + + assert_eq!(batch.objects.len(), 1); + } + + #[test] + fn test_collect_vector_input() { + let mut batch = BatchAccum::new(); + + let doc_input = VectorInput::Document(create_test_document("test")); + let img_input = VectorInput::Image(create_test_image("test.jpg")); + let obj_input = VectorInput::Object(create_test_object("test")); + + collect_vector_input(&doc_input, &mut batch); + collect_vector_input(&img_input, &mut batch); + collect_vector_input(&obj_input, &mut batch); + + assert_eq!(batch.objects.len(), 3); + } + + #[test] + fn test_collect_prefetch() { + let prefetch = Prefetch { + query: Some(QueryInterface::Nearest(VectorInput::Document( + create_test_document("test"), + ))), + prefetch: Some(vec![Prefetch { + query: Some(QueryInterface::Nearest(VectorInput::Image( + create_test_image("nested.jpg"), + ))), + prefetch: None, + using: None, + filter: None, + params: None, + score_threshold: None, + limit: None, + lookup_from: None, + }]), + using: None, + filter: None, + params: None, + score_threshold: None, + limit: None, + lookup_from: None, + }; + + let mut batch = BatchAccum::new(); + collect_prefetch(&prefetch, &mut batch); + assert_eq!(batch.objects.len(), 2); + } + + #[test] + fn test_collect_query_groups_request() { + let request = QueryGroupsRequestInternal { + query: Some(QueryInterface::Query(Query::Nearest(NearestQuery { + nearest: VectorInput::Document(create_test_document("test")), + }))), + prefetch: Some(vec![Prefetch { + query: Some(QueryInterface::Query(Query::Discover(DiscoverQuery { + discover: DiscoverInput { + target: VectorInput::Image(create_test_image("test.jpg")), + context: Some(vec![ContextPair { + positive: VectorInput::Document(create_test_document("pos")), + negative: VectorInput::Image(create_test_image("neg.jpg")), + }]), + }, + }))), + prefetch: None, + using: None, + filter: None, + params: None, + score_threshold: None, + limit: None, + lookup_from: None, + }]), + using: None, + filter: None, + params: None, + score_threshold: None, + with_vector: None, + with_payload: None, + lookup_from: None, + group_request: QueryBaseGroupRequest { + group_by: "test".parse().unwrap(), + group_size: None, + limit: None, + with_lookup: None, + }, + }; + + let batch = collect_query_groups_request(&request); + assert_eq!(batch.objects.len(), 4); + } + + #[test] + fn test_different_model_same_content() { + let mut batch = BatchAccum::new(); + + let mut doc1 = create_test_document("same"); + let mut doc2 = create_test_document("same"); + doc1.model = "model1".to_string(); + doc2.model = "model2".to_string(); + + batch.add(InferenceData::Document(doc1)); + batch.add(InferenceData::Document(doc2)); + + assert_eq!(batch.objects.len(), 2); + } +} diff --git a/src/common/inference/batch_processing_grpc.rs b/src/common/inference/batch_processing_grpc.rs new file mode 100644 index 0000000000000000000000000000000000000000..9081286eff77feff5ac0bc31313bbddf2eaacdee --- /dev/null +++ b/src/common/inference/batch_processing_grpc.rs @@ -0,0 +1,281 @@ +use std::collections::HashSet; + +use api::grpc::qdrant::vector_input::Variant; +use api::grpc::qdrant::{ + query, ContextInput, ContextInputPair, DiscoverInput, PrefetchQuery, Query, RecommendInput, + VectorInput, +}; +use api::rest::schema as rest; +use tonic::Status; + +use super::service::{InferenceData, InferenceInput, InferenceRequest}; + +pub struct BatchAccumGrpc { + pub(crate) objects: HashSet, +} + +impl BatchAccumGrpc { + pub fn new() -> Self { + Self { + objects: HashSet::new(), + } + } + + pub fn add(&mut self, data: InferenceData) { + self.objects.insert(data); + } + + pub fn extend(&mut self, other: BatchAccumGrpc) { + self.objects.extend(other.objects); + } + + pub fn is_empty(&self) -> bool { + self.objects.is_empty() + } +} + +impl From<&BatchAccumGrpc> for InferenceRequest { + fn from(batch: &BatchAccumGrpc) -> Self { + Self { + inputs: batch + .objects + .iter() + .cloned() + .map(InferenceInput::from) + .collect(), + inference: None, + token: None, + } + } +} + +fn collect_vector_input(vector: &VectorInput, batch: &mut BatchAccumGrpc) -> Result<(), Status> { + let Some(variant) = &vector.variant else { + return Ok(()); + }; + + match variant { + Variant::Id(_) => {} + Variant::Dense(_) => {} + Variant::Sparse(_) => {} + Variant::MultiDense(_) => {} + Variant::Document(document) => { + let doc = rest::Document::try_from(document.clone()) + .map_err(|e| Status::internal(format!("Document conversion error: {e:?}")))?; + batch.add(InferenceData::Document(doc)); + } + Variant::Image(image) => { + let img = rest::Image::try_from(image.clone()) + .map_err(|e| Status::internal(format!("Image conversion error: {e:?}")))?; + batch.add(InferenceData::Image(img)); + } + Variant::Object(object) => { + let obj = rest::InferenceObject::try_from(object.clone()) + .map_err(|e| Status::internal(format!("Object conversion error: {e:?}")))?; + batch.add(InferenceData::Object(obj)); + } + } + Ok(()) +} + +pub(crate) fn collect_context_input( + context: &ContextInput, + batch: &mut BatchAccumGrpc, +) -> Result<(), Status> { + let ContextInput { pairs } = context; + + for pair in pairs { + collect_context_input_pair(pair, batch)?; + } + + Ok(()) +} + +fn collect_context_input_pair( + pair: &ContextInputPair, + batch: &mut BatchAccumGrpc, +) -> Result<(), Status> { + let ContextInputPair { positive, negative } = pair; + + if let Some(positive) = positive { + collect_vector_input(positive, batch)?; + } + + if let Some(negative) = negative { + collect_vector_input(negative, batch)?; + } + + Ok(()) +} + +pub(crate) fn collect_discover_input( + discover: &DiscoverInput, + batch: &mut BatchAccumGrpc, +) -> Result<(), Status> { + let DiscoverInput { target, context } = discover; + + if let Some(vector) = target { + collect_vector_input(vector, batch)?; + } + + if let Some(context) = context { + for pair in &context.pairs { + collect_context_input_pair(pair, batch)?; + } + } + + Ok(()) +} + +pub(crate) fn collect_recommend_input( + recommend: &RecommendInput, + batch: &mut BatchAccumGrpc, +) -> Result<(), Status> { + let RecommendInput { + positive, + negative, + strategy: _, + } = recommend; + + for vector in positive { + collect_vector_input(vector, batch)?; + } + + for vector in negative { + collect_vector_input(vector, batch)?; + } + + Ok(()) +} + +pub(crate) fn collect_query(query: &Query, batch: &mut BatchAccumGrpc) -> Result<(), Status> { + let Some(variant) = &query.variant else { + return Ok(()); + }; + + match variant { + query::Variant::Nearest(nearest) => collect_vector_input(nearest, batch)?, + query::Variant::Recommend(recommend) => collect_recommend_input(recommend, batch)?, + query::Variant::Discover(discover) => collect_discover_input(discover, batch)?, + query::Variant::Context(context) => collect_context_input(context, batch)?, + query::Variant::OrderBy(_) => {} + query::Variant::Fusion(_) => {} + query::Variant::Sample(_) => {} + } + + Ok(()) +} + +pub(crate) fn collect_prefetch( + prefetch: &PrefetchQuery, + batch: &mut BatchAccumGrpc, +) -> Result<(), Status> { + let PrefetchQuery { + prefetch, + query, + using: _, + filter: _, + params: _, + score_threshold: _, + limit: _, + lookup_from: _, + } = prefetch; + + if let Some(query) = query { + collect_query(query, batch)?; + } + + for p in prefetch { + collect_prefetch(p, batch)?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use api::rest::schema::{Document, Image, InferenceObject}; + use serde_json::json; + + use super::*; + + fn create_test_document(text: &str) -> Document { + Document { + text: text.to_string(), + model: "test-model".to_string(), + options: Default::default(), + } + } + + fn create_test_image(url: &str) -> Image { + Image { + image: json!({"data": url.to_string()}), + model: "test-model".to_string(), + options: Default::default(), + } + } + + fn create_test_object(data: &str) -> InferenceObject { + InferenceObject { + object: json!({"data": data}), + model: "test-model".to_string(), + options: Default::default(), + } + } + + #[test] + fn test_batch_accum_basic() { + let mut batch = BatchAccumGrpc::new(); + assert!(batch.objects.is_empty()); + + let doc = InferenceData::Document(create_test_document("test")); + batch.add(doc.clone()); + assert_eq!(batch.objects.len(), 1); + + batch.add(doc); + assert_eq!(batch.objects.len(), 1); + } + + #[test] + fn test_batch_accum_extend() { + let mut batch1 = BatchAccumGrpc::new(); + let mut batch2 = BatchAccumGrpc::new(); + + let doc1 = InferenceData::Document(create_test_document("test1")); + let doc2 = InferenceData::Document(create_test_document("test2")); + + batch1.add(doc1); + batch2.add(doc2); + + batch1.extend(batch2); + assert_eq!(batch1.objects.len(), 2); + } + + #[test] + fn test_deduplication() { + let mut batch = BatchAccumGrpc::new(); + + let doc1 = InferenceData::Document(create_test_document("same")); + let doc2 = InferenceData::Document(create_test_document("same")); + + batch.add(doc1); + batch.add(doc2); + + assert_eq!(batch.objects.len(), 1); + } + + #[test] + fn test_different_model_same_content() { + let mut batch = BatchAccumGrpc::new(); + + let mut doc1 = create_test_document("same"); + let mut doc2 = create_test_document("same"); + doc1.model = "model1".to_string(); + doc2.model = "model2".to_string(); + + batch.add(InferenceData::Document(doc1)); + batch.add(InferenceData::Document(doc2)); + + assert_eq!(batch.objects.len(), 2); + } +} diff --git a/src/common/inference/config.rs b/src/common/inference/config.rs new file mode 100644 index 0000000000000000000000000000000000000000..8bb9c36640e7414b5ea8dad35b4a216fa266a9eb --- /dev/null +++ b/src/common/inference/config.rs @@ -0,0 +1,23 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InferenceConfig { + pub address: Option, + #[serde(default = "default_inference_timeout")] + pub timeout: u64, + pub token: Option, +} + +fn default_inference_timeout() -> u64 { + 10 +} + +impl InferenceConfig { + pub fn new(address: Option) -> Self { + Self { + address, + timeout: default_inference_timeout(), + token: None, + } + } +} diff --git a/src/common/inference/infer_processing.rs b/src/common/inference/infer_processing.rs new file mode 100644 index 0000000000000000000000000000000000000000..16bc2a839996622d4b1dcf550ed2cdb9395ac113 --- /dev/null +++ b/src/common/inference/infer_processing.rs @@ -0,0 +1,72 @@ +use std::collections::{HashMap, HashSet}; + +use collection::operations::point_ops::VectorPersisted; +use storage::content_manager::errors::StorageError; + +use super::batch_processing::BatchAccum; +use super::service::{InferenceData, InferenceInput, InferenceService, InferenceType}; + +pub struct BatchAccumInferred { + pub(crate) objects: HashMap, +} + +impl BatchAccumInferred { + pub fn new() -> Self { + Self { + objects: HashMap::new(), + } + } + + pub async fn from_objects( + objects: HashSet, + inference_type: InferenceType, + ) -> Result { + if objects.is_empty() { + return Ok(Self::new()); + } + + let Some(service) = InferenceService::get_global() else { + return Err(StorageError::service_error( + "InferenceService is not initialized. Please check if it was properly configured and initialized during startup." + )); + }; + + service.validate()?; + + let objects_serialized: Vec<_> = objects.into_iter().collect(); + let inference_inputs: Vec<_> = objects_serialized + .iter() + .cloned() + .map(InferenceInput::from) + .collect(); + + let vectors = service + .infer(inference_inputs, inference_type) + .await + .map_err(|e| StorageError::service_error( + format!("Inference request failed. Check if inference service is running and properly configured: {e}") + ))?; + + if vectors.is_empty() { + return Err(StorageError::service_error( + "Inference service returned no vectors. Check if models are properly loaded.", + )); + } + + let objects = objects_serialized.into_iter().zip(vectors).collect(); + + Ok(Self { objects }) + } + + pub async fn from_batch_accum( + batch: BatchAccum, + inference_type: InferenceType, + ) -> Result { + let BatchAccum { objects } = batch; + Self::from_objects(objects, inference_type).await + } + + pub fn get_vector(&self, data: &InferenceData) -> Option<&VectorPersisted> { + self.objects.get(data) + } +} diff --git a/src/common/inference/mod.rs b/src/common/inference/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..3e68cc5bc06818808b2a8592c93a18933f9ac679 --- /dev/null +++ b/src/common/inference/mod.rs @@ -0,0 +1,8 @@ +mod batch_processing; +mod batch_processing_grpc; +pub(crate) mod config; +mod infer_processing; +pub mod query_requests_grpc; +pub mod query_requests_rest; +pub mod service; +pub mod update_requests; diff --git a/src/common/inference/query_requests_grpc.rs b/src/common/inference/query_requests_grpc.rs new file mode 100644 index 0000000000000000000000000000000000000000..c3f2765ead9a88b604a00c15fc8a94ef019aac25 --- /dev/null +++ b/src/common/inference/query_requests_grpc.rs @@ -0,0 +1,535 @@ +use api::conversions::json::json_path_from_proto; +use api::grpc::qdrant as grpc; +use api::grpc::qdrant::query::Variant; +use api::grpc::qdrant::RecommendInput; +use api::rest; +use api::rest::RecommendStrategy; +use collection::operations::universal_query::collection_query::{ + CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, Query, + VectorInputInternal, VectorQuery, +}; +use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal}; +use segment::data_types::order_by::OrderBy; +use segment::data_types::vectors::{VectorInternal, DEFAULT_VECTOR_NAME}; +use segment::vector_storage::query::{ContextPair, ContextQuery, DiscoveryQuery, RecoQuery}; +use tonic::Status; + +use crate::common::inference::batch_processing_grpc::{ + collect_prefetch, collect_query, BatchAccumGrpc, +}; +use crate::common::inference::infer_processing::BatchAccumInferred; +use crate::common::inference::service::{InferenceData, InferenceType}; + +/// ToDo: this function is supposed to call an inference endpoint internally +pub async fn convert_query_point_groups_from_grpc( + query: grpc::QueryPointGroups, +) -> Result { + let grpc::QueryPointGroups { + collection_name: _, + prefetch, + query, + using, + filter, + params, + score_threshold, + with_payload, + with_vectors, + lookup_from, + limit, + group_size, + group_by, + with_lookup, + read_consistency: _, + timeout: _, + shard_key_selector: _, + } = query; + + let mut batch = BatchAccumGrpc::new(); + + if let Some(q) = &query { + collect_query(q, &mut batch)?; + } + + for p in &prefetch { + collect_prefetch(p, &mut batch)?; + } + + let BatchAccumGrpc { objects } = batch; + + let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search) + .await + .map_err(|e| Status::internal(format!("Inference error: {e}")))?; + + let query = if let Some(q) = query { + Some(convert_query_with_inferred(q, &inferred)?) + } else { + None + }; + + let prefetch = prefetch + .into_iter() + .map(|p| convert_prefetch_with_inferred(p, &inferred)) + .collect::, _>>()?; + + let request = CollectionQueryGroupsRequest { + prefetch, + query, + using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), + filter: filter.map(TryFrom::try_from).transpose()?, + score_threshold, + with_vector: with_vectors + .map(From::from) + .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR), + with_payload: with_payload + .map(TryFrom::try_from) + .transpose()? + .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD), + lookup_from: lookup_from.map(From::from), + group_by: json_path_from_proto(&group_by)?, + group_size: group_size + .map(|s| s as usize) + .unwrap_or(CollectionQueryRequest::DEFAULT_GROUP_SIZE), + limit: limit + .map(|l| l as usize) + .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), + params: params.map(From::from), + with_lookup: with_lookup.map(TryFrom::try_from).transpose()?, + }; + + Ok(request) +} + +/// ToDo: this function is supposed to call an inference endpoint internally +pub async fn convert_query_points_from_grpc( + query: grpc::QueryPoints, +) -> Result { + let grpc::QueryPoints { + collection_name: _, + prefetch, + query, + using, + filter, + params, + score_threshold, + limit, + offset, + with_payload, + with_vectors, + read_consistency: _, + shard_key_selector: _, + lookup_from, + timeout: _, + } = query; + + let mut batch = BatchAccumGrpc::new(); + + if let Some(q) = &query { + collect_query(q, &mut batch)?; + } + + for p in &prefetch { + collect_prefetch(p, &mut batch)?; + } + + let BatchAccumGrpc { objects } = batch; + + let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search) + .await + .map_err(|e| Status::internal(format!("Inference error: {e}")))?; + + let prefetch = prefetch + .into_iter() + .map(|p| convert_prefetch_with_inferred(p, &inferred)) + .collect::, _>>()?; + + let query = query + .map(|q| convert_query_with_inferred(q, &inferred)) + .transpose()?; + + Ok(CollectionQueryRequest { + prefetch, + query, + using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), + filter: filter.map(TryFrom::try_from).transpose()?, + score_threshold, + limit: limit + .map(|l| l as usize) + .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), + offset: offset + .map(|o| o as usize) + .unwrap_or(CollectionQueryRequest::DEFAULT_OFFSET), + params: params.map(From::from), + with_vector: with_vectors + .map(From::from) + .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR), + with_payload: with_payload + .map(TryFrom::try_from) + .transpose()? + .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD), + lookup_from: lookup_from.map(From::from), + }) +} + +fn convert_prefetch_with_inferred( + prefetch: grpc::PrefetchQuery, + inferred: &BatchAccumInferred, +) -> Result { + let grpc::PrefetchQuery { + prefetch, + query, + using, + filter, + params, + score_threshold, + limit, + lookup_from, + } = prefetch; + + let nested_prefetches = prefetch + .into_iter() + .map(|p| convert_prefetch_with_inferred(p, inferred)) + .collect::, _>>()?; + + let query = query + .map(|q| convert_query_with_inferred(q, inferred)) + .transpose()?; + + Ok(CollectionPrefetch { + prefetch: nested_prefetches, + query, + using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), + filter: filter.map(TryFrom::try_from).transpose()?, + score_threshold, + limit: limit + .map(|l| l as usize) + .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), + params: params.map(From::from), + lookup_from: lookup_from.map(From::from), + }) +} + +fn convert_query_with_inferred( + query: grpc::Query, + inferred: &BatchAccumInferred, +) -> Result { + let variant = query + .variant + .ok_or_else(|| Status::invalid_argument("Query variant is missing"))?; + + let query = match variant { + Variant::Nearest(nearest) => { + let vector = convert_vector_input_with_inferred(nearest, inferred)?; + Query::Vector(VectorQuery::Nearest(vector)) + } + Variant::Recommend(recommend) => { + let RecommendInput { + positive, + negative, + strategy, + } = recommend; + + let positives = positive + .into_iter() + .map(|v| convert_vector_input_with_inferred(v, inferred)) + .collect::, _>>()?; + + let negatives = negative + .into_iter() + .map(|v| convert_vector_input_with_inferred(v, inferred)) + .collect::, _>>()?; + + let reco_query = RecoQuery::new(positives, negatives); + + let strategy = strategy + .and_then(|x| grpc::RecommendStrategy::try_from(x).ok()) + .map(RecommendStrategy::from) + .unwrap_or_default(); + + match strategy { + RecommendStrategy::AverageVector => { + Query::Vector(VectorQuery::RecommendAverageVector(reco_query)) + } + RecommendStrategy::BestScore => { + Query::Vector(VectorQuery::RecommendBestScore(reco_query)) + } + } + } + Variant::Discover(discover) => { + let grpc::DiscoverInput { target, context } = discover; + + let target = target + .map(|t| convert_vector_input_with_inferred(t, inferred)) + .transpose()? + .ok_or_else(|| Status::invalid_argument("DiscoverInput target is missing"))?; + + let grpc::ContextInput { pairs } = context + .ok_or_else(|| Status::invalid_argument("DiscoverInput context is missing"))?; + + let context = pairs + .into_iter() + .map(|pair| context_pair_from_grpc_with_inferred(pair, inferred)) + .collect::>()?; + + Query::Vector(VectorQuery::Discover(DiscoveryQuery::new(target, context))) + } + Variant::Context(context) => { + let context_query = context_query_from_grpc_with_inferred(context, inferred)?; + Query::Vector(VectorQuery::Context(context_query)) + } + Variant::OrderBy(order_by) => Query::OrderBy(OrderBy::try_from(order_by)?), + Variant::Fusion(fusion) => Query::Fusion(FusionInternal::try_from(fusion)?), + Variant::Sample(sample) => Query::Sample(SampleInternal::try_from(sample)?), + }; + + Ok(query) +} + +fn convert_vector_input_with_inferred( + vector: grpc::VectorInput, + inferred: &BatchAccumInferred, +) -> Result { + use api::grpc::qdrant::vector_input::Variant; + + let variant = vector + .variant + .ok_or_else(|| Status::invalid_argument("VectorInput variant is missing"))?; + + match variant { + Variant::Id(id) => Ok(VectorInputInternal::Id(TryFrom::try_from(id)?)), + Variant::Dense(dense) => Ok(VectorInputInternal::Vector(VectorInternal::Dense( + From::from(dense), + ))), + Variant::Sparse(sparse) => Ok(VectorInputInternal::Vector(VectorInternal::Sparse( + From::from(sparse), + ))), + Variant::MultiDense(multi_dense) => Ok(VectorInputInternal::Vector( + VectorInternal::MultiDense(From::from(multi_dense)), + )), + Variant::Document(doc) => { + let doc: rest::Document = doc + .try_into() + .map_err(|e| Status::internal(format!("Document conversion error: {e}")))?; + let data = InferenceData::Document(doc); + let vector = inferred + .get_vector(&data) + .ok_or_else(|| Status::internal("Missing inferred vector for document"))?; + + Ok(VectorInputInternal::Vector(VectorInternal::from( + vector.clone(), + ))) + } + Variant::Image(img) => { + let img: rest::Image = img + .try_into() + .map_err(|e| Status::internal(format!("Image conversion error: {e}",)))?; + let data = InferenceData::Image(img); + + let vector = inferred + .get_vector(&data) + .ok_or_else(|| Status::internal("Missing inferred vector for image"))?; + + Ok(VectorInputInternal::Vector(VectorInternal::from( + vector.clone(), + ))) + } + Variant::Object(obj) => { + let obj: rest::InferenceObject = obj + .try_into() + .map_err(|e| Status::internal(format!("Object conversion error: {e}")))?; + let data = InferenceData::Object(obj); + let vector = inferred + .get_vector(&data) + .ok_or_else(|| Status::internal("Missing inferred vector for object"))?; + + Ok(VectorInputInternal::Vector(VectorInternal::from( + vector.clone(), + ))) + } + } +} + +fn context_query_from_grpc_with_inferred( + value: grpc::ContextInput, + inferred: &BatchAccumInferred, +) -> Result, Status> { + let grpc::ContextInput { pairs } = value; + + Ok(ContextQuery { + pairs: pairs + .into_iter() + .map(|pair| context_pair_from_grpc_with_inferred(pair, inferred)) + .collect::>()?, + }) +} + +fn context_pair_from_grpc_with_inferred( + value: grpc::ContextInputPair, + inferred: &BatchAccumInferred, +) -> Result, Status> { + let grpc::ContextInputPair { positive, negative } = value; + + let positive = + positive.ok_or_else(|| Status::invalid_argument("ContextPair positive is missing"))?; + let negative = + negative.ok_or_else(|| Status::invalid_argument("ContextPair negative is missing"))?; + + Ok(ContextPair { + positive: convert_vector_input_with_inferred(positive, inferred)?, + negative: convert_vector_input_with_inferred(negative, inferred)?, + }) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use api::grpc::qdrant::value::Kind; + use api::grpc::qdrant::vector_input::Variant; + use api::grpc::qdrant::Value; + use collection::operations::point_ops::VectorPersisted; + + use super::*; + + fn create_test_document() -> api::grpc::qdrant::Document { + api::grpc::qdrant::Document { + text: "test".to_string(), + model: "test-model".to_string(), + options: HashMap::new(), + } + } + + fn create_test_image() -> api::grpc::qdrant::Image { + api::grpc::qdrant::Image { + image: Some(Value { + kind: Some(Kind::StringValue("test.jpg".to_string())), + }), + model: "test-model".to_string(), + options: HashMap::new(), + } + } + + fn create_test_object() -> api::grpc::qdrant::InferenceObject { + api::grpc::qdrant::InferenceObject { + object: Some(Value { + kind: Some(Kind::StringValue("test".to_string())), + }), + model: "test-model".to_string(), + options: HashMap::new(), + } + } + + fn create_test_inferred_batch() -> BatchAccumInferred { + let mut objects = HashMap::new(); + + let grpc_doc = create_test_document(); + let grpc_img = create_test_image(); + let grpc_obj = create_test_object(); + + let doc: rest::Document = grpc_doc.try_into().unwrap(); + let img: rest::Image = grpc_img.try_into().unwrap(); + let obj: rest::InferenceObject = grpc_obj.try_into().unwrap(); + + let doc_data = InferenceData::Document(doc); + let img_data = InferenceData::Image(img); + let obj_data = InferenceData::Object(obj); + + let dense_vector = vec![1.0, 2.0, 3.0]; + let vector_persisted = VectorPersisted::Dense(dense_vector); + + objects.insert(doc_data, vector_persisted.clone()); + objects.insert(img_data, vector_persisted.clone()); + objects.insert(obj_data, vector_persisted); + + BatchAccumInferred { objects } + } + + #[test] + fn test_convert_vector_input_with_inferred_dense() { + let inferred = create_test_inferred_batch(); + let vector = grpc::VectorInput { + variant: Some(Variant::Dense(grpc::DenseVector { + data: vec![1.0, 2.0, 3.0], + })), + }; + + let result = convert_vector_input_with_inferred(vector, &inferred).unwrap(); + match result { + VectorInputInternal::Vector(VectorInternal::Dense(values)) => { + assert_eq!(values, vec![1.0, 2.0, 3.0]); + } + _ => panic!("Expected dense vector"), + } + } + + #[test] + fn test_convert_vector_input_with_inferred_document() { + let inferred = create_test_inferred_batch(); + let doc = create_test_document(); + let vector = grpc::VectorInput { + variant: Some(Variant::Document(doc)), + }; + + let result = convert_vector_input_with_inferred(vector, &inferred).unwrap(); + match result { + VectorInputInternal::Vector(VectorInternal::Dense(values)) => { + assert_eq!(values, vec![1.0, 2.0, 3.0]); + } + _ => panic!("Expected dense vector from inference"), + } + } + + #[test] + fn test_convert_vector_input_missing_variant() { + let inferred = create_test_inferred_batch(); + let vector = grpc::VectorInput { variant: None }; + + let result = convert_vector_input_with_inferred(vector, &inferred); + assert!(result.is_err()); + assert!(result.unwrap_err().message().contains("variant is missing")); + } + + #[test] + fn test_context_pair_from_grpc_with_inferred() { + let inferred = create_test_inferred_batch(); + let pair = grpc::ContextInputPair { + positive: Some(grpc::VectorInput { + variant: Some(Variant::Dense(grpc::DenseVector { + data: vec![1.0, 2.0, 3.0], + })), + }), + negative: Some(grpc::VectorInput { + variant: Some(Variant::Document(create_test_document())), + }), + }; + + let result = context_pair_from_grpc_with_inferred(pair, &inferred).unwrap(); + match (result.positive, result.negative) { + ( + VectorInputInternal::Vector(VectorInternal::Dense(pos)), + VectorInputInternal::Vector(VectorInternal::Dense(neg)), + ) => { + assert_eq!(pos, vec![1.0, 2.0, 3.0]); + assert_eq!(neg, vec![1.0, 2.0, 3.0]); + } + _ => panic!("Expected dense vectors"), + } + } + + #[test] + fn test_context_pair_missing_vectors() { + let inferred = create_test_inferred_batch(); + let pair = grpc::ContextInputPair { + positive: None, + negative: Some(grpc::VectorInput { + variant: Some(Variant::Document(create_test_document())), + }), + }; + + let result = context_pair_from_grpc_with_inferred(pair, &inferred); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .message() + .contains("positive is missing")); + } +} diff --git a/src/common/inference/query_requests_rest.rs b/src/common/inference/query_requests_rest.rs new file mode 100644 index 0000000000000000000000000000000000000000..b3cb59cb9c77e919a2f6567c50a42375c307bc18 --- /dev/null +++ b/src/common/inference/query_requests_rest.rs @@ -0,0 +1,415 @@ +use api::rest::schema as rest; +use collection::lookup::WithLookup; +use collection::operations::universal_query::collection_query::{ + CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, Query, + VectorInputInternal, VectorQuery, +}; +use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal}; +use segment::data_types::order_by::OrderBy; +use segment::data_types::vectors::{MultiDenseVectorInternal, VectorInternal, DEFAULT_VECTOR_NAME}; +use segment::vector_storage::query::{ContextPair, ContextQuery, DiscoveryQuery, RecoQuery}; +use storage::content_manager::errors::StorageError; + +use crate::common::inference::batch_processing::{ + collect_query_groups_request, collect_query_request, +}; +use crate::common::inference::infer_processing::BatchAccumInferred; +use crate::common::inference::service::{InferenceData, InferenceType}; + +pub async fn convert_query_groups_request_from_rest( + request: rest::QueryGroupsRequestInternal, +) -> Result { + let batch = collect_query_groups_request(&request); + let rest::QueryGroupsRequestInternal { + prefetch, + query, + using, + filter, + score_threshold, + params, + with_vector, + with_payload, + lookup_from, + group_request, + } = request; + + let inferred = BatchAccumInferred::from_batch_accum(batch, InferenceType::Search).await?; + let query = query + .map(|q| convert_query_with_inferred(q, &inferred)) + .transpose()?; + + let prefetch = prefetch + .map(|prefetches| { + prefetches + .into_iter() + .map(|p| convert_prefetch_with_inferred(p, &inferred)) + .collect::, _>>() + }) + .transpose()? + .unwrap_or_default(); + + Ok(CollectionQueryGroupsRequest { + prefetch, + query, + using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), + filter, + score_threshold, + params, + with_vector: with_vector.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR), + with_payload: with_payload.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD), + lookup_from, + limit: group_request + .limit + .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), + group_by: group_request.group_by, + group_size: group_request + .group_size + .unwrap_or(CollectionQueryRequest::DEFAULT_GROUP_SIZE), + with_lookup: group_request.with_lookup.map(WithLookup::from), + }) +} + +pub async fn convert_query_request_from_rest( + request: rest::QueryRequestInternal, +) -> Result { + let batch = collect_query_request(&request); + let inferred = BatchAccumInferred::from_batch_accum(batch, InferenceType::Search).await?; + let rest::QueryRequestInternal { + prefetch, + query, + using, + filter, + score_threshold, + params, + limit, + offset, + with_vector, + with_payload, + lookup_from, + } = request; + + let prefetch = prefetch + .map(|prefetches| { + prefetches + .into_iter() + .map(|p| convert_prefetch_with_inferred(p, &inferred)) + .collect::, _>>() + }) + .transpose()? + .unwrap_or_default(); + + let query = query + .map(|q| convert_query_with_inferred(q, &inferred)) + .transpose()?; + + Ok(CollectionQueryRequest { + prefetch, + query, + using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), + filter, + score_threshold, + limit: limit.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), + offset: offset.unwrap_or(CollectionQueryRequest::DEFAULT_OFFSET), + params, + with_vector: with_vector.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR), + with_payload: with_payload.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD), + lookup_from, + }) +} + +fn convert_vector_input_with_inferred( + vector: rest::VectorInput, + inferred: &BatchAccumInferred, +) -> Result { + match vector { + rest::VectorInput::Id(id) => Ok(VectorInputInternal::Id(id)), + rest::VectorInput::DenseVector(dense) => { + Ok(VectorInputInternal::Vector(VectorInternal::Dense(dense))) + } + rest::VectorInput::SparseVector(sparse) => { + Ok(VectorInputInternal::Vector(VectorInternal::Sparse(sparse))) + } + rest::VectorInput::MultiDenseVector(multi_dense) => Ok(VectorInputInternal::Vector( + VectorInternal::MultiDense(MultiDenseVectorInternal::new_unchecked(multi_dense)), + )), + rest::VectorInput::Document(doc) => { + let data = InferenceData::Document(doc); + let vector = inferred.get_vector(&data).ok_or_else(|| { + StorageError::inference_error("Missing inferred vector for document") + })?; + Ok(VectorInputInternal::Vector(VectorInternal::from( + vector.clone(), + ))) + } + rest::VectorInput::Image(img) => { + let data = InferenceData::Image(img); + let vector = inferred.get_vector(&data).ok_or_else(|| { + StorageError::inference_error("Missing inferred vector for image") + })?; + Ok(VectorInputInternal::Vector(VectorInternal::from( + vector.clone(), + ))) + } + rest::VectorInput::Object(obj) => { + let data = InferenceData::Object(obj); + let vector = inferred.get_vector(&data).ok_or_else(|| { + StorageError::inference_error("Missing inferred vector for object") + })?; + Ok(VectorInputInternal::Vector(VectorInternal::from( + vector.clone(), + ))) + } + } +} + +fn convert_query_with_inferred( + query: rest::QueryInterface, + inferred: &BatchAccumInferred, +) -> Result { + let query = rest::Query::from(query); + match query { + rest::Query::Nearest(nearest) => { + let vector = convert_vector_input_with_inferred(nearest.nearest, inferred)?; + Ok(Query::Vector(VectorQuery::Nearest(vector))) + } + rest::Query::Recommend(recommend) => { + let rest::RecommendInput { + positive, + negative, + strategy, + } = recommend.recommend; + let positives = positive + .into_iter() + .flatten() + .map(|v| convert_vector_input_with_inferred(v, inferred)) + .collect::, _>>()?; + let negatives = negative + .into_iter() + .flatten() + .map(|v| convert_vector_input_with_inferred(v, inferred)) + .collect::, _>>()?; + let reco_query = RecoQuery::new(positives, negatives); + match strategy.unwrap_or_default() { + rest::RecommendStrategy::AverageVector => Ok(Query::Vector( + VectorQuery::RecommendAverageVector(reco_query), + )), + rest::RecommendStrategy::BestScore => { + Ok(Query::Vector(VectorQuery::RecommendBestScore(reco_query))) + } + } + } + rest::Query::Discover(discover) => { + let rest::DiscoverInput { target, context } = discover.discover; + let target = convert_vector_input_with_inferred(target, inferred)?; + let context = context + .into_iter() + .flatten() + .map(|pair| context_pair_from_rest_with_inferred(pair, inferred)) + .collect::, _>>()?; + Ok(Query::Vector(VectorQuery::Discover(DiscoveryQuery::new( + target, context, + )))) + } + rest::Query::Context(context) => { + let rest::ContextInput(context) = context.context; + let context = context + .into_iter() + .flatten() + .map(|pair| context_pair_from_rest_with_inferred(pair, inferred)) + .collect::, _>>()?; + Ok(Query::Vector(VectorQuery::Context(ContextQuery::new( + context, + )))) + } + rest::Query::OrderBy(order_by) => Ok(Query::OrderBy(OrderBy::from(order_by.order_by))), + rest::Query::Fusion(fusion) => Ok(Query::Fusion(FusionInternal::from(fusion.fusion))), + rest::Query::Sample(sample) => Ok(Query::Sample(SampleInternal::from(sample.sample))), + } +} + +fn convert_prefetch_with_inferred( + prefetch: rest::Prefetch, + inferred: &BatchAccumInferred, +) -> Result { + let rest::Prefetch { + prefetch, + query, + using, + filter, + score_threshold, + params, + limit, + lookup_from, + } = prefetch; + + let query = query + .map(|q| convert_query_with_inferred(q, inferred)) + .transpose()?; + let nested_prefetches = prefetch + .map(|prefetches| { + prefetches + .into_iter() + .map(|p| convert_prefetch_with_inferred(p, inferred)) + .collect::, _>>() + }) + .transpose()? + .unwrap_or_default(); + + Ok(CollectionPrefetch { + prefetch: nested_prefetches, + query, + using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), + filter, + score_threshold, + limit: limit.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT), + params, + lookup_from, + }) +} + +fn context_pair_from_rest_with_inferred( + value: rest::ContextPair, + inferred: &BatchAccumInferred, +) -> Result, StorageError> { + let rest::ContextPair { positive, negative } = value; + Ok(ContextPair { + positive: convert_vector_input_with_inferred(positive, inferred)?, + negative: convert_vector_input_with_inferred(negative, inferred)?, + }) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use api::rest::schema::{Document, Image, InferenceObject, NearestQuery}; + use collection::operations::point_ops::VectorPersisted; + use serde_json::json; + + use super::*; + + fn create_test_document(text: &str) -> Document { + Document { + text: text.to_string(), + model: "test-model".to_string(), + options: Default::default(), + } + } + + fn create_test_image(url: &str) -> Image { + Image { + image: json!({"data": url.to_string()}), + model: "test-model".to_string(), + options: Default::default(), + } + } + + fn create_test_object(data: &str) -> InferenceObject { + InferenceObject { + object: json!({"data": data}), + model: "test-model".to_string(), + options: Default::default(), + } + } + + fn create_test_inferred_batch() -> BatchAccumInferred { + let mut objects = HashMap::new(); + + let doc = InferenceData::Document(create_test_document("test")); + let img = InferenceData::Image(create_test_image("test.jpg")); + let obj = InferenceData::Object(create_test_object("test")); + + let dense_vector = vec![1.0, 2.0, 3.0]; + let vector_persisted = VectorPersisted::Dense(dense_vector); + + objects.insert(doc, vector_persisted.clone()); + objects.insert(img, vector_persisted.clone()); + objects.insert(obj, vector_persisted); + + BatchAccumInferred { objects } + } + + #[test] + fn test_convert_vector_input_with_inferred_dense() { + let inferred = create_test_inferred_batch(); + let vector = rest::VectorInput::DenseVector(vec![1.0, 2.0, 3.0]); + + let result = convert_vector_input_with_inferred(vector, &inferred).unwrap(); + match result { + VectorInputInternal::Vector(VectorInternal::Dense(values)) => { + assert_eq!(values, vec![1.0, 2.0, 3.0]); + } + _ => panic!("Expected dense vector"), + } + } + + #[test] + fn test_convert_vector_input_with_inferred_document() { + let inferred = create_test_inferred_batch(); + let doc = create_test_document("test"); + let vector = rest::VectorInput::Document(doc); + + let result = convert_vector_input_with_inferred(vector, &inferred).unwrap(); + match result { + VectorInputInternal::Vector(VectorInternal::Dense(values)) => { + assert_eq!(values, vec![1.0, 2.0, 3.0]); + } + _ => panic!("Expected dense vector from inference"), + } + } + + #[test] + fn test_convert_vector_input_with_inferred_missing() { + let inferred = create_test_inferred_batch(); + let doc = create_test_document("missing"); + let vector = rest::VectorInput::Document(doc); + + let result = convert_vector_input_with_inferred(vector, &inferred); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Missing inferred vector")); + } + + #[test] + fn test_context_pair_from_rest_with_inferred() { + let inferred = create_test_inferred_batch(); + let pair = rest::ContextPair { + positive: rest::VectorInput::DenseVector(vec![1.0, 2.0, 3.0]), + negative: rest::VectorInput::Document(create_test_document("test")), + }; + + let result = context_pair_from_rest_with_inferred(pair, &inferred).unwrap(); + match (result.positive, result.negative) { + ( + VectorInputInternal::Vector(VectorInternal::Dense(pos)), + VectorInputInternal::Vector(VectorInternal::Dense(neg)), + ) => { + assert_eq!(pos, vec![1.0, 2.0, 3.0]); + assert_eq!(neg, vec![1.0, 2.0, 3.0]); + } + _ => panic!("Expected dense vectors"), + } + } + + #[test] + fn test_convert_query_with_inferred_nearest() { + let inferred = create_test_inferred_batch(); + let nearest = NearestQuery { + nearest: rest::VectorInput::Document(create_test_document("test")), + }; + let query = rest::QueryInterface::Query(rest::Query::Nearest(nearest)); + + let result = convert_query_with_inferred(query, &inferred).unwrap(); + match result { + Query::Vector(VectorQuery::Nearest(vector)) => match vector { + VectorInputInternal::Vector(VectorInternal::Dense(values)) => { + assert_eq!(values, vec![1.0, 2.0, 3.0]); + } + _ => panic!("Expected dense vector"), + }, + _ => panic!("Expected nearest query"), + } + } +} diff --git a/src/common/inference/service.rs b/src/common/inference/service.rs new file mode 100644 index 0000000000000000000000000000000000000000..936aa24a16da0675002c8c4723509f2bef87cea3 --- /dev/null +++ b/src/common/inference/service.rs @@ -0,0 +1,266 @@ +use std::collections::HashMap; +use std::fmt::Display; +use std::hash::Hash; +use std::sync::Arc; +use std::time::Duration; + +use api::rest::{Document, Image, InferenceObject}; +use collection::operations::point_ops::VectorPersisted; +use parking_lot::RwLock; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use storage::content_manager::errors::StorageError; + +use crate::common::inference::config::InferenceConfig; + +const DOCUMENT_DATA_TYPE: &str = "text"; +const IMAGE_DATA_TYPE: &str = "image"; +const OBJECT_DATA_TYPE: &str = "object"; + +#[derive(Debug, Serialize, Default, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum InferenceType { + #[default] + Update, + Search, +} + +impl Display for InferenceType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", format!("{self:?}").to_lowercase()) + } +} + +#[derive(Debug, Serialize)] +pub struct InferenceRequest { + pub(crate) inputs: Vec, + pub(crate) inference: Option, + #[serde(default)] + pub(crate) token: Option, +} + +#[derive(Debug, Serialize)] +pub struct InferenceInput { + data: Value, + data_type: String, + model: String, + options: Option>, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct InferenceResponse { + pub(crate) embeddings: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub enum InferenceData { + Document(Document), + Image(Image), + Object(InferenceObject), +} + +impl InferenceData { + pub(crate) fn type_name(&self) -> &'static str { + match self { + InferenceData::Document(_) => "document", + InferenceData::Image(_) => "image", + InferenceData::Object(_) => "object", + } + } +} + +impl From for InferenceInput { + fn from(value: InferenceData) -> Self { + match value { + InferenceData::Document(doc) => { + let Document { + text, + model, + options, + } = doc; + InferenceInput { + data: Value::String(text), + data_type: DOCUMENT_DATA_TYPE.to_string(), + model: model.to_string(), + options: options.options, + } + } + InferenceData::Image(img) => { + let Image { + image, + model, + options, + } = img; + InferenceInput { + data: image, + data_type: IMAGE_DATA_TYPE.to_string(), + model: model.to_string(), + options: options.options, + } + } + InferenceData::Object(obj) => { + let InferenceObject { + object, + model, + options, + } = obj; + InferenceInput { + data: object, + data_type: OBJECT_DATA_TYPE.to_string(), + model: model.to_string(), + options: options.options, + } + } + } + } +} + +pub struct InferenceService { + pub(crate) config: InferenceConfig, + pub(crate) client: Client, +} + +static INFERENCE_SERVICE: RwLock>> = RwLock::new(None); + +impl InferenceService { + pub fn new(config: InferenceConfig) -> Self { + let timeout = Duration::from_secs(config.timeout); + Self { + config, + client: Client::builder() + .timeout(timeout) + .build() + .expect("Invalid timeout value for HTTP client"), + } + } + + pub fn init_global(config: InferenceConfig) -> Result<(), StorageError> { + let mut inference_service = INFERENCE_SERVICE.write(); + + if config.token.is_none() { + return Err(StorageError::service_error( + "Cannot initialize InferenceService: token is required but not provided in config", + )); + } + + if config.address.is_none() || config.address.as_ref().unwrap().is_empty() { + return Err(StorageError::service_error( + "Cannot initialize InferenceService: address is required but not provided or empty in config" + )); + } + + *inference_service = Some(Arc::new(Self::new(config))); + Ok(()) + } + + pub fn get_global() -> Option> { + INFERENCE_SERVICE.read().as_ref().cloned() + } + + pub(crate) fn validate(&self) -> Result<(), StorageError> { + if self + .config + .address + .as_ref() + .map_or(true, |url| url.is_empty()) + { + return Err(StorageError::service_error( + "InferenceService configuration error: address is missing or empty", + )); + } + Ok(()) + } + + pub async fn infer( + &self, + inference_inputs: Vec, + inference_type: InferenceType, + ) -> Result, StorageError> { + let request = InferenceRequest { + inputs: inference_inputs, + inference: Some(inference_type), + token: self.config.token.clone(), + }; + + let url = self.config.address.as_ref().ok_or_else(|| { + StorageError::service_error( + "InferenceService URL not configured - please provide valid address in config", + ) + })?; + + let response = self + .client + .post(url) + .json(&request) + .send() + .await + .map_err(|e| { + let error_body = e.to_string(); + StorageError::service_error(format!( + "Failed to send inference request to {url}: {e}, error details: {error_body}", + )) + })?; + + let status = response.status(); + let response_body = response.text().await.map_err(|e| { + StorageError::service_error(format!("Failed to read inference response body: {e}",)) + })?; + + Self::handle_inference_response(status, &response_body) + } + + pub(crate) fn handle_inference_response( + status: reqwest::StatusCode, + response_body: &str, + ) -> Result, StorageError> { + match status { + reqwest::StatusCode::OK => { + let inference_response: InferenceResponse = serde_json::from_str(response_body) + .map_err(|e| { + StorageError::service_error(format!( + "Failed to parse successful inference response: {e}. Response body: {response_body}", + )) + })?; + + if inference_response.embeddings.is_empty() { + Err(StorageError::service_error( + "Inference response contained no embeddings - this may indicate an issue with the model or input" + )) + } else { + Ok(inference_response.embeddings) + } + } + reqwest::StatusCode::BAD_REQUEST => { + let error_json: Value = serde_json::from_str(response_body).map_err(|e| { + StorageError::service_error(format!( + "Failed to parse error response: {e}. Raw response: {response_body}", + )) + })?; + + if let Some(error_message) = error_json["error"].as_str() { + Err(StorageError::bad_request(format!( + "Inference request validation failed: {error_message}", + ))) + } else { + Err(StorageError::bad_request(format!( + "Invalid inference request: {response_body}", + ))) + } + } + status @ (reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN) => { + Err(StorageError::service_error(format!( + "Authentication failed for inference service ({status}): {response_body}", + ))) + } + status @ (reqwest::StatusCode::INTERNAL_SERVER_ERROR + | reqwest::StatusCode::SERVICE_UNAVAILABLE + | reqwest::StatusCode::GATEWAY_TIMEOUT) => Err(StorageError::service_error(format!( + "Inference service error ({status}): {response_body}", + ))), + _ => Err(StorageError::service_error(format!( + "Unexpected inference service response ({status}): {response_body}" + ))), + } + } +} diff --git a/src/common/inference/update_requests.rs b/src/common/inference/update_requests.rs new file mode 100644 index 0000000000000000000000000000000000000000..83b0f09981b4ea5254c004f5d003a5b4b4441dba --- /dev/null +++ b/src/common/inference/update_requests.rs @@ -0,0 +1,409 @@ +use std::collections::HashMap; + +use api::rest::{Batch, BatchVectorStruct, PointStruct, PointVectors, Vector, VectorStruct}; +use collection::operations::point_ops::{ + BatchPersisted, BatchVectorStructPersisted, PointStructPersisted, VectorPersisted, + VectorStructPersisted, +}; +use collection::operations::vector_ops::PointVectorsPersisted; +use storage::content_manager::errors::StorageError; + +use crate::common::inference::batch_processing::BatchAccum; +use crate::common::inference::infer_processing::BatchAccumInferred; +use crate::common::inference::service::{InferenceData, InferenceType}; + +pub async fn convert_point_struct( + point_structs: Vec, + inference_type: InferenceType, +) -> Result, StorageError> { + let mut batch_accum = BatchAccum::new(); + + for point_struct in &point_structs { + match &point_struct.vector { + VectorStruct::Named(named) => { + for vector in named.values() { + match vector { + Vector::Document(doc) => { + batch_accum.add(InferenceData::Document(doc.clone())) + } + Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())), + Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())), + Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {} + } + } + } + VectorStruct::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())), + VectorStruct::Image(img) => batch_accum.add(InferenceData::Image(img.clone())), + VectorStruct::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())), + VectorStruct::MultiDense(_) | VectorStruct::Single(_) => {} + } + } + + let inferred = if !batch_accum.objects.is_empty() { + Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?) + } else { + None + }; + + let mut converted_points: Vec = Vec::new(); + for point_struct in point_structs { + let PointStruct { + id, + vector, + payload, + } = point_struct; + + let converted_vector_struct = match vector { + VectorStruct::Single(single) => VectorStructPersisted::Single(single), + VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi), + VectorStruct::Named(named) => { + let mut named_vectors = HashMap::new(); + for (name, vector) in named { + let converted_vector = match &inferred { + Some(inferred) => convert_vector_with_inferred(vector, inferred)?, + None => match vector { + Vector::Dense(dense) => VectorPersisted::Dense(dense), + Vector::Sparse(sparse) => VectorPersisted::Sparse(sparse), + Vector::MultiDense(multi) => VectorPersisted::MultiDense(multi), + Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => { + return Err(StorageError::inference_error( + "Inference required but service returned no results", + )) + } + }, + }; + named_vectors.insert(name, converted_vector); + } + VectorStructPersisted::Named(named_vectors) + } + VectorStruct::Document(doc) => { + let vector = match &inferred { + Some(inferred) => { + convert_vector_with_inferred(Vector::Document(doc), inferred)? + } + None => { + return Err(StorageError::inference_error( + "Inference required but service returned no results", + )) + } + }; + match vector { + VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense), + VectorPersisted::Sparse(_) => { + return Err(StorageError::bad_request("Sparse vector should be named")); + } + VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi), + } + } + VectorStruct::Image(img) => { + let vector = match &inferred { + Some(inferred) => convert_vector_with_inferred(Vector::Image(img), inferred)?, + None => { + return Err(StorageError::inference_error( + "Inference required but service returned no results", + )) + } + }; + match vector { + VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense), + VectorPersisted::Sparse(_) => { + return Err(StorageError::bad_request("Sparse vector should be named")); + } + VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi), + } + } + VectorStruct::Object(obj) => { + let vector = match &inferred { + Some(inferred) => convert_vector_with_inferred(Vector::Object(obj), inferred)?, + None => { + return Err(StorageError::inference_error( + "Inference required but service returned no results", + )) + } + }; + match vector { + VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense), + VectorPersisted::Sparse(_) => { + return Err(StorageError::bad_request("Sparse vector should be named")); + } + VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi), + } + } + }; + + let converted = PointStructPersisted { + id, + vector: converted_vector_struct, + payload, + }; + + converted_points.push(converted); + } + + Ok(converted_points) +} + +pub async fn convert_batch(batch: Batch) -> Result { + let Batch { + ids, + vectors, + payloads, + } = batch; + + let batch_persisted = BatchPersisted { + ids, + vectors: match vectors { + BatchVectorStruct::Single(single) => BatchVectorStructPersisted::Single(single), + BatchVectorStruct::MultiDense(multi) => BatchVectorStructPersisted::MultiDense(multi), + BatchVectorStruct::Named(named) => { + let mut named_vectors = HashMap::new(); + + for (name, vectors) in named { + let converted_vectors = convert_vectors(vectors, InferenceType::Update).await?; + named_vectors.insert(name, converted_vectors); + } + + BatchVectorStructPersisted::Named(named_vectors) + } + BatchVectorStruct::Document(_) => { + return Err(StorageError::inference_error( + "Document processing is not supported in batch operations.", + )) + } + BatchVectorStruct::Image(_) => { + return Err(StorageError::inference_error( + "Image processing is not supported in batch operations.", + )) + } + BatchVectorStruct::Object(_) => { + return Err(StorageError::inference_error( + "Object processing is not supported in batch operations.", + )) + } + }, + payloads, + }; + + Ok(batch_persisted) +} + +pub async fn convert_point_vectors( + point_vectors_list: Vec, + inference_type: InferenceType, +) -> Result, StorageError> { + let mut converted_point_vectors = Vec::new(); + let mut batch_accum = BatchAccum::new(); + + for point_vectors in &point_vectors_list { + if let VectorStruct::Named(named) = &point_vectors.vector { + for vector in named.values() { + match vector { + Vector::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())), + Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())), + Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())), + Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {} + } + } + } + } + + let inferred = if !batch_accum.objects.is_empty() { + Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?) + } else { + None + }; + + for point_vectors in point_vectors_list { + let PointVectors { id, vector } = point_vectors; + + let converted_vector = match vector { + VectorStruct::Single(dense) => VectorStructPersisted::Single(dense), + VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi), + VectorStruct::Named(named) => { + let mut converted = HashMap::new(); + + for (name, vec) in named { + let converted_vec = match &inferred { + Some(inferred) => convert_vector_with_inferred(vec, inferred)?, + None => match vec { + Vector::Dense(dense) => VectorPersisted::Dense(dense), + Vector::Sparse(sparse) => VectorPersisted::Sparse(sparse), + Vector::MultiDense(multi) => VectorPersisted::MultiDense(multi), + Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => { + return Err(StorageError::inference_error( + "Inference required but service returned no results", + )) + } + }, + }; + converted.insert(name, converted_vec); + } + + VectorStructPersisted::Named(converted) + } + VectorStruct::Document(_) => { + return Err(StorageError::inference_error( + "Document processing is not supported for point vectors.", + )) + } + VectorStruct::Image(_) => { + return Err(StorageError::inference_error( + "Image processing is not supported for point vectors.", + )) + } + VectorStruct::Object(_) => { + return Err(StorageError::inference_error( + "Object processing is not supported for point vectors.", + )) + } + }; + + let converted_point_vector = PointVectorsPersisted { + id, + vector: converted_vector, + }; + + converted_point_vectors.push(converted_point_vector); + } + + Ok(converted_point_vectors) +} + +fn convert_point_struct_with_inferred( + point_structs: Vec, + inferred: &BatchAccumInferred, +) -> Result, StorageError> { + point_structs + .into_iter() + .map(|point_struct| { + let PointStruct { + id, + vector, + payload, + } = point_struct; + let converted_vector_struct = match vector { + VectorStruct::Single(single) => VectorStructPersisted::Single(single), + VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi), + VectorStruct::Named(named) => { + let mut named_vectors = HashMap::new(); + for (name, vector) in named { + let converted_vector = convert_vector_with_inferred(vector, inferred)?; + named_vectors.insert(name, converted_vector); + } + VectorStructPersisted::Named(named_vectors) + } + VectorStruct::Document(doc) => { + let vector = convert_vector_with_inferred(Vector::Document(doc), inferred)?; + match vector { + VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense), + VectorPersisted::Sparse(_) => { + return Err(StorageError::bad_request("Sparse vector should be named")) + } + VectorPersisted::MultiDense(multi) => { + VectorStructPersisted::MultiDense(multi) + } + } + } + VectorStruct::Image(img) => { + let vector = convert_vector_with_inferred(Vector::Image(img), inferred)?; + match vector { + VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense), + VectorPersisted::Sparse(_) => { + return Err(StorageError::bad_request("Sparse vector should be named")) + } + VectorPersisted::MultiDense(multi) => { + VectorStructPersisted::MultiDense(multi) + } + } + } + VectorStruct::Object(obj) => { + let vector = convert_vector_with_inferred(Vector::Object(obj), inferred)?; + match vector { + VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense), + VectorPersisted::Sparse(_) => { + return Err(StorageError::bad_request("Sparse vector should be named")) + } + VectorPersisted::MultiDense(multi) => { + VectorStructPersisted::MultiDense(multi) + } + } + } + }; + + Ok(PointStructPersisted { + id, + vector: converted_vector_struct, + payload, + }) + }) + .collect() +} + +pub async fn convert_vectors( + vectors: Vec, + inference_type: InferenceType, +) -> Result, StorageError> { + let mut batch_accum = BatchAccum::new(); + for vector in &vectors { + match vector { + Vector::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())), + Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())), + Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())), + Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {} + } + } + + let inferred = if !batch_accum.objects.is_empty() { + Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?) + } else { + None + }; + + vectors + .into_iter() + .map(|vector| match &inferred { + Some(inferred) => convert_vector_with_inferred(vector, inferred), + None => match vector { + Vector::Dense(dense) => Ok(VectorPersisted::Dense(dense)), + Vector::Sparse(sparse) => Ok(VectorPersisted::Sparse(sparse)), + Vector::MultiDense(multi) => Ok(VectorPersisted::MultiDense(multi)), + Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => { + Err(StorageError::inference_error( + "Inference required but service returned no results", + )) + } + }, + }) + .collect() +} + +fn convert_vector_with_inferred( + vector: Vector, + inferred: &BatchAccumInferred, +) -> Result { + match vector { + Vector::Dense(dense) => Ok(VectorPersisted::Dense(dense)), + Vector::Sparse(sparse) => Ok(VectorPersisted::Sparse(sparse)), + Vector::MultiDense(multi) => Ok(VectorPersisted::MultiDense(multi)), + Vector::Document(doc) => { + let data = InferenceData::Document(doc); + inferred.get_vector(&data).cloned().ok_or_else(|| { + StorageError::inference_error("Missing inferred vector for document") + }) + } + Vector::Image(img) => { + let data = InferenceData::Image(img); + inferred + .get_vector(&data) + .cloned() + .ok_or_else(|| StorageError::inference_error("Missing inferred vector for image")) + } + Vector::Object(obj) => { + let data = InferenceData::Object(obj); + inferred + .get_vector(&data) + .cloned() + .ok_or_else(|| StorageError::inference_error("Missing inferred vector for object")) + } + } +} diff --git a/src/common/metrics.rs b/src/common/metrics.rs new file mode 100644 index 0000000000000000000000000000000000000000..378ea92f7bcf7060f928f94af66827d3889a7bdf --- /dev/null +++ b/src/common/metrics.rs @@ -0,0 +1,505 @@ +use prometheus::proto::{Counter, Gauge, LabelPair, Metric, MetricFamily, MetricType}; +use prometheus::TextEncoder; +use segment::common::operation_time_statistics::OperationDurationStatistics; + +use crate::common::telemetry::TelemetryData; +use crate::common::telemetry_ops::app_telemetry::{AppBuildTelemetry, AppFeaturesTelemetry}; +use crate::common::telemetry_ops::cluster_telemetry::{ClusterStatusTelemetry, ClusterTelemetry}; +use crate::common::telemetry_ops::collections_telemetry::{ + CollectionTelemetryEnum, CollectionsTelemetry, +}; +use crate::common::telemetry_ops::memory_telemetry::MemoryTelemetry; +use crate::common::telemetry_ops::requests_telemetry::{ + GrpcTelemetry, RequestsTelemetry, WebApiTelemetry, +}; + +/// Whitelist for REST endpoints in metrics output. +/// +/// Contains selection of search, recommend, scroll and upsert endpoints. +/// +/// This array *must* be sorted. +const REST_ENDPOINT_WHITELIST: &[&str] = &[ + "/collections/{name}/index", + "/collections/{name}/points", + "/collections/{name}/points/batch", + "/collections/{name}/points/count", + "/collections/{name}/points/delete", + "/collections/{name}/points/discover", + "/collections/{name}/points/discover/batch", + "/collections/{name}/points/facet", + "/collections/{name}/points/payload", + "/collections/{name}/points/payload/clear", + "/collections/{name}/points/payload/delete", + "/collections/{name}/points/query", + "/collections/{name}/points/query/batch", + "/collections/{name}/points/query/groups", + "/collections/{name}/points/recommend", + "/collections/{name}/points/recommend/batch", + "/collections/{name}/points/recommend/groups", + "/collections/{name}/points/scroll", + "/collections/{name}/points/search", + "/collections/{name}/points/search/batch", + "/collections/{name}/points/search/groups", + "/collections/{name}/points/search/matrix/offsets", + "/collections/{name}/points/search/matrix/pairs", + "/collections/{name}/points/vectors", + "/collections/{name}/points/vectors/delete", +]; + +/// Whitelist for GRPC endpoints in metrics output. +/// +/// Contains selection of search, recommend, scroll and upsert endpoints. +/// +/// This array *must* be sorted. +const GRPC_ENDPOINT_WHITELIST: &[&str] = &[ + "/qdrant.Points/ClearPayload", + "/qdrant.Points/Count", + "/qdrant.Points/Delete", + "/qdrant.Points/DeletePayload", + "/qdrant.Points/Discover", + "/qdrant.Points/DiscoverBatch", + "/qdrant.Points/Facet", + "/qdrant.Points/Get", + "/qdrant.Points/OverwritePayload", + "/qdrant.Points/Query", + "/qdrant.Points/QueryBatch", + "/qdrant.Points/QueryGroups", + "/qdrant.Points/Recommend", + "/qdrant.Points/RecommendBatch", + "/qdrant.Points/RecommendGroups", + "/qdrant.Points/Scroll", + "/qdrant.Points/Search", + "/qdrant.Points/SearchBatch", + "/qdrant.Points/SearchGroups", + "/qdrant.Points/SetPayload", + "/qdrant.Points/UpdateBatch", + "/qdrant.Points/UpdateVectors", + "/qdrant.Points/Upsert", +]; + +/// For REST requests, only report timings when having this HTTP response status. +const REST_TIMINGS_FOR_STATUS: u16 = 200; + +/// Encapsulates metrics data in Prometheus format. +pub struct MetricsData { + metrics: Vec, +} + +impl MetricsData { + pub fn format_metrics(&self) -> String { + TextEncoder::new().encode_to_string(&self.metrics).unwrap() + } +} + +impl From for MetricsData { + fn from(telemetry_data: TelemetryData) -> Self { + let mut metrics = vec![]; + telemetry_data.add_metrics(&mut metrics); + Self { metrics } + } +} + +trait MetricsProvider { + /// Add metrics definitions for this. + fn add_metrics(&self, metrics: &mut Vec); +} + +impl MetricsProvider for TelemetryData { + fn add_metrics(&self, metrics: &mut Vec) { + self.app.add_metrics(metrics); + self.collections.add_metrics(metrics); + self.cluster.add_metrics(metrics); + self.requests.add_metrics(metrics); + if let Some(mem) = &self.memory { + mem.add_metrics(metrics); + } + } +} + +impl MetricsProvider for AppBuildTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + metrics.push(metric_family( + "app_info", + "information about qdrant server", + MetricType::GAUGE, + vec![gauge( + 1.0, + &[("name", &self.name), ("version", &self.version)], + )], + )); + self.features.iter().for_each(|f| f.add_metrics(metrics)); + } +} + +impl MetricsProvider for AppFeaturesTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + metrics.push(metric_family( + "app_status_recovery_mode", + "features enabled in qdrant server", + MetricType::GAUGE, + vec![gauge(if self.recovery_mode { 1.0 } else { 0.0 }, &[])], + )) + } +} + +impl MetricsProvider for CollectionsTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + let vector_count = self + .collections + .iter() + .flatten() + .map(|p| match p { + CollectionTelemetryEnum::Aggregated(a) => a.vectors, + CollectionTelemetryEnum::Full(c) => c.count_vectors(), + }) + .sum::(); + metrics.push(metric_family( + "collections_total", + "number of collections", + MetricType::GAUGE, + vec![gauge(self.number_of_collections as f64, &[])], + )); + metrics.push(metric_family( + "collections_vector_total", + "total number of vectors in all collections", + MetricType::GAUGE, + vec![gauge(vector_count as f64, &[])], + )); + } +} + +impl MetricsProvider for ClusterTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + let ClusterTelemetry { + enabled, + status, + config: _, + peers: _, + metadata: _, + } = self; + + metrics.push(metric_family( + "cluster_enabled", + "is cluster support enabled", + MetricType::GAUGE, + vec![gauge(if *enabled { 1.0 } else { 0.0 }, &[])], + )); + + if let Some(ref status) = status { + status.add_metrics(metrics); + } + } +} + +impl MetricsProvider for ClusterStatusTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + metrics.push(metric_family( + "cluster_peers_total", + "total number of cluster peers", + MetricType::GAUGE, + vec![gauge(self.number_of_peers as f64, &[])], + )); + metrics.push(metric_family( + "cluster_term", + "current cluster term", + MetricType::COUNTER, + vec![counter(self.term as f64, &[])], + )); + + if let Some(ref peer_id) = self.peer_id.map(|p| p.to_string()) { + metrics.push(metric_family( + "cluster_commit", + "index of last committed (finalized) operation cluster peer is aware of", + MetricType::COUNTER, + vec![counter(self.commit as f64, &[("peer_id", peer_id)])], + )); + metrics.push(metric_family( + "cluster_pending_operations_total", + "total number of pending operations for cluster peer", + MetricType::GAUGE, + vec![gauge(self.pending_operations as f64, &[])], + )); + metrics.push(metric_family( + "cluster_voter", + "is cluster peer a voter or learner", + MetricType::GAUGE, + vec![gauge(if self.is_voter { 1.0 } else { 0.0 }, &[])], + )); + } + } +} + +impl MetricsProvider for RequestsTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + self.rest.add_metrics(metrics); + self.grpc.add_metrics(metrics); + } +} + +impl MetricsProvider for WebApiTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + let mut builder = OperationDurationMetricsBuilder::default(); + for (endpoint, responses) in &self.responses { + let Some((method, endpoint)) = endpoint.split_once(' ') else { + continue; + }; + // Endpoint must be whitelisted + if REST_ENDPOINT_WHITELIST.binary_search(&endpoint).is_err() { + continue; + } + for (status, stats) in responses { + builder.add( + stats, + &[ + ("method", method), + ("endpoint", endpoint), + ("status", &status.to_string()), + ], + *status == REST_TIMINGS_FOR_STATUS, + ); + } + } + builder.build("rest", metrics); + } +} + +impl MetricsProvider for GrpcTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + let mut builder = OperationDurationMetricsBuilder::default(); + for (endpoint, stats) in &self.responses { + // Endpoint must be whitelisted + if GRPC_ENDPOINT_WHITELIST + .binary_search(&endpoint.as_str()) + .is_err() + { + continue; + } + builder.add(stats, &[("endpoint", endpoint.as_str())], true); + } + builder.build("grpc", metrics); + } +} + +impl MetricsProvider for MemoryTelemetry { + fn add_metrics(&self, metrics: &mut Vec) { + metrics.push(metric_family( + "memory_active_bytes", + "Total number of bytes in active pages allocated by the application", + MetricType::GAUGE, + vec![gauge(self.active_bytes as f64, &[])], + )); + metrics.push(metric_family( + "memory_allocated_bytes", + "Total number of bytes allocated by the application", + MetricType::GAUGE, + vec![gauge(self.allocated_bytes as f64, &[])], + )); + metrics.push(metric_family( + "memory_metadata_bytes", + "Total number of bytes dedicated to metadata", + MetricType::GAUGE, + vec![gauge(self.metadata_bytes as f64, &[])], + )); + metrics.push(metric_family( + "memory_resident_bytes", + "Maximum number of bytes in physically resident data pages mapped", + MetricType::GAUGE, + vec![gauge(self.resident_bytes as f64, &[])], + )); + metrics.push(metric_family( + "memory_retained_bytes", + "Total number of bytes in virtual memory mappings", + MetricType::GAUGE, + vec![gauge(self.retained_bytes as f64, &[])], + )); + } +} + +/// A helper struct to build a vector of [`MetricFamily`] out of a collection of +/// [`OperationDurationStatistics`]. +#[derive(Default)] +struct OperationDurationMetricsBuilder { + total: Vec, + fail_total: Vec, + avg_secs: Vec, + min_secs: Vec, + max_secs: Vec, + duration_histogram_secs: Vec, +} + +impl OperationDurationMetricsBuilder { + /// Add metrics for the provided statistics. + /// If `add_timings` is `false`, only the total and fail_total counters will be added. + pub fn add( + &mut self, + stat: &OperationDurationStatistics, + labels: &[(&str, &str)], + add_timings: bool, + ) { + self.total.push(counter(stat.count as f64, labels)); + self.fail_total + .push(counter(stat.fail_count as f64, labels)); + + if !add_timings { + return; + } + + self.avg_secs.push(gauge( + f64::from(stat.avg_duration_micros.unwrap_or(0.0)) / 1_000_000.0, + labels, + )); + self.min_secs.push(gauge( + f64::from(stat.min_duration_micros.unwrap_or(0.0)) / 1_000_000.0, + labels, + )); + self.max_secs.push(gauge( + f64::from(stat.max_duration_micros.unwrap_or(0.0)) / 1_000_000.0, + labels, + )); + self.duration_histogram_secs.push(histogram( + stat.count as u64, + stat.total_duration_micros as f64 / 1_000_000.0, + &stat + .duration_micros_histogram + .iter() + .map(|&(b, c)| (f64::from(b) / 1_000_000.0, c as u64)) + .collect::>(), + labels, + )); + } + + /// Build metrics and add them to the provided vector. + pub fn build(self, prefix: &str, metrics: &mut Vec) { + if !self.total.is_empty() { + metrics.push(metric_family( + &format!("{prefix}_responses_total"), + "total number of responses", + MetricType::COUNTER, + self.total, + )); + } + if !self.fail_total.is_empty() { + metrics.push(metric_family( + &format!("{prefix}_responses_fail_total"), + "total number of failed responses", + MetricType::COUNTER, + self.fail_total, + )); + } + if !self.avg_secs.is_empty() { + metrics.push(metric_family( + &format!("{prefix}_responses_avg_duration_seconds"), + "average response duration", + MetricType::GAUGE, + self.avg_secs, + )); + } + if !self.min_secs.is_empty() { + metrics.push(metric_family( + &format!("{prefix}_responses_min_duration_seconds"), + "minimum response duration", + MetricType::GAUGE, + self.min_secs, + )); + } + if !self.max_secs.is_empty() { + metrics.push(metric_family( + &format!("{prefix}_responses_max_duration_seconds"), + "maximum response duration", + MetricType::GAUGE, + self.max_secs, + )); + } + if !self.duration_histogram_secs.is_empty() { + metrics.push(metric_family( + &format!("{prefix}_responses_duration_seconds"), + "response duration histogram", + MetricType::HISTOGRAM, + self.duration_histogram_secs, + )); + } + } +} + +fn metric_family(name: &str, help: &str, r#type: MetricType, metrics: Vec) -> MetricFamily { + let mut metric_family = MetricFamily::default(); + metric_family.set_name(name.into()); + metric_family.set_help(help.into()); + metric_family.set_field_type(r#type); + metric_family.set_metric(metrics); + metric_family +} + +fn counter(value: f64, labels: &[(&str, &str)]) -> Metric { + let mut metric = Metric::default(); + metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect()); + metric.set_counter({ + let mut counter = Counter::default(); + counter.set_value(value); + counter + }); + metric +} + +fn gauge(value: f64, labels: &[(&str, &str)]) -> Metric { + let mut metric = Metric::default(); + metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect()); + metric.set_gauge({ + let mut gauge = Gauge::default(); + gauge.set_value(value); + gauge + }); + metric +} + +fn histogram( + sample_count: u64, + sample_sum: f64, + buckets: &[(f64, u64)], + labels: &[(&str, &str)], +) -> Metric { + let mut metric = Metric::default(); + metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect()); + metric.set_histogram({ + let mut histogram = prometheus::proto::Histogram::default(); + histogram.set_sample_count(sample_count); + histogram.set_sample_sum(sample_sum); + histogram.set_bucket( + buckets + .iter() + .map(|&(upper_bound, cumulative_count)| { + let mut bucket = prometheus::proto::Bucket::default(); + bucket.set_cumulative_count(cumulative_count); + bucket.set_upper_bound(upper_bound); + bucket + }) + .collect(), + ); + histogram + }); + metric +} + +fn label_pair(name: &str, value: &str) -> LabelPair { + let mut label = LabelPair::default(); + label.set_name(name.into()); + label.set_value(value.into()); + label +} + +#[cfg(test)] +mod tests { + #[test] + fn test_endpoint_whitelists_sorted() { + use super::{GRPC_ENDPOINT_WHITELIST, REST_ENDPOINT_WHITELIST}; + + assert!( + REST_ENDPOINT_WHITELIST.windows(2).all(|n| n[0] <= n[1]), + "REST_ENDPOINT_WHITELIST must be sorted in code to allow binary search" + ); + assert!( + GRPC_ENDPOINT_WHITELIST.windows(2).all(|n| n[0] <= n[1]), + "GRPC_ENDPOINT_WHITELIST must be sorted in code to allow binary search" + ); + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..320198ef2cf550a5c3ece497791d261d73f26aa3 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,31 @@ +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod collections; +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod error_reporting; +#[allow(dead_code)] +pub mod health; +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod helpers; +pub mod http_client; +pub mod metrics; +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod points; +pub mod snapshots; +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod stacktrace; +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod telemetry; +pub mod telemetry_ops; +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod telemetry_reporting; + +pub mod auth; + +pub mod strings; + +pub mod debugger; + +#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead +pub mod inference; + +pub mod pyroscope_state; diff --git a/src/common/points.rs b/src/common/points.rs new file mode 100644 index 0000000000000000000000000000000000000000..3504f6c2ecce822b88088dd7dcf7b5a97c3aec6b --- /dev/null +++ b/src/common/points.rs @@ -0,0 +1,1175 @@ +use std::sync::Arc; +use std::time::Duration; + +use api::rest::schema::{PointInsertOperations, PointsBatch, PointsList}; +use api::rest::{SearchGroupsRequestInternal, ShardKeySelector, UpdateVectors}; +use collection::collection::distance_matrix::{ + CollectionSearchMatrixRequest, CollectionSearchMatrixResponse, +}; +use collection::collection::Collection; +use collection::common::batching::batch_requests; +use collection::grouping::group_by::GroupRequest; +use collection::operations::consistency_params::ReadConsistency; +use collection::operations::payload_ops::{ + DeletePayload, DeletePayloadOp, PayloadOps, SetPayload, SetPayloadOp, +}; +use collection::operations::point_ops::{ + FilterSelector, PointIdsList, PointInsertOperationsInternal, PointOperations, PointsSelector, + WriteOrdering, +}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::{ + CollectionError, CoreSearchRequest, CoreSearchRequestBatch, CountRequestInternal, CountResult, + DiscoverRequestBatch, GroupsResult, PointRequestInternal, RecommendGroupsRequestInternal, + RecordInternal, ScrollRequestInternal, ScrollResult, UpdateResult, +}; +use collection::operations::universal_query::collection_query::{ + CollectionQueryGroupsRequest, CollectionQueryRequest, +}; +use collection::operations::vector_ops::{DeleteVectors, UpdateVectorsOp, VectorOperations}; +use collection::operations::verification::{ + new_unchecked_verification_pass, StrictModeVerification, +}; +use collection::operations::{ + ClockTag, CollectionUpdateOperations, CreateIndex, FieldIndexOperations, OperationWithClockTag, +}; +use collection::shards::shard::ShardId; +use common::counter::hardware_accumulator::HwMeasurementAcc; +use schemars::JsonSchema; +use segment::json_path::JsonPath; +use segment::types::{PayloadFieldSchema, PayloadKeyType, ScoredPoint, StrictModeConfig}; +use serde::{Deserialize, Serialize}; +use storage::content_manager::collection_meta_ops::{ + CollectionMetaOperations, CreatePayloadIndex, DropPayloadIndex, +}; +use storage::content_manager::errors::StorageError; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +use validator::Validate; + +use crate::common::inference::service::InferenceType; +use crate::common::inference::update_requests::{ + convert_batch, convert_point_struct, convert_point_vectors, +}; + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate)] +pub struct CreateFieldIndex { + pub field_name: PayloadKeyType, + #[serde(alias = "field_type")] + pub field_schema: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct UpsertOperation { + #[validate(nested)] + upsert: PointInsertOperations, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct DeleteOperation { + #[validate(nested)] + delete: PointsSelector, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct SetPayloadOperation { + #[validate(nested)] + set_payload: SetPayload, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct OverwritePayloadOperation { + #[validate(nested)] + overwrite_payload: SetPayload, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct DeletePayloadOperation { + #[validate(nested)] + delete_payload: DeletePayload, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct ClearPayloadOperation { + #[validate(nested)] + clear_payload: PointsSelector, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct UpdateVectorsOperation { + #[validate(nested)] + update_vectors: UpdateVectors, +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct DeleteVectorsOperation { + #[validate(nested)] + delete_vectors: DeleteVectors, +} + +#[derive(Deserialize, Serialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +#[serde(untagged)] +pub enum UpdateOperation { + Upsert(UpsertOperation), + Delete(DeleteOperation), + SetPayload(SetPayloadOperation), + OverwritePayload(OverwritePayloadOperation), + DeletePayload(DeletePayloadOperation), + ClearPayload(ClearPayloadOperation), + UpdateVectors(UpdateVectorsOperation), + DeleteVectors(DeleteVectorsOperation), +} + +#[derive(Deserialize, Serialize, JsonSchema, Validate)] +pub struct UpdateOperations { + pub operations: Vec, +} + +impl Validate for UpdateOperation { + fn validate(&self) -> Result<(), validator::ValidationErrors> { + match self { + UpdateOperation::Upsert(op) => op.validate(), + UpdateOperation::Delete(op) => op.validate(), + UpdateOperation::SetPayload(op) => op.validate(), + UpdateOperation::OverwritePayload(op) => op.validate(), + UpdateOperation::DeletePayload(op) => op.validate(), + UpdateOperation::ClearPayload(op) => op.validate(), + UpdateOperation::UpdateVectors(op) => op.validate(), + UpdateOperation::DeleteVectors(op) => op.validate(), + } + } +} + +impl StrictModeVerification for UpdateOperation { + fn query_limit(&self) -> Option { + None + } + + fn indexed_filter_read(&self) -> Option<&segment::types::Filter> { + None + } + + fn indexed_filter_write(&self) -> Option<&segment::types::Filter> { + None + } + + fn request_exact(&self) -> Option { + None + } + + fn request_search_params(&self) -> Option<&segment::types::SearchParams> { + None + } + + fn check_strict_mode( + &self, + collection: &Collection, + strict_mode_config: &StrictModeConfig, + ) -> Result<(), CollectionError> { + match self { + UpdateOperation::Delete(delete_op) => delete_op + .delete + .check_strict_mode(collection, strict_mode_config), + UpdateOperation::SetPayload(set_payload) => set_payload + .set_payload + .check_strict_mode(collection, strict_mode_config), + UpdateOperation::OverwritePayload(overwrite_payload) => overwrite_payload + .overwrite_payload + .check_strict_mode(collection, strict_mode_config), + UpdateOperation::DeletePayload(delete_payload) => delete_payload + .delete_payload + .check_strict_mode(collection, strict_mode_config), + UpdateOperation::ClearPayload(clear_payload) => clear_payload + .clear_payload + .check_strict_mode(collection, strict_mode_config), + UpdateOperation::DeleteVectors(delete_op) => delete_op + .delete_vectors + .check_strict_mode(collection, strict_mode_config), + UpdateOperation::UpdateVectors(_) | UpdateOperation::Upsert(_) => Ok(()), + } + } +} + +/// Converts a pair of parameters into a shard selector +/// suitable for update operations. +/// +/// The key difference from selector for search operations is that +/// empty shard selector in case of update means default shard, +/// while empty shard selector in case of search means all shards. +/// +/// Parameters: +/// - shard_selection: selection of the exact shard ID, always have priority over shard_key +/// - shard_key: selection of the shard key, can be a single key or a list of keys +/// +/// Returns: +/// - ShardSelectorInternal - resolved shard selector +fn get_shard_selector_for_update( + shard_selection: Option, + shard_key: Option, +) -> ShardSelectorInternal { + match (shard_selection, shard_key) { + (Some(shard_selection), None) => ShardSelectorInternal::ShardId(shard_selection), + (Some(shard_selection), Some(_)) => { + debug_assert!( + false, + "Shard selection and shard key are mutually exclusive" + ); + ShardSelectorInternal::ShardId(shard_selection) + } + (None, Some(shard_key)) => ShardSelectorInternal::from(shard_key), + (None, None) => ShardSelectorInternal::Empty, + } +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_upsert_points( + toc: Arc, + collection_name: String, + operation: PointInsertOperations, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + let (shard_key, operation) = match operation { + PointInsertOperations::PointsBatch(PointsBatch { batch, shard_key }) => ( + shard_key, + PointInsertOperationsInternal::PointsBatch(convert_batch(batch).await?), + ), + PointInsertOperations::PointsList(PointsList { points, shard_key }) => ( + shard_key, + PointInsertOperationsInternal::PointsList( + convert_point_struct(points, InferenceType::Update).await?, + ), + ), + }; + + let collection_operation = + CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(operation)); + + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_delete_points( + toc: Arc, + collection_name: String, + points: PointsSelector, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + let (point_operation, shard_key) = match points { + PointsSelector::PointIdsSelector(PointIdsList { points, shard_key }) => { + (PointOperations::DeletePoints { ids: points }, shard_key) + } + PointsSelector::FilterSelector(FilterSelector { filter, shard_key }) => { + (PointOperations::DeletePointsByFilter(filter), shard_key) + } + }; + let collection_operation = CollectionUpdateOperations::PointOperation(point_operation); + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_update_vectors( + toc: Arc, + collection_name: String, + operation: UpdateVectors, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + let UpdateVectors { points, shard_key } = operation; + + let persisted_points = convert_point_vectors(points, InferenceType::Update).await?; + + let collection_operation = CollectionUpdateOperations::VectorOperation( + VectorOperations::UpdateVectors(UpdateVectorsOp { + points: persisted_points, + }), + ); + + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_delete_vectors( + toc: Arc, + collection_name: String, + operation: DeleteVectors, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + // TODO: Is this cancel safe!? + + let DeleteVectors { + vector, + filter, + points, + shard_key, + } = operation; + + let vector_names: Vec<_> = vector.into_iter().collect(); + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + let mut result = None; + + if let Some(filter) = filter { + let vectors_operation = + VectorOperations::DeleteVectorsByFilter(filter, vector_names.clone()); + + let collection_operation = CollectionUpdateOperations::VectorOperation(vectors_operation); + + result = Some( + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector.clone(), + access.clone(), + ) + .await?, + ); + } + + if let Some(points) = points { + let vectors_operation = VectorOperations::DeleteVectors(points.into(), vector_names); + let collection_operation = CollectionUpdateOperations::VectorOperation(vectors_operation); + result = Some( + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await?, + ); + } + + result.ok_or_else(|| StorageError::bad_request("No filter or points provided")) +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_set_payload( + toc: Arc, + collection_name: String, + operation: SetPayload, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + let SetPayload { + points, + payload, + filter, + shard_key, + key, + } = operation; + + let collection_operation = + CollectionUpdateOperations::PayloadOperation(PayloadOps::SetPayload(SetPayloadOp { + payload, + points, + filter, + key, + })); + + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_overwrite_payload( + toc: Arc, + collection_name: String, + operation: SetPayload, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + let SetPayload { + points, + payload, + filter, + shard_key, + .. + } = operation; + + let collection_operation = + CollectionUpdateOperations::PayloadOperation(PayloadOps::OverwritePayload(SetPayloadOp { + payload, + points, + filter, + // overwrite operation doesn't support payload selector + key: None, + })); + + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_delete_payload( + toc: Arc, + collection_name: String, + operation: DeletePayload, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + let DeletePayload { + keys, + points, + filter, + shard_key, + } = operation; + + let collection_operation = + CollectionUpdateOperations::PayloadOperation(PayloadOps::DeletePayload(DeletePayloadOp { + keys, + points, + filter, + })); + + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_clear_payload( + toc: Arc, + collection_name: String, + points: PointsSelector, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + let (point_operation, shard_key) = match points { + PointsSelector::PointIdsSelector(PointIdsList { points, shard_key }) => { + (PayloadOps::ClearPayload { points }, shard_key) + } + PointsSelector::FilterSelector(FilterSelector { filter, shard_key }) => { + (PayloadOps::ClearPayloadByFilter(filter), shard_key) + } + }; + + let collection_operation = CollectionUpdateOperations::PayloadOperation(point_operation); + + let shard_selector = get_shard_selector_for_update(shard_selection, shard_key); + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_batch_update_points( + toc: Arc, + collection_name: String, + operations: Vec, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result, StorageError> { + let mut results = Vec::with_capacity(operations.len()); + for operation in operations { + let result = match operation { + UpdateOperation::Upsert(operation) => { + do_upsert_points( + toc.clone(), + collection_name.clone(), + operation.upsert, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + UpdateOperation::Delete(operation) => { + do_delete_points( + toc.clone(), + collection_name.clone(), + operation.delete, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + UpdateOperation::SetPayload(operation) => { + do_set_payload( + toc.clone(), + collection_name.clone(), + operation.set_payload, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + UpdateOperation::OverwritePayload(operation) => { + do_overwrite_payload( + toc.clone(), + collection_name.clone(), + operation.overwrite_payload, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + UpdateOperation::DeletePayload(operation) => { + do_delete_payload( + toc.clone(), + collection_name.clone(), + operation.delete_payload, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + UpdateOperation::ClearPayload(operation) => { + do_clear_payload( + toc.clone(), + collection_name.clone(), + operation.clear_payload, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + UpdateOperation::UpdateVectors(operation) => { + do_update_vectors( + toc.clone(), + collection_name.clone(), + operation.update_vectors, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + UpdateOperation::DeleteVectors(operation) => { + do_delete_vectors( + toc.clone(), + collection_name.clone(), + operation.delete_vectors, + clock_tag, + shard_selection, + wait, + ordering, + access.clone(), + ) + .await + } + }?; + results.push(result); + } + Ok(results) +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_create_index_internal( + toc: Arc, + collection_name: String, + field_name: PayloadKeyType, + field_schema: Option, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, +) -> Result { + let collection_operation = CollectionUpdateOperations::FieldIndexOperation( + FieldIndexOperations::CreateIndex(CreateIndex { + field_name, + field_schema, + }), + ); + + let shard_selector = if let Some(shard_selection) = shard_selection { + ShardSelectorInternal::ShardId(shard_selection) + } else { + ShardSelectorInternal::All + }; + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + Access::full("Internal API"), + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_create_index( + dispatcher: Arc, + collection_name: String, + operation: CreateFieldIndex, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + // TODO: Is this cancel safe!? + + let Some(field_schema) = operation.field_schema else { + return Err(StorageError::bad_request( + "Can't auto-detect field type, please specify `field_schema` in the request", + )); + }; + + let consensus_op = CollectionMetaOperations::CreatePayloadIndex(CreatePayloadIndex { + collection_name: collection_name.to_string(), + field_name: operation.field_name.clone(), + field_schema: field_schema.clone(), + }); + + // Default consensus timeout will be used + let wait_timeout = None; // ToDo: make it configurable + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let toc = dispatcher.toc(&access, &pass).clone(); + + // TODO: Is `submit_collection_meta_op` cancel-safe!? Should be, I think?.. 🤔 + dispatcher + .submit_collection_meta_op(consensus_op, access, wait_timeout) + .await?; + + // This function is required as long as we want to maintain interface compatibility + // for `wait` parameter and return type. + // The idea is to migrate from the point-like interface to consensus-like interface in the next few versions + + do_create_index_internal( + toc, + collection_name, + operation.field_name, + Some(field_schema), + clock_tag, + shard_selection, + wait, + ordering, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_delete_index_internal( + toc: Arc, + collection_name: String, + index_name: JsonPath, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, +) -> Result { + let collection_operation = CollectionUpdateOperations::FieldIndexOperation( + FieldIndexOperations::DeleteIndex(index_name), + ); + + let shard_selector = if let Some(shard_selection) = shard_selection { + ShardSelectorInternal::ShardId(shard_selection) + } else { + ShardSelectorInternal::All + }; + + toc.update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait, + ordering, + shard_selector, + Access::full("Internal API"), + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_delete_index( + dispatcher: Arc, + collection_name: String, + index_name: JsonPath, + clock_tag: Option, + shard_selection: Option, + wait: bool, + ordering: WriteOrdering, + access: Access, +) -> Result { + // TODO: Is this cancel safe!? + + let consensus_op = CollectionMetaOperations::DropPayloadIndex(DropPayloadIndex { + collection_name: collection_name.to_string(), + field_name: index_name.clone(), + }); + + // Default consensus timeout will be used + let wait_timeout = None; // ToDo: make it configurable + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let toc = dispatcher.toc(&access, &pass).clone(); + + // TODO: Is `submit_collection_meta_op` cancel-safe!? Should be, I think?.. 🤔 + dispatcher + .submit_collection_meta_op(consensus_op, access, wait_timeout) + .await?; + + do_delete_index_internal( + toc, + collection_name, + index_name, + clock_tag, + shard_selection, + wait, + ordering, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_core_search_points( + toc: &TableOfContent, + collection_name: &str, + request: CoreSearchRequest, + read_consistency: Option, + shard_selection: ShardSelectorInternal, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result, StorageError> { + let batch_res = do_core_search_batch_points( + toc, + collection_name, + CoreSearchRequestBatch { + searches: vec![request], + }, + read_consistency, + shard_selection, + access, + timeout, + hw_measurement_acc, + ) + .await?; + batch_res + .into_iter() + .next() + .ok_or_else(|| StorageError::service_error("Empty search result")) +} + +pub async fn do_search_batch_points( + toc: &TableOfContent, + collection_name: &str, + requests: Vec<(CoreSearchRequest, ShardSelectorInternal)>, + read_consistency: Option, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result>, StorageError> { + let requests = batch_requests::< + (CoreSearchRequest, ShardSelectorInternal), + ShardSelectorInternal, + Vec, + Vec<_>, + >( + requests, + |(_, shard_selector)| shard_selector, + |(request, _), core_reqs| { + core_reqs.push(request); + Ok(()) + }, + |shard_selector, core_requests, res| { + if core_requests.is_empty() { + return Ok(()); + } + + let core_batch = CoreSearchRequestBatch { + searches: core_requests, + }; + + let req = toc.core_search_batch( + collection_name, + core_batch, + read_consistency, + shard_selector, + access.clone(), + timeout, + hw_measurement_acc, + ); + res.push(req); + Ok(()) + }, + )?; + + let results = futures::future::try_join_all(requests).await?; + let flatten_results: Vec> = results.into_iter().flatten().collect(); + Ok(flatten_results) +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_core_search_batch_points( + toc: &TableOfContent, + collection_name: &str, + request: CoreSearchRequestBatch, + read_consistency: Option, + shard_selection: ShardSelectorInternal, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result>, StorageError> { + toc.core_search_batch( + collection_name, + request, + read_consistency, + shard_selection, + access, + timeout, + hw_measurement_acc, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_search_point_groups( + toc: &TableOfContent, + collection_name: &str, + request: SearchGroupsRequestInternal, + read_consistency: Option, + shard_selection: ShardSelectorInternal, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result { + toc.group( + collection_name, + GroupRequest::from(request), + read_consistency, + shard_selection, + access, + timeout, + hw_measurement_acc, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_recommend_point_groups( + toc: &TableOfContent, + collection_name: &str, + request: RecommendGroupsRequestInternal, + read_consistency: Option, + shard_selection: ShardSelectorInternal, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result { + toc.group( + collection_name, + GroupRequest::from(request), + read_consistency, + shard_selection, + access, + timeout, + hw_measurement_acc, + ) + .await +} + +pub async fn do_discover_batch_points( + toc: &TableOfContent, + collection_name: &str, + request: DiscoverRequestBatch, + read_consistency: Option, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result>, StorageError> { + let requests = request + .searches + .into_iter() + .map(|req| { + let shard_selector = match req.shard_key { + None => ShardSelectorInternal::All, + Some(shard_key) => ShardSelectorInternal::from(shard_key), + }; + + (req.discover_request, shard_selector) + }) + .collect(); + + toc.discover_batch( + collection_name, + requests, + read_consistency, + access, + timeout, + hw_measurement_acc, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_count_points( + toc: &TableOfContent, + collection_name: &str, + request: CountRequestInternal, + read_consistency: Option, + timeout: Option, + shard_selection: ShardSelectorInternal, + access: Access, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result { + toc.count( + collection_name, + request, + read_consistency, + timeout, + shard_selection, + access, + hw_measurement_acc, + ) + .await +} + +pub async fn do_get_points( + toc: &TableOfContent, + collection_name: &str, + request: PointRequestInternal, + read_consistency: Option, + timeout: Option, + shard_selection: ShardSelectorInternal, + access: Access, +) -> Result, StorageError> { + toc.retrieve( + collection_name, + request, + read_consistency, + timeout, + shard_selection, + access, + ) + .await +} + +pub async fn do_scroll_points( + toc: &TableOfContent, + collection_name: &str, + request: ScrollRequestInternal, + read_consistency: Option, + timeout: Option, + shard_selection: ShardSelectorInternal, + access: Access, +) -> Result { + toc.scroll( + collection_name, + request, + read_consistency, + timeout, + shard_selection, + access, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_query_points( + toc: &TableOfContent, + collection_name: &str, + request: CollectionQueryRequest, + read_consistency: Option, + shard_selection: ShardSelectorInternal, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result, StorageError> { + let requests = vec![(request, shard_selection)]; + let batch_res = toc + .query_batch( + collection_name, + requests, + read_consistency, + access, + timeout, + hw_measurement_acc, + ) + .await?; + batch_res + .into_iter() + .next() + .ok_or_else(|| StorageError::service_error("Empty query result")) +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_query_batch_points( + toc: &TableOfContent, + collection_name: &str, + requests: Vec<(CollectionQueryRequest, ShardSelectorInternal)>, + read_consistency: Option, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result>, StorageError> { + toc.query_batch( + collection_name, + requests, + read_consistency, + access, + timeout, + hw_measurement_acc, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_query_point_groups( + toc: &TableOfContent, + collection_name: &str, + request: CollectionQueryGroupsRequest, + read_consistency: Option, + shard_selection: ShardSelectorInternal, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result { + toc.group( + collection_name, + GroupRequest::from(request), + read_consistency, + shard_selection, + access, + timeout, + hw_measurement_acc, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn do_search_points_matrix( + toc: &TableOfContent, + collection_name: &str, + request: CollectionSearchMatrixRequest, + read_consistency: Option, + shard_selection: ShardSelectorInternal, + access: Access, + timeout: Option, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result { + toc.search_points_matrix( + collection_name, + request, + read_consistency, + shard_selection, + access, + timeout, + hw_measurement_acc, + ) + .await +} diff --git a/src/common/pyroscope_state.rs b/src/common/pyroscope_state.rs new file mode 100644 index 0000000000000000000000000000000000000000..1637052c4830d46ef2b89ef50eb785dd648a6008 --- /dev/null +++ b/src/common/pyroscope_state.rs @@ -0,0 +1,93 @@ +#[cfg(target_os = "linux")] +pub mod pyro { + + use pyroscope::pyroscope::PyroscopeAgentRunning; + use pyroscope::{PyroscopeAgent, PyroscopeError}; + use pyroscope_pprofrs::{pprof_backend, PprofConfig}; + + use crate::common::debugger::PyroscopeConfig; + + pub struct PyroscopeState { + pub config: PyroscopeConfig, + pub agent: Option>, + } + + impl PyroscopeState { + fn build_agent( + config: &PyroscopeConfig, + ) -> Result, PyroscopeError> { + let pprof_config = PprofConfig::new().sample_rate(config.sampling_rate.unwrap_or(100)); + let backend_impl = pprof_backend(pprof_config); + + log::info!( + "Starting pyroscope agent with identifier {}", + &config.identifier + ); + // TODO: Add more tags like peerId and peerUrl + let agent = PyroscopeAgent::builder(config.url.to_string(), "qdrant".to_string()) + .backend(backend_impl) + .tags(vec![("app", "Qdrant"), ("identifier", &config.identifier)]) + .build()?; + let running_agent = agent.start()?; + + Ok(running_agent) + } + + pub fn from_config(config: Option) -> Option { + match config { + Some(pyro_config) => { + let agent = PyroscopeState::build_agent(&pyro_config); + match agent { + Ok(agent) => Some(PyroscopeState { + config: pyro_config, + agent: Some(agent), + }), + Err(err) => { + log::warn!("Pyroscope agent failed to start {}", err); + None + } + } + } + None => None, + } + } + + pub fn stop_agent(&mut self) -> bool { + log::info!("Stopping pyroscope agent"); + if let Some(agent) = self.agent.take() { + match agent.stop() { + Ok(stopped_agent) => { + log::info!("Stopped pyroscope agent. Shutting it down"); + stopped_agent.shutdown(); + log::info!("Pyroscope agent shut down completed."); + return true; + } + Err(err) => { + log::warn!("Pyroscope agent failed to stop {}", err); + return false; + } + } + } + true + } + } + + impl Drop for PyroscopeState { + fn drop(&mut self) { + self.stop_agent(); + } + } +} + +#[cfg(not(target_os = "linux"))] +pub mod pyro { + use crate::common::debugger::PyroscopeConfig; + + pub struct PyroscopeState {} + + impl PyroscopeState { + pub fn from_config(_config: Option) -> Option { + None + } + } +} diff --git a/src/common/snapshots.rs b/src/common/snapshots.rs new file mode 100644 index 0000000000000000000000000000000000000000..2f63c0fd86e4a04bac04c38a2797ad9da0d93475 --- /dev/null +++ b/src/common/snapshots.rs @@ -0,0 +1,284 @@ +use std::sync::Arc; + +use collection::collection::Collection; +use collection::common::sha_256::hash_file; +use collection::common::snapshot_stream::SnapshotStream; +use collection::operations::snapshot_ops::{ + ShardSnapshotLocation, SnapshotDescription, SnapshotPriority, +}; +use collection::shards::replica_set::ReplicaState; +use collection::shards::shard::ShardId; +use storage::content_manager::errors::StorageError; +use storage::content_manager::snapshots; +use storage::content_manager::toc::TableOfContent; +use storage::rbac::{Access, AccessRequirements}; + +use super::http_client::HttpClient; + +/// # Cancel safety +/// +/// This function is cancel safe. +pub async fn create_shard_snapshot( + toc: Arc, + access: Access, + collection_name: String, + shard_id: ShardId, +) -> Result { + let collection_pass = access + .check_collection_access(&collection_name, AccessRequirements::new().write().whole())?; + let collection = toc.get_collection(&collection_pass).await?; + + let snapshot = collection + .create_shard_snapshot(shard_id, &toc.optional_temp_or_snapshot_temp_path()?) + .await?; + + Ok(snapshot) +} + +/// # Cancel safety +/// +/// This function is cancel safe. +pub async fn stream_shard_snapshot( + toc: Arc, + access: Access, + collection_name: String, + shard_id: ShardId, +) -> Result { + let collection_pass = access + .check_collection_access(&collection_name, AccessRequirements::new().write().whole())?; + let collection = toc.get_collection(&collection_pass).await?; + + Ok(collection + .stream_shard_snapshot(shard_id, &toc.optional_temp_or_snapshot_temp_path()?) + .await?) +} + +/// # Cancel safety +/// +/// This function is cancel safe. +pub async fn list_shard_snapshots( + toc: Arc, + access: Access, + collection_name: String, + shard_id: ShardId, +) -> Result, StorageError> { + let collection_pass = + access.check_collection_access(&collection_name, AccessRequirements::new().whole())?; + let collection = toc.get_collection(&collection_pass).await?; + let snapshots = collection.list_shard_snapshots(shard_id).await?; + Ok(snapshots) +} + +/// # Cancel safety +/// +/// This function is cancel safe. +pub async fn delete_shard_snapshot( + toc: Arc, + access: Access, + collection_name: String, + shard_id: ShardId, + snapshot_name: String, +) -> Result<(), StorageError> { + let collection_pass = access + .check_collection_access(&collection_name, AccessRequirements::new().write().whole())?; + let collection = toc.get_collection(&collection_pass).await?; + let snapshot_manager = collection.get_snapshots_storage_manager()?; + + let snapshot_path = collection + .shards_holder() + .read() + .await + .get_shard_snapshot_path(collection.snapshots_path(), shard_id, &snapshot_name) + .await?; + + tokio::spawn(async move { snapshot_manager.delete_snapshot(&snapshot_path).await }).await??; + + Ok(()) +} + +/// # Cancel safety +/// +/// This function is cancel safe. +#[allow(clippy::too_many_arguments)] +pub async fn recover_shard_snapshot( + toc: Arc, + access: Access, + collection_name: String, + shard_id: ShardId, + snapshot_location: ShardSnapshotLocation, + snapshot_priority: SnapshotPriority, + checksum: Option, + client: HttpClient, + api_key: Option, +) -> Result<(), StorageError> { + let collection_pass = access + .check_global_access(AccessRequirements::new().manage())? + .issue_pass(&collection_name) + .into_static(); + + // - `recover_shard_snapshot_impl` is *not* cancel safe + // - but the task is *spawned* on the runtime and won't be cancelled, if request is cancelled + + cancel::future::spawn_cancel_on_drop(move |cancel| async move { + let future = async { + let collection = toc.get_collection(&collection_pass).await?; + collection.assert_shard_exists(shard_id).await?; + + let download_dir = toc.optional_temp_or_snapshot_temp_path()?; + + let snapshot_path = match snapshot_location { + ShardSnapshotLocation::Url(url) => { + if !matches!(url.scheme(), "http" | "https") { + let description = format!( + "Invalid snapshot URL {url}: URLs with {} scheme are not supported", + url.scheme(), + ); + + return Err(StorageError::bad_input(description)); + } + + let client = client.client(api_key.as_deref())?; + + snapshots::download::download_snapshot(&client, url, &download_dir).await? + } + + ShardSnapshotLocation::Path(snapshot_file_name) => { + let snapshot_path = collection + .shards_holder() + .read() + .await + .get_shard_snapshot_path( + collection.snapshots_path(), + shard_id, + &snapshot_file_name, + ) + .await?; + + collection + .get_snapshots_storage_manager()? + .get_snapshot_file(&snapshot_path, &download_dir) + .await? + } + }; + + if let Some(checksum) = checksum { + let snapshot_checksum = hash_file(&snapshot_path).await?; + if snapshot_checksum != checksum { + return Err(StorageError::bad_input(format!( + "Snapshot checksum mismatch: expected {checksum}, got {snapshot_checksum}" + ))); + } + } + + Result::<_, StorageError>::Ok((collection, snapshot_path)) + }; + + let (collection, snapshot_path) = + cancel::future::cancel_on_token(cancel.clone(), future).await??; + + // `recover_shard_snapshot_impl` is *not* cancel safe + let result = recover_shard_snapshot_impl( + &toc, + &collection, + shard_id, + &snapshot_path, + snapshot_priority, + cancel, + ) + .await; + + // Remove snapshot after recovery if downloaded + if let Err(err) = snapshot_path.close() { + log::error!("Failed to remove downloaded shards snapshot after recovery: {err}"); + } + + result + }) + .await??; + + Ok(()) +} + +/// # Cancel safety +/// +/// This function is *not* cancel safe. +pub async fn recover_shard_snapshot_impl( + toc: &TableOfContent, + collection: &Collection, + shard: ShardId, + snapshot_path: &std::path::Path, + priority: SnapshotPriority, + cancel: cancel::CancellationToken, +) -> Result<(), StorageError> { + // `Collection::restore_shard_snapshot` and `activate_shard` calls *have to* be executed as a + // single transaction + // + // It is *possible* to make this function to be cancel safe, but it is *extremely tedious* to do so + + // `Collection::restore_shard_snapshot` is *not* cancel safe + // (see `ShardReplicaSet::restore_local_replica_from`) + collection + .restore_shard_snapshot( + shard, + snapshot_path, + toc.this_peer_id, + toc.is_distributed(), + &toc.optional_temp_or_snapshot_temp_path()?, + cancel, + ) + .await?; + + let state = collection.state().await; + let shard_info = state.shards.get(&shard).unwrap(); // TODO: Handle `unwrap`?.. + + // TODO: Unify (and de-duplicate) "recovered shard state notification" logic in `_do_recover_from_snapshot` with this one! + + let other_active_replicas: Vec<_> = shard_info + .replicas + .iter() + .map(|(&peer, &state)| (peer, state)) + .filter(|&(peer, state)| peer != toc.this_peer_id && state == ReplicaState::Active) + .collect(); + + if other_active_replicas.is_empty() { + snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard).await?; + } else { + match priority { + SnapshotPriority::NoSync => { + snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard) + .await?; + } + + SnapshotPriority::Snapshot => { + snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard) + .await?; + + for &(peer, _) in other_active_replicas.iter() { + toc.send_set_replica_state_proposal( + collection.name(), + peer, + shard, + ReplicaState::Dead, + None, + )?; + } + } + + SnapshotPriority::Replica => { + toc.send_set_replica_state_proposal( + collection.name(), + toc.this_peer_id, + shard, + ReplicaState::Dead, + None, + )?; + } + + // `ShardTransfer` is only used during snapshot *shard transfer*. + // State transitions are performed as part of shard transfer *later*, so this simply does *nothing*. + SnapshotPriority::ShardTransfer => (), + } + } + + Ok(()) +} diff --git a/src/common/stacktrace.rs b/src/common/stacktrace.rs new file mode 100644 index 0000000000000000000000000000000000000000..3e0de217550d47749cb46416171834fcfe414777 --- /dev/null +++ b/src/common/stacktrace.rs @@ -0,0 +1,86 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +struct StackTraceSymbol { + name: Option, + file: Option, + line: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +struct StackTraceFrame { + symbols: Vec, +} + +impl StackTraceFrame { + pub fn render(&self) -> String { + let mut result = String::new(); + for symbol in &self.symbols { + let symbol_string = format!( + "{}:{} - {} ", + symbol.file.as_deref().unwrap_or_default(), + symbol.line.unwrap_or_default(), + symbol.name.as_deref().unwrap_or_default(), + ); + result.push_str(&symbol_string); + } + result + } +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct ThreadStackTrace { + id: u32, + name: String, + frames: Vec, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct StackTrace { + threads: Vec, +} + +pub fn get_stack_trace() -> StackTrace { + #[cfg(not(all(target_os = "linux", feature = "stacktrace")))] + { + StackTrace { threads: vec![] } + } + + #[cfg(all(target_os = "linux", feature = "stacktrace"))] + { + let exe = std::env::current_exe().unwrap(); + let trace = + rstack_self::trace(std::process::Command::new(exe).arg("--stacktrace")).unwrap(); + StackTrace { + threads: trace + .threads() + .iter() + .map(|thread| ThreadStackTrace { + id: thread.id(), + name: thread.name().to_string(), + frames: thread + .frames() + .iter() + .map(|frame| { + let frame = StackTraceFrame { + symbols: frame + .symbols() + .iter() + .map(|symbol| StackTraceSymbol { + name: symbol.name().map(|name| name.to_string()), + file: symbol.file().map(|file| { + file.to_str().unwrap_or_default().to_string() + }), + line: symbol.line(), + }) + .collect(), + }; + frame.render() + }) + .collect(), + }) + .collect(), + } + } +} diff --git a/src/common/strings.rs b/src/common/strings.rs new file mode 100644 index 0000000000000000000000000000000000000000..58d51f03c6662fbdf3521dadc9eb106b3de53fbc --- /dev/null +++ b/src/common/strings.rs @@ -0,0 +1,5 @@ +/// Constant-time equality for String types +#[inline] +pub fn ct_eq(lhs: impl AsRef, rhs: impl AsRef) -> bool { + constant_time_eq::constant_time_eq(lhs.as_ref().as_bytes(), rhs.as_ref().as_bytes()) +} diff --git a/src/common/telemetry.rs b/src/common/telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..41518ae986cf1f0de77fa3ad557f05828707ddbd --- /dev/null +++ b/src/common/telemetry.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use collection::operations::verification::new_unchecked_verification_pass; +use common::types::{DetailsLevel, TelemetryDetail}; +use parking_lot::Mutex; +use schemars::JsonSchema; +use segment::common::anonymize::Anonymize; +use serde::Serialize; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +use uuid::Uuid; + +use crate::common::telemetry_ops::app_telemetry::{AppBuildTelemetry, AppBuildTelemetryCollector}; +use crate::common::telemetry_ops::cluster_telemetry::ClusterTelemetry; +use crate::common::telemetry_ops::collections_telemetry::CollectionsTelemetry; +use crate::common::telemetry_ops::memory_telemetry::MemoryTelemetry; +use crate::common::telemetry_ops::requests_telemetry::{ + ActixTelemetryCollector, RequestsTelemetry, TonicTelemetryCollector, +}; +use crate::settings::Settings; + +pub struct TelemetryCollector { + process_id: Uuid, + settings: Settings, + dispatcher: Arc, + pub app_telemetry_collector: AppBuildTelemetryCollector, + pub actix_telemetry_collector: Arc>, + pub tonic_telemetry_collector: Arc>, +} + +// Whole telemetry data +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct TelemetryData { + id: String, + pub(crate) app: AppBuildTelemetry, + pub(crate) collections: CollectionsTelemetry, + pub(crate) cluster: ClusterTelemetry, + pub(crate) requests: RequestsTelemetry, + pub(crate) memory: Option, +} + +impl Anonymize for TelemetryData { + fn anonymize(&self) -> Self { + TelemetryData { + id: self.id.clone(), + app: self.app.anonymize(), + collections: self.collections.anonymize(), + cluster: self.cluster.anonymize(), + requests: self.requests.anonymize(), + memory: self.memory.anonymize(), + } + } +} + +impl TelemetryCollector { + pub fn reporting_id(&self) -> String { + self.process_id.to_string() + } + + pub fn generate_id() -> Uuid { + Uuid::new_v4() + } + + pub fn new(settings: Settings, dispatcher: Arc, id: Uuid) -> Self { + Self { + process_id: id, + settings, + dispatcher, + app_telemetry_collector: AppBuildTelemetryCollector::new(), + actix_telemetry_collector: Arc::new(Mutex::new(ActixTelemetryCollector { + workers: Vec::new(), + })), + tonic_telemetry_collector: Arc::new(Mutex::new(TonicTelemetryCollector { + workers: Vec::new(), + })), + } + } + + pub async fn prepare_data(&self, access: &Access, detail: TelemetryDetail) -> TelemetryData { + TelemetryData { + id: self.process_id.to_string(), + collections: CollectionsTelemetry::collect( + detail, + access, + self.dispatcher + .toc(access, &new_unchecked_verification_pass()), + ) + .await, + app: AppBuildTelemetry::collect(detail, &self.app_telemetry_collector, &self.settings), + cluster: ClusterTelemetry::collect(detail, &self.dispatcher, &self.settings), + requests: RequestsTelemetry::collect( + &self.actix_telemetry_collector.lock(), + &self.tonic_telemetry_collector.lock(), + detail, + ), + memory: (detail.level > DetailsLevel::Level0) + .then(MemoryTelemetry::collect) + .flatten(), + } + } +} diff --git a/src/common/telemetry_ops/app_telemetry.rs b/src/common/telemetry_ops/app_telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..8acde6d95070b611e0cd43d2e9bc777485d9d86c --- /dev/null +++ b/src/common/telemetry_ops/app_telemetry.rs @@ -0,0 +1,198 @@ +use std::path::Path; + +use chrono::{DateTime, SubsecRound, Utc}; +use common::types::{DetailsLevel, TelemetryDetail}; +use schemars::JsonSchema; +use segment::common::anonymize::Anonymize; +use serde::Serialize; + +use crate::settings::Settings; + +pub struct AppBuildTelemetryCollector { + pub startup: DateTime, +} + +impl AppBuildTelemetryCollector { + pub fn new() -> Self { + AppBuildTelemetryCollector { + startup: Utc::now().round_subsecs(2), + } + } +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct AppFeaturesTelemetry { + pub debug: bool, + pub web_feature: bool, + pub service_debug_feature: bool, + pub recovery_mode: bool, +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct RunningEnvironmentTelemetry { + distribution: Option, + distribution_version: Option, + is_docker: bool, + cores: Option, + ram_size: Option, + disk_size: Option, + cpu_flags: String, + #[serde(skip_serializing_if = "Option::is_none")] + cpu_endian: Option, +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct AppBuildTelemetry { + pub name: String, + pub version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub features: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub jwt_rbac: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub hide_jwt_dashboard: Option, + pub startup: DateTime, +} + +impl AppBuildTelemetry { + pub fn collect( + detail: TelemetryDetail, + collector: &AppBuildTelemetryCollector, + settings: &Settings, + ) -> Self { + AppBuildTelemetry { + name: env!("CARGO_PKG_NAME").to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + features: (detail.level >= DetailsLevel::Level1).then(|| AppFeaturesTelemetry { + debug: cfg!(debug_assertions), + web_feature: cfg!(feature = "web"), + service_debug_feature: cfg!(feature = "service_debug"), + recovery_mode: settings.storage.recovery_mode.is_some(), + }), + system: (detail.level >= DetailsLevel::Level1).then(get_system_data), + jwt_rbac: settings.service.jwt_rbac, + hide_jwt_dashboard: settings.service.hide_jwt_dashboard, + startup: collector.startup, + } + } +} + +fn get_system_data() -> RunningEnvironmentTelemetry { + let distribution = if let Ok(release) = sys_info::linux_os_release() { + release.id + } else { + sys_info::os_type().ok() + }; + let distribution_version = if let Ok(release) = sys_info::linux_os_release() { + release.version_id + } else { + sys_info::os_release().ok() + }; + let mut cpu_flags = vec![]; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if std::arch::is_x86_feature_detected!("sse") { + cpu_flags.push("sse"); + } + if std::arch::is_x86_feature_detected!("sse2") { + cpu_flags.push("sse2"); + } + if std::arch::is_x86_feature_detected!("avx") { + cpu_flags.push("avx"); + } + if std::arch::is_x86_feature_detected!("avx2") { + cpu_flags.push("avx2"); + } + if std::arch::is_x86_feature_detected!("fma") { + cpu_flags.push("fma"); + } + if std::arch::is_x86_feature_detected!("f16c") { + cpu_flags.push("f16c"); + } + if std::arch::is_x86_feature_detected!("avx512f") { + cpu_flags.push("avx512f"); + } + } + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") { + cpu_flags.push("neon"); + } + if std::arch::is_aarch64_feature_detected!("fp16") { + cpu_flags.push("fp16"); + } + } + RunningEnvironmentTelemetry { + distribution, + distribution_version, + is_docker: cfg!(unix) && Path::new("/.dockerenv").exists(), + cores: sys_info::cpu_num().ok().map(|x| x as usize), + ram_size: sys_info::mem_info().ok().map(|x| x.total as usize), + disk_size: sys_info::disk_info().ok().map(|x| x.total as usize), + cpu_flags: cpu_flags.join(","), + cpu_endian: Some(CpuEndian::current()), + } +} + +#[derive(Serialize, Clone, Copy, Debug, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum CpuEndian { + Little, + Big, + Other, +} + +impl CpuEndian { + /// Get the current used byte order + pub const fn current() -> Self { + if cfg!(target_endian = "little") { + CpuEndian::Little + } else if cfg!(target_endian = "big") { + CpuEndian::Big + } else { + CpuEndian::Other + } + } +} + +impl Anonymize for AppFeaturesTelemetry { + fn anonymize(&self) -> Self { + AppFeaturesTelemetry { + debug: self.debug, + web_feature: self.web_feature, + service_debug_feature: self.service_debug_feature, + recovery_mode: self.recovery_mode, + } + } +} + +impl Anonymize for AppBuildTelemetry { + fn anonymize(&self) -> Self { + AppBuildTelemetry { + name: self.name.clone(), + version: self.version.clone(), + features: self.features.anonymize(), + system: self.system.anonymize(), + jwt_rbac: self.jwt_rbac, + hide_jwt_dashboard: self.hide_jwt_dashboard, + startup: self.startup.anonymize(), + } + } +} + +impl Anonymize for RunningEnvironmentTelemetry { + fn anonymize(&self) -> Self { + RunningEnvironmentTelemetry { + distribution: self.distribution.clone(), + distribution_version: self.distribution_version.clone(), + is_docker: self.is_docker, + cores: self.cores, + ram_size: self.ram_size.anonymize(), + disk_size: self.disk_size.anonymize(), + cpu_flags: self.cpu_flags.clone(), + cpu_endian: self.cpu_endian, + } + } +} diff --git a/src/common/telemetry_ops/cluster_telemetry.rs b/src/common/telemetry_ops/cluster_telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..211451d9b9e950b36426eb7bf83f86bf65f60a3b --- /dev/null +++ b/src/common/telemetry_ops/cluster_telemetry.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; + +use collection::shards::shard::PeerId; +use common::types::{DetailsLevel, TelemetryDetail}; +use schemars::JsonSchema; +use segment::common::anonymize::Anonymize; +use serde::Serialize; +use storage::dispatcher::Dispatcher; +use storage::types::{ClusterStatus, ConsensusThreadStatus, PeerInfo, StateRole}; + +use crate::settings::Settings; + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct P2pConfigTelemetry { + connection_pool_size: usize, +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct ConsensusConfigTelemetry { + max_message_queue_size: usize, + tick_period_ms: u64, + bootstrap_timeout_sec: u64, +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct ClusterConfigTelemetry { + grpc_timeout_ms: u64, + p2p: P2pConfigTelemetry, + consensus: ConsensusConfigTelemetry, +} + +impl From<&Settings> for ClusterConfigTelemetry { + fn from(settings: &Settings) -> Self { + ClusterConfigTelemetry { + grpc_timeout_ms: settings.cluster.grpc_timeout_ms, + p2p: P2pConfigTelemetry { + connection_pool_size: settings.cluster.p2p.connection_pool_size, + }, + consensus: ConsensusConfigTelemetry { + max_message_queue_size: settings.cluster.consensus.max_message_queue_size, + tick_period_ms: settings.cluster.consensus.tick_period_ms, + bootstrap_timeout_sec: settings.cluster.consensus.bootstrap_timeout_sec, + }, + } + } +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct ClusterStatusTelemetry { + pub number_of_peers: usize, + pub term: u64, + pub commit: u64, + pub pending_operations: usize, + pub role: Option, + pub is_voter: bool, + pub peer_id: Option, + pub consensus_thread_status: ConsensusThreadStatus, +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct ClusterTelemetry { + pub enabled: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub peers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +impl ClusterTelemetry { + pub fn collect( + detail: TelemetryDetail, + dispatcher: &Dispatcher, + settings: &Settings, + ) -> ClusterTelemetry { + ClusterTelemetry { + enabled: settings.cluster.enabled, + status: (detail.level >= DetailsLevel::Level1) + .then(|| match dispatcher.cluster_status() { + ClusterStatus::Disabled => None, + ClusterStatus::Enabled(cluster_info) => Some(ClusterStatusTelemetry { + number_of_peers: cluster_info.peers.len(), + term: cluster_info.raft_info.term, + commit: cluster_info.raft_info.commit, + pending_operations: cluster_info.raft_info.pending_operations, + role: cluster_info.raft_info.role, + is_voter: cluster_info.raft_info.is_voter, + peer_id: Some(cluster_info.peer_id), + consensus_thread_status: cluster_info.consensus_thread_status, + }), + }) + .flatten(), + config: (detail.level >= DetailsLevel::Level2) + .then(|| ClusterConfigTelemetry::from(settings)), + peers: (detail.level >= DetailsLevel::Level2) + .then(|| match dispatcher.cluster_status() { + ClusterStatus::Disabled => None, + ClusterStatus::Enabled(cluster_info) => Some(cluster_info.peers), + }) + .flatten(), + metadata: (detail.level >= DetailsLevel::Level1) + .then(|| { + dispatcher + .consensus_state() + .map(|state| state.persistent.read().cluster_metadata.clone()) + .filter(|metadata| !metadata.is_empty()) + }) + .flatten(), + } + } +} + +impl Anonymize for ClusterTelemetry { + fn anonymize(&self) -> Self { + ClusterTelemetry { + enabled: self.enabled, + status: self.status.clone().map(|x| x.anonymize()), + config: self.config.clone().map(|x| x.anonymize()), + peers: None, + metadata: None, + } + } +} + +impl Anonymize for ClusterStatusTelemetry { + fn anonymize(&self) -> Self { + ClusterStatusTelemetry { + number_of_peers: self.number_of_peers, + term: self.term, + commit: self.commit, + pending_operations: self.pending_operations, + role: self.role, + is_voter: self.is_voter, + peer_id: None, + consensus_thread_status: self.consensus_thread_status.clone(), + } + } +} + +impl Anonymize for ClusterConfigTelemetry { + fn anonymize(&self) -> Self { + self.clone() + } +} diff --git a/src/common/telemetry_ops/collections_telemetry.rs b/src/common/telemetry_ops/collections_telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..282ac35be01492352dd6213c1b3172d815552064 --- /dev/null +++ b/src/common/telemetry_ops/collections_telemetry.rs @@ -0,0 +1,108 @@ +use collection::config::CollectionParams; +use collection::operations::types::OptimizersStatus; +use collection::telemetry::CollectionTelemetry; +use common::types::{DetailsLevel, TelemetryDetail}; +use schemars::JsonSchema; +use segment::common::anonymize::Anonymize; +use serde::Serialize; +use storage::content_manager::toc::TableOfContent; +use storage::rbac::Access; + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct CollectionsAggregatedTelemetry { + pub vectors: usize, + pub optimizers_status: OptimizersStatus, + pub params: CollectionParams, +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +#[serde(untagged)] +pub enum CollectionTelemetryEnum { + Full(CollectionTelemetry), + Aggregated(CollectionsAggregatedTelemetry), +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct CollectionsTelemetry { + pub number_of_collections: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub collections: Option>, +} + +impl From for CollectionsAggregatedTelemetry { + fn from(telemetry: CollectionTelemetry) -> Self { + let optimizers_status = telemetry + .shards + .iter() + .filter_map(|shard| shard.local.as_ref().map(|x| x.optimizations.status.clone())) + .max() + .unwrap_or(OptimizersStatus::Ok); + + CollectionsAggregatedTelemetry { + vectors: telemetry.count_vectors(), + optimizers_status, + params: telemetry.config.params, + } + } +} + +impl CollectionsTelemetry { + pub async fn collect(detail: TelemetryDetail, access: &Access, toc: &TableOfContent) -> Self { + let number_of_collections = toc.all_collections(access).await.len(); + let collections = if detail.level >= DetailsLevel::Level1 { + let telemetry_data = toc + .get_telemetry_data(detail, access) + .await + .into_iter() + .map(|telemetry| { + if detail.level >= DetailsLevel::Level2 { + CollectionTelemetryEnum::Full(telemetry) + } else { + CollectionTelemetryEnum::Aggregated(telemetry.into()) + } + }) + .collect(); + + Some(telemetry_data) + } else { + None + }; + + CollectionsTelemetry { + number_of_collections, + collections, + } + } +} + +impl Anonymize for CollectionsTelemetry { + fn anonymize(&self) -> Self { + CollectionsTelemetry { + number_of_collections: self.number_of_collections, + collections: self.collections.anonymize(), + } + } +} + +impl Anonymize for CollectionTelemetryEnum { + fn anonymize(&self) -> Self { + match self { + CollectionTelemetryEnum::Full(telemetry) => { + CollectionTelemetryEnum::Full(telemetry.anonymize()) + } + CollectionTelemetryEnum::Aggregated(telemetry) => { + CollectionTelemetryEnum::Aggregated(telemetry.anonymize()) + } + } + } +} + +impl Anonymize for CollectionsAggregatedTelemetry { + fn anonymize(&self) -> Self { + CollectionsAggregatedTelemetry { + optimizers_status: self.optimizers_status.clone(), + vectors: self.vectors.anonymize(), + params: self.params.anonymize(), + } + } +} diff --git a/src/common/telemetry_ops/memory_telemetry.rs b/src/common/telemetry_ops/memory_telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..5c2df2cc64595f7240d256821d68187979edfab2 --- /dev/null +++ b/src/common/telemetry_ops/memory_telemetry.rs @@ -0,0 +1,60 @@ +use schemars::JsonSchema; +use segment::common::anonymize::Anonymize; +use serde::Serialize; +#[cfg(all( + not(target_env = "msvc"), + any(target_arch = "x86_64", target_arch = "aarch64") +))] +use tikv_jemalloc_ctl::{epoch, stats}; + +#[derive(Debug, Clone, Default, JsonSchema, Serialize)] +pub struct MemoryTelemetry { + /// Total number of bytes in active pages allocated by the application + pub active_bytes: usize, + /// Total number of bytes allocated by the application + pub allocated_bytes: usize, + /// Total number of bytes dedicated to metadata + pub metadata_bytes: usize, + /// Maximum number of bytes in physically resident data pages mapped + pub resident_bytes: usize, + /// Total number of bytes in virtual memory mappings + pub retained_bytes: usize, +} + +impl MemoryTelemetry { + #[cfg(all( + not(target_env = "msvc"), + any(target_arch = "x86_64", target_arch = "aarch64") + ))] + pub fn collect() -> Option { + if epoch::advance().is_ok() { + Some(MemoryTelemetry { + active_bytes: stats::active::read().unwrap_or_default(), + allocated_bytes: stats::allocated::read().unwrap_or_default(), + metadata_bytes: stats::metadata::read().unwrap_or_default(), + resident_bytes: stats::resident::read().unwrap_or_default(), + retained_bytes: stats::retained::read().unwrap_or_default(), + }) + } else { + log::info!("Failed to advance Jemalloc stats epoch"); + None + } + } + + #[cfg(target_env = "msvc")] + pub fn collect() -> Option { + None + } +} + +impl Anonymize for MemoryTelemetry { + fn anonymize(&self) -> Self { + MemoryTelemetry { + active_bytes: self.active_bytes, + allocated_bytes: self.allocated_bytes, + metadata_bytes: self.metadata_bytes, + resident_bytes: self.resident_bytes, + retained_bytes: self.retained_bytes, + } + } +} diff --git a/src/common/telemetry_ops/mod.rs b/src/common/telemetry_ops/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..b274e5f7995ff69f7fc3ece6bb05a312db6b39d5 --- /dev/null +++ b/src/common/telemetry_ops/mod.rs @@ -0,0 +1,5 @@ +pub mod app_telemetry; +pub mod cluster_telemetry; +pub mod collections_telemetry; +pub mod memory_telemetry; +pub mod requests_telemetry; diff --git a/src/common/telemetry_ops/requests_telemetry.rs b/src/common/telemetry_ops/requests_telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..27e83dcd956fecec89ebb32d7395a534e87459e6 --- /dev/null +++ b/src/common/telemetry_ops/requests_telemetry.rs @@ -0,0 +1,201 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use common::types::TelemetryDetail; +use parking_lot::Mutex; +use schemars::JsonSchema; +use segment::common::anonymize::Anonymize; +use segment::common::operation_time_statistics::{ + OperationDurationStatistics, OperationDurationsAggregator, ScopeDurationMeasurer, +}; +use serde::Serialize; + +pub type HttpStatusCode = u16; + +#[derive(Serialize, Clone, Default, Debug, JsonSchema)] +pub struct WebApiTelemetry { + pub responses: HashMap>, +} + +#[derive(Serialize, Clone, Default, Debug, JsonSchema)] +pub struct GrpcTelemetry { + pub responses: HashMap, +} + +pub struct ActixTelemetryCollector { + pub workers: Vec>>, +} + +#[derive(Default)] +pub struct ActixWorkerTelemetryCollector { + methods: HashMap>>>, +} + +pub struct TonicTelemetryCollector { + pub workers: Vec>>, +} + +#[derive(Default)] +pub struct TonicWorkerTelemetryCollector { + methods: HashMap>>, +} + +impl ActixTelemetryCollector { + pub fn create_web_worker_telemetry(&mut self) -> Arc> { + let worker: Arc> = Default::default(); + self.workers.push(worker.clone()); + worker + } + + pub fn get_telemetry_data(&self, detail: TelemetryDetail) -> WebApiTelemetry { + let mut result = WebApiTelemetry::default(); + for web_data in &self.workers { + let lock = web_data.lock().get_telemetry_data(detail); + result.merge(&lock); + } + result + } +} + +impl TonicTelemetryCollector { + #[allow(dead_code)] + pub fn create_grpc_telemetry_collector(&mut self) -> Arc> { + let worker: Arc> = Default::default(); + self.workers.push(worker.clone()); + worker + } + + pub fn get_telemetry_data(&self, detail: TelemetryDetail) -> GrpcTelemetry { + let mut result = GrpcTelemetry::default(); + for grpc_data in &self.workers { + let lock = grpc_data.lock().get_telemetry_data(detail); + result.merge(&lock); + } + result + } +} + +impl TonicWorkerTelemetryCollector { + #[allow(dead_code)] + pub fn add_response(&mut self, method: String, instant: std::time::Instant) { + let aggregator = self + .methods + .entry(method) + .or_insert_with(OperationDurationsAggregator::new); + ScopeDurationMeasurer::new_with_instant(aggregator, instant); + } + + pub fn get_telemetry_data(&self, detail: TelemetryDetail) -> GrpcTelemetry { + let mut responses = HashMap::new(); + for (method, aggregator) in self.methods.iter() { + responses.insert(method.clone(), aggregator.lock().get_statistics(detail)); + } + GrpcTelemetry { responses } + } +} + +impl ActixWorkerTelemetryCollector { + pub fn add_response( + &mut self, + method: String, + status_code: HttpStatusCode, + instant: std::time::Instant, + ) { + let aggregator = self + .methods + .entry(method) + .or_default() + .entry(status_code) + .or_insert_with(OperationDurationsAggregator::new); + ScopeDurationMeasurer::new_with_instant(aggregator, instant); + } + + pub fn get_telemetry_data(&self, detail: TelemetryDetail) -> WebApiTelemetry { + let mut responses = HashMap::new(); + for (method, status_codes) in &self.methods { + let mut status_codes_map = HashMap::new(); + for (status_code, aggregator) in status_codes { + status_codes_map.insert(*status_code, aggregator.lock().get_statistics(detail)); + } + responses.insert(method.clone(), status_codes_map); + } + WebApiTelemetry { responses } + } +} + +impl GrpcTelemetry { + pub fn merge(&mut self, other: &GrpcTelemetry) { + for (method, other_statistics) in &other.responses { + let entry = self.responses.entry(method.clone()).or_default(); + *entry = entry.clone() + other_statistics.clone(); + } + } +} + +impl WebApiTelemetry { + pub fn merge(&mut self, other: &WebApiTelemetry) { + for (method, status_codes) in &other.responses { + let status_codes_map = self.responses.entry(method.clone()).or_default(); + for (status_code, statistics) in status_codes { + let entry = status_codes_map.entry(*status_code).or_default(); + *entry = entry.clone() + statistics.clone(); + } + } + } +} + +#[derive(Serialize, Clone, Debug, JsonSchema)] +pub struct RequestsTelemetry { + pub rest: WebApiTelemetry, + pub grpc: GrpcTelemetry, +} + +impl RequestsTelemetry { + pub fn collect( + actix_collector: &ActixTelemetryCollector, + tonic_collector: &TonicTelemetryCollector, + detail: TelemetryDetail, + ) -> Self { + let rest = actix_collector.get_telemetry_data(detail); + let grpc = tonic_collector.get_telemetry_data(detail); + Self { rest, grpc } + } +} + +impl Anonymize for RequestsTelemetry { + fn anonymize(&self) -> Self { + let rest = self.rest.anonymize(); + let grpc = self.grpc.anonymize(); + Self { rest, grpc } + } +} + +impl Anonymize for WebApiTelemetry { + fn anonymize(&self) -> Self { + let responses = self + .responses + .iter() + .map(|(key, value)| { + let value: HashMap<_, _> = value + .iter() + .map(|(key, value)| (*key, value.anonymize())) + .collect(); + (key.clone(), value) + }) + .collect(); + + WebApiTelemetry { responses } + } +} + +impl Anonymize for GrpcTelemetry { + fn anonymize(&self) -> Self { + let responses = self + .responses + .iter() + .map(|(key, value)| (key.clone(), value.anonymize())) + .collect(); + + GrpcTelemetry { responses } + } +} diff --git a/src/common/telemetry_reporting.rs b/src/common/telemetry_reporting.rs new file mode 100644 index 0000000000000000000000000000000000000000..d7663b20dd361fdc67d44ec36ef7055b34c64142 --- /dev/null +++ b/src/common/telemetry_reporting.rs @@ -0,0 +1,64 @@ +use std::sync::Arc; +use std::time::Duration; + +use common::types::{DetailsLevel, TelemetryDetail}; +use reqwest::Client; +use segment::common::anonymize::Anonymize; +use storage::rbac::Access; +use tokio::sync::Mutex; + +use super::telemetry::TelemetryCollector; + +const DETAIL: TelemetryDetail = TelemetryDetail { + level: DetailsLevel::Level2, + histograms: false, +}; +const REPORTING_INTERVAL: Duration = Duration::from_secs(60 * 60); // One hour + +pub struct TelemetryReporter { + telemetry_url: String, + telemetry: Arc>, +} + +const FULL_ACCESS: Access = Access::full("Telemetry reporter"); + +impl TelemetryReporter { + fn new(telemetry: Arc>) -> Self { + let telemetry_url = if cfg!(debug_assertions) { + "https://staging-telemetry.qdrant.io".to_string() + } else { + "https://telemetry.qdrant.io".to_string() + }; + + Self { + telemetry_url, + telemetry, + } + } + + async fn report(&self, client: &Client) { + let data = self + .telemetry + .lock() + .await + .prepare_data(&FULL_ACCESS, DETAIL) + .await + .anonymize(); + let data = serde_json::to_string(&data).unwrap(); + let _resp = client + .post(&self.telemetry_url) + .body(data) + .header("Content-Type", "application/json") + .send() + .await; + } + + pub async fn run(telemetry: Arc>) { + let reporter = Self::new(telemetry); + let client = Client::new(); + loop { + reporter.report(&client).await; + tokio::time::sleep(REPORTING_INTERVAL).await; + } + } +} diff --git a/src/consensus.rs b/src/consensus.rs new file mode 100644 index 0000000000000000000000000000000000000000..b965b2961732b4cb4478afaecbb4bb6554ea5ac6 --- /dev/null +++ b/src/consensus.rs @@ -0,0 +1,1502 @@ +use std::collections::{HashMap, HashSet}; +use std::str::FromStr; +use std::sync::{mpsc, Arc}; +use std::thread::JoinHandle; +use std::time::{Duration, Instant}; +use std::{fmt, thread}; + +use anyhow::{anyhow, Context as _}; +use api::grpc::dynamic_channel_pool::make_grpc_channel; +use api::grpc::qdrant::raft_client::RaftClient; +use api::grpc::qdrant::{AllPeers, PeerId as GrpcPeerId, RaftMessage as GrpcRaftMessage}; +use api::grpc::transport_channel_pool::TransportChannelPool; +use collection::shards::channel_service::ChannelService; +use collection::shards::shard::PeerId; +#[cfg(target_os = "linux")] +use common::cpu::linux_high_thread_priority; +use common::defaults; +use raft::eraftpb::Message as RaftMessage; +use raft::prelude::*; +use raft::{SoftState, StateRole, INVALID_ID}; +use storage::content_manager::consensus_manager::ConsensusStateRef; +use storage::content_manager::consensus_ops::{ConsensusOperations, SnapshotStatus}; +use storage::content_manager::toc::TableOfContent; +use tokio::runtime::Handle; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::watch; +use tokio::time::sleep; +use tonic::transport::{ClientTlsConfig, Uri}; + +use crate::common::helpers; +use crate::common::telemetry_ops::requests_telemetry::TonicTelemetryCollector; +use crate::settings::{ConsensusConfig, Settings}; +use crate::tonic::init_internal; + +type Node = RawNode; + +const RECOVERY_RETRY_TIMEOUT: Duration = Duration::from_secs(1); +const RECOVERY_MAX_RETRY_COUNT: usize = 3; + +pub enum Message { + FromClient(ConsensusOperations), + FromPeer(Box), +} + +/// Aka Consensus Thread +/// Manages proposed changes to consensus state, ensures that everything is ordered properly +pub struct Consensus { + /// Raft structure which handles raft-related state + node: Node, + /// Receives proposals from peers and client for applying in consensus + receiver: Receiver, + /// Runtime for async message sending + runtime: Handle, + /// Uri to some other known peer, used to join the consensus + /// ToDo: Make if many + config: ConsensusConfig, + broker: RaftMessageBroker, +} + +impl Consensus { + /// Create and run consensus node + #[allow(clippy::too_many_arguments)] + pub fn run( + logger: &slog::Logger, + state_ref: ConsensusStateRef, + bootstrap_peer: Option, + uri: Option, + settings: Settings, + channel_service: ChannelService, + propose_receiver: mpsc::Receiver, + telemetry_collector: Arc>, + toc: Arc, + runtime: Handle, + reinit: bool, + ) -> anyhow::Result>> { + let tls_client_config = helpers::load_tls_client_config(&settings)?; + + let p2p_host = settings.service.host.clone(); + let p2p_port = settings.cluster.p2p.port.expect("P2P port is not set"); + let config = settings.cluster.consensus.clone(); + + let (mut consensus, message_sender) = Self::new( + logger, + state_ref.clone(), + bootstrap_peer, + uri, + p2p_port, + config, + tls_client_config, + channel_service, + runtime.clone(), + reinit, + )?; + + let state_ref_clone = state_ref.clone(); + thread::Builder::new() + .name("consensus".to_string()) + .spawn(move || { + // On Linux, try to use high thread priority because consensus is important + // Likely fails as we cannot set a higher priority by default due to permissions + #[cfg(target_os = "linux")] + if let Err(err) = linux_high_thread_priority() { + log::debug!( + "Failed to set high thread priority for consensus, ignoring: {err}" + ); + } + + if let Err(err) = consensus.start() { + log::error!("Consensus stopped with error: {err:#}"); + state_ref_clone.on_consensus_thread_err(err); + } else { + log::info!("Consensus stopped"); + state_ref_clone.on_consensus_stopped(); + } + })?; + + let message_sender_moved = message_sender.clone(); + thread::Builder::new() + .name("forward-proposals".to_string()) + .spawn(move || { + // On Linux, try to use high thread priority because consensus is important + // Likely fails as we cannot set a higher priority by default due to permissions + #[cfg(target_os = "linux")] + if let Err(err) = linux_high_thread_priority() { + log::debug!( + "Failed to set high thread priority for consensus, ignoring: {err}" + ); + } + + while let Ok(entry) = propose_receiver.recv() { + if message_sender_moved + .blocking_send(Message::FromClient(entry)) + .is_err() + { + log::error!("Can not forward new entry to consensus as it was stopped."); + break; + } + } + })?; + + let server_tls = if settings.cluster.p2p.enable_tls { + let tls_config = settings + .tls + .clone() + .ok_or_else(Settings::tls_config_is_undefined_error)?; + + Some(helpers::load_tls_internal_server_config(&tls_config)?) + } else { + None + }; + + let handle = thread::Builder::new() + .name("grpc_internal".to_string()) + .spawn(move || { + init_internal( + toc, + state_ref, + telemetry_collector, + settings, + p2p_host, + p2p_port, + server_tls, + message_sender, + runtime, + ) + }) + .unwrap(); + + Ok(handle) + } + + /// If `bootstrap_peer` peer is supplied, then either `uri` or `p2p_port` should be also supplied + #[allow(clippy::too_many_arguments)] + pub fn new( + logger: &slog::Logger, + state_ref: ConsensusStateRef, + bootstrap_peer: Option, + uri: Option, + p2p_port: u16, + config: ConsensusConfig, + tls_config: Option, + channel_service: ChannelService, + runtime: Handle, + reinit: bool, + ) -> anyhow::Result<(Self, Sender)> { + // If we want to re-initialize consensus, we need to prevent other peers + // from re-playing consensus WAL operations, as they should already have them applied. + // Do ensure that we are forcing compacting WAL on the first re-initialized peer, + // which should trigger snapshot transferring instead of replaying WAL. + let force_compact_wal = reinit && bootstrap_peer.is_none(); + + // On the bootstrap-ed peers during reinit of the consensus + // we want to make sure only the bootstrap peer will hold the true state + // Therefore we clear the WAL on the bootstrap peer to force it to request a snapshot + let clear_wal = reinit && bootstrap_peer.is_some(); + + if clear_wal { + log::debug!("Clearing WAL on the bootstrap peer to force snapshot transfer"); + state_ref.clear_wal()?; + } + + // raft will not return entries to the application smaller or equal to `applied` + let last_applied = state_ref.last_applied_entry().unwrap_or_default(); + let raft_config = Config { + id: state_ref.this_peer_id(), + applied: last_applied, + ..Default::default() + }; + raft_config.validate()?; + let op_wait = defaults::CONSENSUS_META_OP_WAIT; + // Commit might take up to 4 ticks as: + // 1 tick - send proposal to leader + // 2 tick - leader sends append entries to peers + // 3 tick - peer answers leader, that entry is persisted + // 4 tick - leader increases commit index and sends it + if 4 * Duration::from_millis(config.tick_period_ms) > op_wait { + log::warn!("With current tick period of {}ms, operation commit time might exceed default wait timeout: {}ms", + config.tick_period_ms, op_wait.as_millis()) + } + // bounded channel for backpressure + let (sender, receiver) = tokio::sync::mpsc::channel(config.max_message_queue_size); + // State might be initialized but the node might be shutdown without actually syncing or committing anything. + if state_ref.is_new_deployment() || reinit { + let leader_established_in_ms = + config.tick_period_ms * raft_config.max_election_tick() as u64; + Self::init( + &state_ref, + bootstrap_peer.clone(), + uri, + p2p_port, + &config, + tls_config.clone(), + &runtime, + leader_established_in_ms, + ) + .map_err(|err| anyhow!("Failed to initialize Consensus for new Raft state: {}", err))?; + } else { + runtime + .block_on(Self::recover( + &state_ref, + uri.clone(), + p2p_port, + &config, + tls_config.clone(), + )) + .map_err(|err| { + anyhow!( + "Failed to recover Consensus from existing Raft state: {}", + err + ) + })?; + + if bootstrap_peer.is_some() || uri.is_some() { + log::debug!("Local raft state found - bootstrap and uri cli arguments were ignored") + } + log::debug!("Local raft state found - skipping initialization"); + }; + + let mut node = Node::new(&raft_config, state_ref.clone(), logger)?; + node.set_batch_append(true); + + // Before consensus has started apply any unapplied committed entries + // They might have not been applied due to unplanned Qdrant shutdown + let _stop_consensus = state_ref.apply_entries(&mut node)?; + + if force_compact_wal { + // Making sure that the WAL will be compacted on start + state_ref.compact_wal(1)?; + } else { + state_ref.compact_wal(config.compact_wal_entries)?; + } + + let broker = RaftMessageBroker::new( + runtime.clone(), + bootstrap_peer, + tls_config, + config.clone(), + node.store().clone(), + channel_service.channel_pool, + ); + + let consensus = Self { + node, + receiver, + runtime, + config, + broker, + }; + + if !state_ref.is_new_deployment() { + state_ref.recover_first_voter()?; + } + + Ok((consensus, sender)) + } + + #[allow(clippy::too_many_arguments)] + fn init( + state_ref: &ConsensusStateRef, + bootstrap_peer: Option, + uri: Option, + p2p_port: u16, + config: &ConsensusConfig, + tls_config: Option, + runtime: &Handle, + leader_established_in_ms: u64, + ) -> anyhow::Result<()> { + if let Some(bootstrap_peer) = bootstrap_peer { + log::debug!("Bootstrapping from peer with address: {bootstrap_peer}"); + runtime.block_on(Self::bootstrap( + state_ref, + bootstrap_peer, + uri, + p2p_port, + config, + tls_config, + ))?; + Ok(()) + } else { + log::debug!( + "Bootstrapping is disabled. Assuming this peer is the first in the network" + ); + let tick_period = config.tick_period_ms; + log::info!("With current tick period of {tick_period}ms, leader will be established in approximately {leader_established_in_ms}ms. To avoid rejected operations - add peers and submit operations only after this period."); + // First peer needs to add its own address + state_ref.add_peer( + state_ref.this_peer_id(), + uri.ok_or_else(|| anyhow::anyhow!("First peer should specify its uri."))? + .parse()?, + )?; + Ok(()) + } + } + + async fn add_peer_to_known_for( + this_peer_id: PeerId, + cluster_uri: Uri, + current_uri: Option, + p2p_port: u16, + config: &ConsensusConfig, + tls_config: Option, + ) -> anyhow::Result { + // Use dedicated transport channel for bootstrapping because of specific timeout + let channel = make_grpc_channel( + Duration::from_secs(config.bootstrap_timeout_sec), + Duration::from_secs(config.bootstrap_timeout_sec), + cluster_uri, + tls_config, + ) + .await + .map_err(|err| anyhow!("Failed to create timeout channel: {err}"))?; + let mut client = RaftClient::new(channel); + let all_peers = client + .add_peer_to_known(tonic::Request::new( + api::grpc::qdrant::AddPeerToKnownMessage { + uri: current_uri, + port: Some(u32::from(p2p_port)), + id: this_peer_id, + }, + )) + .await + .map_err(|err| anyhow!("Failed to add peer to known: {err}"))? + .into_inner(); + Ok(all_peers) + } + + // Re-attach peer to the consensus: + // Notifies the cluster(any node) that this node changed its address + async fn recover( + state_ref: &ConsensusStateRef, + uri: Option, + p2p_port: u16, + config: &ConsensusConfig, + tls_config: Option, + ) -> anyhow::Result<()> { + let this_peer_id = state_ref.this_peer_id(); + let mut peer_to_uri = state_ref + .persistent + .read() + .peer_address_by_id + .read() + .clone(); + let this_peer_url = peer_to_uri.remove(&this_peer_id); + // Recover url if a different one is provided + let do_recover = match (&this_peer_url, &uri) { + (Some(this_peer_url), Some(uri)) => this_peer_url != &Uri::from_str(uri)?, + _ => false, + }; + + if do_recover { + let mut tries = RECOVERY_MAX_RETRY_COUNT; + while tries > 0 { + // Try to inform any peer about the change of address + for (peer_id, peer_uri) in &peer_to_uri { + let res = Self::add_peer_to_known_for( + this_peer_id, + peer_uri.clone(), + uri.clone(), + p2p_port, + config, + tls_config.clone(), + ) + .await; + if res.is_err() { + log::warn!( + "Failed to recover from peer with id {} at {} with error {:?}, trying others", + peer_id, + peer_uri, + res + ); + } else { + log::debug!( + "Successfully recovered from peer with id {} at {}", + peer_id, + peer_uri + ); + return Ok(()); + } + } + tries -= 1; + log::warn!( + "Retrying recovering from known peers (retry {})", + RECOVERY_MAX_RETRY_COUNT - tries + ); + let exp_timeout = + RECOVERY_RETRY_TIMEOUT * (RECOVERY_MAX_RETRY_COUNT - tries) as u32; + sleep(exp_timeout).await; + } + return Err(anyhow::anyhow!("Failed to recover from any known peers")); + } + + Ok(()) + } + + /// Add node sequence: + /// + /// 1. Add current node as a learner + /// 2. Start applying entries from consensus + /// 3. Eventually leader submits the promotion proposal + /// 4. Learners become voters once they read about the promotion from consensus log + async fn bootstrap( + state_ref: &ConsensusStateRef, + bootstrap_peer: Uri, + uri: Option, + p2p_port: u16, + config: &ConsensusConfig, + tls_config: Option, + ) -> anyhow::Result<()> { + let this_peer_id = state_ref.this_peer_id(); + let all_peers = Self::add_peer_to_known_for( + this_peer_id, + bootstrap_peer, + uri.clone(), + p2p_port, + config, + tls_config, + ) + .await?; + + // Although peer addresses are synchronized with consensus, addresses need to be pre-fetched in the case of a new peer + // or it will not know how to answer the Raft leader + for peer in all_peers.all_peers { + state_ref + .add_peer( + peer.id, + peer.uri + .parse() + .context(format!("Failed to parse peer URI: {}", peer.uri))?, + ) + .map_err(|err| anyhow!("Failed to add peer: {}", err))? + } + // Only first peer has itself as a voter in the initial conf state. + // This needs to be propagated manually to other peers as it is not contained in any log entry. + // So we skip the learner phase for the first peer. + state_ref.set_first_voter(all_peers.first_peer_id)?; + state_ref.set_conf_state(ConfState::from((vec![all_peers.first_peer_id], vec![])))?; + Ok(()) + } + + pub fn start(&mut self) -> anyhow::Result<()> { + // If this is the only peer in the cluster, tick Raft node a few times to instantly + // self-elect itself as Raft leader + if self.node.store().peer_count() == 1 { + while !self.node.has_ready() { + self.node.tick(); + } + } + + let tick_period = Duration::from_millis(self.config.tick_period_ms); + let mut previous_tick = Instant::now(); + + loop { + // Apply in-memory changes to the Raft State Machine + // If updates = None, we need to skip this step due to timing limits + // If updates = Some(0), means we didn't receive any updates explicitly + let updates = self.advance_node(previous_tick, tick_period)?; + + let mut elapsed = previous_tick.elapsed(); + + while elapsed > tick_period { + self.node.tick(); + + previous_tick += tick_period; + elapsed -= tick_period; + } + + if self.node.has_ready() { + // Persist AND apply changes, which were committed in the Raft State Machine + let stop_consensus = self.on_ready()?; + + if stop_consensus { + return Ok(()); + } + } else if updates == Some(0) { + // Assume consensus is up-to-date, we can sync local state + // Which involves resoling inconsistencies and trying to recover data marked as dead + self.try_sync_local_state()?; + } + } + } + + fn advance_node( + &mut self, + previous_tick: Instant, + tick_period: Duration, + ) -> anyhow::Result> { + if previous_tick.elapsed() >= tick_period { + return Ok(None); + } + + match self.try_add_origin() { + // `try_add_origin` is not applicable: + // - either current peer is not an origin peer + // - or cluster is already established + Ok(false) => (), + + // Successfully proposed origin peer to consensus, return to consensus loop to handle `on_ready` + Ok(true) => return Ok(Some(1)), + + // Origin peer is not a leader yet, wait for the next tick and return to consensus loop + // to tick Raft node + Err(err @ TryAddOriginError::NotLeader) => { + log::debug!("{err}"); + + let next_tick = previous_tick + tick_period; + let duration_until_next_tick = next_tick.saturating_duration_since(Instant::now()); + thread::sleep(duration_until_next_tick); + + return Ok(None); + } + + // Failed to propose origin peer ID to consensus (which should never happen!), + // log error and continue regular consensus loop + Err(err) => { + log::error!("{err}"); + } + } + + if self + .try_promote_learner() + .context("failed to promote learner")? + { + return Ok(Some(1)); + } + + let mut updates = 0; + let mut timeout_at = previous_tick + tick_period; + + // We need to limit the batch size, as application of one batch should be limited in time. + const RAFT_BATCH_SIZE: usize = 128; + + let wait_timeout_for_consecutive_messages = tick_period / 10; + + // This loop batches incoming messages, so we would need to "apply" them only once. + // The "Apply" step is expensive, so it is done for performance reasons. + + // But on the other hand, we still want to react to rare + // individual messages as fast as possible. + // To fulfill both requirements, we are going the following way: + // 1. Wait for the first message for full tick period. + // 2. If the message is received, wait for the next message only for 1/10 of the tick period. + loop { + // This queue have 2 types of events: + // - Messages from the leader, like pings, requests to add logs, acks, etc. + // - Messages from users, like requests to start shard transfers, etc. + // + // Timeout defines how long can we wait for the next message. + // Since this thread is sync, we can't wait indefinitely. + // Timeout is set up to be about the time of tick. + let Ok(message) = self.recv_update(timeout_at) else { + break; + }; + + // Those messages should not be batched, so we interrupt the loop if we see them. + // Motivation is: if we change the peer, it should be done immediately, + // otherwise we loose the update on this new peer + let is_conf_change = matches!( + message, + Message::FromClient( + ConsensusOperations::AddPeer { .. } | ConsensusOperations::RemovePeer(_) + ), + ); + + // We put the message in Raft State Machine + // This update will hold update in memory, but will not be persisted yet. + // E.g. if it is a ping, we don't need to persist anything ofr it. + if let Err(err) = self.advance_node_impl(message) { + log::warn!("{err}"); + continue; + } + + updates += 1; + timeout_at = Instant::now() + wait_timeout_for_consecutive_messages; + + if previous_tick.elapsed() >= tick_period + || updates >= RAFT_BATCH_SIZE + || is_conf_change + { + break; + } + } + + Ok(Some(updates)) + } + + fn recv_update(&mut self, timeout_at: Instant) -> Result { + self.runtime.block_on(async { + tokio::select! { + biased; + _ = tokio::time::sleep_until(timeout_at.into()) => Err(TryRecvUpdateError::Timeout), + message = self.receiver.recv() => message.ok_or(TryRecvUpdateError::Closed), + } + }) + } + + fn advance_node_impl(&mut self, message: Message) -> anyhow::Result<()> { + match message { + Message::FromClient(ConsensusOperations::AddPeer { peer_id, uri }) => { + let mut change = ConfChangeV2::default(); + + change.set_changes(vec![raft_proto::new_conf_change_single( + peer_id, + ConfChangeType::AddLearnerNode, + )]); + + log::debug!("Proposing network configuration change: {:?}", change); + self.node + .propose_conf_change(uri.into_bytes(), change) + .context("failed to propose conf change")?; + } + + Message::FromClient(ConsensusOperations::RemovePeer(peer_id)) => { + let mut change = ConfChangeV2::default(); + + change.set_changes(vec![raft_proto::new_conf_change_single( + peer_id, + ConfChangeType::RemoveNode, + )]); + + log::debug!("Proposing network configuration change: {:?}", change); + self.node + .propose_conf_change(vec![], change) + .context("failed to propose conf change")?; + } + + Message::FromClient(ConsensusOperations::RequestSnapshot) => { + self.node + .request_snapshot() + .context("failed to request snapshot")?; + } + + Message::FromClient(ConsensusOperations::ReportSnapshot { peer_id, status }) => { + self.node.report_snapshot(peer_id, status.into()); + } + + Message::FromClient(operation) => { + let data = + serde_cbor::to_vec(&operation).context("failed to serialize operation")?; + + log::trace!("Proposing entry from client with length: {}", data.len()); + self.node + .propose(vec![], data) + .context("failed to propose entry")?; + } + + Message::FromPeer(message) => { + let is_heartbeat = matches!( + message.get_msg_type(), + MessageType::MsgHeartbeat | MessageType::MsgHeartbeatResponse, + ); + + if !is_heartbeat { + log::trace!( + "Received a message from peer with progress: {:?}. Message: {:?}", + self.node.raft.prs().get(message.from), + message, + ); + } + + self.node.step(*message).context("failed to step message")?; + } + } + + Ok(()) + } + + fn try_sync_local_state(&mut self) -> anyhow::Result<()> { + if !self.node.has_ready() { + // No updates to process + let store = self.node.store(); + let pending_operations = store.persistent.read().unapplied_entities_count(); + if pending_operations == 0 && store.is_leader_established.check_ready() { + // If leader is established and there is nothing else to do on this iteration, + // then we can check if there are any un-synchronized local state left. + store.sync_local_state()?; + } + } + Ok(()) + } + + /// Tries to propose "origin peer" (the very first peer, that starts new cluster) to consensus + fn try_add_origin(&mut self) -> Result { + // We can determine origin peer from consensus state: + // - it should be the only peer in the cluster + // - and its commit index should be at 0 or 1 + // + // When we add a new node to existing cluster, we have to bootstrap it from existing cluster + // node, and during bootstrap we explicitly add all current peers to consensus state. So, + // *all* peers added to the cluster after the origin will always have at least two peers. + // + // When origin peer starts new cluster, it self-elects itself as a leader and commits empty + // operation with index 1. It is impossible to commit anything to consensus before this + // operation is committed. And to add another (second/third/etc) peer to the cluster, we + // have to commit a conf-change operation. Which means that only origin peer can ever be at + // commit index 0 or 1. + + // Check that we are the only peer in the cluster + if self.node.store().peer_count() > 1 { + return Ok(false); + } + + let status = self.node.status(); + + // Check that we are at index 0 or 1 + if status.hs.commit > 1 { + return Ok(false); + } + + // If we reached this point, we are the origin peer, but it's impossible to propose anything + // to consensus, before leader is elected (`propose_conf_change` will return an error), + // so we have to wait for a few ticks for self-election + if status.ss.raft_state != StateRole::Leader { + return Err(TryAddOriginError::NotLeader); + } + + // Propose origin peer to consensus + let mut change = ConfChangeV2::default(); + + change.set_changes(vec![raft_proto::new_conf_change_single( + status.id, + ConfChangeType::AddNode, + )]); + + let peer_uri = self + .node + .store() + .persistent + .read() + .peer_address_by_id + .read() + .get(&status.id) + .ok_or_else(|| TryAddOriginError::UriNotFound)? + .to_string(); + + self.node.propose_conf_change(peer_uri.into(), change)?; + + Ok(true) + } + + /// Returns `true` if learner promotion was proposed, `false` otherwise. + /// Learner node does not vote on elections, cause it might not have a big picture yet. + /// So consensus should guarantee that learners are promoted one-by-one. + /// Promotions are done by leader and only after it has no pending entries, + /// that guarantees that learner will start voting only after it applies all the changes in the log + fn try_promote_learner(&mut self) -> anyhow::Result { + // Promote only if leader + if self.node.status().ss.raft_state != StateRole::Leader { + return Ok(false); + } + + // Promote only when there are no uncommitted changes. + let store = self.node.store(); + let commit = store.hard_state().commit; + let last_log_entry = store.last_index()?; + + if commit != last_log_entry { + return Ok(false); + } + + let Some(learner) = self.find_learner_to_promote() else { + return Ok(false); + }; + + log::debug!("Proposing promotion for learner {learner} to voter"); + + let mut change = ConfChangeV2::default(); + + change.set_changes(vec![raft_proto::new_conf_change_single( + learner, + ConfChangeType::AddNode, + )]); + + self.node.propose_conf_change(vec![], change)?; + + Ok(true) + } + + fn find_learner_to_promote(&self) -> Option { + let commit = self.node.store().hard_state().commit; + let learners: HashSet<_> = self + .node + .store() + .conf_state() + .learners + .into_iter() + .collect(); + let status = self.node.status(); + status + .progress? + .iter() + .find(|(id, progress)| learners.contains(id) && progress.matched == commit) + .map(|(id, _)| *id) + } + + /// Returns `true` if consensus should be stopped, `false` otherwise. + fn on_ready(&mut self) -> anyhow::Result { + if !self.node.has_ready() { + // No updates to process + return Ok(false); + } + self.store().record_consensus_working(); + // Get the `Ready` with `RawNode::ready` interface. + let ready = self.node.ready(); + + let (Some(light_ready), role_change) = self.process_ready(ready)? else { + // No light ready, so we need to stop consensus. + return Ok(true); + }; + + let result = self.process_light_ready(light_ready)?; + + if let Some(role_change) = role_change { + self.process_role_change(role_change); + } + + self.store().compact_wal(self.config.compact_wal_entries)?; + + Ok(result) + } + + fn process_role_change(&self, role_change: StateRole) { + // Explicit match here for better readability + match role_change { + StateRole::Candidate | StateRole::PreCandidate => { + self.store().is_leader_established.make_not_ready() + } + StateRole::Leader | StateRole::Follower => { + if self.node.raft.leader_id != INVALID_ID { + self.store().is_leader_established.make_ready() + } else { + self.store().is_leader_established.make_not_ready() + } + } + } + } + + /// Tries to process raft's ready state. Happens on each tick. + /// + /// The order of operations in this functions is critical, changing it might lead to bugs. + /// + /// Returns with err on failure to apply the state. + /// If it receives message to stop the consensus - returns None instead of LightReady. + fn process_ready( + &mut self, + mut ready: raft::Ready, + ) -> anyhow::Result<(Option, Option)> { + let store = self.store(); + + if !ready.messages().is_empty() { + log::trace!("Handling {} messages", ready.messages().len()); + self.send_messages(ready.take_messages()); + } + if !ready.snapshot().is_empty() { + // This is a snapshot, we need to apply the snapshot at first. + log::debug!("Applying snapshot"); + + if let Err(err) = store.apply_snapshot(&ready.snapshot().clone())? { + log::error!("Failed to apply snapshot: {err}"); + } + } + if !ready.entries().is_empty() { + // Append entries to the Raft log. + log::debug!("Appending {} entries to raft log", ready.entries().len()); + store + .append_entries(ready.take_entries()) + .map_err(|err| anyhow!("Failed to append entries: {}", err))? + } + if let Some(hs) = ready.hs() { + // Raft HardState changed, and we need to persist it. + log::debug!("Changing hard state. New hard state: {hs:?}"); + store + .set_hard_state(hs.clone()) + .map_err(|err| anyhow!("Failed to set hard state: {}", err))? + } + let role_change = ready.ss().map(|ss| ss.raft_state); + if let Some(ss) = ready.ss() { + log::debug!("Changing soft state. New soft state: {ss:?}"); + self.handle_soft_state(ss); + } + if !ready.persisted_messages().is_empty() { + log::trace!( + "Handling {} persisted messages", + ready.persisted_messages().len() + ); + self.send_messages(ready.take_persisted_messages()); + } + // Should be done after Hard State is saved, so that `applied` index is never bigger than `commit`. + let stop_consensus = + handle_committed_entries(&ready.take_committed_entries(), &store, &mut self.node) + .context("Failed to handle committed entries")?; + if stop_consensus { + return Ok((None, None)); + } + + // Advance the Raft. + let light_rd = self.node.advance(ready); + Ok((Some(light_rd), role_change)) + } + + /// Tries to process raft's light ready state. + /// + /// The order of operations in this functions is critical, changing it might lead to bugs. + /// + /// Returns with err on failure to apply the state. + /// If it receives message to stop the consensus - returns `true`, otherwise `false`. + fn process_light_ready(&mut self, mut light_rd: raft::LightReady) -> anyhow::Result { + let store = self.store(); + // Update commit index. + if let Some(commit) = light_rd.commit_index() { + log::debug!("Updating commit index to {commit}"); + store + .set_commit_index(commit) + .map_err(|err| anyhow!("Failed to set commit index: {}", err))?; + } + self.send_messages(light_rd.take_messages()); + // Apply all committed entries. + let stop_consensus = + handle_committed_entries(&light_rd.take_committed_entries(), &store, &mut self.node) + .context("Failed to apply committed entries")?; + // Advance the apply index. + self.node.advance_apply(); + Ok(stop_consensus) + } + + fn store(&self) -> ConsensusStateRef { + self.node.store().clone() + } + + fn handle_soft_state(&self, state: &SoftState) { + let store = self.node.store(); + store.set_raft_soft_state(state); + } + + fn send_messages(&mut self, messages: Vec) { + self.broker.send(messages); + } +} + +#[derive(Copy, Clone, Debug, thiserror::Error)] +enum TryRecvUpdateError { + #[error("timeout elapsed")] + Timeout, + + #[error("channel closed")] + Closed, +} + +#[derive(Debug, thiserror::Error)] +enum TryAddOriginError { + #[error("origin peer is not a leader")] + NotLeader, + + #[error("origin peer URI not found")] + UriNotFound, + + #[error("failed to propose origin peer URI to consensus: {0}")] + RaftError(#[from] raft::Error), +} + +/// This function actually applies the committed entries to the state machine. +/// Return `true` if consensus should be stopped. +/// `false` otherwise. +fn handle_committed_entries( + entries: &[Entry], + state: &ConsensusStateRef, + raw_node: &mut RawNode, +) -> anyhow::Result { + let mut stop_consensus = false; + if let (Some(first), Some(last)) = (entries.first(), entries.last()) { + state.set_unapplied_entries(first.index, last.index)?; + stop_consensus = state.apply_entries(raw_node)?; + } + Ok(stop_consensus) +} + +struct RaftMessageBroker { + senders: HashMap, + runtime: Handle, + bootstrap_uri: Option, + tls_config: Option, + consensus_config: Arc, + consensus_state: ConsensusStateRef, + transport_channel_pool: Arc, +} + +impl RaftMessageBroker { + pub fn new( + runtime: Handle, + bootstrap_uri: Option, + tls_config: Option, + consensus_config: ConsensusConfig, + consensus_state: ConsensusStateRef, + transport_channel_pool: Arc, + ) -> Self { + Self { + senders: HashMap::new(), + runtime, + bootstrap_uri, + tls_config, + consensus_config: consensus_config.into(), + consensus_state, + transport_channel_pool, + } + } + + pub fn send(&mut self, messages: impl IntoIterator) { + let mut messages = messages.into_iter(); + let mut retry = None; + + while let Some(message) = retry.take().or_else(|| messages.next()) { + let peer_id = message.to; + + let sender = match self.senders.get_mut(&peer_id) { + Some(sender) => sender, + None => { + log::debug!("Spawning message sender task for peer {peer_id}..."); + + let (task, handle) = self.message_sender(); + let future = self.runtime.spawn(task.exec()); + drop(future); // drop `JoinFuture` explicitly to make clippy happy + + self.senders.insert(peer_id, handle); + + self.senders + .get_mut(&peer_id) + .expect("message sender task spawned") + } + }; + + let failed_to_forward = |message: &RaftMessage, description: &str| { + let peer_id = message.to; + + let is_debug = log::max_level() >= log::Level::Debug; + let space = if is_debug { " " } else { "" }; + let message: &dyn fmt::Debug = if is_debug { &message } else { &"" }; // TODO: `fmt::Debug` for `""` prints `""`... 😒 + + log::error!( + "Failed to forward message{space}{message:?} to message sender task {peer_id}: \ + {description}" + ); + }; + + match sender.send(message) { + Ok(()) => (), + + Err(tokio::sync::mpsc::error::TrySendError::Full((_, message))) => { + failed_to_forward( + &message, + "message sender task queue is full. Message will be dropped.", + ); + } + + Err(tokio::sync::mpsc::error::TrySendError::Closed((_, message))) => { + failed_to_forward( + &message, + "message sender task queue is closed. \ + Message sender task will be restarted and message will be retried.", + ); + + self.senders.remove(&peer_id); + retry = Some(message); + } + } + } + } + + fn message_sender(&self) -> (RaftMessageSender, RaftMessageSenderHandle) { + let (messages_tx, messages_rx) = tokio::sync::mpsc::channel(128); + let (heartbeat_tx, heartbeat_rx) = tokio::sync::watch::channel(Default::default()); + + let task = RaftMessageSender { + messages: messages_rx, + heartbeat: heartbeat_rx, + bootstrap_uri: self.bootstrap_uri.clone(), + tls_config: self.tls_config.clone(), + consensus_config: self.consensus_config.clone(), + consensus_state: self.consensus_state.clone(), + transport_channel_pool: self.transport_channel_pool.clone(), + }; + + let handle = RaftMessageSenderHandle { + messages: messages_tx, + heartbeat: heartbeat_tx, + index: 0, + }; + + (task, handle) + } +} + +#[derive(Debug)] +struct RaftMessageSenderHandle { + messages: Sender<(usize, RaftMessage)>, + heartbeat: watch::Sender<(usize, RaftMessage)>, + index: usize, +} + +impl RaftMessageSenderHandle { + #[allow(clippy::result_large_err)] + pub fn send(&mut self, message: RaftMessage) -> RaftMessageSenderResult<()> { + if !is_heartbeat(&message) { + self.messages.try_send((self.index, message))?; + } else { + self.heartbeat.send((self.index, message)).map_err( + |tokio::sync::watch::error::SendError(message)| { + tokio::sync::mpsc::error::TrySendError::Closed(message) + }, + )?; + } + + self.index += 1; + + Ok(()) + } +} + +type RaftMessageSenderResult = Result; +type RaftMessageSenderError = tokio::sync::mpsc::error::TrySendError<(usize, RaftMessage)>; + +struct RaftMessageSender { + messages: Receiver<(usize, RaftMessage)>, + heartbeat: watch::Receiver<(usize, RaftMessage)>, + bootstrap_uri: Option, + tls_config: Option, + consensus_config: Arc, + consensus_state: ConsensusStateRef, + transport_channel_pool: Arc, +} + +impl RaftMessageSender { + pub async fn exec(mut self) { + // Imagine that `raft` crate put four messages to be sent to some other Raft node into + // `RaftMessageSender`'s queue: + // + // | 4: AppendLog | 3: Heartbeat | 2: Heartbeat | 1: AppendLog | + // + // Heartbeat is the most basic message type in Raft. It only carries common "metadata" + // without any additional "payload". And all other message types in Raft also carry + // the same basic metadata as the heartbeat message. + // + // This way, message `3` instantly "outdates" message `2`: they both carry the same data + // fields, but message `3` was produced more recently, and so it might contain newer values + // of these data fields. + // + // And because all messages carry the same basic data as the heartbeat message, message `4` + // instantly "outdates" both message `2` and `3`. + // + // This way, if there are more than one message queued for the `RaftMessageSender`, + // we can optimize delivery a bit and skip any heartbeat message if there's a more + // recent message scheduled later in the queue. + // + // `RaftMessageSender` have two separate "queues": + // - `messages` queue for non-heartbeat messages + // - and `heartbeat` "watch" channel for heartbeat messages + // - "watch" is a special channel in Tokio, that only retains the *last* sent value + // - so any heartbeat received from the `heartbeat` channel is always the *most recent* one + // + // We are using `tokio::select` to "simultaneously" check both queues for new messages... + // but we are using `tokio::select` in a "biased" mode! + // + // - in this mode select always polls `messages.recv()` future first + // - so even if there are new messages in both queues, it will always return a non-heartbeat + // message from `messages` queue first + // - and it will only return a heartbeat message from `heartbeat` channel if there's no + // messages left in the `messages` queue + // + // There's one special case that we should be careful about with our two queues: + // + // If we return to the diagram above, and imagine four messages were sent in the same order + // into our two queues, then `RaftMessageSender` might pull them from the queues in the + // `1`, `4`, `3` order. + // + // E.g., we pull non-heartbeat messages `1` and `4` first, heartbeat `2` was overwritten + // by heartbeat `3` (because of the "watch" channel), so once `messages` queue is empty + // we receive heartbeat `3`, which is now out-of-order. + // + // To handle this we explicitly enumerate each message and only send a message if its index + // is higher-or-equal than the index of a previous one. (This check can be expressed with + // both strict "higher" or "higher-or-equal" conditional, I just like the "or-equal" version + // a bit better.) + // + // If either `messages` queue or `heartbeat` channel is closed (e.g., `messages.recv()` + // returns `None` or `heartbeat.changed()` returns an error), we assume that + // `RaftMessageSenderHandle` has been dropped, and treat it as a "shutdown"/"cancellation" + // signal (and break from the loop). + + let mut prev_index = 0; + + loop { + let (index, message) = tokio::select! { + biased; + Some(message) = self.messages.recv() => message, + Ok(()) = self.heartbeat.changed() => self.heartbeat.borrow_and_update().clone(), + else => break, + }; + + if prev_index <= index { + self.send(&message).await; + prev_index = index; + } + } + } + + async fn send(&mut self, message: &RaftMessage) { + if let Err(err) = self.try_send(message).await { + let peer_id = message.to; + + if log::max_level() >= log::Level::Debug { + log::error!("Failed to send Raft message {message:?} to peer {peer_id}: {err}"); + } else { + log::error!("Failed to send Raft message to peer {peer_id}: {err}"); + } + } + } + + async fn try_send(&mut self, message: &RaftMessage) -> anyhow::Result<()> { + let peer_id = message.to; + + let uri = self.uri(peer_id).await?; + + let mut bytes = Vec::new(); + ::encode(message, &mut bytes) + .context("failed to encode Raft message")?; + let grpc_message = GrpcRaftMessage { message: bytes }; + + let timeout = Duration::from_millis( + self.consensus_config.message_timeout_ticks * self.consensus_config.tick_period_ms, + ); + + let res = self + .transport_channel_pool + .with_channel_timeout( + &uri, + |channel| async { + let mut client = RaftClient::new(channel); + let mut request = tonic::Request::new(grpc_message.clone()); + request.set_timeout(timeout); + client.send(request).await + }, + Some(timeout), + 0, + ) + .await; + + if message.msg_type == raft::eraftpb::MessageType::MsgSnapshot as i32 { + let res = self.consensus_state.report_snapshot( + peer_id, + if res.is_ok() { + SnapshotStatus::Finish + } else { + SnapshotStatus::Failure + }, + ); + + // Should we ignore the error? Seems like it will only produce noise. + // + // - `send_message` is only called by the sub-task spawned by the consensus thread. + // - `report_snapshot` sends a message back to the consensus thread. + // - It can only fail, if the "receiver" end of the channel is closed. + // - Which means consensus thread either resolved successfully, or failed. + // - So, if the consensus thread is shutting down, no need to log a misleading error... + // - ...or, if the consensus thread failed, then we should already have an error, + // and it will only produce more noise. + + if let Err(err) = res { + log::error!("{}", err); + } + } + + match res { + Ok(_) => self.consensus_state.record_message_send_success(&uri), + Err(err) => self.consensus_state.record_message_send_failure(&uri, err), + } + + Ok(()) + } + + async fn uri(&mut self, peer_id: PeerId) -> anyhow::Result { + let uri = self + .consensus_state + .peer_address_by_id() + .get(&peer_id) + .cloned(); + + match uri { + Some(uri) => Ok(uri), + None => self.who_is(peer_id).await, + } + } + + async fn who_is(&mut self, peer_id: PeerId) -> anyhow::Result { + let bootstrap_uri = self + .bootstrap_uri + .clone() + .ok_or_else(|| anyhow::format_err!("No bootstrap URI provided"))?; + + let bootstrap_timeout = Duration::from_secs(self.consensus_config.bootstrap_timeout_sec); + + // Use dedicated transport channel for who_is because of specific timeout + let channel = make_grpc_channel( + bootstrap_timeout, + bootstrap_timeout, + bootstrap_uri, + self.tls_config.clone(), + ) + .await + .map_err(|err| anyhow::format_err!("Failed to create who-is channel: {}", err))?; + + let uri = RaftClient::new(channel) + .who_is(tonic::Request::new(GrpcPeerId { id: peer_id })) + .await? + .into_inner() + .uri + .parse()?; + + Ok(uri) + } +} + +fn is_heartbeat(message: &RaftMessage) -> bool { + message.msg_type == raft::eraftpb::MessageType::MsgHeartbeat as i32 + || message.msg_type == raft::eraftpb::MessageType::MsgHeartbeatResponse as i32 +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::thread; + + use collection::operations::vector_params_builder::VectorParamsBuilder; + use collection::shards::channel_service::ChannelService; + use common::cpu::CpuBudget; + use segment::types::Distance; + use slog::Drain; + use storage::content_manager::collection_meta_ops::{ + CollectionMetaOperations, CreateCollection, CreateCollectionOperation, + }; + use storage::content_manager::consensus::operation_sender::OperationSender; + use storage::content_manager::consensus::persistent::Persistent; + use storage::content_manager::consensus_manager::{ConsensusManager, ConsensusStateRef}; + use storage::content_manager::toc::TableOfContent; + use storage::dispatcher::Dispatcher; + use storage::rbac::Access; + use tempfile::Builder; + + use super::Consensus; + use crate::common::helpers::create_general_purpose_runtime; + use crate::settings::ConsensusConfig; + + #[test] + fn collection_creation_passes_consensus() { + // Given + let storage_dir = Builder::new().prefix("storage").tempdir().unwrap(); + let mut settings = crate::Settings::new(None).expect("Can't read config."); + settings.storage.storage_path = storage_dir.path().to_str().unwrap().to_string(); + tracing_subscriber::fmt::init(); + let search_runtime = + crate::create_search_runtime(settings.storage.performance.max_search_threads) + .expect("Can't create search runtime."); + let update_runtime = + crate::create_update_runtime(settings.storage.performance.max_search_threads) + .expect("Can't create update runtime."); + let general_runtime = + create_general_purpose_runtime().expect("Can't create general purpose runtime."); + let handle = general_runtime.handle().clone(); + let (propose_sender, propose_receiver) = std::sync::mpsc::channel(); + let persistent_state = + Persistent::load_or_init(&settings.storage.storage_path, true, false).unwrap(); + let operation_sender = OperationSender::new(propose_sender); + let toc = TableOfContent::new( + &settings.storage, + search_runtime, + update_runtime, + general_runtime, + CpuBudget::default(), + ChannelService::new(settings.service.http_port, None), + persistent_state.this_peer_id(), + Some(operation_sender.clone()), + ); + let toc_arc = Arc::new(toc); + let storage_path = toc_arc.storage_path(); + let consensus_state: ConsensusStateRef = ConsensusManager::new( + persistent_state, + toc_arc.clone(), + operation_sender, + storage_path, + ) + .into(); + let dispatcher = Dispatcher::new(toc_arc.clone()).with_consensus(consensus_state.clone()); + let slog_logger = slog::Logger::root(slog_stdlog::StdLog.fuse(), slog::o!()); + let (mut consensus, message_sender) = Consensus::new( + &slog_logger, + consensus_state.clone(), + None, + Some("http://127.0.0.1:6335".parse().unwrap()), + 6335, + ConsensusConfig::default(), + None, + ChannelService::new(settings.service.http_port, None), + handle.clone(), + false, + ) + .unwrap(); + + let is_leader_established = consensus_state.is_leader_established.clone(); + thread::spawn(move || consensus.start().unwrap()); + thread::spawn(move || { + while let Ok(entry) = propose_receiver.recv() { + if message_sender + .blocking_send(super::Message::FromClient(entry)) + .is_err() + { + log::error!("Can not forward new entry to consensus as it was stopped."); + break; + } + } + }); + // Wait for Raft to establish the leader + is_leader_established.await_ready(); + // Leader election produces a raft log entry, and then origin peer adds itself to consensus + assert_eq!(consensus_state.hard_state().commit, 2); + // Initially there are 0 collections + assert_eq!(toc_arc.all_collections_sync().len(), 0); + + // When + + // New runtime is used as timers need to be enabled. + handle + .block_on( + dispatcher.submit_collection_meta_op( + CollectionMetaOperations::CreateCollection(CreateCollectionOperation::new( + "test".to_string(), + CreateCollection { + vectors: VectorParamsBuilder::new(10, Distance::Cosine) + .build() + .into(), + sparse_vectors: None, + hnsw_config: None, + wal_config: None, + optimizers_config: None, + shard_number: Some(2), + on_disk_payload: None, + replication_factor: None, + write_consistency_factor: None, + init_from: None, + quantization_config: None, + sharding_method: None, + strict_mode_config: None, + uuid: None, + }, + )), + Access::full("For test"), + None, + ), + ) + .unwrap(); + + // Then + assert_eq!(consensus_state.hard_state().commit, 5); // first peer self-election + add first peer + create collection + activate shard x2 + assert_eq!(toc_arc.all_collections_sync(), vec!["test"]); + } +} diff --git a/src/greeting.rs b/src/greeting.rs new file mode 100644 index 0000000000000000000000000000000000000000..607e0290434af4f77bcf53cc8388dd3f062e6e41 --- /dev/null +++ b/src/greeting.rs @@ -0,0 +1,138 @@ +use std::cmp::min; +use std::env; +use std::io::{stdout, IsTerminal}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +use api::rest::models::get_git_commit_id; +use colored::{Color, ColoredString, Colorize}; + +use crate::settings::Settings; + +fn paint_red(text: &str, true_color: bool) -> ColoredString { + if true_color { + text.bold().truecolor(184, 20, 56) + } else { + text.bold().color(Color::Red) + } +} + +fn paint_green(text: &str, true_color: bool) -> ColoredString { + if true_color { + text.truecolor(134, 186, 144) + } else { + text.color(Color::Green) + } +} + +fn paint_blue(text: &str, true_color: bool) -> ColoredString { + if true_color { + text.bold().truecolor(82, 139, 183) + } else { + text.bold().color(Color::Blue) + } +} + +/// Check whether the given IP will be reachable from `localhost` +/// +/// This is a static analysis based on (very) common defaults and doesn't probe the current +/// routing table. +fn is_localhost_ip(host: &str) -> bool { + let Ok(ip) = host.parse::() else { + return false; + }; + + // Unspecified IPs bind to all interfaces, so `localhost` always points to it + if ip == IpAddr::V4(Ipv4Addr::UNSPECIFIED) || ip == IpAddr::V6(Ipv6Addr::UNSPECIFIED) { + return true; + } + + // On all tested OSes IPv4 localhost points to `localhost` + if ip == IpAddr::V4(Ipv4Addr::LOCALHOST) { + return true; + } + + // On macOS IPv6 localhost points to `localhost`, on Linux it is `ip6-localhost` + if cfg!(target_os = "macos") && ip == IpAddr::V6(Ipv6Addr::LOCALHOST) { + return true; + } + + false +} + +/// Prints welcome message +pub fn welcome(settings: &Settings) { + if !stdout().is_terminal() { + colored::control::set_override(false); + } + + let mut true_color = true; + + match env::var("COLORTERM") { + Ok(val) => { + if val != "24bit" && val != "truecolor" { + true_color = false; + } + } + Err(_) => true_color = false, + } + + let title = [ + r" _ _ ", + r" __ _ __| |_ __ __ _ _ __ | |_ ", + r" / _` |/ _` | '__/ _` | '_ \| __| ", + r"| (_| | (_| | | | (_| | | | | |_ ", + r" \__, |\__,_|_| \__,_|_| |_|\__| ", + r" |_| ", + ]; + for line in title { + println!("{}", paint_red(line, true_color)); + } + println!(); + + // Print current version and, if available, first 8 characters of the git commit hash + let git_commit_info = get_git_commit_id() + .map(|git_commit| { + format!( + ", {} {}", + paint_green("build:", true_color), + paint_blue(&git_commit[..min(8, git_commit.len())], true_color), + ) + }) + .unwrap_or_default(); + + println!( + "{} {}{}", + paint_green("Version:", true_color), + paint_blue(env!("CARGO_PKG_VERSION"), true_color), + git_commit_info + ); + + // Print link to web UI + let ui_link = format!( + "http{}://{}:{}/dashboard", + if settings.service.enable_tls { "s" } else { "" }, + if is_localhost_ip(&settings.service.host) { + "localhost" + } else { + &settings.service.host + }, + settings.service.http_port + ); + + println!( + "{} {}", + paint_green("Access web UI at", true_color), + paint_blue(&ui_link, true_color).underline() + ); + println!(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_welcome() { + welcome(&Settings::new(None).unwrap()); + } +} diff --git a/src/issues_setup.rs b/src/issues_setup.rs new file mode 100644 index 0000000000000000000000000000000000000000..09779bd081ebef7ba92ea91e7e0e4f059753e72d --- /dev/null +++ b/src/issues_setup.rs @@ -0,0 +1,20 @@ +use std::time::Duration; + +use collection::events::{CollectionDeletedEvent, IndexCreatedEvent, SlowQueryEvent}; +use segment::problems::unindexed_field; +use storage::issues_subscribers::UnindexedFieldSubscriber; + +use crate::settings::Settings; + +pub fn setup_subscribers(settings: &Settings) { + settings + .service + .slow_query_secs + .map(|secs| unindexed_field::SLOW_QUERY_THRESHOLD.set(Duration::from_secs_f32(secs))); + + let unindexed_subscriber = UnindexedFieldSubscriber; + + issues::broker::add_subscriber::(Box::new(unindexed_subscriber)); + issues::broker::add_subscriber::(Box::new(unindexed_subscriber)); + issues::broker::add_subscriber::(Box::new(unindexed_subscriber)); +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..ba15e98bd89986eee22b0973abc5f318142a0149 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,554 @@ +#[cfg(feature = "web")] +mod actix; +mod common; +mod consensus; +mod greeting; +mod issues_setup; +mod migrations; +mod settings; +mod snapshots; +mod startup; +mod tonic; +mod tracing; + +use std::io::Error; +use std::sync::Arc; +use std::thread; +use std::thread::JoinHandle; +use std::time::Duration; + +use ::common::cpu::{get_cpu_budget, CpuBudget}; +use ::tonic::transport::Uri; +use api::grpc::transport_channel_pool::TransportChannelPool; +use clap::Parser; +use collection::shards::channel_service::ChannelService; +use consensus::Consensus; +use slog::Drain; +use startup::setup_panic_hook; +use storage::content_manager::consensus::operation_sender::OperationSender; +use storage::content_manager::consensus::persistent::Persistent; +use storage::content_manager::consensus_manager::{ConsensusManager, ConsensusStateRef}; +use storage::content_manager::toc::dispatcher::TocDispatcher; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +#[cfg(all( + not(target_env = "msvc"), + any(target_arch = "x86_64", target_arch = "aarch64") +))] +use tikv_jemallocator::Jemalloc; + +use crate::common::helpers::{ + create_general_purpose_runtime, create_search_runtime, create_update_runtime, + load_tls_client_config, +}; +use crate::common::inference::service::InferenceService; +use crate::common::telemetry::TelemetryCollector; +use crate::common::telemetry_reporting::TelemetryReporter; +use crate::greeting::welcome; +use crate::migrations::single_to_cluster::handle_existing_collections; +use crate::settings::Settings; +use crate::snapshots::{recover_full_snapshot, recover_snapshots}; +use crate::startup::{remove_started_file_indicator, touch_started_file_indicator}; + +#[cfg(all( + not(target_env = "msvc"), + any(target_arch = "x86_64", target_arch = "aarch64") +))] +#[global_allocator] +static GLOBAL: Jemalloc = Jemalloc; + +const FULL_ACCESS: Access = Access::full("For main"); + +/// Qdrant (read: quadrant ) is a vector similarity search engine. +/// It provides a production-ready service with a convenient API to store, search, and manage points - vectors with an additional payload. +/// +/// This CLI starts a Qdrant peer/server. +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Uri of the peer to bootstrap from in case of multi-peer deployment. + /// If not specified - this peer will be considered as a first in a new deployment. + #[arg(long, value_parser, value_name = "URI")] + bootstrap: Option, + /// Uri of this peer. + /// Other peers should be able to reach it by this uri. + /// + /// This value has to be supplied if this is the first peer in a new deployment. + /// + /// In case this is not the first peer and it bootstraps the value is optional. + /// If not supplied then qdrant will take internal grpc port from config and derive the IP address of this peer on bootstrap peer (receiving side) + #[arg(long, value_parser, value_name = "URI")] + uri: Option, + + /// Force snapshot re-creation + /// If provided - existing collections will be replaced with snapshots. + /// Default is to not recreate from snapshots. + #[arg(short, long, action, default_value_t = false)] + force_snapshot: bool, + + /// List of paths to snapshot files. + /// Format: : + /// + /// WARN: Do not use this option if you are recovering collection in existing distributed cluster. + /// Use `/collections//snapshots/recover` API instead. + #[arg(long, value_name = "PATH:NAME", alias = "collection-snapshot")] + snapshot: Option>, + + /// Path to snapshot of multiple collections. + /// Format: + /// + /// WARN: Do not use this option if you are recovering collection in existing distributed cluster. + /// Use `/collections//snapshots/recover` API instead. + #[arg(long, value_name = "PATH")] + storage_snapshot: Option, + + /// Path to an alternative configuration file. + /// Format: + /// + /// Default path: config/config.yaml + #[arg(long, value_name = "PATH")] + config_path: Option, + + /// Disable telemetry sending to developers + /// If provided - telemetry collection will be disabled. + /// Read more: + #[arg(long, action, default_value_t = false)] + disable_telemetry: bool, + + /// Run stacktrace collector. Used for debugging. + #[arg(long, action, default_value_t = false)] + stacktrace: bool, + + /// Reinit consensus state. + /// When enabled, the service will assume the consensus should be reinitialized. + /// The exact behavior depends on if this current node has bootstrap URI or not. + /// If it has - it'll remove current consensus state and consensus WAL (while keeping peer ID) + /// and will try to receive state from the bootstrap peer. + /// If it doesn't have - it'll remove other peers from voters promote + /// the current peer to the leader and the single member of the cluster. + /// It'll also compact consensus WAL to force snapshot + #[arg(long, action, default_value_t = false)] + reinit: bool, +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + // Run backtrace collector, expected to used by `rstack` crate + if args.stacktrace { + #[cfg(all(target_os = "linux", feature = "stacktrace"))] + { + let _ = rstack_self::child(); + } + return Ok(()); + } + + remove_started_file_indicator(); + + let settings = Settings::new(args.config_path)?; + + let reporting_enabled = !settings.telemetry_disabled && !args.disable_telemetry; + + let reporting_id = TelemetryCollector::generate_id(); + + let logger_handle = tracing::setup( + settings + .logger + .with_top_level_directive(settings.log_level.clone()), + )?; + + setup_panic_hook(reporting_enabled, reporting_id.to_string()); + + memory::madvise::set_global(settings.storage.mmap_advice); + segment::vector_storage::common::set_async_scorer( + settings + .storage + .performance + .async_scorer + .unwrap_or_default(), + ); + + welcome(&settings); + + if let Some(recovery_warning) = &settings.storage.recovery_mode { + log::warn!("Qdrant is loaded in recovery mode: {}", recovery_warning); + log::warn!( + "Read more: https://qdrant.tech/documentation/guides/administration/#recovery-mode" + ); + } + + // Validate as soon as possible, but we must initialize logging first + settings.validate_and_warn(); + + // Saved state of the consensus. + let persistent_consensus_state = Persistent::load_or_init( + &settings.storage.storage_path, + args.bootstrap.is_none(), + args.reinit, + )?; + + let is_distributed_deployment = settings.cluster.enabled; + + let temp_path = settings.storage.temp_path.as_deref(); + + let restored_collections = if let Some(full_snapshot) = args.storage_snapshot { + recover_full_snapshot( + temp_path, + &full_snapshot, + &settings.storage.storage_path, + args.force_snapshot, + persistent_consensus_state.this_peer_id(), + is_distributed_deployment, + ) + } else if let Some(snapshots) = args.snapshot { + // recover from snapshots + recover_snapshots( + &snapshots, + args.force_snapshot, + temp_path, + &settings.storage.storage_path, + persistent_consensus_state.this_peer_id(), + is_distributed_deployment, + ) + } else { + vec![] + }; + + // Create and own search runtime out of the scope of async context to ensure correct + // destruction of it + let search_runtime = create_search_runtime(settings.storage.performance.max_search_threads) + .expect("Can't search create runtime."); + + let update_runtime = + create_update_runtime(settings.storage.performance.max_optimization_threads) + .expect("Can't optimizer create runtime."); + + let general_runtime = + create_general_purpose_runtime().expect("Can't optimizer general purpose runtime."); + let runtime_handle = general_runtime.handle().clone(); + + // Use global CPU budget for optimizations based on settings + let optimizer_cpu_budget = CpuBudget::new(get_cpu_budget( + settings.storage.performance.optimizer_cpu_budget, + )); + + // Create a signal sender and receiver. It is used to communicate with the consensus thread. + let (propose_sender, propose_receiver) = std::sync::mpsc::channel(); + + let propose_operation_sender = if settings.cluster.enabled { + // High-level channel which could be used to send User-space consensus operations + Some(OperationSender::new(propose_sender)) + } else { + // We don't need sender for the single-node mode + None + }; + + // Channel service is used to manage connections between peers. + // It allocates required number of channels and manages proper reconnection handling + let mut channel_service = + ChannelService::new(settings.service.http_port, settings.service.api_key.clone()); + + if is_distributed_deployment { + // We only need channel_service in case if cluster is enabled. + // So we initialize it with real values here + let p2p_grpc_timeout = Duration::from_millis(settings.cluster.grpc_timeout_ms); + let connection_timeout = Duration::from_millis(settings.cluster.connection_timeout_ms); + + let tls_config = load_tls_client_config(&settings)?; + + channel_service.channel_pool = Arc::new(TransportChannelPool::new( + p2p_grpc_timeout, + connection_timeout, + settings.cluster.p2p.connection_pool_size, + tls_config, + )); + channel_service.id_to_address = persistent_consensus_state.peer_address_by_id.clone(); + channel_service.id_to_metadata = persistent_consensus_state.peer_metadata_by_id.clone(); + } + + // Table of content manages the list of collections. + // It is a main entry point for the storage. + let toc = TableOfContent::new( + &settings.storage, + search_runtime, + update_runtime, + general_runtime, + optimizer_cpu_budget, + channel_service.clone(), + persistent_consensus_state.this_peer_id(), + propose_operation_sender.clone(), + ); + + toc.clear_all_tmp_directories()?; + + // Here we load all stored collections. + runtime_handle.block_on(async { + for collection in toc.all_collections(&FULL_ACCESS).await { + log::debug!("Loaded collection: {collection}"); + } + }); + + let toc_arc = Arc::new(toc); + let storage_path = toc_arc.storage_path(); + + // Holder for all actively running threads of the service: web, gPRC, consensus, etc. + let mut handles: Vec>> = vec![]; + + // Router for external queries. + // It decides if query should go directly to the ToC or through the consensus. + let mut dispatcher = Dispatcher::new(toc_arc.clone()); + + let (telemetry_collector, dispatcher_arc, health_checker) = if is_distributed_deployment { + let consensus_state: ConsensusStateRef = ConsensusManager::new( + persistent_consensus_state, + toc_arc.clone(), + propose_operation_sender.unwrap(), + storage_path, + ) + .into(); + let is_new_deployment = consensus_state.is_new_deployment(); + + dispatcher = dispatcher.with_consensus(consensus_state.clone()); + + let toc_dispatcher = TocDispatcher::new(Arc::downgrade(&toc_arc), consensus_state.clone()); + toc_arc.with_toc_dispatcher(toc_dispatcher); + + let dispatcher_arc = Arc::new(dispatcher); + + // Monitoring and telemetry. + let telemetry_collector = + TelemetryCollector::new(settings.clone(), dispatcher_arc.clone(), reporting_id); + let tonic_telemetry_collector = telemetry_collector.tonic_telemetry_collector.clone(); + + // `raft` crate uses `slog` crate so it is needed to use `slog_stdlog::StdLog` to forward + // logs from it to `log` crate + let slog_logger = slog::Logger::root(slog_stdlog::StdLog.fuse(), slog::o!()); + + // Runs raft consensus in a separate thread. + // Create a pipe `message_sender` to communicate with the consensus + let health_checker = Arc::new(common::health::HealthChecker::spawn( + toc_arc.clone(), + consensus_state.clone(), + &runtime_handle, + // NOTE: `wait_for_bootstrap` should be calculated *before* starting `Consensus` thread + consensus_state.is_new_deployment() && args.bootstrap.is_some(), + )); + + let handle = Consensus::run( + &slog_logger, + consensus_state.clone(), + args.bootstrap, + args.uri.map(|uri| uri.to_string()), + settings.clone(), + channel_service, + propose_receiver, + tonic_telemetry_collector, + toc_arc.clone(), + runtime_handle.clone(), + args.reinit, + ) + .expect("Can't initialize consensus"); + + handles.push(handle); + + let toc_arc_clone = toc_arc.clone(); + let consensus_state_clone = consensus_state.clone(); + let _cancel_transfer_handle = runtime_handle.spawn(async move { + consensus_state_clone.is_leader_established.await_ready(); + match toc_arc_clone + .cancel_outgoing_all_transfers("Source peer restarted") + .await + { + Ok(_) => { + log::debug!("All transfers if any cancelled"); + } + Err(err) => { + log::error!("Can't cancel outgoing transfers: {}", err); + } + } + }); + + // TODO(resharding): Remove resharding driver? + // + // runtime_handle.block_on(async { + // toc_arc.resume_resharding_tasks().await; + // }); + + let collections_to_recover_in_consensus = if is_new_deployment { + let existing_collections = + runtime_handle.block_on(toc_arc.all_collections(&FULL_ACCESS)); + existing_collections + .into_iter() + .map(|pass| pass.name().to_string()) + .collect() + } else { + restored_collections + }; + + if !collections_to_recover_in_consensus.is_empty() { + runtime_handle.block_on(handle_existing_collections( + toc_arc.clone(), + consensus_state.clone(), + dispatcher_arc.clone(), + consensus_state.this_peer_id(), + collections_to_recover_in_consensus, + )); + } + + (telemetry_collector, dispatcher_arc, Some(health_checker)) + } else { + log::info!("Distributed mode disabled"); + let dispatcher_arc = Arc::new(dispatcher); + + // Monitoring and telemetry. + let telemetry_collector = + TelemetryCollector::new(settings.clone(), dispatcher_arc.clone(), reporting_id); + (telemetry_collector, dispatcher_arc, None) + }; + + let tonic_telemetry_collector = telemetry_collector.tonic_telemetry_collector.clone(); + + // + // Telemetry reporting + // + + let reporting_id = telemetry_collector.reporting_id(); + let telemetry_collector = Arc::new(tokio::sync::Mutex::new(telemetry_collector)); + + if reporting_enabled { + log::info!("Telemetry reporting enabled, id: {}", reporting_id); + + runtime_handle.spawn(TelemetryReporter::run(telemetry_collector.clone())); + } else { + log::info!("Telemetry reporting disabled"); + } + + // Setup subscribers to listen for issue-able events + issues_setup::setup_subscribers(&settings); + + // Helper to better log start errors + let log_err_if_any = |server_name, result| match result { + Err(err) => { + log::error!("Error while starting {} server: {}", server_name, err); + Err(err) + } + ok => ok, + }; + + // + // Inference Service + // + if let Some(inference_config) = settings.inference.clone() { + match InferenceService::init_global(inference_config) { + Ok(_) => { + log::info!("Inference service is configured."); + } + Err(err) => { + log::error!("{err}"); + } + } + } else { + log::info!("Inference service is not configured."); + } + + // + // REST API server + // + + #[cfg(feature = "web")] + { + let dispatcher_arc = dispatcher_arc.clone(); + let settings = settings.clone(); + let handle = thread::Builder::new() + .name("web".to_string()) + .spawn(move || { + log_err_if_any( + "REST", + actix::init( + dispatcher_arc.clone(), + telemetry_collector, + health_checker, + settings, + logger_handle, + ), + ) + }) + .unwrap(); + handles.push(handle); + } + + // + // gRPC server + // + + if let Some(grpc_port) = settings.service.grpc_port { + let settings = settings.clone(); + let handle = thread::Builder::new() + .name("grpc".to_string()) + .spawn(move || { + log_err_if_any( + "gRPC", + tonic::init( + dispatcher_arc, + tonic_telemetry_collector, + settings, + grpc_port, + runtime_handle, + ), + ) + }) + .unwrap(); + handles.push(handle); + } else { + log::info!("gRPC endpoint disabled"); + } + + #[cfg(feature = "service_debug")] + { + use std::fmt::Write; + + use parking_lot::deadlock; + + const DEADLOCK_CHECK_PERIOD: Duration = Duration::from_secs(10); + + thread::Builder::new() + .name("deadlock_checker".to_string()) + .spawn(move || loop { + thread::sleep(DEADLOCK_CHECK_PERIOD); + let deadlocks = deadlock::check_deadlock(); + if deadlocks.is_empty() { + continue; + } + + let mut error = format!("{} deadlocks detected\n", deadlocks.len()); + for (i, threads) in deadlocks.iter().enumerate() { + writeln!(error, "Deadlock #{i}").expect("fail to writeln!"); + for t in threads { + writeln!( + error, + "Thread Id {:#?}\n{:#?}", + t.thread_id(), + t.backtrace() + ) + .expect("fail to writeln!"); + } + } + log::error!("{}", error); + }) + .unwrap(); + } + + touch_started_file_indicator(); + + for handle in handles { + log::debug!( + "Waiting for thread {} to finish", + handle.thread().name().unwrap() + ); + handle.join().expect("thread is not panicking")?; + } + drop(toc_arc); + drop(settings); + Ok(()) +} diff --git a/src/migrations/mod.rs b/src/migrations/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..2c4052976a5016e1875dbc18fbb7bfdbd26421f3 --- /dev/null +++ b/src/migrations/mod.rs @@ -0,0 +1 @@ +pub mod single_to_cluster; diff --git a/src/migrations/single_to_cluster.rs b/src/migrations/single_to_cluster.rs new file mode 100644 index 0000000000000000000000000000000000000000..7db4bbe19854b816928d15f2eadadcbd67e54481 --- /dev/null +++ b/src/migrations/single_to_cluster.rs @@ -0,0 +1,142 @@ +use std::sync::Arc; + +use collection::config::ShardingMethod; +use collection::shards::replica_set::ReplicaState; +use collection::shards::shard::PeerId; +use storage::content_manager::collection_meta_ops::{ + CollectionMetaOperations, CreateCollection, CreateCollectionOperation, CreateShardKey, + SetShardReplicaState, +}; +use storage::content_manager::consensus_manager::ConsensusStateRef; +use storage::content_manager::shard_distribution::ShardDistributionProposal; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::{Access, AccessRequirements}; + +/// Processes the existing collections, which were created outside the consensus: +/// - during the migration from single to cluster +/// - during restoring from a backup +pub async fn handle_existing_collections( + toc_arc: Arc, + consensus_state: ConsensusStateRef, + dispatcher_arc: Arc, + this_peer_id: PeerId, + collections: Vec, +) { + let full_access = Access::full("Migration from single to cluster"); + let multipass = full_access + .check_global_access(AccessRequirements::new().manage()) + .expect("Full access should have manage rights"); + + consensus_state.is_leader_established.await_ready(); + for collection_name in collections { + let Ok(collection_obj) = toc_arc + .get_collection(&multipass.issue_pass(&collection_name)) + .await + else { + break; + }; + + let collection_state = collection_obj.state().await; + let shards_number = collection_state.config.params.shard_number.get(); + let sharding_method = collection_state.config.params.sharding_method; + + let mut collection_create_operation = CreateCollectionOperation::new( + collection_name.to_string(), + CreateCollection { + vectors: collection_state.config.params.vectors, + sparse_vectors: collection_state.config.params.sparse_vectors, + shard_number: Some(shards_number), + sharding_method, + replication_factor: Some(collection_state.config.params.replication_factor.get()), + write_consistency_factor: Some( + collection_state + .config + .params + .write_consistency_factor + .get(), + ), + on_disk_payload: Some(collection_state.config.params.on_disk_payload), + hnsw_config: Some(collection_state.config.hnsw_config.into()), + wal_config: Some(collection_state.config.wal_config.into()), + optimizers_config: Some(collection_state.config.optimizer_config.into()), + init_from: None, + quantization_config: collection_state.config.quantization_config, + strict_mode_config: collection_state.config.strict_mode_config, + uuid: collection_state.config.uuid, + }, + ); + + let mut consensus_operations = Vec::new(); + + match sharding_method.unwrap_or_default() { + ShardingMethod::Auto => { + collection_create_operation.set_distribution(ShardDistributionProposal { + distribution: collection_state + .shards + .iter() + .filter_map(|(shard_id, shard_info)| { + if shard_info.replicas.contains_key(&this_peer_id) { + Some((*shard_id, vec![this_peer_id])) + } else { + None + } + }) + .collect(), + }); + + consensus_operations.push(CollectionMetaOperations::CreateCollection( + collection_create_operation, + )); + } + ShardingMethod::Custom => { + // We should create additional consensus operations here to set the shard distribution + collection_create_operation.set_distribution(ShardDistributionProposal::empty()); + consensus_operations.push(CollectionMetaOperations::CreateCollection( + collection_create_operation, + )); + + for (shard_key, shards) in &collection_state.shards_key_mapping { + let mut placement = Vec::new(); + + for shard_id in shards { + let shard_info = collection_state.shards.get(shard_id).unwrap(); + placement.push(shard_info.replicas.keys().copied().collect()); + } + + consensus_operations.push(CollectionMetaOperations::CreateShardKey( + CreateShardKey { + collection_name: collection_name.to_string(), + shard_key: shard_key.clone(), + placement, + }, + )) + } + } + } + + for operation in consensus_operations { + let _res = dispatcher_arc + .submit_collection_meta_op(operation, full_access.clone(), None) + .await; + } + + for (shard_id, shard_info) in collection_state.shards { + if shard_info.replicas.contains_key(&this_peer_id) { + let _res = dispatcher_arc + .submit_collection_meta_op( + CollectionMetaOperations::SetShardReplicaState(SetShardReplicaState { + collection_name: collection_name.to_string(), + shard_id, + peer_id: this_peer_id, + state: ReplicaState::Active, + from_state: None, + }), + full_access.clone(), + None, + ) + .await; + } + } + } +} diff --git a/src/schema_generator.rs b/src/schema_generator.rs new file mode 100644 index 0000000000000000000000000000000000000000..c6bc31693702e5d7535cef3ed4ce9b3e8569798c --- /dev/null +++ b/src/schema_generator.rs @@ -0,0 +1,111 @@ +use api::rest::models::{CollectionsResponse, HardwareUsage, VersionInfo}; +use api::rest::schema::PointInsertOperations; +use api::rest::{ + FacetRequest, FacetResponse, QueryGroupsRequest, QueryRequest, QueryRequestBatch, + QueryResponse, Record, ScoredPoint, SearchMatrixOffsetsResponse, SearchMatrixPairsResponse, + SearchMatrixRequest, UpdateVectors, +}; +use collection::operations::cluster_ops::ClusterOperations; +use collection::operations::consistency_params::ReadConsistency; +use collection::operations::payload_ops::{DeletePayload, SetPayload}; +use collection::operations::point_ops::{PointsSelector, WriteOrdering}; +use collection::operations::snapshot_ops::{ + ShardSnapshotRecover, SnapshotDescription, SnapshotRecover, +}; +use collection::operations::types::{ + AliasDescription, CollectionClusterInfo, CollectionExistence, CollectionInfo, + CollectionsAliasesResponse, CountRequest, CountResult, DiscoverRequest, DiscoverRequestBatch, + GroupsResult, PointGroup, PointRequest, RecommendGroupsRequest, RecommendRequest, + RecommendRequestBatch, ScrollRequest, ScrollResult, SearchGroupsRequest, SearchRequest, + SearchRequestBatch, UpdateResult, +}; +use collection::operations::vector_ops::DeleteVectors; +use schemars::gen::SchemaSettings; +use schemars::JsonSchema; +use serde::Serialize; +use storage::content_manager::collection_meta_ops::{ + ChangeAliasesOperation, CreateCollection, UpdateCollection, +}; +use storage::types::ClusterStatus; + +use crate::common::helpers::LocksOption; +use crate::common::points::{CreateFieldIndex, UpdateOperations}; +use crate::common::telemetry::TelemetryData; + +mod actix; +mod common; +mod settings; +mod tracing; + +#[derive(Serialize, JsonSchema)] +struct AllDefinitions { + a1: CollectionsResponse, + a2: CollectionInfo, + // a3: CollectionMetaOperations, + a4: PointRequest, + a5: Record, + a6: SearchRequest, + a7: ScoredPoint, + a8: UpdateResult, + // a9: CollectionUpdateOperations, + aa: RecommendRequest, + ab: ScrollRequest, + ac: ScrollResult, + ad: CreateCollection, + ae: UpdateCollection, + af: ChangeAliasesOperation, + ag: CreateFieldIndex, + ah: PointsSelector, + ai: PointInsertOperations, + aj: SetPayload, + ak: DeletePayload, + al: ClusterStatus, + am: SnapshotDescription, + an: CountRequest, + ao: CountResult, + ap: CollectionClusterInfo, + aq: TelemetryData, + ar: ClusterOperations, + at: SearchRequestBatch, + au: RecommendRequestBatch, + av: LocksOption, + aw: SnapshotRecover, + ax: CollectionsAliasesResponse, + ay: AliasDescription, + az: WriteOrdering, + b1: ReadConsistency, + b2: UpdateVectors, + b3: DeleteVectors, + b4: PointGroup, + b5: SearchGroupsRequest, + b6: RecommendGroupsRequest, + b7: GroupsResult, + b8: UpdateOperations, + b9: ShardSnapshotRecover, + ba: DiscoverRequest, + bb: DiscoverRequestBatch, + bc: VersionInfo, + bd: CollectionExistence, + be: QueryRequest, + bf: QueryRequestBatch, + bg: QueryResponse, + bh: QueryGroupsRequest, + bi: SearchMatrixRequest, + bj: SearchMatrixOffsetsResponse, + bk: SearchMatrixPairsResponse, + bl: FacetRequest, + bm: FacetResponse, + bn: HardwareUsage, +} + +fn save_schema() { + let settings = SchemaSettings::draft07(); + let gen = settings.into_generator(); + let schema = gen.into_root_schema_for::(); + let schema_str = serde_json::to_string_pretty(&schema).unwrap(); + println!("{schema_str}") +} + +fn main() { + save_schema::(); +} diff --git a/src/settings.rs b/src/settings.rs new file mode 100644 index 0000000000000000000000000000000000000000..9d9b930f0ee2f09ccc1cb33131e19f285e6f24f1 --- /dev/null +++ b/src/settings.rs @@ -0,0 +1,418 @@ +use std::{env, io}; + +use api::grpc::transport_channel_pool::{ + DEFAULT_CONNECT_TIMEOUT, DEFAULT_GRPC_TIMEOUT, DEFAULT_POOL_SIZE, +}; +use collection::operations::validation; +use config::{Config, ConfigError, Environment, File, FileFormat, Source}; +use serde::Deserialize; +use storage::types::StorageConfig; +use validator::Validate; + +use crate::common::debugger::DebuggerConfig; +use crate::common::inference::config::InferenceConfig; +use crate::tracing; + +const DEFAULT_CONFIG: &str = include_str!("../config/config.yaml"); + +#[derive(Debug, Deserialize, Validate, Clone)] +#[allow(dead_code)] // necessary because some field are only used in main.rs +pub struct ServiceConfig { + #[validate(length(min = 1))] + pub host: String, + pub http_port: u16, + pub grpc_port: Option, // None means that gRPC is disabled + pub max_request_size_mb: usize, + pub max_workers: Option, + #[serde(default = "default_cors")] + pub enable_cors: bool, + #[serde(default)] + pub enable_tls: bool, + #[serde(default)] + pub verify_https_client_certificate: bool, + pub api_key: Option, + pub read_only_api_key: Option, + #[serde(default)] + pub jwt_rbac: Option, + + #[serde(default)] + pub hide_jwt_dashboard: Option, + + /// Directory where static files are served from. + /// For example, the Web-UI should be placed here. + #[serde(default)] + pub static_content_dir: Option, + + /// If serving of the static content is enabled. + /// This includes the Web-UI. True by default. + #[serde(default)] + pub enable_static_content: Option, + + /// How much time is considered too long for a query to execute. + pub slow_query_secs: Option, + + /// Whether to enable reporting of measured hardware utilization in API responses. + #[serde(default)] + pub hardware_reporting: Option, +} + +impl ServiceConfig { + pub fn hardware_reporting(&self) -> bool { + self.hardware_reporting.unwrap_or_default() + } +} + +#[derive(Debug, Deserialize, Clone, Default, Validate)] +pub struct ClusterConfig { + pub enabled: bool, // disabled by default + #[serde(default = "default_timeout_ms")] + #[validate(range(min = 1))] + pub grpc_timeout_ms: u64, + #[serde(default = "default_connection_timeout_ms")] + #[validate(range(min = 1))] + pub connection_timeout_ms: u64, + #[serde(default)] + #[validate(nested)] + pub p2p: P2pConfig, + #[serde(default)] + #[validate(nested)] + pub consensus: ConsensusConfig, +} + +#[derive(Debug, Deserialize, Clone, Validate)] +#[allow(dead_code)] // necessary because some field are only used in main.rs +pub struct P2pConfig { + #[serde(default)] + pub port: Option, + #[serde(default = "default_connection_pool_size")] + #[validate(range(min = 1))] + pub connection_pool_size: usize, + #[serde(default)] + pub enable_tls: bool, +} + +impl Default for P2pConfig { + fn default() -> Self { + P2pConfig { + port: None, + connection_pool_size: default_connection_pool_size(), + enable_tls: false, + } + } +} + +#[derive(Debug, Deserialize, Clone, Validate)] +pub struct ConsensusConfig { + #[serde(default = "default_max_message_queue_size")] + pub max_message_queue_size: usize, // controls the back-pressure at the Raft level + #[serde(default = "default_tick_period_ms")] + #[validate(range(min = 1))] + pub tick_period_ms: u64, + #[serde(default = "default_bootstrap_timeout_sec")] + #[validate(range(min = 1))] + pub bootstrap_timeout_sec: u64, + #[validate(range(min = 1))] + #[serde(default = "default_message_timeout_tics")] + pub message_timeout_ticks: u64, + #[allow(dead_code)] // `schema_generator` complains about this 🙄 + #[serde(default)] + pub compact_wal_entries: u64, // compact WAL when it grows to enough applied entries +} + +impl Default for ConsensusConfig { + fn default() -> Self { + ConsensusConfig { + max_message_queue_size: default_max_message_queue_size(), + tick_period_ms: default_tick_period_ms(), + bootstrap_timeout_sec: default_bootstrap_timeout_sec(), + message_timeout_ticks: default_message_timeout_tics(), + compact_wal_entries: 0, + } + } +} + +#[derive(Debug, Deserialize, Clone, Validate)] +pub struct TlsConfig { + pub cert: String, + pub key: String, + pub ca_cert: String, + #[serde(default = "default_tls_cert_ttl")] + #[validate(range(min = 1))] + pub cert_ttl: Option, +} + +#[derive(Debug, Deserialize, Clone, Validate)] +#[allow(dead_code)] // necessary because some field are only used in main.rs +pub struct Settings { + #[serde(default)] + pub log_level: Option, + #[serde(default)] + pub logger: tracing::LoggerConfig, + #[validate(nested)] + pub storage: StorageConfig, + #[validate(nested)] + pub service: ServiceConfig, + #[serde(default)] + #[validate(nested)] + pub cluster: ClusterConfig, + #[serde(default = "default_telemetry_disabled")] + pub telemetry_disabled: bool, + #[validate(nested)] + pub tls: Option, + #[serde(default)] + pub debugger: DebuggerConfig, + /// A list of messages for errors that happened during loading the configuration. We collect + /// them and store them here while loading because then our logger is not configured yet. + /// We therefore need to log these messages later, after the logger is ready. + #[serde(default, skip)] + pub load_errors: Vec, + #[serde(default)] + pub inference: Option, +} + +impl Settings { + #[allow(dead_code)] + pub fn new(custom_config_path: Option) -> Result { + let mut load_errors = vec![]; + let config_exists = |path| File::with_name(path).collect().is_ok(); + + // Check if custom config file exists, report error if not + if let Some(ref path) = custom_config_path { + if !config_exists(path) { + load_errors.push(LogMsg::Error(format!( + "Config file via --config-path is not found: {path}" + ))); + } + } + + let env = env::var("RUN_MODE").unwrap_or_else(|_| "development".into()); + let config_path_env = format!("config/{env}"); + + // Report error if main or env config files exist, report warning if not + // Check if main and env configuration file + load_errors.extend( + ["config/config", &config_path_env] + .into_iter() + .filter(|path| !config_exists(path)) + .map(|path| LogMsg::Warn(format!("Config file not found: {path}"))), + ); + + // Configuration builder: define different levels of configuration files + let mut config = Config::builder() + // Start with compile-time base config + .add_source(File::from_str(DEFAULT_CONFIG, FileFormat::Yaml)) + // Merge main config: config/config + .add_source(File::with_name("config/config").required(false)) + // Merge env config: config/{env} + // Uses RUN_MODE, defaults to 'development' + .add_source(File::with_name(&config_path_env).required(false)) + // Merge local config, not tracked in git: config/local + .add_source(File::with_name("config/local").required(false)); + + // Merge user provided config with --config-path + if let Some(path) = custom_config_path { + config = config.add_source(File::with_name(&path).required(false)); + } + + // Merge environment settings + // E.g.: `QDRANT_DEBUG=1 ./target/app` would set `debug=true` + config = config.add_source(Environment::with_prefix("QDRANT").separator("__")); + + // Build and merge config and deserialize into Settings, attach any load errors we had + let mut settings: Settings = config.build()?.try_deserialize()?; + settings.load_errors.extend(load_errors); + Ok(settings) + } + + pub fn tls(&self) -> io::Result<&TlsConfig> { + self.tls + .as_ref() + .ok_or_else(Self::tls_config_is_undefined_error) + } + + pub fn tls_config_is_undefined_error() -> io::Error { + io::Error::new( + io::ErrorKind::Other, + "TLS config is not defined in the Qdrant config file", + ) + } + + #[allow(dead_code)] + pub fn validate_and_warn(&self) { + // + // JWT RBAC + // + // Using HMAC-SHA256, recommended secret size is 32 bytes + const JWT_RECOMMENDED_SECRET_LENGTH: usize = 256 / 8; + + // Log if JWT RBAC is enabled but no API key is set + if self.service.jwt_rbac.unwrap_or_default() { + if self.service.api_key.clone().unwrap_or_default().is_empty() { + log::warn!("JWT RBAC configured but no API key set, JWT RBAC is not enabled") + // Log if JWT RAC is enabled, API key is set but smaller than recommended size for JWT secret + } else if self.service.api_key.clone().unwrap_or_default().len() + < JWT_RECOMMENDED_SECRET_LENGTH + { + log::warn!( + "It is highly recommended to use an API key of {} bytes when JWT RBAC is enabled", + JWT_RECOMMENDED_SECRET_LENGTH + ) + } + } + + // Print any load error messages we had + self.load_errors.iter().for_each(LogMsg::log); + + if let Err(ref errs) = self.validate() { + validation::warn_validation_errors("Settings configuration file", errs); + } + } +} + +/// Returns the number of maximum actix workers. +pub fn max_web_workers(settings: &Settings) -> usize { + match settings.service.max_workers { + Some(0) => { + let num_cpu = common::cpu::get_num_cpus(); + std::cmp::max(1, num_cpu - 1) + } + Some(max_workers) => max_workers, + None => settings.storage.performance.max_search_threads, + } +} + +#[derive(Clone, Debug)] +pub enum LogMsg { + Warn(String), + Error(String), +} + +impl LogMsg { + fn log(&self) { + match self { + Self::Warn(msg) => log::warn!("{msg}"), + Self::Error(msg) => log::error!("{msg}"), + } + } +} + +const fn default_telemetry_disabled() -> bool { + false +} + +const fn default_cors() -> bool { + true +} + +const fn default_timeout_ms() -> u64 { + DEFAULT_GRPC_TIMEOUT.as_millis() as u64 +} + +const fn default_connection_timeout_ms() -> u64 { + DEFAULT_CONNECT_TIMEOUT.as_millis() as u64 +} + +const fn default_tick_period_ms() -> u64 { + 100 +} + +// Should not be less than `DEFAULT_META_OP_WAIT` as bootstrapping perform sync. consensus meta operations. +const fn default_bootstrap_timeout_sec() -> u64 { + 15 +} + +const fn default_max_message_queue_size() -> usize { + 100 +} + +const fn default_connection_pool_size() -> usize { + DEFAULT_POOL_SIZE +} + +const fn default_message_timeout_tics() -> u64 { + 10 +} + +#[allow(clippy::unnecessary_wraps)] // Used as serde default +const fn default_tls_cert_ttl() -> Option { + // Default one hour + Some(3600) +} + +#[cfg(test)] +mod tests { + use std::fs; + use std::io::Write; + + use sealed_test::prelude::*; + + use super::*; + + /// Ensure we can successfully deserialize into [`Settings`] with just the default configuration. + #[test] + fn test_default_config() { + Config::builder() + .add_source(File::from_str(DEFAULT_CONFIG, FileFormat::Yaml)) + .build() + .expect("failed to build default config") + .try_deserialize::() + .expect("failed to deserialize default config") + .validate() + .expect("failed to validate default config"); + } + + #[sealed_test(files = ["config/config.yaml", "config/development.yaml"])] + fn test_runtime_development_config() { + env::set_var("RUN_MODE", "development"); + + // `sealed_test` copies files into the same directory as the test runs in. + // We need them in a subdirectory. + std::fs::create_dir("config").expect("failed to create `config` subdirectory."); + std::fs::copy("config.yaml", "config/config.yaml").expect("failed to copy `config.yaml`."); + std::fs::copy("development.yaml", "config/development.yaml") + .expect("failed to copy `development.yaml`."); + + // Read config + let config = Settings::new(None).expect("failed to load development config at runtime"); + + // Validate + config + .validate() + .expect("failed to validate development config at runtime"); + assert!(config.load_errors.is_empty(), "must not have load errors") + } + + #[sealed_test] + fn test_no_config_files() { + let non_existing_config_path = "config/non_existing_config".to_string(); + + // Read config + let config = Settings::new(Some(non_existing_config_path)) + .expect("failed to load with non-existing runtime config"); + + // Validate + config + .validate() + .expect("failed to validate with non-existing runtime config"); + assert!(!config.load_errors.is_empty(), "must have load errors") + } + + #[sealed_test] + fn test_custom_config() { + let path = "config/custom.yaml"; + + // Create custom config file + { + fs::create_dir("config").unwrap(); + let mut custom = fs::File::create(path).unwrap(); + write!(&mut custom, "service:\n http_port: 9999").unwrap(); + custom.flush().unwrap(); + } + + // Load settings with custom config + let config = Settings::new(Some(path.into())).unwrap(); + + // Ensure our custom config is the most important + assert_eq!(config.service.http_port, 9999); + } +} diff --git a/src/snapshots.rs b/src/snapshots.rs new file mode 100644 index 0000000000000000000000000000000000000000..0c64c20f1c0d0271ceff25ec22610eb36981d73d --- /dev/null +++ b/src/snapshots.rs @@ -0,0 +1,141 @@ +use std::fs::{self, remove_dir_all, rename}; +use std::path::{Path, PathBuf}; + +use collection::collection::Collection; +use collection::shards::shard::PeerId; +use log::info; +use segment::common::validate_snapshot_archive::open_snapshot_archive_with_validation; +use storage::content_manager::alias_mapping::AliasPersistence; +use storage::content_manager::snapshots::SnapshotConfig; +use storage::content_manager::toc::{ALIASES_PATH, COLLECTIONS_DIR}; + +/// Recover snapshots from the given arguments +/// +/// # Arguments +/// +/// * `mapping` - `[ : ]` +/// * `force` - if true, allow to overwrite collections from snapshots +/// +/// # Returns +/// +/// * `Vec` - list of collections that were recovered +pub fn recover_snapshots( + mapping: &[String], + force: bool, + temp_dir: Option<&str>, + storage_dir: &str, + this_peer_id: PeerId, + is_distributed: bool, +) -> Vec { + let collection_dir_path = Path::new(storage_dir).join(COLLECTIONS_DIR); + let mut recovered_collections: Vec = vec![]; + + for snapshot_params in mapping { + let mut split = snapshot_params.split(':'); + let path = split + .next() + .unwrap_or_else(|| panic!("Snapshot path is missing: {snapshot_params}")); + + let snapshot_path = Path::new(path); + let collection_name = split + .next() + .unwrap_or_else(|| panic!("Collection name is missing: {snapshot_params}")); + recovered_collections.push(collection_name.to_string()); + assert!( + split.next().is_none(), + "Too many parts in snapshot mapping: {snapshot_params}" + ); + info!("Recovering snapshot {} from {}", collection_name, path); + // check if collection already exists + // if it does, we need to check if we want to overwrite it + // if not, we need to abort + let collection_path = collection_dir_path.join(collection_name); + info!("Collection path: {}", collection_path.display()); + if collection_path.exists() { + if !force { + panic!( + "Collection {collection_name} already exists. Use --force-snapshot to overwrite it." + ); + } + info!("Overwriting collection {}", collection_name); + } + let collection_temp_path = temp_dir + .map(PathBuf::from) + .unwrap_or_else(|| collection_path.with_extension("tmp")); + if let Err(err) = Collection::restore_snapshot( + snapshot_path, + &collection_temp_path, + this_peer_id, + is_distributed, + ) { + panic!("Failed to recover snapshot {collection_name}: {err}"); + } + // Remove collection_path directory if exists + if collection_path.exists() { + if let Err(err) = remove_dir_all(&collection_path) { + panic!("Failed to remove collection {collection_name}: {err}"); + } + } + rename(&collection_temp_path, &collection_path).unwrap(); + } + recovered_collections +} + +pub fn recover_full_snapshot( + temp_dir: Option<&str>, + snapshot_path: &str, + storage_dir: &str, + force: bool, + this_peer_id: PeerId, + is_distributed: bool, +) -> Vec { + let snapshot_temp_path = temp_dir + .map(PathBuf::from) + .unwrap_or_else(|| Path::new(storage_dir).join("snapshots_recovery_tmp")); + fs::create_dir_all(&snapshot_temp_path).unwrap(); + + // Un-tar snapshot into temporary directory + let mut ar = open_snapshot_archive_with_validation(Path::new(snapshot_path)).unwrap(); + ar.unpack(&snapshot_temp_path).unwrap(); + + // Read configuration file with snapshot-to-collection mapping + let config_path = snapshot_temp_path.join("config.json"); + let config_file = fs::File::open(config_path).unwrap(); + let config_json: SnapshotConfig = serde_json::from_reader(config_file).unwrap(); + + // Create mapping from the configuration file + let mapping: Vec = config_json + .collections_mapping + .iter() + .map(|(collection_name, snapshot_file)| { + format!( + "{}:{collection_name}", + snapshot_temp_path.join(snapshot_file).to_str().unwrap(), + ) + }) + .collect(); + + // Launch regular recovery of snapshots + let recovered_collection = recover_snapshots( + &mapping, + force, + temp_dir, + storage_dir, + this_peer_id, + is_distributed, + ); + + let alias_path = Path::new(storage_dir).join(ALIASES_PATH); + let mut alias_persistence = + AliasPersistence::open(&alias_path).expect("Can't open database by the provided config"); + for (alias, collection_name) in config_json.collections_aliases { + if alias_persistence.get(&alias).is_some() && !force { + panic!("Alias {alias} already exists. Use --force-snapshot to overwrite it."); + } + alias_persistence.insert(alias, collection_name).unwrap(); + } + + // Remove temporary directory + remove_dir_all(&snapshot_temp_path).unwrap(); + recovered_collection +} diff --git a/src/startup.rs b/src/startup.rs new file mode 100644 index 0000000000000000000000000000000000000000..237a8cf0af6d56908a8a444f28f7cb3ed08c8906 --- /dev/null +++ b/src/startup.rs @@ -0,0 +1,59 @@ +//! Contains a collection of functions that are called at the start of the program. + +use std::backtrace::Backtrace; +use std::panic; +use std::path::PathBuf; + +use crate::common::error_reporting::ErrorReporter; + +const DEFAULT_INITIALIZED_FILE: &str = ".qdrant-initialized"; + +fn get_init_file_path() -> PathBuf { + std::env::var("QDRANT_INIT_FILE_PATH") + .map(PathBuf::from) + .unwrap_or_else(|_| DEFAULT_INITIALIZED_FILE.into()) +} + +pub fn setup_panic_hook(reporting_enabled: bool, reporting_id: String) { + panic::set_hook(Box::new(move |panic_info| { + let backtrace = Backtrace::force_capture().to_string(); + let loc = if let Some(loc) = panic_info.location() { + format!(" in file {} at line {}", loc.file(), loc.line()) + } else { + String::new() + }; + let message = if let Some(s) = panic_info.payload().downcast_ref::<&str>() { + s + } else if let Some(s) = panic_info.payload().downcast_ref::() { + s + } else { + "Payload not captured as it is not a string." + }; + + log::error!("Panic backtrace: \n{}", backtrace); + log::error!("Panic occurred{loc}: {message}"); + + if reporting_enabled { + ErrorReporter::report(message, &reporting_id, Some(&loc)); + } + })); +} + +/// Creates a file that indicates that the server has been started. +/// This file is used to check if the server has been successfully started before potential kill. +pub fn touch_started_file_indicator() { + if let Err(err) = std::fs::write(get_init_file_path(), "") { + log::warn!("Failed to create init file indicator: {}", err); + } +} + +/// Removes a file that indicates that the server has been started. +/// Use before server initialization to avoid false positives. +pub fn remove_started_file_indicator() { + let path = get_init_file_path(); + if path.exists() { + if let Err(err) = std::fs::remove_file(path) { + log::warn!("Failed to remove init file indicator: {}", err); + } + } +} diff --git a/src/tonic/api/collections_api.rs b/src/tonic/api/collections_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..3b1423582ec133990b1e1e26810ef282d458f22d --- /dev/null +++ b/src/tonic/api/collections_api.rs @@ -0,0 +1,337 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use api::grpc::qdrant::collections_server::Collections; +use api::grpc::qdrant::{ + ChangeAliases, CollectionClusterInfoRequest, CollectionClusterInfoResponse, + CollectionExistsRequest, CollectionExistsResponse, CollectionOperationResponse, + CreateCollection, CreateShardKeyRequest, CreateShardKeyResponse, DeleteCollection, + DeleteShardKeyRequest, DeleteShardKeyResponse, GetCollectionInfoRequest, + GetCollectionInfoResponse, ListAliasesRequest, ListAliasesResponse, + ListCollectionAliasesRequest, ListCollectionsRequest, ListCollectionsResponse, + UpdateCollection, UpdateCollectionClusterSetupRequest, UpdateCollectionClusterSetupResponse, +}; +use collection::operations::cluster_ops::{ + ClusterOperations, CreateShardingKeyOperation, DropShardingKeyOperation, +}; +use collection::operations::types::CollectionsAliasesResponse; +use collection::operations::verification::new_unchecked_verification_pass; +use storage::dispatcher::Dispatcher; +use tonic::{Request, Response, Status}; + +use super::validate; +use crate::common::collections::*; +use crate::tonic::api::collections_common::get; +use crate::tonic::auth::extract_access; + +pub struct CollectionsService { + dispatcher: Arc, +} + +impl CollectionsService { + pub fn new(dispatcher: Arc) -> Self { + Self { dispatcher } + } + + async fn perform_operation( + &self, + mut request: Request, + ) -> Result, Status> + where + O: WithTimeout + + TryInto< + storage::content_manager::collection_meta_ops::CollectionMetaOperations, + Error = Status, + >, + { + let timing = Instant::now(); + let access = extract_access(&mut request); + let operation = request.into_inner(); + let wait_timeout = operation.wait_timeout(); + let result = self + .dispatcher + .submit_collection_meta_op(operation.try_into()?, access, wait_timeout) + .await?; + + let response = CollectionOperationResponse::from((timing, result)); + Ok(Response::new(response)) + } +} + +#[tonic::async_trait] +impl Collections for CollectionsService { + async fn get( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + get( + self.dispatcher.toc(&access, &pass), + request.into_inner(), + access, + None, + ) + .await + } + + async fn list( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let timing = Instant::now(); + let access = extract_access(&mut request); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let result = do_list_collections(self.dispatcher.toc(&access, &pass), access).await?; + + let response = ListCollectionsResponse::from((timing, result)); + Ok(Response::new(response)) + } + + async fn create( + &self, + request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + self.perform_operation(request).await + } + + async fn update( + &self, + request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + self.perform_operation(request).await + } + + async fn delete( + &self, + request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + self.perform_operation(request).await + } + + async fn update_aliases( + &self, + request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + self.perform_operation(request).await + } + + async fn list_collection_aliases( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let timing = Instant::now(); + let access = extract_access(&mut request); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let ListCollectionAliasesRequest { collection_name } = request.into_inner(); + let CollectionsAliasesResponse { aliases } = do_list_collection_aliases( + self.dispatcher.toc(&access, &pass), + access, + &collection_name, + ) + .await?; + let response = ListAliasesResponse { + aliases: aliases.into_iter().map(|alias| alias.into()).collect(), + time: timing.elapsed().as_secs_f64(), + }; + Ok(Response::new(response)) + } + + async fn list_aliases( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let timing = Instant::now(); + let access = extract_access(&mut request); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let CollectionsAliasesResponse { aliases } = + do_list_aliases(self.dispatcher.toc(&access, &pass), access).await?; + let response = ListAliasesResponse { + aliases: aliases.into_iter().map(|alias| alias.into()).collect(), + time: timing.elapsed().as_secs_f64(), + }; + Ok(Response::new(response)) + } + + async fn collection_exists( + &self, + mut request: Request, + ) -> Result, Status> { + let timing = Instant::now(); + validate(request.get_ref())?; + let access = extract_access(&mut request); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let CollectionExistsRequest { collection_name } = request.into_inner(); + let result = do_collection_exists( + self.dispatcher.toc(&access, &pass), + access, + &collection_name, + ) + .await?; + let response = CollectionExistsResponse { + result: Some(result), + time: timing.elapsed().as_secs_f64(), + }; + + Ok(Response::new(response)) + } + + async fn collection_cluster_info( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let response = do_get_collection_cluster( + self.dispatcher.toc(&access, &pass), + access, + request.into_inner().collection_name.as_str(), + ) + .await? + .into(); + + Ok(Response::new(response)) + } + + async fn update_collection_cluster_setup( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let UpdateCollectionClusterSetupRequest { + collection_name, + operation, + timeout, + .. + } = request.into_inner(); + let result = do_update_collection_cluster( + self.dispatcher.as_ref(), + collection_name, + operation + .ok_or_else(|| Status::new(tonic::Code::InvalidArgument, "empty operation"))? + .try_into()?, + access, + timeout.map(std::time::Duration::from_secs), + ) + .await?; + Ok(Response::new(UpdateCollectionClusterSetupResponse { + result, + })) + } + + async fn create_shard_key( + &self, + mut request: Request, + ) -> Result, Status> { + let access = extract_access(&mut request); + + let CreateShardKeyRequest { + collection_name, + request, + timeout, + } = request.into_inner(); + + let Some(request) = request else { + return Err(Status::new(tonic::Code::InvalidArgument, "empty request")); + }; + + let timeout = timeout.map(std::time::Duration::from_secs); + + let operation = ClusterOperations::CreateShardingKey(CreateShardingKeyOperation { + create_sharding_key: request.try_into()?, + }); + + let result = do_update_collection_cluster( + self.dispatcher.as_ref(), + collection_name, + operation, + access, + timeout, + ) + .await?; + + Ok(Response::new(CreateShardKeyResponse { result })) + } + + async fn delete_shard_key( + &self, + mut request: Request, + ) -> Result, Status> { + let access = extract_access(&mut request); + + let DeleteShardKeyRequest { + collection_name, + request, + timeout, + } = request.into_inner(); + + let Some(request) = request else { + return Err(Status::new(tonic::Code::InvalidArgument, "empty request")); + }; + + let timeout = timeout.map(std::time::Duration::from_secs); + + let operation = ClusterOperations::DropShardingKey(DropShardingKeyOperation { + drop_sharding_key: request.try_into()?, + }); + + let result = do_update_collection_cluster( + self.dispatcher.as_ref(), + collection_name, + operation, + access, + timeout, + ) + .await?; + + Ok(Response::new(DeleteShardKeyResponse { result })) + } +} + +trait WithTimeout { + fn wait_timeout(&self) -> Option; +} + +macro_rules! impl_with_timeout { + ($operation:ty) => { + impl WithTimeout for $operation { + fn wait_timeout(&self) -> Option { + self.timeout.map(Duration::from_secs) + } + } + }; +} + +impl_with_timeout!(CreateCollection); +impl_with_timeout!(UpdateCollection); +impl_with_timeout!(DeleteCollection); +impl_with_timeout!(ChangeAliases); +impl_with_timeout!(UpdateCollectionClusterSetupRequest); diff --git a/src/tonic/api/collections_common.rs b/src/tonic/api/collections_common.rs new file mode 100644 index 0000000000000000000000000000000000000000..5d8c54b1da944ee6f7927807386f62886f51cdc7 --- /dev/null +++ b/src/tonic/api/collections_common.rs @@ -0,0 +1,26 @@ +use std::time::Instant; + +use api::grpc::qdrant::{GetCollectionInfoRequest, GetCollectionInfoResponse}; +use collection::shards::shard::ShardId; +use storage::content_manager::toc::TableOfContent; +use storage::rbac::Access; +use tonic::{Response, Status}; + +use crate::common::collections::do_get_collection; + +pub async fn get( + toc: &TableOfContent, + get_collection_info_request: GetCollectionInfoRequest, + access: Access, + shard_selection: Option, +) -> Result, Status> { + let timing = Instant::now(); + let collection_name = get_collection_info_request.collection_name; + let result = do_get_collection(toc, access, &collection_name, shard_selection).await?; + let response = GetCollectionInfoResponse { + result: Some(result.into()), + time: timing.elapsed().as_secs_f64(), + }; + + Ok(Response::new(response)) +} diff --git a/src/tonic/api/collections_internal_api.rs b/src/tonic/api/collections_internal_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..2d834660d5930888a907c7d443a12e5834bcd224 --- /dev/null +++ b/src/tonic/api/collections_internal_api.rs @@ -0,0 +1,208 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use api::grpc::qdrant::collections_internal_server::CollectionsInternal; +use api::grpc::qdrant::{ + CollectionOperationResponse, GetCollectionInfoRequestInternal, GetCollectionInfoResponse, + GetShardRecoveryPointRequest, GetShardRecoveryPointResponse, InitiateShardTransferRequest, + UpdateShardCutoffPointRequest, WaitForShardStateRequest, +}; +use storage::content_manager::toc::TableOfContent; +use storage::rbac::{Access, AccessRequirements, CollectionPass}; +use tonic::{Request, Response, Status}; + +use super::validate_and_log; +use crate::tonic::api::collections_common::get; + +const FULL_ACCESS: Access = Access::full("Internal API"); + +fn full_access_pass(collection_name: &str) -> Result, Status> { + FULL_ACCESS + .check_collection_access(collection_name, AccessRequirements::new()) + .map_err(Status::from) +} + +pub struct CollectionsInternalService { + toc: Arc, +} + +impl CollectionsInternalService { + pub fn new(toc: Arc) -> Self { + Self { toc } + } +} + +#[tonic::async_trait] +impl CollectionsInternal for CollectionsInternalService { + async fn get( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + let GetCollectionInfoRequestInternal { + get_collection_info_request, + shard_id, + } = request.into_inner(); + + let get_collection_info_request = get_collection_info_request + .ok_or_else(|| Status::invalid_argument("GetCollectionInfoRequest is missing"))?; + + get( + self.toc.as_ref(), + get_collection_info_request, + FULL_ACCESS.clone(), + Some(shard_id), + ) + .await + } + + async fn initiate( + &self, + request: Request, + ) -> Result, Status> { + // TODO: Ensure cancel safety! + + validate_and_log(request.get_ref()); + let timing = Instant::now(); + let InitiateShardTransferRequest { + collection_name, + shard_id, + } = request.into_inner(); + + // TODO: Ensure cancel safety! + self.toc + .initiate_receiving_shard(collection_name, shard_id) + .await?; + + let response = CollectionOperationResponse { + result: true, + time: timing.elapsed().as_secs_f64(), + }; + Ok(Response::new(response)) + } + + async fn wait_for_shard_state( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + validate_and_log(&request); + + let timing = Instant::now(); + let WaitForShardStateRequest { + collection_name, + shard_id, + state, + timeout, + } = request; + let state = state.try_into()?; + let timeout = Duration::from_secs(timeout); + + let collection_read = self + .toc + .get_collection(&full_access_pass(&collection_name)?) + .await + .map_err(|err| { + Status::not_found(format!( + "Collection {collection_name} could not be found: {err}" + )) + })?; + + // Wait for replica state + collection_read + .wait_local_shard_replica_state(shard_id, state, timeout) + .await + .map_err(|err| { + Status::aborted(format!( + "Failed to wait for shard {shard_id} to get into {state:?} state: {err}" + )) + })?; + + let response = CollectionOperationResponse { + result: true, + time: timing.elapsed().as_secs_f64(), + }; + Ok(Response::new(response)) + } + + async fn get_shard_recovery_point( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let timing = Instant::now(); + let GetShardRecoveryPointRequest { + collection_name, + shard_id, + } = request.into_inner(); + + let collection_read = self + .toc + .get_collection(&full_access_pass(&collection_name)?) + .await + .map_err(|err| { + Status::not_found(format!( + "Collection {collection_name} could not be found: {err}" + )) + })?; + + // Get shard recovery point + let recovery_point = collection_read + .shard_recovery_point(shard_id) + .await + .map_err(|err| { + Status::internal(format!( + "Failed to get recovery point for shard {shard_id}: {err}" + )) + })?; + + let response = GetShardRecoveryPointResponse { + recovery_point: Some(recovery_point.into()), + time: timing.elapsed().as_secs_f64(), + }; + Ok(Response::new(response)) + } + + async fn update_shard_cutoff_point( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let timing = Instant::now(); + let UpdateShardCutoffPointRequest { + collection_name, + shard_id, + cutoff, + } = request.into_inner(); + + let cutoff = cutoff.ok_or_else(|| Status::invalid_argument("Missing cutoff point"))?; + + let collection_read = self + .toc + .get_collection(&full_access_pass(&collection_name)?) + .await + .map_err(|err| { + Status::not_found(format!( + "Collection {collection_name} could not be found: {err}" + )) + })?; + + // Set the shard cutoff point + collection_read + .update_shard_cutoff_point(shard_id, &cutoff.try_into()?) + .await + .map_err(|err| { + Status::internal(format!( + "Failed to set shard cutoff point for shard {shard_id}: {err}" + )) + })?; + + let response = CollectionOperationResponse { + result: true, + time: timing.elapsed().as_secs_f64(), + }; + Ok(Response::new(response)) + } +} diff --git a/src/tonic/api/mod.rs b/src/tonic/api/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..6c102e727e07f6805a46bcfbc2f8426848a9d2b0 --- /dev/null +++ b/src/tonic/api/mod.rs @@ -0,0 +1,64 @@ +pub mod collections_api; +mod collections_common; +pub mod collections_internal_api; +pub mod points_api; +mod points_common; +pub mod points_internal_api; +pub mod raft_api; +pub mod snapshots_api; + +use collection::operations::validation; +use tonic::Status; +use validator::Validate; + +/// Validate the given request and fail on error. +/// +/// Returns validation error on failure. +fn validate(request: &impl Validate) -> Result<(), Status> { + request.validate().map_err(|ref err| { + Status::invalid_argument(validation::label_errors("Validation error in body", err)) + }) +} + +/// Validate the given request. Returns validation error on failure. +fn validate_and_log(request: &impl Validate) { + if let Err(ref err) = request.validate() { + validation::warn_validation_errors("Internal gRPC", err); + } +} + +#[cfg(test)] +mod tests { + use validator::Validate; + + use super::*; + + #[derive(Validate, Debug)] + struct SomeThing { + #[validate(range(min = 1))] + pub idx: usize, + } + + #[derive(Validate, Debug)] + struct OtherThing { + #[validate(nested)] + pub things: Vec, + } + + #[test] + fn test_validation() { + use tonic::Code; + + let bad_config = OtherThing { + things: vec![SomeThing { idx: 0 }], + }; + + let validation = + validate(&bad_config).expect_err("validation of bad request payload should fail"); + assert_eq!(validation.code(), Code::InvalidArgument); + assert_eq!( + validation.message(), + "Validation error in body: [things[0].idx: value 0 invalid, must be 1 or larger]" + ) + } +} diff --git a/src/tonic/api/points_api.rs b/src/tonic/api/points_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..f4e4f1b0c44a7e198b94a457bbe6c3b93f4d72ff --- /dev/null +++ b/src/tonic/api/points_api.rs @@ -0,0 +1,656 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use api::grpc::qdrant::points_server::Points; +use api::grpc::qdrant::{ + ClearPayloadPoints, CountPoints, CountResponse, CreateFieldIndexCollection, + DeleteFieldIndexCollection, DeletePayloadPoints, DeletePointVectors, DeletePoints, + DiscoverBatchPoints, DiscoverBatchResponse, DiscoverPoints, DiscoverResponse, FacetCounts, + FacetResponse, GetPoints, GetResponse, PointsOperationResponse, QueryBatchPoints, + QueryBatchResponse, QueryGroupsResponse, QueryPointGroups, QueryPoints, QueryResponse, + RecommendBatchPoints, RecommendBatchResponse, RecommendGroupsResponse, RecommendPointGroups, + RecommendPoints, RecommendResponse, ScrollPoints, ScrollResponse, SearchBatchPoints, + SearchBatchResponse, SearchGroupsResponse, SearchMatrixOffsets, SearchMatrixOffsetsResponse, + SearchMatrixPairs, SearchMatrixPairsResponse, SearchMatrixPoints, SearchPointGroups, + SearchPoints, SearchResponse, SetPayloadPoints, UpdateBatchPoints, UpdateBatchResponse, + UpdatePointVectors, UpsertPoints, +}; +use collection::operations::types::CoreSearchRequest; +use collection::operations::verification::new_unchecked_verification_pass; +use common::counter::hardware_accumulator::HwMeasurementAcc; +use storage::content_manager::toc::request_hw_counter::RequestHwCounter; +use storage::dispatcher::Dispatcher; +use tonic::{Request, Response, Status}; + +use super::points_common::{ + delete_vectors, discover, discover_batch, facet, query, query_batch, query_groups, + recommend_groups, scroll, search_groups, search_points_matrix, update_batch, update_vectors, +}; +use super::validate; +use crate::settings::ServiceConfig; +use crate::tonic::api::points_common::{ + clear_payload, convert_shard_selector_for_read, core_search_batch, count, create_field_index, + delete, delete_field_index, delete_payload, get, overwrite_payload, recommend, recommend_batch, + search, set_payload, upsert, +}; +use crate::tonic::auth::extract_access; +use crate::tonic::verification::StrictModeCheckedTocProvider; + +pub struct PointsService { + dispatcher: Arc, + service_config: ServiceConfig, +} + +impl PointsService { + pub fn new(dispatcher: Arc, service_config: ServiceConfig) -> Self { + Self { + dispatcher, + service_config, + } + } + + fn get_request_collection_hw_usage_counter(&self, collection_name: String) -> RequestHwCounter { + let counter = HwMeasurementAcc::new_with_drain( + &self.dispatcher.get_collection_hw_metrics(collection_name), + ); + + RequestHwCounter::new(counter, self.service_config.hardware_reporting(), false) + } +} + +#[tonic::async_trait] +impl Points for PointsService { + async fn upsert( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let access = extract_access(&mut request); + + upsert( + self.dispatcher.toc(&access, &pass).clone(), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn delete( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + delete( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn get(&self, mut request: Request) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + get( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + access, + ) + .await + } + + async fn update_vectors( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let access = extract_access(&mut request); + + update_vectors( + self.dispatcher.toc(&access, &pass).clone(), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn delete_vectors( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + delete_vectors( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn set_payload( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + set_payload( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn overwrite_payload( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + overwrite_payload( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn delete_payload( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + delete_payload( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn clear_payload( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + clear_payload( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn update_batch( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + update_batch(&self.dispatcher, request.into_inner(), None, None, access).await + } + + async fn create_field_index( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + create_field_index( + self.dispatcher.clone(), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn delete_field_index( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + delete_field_index( + self.dispatcher.clone(), + request.into_inner(), + None, + None, + access, + ) + .await + .map(|resp| resp.map(Into::into)) + } + + async fn search( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + + let res = search( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + access, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn search_batch( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + + let SearchBatchPoints { + collection_name, + search_points, + read_consistency, + timeout, + } = request.into_inner(); + + let timeout = timeout.map(Duration::from_secs); + + let mut requests = Vec::new(); + + for mut search_point in search_points { + let shard_key = search_point.shard_key_selector.take(); + + let shard_selector = convert_shard_selector_for_read(None, shard_key); + let core_search_request = CoreSearchRequest::try_from(search_point)?; + + requests.push((core_search_request, shard_selector)); + } + + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name.clone()); + + let res = core_search_batch( + StrictModeCheckedTocProvider::new(&self.dispatcher), + &collection_name, + requests, + read_consistency, + access, + timeout, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn search_groups( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + let res = search_groups( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + access, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn scroll( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + scroll( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + access, + ) + .await + } + + async fn recommend( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + let res = recommend( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + access, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn recommend_batch( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let RecommendBatchPoints { + collection_name, + recommend_points, + read_consistency, + timeout, + } = request.into_inner(); + + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name.clone()); + + let res = recommend_batch( + StrictModeCheckedTocProvider::new(&self.dispatcher), + &collection_name, + recommend_points, + read_consistency, + access, + timeout.map(Duration::from_secs), + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn recommend_groups( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + + let res = recommend_groups( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + access, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn discover( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + let res = discover( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + access, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn discover_batch( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let DiscoverBatchPoints { + collection_name, + discover_points, + read_consistency, + timeout, + } = request.into_inner(); + + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name.clone()); + let res = discover_batch( + StrictModeCheckedTocProvider::new(&self.dispatcher), + &collection_name, + discover_points, + read_consistency, + access, + timeout.map(Duration::from_secs), + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn count( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + let res = count( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + &access, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn query( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + + let res = query( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + access, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn query_batch( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let request = request.into_inner(); + let QueryBatchPoints { + collection_name, + query_points, + read_consistency, + timeout, + } = request; + let timeout = timeout.map(Duration::from_secs); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name.clone()); + let res = query_batch( + StrictModeCheckedTocProvider::new(&self.dispatcher), + &collection_name, + query_points, + read_consistency, + access, + timeout, + hw_metrics, + ) + .await?; + + Ok(res) + } + + async fn query_groups( + &self, + mut request: Request, + ) -> Result, Status> { + let access = extract_access(&mut request); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + + let res = query_groups( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + None, + access, + hw_metrics, + ) + .await?; + + Ok(res) + } + async fn facet( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + facet( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + access, + ) + .await + } + + async fn search_matrix_pairs( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let timing = Instant::now(); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + let search_matrix_response = search_points_matrix( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + access, + hw_metrics.get_counter(), + ) + .await?; + + let pairs_response = SearchMatrixPairsResponse { + result: Some(SearchMatrixPairs::from(search_matrix_response)), + time: timing.elapsed().as_secs_f64(), + usage: hw_metrics.to_grpc_api(), + }; + + Ok(Response::new(pairs_response)) + } + + async fn search_matrix_offsets( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let timing = Instant::now(); + let collection_name = request.get_ref().collection_name.clone(); + let hw_metrics = self.get_request_collection_hw_usage_counter(collection_name); + let search_matrix_response = search_points_matrix( + StrictModeCheckedTocProvider::new(&self.dispatcher), + request.into_inner(), + access, + hw_metrics.get_counter(), + ) + .await?; + + let offsets_response = SearchMatrixOffsetsResponse { + result: Some(SearchMatrixOffsets::from(search_matrix_response)), + time: timing.elapsed().as_secs_f64(), + usage: hw_metrics.to_grpc_api(), + }; + + Ok(Response::new(offsets_response)) + } +} diff --git a/src/tonic/api/points_common.rs b/src/tonic/api/points_common.rs new file mode 100644 index 0000000000000000000000000000000000000000..fc9b321542aa1ff04c23c51d44fd6bbdde5a8356 --- /dev/null +++ b/src/tonic/api/points_common.rs @@ -0,0 +1,2044 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use api::conversions::json::{json_path_from_proto, proto_to_payloads}; +use api::grpc::qdrant::payload_index_params::IndexParams; +use api::grpc::qdrant::points_update_operation::{ClearPayload, Operation, PointStructList}; +use api::grpc::qdrant::{ + points_update_operation, BatchResult, ClearPayloadPoints, CoreSearchPoints, CountPoints, + CountResponse, CreateFieldIndexCollection, DeleteFieldIndexCollection, DeletePayloadPoints, + DeletePointVectors, DeletePoints, DiscoverBatchResponse, DiscoverPoints, DiscoverResponse, + FacetCounts, FacetResponse, FieldType, GetPoints, GetResponse, GroupsResult, + PayloadIndexParams, PointsOperationResponseInternal, PointsSelector, QueryBatchResponse, + QueryGroupsResponse, QueryPointGroups, QueryPoints, QueryResponse, + ReadConsistency as ReadConsistencyGrpc, RecommendBatchResponse, RecommendGroupsResponse, + RecommendPointGroups, RecommendPoints, RecommendResponse, ScrollPoints, ScrollResponse, + SearchBatchResponse, SearchGroupsResponse, SearchMatrixPoints, SearchPointGroups, SearchPoints, + SearchResponse, SetPayloadPoints, SyncPoints, UpdateBatchPoints, UpdateBatchResponse, + UpdatePointVectors, UpsertPoints, +}; +use api::rest::schema::{PointInsertOperations, PointsList}; +use api::rest::{ + OrderByInterface, PointStruct, PointVectors, ShardKeySelector, UpdateVectors, VectorStruct, +}; +use collection::collection::distance_matrix::{ + CollectionSearchMatrixRequest, CollectionSearchMatrixResponse, +}; +use collection::operations::consistency_params::ReadConsistency; +use collection::operations::conversions::{ + try_discover_request_from_grpc, try_points_selector_from_grpc, write_ordering_from_proto, +}; +use collection::operations::payload_ops::DeletePayload; +use collection::operations::point_ops::{self, PointOperations, PointSyncOperation}; +use collection::operations::query_enum::QueryEnum; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::types::{ + default_exact_count, CoreSearchRequest, CoreSearchRequestBatch, PointRequestInternal, + RecommendExample, ScrollRequestInternal, +}; +use collection::operations::vector_ops::DeleteVectors; +use collection::operations::verification::new_unchecked_verification_pass; +use collection::operations::{ClockTag, CollectionUpdateOperations, OperationWithClockTag}; +use collection::shards::shard::ShardId; +use common::counter::hardware_accumulator::HwMeasurementAcc; +use itertools::Itertools; +use segment::data_types::facets::FacetParams; +use segment::data_types::order_by::OrderBy; +use segment::data_types::vectors::DEFAULT_VECTOR_NAME; +use segment::types::{ + ExtendedPointId, Filter, PayloadFieldSchema, PayloadSchemaParams, PayloadSchemaType, +}; +use storage::content_manager::toc::request_hw_counter::RequestHwCounter; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +use tonic::{Response, Status}; + +use crate::common::inference::query_requests_grpc::{ + convert_query_point_groups_from_grpc, convert_query_points_from_grpc, +}; +use crate::common::inference::service::InferenceType; +use crate::common::inference::update_requests::convert_point_struct; +use crate::common::points::{ + do_clear_payload, do_core_search_points, do_count_points, do_create_index, + do_create_index_internal, do_delete_index, do_delete_index_internal, do_delete_payload, + do_delete_points, do_delete_vectors, do_get_points, do_overwrite_payload, + do_query_batch_points, do_query_point_groups, do_query_points, do_scroll_points, + do_search_batch_points, do_set_payload, do_update_vectors, do_upsert_points, CreateFieldIndex, +}; +use crate::tonic::verification::{CheckedTocProvider, StrictModeCheckedTocProvider}; + +fn extract_points_selector( + points_selector: Option, +) -> Result<(Option>, Option), Status> { + let (points, filter) = if let Some(points_selector) = points_selector { + let points_selector = try_points_selector_from_grpc(points_selector, None)?; + match points_selector { + point_ops::PointsSelector::PointIdsSelector(points) => (Some(points.points), None), + point_ops::PointsSelector::FilterSelector(filter) => (None, Some(filter.filter)), + } + } else { + return Err(Status::invalid_argument("points_selector is expected")); + }; + Ok((points, filter)) +} + +pub fn points_operation_response_internal( + timing: Instant, + update_result: collection::operations::types::UpdateResult, +) -> PointsOperationResponseInternal { + PointsOperationResponseInternal { + result: Some(update_result.into()), + time: timing.elapsed().as_secs_f64(), + } +} + +pub(crate) fn convert_shard_selector_for_read( + shard_id_selector: Option, + shard_key_selector: Option, +) -> ShardSelectorInternal { + match (shard_id_selector, shard_key_selector) { + (Some(shard_id), None) => ShardSelectorInternal::ShardId(shard_id), + (None, Some(shard_key_selector)) => ShardSelectorInternal::from(shard_key_selector), + (None, None) => ShardSelectorInternal::All, + (Some(shard_id), Some(_)) => { + debug_assert!( + false, + "Shard selection and shard key selector are mutually exclusive" + ); + ShardSelectorInternal::ShardId(shard_id) + } + } +} + +pub async fn upsert( + toc: Arc, + upsert_points: UpsertPoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let UpsertPoints { + collection_name, + wait, + points, + ordering, + shard_key_selector, + } = upsert_points; + + let points: Result<_, _> = points.into_iter().map(PointStruct::try_from).collect(); + + let operation = PointInsertOperations::PointsList(PointsList { + points: points?, + shard_key: shard_key_selector.map(ShardKeySelector::from), + }); + let timing = Instant::now(); + let result = do_upsert_points( + toc, + collection_name, + operation, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn sync( + toc: Arc, + sync_points: SyncPoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let SyncPoints { + collection_name, + wait, + points, + from_id, + to_id, + ordering, + } = sync_points; + + let timing = Instant::now(); + + let point_structs: Result<_, _> = points.into_iter().map(PointStruct::try_from).collect(); + + // No actual inference should happen here, as we are just syncing existing points + // So this function is used for consistency only + let points = convert_point_struct(point_structs?, InferenceType::Update).await?; + + let operation = PointSyncOperation { + points, + from_id: from_id.map(|x| x.try_into()).transpose()?, + to_id: to_id.map(|x| x.try_into()).transpose()?, + }; + let collection_operation = + CollectionUpdateOperations::PointOperation(PointOperations::SyncPoints(operation)); + + let shard_selector = if let Some(shard_selection) = shard_selection { + ShardSelectorInternal::ShardId(shard_selection) + } else { + debug_assert!(false, "Sync operation is supposed to select shard directly"); + ShardSelectorInternal::Empty + }; + + let result = toc + .update( + &collection_name, + OperationWithClockTag::new(collection_operation, clock_tag), + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + shard_selector, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn delete( + toc_provider: impl CheckedTocProvider, + delete_points: DeletePoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let DeletePoints { + collection_name, + wait, + points, + ordering, + shard_key_selector, + } = delete_points; + + let points_selector = match points { + None => return Err(Status::invalid_argument("PointSelector is missing")), + Some(p) => try_points_selector_from_grpc(p, shard_key_selector)?, + }; + + let toc = toc_provider + .check_strict_mode(&points_selector, &collection_name, None, &access) + .await?; + + let timing = Instant::now(); + let result = do_delete_points( + toc.clone(), + collection_name, + points_selector, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn update_vectors( + toc: Arc, + update_point_vectors: UpdatePointVectors, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let UpdatePointVectors { + collection_name, + wait, + points, + ordering, + shard_key_selector, + } = update_point_vectors; + + // Build list of operation points + let mut op_points = Vec::with_capacity(points.len()); + for point in points { + let id = match point.id { + Some(id) => id.try_into()?, + None => return Err(Status::invalid_argument("id is expected")), + }; + let vector = match point.vectors { + Some(vectors) => VectorStruct::try_from(vectors)?, + None => return Err(Status::invalid_argument("vectors is expected")), + }; + op_points.push(PointVectors { id, vector }); + } + + let operation = UpdateVectors { + points: op_points, + shard_key: shard_key_selector.map(ShardKeySelector::from), + }; + + let timing = Instant::now(); + let result = do_update_vectors( + toc, + collection_name, + operation, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn delete_vectors( + toc_provider: impl CheckedTocProvider, + delete_point_vectors: DeletePointVectors, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let DeletePointVectors { + collection_name, + wait, + points_selector, + vectors, + ordering, + shard_key_selector, + } = delete_point_vectors; + + let (points, filter) = extract_points_selector(points_selector)?; + let vector_names = match vectors { + Some(vectors) => vectors.names, + None => return Err(Status::invalid_argument("vectors is expected")), + }; + + let operation = DeleteVectors { + points, + filter, + vector: vector_names.into_iter().collect(), + shard_key: shard_key_selector.map(ShardKeySelector::from), + }; + + let toc = toc_provider + .check_strict_mode(&operation, &collection_name, None, &access) + .await?; + + let timing = Instant::now(); + let result = do_delete_vectors( + toc.clone(), + collection_name, + operation, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn set_payload( + toc_provider: impl CheckedTocProvider, + set_payload_points: SetPayloadPoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let SetPayloadPoints { + collection_name, + wait, + payload, + points_selector, + ordering, + shard_key_selector, + key, + } = set_payload_points; + let key = key.map(|k| json_path_from_proto(&k)).transpose()?; + + let (points, filter) = extract_points_selector(points_selector)?; + let operation = collection::operations::payload_ops::SetPayload { + payload: proto_to_payloads(payload)?, + points, + filter, + shard_key: shard_key_selector.map(ShardKeySelector::from), + key, + }; + + let toc = toc_provider + .check_strict_mode(&operation, &collection_name, None, &access) + .await?; + + let timing = Instant::now(); + let result = do_set_payload( + toc.clone(), + collection_name, + operation, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn overwrite_payload( + toc_provider: impl CheckedTocProvider, + set_payload_points: SetPayloadPoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let SetPayloadPoints { + collection_name, + wait, + payload, + points_selector, + ordering, + shard_key_selector, + .. + } = set_payload_points; + + let (points, filter) = extract_points_selector(points_selector)?; + let operation = collection::operations::payload_ops::SetPayload { + payload: proto_to_payloads(payload)?, + points, + filter, + shard_key: shard_key_selector.map(ShardKeySelector::from), + // overwrite operation don't support indicate path of property + key: None, + }; + + let toc = toc_provider + .check_strict_mode(&operation, &collection_name, None, &access) + .await?; + + let timing = Instant::now(); + let result = do_overwrite_payload( + toc.clone(), + collection_name, + operation, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn delete_payload( + toc_provider: impl CheckedTocProvider, + delete_payload_points: DeletePayloadPoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let DeletePayloadPoints { + collection_name, + wait, + keys, + points_selector, + ordering, + shard_key_selector, + } = delete_payload_points; + let keys = keys.iter().map(|k| json_path_from_proto(k)).try_collect()?; + + let (points, filter) = extract_points_selector(points_selector)?; + let operation = DeletePayload { + keys, + points, + filter, + shard_key: shard_key_selector.map(ShardKeySelector::from), + }; + + let toc = toc_provider + .check_strict_mode(&operation, &collection_name, None, &access) + .await?; + + let timing = Instant::now(); + let result = do_delete_payload( + toc.clone(), + collection_name, + operation, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn clear_payload( + toc_provider: impl CheckedTocProvider, + clear_payload_points: ClearPayloadPoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let ClearPayloadPoints { + collection_name, + wait, + points, + ordering, + shard_key_selector, + } = clear_payload_points; + + let points_selector = match points { + None => return Err(Status::invalid_argument("PointSelector is missing")), + Some(p) => try_points_selector_from_grpc(p, shard_key_selector)?, + }; + + let toc = toc_provider + .check_strict_mode(&points_selector, &collection_name, None, &access) + .await?; + + let timing = Instant::now(); + let result = do_clear_payload( + toc.clone(), + collection_name, + points_selector, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn update_batch( + dispatcher: &Dispatcher, + update_batch_points: UpdateBatchPoints, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let UpdateBatchPoints { + collection_name, + wait, + operations, + ordering, + } = update_batch_points; + + let timing = Instant::now(); + let mut results = Vec::with_capacity(operations.len()); + for op in operations { + let operation = op + .operation + .ok_or_else(|| Status::invalid_argument("Operation is missing"))?; + let collection_name = collection_name.clone(); + let ordering = ordering.clone(); + let result = match operation { + points_update_operation::Operation::Upsert(PointStructList { + points, + shard_key_selector, + }) => { + // We don't need strict mode checks for upsert! + let toc = dispatcher.toc(&access, &new_unchecked_verification_pass()); + upsert( + toc.clone(), + UpsertPoints { + collection_name, + wait, + points, + ordering, + shard_key_selector, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + points_update_operation::Operation::DeleteDeprecated(points) => { + delete( + StrictModeCheckedTocProvider::new(dispatcher), + DeletePoints { + collection_name, + wait, + points: Some(points), + ordering, + shard_key_selector: None, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + points_update_operation::Operation::SetPayload( + points_update_operation::SetPayload { + payload, + points_selector, + shard_key_selector, + key, + }, + ) => { + set_payload( + StrictModeCheckedTocProvider::new(dispatcher), + SetPayloadPoints { + collection_name, + wait, + payload, + points_selector, + ordering, + shard_key_selector, + key, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + points_update_operation::Operation::OverwritePayload( + points_update_operation::OverwritePayload { + payload, + points_selector, + shard_key_selector, + .. + }, + ) => { + overwrite_payload( + StrictModeCheckedTocProvider::new(dispatcher), + SetPayloadPoints { + collection_name, + wait, + payload, + points_selector, + ordering, + shard_key_selector, + // overwrite operation don't support it + key: None, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + points_update_operation::Operation::DeletePayload( + points_update_operation::DeletePayload { + keys, + points_selector, + shard_key_selector, + }, + ) => { + delete_payload( + StrictModeCheckedTocProvider::new(dispatcher), + DeletePayloadPoints { + collection_name, + wait, + keys, + points_selector, + ordering, + shard_key_selector, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + points_update_operation::Operation::ClearPayload(ClearPayload { + points, + shard_key_selector, + }) => { + clear_payload( + StrictModeCheckedTocProvider::new(dispatcher), + ClearPayloadPoints { + collection_name, + wait, + points, + ordering, + shard_key_selector, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + points_update_operation::Operation::UpdateVectors( + points_update_operation::UpdateVectors { + points, + shard_key_selector, + }, + ) => { + // We don't need strict mode checks for vector updates! + let toc = dispatcher.toc(&access, &new_unchecked_verification_pass()); + update_vectors( + toc.clone(), + UpdatePointVectors { + collection_name, + wait, + points, + ordering, + shard_key_selector, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + points_update_operation::Operation::DeleteVectors( + points_update_operation::DeleteVectors { + points_selector, + vectors, + shard_key_selector, + }, + ) => { + delete_vectors( + StrictModeCheckedTocProvider::new(dispatcher), + DeletePointVectors { + collection_name, + wait, + points_selector, + vectors, + ordering, + shard_key_selector, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + Operation::ClearPayloadDeprecated(selector) => { + clear_payload( + StrictModeCheckedTocProvider::new(dispatcher), + ClearPayloadPoints { + collection_name, + wait, + points: Some(selector), + ordering, + shard_key_selector: None, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + Operation::DeletePoints(points_update_operation::DeletePoints { + points, + shard_key_selector, + }) => { + delete( + StrictModeCheckedTocProvider::new(dispatcher), + DeletePoints { + collection_name, + wait, + points, + ordering, + shard_key_selector, + }, + clock_tag, + shard_selection, + access.clone(), + ) + .await + } + }?; + results.push(result); + } + Ok(Response::new(UpdateBatchResponse { + result: results + .into_iter() + .map(|response| response.into_inner().result.unwrap().into()) + .collect(), + time: timing.elapsed().as_secs_f64(), + })) +} + +fn convert_field_type( + field_type: Option, + field_index_params: Option, +) -> Result, Status> { + let field_type_parsed = field_type + .map(|x| FieldType::try_from(x).ok()) + .ok_or_else(|| Status::invalid_argument("cannot convert field_type"))?; + + let field_schema = match (field_type_parsed, field_index_params) { + ( + Some(field_type), + Some(PayloadIndexParams { + index_params: Some(index_params), + }), + ) => { + let schema_params = match index_params { + // Parameterized keyword type + IndexParams::KeywordIndexParams(keyword_index_params) => { + matches!(field_type, FieldType::Keyword).then(|| { + TryFrom::try_from(keyword_index_params).map(PayloadSchemaParams::Keyword) + }) + } + IndexParams::IntegerIndexParams(integer_index_params) => { + matches!(field_type, FieldType::Integer).then(|| { + TryFrom::try_from(integer_index_params).map(PayloadSchemaParams::Integer) + }) + } + // Parameterized float type + IndexParams::FloatIndexParams(float_index_params) => { + matches!(field_type, FieldType::Float).then(|| { + TryFrom::try_from(float_index_params).map(PayloadSchemaParams::Float) + }) + } + IndexParams::GeoIndexParams(geo_index_params) => { + matches!(field_type, FieldType::Geo) + .then(|| TryFrom::try_from(geo_index_params).map(PayloadSchemaParams::Geo)) + } + // Parameterized text type + IndexParams::TextIndexParams(text_index_params) => { + matches!(field_type, FieldType::Text).then(|| { + TryFrom::try_from(text_index_params).map(PayloadSchemaParams::Text) + }) + } + // Parameterized bool type + IndexParams::BoolIndexParams(bool_index_params) => { + matches!(field_type, FieldType::Bool).then(|| { + TryFrom::try_from(bool_index_params).map(PayloadSchemaParams::Bool) + }) + } + // Parameterized Datetime type + IndexParams::DatetimeIndexParams(datetime_index_params) => { + matches!(field_type, FieldType::Datetime).then(|| { + TryFrom::try_from(datetime_index_params).map(PayloadSchemaParams::Datetime) + }) + } + // Parameterized Uuid type + IndexParams::UuidIndexParams(uuid_index_params) => { + matches!(field_type, FieldType::Uuid).then(|| { + TryFrom::try_from(uuid_index_params).map(PayloadSchemaParams::Uuid) + }) + } + } + .ok_or_else(|| { + Status::invalid_argument(format!( + "field_type ({field_type:?}) and field_index_params do not match" + )) + })??; + + Some(PayloadFieldSchema::FieldParams(schema_params)) + } + // Regular field types + (Some(v), None | Some(PayloadIndexParams { index_params: None })) => match v { + FieldType::Keyword => Some(PayloadSchemaType::Keyword.into()), + FieldType::Integer => Some(PayloadSchemaType::Integer.into()), + FieldType::Float => Some(PayloadSchemaType::Float.into()), + FieldType::Geo => Some(PayloadSchemaType::Geo.into()), + FieldType::Text => Some(PayloadSchemaType::Text.into()), + FieldType::Bool => Some(PayloadSchemaType::Bool.into()), + FieldType::Datetime => Some(PayloadSchemaType::Datetime.into()), + FieldType::Uuid => Some(PayloadSchemaType::Uuid.into()), + }, + (None, Some(_)) => return Err(Status::invalid_argument("field type is missing")), + (None, None) => None, + }; + + Ok(field_schema) +} + +pub async fn create_field_index( + dispatcher: Arc, + create_field_index_collection: CreateFieldIndexCollection, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let CreateFieldIndexCollection { + collection_name, + wait, + field_name, + field_type, + field_index_params, + ordering, + } = create_field_index_collection; + + let field_name = json_path_from_proto(&field_name)?; + let field_schema = convert_field_type(field_type, field_index_params)?; + + let operation = CreateFieldIndex { + field_name, + field_schema, + }; + + let timing = Instant::now(); + let result = do_create_index( + dispatcher, + collection_name, + operation, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn create_field_index_internal( + toc: Arc, + create_field_index_collection: CreateFieldIndexCollection, + clock_tag: Option, + shard_selection: Option, +) -> Result, Status> { + let CreateFieldIndexCollection { + collection_name, + wait, + field_name, + field_type, + field_index_params, + ordering, + } = create_field_index_collection; + + let field_name = json_path_from_proto(&field_name)?; + let field_schema = convert_field_type(field_type, field_index_params)?; + + let timing = Instant::now(); + let result = do_create_index_internal( + toc, + collection_name, + field_name, + field_schema, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn delete_field_index( + dispatcher: Arc, + delete_field_index_collection: DeleteFieldIndexCollection, + clock_tag: Option, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let DeleteFieldIndexCollection { + collection_name, + wait, + field_name, + ordering, + } = delete_field_index_collection; + + let field_name = json_path_from_proto(&field_name)?; + + let timing = Instant::now(); + let result = do_delete_index( + dispatcher, + collection_name, + field_name, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + access, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn delete_field_index_internal( + toc: Arc, + delete_field_index_collection: DeleteFieldIndexCollection, + clock_tag: Option, + shard_selection: Option, +) -> Result, Status> { + let DeleteFieldIndexCollection { + collection_name, + wait, + field_name, + ordering, + } = delete_field_index_collection; + + let field_name = json_path_from_proto(&field_name)?; + + let timing = Instant::now(); + let result = do_delete_index_internal( + toc, + collection_name, + field_name, + clock_tag, + shard_selection, + wait.unwrap_or(false), + write_ordering_from_proto(ordering)?, + ) + .await?; + + let response = points_operation_response_internal(timing, result); + Ok(Response::new(response)) +} + +pub async fn search( + toc_provider: impl CheckedTocProvider, + search_points: SearchPoints, + shard_selection: Option, + access: Access, + hw_measurement_acc: RequestHwCounter, +) -> Result, Status> { + let SearchPoints { + collection_name, + vector, + filter, + limit, + offset, + with_payload, + params, + score_threshold, + vector_name, + with_vectors, + read_consistency, + timeout, + shard_key_selector, + sparse_indices, + } = search_points; + + let vector_struct = + api::grpc::conversions::into_named_vector_struct(vector_name, vector, sparse_indices)?; + + let shard_selector = convert_shard_selector_for_read(shard_selection, shard_key_selector); + + let search_request = CoreSearchRequest { + query: QueryEnum::Nearest(vector_struct), + filter: filter.map(|f| f.try_into()).transpose()?, + params: params.map(|p| p.into()), + limit: limit as usize, + offset: offset.unwrap_or_default() as usize, + with_payload: with_payload.map(|wp| wp.try_into()).transpose()?, + with_vector: Some( + with_vectors + .map(|selector| selector.into()) + .unwrap_or_default(), + ), + score_threshold, + }; + + let toc = toc_provider + .check_strict_mode( + &search_request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let timing = Instant::now(); + let scored_points = do_core_search_points( + toc, + &collection_name, + search_request, + read_consistency, + shard_selector, + access, + timeout.map(Duration::from_secs), + hw_measurement_acc.get_counter(), + ) + .await?; + + let response = SearchResponse { + result: scored_points + .into_iter() + .map(|point| point.into()) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: hw_measurement_acc.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn core_search_batch( + toc_provider: impl CheckedTocProvider, + collection_name: &str, + requests: Vec<(CoreSearchRequest, ShardSelectorInternal)>, + read_consistency: Option, + access: Access, + timeout: Option, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let toc = toc_provider + .check_strict_mode_batch( + &requests, + |i| &i.0, + collection_name, + timeout.map(|i| i.as_secs() as usize), + &access, + ) + .await?; + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let timing = Instant::now(); + + let scored_points = do_search_batch_points( + toc, + collection_name, + requests, + read_consistency, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = SearchBatchResponse { + result: scored_points + .into_iter() + .map(|points| BatchResult { + result: points.into_iter().map(|p| p.into()).collect(), + }) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +#[allow(clippy::too_many_arguments)] +pub async fn core_search_list( + toc: &TableOfContent, + collection_name: String, + search_points: Vec, + read_consistency: Option, + shard_selection: Option, + access: Access, + timeout: Option, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let searches: Result, Status> = + search_points.into_iter().map(TryInto::try_into).collect(); + + let request = CoreSearchRequestBatch { + searches: searches?, + }; + + let timing = Instant::now(); + + // As this function is handling an internal request, + // we can assume that shard_key is already resolved + let shard_selection = match shard_selection { + None => { + debug_assert!(false, "Shard selection is expected for internal request"); + ShardSelectorInternal::All + } + Some(shard_id) => ShardSelectorInternal::ShardId(shard_id), + }; + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let scored_points = toc + .core_search_batch( + &collection_name, + request, + read_consistency, + shard_selection, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = SearchBatchResponse { + result: scored_points + .into_iter() + .map(|points| BatchResult { + result: points.into_iter().map(|p| p.into()).collect(), + }) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn search_groups( + toc_provider: impl CheckedTocProvider, + search_point_groups: SearchPointGroups, + shard_selection: Option, + access: Access, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let search_groups_request = search_point_groups.clone().try_into()?; + + let SearchPointGroups { + collection_name, + read_consistency, + timeout, + shard_key_selector, + .. + } = search_point_groups; + + let toc = toc_provider + .check_strict_mode( + &search_groups_request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let shard_selector = convert_shard_selector_for_read(shard_selection, shard_key_selector); + + let timing = Instant::now(); + let groups_result = crate::common::points::do_search_point_groups( + toc, + &collection_name, + search_groups_request, + read_consistency, + shard_selector, + access, + timeout.map(Duration::from_secs), + request_hw_counter.get_counter(), + ) + .await?; + + let groups_result = GroupsResult::try_from(groups_result) + .map_err(|e| Status::internal(format!("Failed to convert groups result: {e}")))?; + + let response = SearchGroupsResponse { + result: Some(groups_result), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn recommend( + toc_provider: impl CheckedTocProvider, + recommend_points: RecommendPoints, + access: Access, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + // TODO(luis): check if we can make this into a From impl + let RecommendPoints { + collection_name, + positive, + negative, + positive_vectors, + negative_vectors, + strategy, + filter, + limit, + offset, + with_payload, + params, + score_threshold, + using, + with_vectors, + lookup_from, + read_consistency, + timeout, + shard_key_selector, + } = recommend_points; + + let positive_ids = positive + .into_iter() + .map(TryInto::try_into) + .collect::, Status>>()?; + let positive_vectors = positive_vectors + .into_iter() + .map(TryInto::try_into) + .collect::>()?; + let positive = [positive_ids, positive_vectors].concat(); + + let negative_ids = negative + .into_iter() + .map(TryInto::try_into) + .collect::, Status>>()?; + let negative_vectors = negative_vectors + .into_iter() + .map(|v| RecommendExample::Dense(v.data)) + .collect(); + let negative = [negative_ids, negative_vectors].concat(); + + let request = collection::operations::types::RecommendRequestInternal { + positive, + negative, + strategy: strategy.map(|s| s.try_into()).transpose()?, + filter: filter.map(|f| f.try_into()).transpose()?, + params: params.map(|p| p.into()), + limit: limit as usize, + offset: offset.map(|x| x as usize), + with_payload: with_payload.map(|wp| wp.try_into()).transpose()?, + with_vector: Some( + with_vectors + .map(|selector| selector.into()) + .unwrap_or_default(), + ), + score_threshold, + using: using.map(|u| u.into()), + lookup_from: lookup_from.map(|l| l.into()), + }; + + let toc = toc_provider + .check_strict_mode( + &request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + let shard_selector = convert_shard_selector_for_read(None, shard_key_selector); + let timeout = timeout.map(Duration::from_secs); + + let timing = Instant::now(); + let recommended_points = toc + .recommend( + &collection_name, + request, + read_consistency, + shard_selector, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = RecommendResponse { + result: recommended_points + .into_iter() + .map(|point| point.into()) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn recommend_batch( + toc_provider: impl CheckedTocProvider, + collection_name: &str, + recommend_points: Vec, + read_consistency: Option, + access: Access, + timeout: Option, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let mut requests = Vec::with_capacity(recommend_points.len()); + + for mut request in recommend_points { + let shard_selector = + convert_shard_selector_for_read(None, request.shard_key_selector.take()); + let internal_request: collection::operations::types::RecommendRequestInternal = + request.try_into()?; + requests.push((internal_request, shard_selector)); + } + + let toc = toc_provider + .check_strict_mode_batch( + &requests, + |i| &i.0, + collection_name, + timeout.map(|i| i.as_secs() as usize), + &access, + ) + .await?; + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let timing = Instant::now(); + let scored_points = toc + .recommend_batch( + collection_name, + requests, + read_consistency, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = RecommendBatchResponse { + result: scored_points + .into_iter() + .map(|points| BatchResult { + result: points.into_iter().map(|p| p.into()).collect(), + }) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn recommend_groups( + toc_provider: impl CheckedTocProvider, + recommend_point_groups: RecommendPointGroups, + access: Access, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let recommend_groups_request = recommend_point_groups.clone().try_into()?; + + let RecommendPointGroups { + collection_name, + read_consistency, + timeout, + shard_key_selector, + .. + } = recommend_point_groups; + + let toc = toc_provider + .check_strict_mode( + &recommend_groups_request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let shard_selector = convert_shard_selector_for_read(None, shard_key_selector); + + let timing = Instant::now(); + let groups_result = crate::common::points::do_recommend_point_groups( + toc, + &collection_name, + recommend_groups_request, + read_consistency, + shard_selector, + access, + timeout.map(Duration::from_secs), + request_hw_counter.get_counter(), + ) + .await?; + + let groups_result = GroupsResult::try_from(groups_result) + .map_err(|e| Status::internal(format!("Failed to convert groups result: {e}")))?; + + let response = RecommendGroupsResponse { + result: Some(groups_result), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn discover( + toc_provider: impl CheckedTocProvider, + discover_points: DiscoverPoints, + access: Access, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let (request, collection_name, read_consistency, timeout, shard_key_selector) = + try_discover_request_from_grpc(discover_points)?; + + let toc = toc_provider + .check_strict_mode( + &request, + &collection_name, + timeout.map(|i| i.as_secs() as usize), + &access, + ) + .await?; + + let timing = Instant::now(); + + let shard_selector = convert_shard_selector_for_read(None, shard_key_selector); + + let discovered_points = toc + .discover( + &collection_name, + request, + read_consistency, + shard_selector, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = DiscoverResponse { + result: discovered_points + .into_iter() + .map(|point| point.into()) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn discover_batch( + toc_provider: impl CheckedTocProvider, + collection_name: &str, + discover_points: Vec, + read_consistency: Option, + access: Access, + timeout: Option, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let mut requests = Vec::with_capacity(discover_points.len()); + + for discovery_request in discover_points { + let (internal_request, _collection_name, _consistency, _timeout, shard_key_selector) = + try_discover_request_from_grpc(discovery_request)?; + let shard_selector = convert_shard_selector_for_read(None, shard_key_selector); + requests.push((internal_request, shard_selector)); + } + + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let toc = toc_provider + .check_strict_mode_batch( + &requests, + |i| &i.0, + collection_name, + timeout.map(|i| i.as_secs() as usize), + &access, + ) + .await?; + + let timing = Instant::now(); + let scored_points = toc + .discover_batch( + collection_name, + requests, + read_consistency, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = DiscoverBatchResponse { + result: scored_points + .into_iter() + .map(|points| BatchResult { + result: points.into_iter().map(|p| p.into()).collect(), + }) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn scroll( + toc_provider: impl CheckedTocProvider, + scroll_points: ScrollPoints, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let ScrollPoints { + collection_name, + filter, + offset, + limit, + with_payload, + with_vectors, + read_consistency, + shard_key_selector, + order_by, + timeout, + } = scroll_points; + + let scroll_request = ScrollRequestInternal { + offset: offset.map(|o| o.try_into()).transpose()?, + limit: limit.map(|l| l as usize), + filter: filter.map(|f| f.try_into()).transpose()?, + with_payload: with_payload.map(|wp| wp.try_into()).transpose()?, + with_vector: with_vectors + .map(|selector| selector.into()) + .unwrap_or_default(), + order_by: order_by + .map(OrderBy::try_from) + .transpose()? + .map(OrderByInterface::Struct), + }; + + let toc = toc_provider + .check_strict_mode( + &scroll_request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let timeout = timeout.map(Duration::from_secs); + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let shard_selector = convert_shard_selector_for_read(shard_selection, shard_key_selector); + + let timing = Instant::now(); + let scrolled_points = do_scroll_points( + toc, + &collection_name, + scroll_request, + read_consistency, + timeout, + shard_selector, + access, + ) + .await?; + + let points: Result<_, _> = scrolled_points + .points + .into_iter() + .map(api::grpc::qdrant::RetrievedPoint::try_from) + .collect(); + + let points = points.map_err(|e| Status::internal(format!("Failed to convert points: {e}")))?; + + let response = ScrollResponse { + next_page_offset: scrolled_points.next_page_offset.map(|n| n.into()), + result: points, + time: timing.elapsed().as_secs_f64(), + }; + + Ok(Response::new(response)) +} + +pub async fn count( + toc_provider: impl CheckedTocProvider, + count_points: CountPoints, + shard_selection: Option, + access: &Access, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let CountPoints { + collection_name, + filter, + exact, + read_consistency, + shard_key_selector, + timeout, + } = count_points; + + let count_request = collection::operations::types::CountRequestInternal { + filter: filter.map(|f| f.try_into()).transpose()?, + exact: exact.unwrap_or_else(default_exact_count), + }; + + let toc = toc_provider + .check_strict_mode( + &count_request, + &collection_name, + timeout.map(|i| i as usize), + access, + ) + .await?; + + let timeout = timeout.map(Duration::from_secs); + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let shard_selector = convert_shard_selector_for_read(shard_selection, shard_key_selector); + + let timing = Instant::now(); + + let count_result = do_count_points( + toc, + &collection_name, + count_request, + read_consistency, + timeout, + shard_selector, + access.clone(), + request_hw_counter.get_counter(), + ) + .await?; + + let response = CountResponse { + result: Some(count_result.into()), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn get( + toc_provider: impl CheckedTocProvider, + get_points: GetPoints, + shard_selection: Option, + access: Access, +) -> Result, Status> { + let GetPoints { + collection_name, + ids, + with_payload, + with_vectors, + read_consistency, + shard_key_selector, + timeout, + } = get_points; + + let point_request = PointRequestInternal { + ids: ids + .into_iter() + .map(|p| p.try_into()) + .collect::>()?, + with_payload: with_payload.map(|wp| wp.try_into()).transpose()?, + with_vector: with_vectors + .map(|selector| selector.into()) + .unwrap_or_default(), + }; + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let shard_selector = convert_shard_selector_for_read(shard_selection, shard_key_selector); + + let timing = Instant::now(); + + let toc = toc_provider + .check_strict_mode( + &point_request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let timeout = timeout.map(Duration::from_secs); + + let records = do_get_points( + toc, + &collection_name, + point_request, + read_consistency, + timeout, + shard_selector, + access, + ) + .await?; + + let response = GetResponse { + result: records.into_iter().map(|point| point.into()).collect(), + time: timing.elapsed().as_secs_f64(), + }; + + Ok(Response::new(response)) +} + +pub async fn query( + toc_provider: impl CheckedTocProvider, + query_points: QueryPoints, + shard_selection: Option, + access: Access, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let shard_key_selector = query_points.shard_key_selector.clone(); + let shard_selector = convert_shard_selector_for_read(shard_selection, shard_key_selector); + let read_consistency = query_points + .read_consistency + .clone() + .map(TryFrom::try_from) + .transpose()?; + let collection_name = query_points.collection_name.clone(); + let timeout = query_points.timeout; + let request = convert_query_points_from_grpc(query_points).await?; + + let toc = toc_provider + .check_strict_mode( + &request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let timeout = timeout.map(Duration::from_secs); + + let timing = Instant::now(); + let scored_points = do_query_points( + toc, + &collection_name, + request, + read_consistency, + shard_selector, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = QueryResponse { + result: scored_points + .into_iter() + .map(|point| point.into()) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn query_batch( + toc_provider: impl CheckedTocProvider, + collection_name: &str, + points: Vec, + read_consistency: Option, + access: Access, + timeout: Option, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + let mut requests = Vec::with_capacity(points.len()); + for query_points in points { + let shard_key_selector = query_points.shard_key_selector.clone(); + let shard_selector = convert_shard_selector_for_read(None, shard_key_selector); + let request = convert_query_points_from_grpc(query_points).await?; + requests.push((request, shard_selector)); + } + + let toc = toc_provider + .check_strict_mode_batch( + &requests, + |i| &i.0, + collection_name, + timeout.map(|i| i.as_secs() as usize), + &access, + ) + .await?; + + let timing = Instant::now(); + let scored_points = do_query_batch_points( + toc, + collection_name, + requests, + read_consistency, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let response = QueryBatchResponse { + result: scored_points + .into_iter() + .map(|points| BatchResult { + result: points.into_iter().map(|p| p.into()).collect(), + }) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn query_groups( + toc_provider: impl CheckedTocProvider, + query_points: QueryPointGroups, + shard_selection: Option, + access: Access, + request_hw_counter: RequestHwCounter, +) -> Result, Status> { + let shard_key_selector = query_points.shard_key_selector.clone(); + let shard_selector = convert_shard_selector_for_read(shard_selection, shard_key_selector); + let read_consistency = query_points + .read_consistency + .clone() + .map(TryFrom::try_from) + .transpose()?; + let timeout = query_points.timeout; + let collection_name = query_points.collection_name.clone(); + let request = convert_query_point_groups_from_grpc(query_points).await?; + + let toc = toc_provider + .check_strict_mode( + &request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let timeout = timeout.map(Duration::from_secs); + let timing = Instant::now(); + + let groups_result = do_query_point_groups( + toc, + &collection_name, + request, + read_consistency, + shard_selector, + access, + timeout, + request_hw_counter.get_counter(), + ) + .await?; + + let grpc_group_result = GroupsResult::try_from(groups_result) + .map_err(|err| Status::internal(format!("failed to convert result: {err}")))?; + + let response = QueryGroupsResponse { + result: Some(grpc_group_result), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_counter.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +pub async fn facet( + toc_provider: impl CheckedTocProvider, + facet_counts: FacetCounts, + access: Access, +) -> Result, Status> { + let FacetCounts { + collection_name, + key, + filter, + exact, + limit, + read_consistency, + shard_key_selector, + timeout, + } = facet_counts; + + let facet_request = FacetParams { + key: json_path_from_proto(&key)?, + filter: filter.map(TryInto::try_into).transpose()?, + limit: limit + .map(usize::try_from) + .transpose() + .map_err(|_| Status::invalid_argument("could not parse limit param into usize"))? + .unwrap_or(FacetParams::DEFAULT_LIMIT), + exact: exact.unwrap_or(FacetParams::DEFAULT_EXACT), + }; + + let toc = toc_provider + .check_strict_mode( + &facet_request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let timeout = timeout.map(Duration::from_secs); + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let shard_selector = convert_shard_selector_for_read(None, shard_key_selector); + + let timing = Instant::now(); + let facet_response = toc + .facet( + &collection_name, + facet_request, + shard_selector, + read_consistency, + access, + timeout, + ) + .await?; + + let segment::data_types::facets::FacetResponse { hits } = facet_response; + + let response = FacetResponse { + hits: hits.into_iter().map(From::from).collect(), + time: timing.elapsed().as_secs_f64(), + }; + + Ok(Response::new(response)) +} + +pub async fn search_points_matrix( + toc_provider: impl CheckedTocProvider, + search_matrix_points: SearchMatrixPoints, + access: Access, + hw_measurement_acc: &HwMeasurementAcc, +) -> Result { + let SearchMatrixPoints { + collection_name, + filter, + sample, + limit, + using, + read_consistency, + shard_key_selector, + timeout, + } = search_matrix_points; + + let search_matrix_request = CollectionSearchMatrixRequest { + filter: filter.map(TryInto::try_into).transpose()?, + sample_size: sample + .map(usize::try_from) + .transpose() + .map_err(|_| Status::invalid_argument("could not parse 'sample' param into usize"))? + .unwrap_or(CollectionSearchMatrixRequest::DEFAULT_SAMPLE), + limit_per_sample: limit + .map(usize::try_from) + .transpose() + .map_err(|_| Status::invalid_argument("could not parse 'limit' param into usize"))? + .unwrap_or(CollectionSearchMatrixRequest::DEFAULT_LIMIT_PER_SAMPLE), + using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()), + }; + + let toc = toc_provider + .check_strict_mode( + &search_matrix_request, + &collection_name, + timeout.map(|i| i as usize), + &access, + ) + .await?; + + let timeout = timeout.map(Duration::from_secs); + let read_consistency = ReadConsistency::try_from_optional(read_consistency)?; + + let shard_selector = convert_shard_selector_for_read(None, shard_key_selector); + + let search_matrix_response = toc + .search_points_matrix( + &collection_name, + search_matrix_request, + read_consistency, + shard_selector, + access, + timeout, + hw_measurement_acc, + ) + .await?; + + Ok(search_matrix_response) +} diff --git a/src/tonic/api/points_internal_api.rs b/src/tonic/api/points_internal_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..167ce523b2a849b8a9cdd7697505871b7d4b2ec2 --- /dev/null +++ b/src/tonic/api/points_internal_api.rs @@ -0,0 +1,629 @@ +use std::str::FromStr; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use api::grpc::qdrant::points_internal_server::PointsInternal; +use api::grpc::qdrant::{ + ClearPayloadPointsInternal, CoreSearchBatchPointsInternal, CountPointsInternal, CountResponse, + CreateFieldIndexCollectionInternal, DeleteFieldIndexCollectionInternal, + DeletePayloadPointsInternal, DeletePointsInternal, DeleteVectorsInternal, FacetCountsInternal, + FacetResponseInternal, GetPointsInternal, GetResponse, IntermediateResult, + PointsOperationResponseInternal, QueryBatchPointsInternal, QueryBatchResponseInternal, + QueryResultInternal, QueryShardPoints, RecommendPointsInternal, RecommendResponse, + ScrollPointsInternal, ScrollResponse, SearchBatchResponse, SetPayloadPointsInternal, + SyncPointsInternal, UpdateVectorsInternal, UpsertPointsInternal, +}; +use collection::operations::shard_selector_internal::ShardSelectorInternal; +use collection::operations::universal_query::shard_query::ShardQueryRequest; +use collection::shards::shard::ShardId; +use common::counter::hardware_accumulator::HwMeasurementAcc; +use itertools::Itertools; +use segment::data_types::facets::{FacetParams, FacetResponse}; +use segment::json_path::JsonPath; +use segment::types::Filter; +use storage::content_manager::toc::request_hw_counter::RequestHwCounter; +use storage::content_manager::toc::TableOfContent; +use storage::rbac::Access; +use tonic::{Request, Response, Status}; + +use super::points_common::{core_search_list, scroll}; +use super::validate_and_log; +use crate::settings::ServiceConfig; +use crate::tonic::api::points_common::{ + clear_payload, count, create_field_index_internal, delete, delete_field_index_internal, + delete_payload, delete_vectors, get, overwrite_payload, recommend, set_payload, sync, + update_vectors, upsert, +}; +use crate::tonic::verification::UncheckedTocProvider; + +const FULL_ACCESS: Access = Access::full("Internal API"); + +/// This API is intended for P2P communication within a distributed deployment. +pub struct PointsInternalService { + toc: Arc, + service_config: ServiceConfig, +} + +impl PointsInternalService { + pub fn new(toc: Arc, service_config: ServiceConfig) -> Self { + Self { + toc, + service_config, + } + } +} + +pub async fn query_batch_internal( + toc: &TableOfContent, + collection_name: String, + query_points: Vec, + shard_selection: Option, + timeout: Option, + request_hw_data: RequestHwCounter, +) -> Result, Status> { + let batch_requests: Vec<_> = query_points + .into_iter() + .map(ShardQueryRequest::try_from) + .try_collect()?; + + let timing = Instant::now(); + + // As this function is handling an internal request, + // we can assume that shard_key is already resolved + let shard_selection = match shard_selection { + None => { + debug_assert!(false, "Shard selection is expected for internal request"); + ShardSelectorInternal::All + } + Some(shard_id) => ShardSelectorInternal::ShardId(shard_id), + }; + + let batch_response = toc + .query_batch_internal( + &collection_name, + batch_requests, + shard_selection, + timeout, + request_hw_data.get_counter(), + ) + .await?; + + let response = QueryBatchResponseInternal { + results: batch_response + .into_iter() + .map(|response| QueryResultInternal { + intermediate_results: response + .into_iter() + .map(|intermediate| IntermediateResult { + result: intermediate.into_iter().map(From::from).collect_vec(), + }) + .collect_vec(), + }) + .collect(), + time: timing.elapsed().as_secs_f64(), + usage: request_hw_data.to_grpc_api(), + }; + + Ok(Response::new(response)) +} + +async fn facet_counts_internal( + toc: &TableOfContent, + request: FacetCountsInternal, +) -> Result, Status> { + let timing = Instant::now(); + + let FacetCountsInternal { + collection_name, + key, + filter, + limit, + exact, + shard_id, + timeout, + } = request; + + let shard_selection = ShardSelectorInternal::ShardId(shard_id); + + let request = FacetParams { + key: JsonPath::from_str(&key) + .map_err(|_| Status::invalid_argument("Failed to parse facet key"))?, + limit: limit as usize, + filter: filter.map(Filter::try_from).transpose()?, + exact, + }; + + let response = toc + .facet_internal( + &collection_name, + request, + shard_selection, + timeout.map(Duration::from_secs), + ) + .await?; + + let FacetResponse { hits } = response; + + let response = FacetResponseInternal { + hits: hits.into_iter().map(From::from).collect_vec(), + time: timing.elapsed().as_secs_f64(), + }; + + Ok(Response::new(response)) +} + +impl PointsInternalService { + /// Generates a new `RequestHwCounter` for the request. + /// This counter is indented to be used for internal requests. + /// + /// So, it collects the hardware usage to the collection's counter ONLY if it was not + /// converted to a response. + fn get_request_collection_hw_usage_counter_for_internal( + &self, + collection_name: String, + ) -> RequestHwCounter { + let counter = + HwMeasurementAcc::new_with_drain(&self.toc.get_collection_hw_metrics(collection_name)); + + RequestHwCounter::new(counter, self.service_config.hardware_reporting(), true) + } +} + +#[tonic::async_trait] +impl PointsInternal for PointsInternalService { + async fn upsert( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let UpsertPointsInternal { + upsert_points, + shard_id, + clock_tag, + } = request.into_inner(); + + let upsert_points = + upsert_points.ok_or_else(|| Status::invalid_argument("UpsertPoints is missing"))?; + + upsert( + self.toc.clone(), + upsert_points, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn delete( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let DeletePointsInternal { + delete_points, + shard_id, + clock_tag, + } = request.into_inner(); + + let delete_points = + delete_points.ok_or_else(|| Status::invalid_argument("DeletePoints is missing"))?; + + delete( + UncheckedTocProvider::new_unchecked(&self.toc), + delete_points, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn update_vectors( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let request = request.into_inner(); + + let shard_id = request.shard_id; + let clock_tag = request.clock_tag; + + let update_point_vectors = request + .update_vectors + .ok_or_else(|| Status::invalid_argument("UpdateVectors is missing"))?; + + update_vectors( + self.toc.clone(), + update_point_vectors, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn delete_vectors( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let request = request.into_inner(); + + let shard_id = request.shard_id; + let clock_tag = request.clock_tag; + + let delete_point_vectors = request + .delete_vectors + .ok_or_else(|| Status::invalid_argument("DeleteVectors is missing"))?; + + delete_vectors( + UncheckedTocProvider::new_unchecked(&self.toc), + delete_point_vectors, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn set_payload( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let SetPayloadPointsInternal { + set_payload_points, + shard_id, + clock_tag, + } = request.into_inner(); + + let set_payload_points = set_payload_points + .ok_or_else(|| Status::invalid_argument("SetPayloadPoints is missing"))?; + + set_payload( + UncheckedTocProvider::new_unchecked(&self.toc), + set_payload_points, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn overwrite_payload( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let SetPayloadPointsInternal { + set_payload_points, + shard_id, + clock_tag, + } = request.into_inner(); + + let set_payload_points = set_payload_points + .ok_or_else(|| Status::invalid_argument("SetPayloadPoints is missing"))?; + + overwrite_payload( + UncheckedTocProvider::new_unchecked(&self.toc), + set_payload_points, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn delete_payload( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let DeletePayloadPointsInternal { + delete_payload_points, + shard_id, + clock_tag, + } = request.into_inner(); + + let delete_payload_points = delete_payload_points + .ok_or_else(|| Status::invalid_argument("DeletePayloadPoints is missing"))?; + + delete_payload( + UncheckedTocProvider::new_unchecked(&self.toc), + delete_payload_points, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn clear_payload( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let ClearPayloadPointsInternal { + clear_payload_points, + shard_id, + clock_tag, + } = request.into_inner(); + + let clear_payload_points = clear_payload_points + .ok_or_else(|| Status::invalid_argument("ClearPayloadPoints is missing"))?; + + clear_payload( + UncheckedTocProvider::new_unchecked(&self.toc), + clear_payload_points, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn create_field_index( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let CreateFieldIndexCollectionInternal { + create_field_index_collection, + shard_id, + clock_tag, + } = request.into_inner(); + + let create_field_index_collection = create_field_index_collection + .ok_or_else(|| Status::invalid_argument("CreateFieldIndexCollection is missing"))?; + + create_field_index_internal( + self.toc.clone(), + create_field_index_collection, + clock_tag.map(Into::into), + shard_id, + ) + .await + } + + async fn delete_field_index( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let DeleteFieldIndexCollectionInternal { + delete_field_index_collection, + shard_id, + clock_tag, + } = request.into_inner(); + + let delete_field_index_collection = delete_field_index_collection + .ok_or_else(|| Status::invalid_argument("DeleteFieldIndexCollection is missing"))?; + + delete_field_index_internal( + self.toc.clone(), + delete_field_index_collection, + clock_tag.map(Into::into), + shard_id, + ) + .await + } + + async fn core_search_batch( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let CoreSearchBatchPointsInternal { + collection_name, + search_points, + shard_id, + timeout, + } = request.into_inner(); + + let timeout = timeout.map(Duration::from_secs); + + // Individual `read_consistency` values are ignored by `core_search_batch`... + // + // search_points + // .iter_mut() + // .for_each(|search_points| search_points.read_consistency = None); + + let hw_data = + self.get_request_collection_hw_usage_counter_for_internal(collection_name.clone()); + let res = core_search_list( + self.toc.as_ref(), + collection_name, + search_points, + None, // *Has* to be `None`! + shard_id, + FULL_ACCESS.clone(), + timeout, + hw_data, + ) + .await?; + + Ok(res) + } + + async fn recommend( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let RecommendPointsInternal { + recommend_points, + .. // shard_id - is not used in internal API, + // because it is transformed into regular search requests on the first node + } = request.into_inner(); + + let mut recommend_points = recommend_points + .ok_or_else(|| Status::invalid_argument("RecommendPoints is missing"))?; + + recommend_points.read_consistency = None; // *Have* to be `None`! + + let collection_name = recommend_points.collection_name.clone(); + + let hw_data = self.get_request_collection_hw_usage_counter_for_internal(collection_name); + let res = recommend( + UncheckedTocProvider::new_unchecked(&self.toc), + recommend_points, + FULL_ACCESS.clone(), + hw_data, + ) + .await?; + + Ok(res) + } + + async fn scroll( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let ScrollPointsInternal { + scroll_points, + shard_id, + } = request.into_inner(); + + let mut scroll_points = + scroll_points.ok_or_else(|| Status::invalid_argument("ScrollPoints is missing"))?; + + scroll_points.read_consistency = None; // *Have* to be `None`! + + scroll( + UncheckedTocProvider::new_unchecked(&self.toc), + scroll_points, + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn get( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let GetPointsInternal { + get_points, + shard_id, + } = request.into_inner(); + + let mut get_points = + get_points.ok_or_else(|| Status::invalid_argument("GetPoints is missing"))?; + + get_points.read_consistency = None; // *Have* to be `None`! + + get( + UncheckedTocProvider::new_unchecked(&self.toc), + get_points, + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn count( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let CountPointsInternal { + count_points, + shard_id, + } = request.into_inner(); + + let count_points = + count_points.ok_or_else(|| Status::invalid_argument("CountPoints is missing"))?; + let hw_data = self.get_request_collection_hw_usage_counter_for_internal( + count_points.collection_name.clone(), + ); + let res = count( + UncheckedTocProvider::new_unchecked(&self.toc), + count_points, + shard_id, + &FULL_ACCESS, + hw_data, + ) + .await?; + Ok(res) + } + + async fn sync( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let SyncPointsInternal { + sync_points, + shard_id, + clock_tag, + } = request.into_inner(); + + let sync_points = + sync_points.ok_or_else(|| Status::invalid_argument("SyncPoints is missing"))?; + sync( + self.toc.clone(), + sync_points, + clock_tag.map(Into::into), + shard_id, + FULL_ACCESS.clone(), + ) + .await + } + + async fn query_batch( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + let QueryBatchPointsInternal { + collection_name, + shard_id, + query_points, + timeout, + } = request.into_inner(); + + let timeout = timeout.map(Duration::from_secs); + + let hw_data = + self.get_request_collection_hw_usage_counter_for_internal(collection_name.clone()); + + query_batch_internal( + self.toc.as_ref(), + collection_name, + query_points, + shard_id, + timeout, + hw_data, + ) + .await + } + + async fn facet( + &self, + request: Request, + ) -> Result, Status> { + validate_and_log(request.get_ref()); + + facet_counts_internal(self.toc.as_ref(), request.into_inner()).await + } +} diff --git a/src/tonic/api/raft_api.rs b/src/tonic/api/raft_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..36be81d61216bca499efaeae16fc7785ba847bb2 --- /dev/null +++ b/src/tonic/api/raft_api.rs @@ -0,0 +1,139 @@ +use api::grpc::qdrant::raft_server::Raft; +use api::grpc::qdrant::{ + AddPeerToKnownMessage, AllPeers, Peer, PeerId, RaftMessage as RaftMessageBytes, Uri as UriStr, +}; +use itertools::Itertools; +use raft::eraftpb::Message as RaftMessage; +use storage::content_manager::consensus_manager::ConsensusStateRef; +use storage::content_manager::consensus_ops::ConsensusOperations; +use tokio::sync::mpsc::Sender; +use tonic::transport::Uri; +use tonic::{async_trait, Request, Response, Status}; + +use super::validate; +use crate::consensus; + +pub struct RaftService { + message_sender: Sender, + consensus_state: ConsensusStateRef, +} + +impl RaftService { + pub fn new(sender: Sender, consensus_state: ConsensusStateRef) -> Self { + Self { + message_sender: sender, + consensus_state, + } + } +} + +#[async_trait] +impl Raft for RaftService { + async fn send(&self, mut request: Request) -> Result, Status> { + let message = + ::decode(&request.get_mut().message[..]) + .map_err(|err| { + Status::invalid_argument(format!("Failed to parse raft message: {err}")) + })?; + self.message_sender + .send(consensus::Message::FromPeer(Box::new(message))) + .await + .map_err(|_| Status::internal("Can't send Raft message over channel"))?; + Ok(Response::new(())) + } + + async fn who_is( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let addresses = self.consensus_state.peer_address_by_id(); + let uri = addresses + .get(&request.get_ref().id) + .ok_or_else(|| Status::internal("Peer not found"))?; + Ok(Response::new(UriStr { + uri: uri.to_string(), + })) + } + + async fn add_peer_to_known( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + validate(request.get_ref())?; + let peer = request.get_ref(); + let uri_string = if let Some(uri) = &peer.uri { + uri.clone() + } else { + let ip = request + .remote_addr() + .ok_or_else(|| { + Status::failed_precondition("Remote address unavailable due to the used IO") + })? + .ip(); + let port = peer + .port + .ok_or_else(|| Status::invalid_argument("URI or port should be supplied"))?; + format!("http://{ip}:{port}") + }; + let uri: Uri = uri_string + .parse() + .map_err(|err| Status::internal(format!("Failed to parse uri: {err}")))?; + let peer = request.into_inner(); + + // the consensus operation can take up to DEFAULT_META_OP_WAIT + self.consensus_state + .propose_consensus_op_with_await( + ConsensusOperations::AddPeer { + peer_id: peer.id, + uri: uri.to_string(), + }, + None, + ) + .await + .map_err(|err| Status::internal(format!("Failed to add peer: {err}")))?; + + let mut addresses = self.consensus_state.peer_address_by_id(); + + // Make sure that the new peer is now present in the known addresses + if !addresses.values().contains(&uri) { + return Err(Status::internal(format!( + "Failed to add peer after consensus: {uri}" + ))); + } + + let first_peer_id = self.consensus_state.first_voter(); + + // If `first_peer_id` is not present in the list of peers, it means it was removed from + // cluster at some point. + // + // Before Qdrant version 1.11.6 origin peer was not committed to consensus, so if it was + // removed from cluster, any node added to the cluster after this would not recognize it as + // being part of the cluster in the past and will end up with a broken consensus state. + // + // To prevent this, we add `first_peer_id` (with a fake URI) to the list of peers. + // + // `add_peer_to_known` is used to add new peers to the cluster, and so `first_peer_id` (and + // its fake URI) would be removed from new peer's state shortly, while it will be synchronizing + // and applying past Raft log. + addresses.entry(first_peer_id).or_default(); + + Ok(Response::new(AllPeers { + all_peers: addresses + .into_iter() + .map(|(id, uri)| Peer { + id, + uri: uri.to_string(), + }) + .collect(), + first_peer_id, + })) + } + + // Left for compatibility - does nothing + async fn add_peer_as_participant( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + Ok(Response::new(())) + } +} diff --git a/src/tonic/api/snapshots_api.rs b/src/tonic/api/snapshots_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..8318a94e9f148a5bee6a3e1257291048f91ff60d --- /dev/null +++ b/src/tonic/api/snapshots_api.rs @@ -0,0 +1,283 @@ +use std::sync::Arc; +use std::time::Instant; + +use api::grpc::qdrant::shard_snapshots_server::ShardSnapshots; +use api::grpc::qdrant::snapshots_server::Snapshots; +use api::grpc::qdrant::{ + CreateFullSnapshotRequest, CreateShardSnapshotRequest, CreateSnapshotRequest, + CreateSnapshotResponse, DeleteFullSnapshotRequest, DeleteShardSnapshotRequest, + DeleteSnapshotRequest, DeleteSnapshotResponse, ListFullSnapshotsRequest, + ListShardSnapshotsRequest, ListSnapshotsRequest, ListSnapshotsResponse, + RecoverShardSnapshotRequest, RecoverSnapshotResponse, +}; +use collection::operations::verification::new_unchecked_verification_pass; +use storage::content_manager::snapshots::{ + do_create_full_snapshot, do_delete_collection_snapshot, do_delete_full_snapshot, + do_list_full_snapshots, +}; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use tonic::{async_trait, Request, Response, Status}; + +use super::{validate, validate_and_log}; +use crate::common; +use crate::common::collections::{do_create_snapshot, do_list_snapshots}; +use crate::common::http_client::HttpClient; +use crate::tonic::auth::extract_access; + +pub struct SnapshotsService { + dispatcher: Arc, +} + +impl SnapshotsService { + pub fn new(dispatcher: Arc) -> Self { + Self { dispatcher } + } +} + +#[async_trait] +impl Snapshots for SnapshotsService { + async fn create( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let access = extract_access(&mut request); + let collection_name = request.into_inner().collection_name; + let timing = Instant::now(); + let dispatcher = self.dispatcher.clone(); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let response = do_create_snapshot( + Arc::clone(dispatcher.toc(&access, &pass)), + access, + &collection_name, + ) + .await?; + + Ok(Response::new(CreateSnapshotResponse { + snapshot_description: Some(response.into()), + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn list( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let timing = Instant::now(); + let access = extract_access(&mut request); + let ListSnapshotsRequest { collection_name } = request.into_inner(); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let snapshots = do_list_snapshots( + self.dispatcher.toc(&access, &pass), + access, + &collection_name, + ) + .await?; + + Ok(Response::new(ListSnapshotsResponse { + snapshot_descriptions: snapshots.into_iter().map(|s| s.into()).collect(), + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn delete( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let timing = Instant::now(); + let access = extract_access(&mut request); + let DeleteSnapshotRequest { + collection_name, + snapshot_name, + } = request.into_inner(); + + let _response = do_delete_collection_snapshot( + &self.dispatcher, + access, + &collection_name, + &snapshot_name, + ) + .await?; + + Ok(Response::new(DeleteSnapshotResponse { + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn create_full( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let timing = Instant::now(); + let access = extract_access(&mut request); + + let response = do_create_full_snapshot(&self.dispatcher, access).await?; + + Ok(Response::new(CreateSnapshotResponse { + snapshot_description: Some(response.into()), + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn list_full( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + let timing = Instant::now(); + let access = extract_access(&mut request); + + // Nothing to verify here. + let pass = new_unchecked_verification_pass(); + + let snapshots = do_list_full_snapshots(self.dispatcher.toc(&access, &pass), access).await?; + Ok(Response::new(ListSnapshotsResponse { + snapshot_descriptions: snapshots.into_iter().map(|s| s.into()).collect(), + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn delete_full( + &self, + mut request: Request, + ) -> Result, Status> { + validate(request.get_ref())?; + + let timing = Instant::now(); + let access = extract_access(&mut request); + let snapshot_name = request.into_inner().snapshot_name; + + let _response = do_delete_full_snapshot(&self.dispatcher, access, &snapshot_name).await?; + + Ok(Response::new(DeleteSnapshotResponse { + time: timing.elapsed().as_secs_f64(), + })) + } +} + +pub struct ShardSnapshotsService { + toc: Arc, + http_client: HttpClient, +} + +impl ShardSnapshotsService { + pub fn new(toc: Arc, http_client: HttpClient) -> Self { + Self { toc, http_client } + } +} + +#[async_trait] +impl ShardSnapshots for ShardSnapshotsService { + async fn create( + &self, + mut request: Request, + ) -> Result, Status> { + let access = extract_access(&mut request); + let request = request.into_inner(); + validate_and_log(&request); + + let timing = Instant::now(); + + let snapshot_description = common::snapshots::create_shard_snapshot( + self.toc.clone(), + access, + request.collection_name, + request.shard_id, + ) + .await?; + + Ok(Response::new(CreateSnapshotResponse { + snapshot_description: Some(snapshot_description.into()), + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn list( + &self, + mut request: Request, + ) -> Result, Status> { + let access = extract_access(&mut request); + let request = request.into_inner(); + validate_and_log(&request); + + let timing = Instant::now(); + + let snapshot_descriptions = common::snapshots::list_shard_snapshots( + self.toc.clone(), + access, + request.collection_name, + request.shard_id, + ) + .await?; + + Ok(Response::new(ListSnapshotsResponse { + snapshot_descriptions: snapshot_descriptions.into_iter().map(Into::into).collect(), + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn delete( + &self, + mut request: Request, + ) -> Result, Status> { + let access = extract_access(&mut request); + let request = request.into_inner(); + validate_and_log(&request); + + let timing = Instant::now(); + + common::snapshots::delete_shard_snapshot( + self.toc.clone(), + access, + request.collection_name, + request.shard_id, + request.snapshot_name, + ) + .await?; + + Ok(Response::new(DeleteSnapshotResponse { + time: timing.elapsed().as_secs_f64(), + })) + } + + async fn recover( + &self, + mut request: Request, + ) -> Result, Status> { + let access = extract_access(&mut request); + let request = request.into_inner(); + validate_and_log(&request); + + let timing = Instant::now(); + + common::snapshots::recover_shard_snapshot( + self.toc.clone(), + access, + request.collection_name, + request.shard_id, + request.snapshot_location.try_into()?, + request.snapshot_priority.try_into()?, + request.checksum, + self.http_client.clone(), + request.api_key, + ) + .await?; + + Ok(Response::new(RecoverSnapshotResponse { + time: timing.elapsed().as_secs_f64(), + })) + } +} diff --git a/src/tonic/auth.rs b/src/tonic/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..c434d54e7310184f6d1eb335ca93c8bcd61728c6 --- /dev/null +++ b/src/tonic/auth.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::future::BoxFuture; +use storage::rbac::Access; +use tonic::body::BoxBody; +use tonic::Status; +use tower::{Layer, Service}; + +use crate::common::auth::{AuthError, AuthKeys}; + +type Request = tonic::codegen::http::Request; +type Response = tonic::codegen::http::Response; + +#[derive(Clone)] +pub struct AuthMiddleware { + auth_keys: Arc, + service: S, +} + +async fn check(auth_keys: Arc, mut req: Request) -> Result { + let access = auth_keys + .validate_request(|key| req.headers().get(key).and_then(|val| val.to_str().ok())) + .await + .map_err(|e| match e { + AuthError::Unauthorized(e) => Status::unauthenticated(e), + AuthError::Forbidden(e) => Status::permission_denied(e), + AuthError::StorageError(e) => Status::from(e), + })?; + + let previous = req.extensions_mut().insert::(access); + debug_assert!( + previous.is_none(), + "Previous access object should not exist in the request" + ); + + Ok(req) +} + +impl Service for AuthMiddleware +where + S: Service + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let auth_keys = self.auth_keys.clone(); + let mut service = self.service.clone(); + Box::pin(async move { + match check(auth_keys, request).await { + Ok(req) => service.call(req).await, + Err(e) => Ok(e.to_http()), + } + }) + } +} + +#[derive(Clone)] +pub struct AuthLayer { + auth_keys: Arc, +} + +impl AuthLayer { + pub fn new(auth_keys: AuthKeys) -> Self { + Self { + auth_keys: Arc::new(auth_keys), + } + } +} + +impl Layer for AuthLayer { + type Service = AuthMiddleware; + + fn layer(&self, service: S) -> Self::Service { + Self::Service { + auth_keys: self.auth_keys.clone(), + service, + } + } +} + +pub fn extract_access(req: &mut tonic::Request) -> Access { + req.extensions_mut().remove::().unwrap_or_else(|| { + Access::full("All requests have full by default access when API key is not configured") + }) +} diff --git a/src/tonic/logging.rs b/src/tonic/logging.rs new file mode 100644 index 0000000000000000000000000000000000000000..c5441960a1d0c9a90f53cbaaf127de12e9fc9b29 --- /dev/null +++ b/src/tonic/logging.rs @@ -0,0 +1,113 @@ +use std::task::{Context, Poll}; + +use futures_util::future::BoxFuture; +use tonic::body::BoxBody; +use tonic::codegen::http::Response; +use tonic::Code; +use tower::Service; +use tower_layer::Layer; + +#[derive(Clone)] +pub struct LoggingMiddleware { + inner: T, +} + +#[derive(Clone)] +pub struct LoggingMiddlewareLayer; + +impl LoggingMiddlewareLayer { + pub fn new() -> Self { + Self {} + } +} + +impl Service> for LoggingMiddleware +where + S: Service, Response = Response> + + Clone, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call( + &mut self, + request: tonic::codegen::http::Request, + ) -> Self::Future { + let clone = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, clone); + + let method_name = request.uri().path().to_string(); + let instant = std::time::Instant::now(); + let future = inner.call(request); + Box::pin(async move { + let response = future.await; + let elapsed_sec = instant.elapsed().as_secs_f32(); + match response { + Err(error) => { + log::error!("gGRPC request error {}", method_name); + Err(error) + } + Ok(response_tonic) => { + let grpc_status = tonic::Status::from_header_map(response_tonic.headers()); + if let Some(grpc_status) = grpc_status { + match grpc_status.code() { + Code::Ok => { + log::trace!("gRPC {} Ok {:.6}", method_name, elapsed_sec); + } + Code::Cancelled => { + // cluster mode generates a large amount of `stream error received: stream no longer needed` + log::trace!("gRPC {} {:.6}", method_name, elapsed_sec); + } + Code::DeadlineExceeded + | Code::Aborted + | Code::OutOfRange + | Code::ResourceExhausted + | Code::NotFound + | Code::InvalidArgument + | Code::AlreadyExists + | Code::FailedPrecondition + | Code::PermissionDenied + | Code::Unauthenticated => { + log::info!( + "gRPC {} failed with {} {:?} {:.6}", + method_name, + grpc_status.code(), + grpc_status.message(), + elapsed_sec, + ); + } + Code::Internal + | Code::Unimplemented + | Code::Unavailable + | Code::DataLoss + | Code::Unknown => log::error!( + "gRPC {} unexpectedly failed with {} {:?} {:.6}", + method_name, + grpc_status.code(), + grpc_status.message(), + elapsed_sec, + ), + }; + } else { + log::trace!("gRPC {} Ok {:.6}", method_name, elapsed_sec); + } + Ok(response_tonic) + } + } + }) + } +} + +impl Layer for LoggingMiddlewareLayer { + type Service = LoggingMiddleware; + + fn layer(&self, service: S) -> Self::Service { + LoggingMiddleware { inner: service } + } +} diff --git a/src/tonic/mod.rs b/src/tonic/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..45df86b4eda7714f7c59e1f871ed6c4d02564514 --- /dev/null +++ b/src/tonic/mod.rs @@ -0,0 +1,360 @@ +mod api; +mod auth; +mod logging; +mod tonic_telemetry; +pub(super) mod verification; + +use std::io; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use std::time::Duration; + +use ::api::grpc::grpc_health_v1::health_check_response::ServingStatus; +use ::api::grpc::grpc_health_v1::health_server::{Health, HealthServer}; +use ::api::grpc::grpc_health_v1::{ + HealthCheckRequest as ProtocolHealthCheckRequest, + HealthCheckResponse as ProtocolHealthCheckResponse, +}; +use ::api::grpc::qdrant::collections_internal_server::CollectionsInternalServer; +use ::api::grpc::qdrant::collections_server::CollectionsServer; +use ::api::grpc::qdrant::points_internal_server::PointsInternalServer; +use ::api::grpc::qdrant::points_server::PointsServer; +use ::api::grpc::qdrant::qdrant_internal_server::{QdrantInternal, QdrantInternalServer}; +use ::api::grpc::qdrant::qdrant_server::{Qdrant, QdrantServer}; +use ::api::grpc::qdrant::shard_snapshots_server::ShardSnapshotsServer; +use ::api::grpc::qdrant::snapshots_server::SnapshotsServer; +use ::api::grpc::qdrant::{ + GetConsensusCommitRequest, GetConsensusCommitResponse, HealthCheckReply, HealthCheckRequest, + WaitOnConsensusCommitRequest, WaitOnConsensusCommitResponse, +}; +use ::api::grpc::QDRANT_DESCRIPTOR_SET; +use ::api::rest::models::VersionInfo; +use collection::operations::verification::new_unchecked_verification_pass; +use storage::content_manager::consensus_manager::ConsensusStateRef; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +use tokio::runtime::Handle; +use tokio::signal; +use tonic::codec::CompressionEncoding; +use tonic::transport::{Server, ServerTlsConfig}; +use tonic::{Request, Response, Status}; + +use crate::common::auth::AuthKeys; +use crate::common::helpers; +use crate::common::http_client::HttpClient; +use crate::common::telemetry_ops::requests_telemetry::TonicTelemetryCollector; +use crate::settings::Settings; +use crate::tonic::api::collections_api::CollectionsService; +use crate::tonic::api::collections_internal_api::CollectionsInternalService; +use crate::tonic::api::points_api::PointsService; +use crate::tonic::api::points_internal_api::PointsInternalService; +use crate::tonic::api::snapshots_api::{ShardSnapshotsService, SnapshotsService}; + +#[derive(Default)] +pub struct QdrantService {} + +#[tonic::async_trait] +impl Qdrant for QdrantService { + async fn health_check( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(VersionInfo::default().into())) + } +} + +// Additional health check service that follows gRPC health check protocol as described in #2614 +#[derive(Default)] +pub struct HealthService {} + +#[tonic::async_trait] +impl Health for HealthService { + async fn check( + &self, + _request: Request, + ) -> Result, Status> { + let response = ProtocolHealthCheckResponse { + status: ServingStatus::Serving as i32, + }; + + Ok(Response::new(response)) + } +} + +pub struct QdrantInternalService { + /// Qdrant settings + settings: Settings, + /// Consensus state + consensus_state: ConsensusStateRef, +} + +impl QdrantInternalService { + fn new(settings: Settings, consensus_state: ConsensusStateRef) -> Self { + Self { + settings, + consensus_state, + } + } +} + +#[tonic::async_trait] +impl QdrantInternal for QdrantInternalService { + async fn get_consensus_commit( + &self, + _: tonic::Request, + ) -> Result, Status> { + let persistent = self.consensus_state.persistent.read(); + let commit = persistent.state.hard_state.commit as _; + let term = persistent.state.hard_state.term as _; + Ok(Response::new(GetConsensusCommitResponse { commit, term })) + } + + async fn wait_on_consensus_commit( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let commit = request.commit as u64; + let term = request.term as u64; + let timeout = Duration::from_secs(request.timeout as u64); + let consensus_tick = Duration::from_millis(self.settings.cluster.consensus.tick_period_ms); + let ok = self + .consensus_state + .wait_for_consensus_commit(commit, term, consensus_tick, timeout) + .await + .is_ok(); + Ok(Response::new(WaitOnConsensusCommitResponse { ok })) + } +} + +#[cfg(not(unix))] +async fn wait_stop_signal(for_what: &str) { + signal::ctrl_c().await.unwrap(); + log::debug!("Stopping {for_what} on SIGINT"); +} + +#[cfg(unix)] +async fn wait_stop_signal(for_what: &str) { + let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()).unwrap(); + let mut inrt = signal::unix::signal(signal::unix::SignalKind::interrupt()).unwrap(); + + tokio::select! { + _ = term.recv() => log::debug!("Stopping {for_what} on SIGTERM"), + _ = inrt.recv() => log::debug!("Stopping {for_what} on SIGINT"), + } +} + +pub fn init( + dispatcher: Arc, + telemetry_collector: Arc>, + settings: Settings, + grpc_port: u16, + runtime: Handle, +) -> io::Result<()> { + runtime.block_on(async { + let socket = + SocketAddr::from((settings.service.host.parse::().unwrap(), grpc_port)); + + let qdrant_service = QdrantService::default(); + let health_service = HealthService::default(); + let collections_service = CollectionsService::new(dispatcher.clone()); + let points_service = PointsService::new(dispatcher.clone(), settings.service.clone()); + let snapshot_service = SnapshotsService::new(dispatcher.clone()); + + // Only advertise the public services. By default, all services in QDRANT_DESCRIPTOR_SET + // will be advertised, so explicitly list the services to be included. + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(QDRANT_DESCRIPTOR_SET) + .with_service_name("qdrant.Collections") + .with_service_name("qdrant.Points") + .with_service_name("qdrant.Snapshots") + .with_service_name("qdrant.Qdrant") + .with_service_name("grpc.health.v1.Health") + .build() + .unwrap(); + + log::info!("Qdrant gRPC listening on {}", grpc_port); + + let mut server = Server::builder(); + + if settings.service.enable_tls { + log::info!("TLS enabled for gRPC API (TTL not supported)"); + + let tls_server_config = helpers::load_tls_external_server_config(settings.tls()?)?; + + server = server + .tls_config(tls_server_config) + .map_err(helpers::tonic_error_to_io_error)?; + } else { + log::info!("TLS disabled for gRPC API"); + } + + // The stack of middleware that our service will be wrapped in + let middleware_layer = tower::ServiceBuilder::new() + .layer(logging::LoggingMiddlewareLayer::new()) + .layer(tonic_telemetry::TonicTelemetryLayer::new( + telemetry_collector, + )) + .option_layer({ + AuthKeys::try_create( + &settings.service, + dispatcher + .toc( + &Access::full("For tonic auth middleware"), + &new_unchecked_verification_pass(), + ) + .clone(), + ) + .map(auth::AuthLayer::new) + }) + .into_inner(); + + server + .layer(middleware_layer) + .add_service(reflection_service) + .add_service( + QdrantServer::new(qdrant_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + CollectionsServer::new(collections_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + PointsServer::new(points_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + SnapshotsServer::new(snapshot_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + HealthServer::new(health_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .serve_with_shutdown(socket, async { + wait_stop_signal("gRPC service").await; + }) + .await + .map_err(helpers::tonic_error_to_io_error) + })?; + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn init_internal( + toc: Arc, + consensus_state: ConsensusStateRef, + telemetry_collector: Arc>, + settings: Settings, + host: String, + internal_grpc_port: u16, + tls_config: Option, + to_consensus: tokio::sync::mpsc::Sender, + runtime: Handle, +) -> std::io::Result<()> { + use ::api::grpc::qdrant::raft_server::RaftServer; + + use crate::tonic::api::raft_api::RaftService; + + let http_client = HttpClient::from_settings(&settings)?; + + runtime + .block_on(async { + let socket = SocketAddr::from((host.parse::().unwrap(), internal_grpc_port)); + + let qdrant_service = QdrantService::default(); + let points_internal_service = + PointsInternalService::new(toc.clone(), settings.service.clone()); + let qdrant_internal_service = + QdrantInternalService::new(settings, consensus_state.clone()); + let collections_internal_service = CollectionsInternalService::new(toc.clone()); + let shard_snapshots_service = ShardSnapshotsService::new(toc.clone(), http_client); + let raft_service = RaftService::new(to_consensus, consensus_state); + + log::debug!("Qdrant internal gRPC listening on {}", internal_grpc_port); + + let mut server = Server::builder() + // Internally use a high limit for pending accept streams. + // We can have a huge number of reset/dropped HTTP2 streams in our internal + // communication when there are a lot of clients dropping connections. This + // internally causes an GOAWAY/ENHANCE_YOUR_CALM error breaking cluster consensus. + // We prefer to keep more pending reset streams even though this may be expensive, + // versus an internal error that is very hard to handle. + // More info: + .http2_max_pending_accept_reset_streams(Some(1024)); + + if let Some(config) = tls_config { + log::info!("TLS enabled for internal gRPC API (TTL not supported)"); + + server = server.tls_config(config)?; + } else { + log::info!("TLS disabled for internal gRPC API"); + }; + + // The stack of middleware that our service will be wrapped in + let middleware_layer = tower::ServiceBuilder::new() + .layer(logging::LoggingMiddlewareLayer::new()) + .layer(tonic_telemetry::TonicTelemetryLayer::new( + telemetry_collector, + )) + .into_inner(); + + server + .layer(middleware_layer) + .add_service( + QdrantServer::new(qdrant_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + QdrantInternalServer::new(qdrant_internal_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + CollectionsInternalServer::new(collections_internal_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + PointsInternalServer::new(points_internal_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + ShardSnapshotsServer::new(shard_snapshots_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .add_service( + RaftServer::new(raft_service) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + .max_decoding_message_size(usize::MAX), + ) + .serve_with_shutdown(socket, async { + wait_stop_signal("internal gRPC").await; + }) + .await + }) + .unwrap(); + Ok(()) +} diff --git a/src/tonic/tonic_telemetry.rs b/src/tonic/tonic_telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..239abebcb7486d773de990226d4784b14ee08e44 --- /dev/null +++ b/src/tonic/tonic_telemetry.rs @@ -0,0 +1,74 @@ +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures_util::future::BoxFuture; +use tower::Service; +use tower_layer::Layer; + +use crate::common::telemetry_ops::requests_telemetry::{ + TonicTelemetryCollector, TonicWorkerTelemetryCollector, +}; + +#[derive(Clone)] +pub struct TonicTelemetryService { + service: T, + telemetry_data: Arc>, +} + +#[derive(Clone)] +pub struct TonicTelemetryLayer { + telemetry_collector: Arc>, +} + +impl Service> for TonicTelemetryService +where + S: Service>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call( + &mut self, + request: tonic::codegen::http::Request, + ) -> Self::Future { + let method_name = request.uri().path().to_string(); + let future = self.service.call(request); + let telemetry_data = self.telemetry_data.clone(); + Box::pin(async move { + let instant = std::time::Instant::now(); + let response = future.await?; + telemetry_data.lock().add_response(method_name, instant); + Ok(response) + }) + } +} + +impl TonicTelemetryLayer { + pub fn new( + telemetry_collector: Arc>, + ) -> TonicTelemetryLayer { + Self { + telemetry_collector, + } + } +} + +impl Layer for TonicTelemetryLayer { + type Service = TonicTelemetryService; + + fn layer(&self, service: S) -> Self::Service { + TonicTelemetryService { + service, + telemetry_data: self + .telemetry_collector + .lock() + .create_grpc_telemetry_collector(), + } + } +} diff --git a/src/tonic/verification.rs b/src/tonic/verification.rs new file mode 100644 index 0000000000000000000000000000000000000000..2382a51e83e1e9f0eaa22809161334259c4ba40c --- /dev/null +++ b/src/tonic/verification.rs @@ -0,0 +1,118 @@ +use std::sync::Arc; + +use collection::operations::verification::StrictModeVerification; +use storage::content_manager::collection_verification::{ + check_strict_mode, check_strict_mode_batch, +}; +use storage::content_manager::toc::TableOfContent; +use storage::dispatcher::Dispatcher; +use storage::rbac::Access; +use tonic::Status; + +/// Trait for different ways of providing something with `toc` that may do additional checks eg. for Strict mode. +pub trait CheckedTocProvider { + async fn check_strict_mode<'b>( + &'b self, + request: &impl StrictModeVerification, + collection_name: &str, + timeout: Option, + access: &Access, + ) -> Result<&'b Arc, Status>; + + async fn check_strict_mode_batch<'b, I, R>( + &'b self, + requests: &[I], + conv: impl Fn(&I) -> &R, + collection_name: &str, + timeout: Option, + access: &Access, + ) -> Result<&'b Arc, Status> + where + R: StrictModeVerification; +} + +/// Simple provider for TableOfContent that doesn't do any checks. +pub struct UncheckedTocProvider<'a> { + toc: &'a Arc, +} + +impl<'a> UncheckedTocProvider<'a> { + pub fn new_unchecked(toc: &'a Arc) -> Self { + Self { toc } + } +} + +impl<'a> CheckedTocProvider for UncheckedTocProvider<'a> { + async fn check_strict_mode<'b>( + &'b self, + _request: &impl StrictModeVerification, + _collection_name: &str, + _timeout: Option, + _access: &Access, + ) -> Result<&'b Arc, Status> { + // No checks here + Ok(self.toc) + } + + async fn check_strict_mode_batch<'b, I, R>( + &'b self, + _requests: &[I], + _conv: impl Fn(&I) -> &R, + _collection_name: &str, + _timeout: Option, + _access: &Access, + ) -> Result<&'b Arc, Status> + where + R: StrictModeVerification, + { + // No checks here + Ok(self.toc) + } +} + +/// Provider for TableOfContent that requires Strict mode to be checked. +pub struct StrictModeCheckedTocProvider<'a> { + dispatcher: &'a Dispatcher, +} + +impl<'a> StrictModeCheckedTocProvider<'a> { + pub fn new(dispatcher: &'a Dispatcher) -> Self { + Self { dispatcher } + } +} + +impl<'a> CheckedTocProvider for StrictModeCheckedTocProvider<'a> { + async fn check_strict_mode( + &self, + request: &impl StrictModeVerification, + collection_name: &str, + timeout: Option, + access: &Access, + ) -> Result<&Arc, Status> { + let pass = + check_strict_mode(request, timeout, collection_name, self.dispatcher, access).await?; + Ok(self.dispatcher.toc(access, &pass)) + } + + async fn check_strict_mode_batch<'b, I, R>( + &'b self, + requests: &[I], + conv: impl Fn(&I) -> &R, + collection_name: &str, + timeout: Option, + access: &Access, + ) -> Result<&'b Arc, Status> + where + R: StrictModeVerification, + { + let pass = check_strict_mode_batch( + requests.iter().map(conv), + timeout, + collection_name, + self.dispatcher, + access, + ) + .await?; + Ok(self.dispatcher.toc(access, &pass)) + } +} diff --git a/src/tracing/config.rs b/src/tracing/config.rs new file mode 100644 index 0000000000000000000000000000000000000000..478e66d3254d012407a0a3b962a17d470cec0d01 --- /dev/null +++ b/src/tracing/config.rs @@ -0,0 +1,91 @@ +use std::collections::HashSet; + +use serde::{Deserialize, Serialize}; +use tracing_subscriber::fmt; + +use super::*; + +#[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] +#[serde(default)] +pub struct LoggerConfig { + #[serde(flatten)] + pub default: default::Config, + #[serde(default)] + pub on_disk: on_disk::Config, +} + +impl LoggerConfig { + pub fn with_top_level_directive(&self, log_level: Option) -> Self { + let mut logger_config = self.clone(); + + if logger_config.default.log_level.is_some() && log_level.is_some() { + eprintln!( + "Both top-level `log_level` and `logger.log_level` config directives are used. \ + `logger.log_level` takes priority, so top-level `log_level` will be ignored." + ); + } + + if logger_config.default.log_level.is_none() { + logger_config.default.log_level = log_level; + } + + logger_config + } + + pub fn merge(&mut self, other: Self) { + self.default.merge(other.default); + self.on_disk.merge(other.on_disk); + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)] +#[serde(rename_all = "lowercase")] +pub enum SpanEvent { + New, + Enter, + Exit, + Close, +} + +impl SpanEvent { + pub fn unwrap_or_default_config(events: &Option>) -> fmt::format::FmtSpan { + Self::into_fmt_span(events.as_ref().unwrap_or(&HashSet::new()).iter().copied()) + } + + pub fn into_fmt_span(events: impl IntoIterator) -> fmt::format::FmtSpan { + events + .into_iter() + .fold(fmt::format::FmtSpan::NONE, |events, event| { + events | event.into() + }) + } +} + +impl From for fmt::format::FmtSpan { + fn from(event: SpanEvent) -> Self { + match event { + SpanEvent::New => fmt::format::FmtSpan::NEW, + SpanEvent::Enter => fmt::format::FmtSpan::ENTER, + SpanEvent::Exit => fmt::format::FmtSpan::EXIT, + SpanEvent::Close => fmt::format::FmtSpan::CLOSE, + } + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Deserialize, Serialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum Color { + #[default] + Auto, + #[serde(untagged)] + Explicit(bool), +} + +impl Color { + pub fn to_bool(self) -> bool { + match self { + Self::Auto => colored::control::SHOULD_COLORIZE.should_colorize(), + Self::Explicit(bool) => bool, + } + } +} diff --git a/src/tracing/default.rs b/src/tracing/default.rs new file mode 100644 index 0000000000000000000000000000000000000000..eef31444a1b0d24dc36e72bd386a9028a56f625d --- /dev/null +++ b/src/tracing/default.rs @@ -0,0 +1,61 @@ +use std::collections::HashSet; + +use common::ext::OptionExt; +use serde::{Deserialize, Serialize}; +use tracing_subscriber::prelude::*; +use tracing_subscriber::{filter, fmt, registry}; + +use super::*; + +#[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] +#[serde(default)] +pub struct Config { + pub log_level: Option, + pub span_events: Option>, + pub color: Option, +} + +impl Config { + pub fn merge(&mut self, other: Self) { + let Self { + log_level, + span_events, + color, + } = other; + + self.log_level.replace_if_some(log_level); + self.span_events.replace_if_some(span_events); + self.color.replace_if_some(color); + } +} + +#[rustfmt::skip] // `rustfmt` formats this into unreadable single line +pub type Logger = filter::Filtered< + Option>, + filter::EnvFilter, + S, +>; + +pub fn new_logger(config: &Config) -> Logger +where + S: tracing::Subscriber + for<'span> registry::LookupSpan<'span>, +{ + let layer = new_layer(config); + let filter = new_filter(config); + Some(layer).with_filter(filter) +} + +pub fn new_layer(config: &Config) -> fmt::Layer +where + S: tracing::Subscriber + for<'span> registry::LookupSpan<'span>, +{ + fmt::Layer::default() + .with_span_events(config::SpanEvent::unwrap_or_default_config( + &config.span_events, + )) + .with_ansi(config.color.unwrap_or_default().to_bool()) +} + +pub fn new_filter(config: &Config) -> filter::EnvFilter { + filter(config.log_level.as_deref().unwrap_or("")) +} diff --git a/src/tracing/handle.rs b/src/tracing/handle.rs new file mode 100644 index 0000000000000000000000000000000000000000..375fe572202123d5533db95ddc3587651e89e903 --- /dev/null +++ b/src/tracing/handle.rs @@ -0,0 +1,94 @@ +use std::sync::Arc; + +use tokio::sync::RwLock; +use tracing_subscriber::{layer, reload, Registry}; + +use super::*; + +#[derive(Clone)] +pub struct LoggerHandle { + config: Arc>, + default: DefaultLoggerReloadHandle, + on_disk: OnDiskLoggerReloadHandle, +} + +#[rustfmt::skip] // `rustfmt` formats this into unreadable single line +type DefaultLoggerReloadHandle = reload::Handle< + default::Logger, + S, +>; + +#[rustfmt::skip] // `rustfmt` formats this into unreadable single line +type DefaultLoggerSubscriber = layer::Layered< + reload::Layer, S>, + S, +>; + +#[rustfmt::skip] // `rustfmt` formats this into unreadable single line +type OnDiskLoggerReloadHandle = reload::Handle< + on_disk::Logger, + S, +>; + +impl LoggerHandle { + pub fn new( + config: config::LoggerConfig, + default: DefaultLoggerReloadHandle, + on_disk: OnDiskLoggerReloadHandle, + ) -> Self { + Self { + config: Arc::new(RwLock::new(config)), + default, + on_disk, + } + } + + pub async fn get_config(&self) -> config::LoggerConfig { + self.config.read().await.clone() + } + + pub async fn update_config(&self, new_config: config::LoggerConfig) -> anyhow::Result<()> { + let mut config = self.config.write().await; + + // `tracing-subscriber` does not support `reload`ing `Filtered` layers, so we *have to* use + // `modify`. However, `modify` would *deadlock* if provided closure logs anything or produce + // any `tracing` event. + // + // So, we structure `update_config` to only do an absolute minimum of changes and only use + // the most trivial operations during `modify`, to guarantee we won't deadlock. + // + // See: + // - https://docs.rs/tracing-subscriber/latest/tracing_subscriber/reload/struct.Handle.html#method.reload + // - https://github.com/tokio-rs/tracing/issues/1629 + // - https://github.com/tokio-rs/tracing/pull/2657 + + let mut merged_config = config.clone(); + merged_config.merge(new_config); + + if merged_config.on_disk != config.on_disk { + let new_layer = on_disk::new_layer(&merged_config.on_disk)?; + let new_filter = on_disk::new_filter(&merged_config.on_disk); + + self.on_disk.modify(move |logger| { + *logger.inner_mut() = new_layer; + *logger.filter_mut() = new_filter; + })?; + + config.on_disk = merged_config.on_disk; + } + + if merged_config.default != config.default { + let new_layer = default::new_layer(&merged_config.default); + let new_filter = default::new_filter(&merged_config.default); + + self.default.modify(move |logger| { + *logger.inner_mut() = Some(new_layer); + *logger.filter_mut() = new_filter; + })?; + + config.default = merged_config.default; + } + + Ok(()) + } +} diff --git a/src/tracing/mod.rs b/src/tracing/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..e96a2099d627b5df7b5ca66db73c579c8bee9ce8 --- /dev/null +++ b/src/tracing/mod.rs @@ -0,0 +1,110 @@ +#![allow(dead_code)] // `schema_generator` binary target produce warnings + +pub mod config; +pub mod default; +pub mod handle; +pub mod on_disk; + +#[cfg(test)] +mod test; + +use std::fmt::Write as _; +use std::str::FromStr as _; + +use tracing_subscriber::prelude::*; +use tracing_subscriber::{filter, reload}; + +pub use self::config::LoggerConfig; +pub use self::handle::LoggerHandle; + +const DEFAULT_LOG_LEVEL: log::LevelFilter = log::LevelFilter::Info; + +const DEFAULT_FILTERS: &[(&str, log::LevelFilter)] = &[ + ("hyper", log::LevelFilter::Info), + ("h2", log::LevelFilter::Error), + ("tower", log::LevelFilter::Warn), + ("rustls", log::LevelFilter::Info), + ("wal", log::LevelFilter::Warn), + ("raft", log::LevelFilter::Warn), +]; + +pub fn setup(mut config: config::LoggerConfig) -> anyhow::Result { + // Note that on-disk logger *have* to be initialized *before* default logger! + // + // If default logger is initialized before on-disk logger, then ANSI escape-sequences (that are + // used to apply color and formatting in the terminal, but looks like corrupted text in the text + // editor) might appear in the on-disk log-file. + // + // This happens because when multiple `fmt::Layer`s are initialized in the same subscriber, + // the top-level `fmt::Layer` would cache pre-formatted fragments of the log-line + // for the next `fmt::Layer`s to reuse. + // + // And default logger outputs colored log-lines, which on-disk logger reuse even if colors are + // disabled for the on-disk logger. :/ + + let on_disk_logger = on_disk::new_logger(&mut config.on_disk); + let (on_disk_logger, on_disk_logger_handle) = reload::Layer::new(on_disk_logger); + let reg = tracing_subscriber::registry().with(on_disk_logger); + + let default_logger = default::new_logger(&config.default); + let (default_logger, default_logger_handle) = reload::Layer::new(default_logger); + let reg = reg.with(default_logger); + + let logger_handle = LoggerHandle::new(config, default_logger_handle, on_disk_logger_handle); + + // Use `console` or `console-subscriber` feature to enable `console-subscriber` + // + // Note, that `console-subscriber` requires manually enabling + // `--cfg tokio_unstable` rust flags during compilation! + // + // Otherwise `console_subscriber::spawn` call panics! + // + // See https://docs.rs/tokio/latest/tokio/#unstable-features + #[cfg(all(feature = "console-subscriber", tokio_unstable))] + let reg = reg.with(console_subscriber::spawn()); + + #[cfg(all(feature = "console-subscriber", not(tokio_unstable)))] + eprintln!( + "`console-subscriber` requires manually enabling \ + `--cfg tokio_unstable` rust flags during compilation!" + ); + + // Use `tracy` or `tracing-tracy` feature to enable `tracing-tracy` + #[cfg(feature = "tracing-tracy")] + let reg = reg.with( + tracing_tracy::TracyLayer::new(tracing_tracy::DefaultConfig::default()).with_filter( + tracing_subscriber::filter::filter_fn(|metadata| metadata.is_span()), + ), + ); + + tracing::subscriber::set_global_default(reg)?; + tracing_log::LogTracer::init()?; + + Ok(logger_handle) +} + +fn filter(user_filters: &str) -> filter::EnvFilter { + let mut filter = String::new(); + + let user_log_level = user_filters + .rsplit(',') + .find_map(|dir| log::LevelFilter::from_str(dir).ok()); + + if user_log_level.is_none() { + write!(&mut filter, "{DEFAULT_LOG_LEVEL}").unwrap(); // Writing into `String` never fails + } + + for &(target, log_level) in DEFAULT_FILTERS { + if user_log_level.unwrap_or(DEFAULT_LOG_LEVEL) > log_level { + let comma = if filter.is_empty() { "" } else { "," }; + write!(&mut filter, "{comma}{target}={log_level}").unwrap(); // Writing into `String` never fails + } + } + + let comma = if filter.is_empty() { "" } else { "," }; + write!(&mut filter, "{comma}{user_filters}").unwrap(); // Writing into `String` never fails + + filter::EnvFilter::builder() + .with_regex(false) + .parse_lossy(filter) +} diff --git a/src/tracing/on_disk.rs b/src/tracing/on_disk.rs new file mode 100644 index 0000000000000000000000000000000000000000..1c29bfa4e2340ef21ab246b488f2ee5ca7710d7c --- /dev/null +++ b/src/tracing/on_disk.rs @@ -0,0 +1,106 @@ +use std::collections::HashSet; +use std::sync::Mutex; +use std::{fs, io}; + +use anyhow::Context as _; +use common::ext::OptionExt; +use serde::{Deserialize, Serialize}; +use tracing_subscriber::prelude::*; +use tracing_subscriber::{filter, fmt, registry}; + +use super::*; + +#[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] +#[serde(default)] +pub struct Config { + pub enabled: Option, + pub log_file: Option, + pub log_level: Option, + pub span_events: Option>, +} + +impl Config { + pub fn merge(&mut self, other: Self) { + let Self { + enabled, + log_file, + log_level, + span_events, + } = other; + + self.enabled.replace_if_some(enabled); + self.log_file.replace_if_some(log_file); + self.log_level.replace_if_some(log_level); + self.span_events.replace_if_some(span_events); + } +} + +#[rustfmt::skip] // `rustfmt` formats this into unreadable single line :/ +pub type Logger = filter::Filtered< + Option>, + filter::EnvFilter, + S, +>; + +#[rustfmt::skip] // `rustfmt` formats this into unreadable single line :/ +pub type Layer = fmt::Layer< + S, + fmt::format::DefaultFields, + fmt::format::Format, + MakeWriter, +>; + +pub type MakeWriter = Mutex>; + +pub fn new_logger(config: &mut Config) -> Logger +where + S: tracing::Subscriber + for<'span> registry::LookupSpan<'span>, +{ + let layer = match new_layer(config) { + Ok(layer) => layer, + Err(err) => { + eprintln!( + "failed to enable logging into {} log-file: {err}", + config.log_file.as_deref().unwrap_or(""), + ); + + config.enabled = Some(false); + None + } + }; + + let filter = new_filter(config); + layer.with_filter(filter) +} + +pub fn new_layer(config: &Config) -> anyhow::Result>> +where + S: tracing::Subscriber + for<'span> registry::LookupSpan<'span>, +{ + if !config.enabled.unwrap_or_default() { + return Ok(None); + } + + let Some(log_file) = &config.log_file else { + return Err(anyhow::format_err!("log file is not specified")); + }; + + let writer = fs::OpenOptions::new() + .create(true) + .append(true) + .open(log_file) + .with_context(|| format!("failed to open {log_file} log-file"))?; + + let layer = fmt::Layer::default() + .with_writer(Mutex::new(io::BufWriter::new(writer))) + .with_span_events(config::SpanEvent::unwrap_or_default_config( + &config.span_events, + )) + .with_ansi(false); + + Ok(Some(layer)) +} + +pub fn new_filter(config: &Config) -> filter::EnvFilter { + filter(config.log_level.as_deref().unwrap_or("")) +} diff --git a/src/tracing/test.rs b/src/tracing/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..072ec6cb4bb1e8af912a5b55273ea51b85cf2298 --- /dev/null +++ b/src/tracing/test.rs @@ -0,0 +1,81 @@ +use std::collections::HashSet; + +use serde_json::json; + +use super::*; + +#[test] +fn deseriailze_logger_config() { + let json = json!({ + "log_level": "debug", + "span_events": ["new", "close"], + "color": true, + + "on_disk": { + "enabled": true, + "log_file": "/logs/qdrant", + "log_level": "tracing", + "span_events": ["new", "close"], + } + }); + + let config = deserialize_config(json); + + let expected = LoggerConfig { + default: default::Config { + log_level: Some("debug".into()), + span_events: Some(HashSet::from([ + config::SpanEvent::New, + config::SpanEvent::Close, + ])), + color: Some(config::Color::Explicit(true)), + }, + + on_disk: on_disk::Config { + enabled: Some(true), + log_file: Some("/logs/qdrant".into()), + log_level: Some("tracing".into()), + span_events: Some(HashSet::from([ + config::SpanEvent::New, + config::SpanEvent::Close, + ])), + }, + }; + + assert_eq!(config, expected); +} + +#[test] +fn deserialize_empty_config() { + let config = deserialize_config(json!({})); + assert_eq!(config, LoggerConfig::default()); +} + +#[test] +fn deserialize_config_with_empty_on_disk() { + let config = deserialize_config(json!({ "on_disk": {} })); + assert_eq!(config, LoggerConfig::default()); +} + +#[test] +fn deseriailze_config_with_explicit_nulls() { + let json = json!({ + "log_level": null, + "span_events": null, + "color": null, + + "on_disk": { + "enabled": null, + "log_file": null, + "log_level": null, + "span_events": null, + } + }); + + let config = deserialize_config(json); + assert_eq!(config, LoggerConfig::default()); +} + +fn deserialize_config(json: serde_json::Value) -> LoggerConfig { + serde_json::from_value(json).unwrap() +} diff --git a/src/wal_inspector.rs b/src/wal_inspector.rs new file mode 100644 index 0000000000000000000000000000000000000000..072e7b8816f34ed5eea1c12cce2fdb179ca991ce --- /dev/null +++ b/src/wal_inspector.rs @@ -0,0 +1,83 @@ +use std::env; +use std::path::Path; + +use collection::operations::OperationWithClockTag; +use collection::wal::SerdeWal; +use storage::content_manager::consensus::consensus_wal::ConsensusOpWal; +use storage::content_manager::consensus_ops::ConsensusOperations; +use wal::WalOptions; + +/// Executable to inspect the content of a write ahead log folder (collection OR consensus WAL). +/// e.g: +/// `cargo run --bin wal_inspector storage/collections/test-collection/0/wal/ collection` +/// `cargo run --bin wal_inspector -- storage/node4/wal/ consensus` (expects `collections_meta_wal` folder as first child) +fn main() { + let args: Vec = env::args().collect(); + let wal_path = Path::new(&args[1]); + let wal_type = args[2].as_str(); + match wal_type { + "collection" => print_collection_wal(wal_path), + "consensus" => print_consensus_wal(wal_path), + _ => eprintln!("Unknown wal type: {wal_type}"), + } +} + +fn print_consensus_wal(wal_path: &Path) { + // must live within a folder named `collections_meta_wal` + let wal = ConsensusOpWal::new(wal_path.to_str().unwrap()); + println!("=========================="); + let first_index = wal.first_entry().unwrap(); + println!("First entry: {first_index:?}"); + let last_index = wal.last_entry().unwrap(); + println!("Last entry: {last_index:?}"); + println!( + "Offset of first entry: {:?}", + wal.index_offset().unwrap().wal_to_raft_offset + ); + let entries = wal + .entries( + first_index.map(|f| f.index).unwrap_or(1), + last_index.map(|f| f.index).unwrap_or(0) + 1, + None, + ) + .unwrap(); + for entry in entries { + println!("=========================="); + let command = ConsensusOperations::try_from(&entry); + let data = match command { + Ok(command) => format!("{command:?}"), + Err(_) => format!("{:?}", entry.data), + }; + println!( + "Entry ID:{}\nterm:{}\nentry_type:{}\ndata:{:?}", + entry.index, entry.term, entry.entry_type, data + ) + } +} + +fn print_collection_wal(wal_path: &Path) { + let wal: Result, _> = + SerdeWal::new(wal_path.to_str().unwrap(), WalOptions::default()); + + match wal { + Err(error) => { + eprintln!("Unable to open write ahead log in directory {wal_path:?}: {error}."); + } + Ok(wal) => { + // print all entries + let mut count = 0; + for (idx, op) in wal.read_all(false) { + println!("=========================="); + println!("Entry: {idx}"); + println!("Operation: {:?}", op.operation); + if let Some(clock_tag) = op.clock_tag { + println!("Clock: {clock_tag:?}"); + } + count += 1; + } + println!("=========================="); + println!("End of WAL."); + println!("Found {count} entries."); + } + } +}