Spaces:
Build error
Build error
Gouzi Mohaled
commited on
Commit
·
d8435ba
1
Parent(s):
3932407
Ajout du dossier src
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/actix/actix_telemetry.rs +90 -0
- src/actix/api/cluster_api.rs +189 -0
- src/actix/api/collections_api.rs +256 -0
- src/actix/api/count_api.rs +69 -0
- src/actix/api/debug_api.rs +36 -0
- src/actix/api/discovery_api.rs +140 -0
- src/actix/api/facet_api.rs +77 -0
- src/actix/api/issues_api.rs +32 -0
- src/actix/api/local_shard_api.rs +267 -0
- src/actix/api/mod.rs +46 -0
- src/actix/api/query_api.rs +232 -0
- src/actix/api/read_params.rs +118 -0
- src/actix/api/recommend_api.rs +235 -0
- src/actix/api/retrieve_api.rs +200 -0
- src/actix/api/search_api.rs +333 -0
- src/actix/api/service_api.rs +217 -0
- src/actix/api/shards_api.rs +80 -0
- src/actix/api/snapshot_api.rs +585 -0
- src/actix/api/update_api.rs +392 -0
- src/actix/auth.rs +160 -0
- src/actix/certificate_helpers.rs +203 -0
- src/actix/helpers.rs +179 -0
- src/actix/mod.rs +262 -0
- src/actix/web_ui.rs +115 -0
- src/common/auth/claims.rs +69 -0
- src/common/auth/jwt_parser.rs +155 -0
- src/common/auth/mod.rs +165 -0
- src/common/collections.rs +834 -0
- src/common/debugger.rs +90 -0
- src/common/error_reporting.rs +31 -0
- src/common/health.rs +372 -0
- src/common/helpers.rs +151 -0
- src/common/http_client.rs +156 -0
- src/common/inference/batch_processing.rs +370 -0
- src/common/inference/batch_processing_grpc.rs +281 -0
- src/common/inference/config.rs +23 -0
- src/common/inference/infer_processing.rs +72 -0
- src/common/inference/mod.rs +8 -0
- src/common/inference/query_requests_grpc.rs +535 -0
- src/common/inference/query_requests_rest.rs +415 -0
- src/common/inference/service.rs +266 -0
- src/common/inference/update_requests.rs +409 -0
- src/common/metrics.rs +505 -0
- src/common/mod.rs +31 -0
- src/common/points.rs +1175 -0
- src/common/pyroscope_state.rs +93 -0
- src/common/snapshots.rs +284 -0
- src/common/stacktrace.rs +86 -0
- src/common/strings.rs +5 -0
- src/common/telemetry.rs +101 -0
src/actix/actix_telemetry.rs
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::future::{ready, Ready};
|
2 |
+
use std::sync::Arc;
|
3 |
+
|
4 |
+
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
|
5 |
+
use actix_web::Error;
|
6 |
+
use futures_util::future::LocalBoxFuture;
|
7 |
+
use parking_lot::Mutex;
|
8 |
+
|
9 |
+
use crate::common::telemetry_ops::requests_telemetry::{
|
10 |
+
ActixTelemetryCollector, ActixWorkerTelemetryCollector,
|
11 |
+
};
|
12 |
+
|
13 |
+
pub struct ActixTelemetryService<S> {
|
14 |
+
service: S,
|
15 |
+
telemetry_data: Arc<Mutex<ActixWorkerTelemetryCollector>>,
|
16 |
+
}
|
17 |
+
|
18 |
+
pub struct ActixTelemetryTransform {
|
19 |
+
telemetry_collector: Arc<Mutex<ActixTelemetryCollector>>,
|
20 |
+
}
|
21 |
+
|
22 |
+
/// Actix telemetry service. It hooks every request and looks into response status code.
|
23 |
+
///
|
24 |
+
/// More about actix service with similar example
|
25 |
+
/// <https://actix.rs/docs/middleware/>
|
26 |
+
impl<S, B> Service<ServiceRequest> for ActixTelemetryService<S>
|
27 |
+
where
|
28 |
+
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
29 |
+
S::Future: 'static,
|
30 |
+
B: 'static,
|
31 |
+
{
|
32 |
+
type Response = ServiceResponse<B>;
|
33 |
+
type Error = Error;
|
34 |
+
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
35 |
+
|
36 |
+
actix_web::dev::forward_ready!(service);
|
37 |
+
|
38 |
+
fn call(&self, request: ServiceRequest) -> Self::Future {
|
39 |
+
let match_pattern = request
|
40 |
+
.match_pattern()
|
41 |
+
.unwrap_or_else(|| "unknown".to_owned());
|
42 |
+
let request_key = format!("{} {}", request.method(), match_pattern);
|
43 |
+
let future = self.service.call(request);
|
44 |
+
let telemetry_data = self.telemetry_data.clone();
|
45 |
+
Box::pin(async move {
|
46 |
+
let instant = std::time::Instant::now();
|
47 |
+
let response = future.await?;
|
48 |
+
let status = response.response().status().as_u16();
|
49 |
+
telemetry_data
|
50 |
+
.lock()
|
51 |
+
.add_response(request_key, status, instant);
|
52 |
+
Ok(response)
|
53 |
+
})
|
54 |
+
}
|
55 |
+
}
|
56 |
+
|
57 |
+
impl ActixTelemetryTransform {
|
58 |
+
pub fn new(telemetry_collector: Arc<Mutex<ActixTelemetryCollector>>) -> Self {
|
59 |
+
Self {
|
60 |
+
telemetry_collector,
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
/// Actix telemetry transform. It's a builder for an actix service
|
66 |
+
///
|
67 |
+
/// More about actix transform with similar example
|
68 |
+
/// <https://actix.rs/docs/middleware/>
|
69 |
+
impl<S, B> Transform<S, ServiceRequest> for ActixTelemetryTransform
|
70 |
+
where
|
71 |
+
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
|
72 |
+
S::Future: 'static,
|
73 |
+
B: 'static,
|
74 |
+
{
|
75 |
+
type Response = ServiceResponse<B>;
|
76 |
+
type Error = Error;
|
77 |
+
type Transform = ActixTelemetryService<S>;
|
78 |
+
type InitError = ();
|
79 |
+
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
80 |
+
|
81 |
+
fn new_transform(&self, service: S) -> Self::Future {
|
82 |
+
ready(Ok(ActixTelemetryService {
|
83 |
+
service,
|
84 |
+
telemetry_data: self
|
85 |
+
.telemetry_collector
|
86 |
+
.lock()
|
87 |
+
.create_web_worker_telemetry(),
|
88 |
+
}))
|
89 |
+
}
|
90 |
+
}
|
src/actix/api/cluster_api.rs
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::future::Future;
|
2 |
+
|
3 |
+
use actix_web::{delete, get, post, put, web, HttpResponse};
|
4 |
+
use actix_web_validator::Query;
|
5 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
6 |
+
use schemars::JsonSchema;
|
7 |
+
use serde::{Deserialize, Serialize};
|
8 |
+
use storage::content_manager::consensus_ops::ConsensusOperations;
|
9 |
+
use storage::content_manager::errors::StorageError;
|
10 |
+
use storage::dispatcher::Dispatcher;
|
11 |
+
use storage::rbac::AccessRequirements;
|
12 |
+
use validator::Validate;
|
13 |
+
|
14 |
+
use crate::actix::auth::ActixAccess;
|
15 |
+
use crate::actix::helpers;
|
16 |
+
|
17 |
+
#[derive(Debug, Deserialize, Validate)]
|
18 |
+
struct QueryParams {
|
19 |
+
#[serde(default)]
|
20 |
+
force: bool,
|
21 |
+
#[serde(default)]
|
22 |
+
#[validate(range(min = 1))]
|
23 |
+
timeout: Option<u64>,
|
24 |
+
}
|
25 |
+
|
26 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
27 |
+
pub struct MetadataParams {
|
28 |
+
#[serde(default)]
|
29 |
+
pub wait: bool,
|
30 |
+
}
|
31 |
+
|
32 |
+
#[get("/cluster")]
|
33 |
+
fn cluster_status(
|
34 |
+
dispatcher: web::Data<Dispatcher>,
|
35 |
+
ActixAccess(access): ActixAccess,
|
36 |
+
) -> impl Future<Output = HttpResponse> {
|
37 |
+
helpers::time(async move {
|
38 |
+
access.check_global_access(AccessRequirements::new())?;
|
39 |
+
Ok(dispatcher.cluster_status())
|
40 |
+
})
|
41 |
+
}
|
42 |
+
|
43 |
+
#[post("/cluster/recover")]
|
44 |
+
fn recover_current_peer(
|
45 |
+
dispatcher: web::Data<Dispatcher>,
|
46 |
+
ActixAccess(access): ActixAccess,
|
47 |
+
) -> impl Future<Output = HttpResponse> {
|
48 |
+
// Not a collection level request.
|
49 |
+
let pass = new_unchecked_verification_pass();
|
50 |
+
|
51 |
+
helpers::time(async move {
|
52 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
53 |
+
dispatcher.toc(&access, &pass).request_snapshot()?;
|
54 |
+
Ok(true)
|
55 |
+
})
|
56 |
+
}
|
57 |
+
|
58 |
+
#[delete("/cluster/peer/{peer_id}")]
|
59 |
+
fn remove_peer(
|
60 |
+
dispatcher: web::Data<Dispatcher>,
|
61 |
+
peer_id: web::Path<u64>,
|
62 |
+
Query(params): Query<QueryParams>,
|
63 |
+
ActixAccess(access): ActixAccess,
|
64 |
+
) -> impl Future<Output = HttpResponse> {
|
65 |
+
// Not a collection level request.
|
66 |
+
let pass = new_unchecked_verification_pass();
|
67 |
+
|
68 |
+
helpers::time(async move {
|
69 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
70 |
+
|
71 |
+
let dispatcher = dispatcher.into_inner();
|
72 |
+
let toc = dispatcher.toc(&access, &pass);
|
73 |
+
let peer_id = peer_id.into_inner();
|
74 |
+
|
75 |
+
let has_shards = toc.peer_has_shards(peer_id).await;
|
76 |
+
if !params.force && has_shards {
|
77 |
+
return Err(StorageError::BadRequest {
|
78 |
+
description: format!("Cannot remove peer {peer_id} as there are shards on it"),
|
79 |
+
});
|
80 |
+
}
|
81 |
+
|
82 |
+
match dispatcher.consensus_state() {
|
83 |
+
Some(consensus_state) => {
|
84 |
+
consensus_state
|
85 |
+
.propose_consensus_op_with_await(
|
86 |
+
ConsensusOperations::RemovePeer(peer_id),
|
87 |
+
params.timeout.map(std::time::Duration::from_secs),
|
88 |
+
)
|
89 |
+
.await
|
90 |
+
}
|
91 |
+
None => Err(StorageError::BadRequest {
|
92 |
+
description: "Distributed mode disabled.".to_string(),
|
93 |
+
}),
|
94 |
+
}
|
95 |
+
})
|
96 |
+
}
|
97 |
+
|
98 |
+
#[get("/cluster/metadata/keys")]
|
99 |
+
async fn get_cluster_metadata_keys(
|
100 |
+
dispatcher: web::Data<Dispatcher>,
|
101 |
+
ActixAccess(access): ActixAccess,
|
102 |
+
) -> HttpResponse {
|
103 |
+
helpers::time(async move {
|
104 |
+
access.check_global_access(AccessRequirements::new())?;
|
105 |
+
|
106 |
+
let keys = dispatcher
|
107 |
+
.consensus_state()
|
108 |
+
.ok_or_else(|| StorageError::service_error("Qdrant is running in standalone mode"))?
|
109 |
+
.persistent
|
110 |
+
.read()
|
111 |
+
.get_cluster_metadata_keys();
|
112 |
+
|
113 |
+
Ok(keys)
|
114 |
+
})
|
115 |
+
.await
|
116 |
+
}
|
117 |
+
|
118 |
+
#[get("/cluster/metadata/keys/{key}")]
|
119 |
+
async fn get_cluster_metadata_key(
|
120 |
+
dispatcher: web::Data<Dispatcher>,
|
121 |
+
ActixAccess(access): ActixAccess,
|
122 |
+
key: web::Path<String>,
|
123 |
+
) -> HttpResponse {
|
124 |
+
helpers::time(async move {
|
125 |
+
access.check_global_access(AccessRequirements::new())?;
|
126 |
+
|
127 |
+
let value = dispatcher
|
128 |
+
.consensus_state()
|
129 |
+
.ok_or_else(|| StorageError::service_error("Qdrant is running in standalone mode"))?
|
130 |
+
.persistent
|
131 |
+
.read()
|
132 |
+
.get_cluster_metadata_key(key.as_ref());
|
133 |
+
|
134 |
+
Ok(value)
|
135 |
+
})
|
136 |
+
.await
|
137 |
+
}
|
138 |
+
|
139 |
+
#[put("/cluster/metadata/keys/{key}")]
|
140 |
+
async fn update_cluster_metadata_key(
|
141 |
+
dispatcher: web::Data<Dispatcher>,
|
142 |
+
ActixAccess(access): ActixAccess,
|
143 |
+
key: web::Path<String>,
|
144 |
+
params: Query<MetadataParams>,
|
145 |
+
value: web::Json<serde_json::Value>,
|
146 |
+
) -> HttpResponse {
|
147 |
+
// Not a collection level request.
|
148 |
+
let pass = new_unchecked_verification_pass();
|
149 |
+
helpers::time(async move {
|
150 |
+
let toc = dispatcher.toc(&access, &pass);
|
151 |
+
access.check_global_access(AccessRequirements::new().write())?;
|
152 |
+
|
153 |
+
toc.update_cluster_metadata(key.into_inner(), value.into_inner(), params.wait)
|
154 |
+
.await?;
|
155 |
+
Ok(true)
|
156 |
+
})
|
157 |
+
.await
|
158 |
+
}
|
159 |
+
|
160 |
+
#[delete("/cluster/metadata/keys/{key}")]
|
161 |
+
async fn delete_cluster_metadata_key(
|
162 |
+
dispatcher: web::Data<Dispatcher>,
|
163 |
+
ActixAccess(access): ActixAccess,
|
164 |
+
key: web::Path<String>,
|
165 |
+
params: Query<MetadataParams>,
|
166 |
+
) -> HttpResponse {
|
167 |
+
// Not a collection level request.
|
168 |
+
let pass = new_unchecked_verification_pass();
|
169 |
+
helpers::time(async move {
|
170 |
+
let toc = dispatcher.toc(&access, &pass);
|
171 |
+
access.check_global_access(AccessRequirements::new().write())?;
|
172 |
+
|
173 |
+
toc.update_cluster_metadata(key.into_inner(), serde_json::Value::Null, params.wait)
|
174 |
+
.await?;
|
175 |
+
Ok(true)
|
176 |
+
})
|
177 |
+
.await
|
178 |
+
}
|
179 |
+
|
180 |
+
// Configure services
|
181 |
+
pub fn config_cluster_api(cfg: &mut web::ServiceConfig) {
|
182 |
+
cfg.service(cluster_status)
|
183 |
+
.service(remove_peer)
|
184 |
+
.service(recover_current_peer)
|
185 |
+
.service(get_cluster_metadata_keys)
|
186 |
+
.service(get_cluster_metadata_key)
|
187 |
+
.service(update_cluster_metadata_key)
|
188 |
+
.service(delete_cluster_metadata_key);
|
189 |
+
}
|
src/actix/api/collections_api.rs
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::time::Duration;
|
2 |
+
|
3 |
+
use actix_web::rt::time::Instant;
|
4 |
+
use actix_web::{delete, get, patch, post, put, web, HttpResponse, Responder};
|
5 |
+
use actix_web_validator::{Json, Path, Query};
|
6 |
+
use collection::operations::cluster_ops::ClusterOperations;
|
7 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
8 |
+
use serde::Deserialize;
|
9 |
+
use storage::content_manager::collection_meta_ops::{
|
10 |
+
ChangeAliasesOperation, CollectionMetaOperations, CreateCollection, CreateCollectionOperation,
|
11 |
+
DeleteCollectionOperation, UpdateCollection, UpdateCollectionOperation,
|
12 |
+
};
|
13 |
+
use storage::dispatcher::Dispatcher;
|
14 |
+
use validator::Validate;
|
15 |
+
|
16 |
+
use super::CollectionPath;
|
17 |
+
use crate::actix::api::StrictCollectionPath;
|
18 |
+
use crate::actix::auth::ActixAccess;
|
19 |
+
use crate::actix::helpers::{self, process_response};
|
20 |
+
use crate::common::collections::*;
|
21 |
+
|
22 |
+
#[derive(Debug, Deserialize, Validate)]
|
23 |
+
pub struct WaitTimeout {
|
24 |
+
#[validate(range(min = 1))]
|
25 |
+
timeout: Option<u64>,
|
26 |
+
}
|
27 |
+
|
28 |
+
impl WaitTimeout {
|
29 |
+
pub fn timeout(&self) -> Option<Duration> {
|
30 |
+
self.timeout.map(Duration::from_secs)
|
31 |
+
}
|
32 |
+
}
|
33 |
+
|
34 |
+
#[get("/collections")]
|
35 |
+
async fn get_collections(
|
36 |
+
dispatcher: web::Data<Dispatcher>,
|
37 |
+
ActixAccess(access): ActixAccess,
|
38 |
+
) -> HttpResponse {
|
39 |
+
// No request to verify
|
40 |
+
let pass = new_unchecked_verification_pass();
|
41 |
+
|
42 |
+
helpers::time(do_list_collections(dispatcher.toc(&access, &pass), access)).await
|
43 |
+
}
|
44 |
+
|
45 |
+
#[get("/aliases")]
|
46 |
+
async fn get_aliases(
|
47 |
+
dispatcher: web::Data<Dispatcher>,
|
48 |
+
ActixAccess(access): ActixAccess,
|
49 |
+
) -> HttpResponse {
|
50 |
+
// No request to verify
|
51 |
+
let pass = new_unchecked_verification_pass();
|
52 |
+
|
53 |
+
helpers::time(do_list_aliases(dispatcher.toc(&access, &pass), access)).await
|
54 |
+
}
|
55 |
+
|
56 |
+
#[get("/collections/{name}")]
|
57 |
+
async fn get_collection(
|
58 |
+
dispatcher: web::Data<Dispatcher>,
|
59 |
+
collection: Path<CollectionPath>,
|
60 |
+
ActixAccess(access): ActixAccess,
|
61 |
+
) -> HttpResponse {
|
62 |
+
// No request to verify
|
63 |
+
let pass = new_unchecked_verification_pass();
|
64 |
+
|
65 |
+
helpers::time(do_get_collection(
|
66 |
+
dispatcher.toc(&access, &pass),
|
67 |
+
access,
|
68 |
+
&collection.name,
|
69 |
+
None,
|
70 |
+
))
|
71 |
+
.await
|
72 |
+
}
|
73 |
+
|
74 |
+
#[get("/collections/{name}/exists")]
|
75 |
+
async fn get_collection_existence(
|
76 |
+
dispatcher: web::Data<Dispatcher>,
|
77 |
+
collection: Path<CollectionPath>,
|
78 |
+
ActixAccess(access): ActixAccess,
|
79 |
+
) -> HttpResponse {
|
80 |
+
// No request to verify
|
81 |
+
let pass = new_unchecked_verification_pass();
|
82 |
+
|
83 |
+
helpers::time(do_collection_exists(
|
84 |
+
dispatcher.toc(&access, &pass),
|
85 |
+
access,
|
86 |
+
&collection.name,
|
87 |
+
))
|
88 |
+
.await
|
89 |
+
}
|
90 |
+
|
91 |
+
#[get("/collections/{name}/aliases")]
|
92 |
+
async fn get_collection_aliases(
|
93 |
+
dispatcher: web::Data<Dispatcher>,
|
94 |
+
collection: Path<CollectionPath>,
|
95 |
+
ActixAccess(access): ActixAccess,
|
96 |
+
) -> HttpResponse {
|
97 |
+
// No request to verify
|
98 |
+
let pass = new_unchecked_verification_pass();
|
99 |
+
|
100 |
+
helpers::time(do_list_collection_aliases(
|
101 |
+
dispatcher.toc(&access, &pass),
|
102 |
+
access,
|
103 |
+
&collection.name,
|
104 |
+
))
|
105 |
+
.await
|
106 |
+
}
|
107 |
+
|
108 |
+
#[put("/collections/{name}")]
|
109 |
+
async fn create_collection(
|
110 |
+
dispatcher: web::Data<Dispatcher>,
|
111 |
+
collection: Path<StrictCollectionPath>,
|
112 |
+
operation: Json<CreateCollection>,
|
113 |
+
Query(query): Query<WaitTimeout>,
|
114 |
+
ActixAccess(access): ActixAccess,
|
115 |
+
) -> HttpResponse {
|
116 |
+
helpers::time(dispatcher.submit_collection_meta_op(
|
117 |
+
CollectionMetaOperations::CreateCollection(CreateCollectionOperation::new(
|
118 |
+
collection.name.clone(),
|
119 |
+
operation.into_inner(),
|
120 |
+
)),
|
121 |
+
access,
|
122 |
+
query.timeout(),
|
123 |
+
))
|
124 |
+
.await
|
125 |
+
}
|
126 |
+
|
127 |
+
#[patch("/collections/{name}")]
|
128 |
+
async fn update_collection(
|
129 |
+
dispatcher: web::Data<Dispatcher>,
|
130 |
+
collection: Path<CollectionPath>,
|
131 |
+
operation: Json<UpdateCollection>,
|
132 |
+
Query(query): Query<WaitTimeout>,
|
133 |
+
ActixAccess(access): ActixAccess,
|
134 |
+
) -> impl Responder {
|
135 |
+
let timing = Instant::now();
|
136 |
+
let name = collection.name.clone();
|
137 |
+
let response = dispatcher
|
138 |
+
.submit_collection_meta_op(
|
139 |
+
CollectionMetaOperations::UpdateCollection(UpdateCollectionOperation::new(
|
140 |
+
name,
|
141 |
+
operation.into_inner(),
|
142 |
+
)),
|
143 |
+
access,
|
144 |
+
query.timeout(),
|
145 |
+
)
|
146 |
+
.await;
|
147 |
+
process_response(response, timing, None)
|
148 |
+
}
|
149 |
+
|
150 |
+
#[delete("/collections/{name}")]
|
151 |
+
async fn delete_collection(
|
152 |
+
dispatcher: web::Data<Dispatcher>,
|
153 |
+
collection: Path<CollectionPath>,
|
154 |
+
Query(query): Query<WaitTimeout>,
|
155 |
+
ActixAccess(access): ActixAccess,
|
156 |
+
) -> impl Responder {
|
157 |
+
let timing = Instant::now();
|
158 |
+
let response = dispatcher
|
159 |
+
.submit_collection_meta_op(
|
160 |
+
CollectionMetaOperations::DeleteCollection(DeleteCollectionOperation(
|
161 |
+
collection.name.clone(),
|
162 |
+
)),
|
163 |
+
access,
|
164 |
+
query.timeout(),
|
165 |
+
)
|
166 |
+
.await;
|
167 |
+
process_response(response, timing, None)
|
168 |
+
}
|
169 |
+
|
170 |
+
#[post("/collections/aliases")]
|
171 |
+
async fn update_aliases(
|
172 |
+
dispatcher: web::Data<Dispatcher>,
|
173 |
+
operation: Json<ChangeAliasesOperation>,
|
174 |
+
Query(query): Query<WaitTimeout>,
|
175 |
+
ActixAccess(access): ActixAccess,
|
176 |
+
) -> impl Responder {
|
177 |
+
let timing = Instant::now();
|
178 |
+
let response = dispatcher
|
179 |
+
.submit_collection_meta_op(
|
180 |
+
CollectionMetaOperations::ChangeAliases(operation.0),
|
181 |
+
access,
|
182 |
+
query.timeout(),
|
183 |
+
)
|
184 |
+
.await;
|
185 |
+
process_response(response, timing, None)
|
186 |
+
}
|
187 |
+
|
188 |
+
#[get("/collections/{name}/cluster")]
|
189 |
+
async fn get_cluster_info(
|
190 |
+
dispatcher: web::Data<Dispatcher>,
|
191 |
+
collection: Path<CollectionPath>,
|
192 |
+
ActixAccess(access): ActixAccess,
|
193 |
+
) -> impl Responder {
|
194 |
+
// No request to verify
|
195 |
+
let pass = new_unchecked_verification_pass();
|
196 |
+
|
197 |
+
helpers::time(do_get_collection_cluster(
|
198 |
+
dispatcher.toc(&access, &pass),
|
199 |
+
access,
|
200 |
+
&collection.name,
|
201 |
+
))
|
202 |
+
.await
|
203 |
+
}
|
204 |
+
|
205 |
+
#[post("/collections/{name}/cluster")]
|
206 |
+
async fn update_collection_cluster(
|
207 |
+
dispatcher: web::Data<Dispatcher>,
|
208 |
+
collection: Path<CollectionPath>,
|
209 |
+
operation: Json<ClusterOperations>,
|
210 |
+
Query(query): Query<WaitTimeout>,
|
211 |
+
ActixAccess(access): ActixAccess,
|
212 |
+
) -> impl Responder {
|
213 |
+
let timing = Instant::now();
|
214 |
+
let wait_timeout = query.timeout();
|
215 |
+
let response = do_update_collection_cluster(
|
216 |
+
&dispatcher.into_inner(),
|
217 |
+
collection.name.clone(),
|
218 |
+
operation.0,
|
219 |
+
access,
|
220 |
+
wait_timeout,
|
221 |
+
)
|
222 |
+
.await;
|
223 |
+
process_response(response, timing, None)
|
224 |
+
}
|
225 |
+
|
226 |
+
// Configure services
|
227 |
+
pub fn config_collections_api(cfg: &mut web::ServiceConfig) {
|
228 |
+
// Ordering of services is important for correct path pattern matching
|
229 |
+
// See: <https://github.com/qdrant/qdrant/issues/3543>
|
230 |
+
cfg.service(update_aliases)
|
231 |
+
.service(get_collections)
|
232 |
+
.service(get_collection)
|
233 |
+
.service(get_collection_existence)
|
234 |
+
.service(create_collection)
|
235 |
+
.service(update_collection)
|
236 |
+
.service(delete_collection)
|
237 |
+
.service(get_aliases)
|
238 |
+
.service(get_collection_aliases)
|
239 |
+
.service(get_cluster_info)
|
240 |
+
.service(update_collection_cluster);
|
241 |
+
}
|
242 |
+
|
243 |
+
#[cfg(test)]
|
244 |
+
mod tests {
|
245 |
+
use actix_web::web::Query;
|
246 |
+
|
247 |
+
use super::WaitTimeout;
|
248 |
+
|
249 |
+
#[test]
|
250 |
+
fn timeout_is_deserialized() {
|
251 |
+
let timeout: WaitTimeout = Query::from_query("").unwrap().0;
|
252 |
+
assert!(timeout.timeout.is_none());
|
253 |
+
let timeout: WaitTimeout = Query::from_query("timeout=10").unwrap().0;
|
254 |
+
assert_eq!(timeout.timeout, Some(10))
|
255 |
+
}
|
256 |
+
}
|
src/actix/api/count_api.rs
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{post, web, Responder};
|
2 |
+
use actix_web_validator::{Json, Path, Query};
|
3 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
4 |
+
use collection::operations::types::CountRequest;
|
5 |
+
use storage::content_manager::collection_verification::check_strict_mode;
|
6 |
+
use storage::dispatcher::Dispatcher;
|
7 |
+
use tokio::time::Instant;
|
8 |
+
|
9 |
+
use super::CollectionPath;
|
10 |
+
use crate::actix::api::read_params::ReadParams;
|
11 |
+
use crate::actix::auth::ActixAccess;
|
12 |
+
use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
|
13 |
+
use crate::common::points::do_count_points;
|
14 |
+
use crate::settings::ServiceConfig;
|
15 |
+
|
16 |
+
#[post("/collections/{name}/points/count")]
|
17 |
+
async fn count_points(
|
18 |
+
dispatcher: web::Data<Dispatcher>,
|
19 |
+
collection: Path<CollectionPath>,
|
20 |
+
request: Json<CountRequest>,
|
21 |
+
params: Query<ReadParams>,
|
22 |
+
service_config: web::Data<ServiceConfig>,
|
23 |
+
ActixAccess(access): ActixAccess,
|
24 |
+
) -> impl Responder {
|
25 |
+
let CountRequest {
|
26 |
+
count_request,
|
27 |
+
shard_key,
|
28 |
+
} = request.into_inner();
|
29 |
+
|
30 |
+
let pass = match check_strict_mode(
|
31 |
+
&count_request,
|
32 |
+
params.timeout_as_secs(),
|
33 |
+
&collection.name,
|
34 |
+
&dispatcher,
|
35 |
+
&access,
|
36 |
+
)
|
37 |
+
.await
|
38 |
+
{
|
39 |
+
Ok(pass) => pass,
|
40 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
41 |
+
};
|
42 |
+
|
43 |
+
let shard_selector = match shard_key {
|
44 |
+
None => ShardSelectorInternal::All,
|
45 |
+
Some(shard_keys) => ShardSelectorInternal::from(shard_keys),
|
46 |
+
};
|
47 |
+
|
48 |
+
let request_hw_counter = get_request_hardware_counter(
|
49 |
+
&dispatcher,
|
50 |
+
collection.name.clone(),
|
51 |
+
service_config.hardware_reporting(),
|
52 |
+
);
|
53 |
+
|
54 |
+
let timing = Instant::now();
|
55 |
+
|
56 |
+
let result = do_count_points(
|
57 |
+
dispatcher.toc(&access, &pass),
|
58 |
+
&collection.name,
|
59 |
+
count_request,
|
60 |
+
params.consistency,
|
61 |
+
params.timeout(),
|
62 |
+
shard_selector,
|
63 |
+
access,
|
64 |
+
request_hw_counter.get_counter(),
|
65 |
+
)
|
66 |
+
.await;
|
67 |
+
|
68 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
69 |
+
}
|
src/actix/api/debug_api.rs
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{get, patch, web, Responder};
|
2 |
+
use storage::rbac::AccessRequirements;
|
3 |
+
|
4 |
+
use crate::actix::auth::ActixAccess;
|
5 |
+
use crate::common::debugger::{DebugConfigPatch, DebuggerState};
|
6 |
+
|
7 |
+
#[get("/debugger")]
|
8 |
+
async fn get_debugger_config(
|
9 |
+
ActixAccess(access): ActixAccess,
|
10 |
+
debugger_state: web::Data<DebuggerState>,
|
11 |
+
) -> impl Responder {
|
12 |
+
crate::actix::helpers::time(async move {
|
13 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
14 |
+
Ok(debugger_state.get_config())
|
15 |
+
})
|
16 |
+
.await
|
17 |
+
}
|
18 |
+
|
19 |
+
#[patch("/debugger")]
|
20 |
+
async fn update_debugger_config(
|
21 |
+
ActixAccess(access): ActixAccess,
|
22 |
+
debugger_state: web::Data<DebuggerState>,
|
23 |
+
debug_patch: web::Json<DebugConfigPatch>,
|
24 |
+
) -> impl Responder {
|
25 |
+
crate::actix::helpers::time(async move {
|
26 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
27 |
+
Ok(debugger_state.apply_config_patch(debug_patch.into_inner()))
|
28 |
+
})
|
29 |
+
.await
|
30 |
+
}
|
31 |
+
|
32 |
+
// Configure services
|
33 |
+
pub fn config_debugger_api(cfg: &mut web::ServiceConfig) {
|
34 |
+
cfg.service(get_debugger_config);
|
35 |
+
cfg.service(update_debugger_config);
|
36 |
+
}
|
src/actix/api/discovery_api.rs
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{post, web, Responder};
|
2 |
+
use actix_web_validator::{Json, Path, Query};
|
3 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
4 |
+
use collection::operations::types::{DiscoverRequest, DiscoverRequestBatch};
|
5 |
+
use itertools::Itertools;
|
6 |
+
use storage::content_manager::collection_verification::{
|
7 |
+
check_strict_mode, check_strict_mode_batch,
|
8 |
+
};
|
9 |
+
use storage::dispatcher::Dispatcher;
|
10 |
+
use tokio::time::Instant;
|
11 |
+
|
12 |
+
use crate::actix::api::read_params::ReadParams;
|
13 |
+
use crate::actix::api::CollectionPath;
|
14 |
+
use crate::actix::auth::ActixAccess;
|
15 |
+
use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
|
16 |
+
use crate::common::points::do_discover_batch_points;
|
17 |
+
use crate::settings::ServiceConfig;
|
18 |
+
|
19 |
+
#[post("/collections/{name}/points/discover")]
|
20 |
+
async fn discover_points(
|
21 |
+
dispatcher: web::Data<Dispatcher>,
|
22 |
+
collection: Path<CollectionPath>,
|
23 |
+
request: Json<DiscoverRequest>,
|
24 |
+
params: Query<ReadParams>,
|
25 |
+
service_config: web::Data<ServiceConfig>,
|
26 |
+
ActixAccess(access): ActixAccess,
|
27 |
+
) -> impl Responder {
|
28 |
+
let DiscoverRequest {
|
29 |
+
discover_request,
|
30 |
+
shard_key,
|
31 |
+
} = request.into_inner();
|
32 |
+
|
33 |
+
let pass = match check_strict_mode(
|
34 |
+
&discover_request,
|
35 |
+
params.timeout_as_secs(),
|
36 |
+
&collection.name,
|
37 |
+
&dispatcher,
|
38 |
+
&access,
|
39 |
+
)
|
40 |
+
.await
|
41 |
+
{
|
42 |
+
Ok(pass) => pass,
|
43 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
44 |
+
};
|
45 |
+
|
46 |
+
let shard_selection = match shard_key {
|
47 |
+
None => ShardSelectorInternal::All,
|
48 |
+
Some(shard_keys) => shard_keys.into(),
|
49 |
+
};
|
50 |
+
|
51 |
+
let request_hw_counter = get_request_hardware_counter(
|
52 |
+
&dispatcher,
|
53 |
+
collection.name.clone(),
|
54 |
+
service_config.hardware_reporting(),
|
55 |
+
);
|
56 |
+
|
57 |
+
let timing = Instant::now();
|
58 |
+
|
59 |
+
let result = dispatcher
|
60 |
+
.toc(&access, &pass)
|
61 |
+
.discover(
|
62 |
+
&collection.name,
|
63 |
+
discover_request,
|
64 |
+
params.consistency,
|
65 |
+
shard_selection,
|
66 |
+
access,
|
67 |
+
params.timeout(),
|
68 |
+
request_hw_counter.get_counter(),
|
69 |
+
)
|
70 |
+
.await
|
71 |
+
.map(|scored_points| {
|
72 |
+
scored_points
|
73 |
+
.into_iter()
|
74 |
+
.map(api::rest::ScoredPoint::from)
|
75 |
+
.collect_vec()
|
76 |
+
});
|
77 |
+
|
78 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
79 |
+
}
|
80 |
+
|
81 |
+
#[post("/collections/{name}/points/discover/batch")]
|
82 |
+
async fn discover_batch_points(
|
83 |
+
dispatcher: web::Data<Dispatcher>,
|
84 |
+
collection: Path<CollectionPath>,
|
85 |
+
request: Json<DiscoverRequestBatch>,
|
86 |
+
params: Query<ReadParams>,
|
87 |
+
service_config: web::Data<ServiceConfig>,
|
88 |
+
ActixAccess(access): ActixAccess,
|
89 |
+
) -> impl Responder {
|
90 |
+
let request = request.into_inner();
|
91 |
+
|
92 |
+
let pass = match check_strict_mode_batch(
|
93 |
+
request.searches.iter().map(|i| &i.discover_request),
|
94 |
+
params.timeout_as_secs(),
|
95 |
+
&collection.name,
|
96 |
+
&dispatcher,
|
97 |
+
&access,
|
98 |
+
)
|
99 |
+
.await
|
100 |
+
{
|
101 |
+
Ok(pass) => pass,
|
102 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
103 |
+
};
|
104 |
+
|
105 |
+
let request_hw_counter = get_request_hardware_counter(
|
106 |
+
&dispatcher,
|
107 |
+
collection.name.clone(),
|
108 |
+
service_config.hardware_reporting(),
|
109 |
+
);
|
110 |
+
let timing = Instant::now();
|
111 |
+
|
112 |
+
let result = do_discover_batch_points(
|
113 |
+
dispatcher.toc(&access, &pass),
|
114 |
+
&collection.name,
|
115 |
+
request,
|
116 |
+
params.consistency,
|
117 |
+
access,
|
118 |
+
params.timeout(),
|
119 |
+
request_hw_counter.get_counter(),
|
120 |
+
)
|
121 |
+
.await
|
122 |
+
.map(|batch_scored_points| {
|
123 |
+
batch_scored_points
|
124 |
+
.into_iter()
|
125 |
+
.map(|scored_points| {
|
126 |
+
scored_points
|
127 |
+
.into_iter()
|
128 |
+
.map(api::rest::ScoredPoint::from)
|
129 |
+
.collect_vec()
|
130 |
+
})
|
131 |
+
.collect_vec()
|
132 |
+
});
|
133 |
+
|
134 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
135 |
+
}
|
136 |
+
|
137 |
+
pub fn config_discovery_api(cfg: &mut web::ServiceConfig) {
|
138 |
+
cfg.service(discover_points);
|
139 |
+
cfg.service(discover_batch_points);
|
140 |
+
}
|
src/actix/api/facet_api.rs
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{post, web, Responder};
|
2 |
+
use actix_web_validator::{Json, Path, Query};
|
3 |
+
use api::rest::{FacetRequest, FacetResponse};
|
4 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
5 |
+
use storage::content_manager::collection_verification::check_strict_mode;
|
6 |
+
use storage::dispatcher::Dispatcher;
|
7 |
+
use tokio::time::Instant;
|
8 |
+
|
9 |
+
use crate::actix::api::read_params::ReadParams;
|
10 |
+
use crate::actix::api::CollectionPath;
|
11 |
+
use crate::actix::auth::ActixAccess;
|
12 |
+
use crate::actix::helpers::{
|
13 |
+
get_request_hardware_counter, process_response, process_response_error,
|
14 |
+
};
|
15 |
+
use crate::settings::ServiceConfig;
|
16 |
+
|
17 |
+
#[post("/collections/{name}/facet")]
|
18 |
+
async fn facet(
|
19 |
+
dispatcher: web::Data<Dispatcher>,
|
20 |
+
collection: Path<CollectionPath>,
|
21 |
+
request: Json<FacetRequest>,
|
22 |
+
params: Query<ReadParams>,
|
23 |
+
service_config: web::Data<ServiceConfig>,
|
24 |
+
ActixAccess(access): ActixAccess,
|
25 |
+
) -> impl Responder {
|
26 |
+
let timing = Instant::now();
|
27 |
+
|
28 |
+
let FacetRequest {
|
29 |
+
facet_request,
|
30 |
+
shard_key,
|
31 |
+
} = request.into_inner();
|
32 |
+
|
33 |
+
let pass = match check_strict_mode(
|
34 |
+
&facet_request,
|
35 |
+
params.timeout_as_secs(),
|
36 |
+
&collection.name,
|
37 |
+
&dispatcher,
|
38 |
+
&access,
|
39 |
+
)
|
40 |
+
.await
|
41 |
+
{
|
42 |
+
Ok(pass) => pass,
|
43 |
+
Err(err) => return process_response_error(err, timing, None),
|
44 |
+
};
|
45 |
+
|
46 |
+
let facet_params = From::from(facet_request);
|
47 |
+
|
48 |
+
let shard_selection = match shard_key {
|
49 |
+
None => ShardSelectorInternal::All,
|
50 |
+
Some(shard_keys) => shard_keys.into(),
|
51 |
+
};
|
52 |
+
|
53 |
+
let request_hw_counter = get_request_hardware_counter(
|
54 |
+
&dispatcher,
|
55 |
+
collection.name.clone(),
|
56 |
+
service_config.hardware_reporting(),
|
57 |
+
);
|
58 |
+
|
59 |
+
let response = dispatcher
|
60 |
+
.toc(&access, &pass)
|
61 |
+
.facet(
|
62 |
+
&collection.name,
|
63 |
+
facet_params,
|
64 |
+
shard_selection,
|
65 |
+
params.consistency,
|
66 |
+
access,
|
67 |
+
params.timeout(),
|
68 |
+
)
|
69 |
+
.await
|
70 |
+
.map(FacetResponse::from);
|
71 |
+
|
72 |
+
process_response(response, timing, request_hw_counter.to_rest_api())
|
73 |
+
}
|
74 |
+
|
75 |
+
pub fn config_facet_api(cfg: &mut web::ServiceConfig) {
|
76 |
+
cfg.service(facet);
|
77 |
+
}
|
src/actix/api/issues_api.rs
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{delete, get, web, Responder};
|
2 |
+
use collection::operations::types::IssuesReport;
|
3 |
+
use storage::rbac::AccessRequirements;
|
4 |
+
|
5 |
+
use crate::actix::auth::ActixAccess;
|
6 |
+
|
7 |
+
#[get("/issues")]
|
8 |
+
async fn get_issues(ActixAccess(access): ActixAccess) -> impl Responder {
|
9 |
+
crate::actix::helpers::time(async move {
|
10 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
11 |
+
Ok(IssuesReport {
|
12 |
+
issues: issues::all_issues(),
|
13 |
+
})
|
14 |
+
})
|
15 |
+
.await
|
16 |
+
}
|
17 |
+
|
18 |
+
#[delete("/issues")]
|
19 |
+
async fn clear_issues(ActixAccess(access): ActixAccess) -> impl Responder {
|
20 |
+
crate::actix::helpers::time(async move {
|
21 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
22 |
+
issues::clear();
|
23 |
+
Ok(true)
|
24 |
+
})
|
25 |
+
.await
|
26 |
+
}
|
27 |
+
|
28 |
+
// Configure services
|
29 |
+
pub fn config_issues_api(cfg: &mut web::ServiceConfig) {
|
30 |
+
cfg.service(get_issues);
|
31 |
+
cfg.service(clear_issues);
|
32 |
+
}
|
src/actix/api/local_shard_api.rs
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::sync::Arc;
|
2 |
+
|
3 |
+
use actix_web::{post, web, Responder};
|
4 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
5 |
+
use collection::operations::types::{
|
6 |
+
CountRequestInternal, PointRequestInternal, ScrollRequestInternal,
|
7 |
+
};
|
8 |
+
use collection::operations::verification::{new_unchecked_verification_pass, VerificationPass};
|
9 |
+
use collection::shards::shard::ShardId;
|
10 |
+
use segment::types::{Condition, Filter};
|
11 |
+
use storage::content_manager::collection_verification::check_strict_mode;
|
12 |
+
use storage::content_manager::errors::{StorageError, StorageResult};
|
13 |
+
use storage::dispatcher::Dispatcher;
|
14 |
+
use storage::rbac::{Access, AccessRequirements};
|
15 |
+
use tokio::time::Instant;
|
16 |
+
|
17 |
+
use crate::actix::api::read_params::ReadParams;
|
18 |
+
use crate::actix::auth::ActixAccess;
|
19 |
+
use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
|
20 |
+
use crate::common::points;
|
21 |
+
use crate::settings::ServiceConfig;
|
22 |
+
|
23 |
+
// Configure services
|
24 |
+
pub fn config_local_shard_api(cfg: &mut web::ServiceConfig) {
|
25 |
+
cfg.service(get_points)
|
26 |
+
.service(scroll_points)
|
27 |
+
.service(count_points)
|
28 |
+
.service(cleanup_shard);
|
29 |
+
}
|
30 |
+
|
31 |
+
#[post("/collections/{collection}/shards/{shard}/points")]
|
32 |
+
async fn get_points(
|
33 |
+
dispatcher: web::Data<Dispatcher>,
|
34 |
+
ActixAccess(access): ActixAccess,
|
35 |
+
path: web::Path<CollectionShard>,
|
36 |
+
request: web::Json<PointRequestInternal>,
|
37 |
+
params: web::Query<ReadParams>,
|
38 |
+
) -> impl Responder {
|
39 |
+
// No strict mode verification needed
|
40 |
+
let pass = new_unchecked_verification_pass();
|
41 |
+
|
42 |
+
helpers::time(async move {
|
43 |
+
let records = points::do_get_points(
|
44 |
+
dispatcher.toc(&access, &pass),
|
45 |
+
&path.collection,
|
46 |
+
request.into_inner(),
|
47 |
+
params.consistency,
|
48 |
+
params.timeout(),
|
49 |
+
ShardSelectorInternal::ShardId(path.shard),
|
50 |
+
access,
|
51 |
+
)
|
52 |
+
.await?;
|
53 |
+
|
54 |
+
let records: Vec<_> = records.into_iter().map(api::rest::Record::from).collect();
|
55 |
+
Ok(records)
|
56 |
+
})
|
57 |
+
.await
|
58 |
+
}
|
59 |
+
|
60 |
+
#[post("/collections/{collection}/shards/{shard}/points/scroll")]
|
61 |
+
async fn scroll_points(
|
62 |
+
dispatcher: web::Data<Dispatcher>,
|
63 |
+
ActixAccess(access): ActixAccess,
|
64 |
+
path: web::Path<CollectionShard>,
|
65 |
+
request: web::Json<WithFilter<ScrollRequestInternal>>,
|
66 |
+
params: web::Query<ReadParams>,
|
67 |
+
) -> impl Responder {
|
68 |
+
let WithFilter {
|
69 |
+
mut request,
|
70 |
+
hash_ring_filter,
|
71 |
+
} = request.into_inner();
|
72 |
+
|
73 |
+
let pass = match check_strict_mode(
|
74 |
+
&request,
|
75 |
+
params.timeout_as_secs(),
|
76 |
+
&path.collection,
|
77 |
+
&dispatcher,
|
78 |
+
&access,
|
79 |
+
)
|
80 |
+
.await
|
81 |
+
{
|
82 |
+
Ok(pass) => pass,
|
83 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
84 |
+
};
|
85 |
+
|
86 |
+
helpers::time(async move {
|
87 |
+
let hash_ring_filter = match hash_ring_filter {
|
88 |
+
Some(filter) => get_hash_ring_filter(
|
89 |
+
&dispatcher,
|
90 |
+
&access,
|
91 |
+
&path.collection,
|
92 |
+
AccessRequirements::new(),
|
93 |
+
filter.expected_shard_id,
|
94 |
+
&pass,
|
95 |
+
)
|
96 |
+
.await?
|
97 |
+
.into(),
|
98 |
+
|
99 |
+
None => None,
|
100 |
+
};
|
101 |
+
|
102 |
+
request.filter = merge_with_optional_filter(request.filter.take(), hash_ring_filter);
|
103 |
+
|
104 |
+
dispatcher
|
105 |
+
.toc(&access, &pass)
|
106 |
+
.scroll(
|
107 |
+
&path.collection,
|
108 |
+
request,
|
109 |
+
params.consistency,
|
110 |
+
params.timeout(),
|
111 |
+
ShardSelectorInternal::ShardId(path.shard),
|
112 |
+
access,
|
113 |
+
)
|
114 |
+
.await
|
115 |
+
})
|
116 |
+
.await
|
117 |
+
}
|
118 |
+
|
119 |
+
#[post("/collections/{collection}/shards/{shard}/points/count")]
|
120 |
+
async fn count_points(
|
121 |
+
dispatcher: web::Data<Dispatcher>,
|
122 |
+
ActixAccess(access): ActixAccess,
|
123 |
+
path: web::Path<CollectionShard>,
|
124 |
+
request: web::Json<WithFilter<CountRequestInternal>>,
|
125 |
+
params: web::Query<ReadParams>,
|
126 |
+
service_config: web::Data<ServiceConfig>,
|
127 |
+
) -> impl Responder {
|
128 |
+
let WithFilter {
|
129 |
+
mut request,
|
130 |
+
hash_ring_filter,
|
131 |
+
} = request.into_inner();
|
132 |
+
|
133 |
+
let pass = match check_strict_mode(
|
134 |
+
&request,
|
135 |
+
params.timeout_as_secs(),
|
136 |
+
&path.collection,
|
137 |
+
&dispatcher,
|
138 |
+
&access,
|
139 |
+
)
|
140 |
+
.await
|
141 |
+
{
|
142 |
+
Ok(pass) => pass,
|
143 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
144 |
+
};
|
145 |
+
|
146 |
+
let request_hw_counter = get_request_hardware_counter(
|
147 |
+
&dispatcher,
|
148 |
+
path.collection.clone(),
|
149 |
+
service_config.hardware_reporting(),
|
150 |
+
);
|
151 |
+
let timing = Instant::now();
|
152 |
+
let hw_measurement_acc = request_hw_counter.get_counter();
|
153 |
+
|
154 |
+
let result = async move {
|
155 |
+
let hash_ring_filter = match hash_ring_filter {
|
156 |
+
Some(filter) => get_hash_ring_filter(
|
157 |
+
&dispatcher,
|
158 |
+
&access,
|
159 |
+
&path.collection,
|
160 |
+
AccessRequirements::new(),
|
161 |
+
filter.expected_shard_id,
|
162 |
+
&pass,
|
163 |
+
)
|
164 |
+
.await?
|
165 |
+
.into(),
|
166 |
+
|
167 |
+
None => None,
|
168 |
+
};
|
169 |
+
|
170 |
+
request.filter = merge_with_optional_filter(request.filter.take(), hash_ring_filter);
|
171 |
+
|
172 |
+
points::do_count_points(
|
173 |
+
dispatcher.toc(&access, &pass),
|
174 |
+
&path.collection,
|
175 |
+
request,
|
176 |
+
params.consistency,
|
177 |
+
params.timeout(),
|
178 |
+
ShardSelectorInternal::ShardId(path.shard),
|
179 |
+
access,
|
180 |
+
hw_measurement_acc,
|
181 |
+
)
|
182 |
+
.await
|
183 |
+
}
|
184 |
+
.await;
|
185 |
+
|
186 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
187 |
+
}
|
188 |
+
|
189 |
+
#[post("/collections/{collection}/shards/{shard}/cleanup")]
|
190 |
+
async fn cleanup_shard(
|
191 |
+
dispatcher: web::Data<Dispatcher>,
|
192 |
+
ActixAccess(access): ActixAccess,
|
193 |
+
path: web::Path<CollectionShard>,
|
194 |
+
) -> impl Responder {
|
195 |
+
// Nothing to verify here.
|
196 |
+
let pass = new_unchecked_verification_pass();
|
197 |
+
|
198 |
+
helpers::time(async move {
|
199 |
+
let path = path.into_inner();
|
200 |
+
dispatcher
|
201 |
+
.toc(&access, &pass)
|
202 |
+
.cleanup_local_shard(&path.collection, path.shard, access)
|
203 |
+
.await
|
204 |
+
})
|
205 |
+
.await
|
206 |
+
}
|
207 |
+
|
208 |
+
#[derive(serde::Deserialize, validator::Validate)]
|
209 |
+
struct CollectionShard {
|
210 |
+
#[validate(length(min = 1, max = 255))]
|
211 |
+
collection: String,
|
212 |
+
shard: ShardId,
|
213 |
+
}
|
214 |
+
|
215 |
+
#[derive(Clone, Debug, serde::Deserialize)]
|
216 |
+
struct WithFilter<T> {
|
217 |
+
#[serde(flatten)]
|
218 |
+
request: T,
|
219 |
+
#[serde(default)]
|
220 |
+
hash_ring_filter: Option<SerdeHelper>,
|
221 |
+
}
|
222 |
+
|
223 |
+
#[derive(Clone, Debug, serde::Deserialize)]
|
224 |
+
struct SerdeHelper {
|
225 |
+
expected_shard_id: ShardId,
|
226 |
+
}
|
227 |
+
|
228 |
+
async fn get_hash_ring_filter(
|
229 |
+
dispatcher: &Dispatcher,
|
230 |
+
access: &Access,
|
231 |
+
collection: &str,
|
232 |
+
reqs: AccessRequirements,
|
233 |
+
expected_shard_id: ShardId,
|
234 |
+
verification_pass: &VerificationPass,
|
235 |
+
) -> StorageResult<Filter> {
|
236 |
+
let pass = access.check_collection_access(collection, reqs)?;
|
237 |
+
|
238 |
+
let shard_holder = dispatcher
|
239 |
+
.toc(access, verification_pass)
|
240 |
+
.get_collection(&pass)
|
241 |
+
.await?
|
242 |
+
.shards_holder();
|
243 |
+
|
244 |
+
let hash_ring_filter = shard_holder
|
245 |
+
.read()
|
246 |
+
.await
|
247 |
+
.hash_ring_filter(expected_shard_id)
|
248 |
+
.ok_or_else(|| {
|
249 |
+
StorageError::bad_request(format!(
|
250 |
+
"shard {expected_shard_id} does not exist in collection {collection}"
|
251 |
+
))
|
252 |
+
})?;
|
253 |
+
|
254 |
+
let condition = Condition::CustomIdChecker(Arc::new(hash_ring_filter));
|
255 |
+
let filter = Filter::new_must(condition);
|
256 |
+
|
257 |
+
Ok(filter)
|
258 |
+
}
|
259 |
+
|
260 |
+
fn merge_with_optional_filter(filter: Option<Filter>, hash_ring: Option<Filter>) -> Option<Filter> {
|
261 |
+
match (filter, hash_ring) {
|
262 |
+
(Some(filter), Some(hash_ring)) => hash_ring.merge_owned(filter).into(),
|
263 |
+
(Some(filter), None) => filter.into(),
|
264 |
+
(None, Some(hash_ring)) => hash_ring.into(),
|
265 |
+
_ => None,
|
266 |
+
}
|
267 |
+
}
|
src/actix/api/mod.rs
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use common::validation::validate_collection_name;
|
2 |
+
use serde::Deserialize;
|
3 |
+
use validator::Validate;
|
4 |
+
|
5 |
+
pub mod cluster_api;
|
6 |
+
pub mod collections_api;
|
7 |
+
pub mod count_api;
|
8 |
+
pub mod debug_api;
|
9 |
+
pub mod discovery_api;
|
10 |
+
pub mod facet_api;
|
11 |
+
pub mod issues_api;
|
12 |
+
pub mod local_shard_api;
|
13 |
+
pub mod query_api;
|
14 |
+
pub mod read_params;
|
15 |
+
pub mod recommend_api;
|
16 |
+
pub mod retrieve_api;
|
17 |
+
pub mod search_api;
|
18 |
+
pub mod service_api;
|
19 |
+
pub mod shards_api;
|
20 |
+
pub mod snapshot_api;
|
21 |
+
pub mod update_api;
|
22 |
+
|
23 |
+
/// A collection path with stricter validation
|
24 |
+
///
|
25 |
+
/// Validation for collection paths has been made more strict over time.
|
26 |
+
/// To prevent breaking changes on existing collections, this is only enforced for newly created
|
27 |
+
/// collections. Basic validation is enforced everywhere else.
|
28 |
+
#[derive(Deserialize, Validate)]
|
29 |
+
struct StrictCollectionPath {
|
30 |
+
#[validate(
|
31 |
+
length(min = 1, max = 255),
|
32 |
+
custom(function = "validate_collection_name")
|
33 |
+
)]
|
34 |
+
name: String,
|
35 |
+
}
|
36 |
+
|
37 |
+
/// A collection path with basic validation
|
38 |
+
///
|
39 |
+
/// Validation for collection paths has been made more strict over time.
|
40 |
+
/// To prevent breaking changes on existing collections, this is only enforced for newly created
|
41 |
+
/// collections. Basic validation is enforced everywhere else.
|
42 |
+
#[derive(Deserialize, Validate)]
|
43 |
+
struct CollectionPath {
|
44 |
+
#[validate(length(min = 1, max = 255))]
|
45 |
+
name: String,
|
46 |
+
}
|
src/actix/api/query_api.rs
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{post, web, Responder};
|
2 |
+
use actix_web_validator::{Json, Path, Query};
|
3 |
+
use api::rest::{QueryGroupsRequest, QueryRequest, QueryRequestBatch, QueryResponse};
|
4 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
5 |
+
use itertools::Itertools;
|
6 |
+
use storage::content_manager::collection_verification::{
|
7 |
+
check_strict_mode, check_strict_mode_batch,
|
8 |
+
};
|
9 |
+
use storage::content_manager::errors::StorageError;
|
10 |
+
use storage::dispatcher::Dispatcher;
|
11 |
+
use tokio::time::Instant;
|
12 |
+
|
13 |
+
use super::read_params::ReadParams;
|
14 |
+
use super::CollectionPath;
|
15 |
+
use crate::actix::auth::ActixAccess;
|
16 |
+
use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
|
17 |
+
use crate::common::inference::query_requests_rest::{
|
18 |
+
convert_query_groups_request_from_rest, convert_query_request_from_rest,
|
19 |
+
};
|
20 |
+
use crate::common::points::do_query_point_groups;
|
21 |
+
use crate::settings::ServiceConfig;
|
22 |
+
|
23 |
+
#[post("/collections/{name}/points/query")]
|
24 |
+
async fn query_points(
|
25 |
+
dispatcher: web::Data<Dispatcher>,
|
26 |
+
collection: Path<CollectionPath>,
|
27 |
+
request: Json<QueryRequest>,
|
28 |
+
params: Query<ReadParams>,
|
29 |
+
service_config: web::Data<ServiceConfig>,
|
30 |
+
ActixAccess(access): ActixAccess,
|
31 |
+
) -> impl Responder {
|
32 |
+
let QueryRequest {
|
33 |
+
internal: query_request,
|
34 |
+
shard_key,
|
35 |
+
} = request.into_inner();
|
36 |
+
|
37 |
+
let pass = match check_strict_mode(
|
38 |
+
&query_request,
|
39 |
+
params.timeout_as_secs(),
|
40 |
+
&collection.name,
|
41 |
+
&dispatcher,
|
42 |
+
&access,
|
43 |
+
)
|
44 |
+
.await
|
45 |
+
{
|
46 |
+
Ok(pass) => pass,
|
47 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
48 |
+
};
|
49 |
+
|
50 |
+
let request_hw_counter = get_request_hardware_counter(
|
51 |
+
&dispatcher,
|
52 |
+
collection.name.clone(),
|
53 |
+
service_config.hardware_reporting(),
|
54 |
+
);
|
55 |
+
let timing = Instant::now();
|
56 |
+
|
57 |
+
let shard_selection = match shard_key {
|
58 |
+
None => ShardSelectorInternal::All,
|
59 |
+
Some(shard_keys) => shard_keys.into(),
|
60 |
+
};
|
61 |
+
let hw_measurement_acc = request_hw_counter.get_counter();
|
62 |
+
|
63 |
+
let result = async move {
|
64 |
+
let request = convert_query_request_from_rest(query_request).await?;
|
65 |
+
|
66 |
+
let points = dispatcher
|
67 |
+
.toc(&access, &pass)
|
68 |
+
.query_batch(
|
69 |
+
&collection.name,
|
70 |
+
vec![(request, shard_selection)],
|
71 |
+
params.consistency,
|
72 |
+
access,
|
73 |
+
params.timeout(),
|
74 |
+
hw_measurement_acc,
|
75 |
+
)
|
76 |
+
.await?
|
77 |
+
.pop()
|
78 |
+
.ok_or_else(|| {
|
79 |
+
StorageError::service_error("Expected at least one response for one query")
|
80 |
+
})?
|
81 |
+
.into_iter()
|
82 |
+
.map(api::rest::ScoredPoint::from)
|
83 |
+
.collect_vec();
|
84 |
+
|
85 |
+
Ok(QueryResponse { points })
|
86 |
+
}
|
87 |
+
.await;
|
88 |
+
|
89 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
90 |
+
}
|
91 |
+
|
92 |
+
#[post("/collections/{name}/points/query/batch")]
|
93 |
+
async fn query_points_batch(
|
94 |
+
dispatcher: web::Data<Dispatcher>,
|
95 |
+
collection: Path<CollectionPath>,
|
96 |
+
request: Json<QueryRequestBatch>,
|
97 |
+
params: Query<ReadParams>,
|
98 |
+
service_config: web::Data<ServiceConfig>,
|
99 |
+
ActixAccess(access): ActixAccess,
|
100 |
+
) -> impl Responder {
|
101 |
+
let QueryRequestBatch { searches } = request.into_inner();
|
102 |
+
|
103 |
+
let pass = match check_strict_mode_batch(
|
104 |
+
searches.iter().map(|i| &i.internal),
|
105 |
+
params.timeout_as_secs(),
|
106 |
+
&collection.name,
|
107 |
+
&dispatcher,
|
108 |
+
&access,
|
109 |
+
)
|
110 |
+
.await
|
111 |
+
{
|
112 |
+
Ok(pass) => pass,
|
113 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
114 |
+
};
|
115 |
+
|
116 |
+
let request_hw_counter = get_request_hardware_counter(
|
117 |
+
&dispatcher,
|
118 |
+
collection.name.clone(),
|
119 |
+
service_config.hardware_reporting(),
|
120 |
+
);
|
121 |
+
let timing = Instant::now();
|
122 |
+
let hw_measurement_acc = request_hw_counter.get_counter();
|
123 |
+
|
124 |
+
let result = async move {
|
125 |
+
let mut batch = Vec::with_capacity(searches.len());
|
126 |
+
for request in searches {
|
127 |
+
let QueryRequest {
|
128 |
+
internal,
|
129 |
+
shard_key,
|
130 |
+
} = request;
|
131 |
+
|
132 |
+
let request = convert_query_request_from_rest(internal).await?;
|
133 |
+
let shard_selection = match shard_key {
|
134 |
+
None => ShardSelectorInternal::All,
|
135 |
+
Some(shard_keys) => shard_keys.into(),
|
136 |
+
};
|
137 |
+
|
138 |
+
batch.push((request, shard_selection));
|
139 |
+
}
|
140 |
+
|
141 |
+
let res = dispatcher
|
142 |
+
.toc(&access, &pass)
|
143 |
+
.query_batch(
|
144 |
+
&collection.name,
|
145 |
+
batch,
|
146 |
+
params.consistency,
|
147 |
+
access,
|
148 |
+
params.timeout(),
|
149 |
+
hw_measurement_acc,
|
150 |
+
)
|
151 |
+
.await?
|
152 |
+
.into_iter()
|
153 |
+
.map(|response| QueryResponse {
|
154 |
+
points: response
|
155 |
+
.into_iter()
|
156 |
+
.map(api::rest::ScoredPoint::from)
|
157 |
+
.collect_vec(),
|
158 |
+
})
|
159 |
+
.collect_vec();
|
160 |
+
Ok(res)
|
161 |
+
}
|
162 |
+
.await;
|
163 |
+
|
164 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
165 |
+
}
|
166 |
+
|
167 |
+
#[post("/collections/{name}/points/query/groups")]
|
168 |
+
async fn query_points_groups(
|
169 |
+
dispatcher: web::Data<Dispatcher>,
|
170 |
+
collection: Path<CollectionPath>,
|
171 |
+
request: Json<QueryGroupsRequest>,
|
172 |
+
params: Query<ReadParams>,
|
173 |
+
service_config: web::Data<ServiceConfig>,
|
174 |
+
ActixAccess(access): ActixAccess,
|
175 |
+
) -> impl Responder {
|
176 |
+
let QueryGroupsRequest {
|
177 |
+
search_group_request,
|
178 |
+
shard_key,
|
179 |
+
} = request.into_inner();
|
180 |
+
|
181 |
+
let pass = match check_strict_mode(
|
182 |
+
&search_group_request,
|
183 |
+
params.timeout_as_secs(),
|
184 |
+
&collection.name,
|
185 |
+
&dispatcher,
|
186 |
+
&access,
|
187 |
+
)
|
188 |
+
.await
|
189 |
+
{
|
190 |
+
Ok(pass) => pass,
|
191 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
192 |
+
};
|
193 |
+
|
194 |
+
let request_hw_counter = get_request_hardware_counter(
|
195 |
+
&dispatcher,
|
196 |
+
collection.name.clone(),
|
197 |
+
service_config.hardware_reporting(),
|
198 |
+
);
|
199 |
+
let timing = Instant::now();
|
200 |
+
let hw_measurement_acc = request_hw_counter.get_counter();
|
201 |
+
|
202 |
+
let result = async move {
|
203 |
+
let shard_selection = match shard_key {
|
204 |
+
None => ShardSelectorInternal::All,
|
205 |
+
Some(shard_keys) => shard_keys.into(),
|
206 |
+
};
|
207 |
+
|
208 |
+
let query_group_request =
|
209 |
+
convert_query_groups_request_from_rest(search_group_request).await?;
|
210 |
+
|
211 |
+
do_query_point_groups(
|
212 |
+
dispatcher.toc(&access, &pass),
|
213 |
+
&collection.name,
|
214 |
+
query_group_request,
|
215 |
+
params.consistency,
|
216 |
+
shard_selection,
|
217 |
+
access,
|
218 |
+
params.timeout(),
|
219 |
+
hw_measurement_acc,
|
220 |
+
)
|
221 |
+
.await
|
222 |
+
}
|
223 |
+
.await;
|
224 |
+
|
225 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
226 |
+
}
|
227 |
+
|
228 |
+
pub fn config_query_api(cfg: &mut web::ServiceConfig) {
|
229 |
+
cfg.service(query_points);
|
230 |
+
cfg.service(query_points_batch);
|
231 |
+
cfg.service(query_points_groups);
|
232 |
+
}
|
src/actix/api/read_params.rs
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::num::NonZeroU64;
|
2 |
+
use std::time::Duration;
|
3 |
+
|
4 |
+
use collection::operations::consistency_params::ReadConsistency;
|
5 |
+
use schemars::JsonSchema;
|
6 |
+
use serde::Deserialize;
|
7 |
+
use validator::Validate;
|
8 |
+
|
9 |
+
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Deserialize, JsonSchema, Validate)]
|
10 |
+
pub struct ReadParams {
|
11 |
+
#[serde(default, deserialize_with = "deserialize_read_consistency")]
|
12 |
+
#[validate(nested)]
|
13 |
+
pub consistency: Option<ReadConsistency>,
|
14 |
+
/// If set, overrides global timeout for this request. Unit is seconds.
|
15 |
+
pub timeout: Option<NonZeroU64>,
|
16 |
+
}
|
17 |
+
|
18 |
+
impl ReadParams {
|
19 |
+
pub fn timeout(&self) -> Option<Duration> {
|
20 |
+
self.timeout.map(|num| Duration::from_secs(num.get()))
|
21 |
+
}
|
22 |
+
|
23 |
+
pub(crate) fn timeout_as_secs(&self) -> Option<usize> {
|
24 |
+
self.timeout.map(|i| i.get() as usize)
|
25 |
+
}
|
26 |
+
}
|
27 |
+
|
28 |
+
fn deserialize_read_consistency<'de, D>(
|
29 |
+
deserializer: D,
|
30 |
+
) -> Result<Option<ReadConsistency>, D::Error>
|
31 |
+
where
|
32 |
+
D: serde::Deserializer<'de>,
|
33 |
+
{
|
34 |
+
#[derive(Deserialize)]
|
35 |
+
#[serde(untagged)]
|
36 |
+
enum Helper<'a> {
|
37 |
+
ReadConsistency(ReadConsistency),
|
38 |
+
Str(&'a str),
|
39 |
+
}
|
40 |
+
|
41 |
+
match Helper::deserialize(deserializer)? {
|
42 |
+
Helper::ReadConsistency(read_consistency) => Ok(Some(read_consistency)),
|
43 |
+
Helper::Str("") => Ok(None),
|
44 |
+
_ => Err(serde::de::Error::custom(
|
45 |
+
"failed to deserialize read consistency query parameter value",
|
46 |
+
)),
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
#[cfg(test)]
|
51 |
+
mod test {
|
52 |
+
use collection::operations::consistency_params::ReadConsistencyType;
|
53 |
+
|
54 |
+
use super::*;
|
55 |
+
|
56 |
+
#[test]
|
57 |
+
fn deserialize_empty_string() {
|
58 |
+
test_str("", ReadParams::default());
|
59 |
+
}
|
60 |
+
|
61 |
+
#[test]
|
62 |
+
fn deserialize_empty_value() {
|
63 |
+
test("", ReadParams::default());
|
64 |
+
}
|
65 |
+
|
66 |
+
#[test]
|
67 |
+
fn deserialize_type() {
|
68 |
+
test("all", from_type(ReadConsistencyType::All));
|
69 |
+
test("majority", from_type(ReadConsistencyType::Majority));
|
70 |
+
test("quorum", from_type(ReadConsistencyType::Quorum));
|
71 |
+
}
|
72 |
+
|
73 |
+
#[test]
|
74 |
+
fn deserialize_factor() {
|
75 |
+
for factor in 1..42 {
|
76 |
+
test(&factor.to_string(), from_factor(factor));
|
77 |
+
}
|
78 |
+
}
|
79 |
+
|
80 |
+
#[test]
|
81 |
+
fn try_deserialize_factor_0() {
|
82 |
+
assert!(try_deserialize(&str("0")).is_err());
|
83 |
+
}
|
84 |
+
|
85 |
+
fn test(value: &str, params: ReadParams) {
|
86 |
+
test_str(&str(value), params);
|
87 |
+
}
|
88 |
+
|
89 |
+
fn test_str(str: &str, params: ReadParams) {
|
90 |
+
assert_eq!(deserialize(str), params);
|
91 |
+
}
|
92 |
+
|
93 |
+
fn deserialize(str: &str) -> ReadParams {
|
94 |
+
try_deserialize(str).unwrap()
|
95 |
+
}
|
96 |
+
|
97 |
+
fn try_deserialize(str: &str) -> Result<ReadParams, serde_urlencoded::de::Error> {
|
98 |
+
serde_urlencoded::from_str(str)
|
99 |
+
}
|
100 |
+
|
101 |
+
fn str(value: &str) -> String {
|
102 |
+
format!("consistency={value}")
|
103 |
+
}
|
104 |
+
|
105 |
+
fn from_type(r#type: ReadConsistencyType) -> ReadParams {
|
106 |
+
ReadParams {
|
107 |
+
consistency: Some(ReadConsistency::Type(r#type)),
|
108 |
+
..Default::default()
|
109 |
+
}
|
110 |
+
}
|
111 |
+
|
112 |
+
fn from_factor(factor: usize) -> ReadParams {
|
113 |
+
ReadParams {
|
114 |
+
consistency: Some(ReadConsistency::Factor(factor)),
|
115 |
+
..Default::default()
|
116 |
+
}
|
117 |
+
}
|
118 |
+
}
|
src/actix/api/recommend_api.rs
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::time::Duration;
|
2 |
+
|
3 |
+
use actix_web::{post, web, Responder};
|
4 |
+
use actix_web_validator::{Json, Path, Query};
|
5 |
+
use collection::operations::consistency_params::ReadConsistency;
|
6 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
7 |
+
use collection::operations::types::{
|
8 |
+
RecommendGroupsRequest, RecommendRequest, RecommendRequestBatch,
|
9 |
+
};
|
10 |
+
use common::counter::hardware_accumulator::HwMeasurementAcc;
|
11 |
+
use itertools::Itertools;
|
12 |
+
use segment::types::ScoredPoint;
|
13 |
+
use storage::content_manager::collection_verification::{
|
14 |
+
check_strict_mode, check_strict_mode_batch,
|
15 |
+
};
|
16 |
+
use storage::content_manager::errors::StorageError;
|
17 |
+
use storage::content_manager::toc::TableOfContent;
|
18 |
+
use storage::dispatcher::Dispatcher;
|
19 |
+
use storage::rbac::Access;
|
20 |
+
use tokio::time::Instant;
|
21 |
+
|
22 |
+
use super::read_params::ReadParams;
|
23 |
+
use super::CollectionPath;
|
24 |
+
use crate::actix::auth::ActixAccess;
|
25 |
+
use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
|
26 |
+
use crate::settings::ServiceConfig;
|
27 |
+
|
28 |
+
#[post("/collections/{name}/points/recommend")]
|
29 |
+
async fn recommend_points(
|
30 |
+
dispatcher: web::Data<Dispatcher>,
|
31 |
+
collection: Path<CollectionPath>,
|
32 |
+
request: Json<RecommendRequest>,
|
33 |
+
params: Query<ReadParams>,
|
34 |
+
service_config: web::Data<ServiceConfig>,
|
35 |
+
ActixAccess(access): ActixAccess,
|
36 |
+
) -> impl Responder {
|
37 |
+
let RecommendRequest {
|
38 |
+
recommend_request,
|
39 |
+
shard_key,
|
40 |
+
} = request.into_inner();
|
41 |
+
|
42 |
+
let pass = match check_strict_mode(
|
43 |
+
&recommend_request,
|
44 |
+
params.timeout_as_secs(),
|
45 |
+
&collection.name,
|
46 |
+
&dispatcher,
|
47 |
+
&access,
|
48 |
+
)
|
49 |
+
.await
|
50 |
+
{
|
51 |
+
Ok(pass) => pass,
|
52 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
53 |
+
};
|
54 |
+
|
55 |
+
let shard_selection = match shard_key {
|
56 |
+
None => ShardSelectorInternal::All,
|
57 |
+
Some(shard_keys) => shard_keys.into(),
|
58 |
+
};
|
59 |
+
|
60 |
+
let request_hw_counter = get_request_hardware_counter(
|
61 |
+
&dispatcher,
|
62 |
+
collection.name.clone(),
|
63 |
+
service_config.hardware_reporting(),
|
64 |
+
);
|
65 |
+
|
66 |
+
let timing = Instant::now();
|
67 |
+
|
68 |
+
let result = dispatcher
|
69 |
+
.toc(&access, &pass)
|
70 |
+
.recommend(
|
71 |
+
&collection.name,
|
72 |
+
recommend_request,
|
73 |
+
params.consistency,
|
74 |
+
shard_selection,
|
75 |
+
access,
|
76 |
+
params.timeout(),
|
77 |
+
request_hw_counter.get_counter(),
|
78 |
+
)
|
79 |
+
.await
|
80 |
+
.map(|scored_points| {
|
81 |
+
scored_points
|
82 |
+
.into_iter()
|
83 |
+
.map(api::rest::ScoredPoint::from)
|
84 |
+
.collect_vec()
|
85 |
+
});
|
86 |
+
|
87 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
88 |
+
}
|
89 |
+
|
90 |
+
async fn do_recommend_batch_points(
|
91 |
+
toc: &TableOfContent,
|
92 |
+
collection_name: &str,
|
93 |
+
request: RecommendRequestBatch,
|
94 |
+
read_consistency: Option<ReadConsistency>,
|
95 |
+
access: Access,
|
96 |
+
timeout: Option<Duration>,
|
97 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
98 |
+
) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
|
99 |
+
let requests = request
|
100 |
+
.searches
|
101 |
+
.into_iter()
|
102 |
+
.map(|req| {
|
103 |
+
let shard_selector = match req.shard_key {
|
104 |
+
None => ShardSelectorInternal::All,
|
105 |
+
Some(shard_key) => ShardSelectorInternal::from(shard_key),
|
106 |
+
};
|
107 |
+
|
108 |
+
(req.recommend_request, shard_selector)
|
109 |
+
})
|
110 |
+
.collect();
|
111 |
+
|
112 |
+
toc.recommend_batch(
|
113 |
+
collection_name,
|
114 |
+
requests,
|
115 |
+
read_consistency,
|
116 |
+
access,
|
117 |
+
timeout,
|
118 |
+
hw_measurement_acc,
|
119 |
+
)
|
120 |
+
.await
|
121 |
+
}
|
122 |
+
|
123 |
+
#[post("/collections/{name}/points/recommend/batch")]
|
124 |
+
async fn recommend_batch_points(
|
125 |
+
dispatcher: web::Data<Dispatcher>,
|
126 |
+
collection: Path<CollectionPath>,
|
127 |
+
request: Json<RecommendRequestBatch>,
|
128 |
+
params: Query<ReadParams>,
|
129 |
+
service_config: web::Data<ServiceConfig>,
|
130 |
+
ActixAccess(access): ActixAccess,
|
131 |
+
) -> impl Responder {
|
132 |
+
let pass = match check_strict_mode_batch(
|
133 |
+
request.searches.iter().map(|i| &i.recommend_request),
|
134 |
+
params.timeout_as_secs(),
|
135 |
+
&collection.name,
|
136 |
+
&dispatcher,
|
137 |
+
&access,
|
138 |
+
)
|
139 |
+
.await
|
140 |
+
{
|
141 |
+
Ok(pass) => pass,
|
142 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
143 |
+
};
|
144 |
+
|
145 |
+
let request_hw_counter = get_request_hardware_counter(
|
146 |
+
&dispatcher,
|
147 |
+
collection.name.clone(),
|
148 |
+
service_config.hardware_reporting(),
|
149 |
+
);
|
150 |
+
let timing = Instant::now();
|
151 |
+
|
152 |
+
let result = do_recommend_batch_points(
|
153 |
+
dispatcher.toc(&access, &pass),
|
154 |
+
&collection.name,
|
155 |
+
request.into_inner(),
|
156 |
+
params.consistency,
|
157 |
+
access,
|
158 |
+
params.timeout(),
|
159 |
+
request_hw_counter.get_counter(),
|
160 |
+
)
|
161 |
+
.await
|
162 |
+
.map(|batch_scored_points| {
|
163 |
+
batch_scored_points
|
164 |
+
.into_iter()
|
165 |
+
.map(|scored_points| {
|
166 |
+
scored_points
|
167 |
+
.into_iter()
|
168 |
+
.map(api::rest::ScoredPoint::from)
|
169 |
+
.collect_vec()
|
170 |
+
})
|
171 |
+
.collect_vec()
|
172 |
+
});
|
173 |
+
|
174 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
175 |
+
}
|
176 |
+
|
177 |
+
#[post("/collections/{name}/points/recommend/groups")]
|
178 |
+
async fn recommend_point_groups(
|
179 |
+
dispatcher: web::Data<Dispatcher>,
|
180 |
+
collection: Path<CollectionPath>,
|
181 |
+
request: Json<RecommendGroupsRequest>,
|
182 |
+
params: Query<ReadParams>,
|
183 |
+
service_config: web::Data<ServiceConfig>,
|
184 |
+
ActixAccess(access): ActixAccess,
|
185 |
+
) -> impl Responder {
|
186 |
+
let RecommendGroupsRequest {
|
187 |
+
recommend_group_request,
|
188 |
+
shard_key,
|
189 |
+
} = request.into_inner();
|
190 |
+
|
191 |
+
let pass = match check_strict_mode(
|
192 |
+
&recommend_group_request,
|
193 |
+
params.timeout_as_secs(),
|
194 |
+
&collection.name,
|
195 |
+
&dispatcher,
|
196 |
+
&access,
|
197 |
+
)
|
198 |
+
.await
|
199 |
+
{
|
200 |
+
Ok(pass) => pass,
|
201 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
202 |
+
};
|
203 |
+
|
204 |
+
let shard_selection = match shard_key {
|
205 |
+
None => ShardSelectorInternal::All,
|
206 |
+
Some(shard_keys) => shard_keys.into(),
|
207 |
+
};
|
208 |
+
|
209 |
+
let request_hw_counter = get_request_hardware_counter(
|
210 |
+
&dispatcher,
|
211 |
+
collection.name.clone(),
|
212 |
+
service_config.hardware_reporting(),
|
213 |
+
);
|
214 |
+
let timing = Instant::now();
|
215 |
+
|
216 |
+
let result = crate::common::points::do_recommend_point_groups(
|
217 |
+
dispatcher.toc(&access, &pass),
|
218 |
+
&collection.name,
|
219 |
+
recommend_group_request,
|
220 |
+
params.consistency,
|
221 |
+
shard_selection,
|
222 |
+
access,
|
223 |
+
params.timeout(),
|
224 |
+
request_hw_counter.get_counter(),
|
225 |
+
)
|
226 |
+
.await;
|
227 |
+
|
228 |
+
helpers::process_response(result, timing, request_hw_counter.to_rest_api())
|
229 |
+
}
|
230 |
+
// Configure services
|
231 |
+
pub fn config_recommend_api(cfg: &mut web::ServiceConfig) {
|
232 |
+
cfg.service(recommend_points)
|
233 |
+
.service(recommend_batch_points)
|
234 |
+
.service(recommend_point_groups);
|
235 |
+
}
|
src/actix/api/retrieve_api.rs
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::time::Duration;
|
2 |
+
|
3 |
+
use actix_web::{get, post, web, Responder};
|
4 |
+
use actix_web_validator::{Json, Path, Query};
|
5 |
+
use collection::operations::consistency_params::ReadConsistency;
|
6 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
7 |
+
use collection::operations::types::{
|
8 |
+
PointRequest, PointRequestInternal, RecordInternal, ScrollRequest,
|
9 |
+
};
|
10 |
+
use futures::TryFutureExt;
|
11 |
+
use itertools::Itertools;
|
12 |
+
use segment::types::{PointIdType, WithPayloadInterface};
|
13 |
+
use serde::Deserialize;
|
14 |
+
use storage::content_manager::collection_verification::{
|
15 |
+
check_strict_mode, check_strict_mode_timeout,
|
16 |
+
};
|
17 |
+
use storage::content_manager::errors::StorageError;
|
18 |
+
use storage::content_manager::toc::TableOfContent;
|
19 |
+
use storage::dispatcher::Dispatcher;
|
20 |
+
use storage::rbac::Access;
|
21 |
+
use tokio::time::Instant;
|
22 |
+
use validator::Validate;
|
23 |
+
|
24 |
+
use super::read_params::ReadParams;
|
25 |
+
use super::CollectionPath;
|
26 |
+
use crate::actix::auth::ActixAccess;
|
27 |
+
use crate::actix::helpers::{self, process_response_error};
|
28 |
+
use crate::common::points::do_get_points;
|
29 |
+
|
30 |
+
#[derive(Deserialize, Validate)]
|
31 |
+
struct PointPath {
|
32 |
+
#[validate(length(min = 1))]
|
33 |
+
// TODO: validate this is a valid ID type (usize or UUID)? Does currently error on deserialize.
|
34 |
+
id: String,
|
35 |
+
}
|
36 |
+
|
37 |
+
async fn do_get_point(
|
38 |
+
toc: &TableOfContent,
|
39 |
+
collection_name: &str,
|
40 |
+
point_id: PointIdType,
|
41 |
+
read_consistency: Option<ReadConsistency>,
|
42 |
+
timeout: Option<Duration>,
|
43 |
+
access: Access,
|
44 |
+
) -> Result<Option<RecordInternal>, StorageError> {
|
45 |
+
let request = PointRequestInternal {
|
46 |
+
ids: vec![point_id],
|
47 |
+
with_payload: Some(WithPayloadInterface::Bool(true)),
|
48 |
+
with_vector: true.into(),
|
49 |
+
};
|
50 |
+
|
51 |
+
let shard_selection = ShardSelectorInternal::All;
|
52 |
+
|
53 |
+
toc.retrieve(
|
54 |
+
collection_name,
|
55 |
+
request,
|
56 |
+
read_consistency,
|
57 |
+
timeout,
|
58 |
+
shard_selection,
|
59 |
+
access,
|
60 |
+
)
|
61 |
+
.await
|
62 |
+
.map(|points| points.into_iter().next())
|
63 |
+
}
|
64 |
+
|
65 |
+
#[get("/collections/{name}/points/{id}")]
|
66 |
+
async fn get_point(
|
67 |
+
dispatcher: web::Data<Dispatcher>,
|
68 |
+
collection: Path<CollectionPath>,
|
69 |
+
point: Path<PointPath>,
|
70 |
+
params: Query<ReadParams>,
|
71 |
+
ActixAccess(access): ActixAccess,
|
72 |
+
) -> impl Responder {
|
73 |
+
let pass = match check_strict_mode_timeout(
|
74 |
+
params.timeout_as_secs(),
|
75 |
+
&collection.name,
|
76 |
+
&dispatcher,
|
77 |
+
&access,
|
78 |
+
)
|
79 |
+
.await
|
80 |
+
{
|
81 |
+
Ok(p) => p,
|
82 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
83 |
+
};
|
84 |
+
|
85 |
+
helpers::time(async move {
|
86 |
+
let point_id: PointIdType = point.id.parse().map_err(|_| StorageError::BadInput {
|
87 |
+
description: format!("Can not recognize \"{}\" as point id", point.id),
|
88 |
+
})?;
|
89 |
+
|
90 |
+
let Some(record) = do_get_point(
|
91 |
+
dispatcher.toc(&access, &pass),
|
92 |
+
&collection.name,
|
93 |
+
point_id,
|
94 |
+
params.consistency,
|
95 |
+
params.timeout(),
|
96 |
+
access,
|
97 |
+
)
|
98 |
+
.await?
|
99 |
+
else {
|
100 |
+
return Err(StorageError::NotFound {
|
101 |
+
description: format!("Point with id {point_id} does not exists!"),
|
102 |
+
});
|
103 |
+
};
|
104 |
+
|
105 |
+
Ok(api::rest::Record::from(record))
|
106 |
+
})
|
107 |
+
.await
|
108 |
+
}
|
109 |
+
|
110 |
+
#[post("/collections/{name}/points")]
|
111 |
+
async fn get_points(
|
112 |
+
dispatcher: web::Data<Dispatcher>,
|
113 |
+
collection: Path<CollectionPath>,
|
114 |
+
request: Json<PointRequest>,
|
115 |
+
params: Query<ReadParams>,
|
116 |
+
ActixAccess(access): ActixAccess,
|
117 |
+
) -> impl Responder {
|
118 |
+
let pass = match check_strict_mode_timeout(
|
119 |
+
params.timeout_as_secs(),
|
120 |
+
&collection.name,
|
121 |
+
&dispatcher,
|
122 |
+
&access,
|
123 |
+
)
|
124 |
+
.await
|
125 |
+
{
|
126 |
+
Ok(p) => p,
|
127 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
128 |
+
};
|
129 |
+
|
130 |
+
let PointRequest {
|
131 |
+
point_request,
|
132 |
+
shard_key,
|
133 |
+
} = request.into_inner();
|
134 |
+
|
135 |
+
let shard_selection = match shard_key {
|
136 |
+
None => ShardSelectorInternal::All,
|
137 |
+
Some(shard_keys) => ShardSelectorInternal::from(shard_keys),
|
138 |
+
};
|
139 |
+
|
140 |
+
helpers::time(
|
141 |
+
do_get_points(
|
142 |
+
dispatcher.toc(&access, &pass),
|
143 |
+
&collection.name,
|
144 |
+
point_request,
|
145 |
+
params.consistency,
|
146 |
+
params.timeout(),
|
147 |
+
shard_selection,
|
148 |
+
access,
|
149 |
+
)
|
150 |
+
.map_ok(|response| {
|
151 |
+
response
|
152 |
+
.into_iter()
|
153 |
+
.map(api::rest::Record::from)
|
154 |
+
.collect_vec()
|
155 |
+
}),
|
156 |
+
)
|
157 |
+
.await
|
158 |
+
}
|
159 |
+
|
160 |
+
#[post("/collections/{name}/points/scroll")]
|
161 |
+
async fn scroll_points(
|
162 |
+
dispatcher: web::Data<Dispatcher>,
|
163 |
+
collection: Path<CollectionPath>,
|
164 |
+
request: Json<ScrollRequest>,
|
165 |
+
params: Query<ReadParams>,
|
166 |
+
ActixAccess(access): ActixAccess,
|
167 |
+
) -> impl Responder {
|
168 |
+
let ScrollRequest {
|
169 |
+
scroll_request,
|
170 |
+
shard_key,
|
171 |
+
} = request.into_inner();
|
172 |
+
|
173 |
+
let pass = match check_strict_mode(
|
174 |
+
&scroll_request,
|
175 |
+
params.timeout_as_secs(),
|
176 |
+
&collection.name,
|
177 |
+
&dispatcher,
|
178 |
+
&access,
|
179 |
+
)
|
180 |
+
.await
|
181 |
+
{
|
182 |
+
Ok(pass) => pass,
|
183 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
184 |
+
};
|
185 |
+
|
186 |
+
let shard_selection = match shard_key {
|
187 |
+
None => ShardSelectorInternal::All,
|
188 |
+
Some(shard_keys) => ShardSelectorInternal::from(shard_keys),
|
189 |
+
};
|
190 |
+
|
191 |
+
helpers::time(dispatcher.toc(&access, &pass).scroll(
|
192 |
+
&collection.name,
|
193 |
+
scroll_request,
|
194 |
+
params.consistency,
|
195 |
+
params.timeout(),
|
196 |
+
shard_selection,
|
197 |
+
access,
|
198 |
+
))
|
199 |
+
.await
|
200 |
+
}
|
src/actix/api/search_api.rs
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{post, web, HttpResponse, Responder};
|
2 |
+
use actix_web_validator::{Json, Path, Query};
|
3 |
+
use api::rest::{SearchMatrixOffsetsResponse, SearchMatrixPairsResponse, SearchMatrixRequest};
|
4 |
+
use collection::collection::distance_matrix::CollectionSearchMatrixRequest;
|
5 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
6 |
+
use collection::operations::types::{
|
7 |
+
CoreSearchRequest, SearchGroupsRequest, SearchRequest, SearchRequestBatch,
|
8 |
+
};
|
9 |
+
use itertools::Itertools;
|
10 |
+
use storage::content_manager::collection_verification::{
|
11 |
+
check_strict_mode, check_strict_mode_batch,
|
12 |
+
};
|
13 |
+
use storage::dispatcher::Dispatcher;
|
14 |
+
use tokio::time::Instant;
|
15 |
+
|
16 |
+
use super::read_params::ReadParams;
|
17 |
+
use super::CollectionPath;
|
18 |
+
use crate::actix::auth::ActixAccess;
|
19 |
+
use crate::actix::helpers::{
|
20 |
+
get_request_hardware_counter, process_response, process_response_error,
|
21 |
+
};
|
22 |
+
use crate::common::points::{
|
23 |
+
do_core_search_points, do_search_batch_points, do_search_point_groups, do_search_points_matrix,
|
24 |
+
};
|
25 |
+
use crate::settings::ServiceConfig;
|
26 |
+
|
27 |
+
#[post("/collections/{name}/points/search")]
|
28 |
+
async fn search_points(
|
29 |
+
dispatcher: web::Data<Dispatcher>,
|
30 |
+
collection: Path<CollectionPath>,
|
31 |
+
request: Json<SearchRequest>,
|
32 |
+
params: Query<ReadParams>,
|
33 |
+
service_config: web::Data<ServiceConfig>,
|
34 |
+
ActixAccess(access): ActixAccess,
|
35 |
+
) -> HttpResponse {
|
36 |
+
let SearchRequest {
|
37 |
+
search_request,
|
38 |
+
shard_key,
|
39 |
+
} = request.into_inner();
|
40 |
+
|
41 |
+
let pass = match check_strict_mode(
|
42 |
+
&search_request,
|
43 |
+
params.timeout_as_secs(),
|
44 |
+
&collection.name,
|
45 |
+
&dispatcher,
|
46 |
+
&access,
|
47 |
+
)
|
48 |
+
.await
|
49 |
+
{
|
50 |
+
Ok(pass) => pass,
|
51 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
52 |
+
};
|
53 |
+
|
54 |
+
let shard_selection = match shard_key {
|
55 |
+
None => ShardSelectorInternal::All,
|
56 |
+
Some(shard_keys) => shard_keys.into(),
|
57 |
+
};
|
58 |
+
|
59 |
+
let request_hw_counter = get_request_hardware_counter(
|
60 |
+
&dispatcher,
|
61 |
+
collection.name.clone(),
|
62 |
+
service_config.hardware_reporting(),
|
63 |
+
);
|
64 |
+
|
65 |
+
let timing = Instant::now();
|
66 |
+
|
67 |
+
let result = do_core_search_points(
|
68 |
+
dispatcher.toc(&access, &pass),
|
69 |
+
&collection.name,
|
70 |
+
search_request.into(),
|
71 |
+
params.consistency,
|
72 |
+
shard_selection,
|
73 |
+
access,
|
74 |
+
params.timeout(),
|
75 |
+
request_hw_counter.get_counter(),
|
76 |
+
)
|
77 |
+
.await
|
78 |
+
.map(|scored_points| {
|
79 |
+
scored_points
|
80 |
+
.into_iter()
|
81 |
+
.map(api::rest::ScoredPoint::from)
|
82 |
+
.collect_vec()
|
83 |
+
});
|
84 |
+
|
85 |
+
process_response(result, timing, request_hw_counter.to_rest_api())
|
86 |
+
}
|
87 |
+
|
88 |
+
#[post("/collections/{name}/points/search/batch")]
|
89 |
+
async fn batch_search_points(
|
90 |
+
dispatcher: web::Data<Dispatcher>,
|
91 |
+
collection: Path<CollectionPath>,
|
92 |
+
request: Json<SearchRequestBatch>,
|
93 |
+
params: Query<ReadParams>,
|
94 |
+
service_config: web::Data<ServiceConfig>,
|
95 |
+
ActixAccess(access): ActixAccess,
|
96 |
+
) -> HttpResponse {
|
97 |
+
let requests = request
|
98 |
+
.into_inner()
|
99 |
+
.searches
|
100 |
+
.into_iter()
|
101 |
+
.map(|req| {
|
102 |
+
let SearchRequest {
|
103 |
+
search_request,
|
104 |
+
shard_key,
|
105 |
+
} = req;
|
106 |
+
let shard_selection = match shard_key {
|
107 |
+
None => ShardSelectorInternal::All,
|
108 |
+
Some(shard_keys) => shard_keys.into(),
|
109 |
+
};
|
110 |
+
let core_request: CoreSearchRequest = search_request.into();
|
111 |
+
|
112 |
+
(core_request, shard_selection)
|
113 |
+
})
|
114 |
+
.collect::<Vec<_>>();
|
115 |
+
|
116 |
+
let pass = match check_strict_mode_batch(
|
117 |
+
requests.iter().map(|i| &i.0),
|
118 |
+
params.timeout_as_secs(),
|
119 |
+
&collection.name,
|
120 |
+
&dispatcher,
|
121 |
+
&access,
|
122 |
+
)
|
123 |
+
.await
|
124 |
+
{
|
125 |
+
Ok(pass) => pass,
|
126 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
127 |
+
};
|
128 |
+
|
129 |
+
let request_hw_counter = get_request_hardware_counter(
|
130 |
+
&dispatcher,
|
131 |
+
collection.name.clone(),
|
132 |
+
service_config.hardware_reporting(),
|
133 |
+
);
|
134 |
+
|
135 |
+
let timing = Instant::now();
|
136 |
+
|
137 |
+
let result = do_search_batch_points(
|
138 |
+
dispatcher.toc(&access, &pass),
|
139 |
+
&collection.name,
|
140 |
+
requests,
|
141 |
+
params.consistency,
|
142 |
+
access,
|
143 |
+
params.timeout(),
|
144 |
+
request_hw_counter.get_counter(),
|
145 |
+
)
|
146 |
+
.await
|
147 |
+
.map(|batch_scored_points| {
|
148 |
+
batch_scored_points
|
149 |
+
.into_iter()
|
150 |
+
.map(|scored_points| {
|
151 |
+
scored_points
|
152 |
+
.into_iter()
|
153 |
+
.map(api::rest::ScoredPoint::from)
|
154 |
+
.collect_vec()
|
155 |
+
})
|
156 |
+
.collect_vec()
|
157 |
+
});
|
158 |
+
|
159 |
+
process_response(result, timing, request_hw_counter.to_rest_api())
|
160 |
+
}
|
161 |
+
|
162 |
+
#[post("/collections/{name}/points/search/groups")]
|
163 |
+
async fn search_point_groups(
|
164 |
+
dispatcher: web::Data<Dispatcher>,
|
165 |
+
collection: Path<CollectionPath>,
|
166 |
+
request: Json<SearchGroupsRequest>,
|
167 |
+
params: Query<ReadParams>,
|
168 |
+
service_config: web::Data<ServiceConfig>,
|
169 |
+
ActixAccess(access): ActixAccess,
|
170 |
+
) -> HttpResponse {
|
171 |
+
let SearchGroupsRequest {
|
172 |
+
search_group_request,
|
173 |
+
shard_key,
|
174 |
+
} = request.into_inner();
|
175 |
+
|
176 |
+
let pass = match check_strict_mode(
|
177 |
+
&search_group_request,
|
178 |
+
params.timeout_as_secs(),
|
179 |
+
&collection.name,
|
180 |
+
&dispatcher,
|
181 |
+
&access,
|
182 |
+
)
|
183 |
+
.await
|
184 |
+
{
|
185 |
+
Ok(pass) => pass,
|
186 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
187 |
+
};
|
188 |
+
|
189 |
+
let shard_selection = match shard_key {
|
190 |
+
None => ShardSelectorInternal::All,
|
191 |
+
Some(shard_keys) => shard_keys.into(),
|
192 |
+
};
|
193 |
+
|
194 |
+
let request_hw_counter = get_request_hardware_counter(
|
195 |
+
&dispatcher,
|
196 |
+
collection.name.clone(),
|
197 |
+
service_config.hardware_reporting(),
|
198 |
+
);
|
199 |
+
let timing = Instant::now();
|
200 |
+
|
201 |
+
let result = do_search_point_groups(
|
202 |
+
dispatcher.toc(&access, &pass),
|
203 |
+
&collection.name,
|
204 |
+
search_group_request,
|
205 |
+
params.consistency,
|
206 |
+
shard_selection,
|
207 |
+
access,
|
208 |
+
params.timeout(),
|
209 |
+
request_hw_counter.get_counter(),
|
210 |
+
)
|
211 |
+
.await;
|
212 |
+
|
213 |
+
process_response(result, timing, request_hw_counter.to_rest_api())
|
214 |
+
}
|
215 |
+
|
216 |
+
#[post("/collections/{name}/points/search/matrix/pairs")]
|
217 |
+
async fn search_points_matrix_pairs(
|
218 |
+
dispatcher: web::Data<Dispatcher>,
|
219 |
+
collection: Path<CollectionPath>,
|
220 |
+
request: Json<SearchMatrixRequest>,
|
221 |
+
params: Query<ReadParams>,
|
222 |
+
service_config: web::Data<ServiceConfig>,
|
223 |
+
ActixAccess(access): ActixAccess,
|
224 |
+
) -> impl Responder {
|
225 |
+
let SearchMatrixRequest {
|
226 |
+
search_request,
|
227 |
+
shard_key,
|
228 |
+
} = request.into_inner();
|
229 |
+
|
230 |
+
let pass = match check_strict_mode(
|
231 |
+
&search_request,
|
232 |
+
params.timeout_as_secs(),
|
233 |
+
&collection.name,
|
234 |
+
&dispatcher,
|
235 |
+
&access,
|
236 |
+
)
|
237 |
+
.await
|
238 |
+
{
|
239 |
+
Ok(pass) => pass,
|
240 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
241 |
+
};
|
242 |
+
|
243 |
+
let shard_selection = match shard_key {
|
244 |
+
None => ShardSelectorInternal::All,
|
245 |
+
Some(shard_keys) => shard_keys.into(),
|
246 |
+
};
|
247 |
+
|
248 |
+
let request_hw_counter = get_request_hardware_counter(
|
249 |
+
&dispatcher,
|
250 |
+
collection.name.clone(),
|
251 |
+
service_config.hardware_reporting(),
|
252 |
+
);
|
253 |
+
let timing = Instant::now();
|
254 |
+
|
255 |
+
let response = do_search_points_matrix(
|
256 |
+
dispatcher.toc(&access, &pass),
|
257 |
+
&collection.name,
|
258 |
+
CollectionSearchMatrixRequest::from(search_request),
|
259 |
+
params.consistency,
|
260 |
+
shard_selection,
|
261 |
+
access,
|
262 |
+
params.timeout(),
|
263 |
+
request_hw_counter.get_counter(),
|
264 |
+
)
|
265 |
+
.await
|
266 |
+
.map(SearchMatrixPairsResponse::from);
|
267 |
+
|
268 |
+
process_response(response, timing, request_hw_counter.to_rest_api())
|
269 |
+
}
|
270 |
+
|
271 |
+
#[post("/collections/{name}/points/search/matrix/offsets")]
|
272 |
+
async fn search_points_matrix_offsets(
|
273 |
+
dispatcher: web::Data<Dispatcher>,
|
274 |
+
collection: Path<CollectionPath>,
|
275 |
+
request: Json<SearchMatrixRequest>,
|
276 |
+
params: Query<ReadParams>,
|
277 |
+
service_config: web::Data<ServiceConfig>,
|
278 |
+
ActixAccess(access): ActixAccess,
|
279 |
+
) -> impl Responder {
|
280 |
+
let SearchMatrixRequest {
|
281 |
+
search_request,
|
282 |
+
shard_key,
|
283 |
+
} = request.into_inner();
|
284 |
+
|
285 |
+
let pass = match check_strict_mode(
|
286 |
+
&search_request,
|
287 |
+
params.timeout_as_secs(),
|
288 |
+
&collection.name,
|
289 |
+
&dispatcher,
|
290 |
+
&access,
|
291 |
+
)
|
292 |
+
.await
|
293 |
+
{
|
294 |
+
Ok(pass) => pass,
|
295 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
296 |
+
};
|
297 |
+
|
298 |
+
let shard_selection = match shard_key {
|
299 |
+
None => ShardSelectorInternal::All,
|
300 |
+
Some(shard_keys) => shard_keys.into(),
|
301 |
+
};
|
302 |
+
|
303 |
+
let request_hw_counter = get_request_hardware_counter(
|
304 |
+
&dispatcher,
|
305 |
+
collection.name.clone(),
|
306 |
+
service_config.hardware_reporting(),
|
307 |
+
);
|
308 |
+
let timing = Instant::now();
|
309 |
+
|
310 |
+
let response = do_search_points_matrix(
|
311 |
+
dispatcher.toc(&access, &pass),
|
312 |
+
&collection.name,
|
313 |
+
CollectionSearchMatrixRequest::from(search_request),
|
314 |
+
params.consistency,
|
315 |
+
shard_selection,
|
316 |
+
access,
|
317 |
+
params.timeout(),
|
318 |
+
request_hw_counter.get_counter(),
|
319 |
+
)
|
320 |
+
.await
|
321 |
+
.map(SearchMatrixOffsetsResponse::from);
|
322 |
+
|
323 |
+
process_response(response, timing, request_hw_counter.to_rest_api())
|
324 |
+
}
|
325 |
+
|
326 |
+
// Configure services
|
327 |
+
pub fn config_search_api(cfg: &mut web::ServiceConfig) {
|
328 |
+
cfg.service(search_points)
|
329 |
+
.service(batch_search_points)
|
330 |
+
.service(search_point_groups)
|
331 |
+
.service(search_points_matrix_pairs)
|
332 |
+
.service(search_points_matrix_offsets);
|
333 |
+
}
|
src/actix/api/service_api.rs
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::future::Future;
|
2 |
+
use std::sync::Arc;
|
3 |
+
|
4 |
+
use actix_web::http::header::ContentType;
|
5 |
+
use actix_web::http::StatusCode;
|
6 |
+
use actix_web::rt::time::Instant;
|
7 |
+
use actix_web::web::Query;
|
8 |
+
use actix_web::{get, post, web, HttpResponse, Responder};
|
9 |
+
use actix_web_validator::Json;
|
10 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
11 |
+
use common::types::{DetailsLevel, TelemetryDetail};
|
12 |
+
use schemars::JsonSchema;
|
13 |
+
use segment::common::anonymize::Anonymize;
|
14 |
+
use serde::{Deserialize, Serialize};
|
15 |
+
use storage::content_manager::errors::StorageError;
|
16 |
+
use storage::dispatcher::Dispatcher;
|
17 |
+
use storage::rbac::AccessRequirements;
|
18 |
+
use tokio::sync::Mutex;
|
19 |
+
|
20 |
+
use crate::actix::auth::ActixAccess;
|
21 |
+
use crate::actix::helpers::{self, process_response_error};
|
22 |
+
use crate::common::health;
|
23 |
+
use crate::common::helpers::LocksOption;
|
24 |
+
use crate::common::metrics::MetricsData;
|
25 |
+
use crate::common::stacktrace::get_stack_trace;
|
26 |
+
use crate::common::telemetry::TelemetryCollector;
|
27 |
+
use crate::tracing;
|
28 |
+
|
29 |
+
#[derive(Deserialize, Serialize, JsonSchema)]
|
30 |
+
pub struct TelemetryParam {
|
31 |
+
pub anonymize: Option<bool>,
|
32 |
+
pub details_level: Option<usize>,
|
33 |
+
}
|
34 |
+
|
35 |
+
#[get("/telemetry")]
|
36 |
+
fn telemetry(
|
37 |
+
telemetry_collector: web::Data<Mutex<TelemetryCollector>>,
|
38 |
+
params: Query<TelemetryParam>,
|
39 |
+
ActixAccess(access): ActixAccess,
|
40 |
+
) -> impl Future<Output = HttpResponse> {
|
41 |
+
helpers::time(async move {
|
42 |
+
access.check_global_access(AccessRequirements::new())?;
|
43 |
+
let anonymize = params.anonymize.unwrap_or(false);
|
44 |
+
let details_level = params
|
45 |
+
.details_level
|
46 |
+
.map_or(DetailsLevel::Level0, Into::into);
|
47 |
+
let detail = TelemetryDetail {
|
48 |
+
level: details_level,
|
49 |
+
histograms: false,
|
50 |
+
};
|
51 |
+
let telemetry_collector = telemetry_collector.lock().await;
|
52 |
+
let telemetry_data = telemetry_collector.prepare_data(&access, detail).await;
|
53 |
+
let telemetry_data = if anonymize {
|
54 |
+
telemetry_data.anonymize()
|
55 |
+
} else {
|
56 |
+
telemetry_data
|
57 |
+
};
|
58 |
+
Ok(telemetry_data)
|
59 |
+
})
|
60 |
+
}
|
61 |
+
|
62 |
+
#[derive(Deserialize, Serialize, JsonSchema)]
|
63 |
+
pub struct MetricsParam {
|
64 |
+
pub anonymize: Option<bool>,
|
65 |
+
}
|
66 |
+
|
67 |
+
#[get("/metrics")]
|
68 |
+
async fn metrics(
|
69 |
+
telemetry_collector: web::Data<Mutex<TelemetryCollector>>,
|
70 |
+
params: Query<MetricsParam>,
|
71 |
+
ActixAccess(access): ActixAccess,
|
72 |
+
) -> HttpResponse {
|
73 |
+
if let Err(err) = access.check_global_access(AccessRequirements::new()) {
|
74 |
+
return process_response_error(err, Instant::now(), None);
|
75 |
+
}
|
76 |
+
|
77 |
+
let anonymize = params.anonymize.unwrap_or(false);
|
78 |
+
let telemetry_collector = telemetry_collector.lock().await;
|
79 |
+
let telemetry_data = telemetry_collector
|
80 |
+
.prepare_data(
|
81 |
+
&access,
|
82 |
+
TelemetryDetail {
|
83 |
+
level: DetailsLevel::Level1,
|
84 |
+
histograms: true,
|
85 |
+
},
|
86 |
+
)
|
87 |
+
.await;
|
88 |
+
let telemetry_data = if anonymize {
|
89 |
+
telemetry_data.anonymize()
|
90 |
+
} else {
|
91 |
+
telemetry_data
|
92 |
+
};
|
93 |
+
|
94 |
+
HttpResponse::Ok()
|
95 |
+
.content_type(ContentType::plaintext())
|
96 |
+
.body(MetricsData::from(telemetry_data).format_metrics())
|
97 |
+
}
|
98 |
+
|
99 |
+
#[post("/locks")]
|
100 |
+
fn put_locks(
|
101 |
+
dispatcher: web::Data<Dispatcher>,
|
102 |
+
locks_option: Json<LocksOption>,
|
103 |
+
ActixAccess(access): ActixAccess,
|
104 |
+
) -> impl Future<Output = HttpResponse> {
|
105 |
+
// Not a collection level request.
|
106 |
+
let pass = new_unchecked_verification_pass();
|
107 |
+
|
108 |
+
helpers::time(async move {
|
109 |
+
let toc = dispatcher.toc(&access, &pass);
|
110 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
111 |
+
let result = LocksOption {
|
112 |
+
write: toc.is_write_locked(),
|
113 |
+
error_message: toc.get_lock_error_message(),
|
114 |
+
};
|
115 |
+
toc.set_locks(locks_option.write, locks_option.error_message.clone());
|
116 |
+
Ok(result)
|
117 |
+
})
|
118 |
+
}
|
119 |
+
|
120 |
+
#[get("/locks")]
|
121 |
+
fn get_locks(
|
122 |
+
dispatcher: web::Data<Dispatcher>,
|
123 |
+
ActixAccess(access): ActixAccess,
|
124 |
+
) -> impl Future<Output = HttpResponse> {
|
125 |
+
// Not a collection level request.
|
126 |
+
let pass = new_unchecked_verification_pass();
|
127 |
+
|
128 |
+
helpers::time(async move {
|
129 |
+
access.check_global_access(AccessRequirements::new())?;
|
130 |
+
let toc = dispatcher.toc(&access, &pass);
|
131 |
+
let result = LocksOption {
|
132 |
+
write: toc.is_write_locked(),
|
133 |
+
error_message: toc.get_lock_error_message(),
|
134 |
+
};
|
135 |
+
Ok(result)
|
136 |
+
})
|
137 |
+
}
|
138 |
+
|
139 |
+
#[get("/stacktrace")]
|
140 |
+
fn get_stacktrace(ActixAccess(access): ActixAccess) -> impl Future<Output = HttpResponse> {
|
141 |
+
helpers::time(async move {
|
142 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
143 |
+
Ok(get_stack_trace())
|
144 |
+
})
|
145 |
+
}
|
146 |
+
|
147 |
+
#[get("/healthz")]
|
148 |
+
async fn healthz() -> impl Responder {
|
149 |
+
kubernetes_healthz()
|
150 |
+
}
|
151 |
+
|
152 |
+
#[get("/livez")]
|
153 |
+
async fn livez() -> impl Responder {
|
154 |
+
kubernetes_healthz()
|
155 |
+
}
|
156 |
+
|
157 |
+
#[get("/readyz")]
|
158 |
+
async fn readyz(health_checker: web::Data<Option<Arc<health::HealthChecker>>>) -> impl Responder {
|
159 |
+
let is_ready = match health_checker.as_ref() {
|
160 |
+
Some(health_checker) => health_checker.check_ready().await,
|
161 |
+
None => true,
|
162 |
+
};
|
163 |
+
|
164 |
+
let (status, body) = if is_ready {
|
165 |
+
(StatusCode::OK, "all shards are ready")
|
166 |
+
} else {
|
167 |
+
(StatusCode::SERVICE_UNAVAILABLE, "some shards are not ready")
|
168 |
+
};
|
169 |
+
|
170 |
+
HttpResponse::build(status)
|
171 |
+
.content_type(ContentType::plaintext())
|
172 |
+
.body(body)
|
173 |
+
}
|
174 |
+
|
175 |
+
/// Basic Kubernetes healthz endpoint
|
176 |
+
fn kubernetes_healthz() -> impl Responder {
|
177 |
+
HttpResponse::Ok()
|
178 |
+
.content_type(ContentType::plaintext())
|
179 |
+
.body("healthz check passed")
|
180 |
+
}
|
181 |
+
|
182 |
+
#[get("/logger")]
|
183 |
+
async fn get_logger_config(handle: web::Data<tracing::LoggerHandle>) -> impl Responder {
|
184 |
+
let timing = Instant::now();
|
185 |
+
let result = handle.get_config().await;
|
186 |
+
helpers::process_response(Ok(result), timing, None)
|
187 |
+
}
|
188 |
+
|
189 |
+
#[post("/logger")]
|
190 |
+
async fn update_logger_config(
|
191 |
+
handle: web::Data<tracing::LoggerHandle>,
|
192 |
+
config: web::Json<tracing::LoggerConfig>,
|
193 |
+
) -> impl Responder {
|
194 |
+
let timing = Instant::now();
|
195 |
+
|
196 |
+
let result = handle
|
197 |
+
.update_config(config.into_inner())
|
198 |
+
.await
|
199 |
+
.map(|_| true)
|
200 |
+
.map_err(|err| StorageError::service_error(err.to_string()));
|
201 |
+
|
202 |
+
helpers::process_response(result, timing, None)
|
203 |
+
}
|
204 |
+
|
205 |
+
// Configure services
|
206 |
+
pub fn config_service_api(cfg: &mut web::ServiceConfig) {
|
207 |
+
cfg.service(telemetry)
|
208 |
+
.service(metrics)
|
209 |
+
.service(put_locks)
|
210 |
+
.service(get_locks)
|
211 |
+
.service(get_stacktrace)
|
212 |
+
.service(healthz)
|
213 |
+
.service(livez)
|
214 |
+
.service(readyz)
|
215 |
+
.service(get_logger_config)
|
216 |
+
.service(update_logger_config);
|
217 |
+
}
|
src/actix/api/shards_api.rs
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::{post, put, web, Responder};
|
2 |
+
use actix_web_validator::{Json, Path, Query};
|
3 |
+
use collection::operations::cluster_ops::{
|
4 |
+
ClusterOperations, CreateShardingKey, CreateShardingKeyOperation, DropShardingKey,
|
5 |
+
DropShardingKeyOperation,
|
6 |
+
};
|
7 |
+
use storage::dispatcher::Dispatcher;
|
8 |
+
use tokio::time::Instant;
|
9 |
+
|
10 |
+
use crate::actix::api::collections_api::WaitTimeout;
|
11 |
+
use crate::actix::api::CollectionPath;
|
12 |
+
use crate::actix::auth::ActixAccess;
|
13 |
+
use crate::actix::helpers::process_response;
|
14 |
+
use crate::common::collections::do_update_collection_cluster;
|
15 |
+
|
16 |
+
// ToDo: introduce API for listing shard keys
|
17 |
+
|
18 |
+
#[put("/collections/{name}/shards")]
|
19 |
+
async fn create_shard_key(
|
20 |
+
dispatcher: web::Data<Dispatcher>,
|
21 |
+
collection: Path<CollectionPath>,
|
22 |
+
request: Json<CreateShardingKey>,
|
23 |
+
Query(query): Query<WaitTimeout>,
|
24 |
+
ActixAccess(access): ActixAccess,
|
25 |
+
) -> impl Responder {
|
26 |
+
let timing = Instant::now();
|
27 |
+
let wait_timeout = query.timeout();
|
28 |
+
let dispatcher = dispatcher.into_inner();
|
29 |
+
|
30 |
+
let request = request.into_inner();
|
31 |
+
|
32 |
+
let operation = ClusterOperations::CreateShardingKey(CreateShardingKeyOperation {
|
33 |
+
create_sharding_key: request,
|
34 |
+
});
|
35 |
+
|
36 |
+
let response = do_update_collection_cluster(
|
37 |
+
&dispatcher,
|
38 |
+
collection.name.clone(),
|
39 |
+
operation,
|
40 |
+
access,
|
41 |
+
wait_timeout,
|
42 |
+
)
|
43 |
+
.await;
|
44 |
+
|
45 |
+
process_response(response, timing, None)
|
46 |
+
}
|
47 |
+
|
48 |
+
#[post("/collections/{name}/shards/delete")]
|
49 |
+
async fn delete_shard_key(
|
50 |
+
dispatcher: web::Data<Dispatcher>,
|
51 |
+
collection: Path<CollectionPath>,
|
52 |
+
request: Json<DropShardingKey>,
|
53 |
+
Query(query): Query<WaitTimeout>,
|
54 |
+
ActixAccess(access): ActixAccess,
|
55 |
+
) -> impl Responder {
|
56 |
+
let timing = Instant::now();
|
57 |
+
let wait_timeout = query.timeout();
|
58 |
+
|
59 |
+
let dispatcher = dispatcher.into_inner();
|
60 |
+
let request = request.into_inner();
|
61 |
+
|
62 |
+
let operation = ClusterOperations::DropShardingKey(DropShardingKeyOperation {
|
63 |
+
drop_sharding_key: request,
|
64 |
+
});
|
65 |
+
|
66 |
+
let response = do_update_collection_cluster(
|
67 |
+
&dispatcher,
|
68 |
+
collection.name.clone(),
|
69 |
+
operation,
|
70 |
+
access,
|
71 |
+
wait_timeout,
|
72 |
+
)
|
73 |
+
.await;
|
74 |
+
|
75 |
+
process_response(response, timing, None)
|
76 |
+
}
|
77 |
+
|
78 |
+
pub fn config_shards_api(cfg: &mut web::ServiceConfig) {
|
79 |
+
cfg.service(create_shard_key).service(delete_shard_key);
|
80 |
+
}
|
src/actix/api/snapshot_api.rs
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::path::Path;
|
2 |
+
|
3 |
+
use actix_multipart::form::tempfile::TempFile;
|
4 |
+
use actix_multipart::form::MultipartForm;
|
5 |
+
use actix_web::{delete, get, post, put, web, Responder, Result};
|
6 |
+
use actix_web_validator as valid;
|
7 |
+
use collection::common::file_utils::move_file;
|
8 |
+
use collection::common::sha_256::{hash_file, hashes_equal};
|
9 |
+
use collection::common::snapshot_stream::SnapshotStream;
|
10 |
+
use collection::operations::snapshot_ops::{
|
11 |
+
ShardSnapshotRecover, SnapshotPriority, SnapshotRecover,
|
12 |
+
};
|
13 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
14 |
+
use collection::shards::shard::ShardId;
|
15 |
+
use futures::{FutureExt as _, TryFutureExt as _};
|
16 |
+
use reqwest::Url;
|
17 |
+
use schemars::JsonSchema;
|
18 |
+
use serde::{Deserialize, Serialize};
|
19 |
+
use storage::content_manager::errors::StorageError;
|
20 |
+
use storage::content_manager::snapshots::recover::do_recover_from_snapshot;
|
21 |
+
use storage::content_manager::snapshots::{
|
22 |
+
do_create_full_snapshot, do_delete_collection_snapshot, do_delete_full_snapshot,
|
23 |
+
do_list_full_snapshots,
|
24 |
+
};
|
25 |
+
use storage::content_manager::toc::TableOfContent;
|
26 |
+
use storage::dispatcher::Dispatcher;
|
27 |
+
use storage::rbac::{Access, AccessRequirements};
|
28 |
+
use uuid::Uuid;
|
29 |
+
use validator::Validate;
|
30 |
+
|
31 |
+
use super::{CollectionPath, StrictCollectionPath};
|
32 |
+
use crate::actix::auth::ActixAccess;
|
33 |
+
use crate::actix::helpers::{self, HttpError};
|
34 |
+
use crate::common;
|
35 |
+
use crate::common::collections::*;
|
36 |
+
use crate::common::http_client::HttpClient;
|
37 |
+
|
38 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
39 |
+
pub struct SnapshotUploadingParam {
|
40 |
+
pub wait: Option<bool>,
|
41 |
+
pub priority: Option<SnapshotPriority>,
|
42 |
+
|
43 |
+
/// Optional SHA256 checksum to verify snapshot integrity before recovery.
|
44 |
+
#[serde(default)]
|
45 |
+
#[validate(custom(function = "::common::validation::validate_sha256_hash"))]
|
46 |
+
pub checksum: Option<String>,
|
47 |
+
}
|
48 |
+
|
49 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
50 |
+
pub struct SnapshottingParam {
|
51 |
+
pub wait: Option<bool>,
|
52 |
+
}
|
53 |
+
|
54 |
+
#[derive(MultipartForm)]
|
55 |
+
pub struct SnapshottingForm {
|
56 |
+
snapshot: TempFile,
|
57 |
+
}
|
58 |
+
|
59 |
+
// Actix specific code
|
60 |
+
pub async fn do_get_full_snapshot(
|
61 |
+
toc: &TableOfContent,
|
62 |
+
access: Access,
|
63 |
+
snapshot_name: &str,
|
64 |
+
) -> Result<SnapshotStream, HttpError> {
|
65 |
+
access.check_global_access(AccessRequirements::new())?;
|
66 |
+
let snapshots_storage_manager = toc.get_snapshots_storage_manager()?;
|
67 |
+
let snapshot_path =
|
68 |
+
snapshots_storage_manager.get_full_snapshot_path(toc.snapshots_path(), snapshot_name)?;
|
69 |
+
let snapshot_stream = snapshots_storage_manager
|
70 |
+
.get_snapshot_stream(&snapshot_path)
|
71 |
+
.await?;
|
72 |
+
Ok(snapshot_stream)
|
73 |
+
}
|
74 |
+
|
75 |
+
pub async fn do_save_uploaded_snapshot(
|
76 |
+
toc: &TableOfContent,
|
77 |
+
collection_name: &str,
|
78 |
+
snapshot: TempFile,
|
79 |
+
) -> Result<Url, StorageError> {
|
80 |
+
let filename = snapshot
|
81 |
+
.file_name
|
82 |
+
// Sanitize the file name:
|
83 |
+
// - only take the top level path (no directories such as ../)
|
84 |
+
// - require the file name to be valid UTF-8
|
85 |
+
.and_then(|x| {
|
86 |
+
Path::new(&x)
|
87 |
+
.file_name()
|
88 |
+
.map(|filename| filename.to_owned())
|
89 |
+
})
|
90 |
+
.and_then(|x| x.to_str().map(|x| x.to_owned()))
|
91 |
+
.unwrap_or_else(|| Uuid::new_v4().to_string());
|
92 |
+
let collection_snapshot_path = toc.snapshots_path_for_collection(collection_name);
|
93 |
+
if !collection_snapshot_path.exists() {
|
94 |
+
log::debug!(
|
95 |
+
"Creating missing collection snapshots directory for {}",
|
96 |
+
collection_name
|
97 |
+
);
|
98 |
+
toc.create_snapshots_path(collection_name).await?;
|
99 |
+
}
|
100 |
+
|
101 |
+
let path = collection_snapshot_path.join(filename);
|
102 |
+
|
103 |
+
move_file(snapshot.file.path(), &path).await?;
|
104 |
+
|
105 |
+
let absolute_path = path.canonicalize()?;
|
106 |
+
|
107 |
+
let snapshot_location = Url::from_file_path(&absolute_path).map_err(|_| {
|
108 |
+
StorageError::service_error(format!(
|
109 |
+
"Failed to convert path to URL: {}",
|
110 |
+
absolute_path.display()
|
111 |
+
))
|
112 |
+
})?;
|
113 |
+
|
114 |
+
Ok(snapshot_location)
|
115 |
+
}
|
116 |
+
|
117 |
+
// Actix specific code
|
118 |
+
pub async fn do_get_snapshot(
|
119 |
+
toc: &TableOfContent,
|
120 |
+
access: Access,
|
121 |
+
collection_name: &str,
|
122 |
+
snapshot_name: &str,
|
123 |
+
) -> Result<SnapshotStream, HttpError> {
|
124 |
+
let collection_pass =
|
125 |
+
access.check_collection_access(collection_name, AccessRequirements::new().whole())?;
|
126 |
+
let collection: tokio::sync::RwLockReadGuard<collection::collection::Collection> =
|
127 |
+
toc.get_collection(&collection_pass).await?;
|
128 |
+
let snapshot_storage_manager = collection.get_snapshots_storage_manager()?;
|
129 |
+
let snapshot_path =
|
130 |
+
snapshot_storage_manager.get_snapshot_path(collection.snapshots_path(), snapshot_name)?;
|
131 |
+
let snapshot_stream = snapshot_storage_manager
|
132 |
+
.get_snapshot_stream(&snapshot_path)
|
133 |
+
.await?;
|
134 |
+
Ok(snapshot_stream)
|
135 |
+
}
|
136 |
+
|
137 |
+
#[get("/collections/{name}/snapshots")]
|
138 |
+
async fn list_snapshots(
|
139 |
+
dispatcher: web::Data<Dispatcher>,
|
140 |
+
path: web::Path<String>,
|
141 |
+
ActixAccess(access): ActixAccess,
|
142 |
+
) -> impl Responder {
|
143 |
+
// Nothing to verify.
|
144 |
+
let pass = new_unchecked_verification_pass();
|
145 |
+
|
146 |
+
helpers::time(do_list_snapshots(
|
147 |
+
dispatcher.toc(&access, &pass),
|
148 |
+
access,
|
149 |
+
&path,
|
150 |
+
))
|
151 |
+
.await
|
152 |
+
}
|
153 |
+
|
154 |
+
#[post("/collections/{name}/snapshots")]
|
155 |
+
async fn create_snapshot(
|
156 |
+
dispatcher: web::Data<Dispatcher>,
|
157 |
+
path: web::Path<String>,
|
158 |
+
params: valid::Query<SnapshottingParam>,
|
159 |
+
ActixAccess(access): ActixAccess,
|
160 |
+
) -> impl Responder {
|
161 |
+
// Nothing to verify.
|
162 |
+
let pass = new_unchecked_verification_pass();
|
163 |
+
|
164 |
+
let collection_name = path.into_inner();
|
165 |
+
|
166 |
+
let future = async move {
|
167 |
+
do_create_snapshot(
|
168 |
+
dispatcher.toc(&access, &pass).clone(),
|
169 |
+
access,
|
170 |
+
&collection_name,
|
171 |
+
)
|
172 |
+
.await
|
173 |
+
};
|
174 |
+
|
175 |
+
helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
|
176 |
+
}
|
177 |
+
|
178 |
+
#[post("/collections/{name}/snapshots/upload")]
|
179 |
+
async fn upload_snapshot(
|
180 |
+
dispatcher: web::Data<Dispatcher>,
|
181 |
+
http_client: web::Data<HttpClient>,
|
182 |
+
collection: valid::Path<StrictCollectionPath>,
|
183 |
+
MultipartForm(form): MultipartForm<SnapshottingForm>,
|
184 |
+
params: valid::Query<SnapshotUploadingParam>,
|
185 |
+
ActixAccess(access): ActixAccess,
|
186 |
+
) -> impl Responder {
|
187 |
+
let wait = params.wait;
|
188 |
+
|
189 |
+
// Nothing to verify.
|
190 |
+
let pass = new_unchecked_verification_pass();
|
191 |
+
|
192 |
+
let future = async move {
|
193 |
+
let snapshot = form.snapshot;
|
194 |
+
|
195 |
+
access.check_global_access(AccessRequirements::new().manage())?;
|
196 |
+
|
197 |
+
if let Some(checksum) = ¶ms.checksum {
|
198 |
+
let snapshot_checksum = hash_file(snapshot.file.path()).await?;
|
199 |
+
if !hashes_equal(snapshot_checksum.as_str(), checksum.as_str()) {
|
200 |
+
return Err(StorageError::checksum_mismatch(snapshot_checksum, checksum));
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
let snapshot_location =
|
205 |
+
do_save_uploaded_snapshot(dispatcher.toc(&access, &pass), &collection.name, snapshot)
|
206 |
+
.await?;
|
207 |
+
|
208 |
+
// Snapshot is a local file, we do not need an API key for that
|
209 |
+
let http_client = http_client.client(None)?;
|
210 |
+
|
211 |
+
let snapshot_recover = SnapshotRecover {
|
212 |
+
location: snapshot_location,
|
213 |
+
priority: params.priority,
|
214 |
+
checksum: None,
|
215 |
+
api_key: None,
|
216 |
+
};
|
217 |
+
|
218 |
+
do_recover_from_snapshot(
|
219 |
+
dispatcher.get_ref(),
|
220 |
+
&collection.name,
|
221 |
+
snapshot_recover,
|
222 |
+
access,
|
223 |
+
http_client,
|
224 |
+
)
|
225 |
+
.await
|
226 |
+
};
|
227 |
+
|
228 |
+
helpers::time_or_accept(future, wait.unwrap_or(true)).await
|
229 |
+
}
|
230 |
+
|
231 |
+
#[put("/collections/{name}/snapshots/recover")]
|
232 |
+
async fn recover_from_snapshot(
|
233 |
+
dispatcher: web::Data<Dispatcher>,
|
234 |
+
http_client: web::Data<HttpClient>,
|
235 |
+
collection: valid::Path<CollectionPath>,
|
236 |
+
request: valid::Json<SnapshotRecover>,
|
237 |
+
params: valid::Query<SnapshottingParam>,
|
238 |
+
ActixAccess(access): ActixAccess,
|
239 |
+
) -> impl Responder {
|
240 |
+
let future = async move {
|
241 |
+
let snapshot_recover = request.into_inner();
|
242 |
+
let http_client = http_client.client(snapshot_recover.api_key.as_deref())?;
|
243 |
+
|
244 |
+
do_recover_from_snapshot(
|
245 |
+
dispatcher.get_ref(),
|
246 |
+
&collection.name,
|
247 |
+
snapshot_recover,
|
248 |
+
access,
|
249 |
+
http_client,
|
250 |
+
)
|
251 |
+
.await
|
252 |
+
};
|
253 |
+
|
254 |
+
helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
|
255 |
+
}
|
256 |
+
|
257 |
+
#[get("/collections/{name}/snapshots/{snapshot_name}")]
|
258 |
+
async fn get_snapshot(
|
259 |
+
dispatcher: web::Data<Dispatcher>,
|
260 |
+
path: web::Path<(String, String)>,
|
261 |
+
ActixAccess(access): ActixAccess,
|
262 |
+
) -> impl Responder {
|
263 |
+
// Nothing to verify.
|
264 |
+
let pass = new_unchecked_verification_pass();
|
265 |
+
|
266 |
+
let (collection_name, snapshot_name) = path.into_inner();
|
267 |
+
do_get_snapshot(
|
268 |
+
dispatcher.toc(&access, &pass),
|
269 |
+
access,
|
270 |
+
&collection_name,
|
271 |
+
&snapshot_name,
|
272 |
+
)
|
273 |
+
.await
|
274 |
+
}
|
275 |
+
|
276 |
+
#[get("/snapshots")]
|
277 |
+
async fn list_full_snapshots(
|
278 |
+
dispatcher: web::Data<Dispatcher>,
|
279 |
+
ActixAccess(access): ActixAccess,
|
280 |
+
) -> impl Responder {
|
281 |
+
// nothing to verify.
|
282 |
+
let pass = new_unchecked_verification_pass();
|
283 |
+
|
284 |
+
helpers::time(do_list_full_snapshots(
|
285 |
+
dispatcher.toc(&access, &pass),
|
286 |
+
access,
|
287 |
+
))
|
288 |
+
.await
|
289 |
+
}
|
290 |
+
|
291 |
+
#[post("/snapshots")]
|
292 |
+
async fn create_full_snapshot(
|
293 |
+
dispatcher: web::Data<Dispatcher>,
|
294 |
+
params: valid::Query<SnapshottingParam>,
|
295 |
+
ActixAccess(access): ActixAccess,
|
296 |
+
) -> impl Responder {
|
297 |
+
let future = async move { do_create_full_snapshot(dispatcher.get_ref(), access).await };
|
298 |
+
helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
|
299 |
+
}
|
300 |
+
|
301 |
+
#[get("/snapshots/{snapshot_name}")]
|
302 |
+
async fn get_full_snapshot(
|
303 |
+
dispatcher: web::Data<Dispatcher>,
|
304 |
+
path: web::Path<String>,
|
305 |
+
ActixAccess(access): ActixAccess,
|
306 |
+
) -> impl Responder {
|
307 |
+
// nothing to verify.
|
308 |
+
let pass = new_unchecked_verification_pass();
|
309 |
+
|
310 |
+
let snapshot_name = path.into_inner();
|
311 |
+
do_get_full_snapshot(dispatcher.toc(&access, &pass), access, &snapshot_name).await
|
312 |
+
}
|
313 |
+
|
314 |
+
#[delete("/snapshots/{snapshot_name}")]
|
315 |
+
async fn delete_full_snapshot(
|
316 |
+
dispatcher: web::Data<Dispatcher>,
|
317 |
+
path: web::Path<String>,
|
318 |
+
params: valid::Query<SnapshottingParam>,
|
319 |
+
ActixAccess(access): ActixAccess,
|
320 |
+
) -> impl Responder {
|
321 |
+
let future = async move {
|
322 |
+
let snapshot_name = path.into_inner();
|
323 |
+
do_delete_full_snapshot(dispatcher.get_ref(), access, &snapshot_name).await
|
324 |
+
};
|
325 |
+
|
326 |
+
helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
|
327 |
+
}
|
328 |
+
|
329 |
+
#[delete("/collections/{name}/snapshots/{snapshot_name}")]
|
330 |
+
async fn delete_collection_snapshot(
|
331 |
+
dispatcher: web::Data<Dispatcher>,
|
332 |
+
path: web::Path<(String, String)>,
|
333 |
+
params: valid::Query<SnapshottingParam>,
|
334 |
+
ActixAccess(access): ActixAccess,
|
335 |
+
) -> impl Responder {
|
336 |
+
let future = async move {
|
337 |
+
let (collection_name, snapshot_name) = path.into_inner();
|
338 |
+
|
339 |
+
do_delete_collection_snapshot(
|
340 |
+
dispatcher.get_ref(),
|
341 |
+
access,
|
342 |
+
&collection_name,
|
343 |
+
&snapshot_name,
|
344 |
+
)
|
345 |
+
.await
|
346 |
+
};
|
347 |
+
|
348 |
+
helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
|
349 |
+
}
|
350 |
+
|
351 |
+
#[get("/collections/{collection}/shards/{shard}/snapshots")]
|
352 |
+
async fn list_shard_snapshots(
|
353 |
+
dispatcher: web::Data<Dispatcher>,
|
354 |
+
path: web::Path<(String, ShardId)>,
|
355 |
+
ActixAccess(access): ActixAccess,
|
356 |
+
) -> impl Responder {
|
357 |
+
// nothing to verify.
|
358 |
+
let pass = new_unchecked_verification_pass();
|
359 |
+
|
360 |
+
let (collection, shard) = path.into_inner();
|
361 |
+
|
362 |
+
let future = common::snapshots::list_shard_snapshots(
|
363 |
+
dispatcher.toc(&access, &pass).clone(),
|
364 |
+
access,
|
365 |
+
collection,
|
366 |
+
shard,
|
367 |
+
)
|
368 |
+
.map_err(Into::into);
|
369 |
+
|
370 |
+
helpers::time(future).await
|
371 |
+
}
|
372 |
+
|
373 |
+
#[post("/collections/{collection}/shards/{shard}/snapshots")]
|
374 |
+
async fn create_shard_snapshot(
|
375 |
+
dispatcher: web::Data<Dispatcher>,
|
376 |
+
path: web::Path<(String, ShardId)>,
|
377 |
+
query: web::Query<SnapshottingParam>,
|
378 |
+
ActixAccess(access): ActixAccess,
|
379 |
+
) -> impl Responder {
|
380 |
+
// nothing to verify.
|
381 |
+
let pass = new_unchecked_verification_pass();
|
382 |
+
|
383 |
+
let (collection, shard) = path.into_inner();
|
384 |
+
let future = common::snapshots::create_shard_snapshot(
|
385 |
+
dispatcher.toc(&access, &pass).clone(),
|
386 |
+
access,
|
387 |
+
collection,
|
388 |
+
shard,
|
389 |
+
);
|
390 |
+
|
391 |
+
helpers::time_or_accept(future, query.wait.unwrap_or(true)).await
|
392 |
+
}
|
393 |
+
|
394 |
+
#[get("/collections/{collection}/shards/{shard}/snapshot")]
|
395 |
+
async fn stream_shard_snapshot(
|
396 |
+
dispatcher: web::Data<Dispatcher>,
|
397 |
+
path: web::Path<(String, ShardId)>,
|
398 |
+
ActixAccess(access): ActixAccess,
|
399 |
+
) -> Result<SnapshotStream, HttpError> {
|
400 |
+
// nothing to verify.
|
401 |
+
let pass = new_unchecked_verification_pass();
|
402 |
+
|
403 |
+
let (collection, shard) = path.into_inner();
|
404 |
+
Ok(common::snapshots::stream_shard_snapshot(
|
405 |
+
dispatcher.toc(&access, &pass).clone(),
|
406 |
+
access,
|
407 |
+
collection,
|
408 |
+
shard,
|
409 |
+
)
|
410 |
+
.await?)
|
411 |
+
}
|
412 |
+
|
413 |
+
// TODO: `PUT` (same as `recover_from_snapshot`) or `POST`!?
|
414 |
+
#[put("/collections/{collection}/shards/{shard}/snapshots/recover")]
|
415 |
+
async fn recover_shard_snapshot(
|
416 |
+
dispatcher: web::Data<Dispatcher>,
|
417 |
+
http_client: web::Data<HttpClient>,
|
418 |
+
path: web::Path<(String, ShardId)>,
|
419 |
+
query: web::Query<SnapshottingParam>,
|
420 |
+
web::Json(request): web::Json<ShardSnapshotRecover>,
|
421 |
+
ActixAccess(access): ActixAccess,
|
422 |
+
) -> impl Responder {
|
423 |
+
// nothing to verify.
|
424 |
+
let pass = new_unchecked_verification_pass();
|
425 |
+
|
426 |
+
let future = async move {
|
427 |
+
let (collection, shard) = path.into_inner();
|
428 |
+
|
429 |
+
common::snapshots::recover_shard_snapshot(
|
430 |
+
dispatcher.toc(&access, &pass).clone(),
|
431 |
+
access,
|
432 |
+
collection,
|
433 |
+
shard,
|
434 |
+
request.location,
|
435 |
+
request.priority.unwrap_or_default(),
|
436 |
+
request.checksum,
|
437 |
+
http_client.as_ref().clone(),
|
438 |
+
request.api_key,
|
439 |
+
)
|
440 |
+
.await?;
|
441 |
+
|
442 |
+
Ok(true)
|
443 |
+
};
|
444 |
+
|
445 |
+
helpers::time_or_accept(future, query.wait.unwrap_or(true)).await
|
446 |
+
}
|
447 |
+
|
448 |
+
// TODO: `POST` (same as `upload_snapshot`) or `PUT`!?
|
449 |
+
#[post("/collections/{collection}/shards/{shard}/snapshots/upload")]
|
450 |
+
async fn upload_shard_snapshot(
|
451 |
+
dispatcher: web::Data<Dispatcher>,
|
452 |
+
path: web::Path<(String, ShardId)>,
|
453 |
+
query: web::Query<SnapshotUploadingParam>,
|
454 |
+
MultipartForm(form): MultipartForm<SnapshottingForm>,
|
455 |
+
ActixAccess(access): ActixAccess,
|
456 |
+
) -> impl Responder {
|
457 |
+
// nothing to verify.
|
458 |
+
let pass = new_unchecked_verification_pass();
|
459 |
+
|
460 |
+
let (collection, shard) = path.into_inner();
|
461 |
+
let SnapshotUploadingParam {
|
462 |
+
wait,
|
463 |
+
priority,
|
464 |
+
checksum,
|
465 |
+
} = query.into_inner();
|
466 |
+
|
467 |
+
// - `recover_shard_snapshot_impl` is *not* cancel safe
|
468 |
+
// - but the task is *spawned* on the runtime and won't be cancelled, if request is cancelled
|
469 |
+
|
470 |
+
let future = cancel::future::spawn_cancel_on_drop(move |cancel| async move {
|
471 |
+
// TODO: Run this check before the multipart blob is uploaded
|
472 |
+
let collection_pass = access
|
473 |
+
.check_global_access(AccessRequirements::new().manage())?
|
474 |
+
.issue_pass(&collection);
|
475 |
+
|
476 |
+
if let Some(checksum) = checksum {
|
477 |
+
let snapshot_checksum = hash_file(form.snapshot.file.path()).await?;
|
478 |
+
if !hashes_equal(snapshot_checksum.as_str(), checksum.as_str()) {
|
479 |
+
return Err(StorageError::checksum_mismatch(snapshot_checksum, checksum));
|
480 |
+
}
|
481 |
+
}
|
482 |
+
|
483 |
+
let future = async {
|
484 |
+
let collection = dispatcher
|
485 |
+
.toc(&access, &pass)
|
486 |
+
.get_collection(&collection_pass)
|
487 |
+
.await?;
|
488 |
+
collection.assert_shard_exists(shard).await?;
|
489 |
+
|
490 |
+
Result::<_, StorageError>::Ok(collection)
|
491 |
+
};
|
492 |
+
|
493 |
+
let collection = cancel::future::cancel_on_token(cancel.clone(), future).await??;
|
494 |
+
|
495 |
+
// `recover_shard_snapshot_impl` is *not* cancel safe
|
496 |
+
common::snapshots::recover_shard_snapshot_impl(
|
497 |
+
dispatcher.toc(&access, &pass),
|
498 |
+
&collection,
|
499 |
+
shard,
|
500 |
+
form.snapshot.file.path(),
|
501 |
+
priority.unwrap_or_default(),
|
502 |
+
cancel,
|
503 |
+
)
|
504 |
+
.await?;
|
505 |
+
|
506 |
+
Ok(())
|
507 |
+
})
|
508 |
+
.map(|x| x.map_err(Into::into).and_then(|x| x));
|
509 |
+
|
510 |
+
helpers::time_or_accept(future, wait.unwrap_or(true)).await
|
511 |
+
}
|
512 |
+
|
513 |
+
#[get("/collections/{collection}/shards/{shard}/snapshots/{snapshot}")]
|
514 |
+
async fn download_shard_snapshot(
|
515 |
+
dispatcher: web::Data<Dispatcher>,
|
516 |
+
path: web::Path<(String, ShardId, String)>,
|
517 |
+
ActixAccess(access): ActixAccess,
|
518 |
+
) -> Result<impl Responder, HttpError> {
|
519 |
+
// nothing to verify.
|
520 |
+
let pass = new_unchecked_verification_pass();
|
521 |
+
|
522 |
+
let (collection, shard, snapshot) = path.into_inner();
|
523 |
+
let collection_pass =
|
524 |
+
access.check_collection_access(&collection, AccessRequirements::new().whole())?;
|
525 |
+
let collection = dispatcher
|
526 |
+
.toc(&access, &pass)
|
527 |
+
.get_collection(&collection_pass)
|
528 |
+
.await?;
|
529 |
+
let snapshots_storage_manager = collection.get_snapshots_storage_manager()?;
|
530 |
+
let snapshot_path = collection
|
531 |
+
.shards_holder()
|
532 |
+
.read()
|
533 |
+
.await
|
534 |
+
.get_shard_snapshot_path(collection.snapshots_path(), shard, &snapshot)
|
535 |
+
.await?;
|
536 |
+
let snapshot_stream = snapshots_storage_manager
|
537 |
+
.get_snapshot_stream(&snapshot_path)
|
538 |
+
.await?;
|
539 |
+
Ok(snapshot_stream)
|
540 |
+
}
|
541 |
+
|
542 |
+
#[delete("/collections/{collection}/shards/{shard}/snapshots/{snapshot}")]
|
543 |
+
async fn delete_shard_snapshot(
|
544 |
+
dispatcher: web::Data<Dispatcher>,
|
545 |
+
path: web::Path<(String, ShardId, String)>,
|
546 |
+
query: web::Query<SnapshottingParam>,
|
547 |
+
ActixAccess(access): ActixAccess,
|
548 |
+
) -> impl Responder {
|
549 |
+
// nothing to verify.
|
550 |
+
let pass = new_unchecked_verification_pass();
|
551 |
+
|
552 |
+
let (collection, shard, snapshot) = path.into_inner();
|
553 |
+
let future = common::snapshots::delete_shard_snapshot(
|
554 |
+
dispatcher.toc(&access, &pass).clone(),
|
555 |
+
access,
|
556 |
+
collection,
|
557 |
+
shard,
|
558 |
+
snapshot,
|
559 |
+
)
|
560 |
+
.map_ok(|_| true)
|
561 |
+
.map_err(Into::into);
|
562 |
+
|
563 |
+
helpers::time_or_accept(future, query.wait.unwrap_or(true)).await
|
564 |
+
}
|
565 |
+
|
566 |
+
// Configure services
|
567 |
+
pub fn config_snapshots_api(cfg: &mut web::ServiceConfig) {
|
568 |
+
cfg.service(list_snapshots)
|
569 |
+
.service(create_snapshot)
|
570 |
+
.service(upload_snapshot)
|
571 |
+
.service(recover_from_snapshot)
|
572 |
+
.service(get_snapshot)
|
573 |
+
.service(list_full_snapshots)
|
574 |
+
.service(create_full_snapshot)
|
575 |
+
.service(get_full_snapshot)
|
576 |
+
.service(delete_full_snapshot)
|
577 |
+
.service(delete_collection_snapshot)
|
578 |
+
.service(list_shard_snapshots)
|
579 |
+
.service(create_shard_snapshot)
|
580 |
+
.service(stream_shard_snapshot)
|
581 |
+
.service(recover_shard_snapshot)
|
582 |
+
.service(upload_shard_snapshot)
|
583 |
+
.service(download_shard_snapshot)
|
584 |
+
.service(delete_shard_snapshot);
|
585 |
+
}
|
src/actix/api/update_api.rs
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use actix_web::rt::time::Instant;
|
2 |
+
use actix_web::{delete, post, put, web, Responder};
|
3 |
+
use actix_web_validator::{Json, Path, Query};
|
4 |
+
use api::rest::schema::PointInsertOperations;
|
5 |
+
use api::rest::UpdateVectors;
|
6 |
+
use collection::operations::payload_ops::{DeletePayload, SetPayload};
|
7 |
+
use collection::operations::point_ops::{PointsSelector, WriteOrdering};
|
8 |
+
use collection::operations::types::UpdateResult;
|
9 |
+
use collection::operations::vector_ops::DeleteVectors;
|
10 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
11 |
+
use schemars::JsonSchema;
|
12 |
+
use segment::json_path::JsonPath;
|
13 |
+
use serde::{Deserialize, Serialize};
|
14 |
+
use storage::content_manager::collection_verification::check_strict_mode;
|
15 |
+
use storage::dispatcher::Dispatcher;
|
16 |
+
use validator::Validate;
|
17 |
+
|
18 |
+
use super::CollectionPath;
|
19 |
+
use crate::actix::auth::ActixAccess;
|
20 |
+
use crate::actix::helpers::{self, process_response, process_response_error};
|
21 |
+
use crate::common::points::{
|
22 |
+
do_batch_update_points, do_clear_payload, do_create_index, do_delete_index, do_delete_payload,
|
23 |
+
do_delete_points, do_delete_vectors, do_overwrite_payload, do_set_payload, do_update_vectors,
|
24 |
+
do_upsert_points, CreateFieldIndex, UpdateOperations,
|
25 |
+
};
|
26 |
+
|
27 |
+
#[derive(Deserialize, Validate)]
|
28 |
+
struct FieldPath {
|
29 |
+
#[serde(rename = "field_name")]
|
30 |
+
name: JsonPath,
|
31 |
+
}
|
32 |
+
|
33 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
34 |
+
pub struct UpdateParam {
|
35 |
+
pub wait: Option<bool>,
|
36 |
+
pub ordering: Option<WriteOrdering>,
|
37 |
+
}
|
38 |
+
|
39 |
+
#[put("/collections/{name}/points")]
|
40 |
+
async fn upsert_points(
|
41 |
+
dispatcher: web::Data<Dispatcher>,
|
42 |
+
collection: Path<CollectionPath>,
|
43 |
+
operation: Json<PointInsertOperations>,
|
44 |
+
params: Query<UpdateParam>,
|
45 |
+
ActixAccess(access): ActixAccess,
|
46 |
+
) -> impl Responder {
|
47 |
+
// nothing to verify.
|
48 |
+
let pass = new_unchecked_verification_pass();
|
49 |
+
|
50 |
+
let operation = operation.into_inner();
|
51 |
+
let wait = params.wait.unwrap_or(false);
|
52 |
+
let ordering = params.ordering.unwrap_or_default();
|
53 |
+
|
54 |
+
helpers::time(do_upsert_points(
|
55 |
+
dispatcher.toc(&access, &pass).clone(),
|
56 |
+
collection.into_inner().name,
|
57 |
+
operation,
|
58 |
+
None,
|
59 |
+
None,
|
60 |
+
wait,
|
61 |
+
ordering,
|
62 |
+
access,
|
63 |
+
))
|
64 |
+
.await
|
65 |
+
}
|
66 |
+
|
67 |
+
#[post("/collections/{name}/points/delete")]
|
68 |
+
async fn delete_points(
|
69 |
+
dispatcher: web::Data<Dispatcher>,
|
70 |
+
collection: Path<CollectionPath>,
|
71 |
+
operation: Json<PointsSelector>,
|
72 |
+
params: Query<UpdateParam>,
|
73 |
+
ActixAccess(access): ActixAccess,
|
74 |
+
) -> impl Responder {
|
75 |
+
let operation = operation.into_inner();
|
76 |
+
let pass =
|
77 |
+
match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
|
78 |
+
Ok(pass) => pass,
|
79 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
80 |
+
};
|
81 |
+
|
82 |
+
let wait = params.wait.unwrap_or(false);
|
83 |
+
let ordering = params.ordering.unwrap_or_default();
|
84 |
+
|
85 |
+
helpers::time(do_delete_points(
|
86 |
+
dispatcher.toc(&access, &pass).clone(),
|
87 |
+
collection.into_inner().name,
|
88 |
+
operation,
|
89 |
+
None,
|
90 |
+
None,
|
91 |
+
wait,
|
92 |
+
ordering,
|
93 |
+
access,
|
94 |
+
))
|
95 |
+
.await
|
96 |
+
}
|
97 |
+
|
98 |
+
#[put("/collections/{name}/points/vectors")]
|
99 |
+
async fn update_vectors(
|
100 |
+
dispatcher: web::Data<Dispatcher>,
|
101 |
+
collection: Path<CollectionPath>,
|
102 |
+
operation: Json<UpdateVectors>,
|
103 |
+
params: Query<UpdateParam>,
|
104 |
+
ActixAccess(access): ActixAccess,
|
105 |
+
) -> impl Responder {
|
106 |
+
// Nothing to verify here.
|
107 |
+
let pass = new_unchecked_verification_pass();
|
108 |
+
|
109 |
+
let operation = operation.into_inner();
|
110 |
+
let wait = params.wait.unwrap_or(false);
|
111 |
+
let ordering = params.ordering.unwrap_or_default();
|
112 |
+
|
113 |
+
helpers::time(do_update_vectors(
|
114 |
+
dispatcher.toc(&access, &pass).clone(),
|
115 |
+
collection.into_inner().name,
|
116 |
+
operation,
|
117 |
+
None,
|
118 |
+
None,
|
119 |
+
wait,
|
120 |
+
ordering,
|
121 |
+
access,
|
122 |
+
))
|
123 |
+
.await
|
124 |
+
}
|
125 |
+
|
126 |
+
#[post("/collections/{name}/points/vectors/delete")]
|
127 |
+
async fn delete_vectors(
|
128 |
+
dispatcher: web::Data<Dispatcher>,
|
129 |
+
collection: Path<CollectionPath>,
|
130 |
+
operation: Json<DeleteVectors>,
|
131 |
+
params: Query<UpdateParam>,
|
132 |
+
ActixAccess(access): ActixAccess,
|
133 |
+
) -> impl Responder {
|
134 |
+
let timing = Instant::now();
|
135 |
+
|
136 |
+
let operation = operation.into_inner();
|
137 |
+
let pass =
|
138 |
+
match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
|
139 |
+
Ok(pass) => pass,
|
140 |
+
Err(err) => return process_response_error(err, timing, None),
|
141 |
+
};
|
142 |
+
|
143 |
+
let wait = params.wait.unwrap_or(false);
|
144 |
+
let ordering = params.ordering.unwrap_or_default();
|
145 |
+
|
146 |
+
let response = do_delete_vectors(
|
147 |
+
dispatcher.toc(&access, &pass).clone(),
|
148 |
+
collection.into_inner().name,
|
149 |
+
operation,
|
150 |
+
None,
|
151 |
+
None,
|
152 |
+
wait,
|
153 |
+
ordering,
|
154 |
+
access,
|
155 |
+
)
|
156 |
+
.await;
|
157 |
+
process_response(response, timing, None)
|
158 |
+
}
|
159 |
+
|
160 |
+
#[post("/collections/{name}/points/payload")]
|
161 |
+
async fn set_payload(
|
162 |
+
dispatcher: web::Data<Dispatcher>,
|
163 |
+
collection: Path<CollectionPath>,
|
164 |
+
operation: Json<SetPayload>,
|
165 |
+
params: Query<UpdateParam>,
|
166 |
+
ActixAccess(access): ActixAccess,
|
167 |
+
) -> impl Responder {
|
168 |
+
let operation = operation.into_inner();
|
169 |
+
|
170 |
+
let pass =
|
171 |
+
match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
|
172 |
+
Ok(pass) => pass,
|
173 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
174 |
+
};
|
175 |
+
|
176 |
+
let wait = params.wait.unwrap_or(false);
|
177 |
+
let ordering = params.ordering.unwrap_or_default();
|
178 |
+
|
179 |
+
helpers::time(do_set_payload(
|
180 |
+
dispatcher.toc(&access, &pass).clone(),
|
181 |
+
collection.into_inner().name,
|
182 |
+
operation,
|
183 |
+
None,
|
184 |
+
None,
|
185 |
+
wait,
|
186 |
+
ordering,
|
187 |
+
access,
|
188 |
+
))
|
189 |
+
.await
|
190 |
+
}
|
191 |
+
|
192 |
+
#[put("/collections/{name}/points/payload")]
|
193 |
+
async fn overwrite_payload(
|
194 |
+
dispatcher: web::Data<Dispatcher>,
|
195 |
+
collection: Path<CollectionPath>,
|
196 |
+
operation: Json<SetPayload>,
|
197 |
+
params: Query<UpdateParam>,
|
198 |
+
ActixAccess(access): ActixAccess,
|
199 |
+
) -> impl Responder {
|
200 |
+
let operation = operation.into_inner();
|
201 |
+
let pass =
|
202 |
+
match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
|
203 |
+
Ok(pass) => pass,
|
204 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
205 |
+
};
|
206 |
+
let wait = params.wait.unwrap_or(false);
|
207 |
+
let ordering = params.ordering.unwrap_or_default();
|
208 |
+
|
209 |
+
helpers::time(do_overwrite_payload(
|
210 |
+
dispatcher.toc(&access, &pass).clone(),
|
211 |
+
collection.into_inner().name,
|
212 |
+
operation,
|
213 |
+
None,
|
214 |
+
None,
|
215 |
+
wait,
|
216 |
+
ordering,
|
217 |
+
access,
|
218 |
+
))
|
219 |
+
.await
|
220 |
+
}
|
221 |
+
|
222 |
+
#[post("/collections/{name}/points/payload/delete")]
|
223 |
+
async fn delete_payload(
|
224 |
+
dispatcher: web::Data<Dispatcher>,
|
225 |
+
collection: Path<CollectionPath>,
|
226 |
+
operation: Json<DeletePayload>,
|
227 |
+
params: Query<UpdateParam>,
|
228 |
+
ActixAccess(access): ActixAccess,
|
229 |
+
) -> impl Responder {
|
230 |
+
let operation = operation.into_inner();
|
231 |
+
let pass =
|
232 |
+
match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
|
233 |
+
Ok(pass) => pass,
|
234 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
235 |
+
};
|
236 |
+
let wait = params.wait.unwrap_or(false);
|
237 |
+
let ordering = params.ordering.unwrap_or_default();
|
238 |
+
|
239 |
+
helpers::time(do_delete_payload(
|
240 |
+
dispatcher.toc(&access, &pass).clone(),
|
241 |
+
collection.into_inner().name,
|
242 |
+
operation,
|
243 |
+
None,
|
244 |
+
None,
|
245 |
+
wait,
|
246 |
+
ordering,
|
247 |
+
access,
|
248 |
+
))
|
249 |
+
.await
|
250 |
+
}
|
251 |
+
|
252 |
+
#[post("/collections/{name}/points/payload/clear")]
|
253 |
+
async fn clear_payload(
|
254 |
+
dispatcher: web::Data<Dispatcher>,
|
255 |
+
collection: Path<CollectionPath>,
|
256 |
+
operation: Json<PointsSelector>,
|
257 |
+
params: Query<UpdateParam>,
|
258 |
+
ActixAccess(access): ActixAccess,
|
259 |
+
) -> impl Responder {
|
260 |
+
let operation = operation.into_inner();
|
261 |
+
let pass =
|
262 |
+
match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
|
263 |
+
Ok(pass) => pass,
|
264 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
265 |
+
};
|
266 |
+
|
267 |
+
let wait = params.wait.unwrap_or(false);
|
268 |
+
let ordering = params.ordering.unwrap_or_default();
|
269 |
+
|
270 |
+
helpers::time(do_clear_payload(
|
271 |
+
dispatcher.toc(&access, &pass).clone(),
|
272 |
+
collection.into_inner().name,
|
273 |
+
operation,
|
274 |
+
None,
|
275 |
+
None,
|
276 |
+
wait,
|
277 |
+
ordering,
|
278 |
+
access,
|
279 |
+
))
|
280 |
+
.await
|
281 |
+
}
|
282 |
+
|
283 |
+
#[post("/collections/{name}/points/batch")]
|
284 |
+
async fn update_batch(
|
285 |
+
dispatcher: web::Data<Dispatcher>,
|
286 |
+
collection: Path<CollectionPath>,
|
287 |
+
operations: Json<UpdateOperations>,
|
288 |
+
params: Query<UpdateParam>,
|
289 |
+
ActixAccess(access): ActixAccess,
|
290 |
+
) -> impl Responder {
|
291 |
+
let timing = Instant::now();
|
292 |
+
let operations = operations.into_inner();
|
293 |
+
|
294 |
+
let mut vpass = None;
|
295 |
+
for operation in operations.operations.iter() {
|
296 |
+
let pass = match check_strict_mode(operation, None, &collection.name, &dispatcher, &access)
|
297 |
+
.await
|
298 |
+
{
|
299 |
+
Ok(pass) => pass,
|
300 |
+
Err(err) => return process_response_error(err, Instant::now(), None),
|
301 |
+
};
|
302 |
+
vpass = Some(pass);
|
303 |
+
}
|
304 |
+
|
305 |
+
// vpass == None => No update operation available
|
306 |
+
let Some(pass) = vpass else {
|
307 |
+
return process_response::<Vec<UpdateResult>>(Ok(vec![]), timing, None);
|
308 |
+
};
|
309 |
+
|
310 |
+
let wait = params.wait.unwrap_or(false);
|
311 |
+
let ordering = params.ordering.unwrap_or_default();
|
312 |
+
|
313 |
+
let response = do_batch_update_points(
|
314 |
+
dispatcher.toc(&access, &pass).clone(),
|
315 |
+
collection.into_inner().name,
|
316 |
+
operations.operations,
|
317 |
+
None,
|
318 |
+
None,
|
319 |
+
wait,
|
320 |
+
ordering,
|
321 |
+
access,
|
322 |
+
)
|
323 |
+
.await;
|
324 |
+
process_response(response, timing, None)
|
325 |
+
}
|
326 |
+
#[put("/collections/{name}/index")]
|
327 |
+
async fn create_field_index(
|
328 |
+
dispatcher: web::Data<Dispatcher>,
|
329 |
+
collection: Path<CollectionPath>,
|
330 |
+
operation: Json<CreateFieldIndex>,
|
331 |
+
params: Query<UpdateParam>,
|
332 |
+
ActixAccess(access): ActixAccess,
|
333 |
+
) -> impl Responder {
|
334 |
+
let timing = Instant::now();
|
335 |
+
let operation = operation.into_inner();
|
336 |
+
let wait = params.wait.unwrap_or(false);
|
337 |
+
let ordering = params.ordering.unwrap_or_default();
|
338 |
+
|
339 |
+
let response = do_create_index(
|
340 |
+
dispatcher.into_inner(),
|
341 |
+
collection.into_inner().name,
|
342 |
+
operation,
|
343 |
+
None,
|
344 |
+
None,
|
345 |
+
wait,
|
346 |
+
ordering,
|
347 |
+
access,
|
348 |
+
)
|
349 |
+
.await;
|
350 |
+
process_response(response, timing, None)
|
351 |
+
}
|
352 |
+
|
353 |
+
#[delete("/collections/{name}/index/{field_name}")]
|
354 |
+
async fn delete_field_index(
|
355 |
+
dispatcher: web::Data<Dispatcher>,
|
356 |
+
collection: Path<CollectionPath>,
|
357 |
+
field: Path<FieldPath>,
|
358 |
+
params: Query<UpdateParam>,
|
359 |
+
ActixAccess(access): ActixAccess,
|
360 |
+
) -> impl Responder {
|
361 |
+
let timing = Instant::now();
|
362 |
+
let wait = params.wait.unwrap_or(false);
|
363 |
+
let ordering = params.ordering.unwrap_or_default();
|
364 |
+
|
365 |
+
let response = do_delete_index(
|
366 |
+
dispatcher.into_inner(),
|
367 |
+
collection.into_inner().name,
|
368 |
+
field.name.clone(),
|
369 |
+
None,
|
370 |
+
None,
|
371 |
+
wait,
|
372 |
+
ordering,
|
373 |
+
access,
|
374 |
+
)
|
375 |
+
.await;
|
376 |
+
process_response(response, timing, None)
|
377 |
+
}
|
378 |
+
|
379 |
+
// Configure services
|
380 |
+
pub fn config_update_api(cfg: &mut web::ServiceConfig) {
|
381 |
+
cfg.service(upsert_points)
|
382 |
+
.service(delete_points)
|
383 |
+
.service(update_vectors)
|
384 |
+
.service(delete_vectors)
|
385 |
+
.service(set_payload)
|
386 |
+
.service(overwrite_payload)
|
387 |
+
.service(delete_payload)
|
388 |
+
.service(clear_payload)
|
389 |
+
.service(create_field_index)
|
390 |
+
.service(delete_field_index)
|
391 |
+
.service(update_batch);
|
392 |
+
}
|
src/actix/auth.rs
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::convert::Infallible;
|
2 |
+
use std::future::{ready, Ready};
|
3 |
+
use std::sync::Arc;
|
4 |
+
|
5 |
+
use actix_web::body::{BoxBody, EitherBody};
|
6 |
+
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
|
7 |
+
use actix_web::{Error, FromRequest, HttpMessage, HttpResponse, ResponseError};
|
8 |
+
use futures_util::future::LocalBoxFuture;
|
9 |
+
use storage::rbac::Access;
|
10 |
+
|
11 |
+
use super::helpers::HttpError;
|
12 |
+
use crate::common::auth::{AuthError, AuthKeys};
|
13 |
+
|
14 |
+
pub struct Auth {
|
15 |
+
auth_keys: AuthKeys,
|
16 |
+
whitelist: Vec<WhitelistItem>,
|
17 |
+
}
|
18 |
+
|
19 |
+
impl Auth {
|
20 |
+
pub fn new(auth_keys: AuthKeys, whitelist: Vec<WhitelistItem>) -> Self {
|
21 |
+
Self {
|
22 |
+
auth_keys,
|
23 |
+
whitelist,
|
24 |
+
}
|
25 |
+
}
|
26 |
+
}
|
27 |
+
|
28 |
+
impl<S, B> Transform<S, ServiceRequest> for Auth
|
29 |
+
where
|
30 |
+
S: Service<ServiceRequest, Response = ServiceResponse<EitherBody<B, BoxBody>>, Error = Error>
|
31 |
+
+ 'static,
|
32 |
+
S::Future: 'static,
|
33 |
+
B: 'static,
|
34 |
+
{
|
35 |
+
type Response = ServiceResponse<EitherBody<B, BoxBody>>;
|
36 |
+
type Error = Error;
|
37 |
+
type InitError = ();
|
38 |
+
type Transform = AuthMiddleware<S>;
|
39 |
+
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
40 |
+
|
41 |
+
fn new_transform(&self, service: S) -> Self::Future {
|
42 |
+
ready(Ok(AuthMiddleware {
|
43 |
+
auth_keys: Arc::new(self.auth_keys.clone()),
|
44 |
+
whitelist: self.whitelist.clone(),
|
45 |
+
service: Arc::new(service),
|
46 |
+
}))
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
#[derive(Clone, Eq, PartialEq, Hash)]
|
51 |
+
pub struct WhitelistItem(pub String, pub PathMode);
|
52 |
+
|
53 |
+
impl WhitelistItem {
|
54 |
+
pub fn exact<S: Into<String>>(path: S) -> Self {
|
55 |
+
Self(path.into(), PathMode::Exact)
|
56 |
+
}
|
57 |
+
|
58 |
+
pub fn prefix<S: Into<String>>(path: S) -> Self {
|
59 |
+
Self(path.into(), PathMode::Prefix)
|
60 |
+
}
|
61 |
+
|
62 |
+
pub fn matches(&self, other: &str) -> bool {
|
63 |
+
self.1.check(&self.0, other)
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
|
68 |
+
pub enum PathMode {
|
69 |
+
/// Path must match exactly
|
70 |
+
Exact,
|
71 |
+
/// Path must have given prefix
|
72 |
+
Prefix,
|
73 |
+
}
|
74 |
+
|
75 |
+
impl PathMode {
|
76 |
+
fn check(&self, key: &str, other: &str) -> bool {
|
77 |
+
match self {
|
78 |
+
Self::Exact => key == other,
|
79 |
+
Self::Prefix => other.starts_with(key),
|
80 |
+
}
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
pub struct AuthMiddleware<S> {
|
85 |
+
auth_keys: Arc<AuthKeys>,
|
86 |
+
/// List of items whitelisted from authentication.
|
87 |
+
whitelist: Vec<WhitelistItem>,
|
88 |
+
service: Arc<S>,
|
89 |
+
}
|
90 |
+
|
91 |
+
impl<S> AuthMiddleware<S> {
|
92 |
+
pub fn is_path_whitelisted(&self, path: &str) -> bool {
|
93 |
+
self.whitelist.iter().any(|item| item.matches(path))
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
impl<S, B> Service<ServiceRequest> for AuthMiddleware<S>
|
98 |
+
where
|
99 |
+
S: Service<ServiceRequest, Response = ServiceResponse<EitherBody<B, BoxBody>>, Error = Error>
|
100 |
+
+ 'static,
|
101 |
+
S::Future: 'static,
|
102 |
+
B: 'static,
|
103 |
+
{
|
104 |
+
type Response = ServiceResponse<EitherBody<B, BoxBody>>;
|
105 |
+
type Error = Error;
|
106 |
+
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
107 |
+
|
108 |
+
forward_ready!(service);
|
109 |
+
|
110 |
+
fn call(&self, req: ServiceRequest) -> Self::Future {
|
111 |
+
let path = req.path();
|
112 |
+
|
113 |
+
if self.is_path_whitelisted(path) {
|
114 |
+
return Box::pin(self.service.call(req));
|
115 |
+
}
|
116 |
+
|
117 |
+
let auth_keys = self.auth_keys.clone();
|
118 |
+
let service = self.service.clone();
|
119 |
+
Box::pin(async move {
|
120 |
+
match auth_keys
|
121 |
+
.validate_request(|key| req.headers().get(key).and_then(|val| val.to_str().ok()))
|
122 |
+
.await
|
123 |
+
{
|
124 |
+
Ok(access) => {
|
125 |
+
let previous = req.extensions_mut().insert::<Access>(access);
|
126 |
+
debug_assert!(
|
127 |
+
previous.is_none(),
|
128 |
+
"Previous access object should not exist in the request"
|
129 |
+
);
|
130 |
+
service.call(req).await
|
131 |
+
}
|
132 |
+
Err(e) => {
|
133 |
+
let resp = match e {
|
134 |
+
AuthError::Unauthorized(e) => HttpResponse::Unauthorized().body(e),
|
135 |
+
AuthError::Forbidden(e) => HttpResponse::Forbidden().body(e),
|
136 |
+
AuthError::StorageError(e) => HttpError::from(e).error_response(),
|
137 |
+
};
|
138 |
+
Ok(req.into_response(resp).map_into_right_body())
|
139 |
+
}
|
140 |
+
}
|
141 |
+
})
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
pub struct ActixAccess(pub Access);
|
146 |
+
|
147 |
+
impl FromRequest for ActixAccess {
|
148 |
+
type Error = Infallible;
|
149 |
+
type Future = Ready<Result<Self, Self::Error>>;
|
150 |
+
|
151 |
+
fn from_request(
|
152 |
+
req: &actix_web::HttpRequest,
|
153 |
+
_payload: &mut actix_web::dev::Payload,
|
154 |
+
) -> Self::Future {
|
155 |
+
let access = req.extensions_mut().remove::<Access>().unwrap_or_else(|| {
|
156 |
+
Access::full("All requests have full by default access when API key is not configured")
|
157 |
+
});
|
158 |
+
ready(Ok(ActixAccess(access)))
|
159 |
+
}
|
160 |
+
}
|
src/actix/certificate_helpers.rs
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::fmt::Debug;
|
2 |
+
use std::fs::File;
|
3 |
+
use std::io::{self, BufRead, BufReader};
|
4 |
+
use std::sync::Arc;
|
5 |
+
use std::time::{Duration, Instant};
|
6 |
+
|
7 |
+
use parking_lot::RwLock;
|
8 |
+
use rustls::client::VerifierBuilderError;
|
9 |
+
use rustls::pki_types::CertificateDer;
|
10 |
+
use rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier};
|
11 |
+
use rustls::sign::CertifiedKey;
|
12 |
+
use rustls::{crypto, RootCertStore, ServerConfig};
|
13 |
+
use rustls_pemfile::Item;
|
14 |
+
|
15 |
+
use crate::settings::{Settings, TlsConfig};
|
16 |
+
|
17 |
+
type Result<T> = std::result::Result<T, Error>;
|
18 |
+
|
19 |
+
/// A TTL based rotating server certificate resolver
|
20 |
+
#[derive(Debug)]
|
21 |
+
struct RotatingCertificateResolver {
|
22 |
+
/// TLS configuration used for loading/refreshing certified key
|
23 |
+
tls_config: TlsConfig,
|
24 |
+
|
25 |
+
/// TTL for each rotation
|
26 |
+
ttl: Option<Duration>,
|
27 |
+
|
28 |
+
/// Current certified key
|
29 |
+
key: RwLock<CertifiedKeyWithAge>,
|
30 |
+
}
|
31 |
+
|
32 |
+
impl RotatingCertificateResolver {
|
33 |
+
pub fn new(tls_config: TlsConfig, ttl: Option<Duration>) -> Result<Self> {
|
34 |
+
let certified_key = load_certified_key(&tls_config)?;
|
35 |
+
|
36 |
+
Ok(Self {
|
37 |
+
tls_config,
|
38 |
+
ttl,
|
39 |
+
key: RwLock::new(CertifiedKeyWithAge::from(certified_key)),
|
40 |
+
})
|
41 |
+
}
|
42 |
+
|
43 |
+
/// Get certificate key or refresh
|
44 |
+
///
|
45 |
+
/// The key is automatically refreshed when the TTL is reached.
|
46 |
+
/// If refreshing fails, an error is logged and the old key is persisted.
|
47 |
+
fn get_key_or_refresh(&self) -> Arc<CertifiedKey> {
|
48 |
+
// Get read-only lock to the key. If TTL is not configured or is not expired, return key.
|
49 |
+
let key = self.key.read();
|
50 |
+
let ttl = match self.ttl {
|
51 |
+
Some(ttl) if key.is_expired(ttl) => ttl,
|
52 |
+
_ => return key.key.clone(),
|
53 |
+
};
|
54 |
+
drop(key);
|
55 |
+
|
56 |
+
// If TTL is expired:
|
57 |
+
// - get read-write lock to the key
|
58 |
+
// - *re-check that TTL is expired* (to avoid refreshing the key multiple times from concurrent threads)
|
59 |
+
// - refresh and return the key
|
60 |
+
let mut key = self.key.write();
|
61 |
+
if key.is_expired(ttl) {
|
62 |
+
if let Err(err) = key.refresh(&self.tls_config) {
|
63 |
+
log::error!("Failed to refresh server TLS certificate, keeping current: {err}");
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
key.key.clone()
|
68 |
+
}
|
69 |
+
}
|
70 |
+
|
71 |
+
impl ResolvesServerCert for RotatingCertificateResolver {
|
72 |
+
fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
73 |
+
Some(self.get_key_or_refresh())
|
74 |
+
}
|
75 |
+
}
|
76 |
+
|
77 |
+
#[derive(Debug)]
|
78 |
+
struct CertifiedKeyWithAge {
|
79 |
+
/// Last time the certificate was updated/replaced
|
80 |
+
last_update: Instant,
|
81 |
+
|
82 |
+
/// Current certified key
|
83 |
+
key: Arc<CertifiedKey>,
|
84 |
+
}
|
85 |
+
|
86 |
+
impl CertifiedKeyWithAge {
|
87 |
+
pub fn from(key: Arc<CertifiedKey>) -> Self {
|
88 |
+
Self {
|
89 |
+
last_update: Instant::now(),
|
90 |
+
key,
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
pub fn refresh(&mut self, tls_config: &TlsConfig) -> Result<()> {
|
95 |
+
*self = Self::from(load_certified_key(tls_config)?);
|
96 |
+
Ok(())
|
97 |
+
}
|
98 |
+
|
99 |
+
pub fn age(&self) -> Duration {
|
100 |
+
self.last_update.elapsed()
|
101 |
+
}
|
102 |
+
|
103 |
+
pub fn is_expired(&self, ttl: Duration) -> bool {
|
104 |
+
self.age() >= ttl
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
/// Load TLS configuration and construct certified key.
|
109 |
+
fn load_certified_key(tls_config: &TlsConfig) -> Result<Arc<CertifiedKey>> {
|
110 |
+
// Load certificates
|
111 |
+
let certs: Vec<CertificateDer> = with_buf_read(&tls_config.cert, |rd| {
|
112 |
+
rustls_pemfile::read_all(rd).collect::<io::Result<Vec<_>>>()
|
113 |
+
})?
|
114 |
+
.into_iter()
|
115 |
+
.filter_map(|item| match item {
|
116 |
+
Item::X509Certificate(data) => Some(data),
|
117 |
+
_ => None,
|
118 |
+
})
|
119 |
+
.collect();
|
120 |
+
if certs.is_empty() {
|
121 |
+
return Err(Error::NoServerCert);
|
122 |
+
}
|
123 |
+
|
124 |
+
// Load private key
|
125 |
+
let private_key_item =
|
126 |
+
with_buf_read(&tls_config.key, rustls_pemfile::read_one)?.ok_or(Error::NoPrivateKey)?;
|
127 |
+
let private_key = match private_key_item {
|
128 |
+
Item::Pkcs1Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey),
|
129 |
+
Item::Pkcs8Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey),
|
130 |
+
Item::Sec1Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey),
|
131 |
+
_ => return Err(Error::InvalidPrivateKey),
|
132 |
+
};
|
133 |
+
let signing_key = crypto::ring::sign::any_supported_type(&private_key).map_err(Error::Sign)?;
|
134 |
+
|
135 |
+
// Construct certified key
|
136 |
+
let certified_key = CertifiedKey::new(certs, signing_key);
|
137 |
+
Ok(Arc::new(certified_key))
|
138 |
+
}
|
139 |
+
|
140 |
+
/// Generate an actix server configuration with TLS
|
141 |
+
///
|
142 |
+
/// Uses TLS settings as configured in configuration by user.
|
143 |
+
pub fn actix_tls_server_config(settings: &Settings) -> Result<ServerConfig> {
|
144 |
+
let config = ServerConfig::builder();
|
145 |
+
let tls_config = settings
|
146 |
+
.tls
|
147 |
+
.clone()
|
148 |
+
.ok_or_else(Settings::tls_config_is_undefined_error)
|
149 |
+
.map_err(Error::Io)?;
|
150 |
+
|
151 |
+
// Verify client CA or not
|
152 |
+
let config = if settings.service.verify_https_client_certificate {
|
153 |
+
let mut root_cert_store = RootCertStore::empty();
|
154 |
+
let ca_certs: Vec<CertificateDer> = with_buf_read(&tls_config.ca_cert, |rd| {
|
155 |
+
rustls_pemfile::certs(rd).collect()
|
156 |
+
})?;
|
157 |
+
root_cert_store.add_parsable_certificates(ca_certs);
|
158 |
+
let client_cert_verifier = WebPkiClientVerifier::builder(root_cert_store.into())
|
159 |
+
.build()
|
160 |
+
.map_err(Error::ClientCertVerifier)?;
|
161 |
+
config.with_client_cert_verifier(client_cert_verifier)
|
162 |
+
} else {
|
163 |
+
config.with_no_client_auth()
|
164 |
+
};
|
165 |
+
|
166 |
+
// Configure rotating certificate resolver
|
167 |
+
let ttl = match tls_config.cert_ttl {
|
168 |
+
None | Some(0) => None,
|
169 |
+
Some(seconds) => Some(Duration::from_secs(seconds)),
|
170 |
+
};
|
171 |
+
let cert_resolver = RotatingCertificateResolver::new(tls_config, ttl)?;
|
172 |
+
let config = config.with_cert_resolver(Arc::new(cert_resolver));
|
173 |
+
|
174 |
+
Ok(config)
|
175 |
+
}
|
176 |
+
|
177 |
+
fn with_buf_read<T>(path: &str, f: impl FnOnce(&mut dyn BufRead) -> io::Result<T>) -> Result<T> {
|
178 |
+
let file = File::open(path).map_err(|err| Error::OpenFile(err, path.into()))?;
|
179 |
+
let mut reader = BufReader::new(file);
|
180 |
+
let dyn_reader: &mut dyn BufRead = &mut reader;
|
181 |
+
f(dyn_reader).map_err(|err| Error::ReadFile(err, path.into()))
|
182 |
+
}
|
183 |
+
|
184 |
+
/// Actix TLS errors.
|
185 |
+
#[derive(thiserror::Error, Debug)]
|
186 |
+
pub enum Error {
|
187 |
+
#[error("TLS file could not be opened: {1}")]
|
188 |
+
OpenFile(#[source] io::Error, String),
|
189 |
+
#[error("TLS file could not be read: {1}")]
|
190 |
+
ReadFile(#[source] io::Error, String),
|
191 |
+
#[error("general TLS IO error")]
|
192 |
+
Io(#[source] io::Error),
|
193 |
+
#[error("no server certificate found")]
|
194 |
+
NoServerCert,
|
195 |
+
#[error("no private key found")]
|
196 |
+
NoPrivateKey,
|
197 |
+
#[error("invalid private key")]
|
198 |
+
InvalidPrivateKey,
|
199 |
+
#[error("TLS signing error")]
|
200 |
+
Sign(#[source] rustls::Error),
|
201 |
+
#[error("client certificate verification")]
|
202 |
+
ClientCertVerifier(#[source] VerifierBuilderError),
|
203 |
+
}
|
src/actix/helpers.rs
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::fmt::Debug;
|
2 |
+
use std::future::Future;
|
3 |
+
|
4 |
+
use actix_web::rt::time::Instant;
|
5 |
+
use actix_web::{http, HttpResponse, ResponseError};
|
6 |
+
use api::rest::models::{ApiResponse, ApiStatus, HardwareUsage};
|
7 |
+
use collection::operations::types::CollectionError;
|
8 |
+
use common::counter::hardware_accumulator::HwMeasurementAcc;
|
9 |
+
use serde::Serialize;
|
10 |
+
use storage::content_manager::errors::StorageError;
|
11 |
+
use storage::content_manager::toc::request_hw_counter::RequestHwCounter;
|
12 |
+
use storage::dispatcher::Dispatcher;
|
13 |
+
|
14 |
+
pub fn get_request_hardware_counter(
|
15 |
+
dispatcher: &Dispatcher,
|
16 |
+
collection_name: String,
|
17 |
+
report_to_api: bool,
|
18 |
+
) -> RequestHwCounter {
|
19 |
+
RequestHwCounter::new(
|
20 |
+
HwMeasurementAcc::new_with_drain(&dispatcher.get_collection_hw_metrics(collection_name)),
|
21 |
+
report_to_api,
|
22 |
+
false,
|
23 |
+
)
|
24 |
+
}
|
25 |
+
|
26 |
+
pub fn accepted_response(timing: Instant, hardware_usage: Option<HardwareUsage>) -> HttpResponse {
|
27 |
+
HttpResponse::Accepted().json(ApiResponse::<()> {
|
28 |
+
result: None,
|
29 |
+
status: ApiStatus::Accepted,
|
30 |
+
time: timing.elapsed().as_secs_f64(),
|
31 |
+
usage: hardware_usage,
|
32 |
+
})
|
33 |
+
}
|
34 |
+
|
35 |
+
pub fn process_response<T>(
|
36 |
+
response: Result<T, StorageError>,
|
37 |
+
timing: Instant,
|
38 |
+
hardware_usage: Option<HardwareUsage>,
|
39 |
+
) -> HttpResponse
|
40 |
+
where
|
41 |
+
T: Serialize,
|
42 |
+
{
|
43 |
+
match response {
|
44 |
+
Ok(res) => HttpResponse::Ok().json(ApiResponse {
|
45 |
+
result: Some(res),
|
46 |
+
status: ApiStatus::Ok,
|
47 |
+
time: timing.elapsed().as_secs_f64(),
|
48 |
+
usage: hardware_usage,
|
49 |
+
}),
|
50 |
+
Err(err) => process_response_error(err, timing, hardware_usage),
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
pub fn process_response_error(
|
55 |
+
err: StorageError,
|
56 |
+
timing: Instant,
|
57 |
+
hardware_usage: Option<HardwareUsage>,
|
58 |
+
) -> HttpResponse {
|
59 |
+
log_service_error(&err);
|
60 |
+
|
61 |
+
let error = HttpError::from(err);
|
62 |
+
|
63 |
+
HttpResponse::build(error.status_code()).json(ApiResponse::<()> {
|
64 |
+
result: None,
|
65 |
+
status: ApiStatus::Error(error.to_string()),
|
66 |
+
time: timing.elapsed().as_secs_f64(),
|
67 |
+
usage: hardware_usage,
|
68 |
+
})
|
69 |
+
}
|
70 |
+
|
71 |
+
/// Response wrapper for a `Future` returning `Result`.
|
72 |
+
///
|
73 |
+
/// # Cancel safety
|
74 |
+
///
|
75 |
+
/// Future must be cancel safe.
|
76 |
+
pub async fn time<T, Fut>(future: Fut) -> HttpResponse
|
77 |
+
where
|
78 |
+
Fut: Future<Output = Result<T, StorageError>>,
|
79 |
+
T: serde::Serialize,
|
80 |
+
{
|
81 |
+
time_impl(async { future.await.map(Some) }).await
|
82 |
+
}
|
83 |
+
|
84 |
+
/// Response wrapper for a `Future` returning `Result`.
|
85 |
+
/// If `wait` is false, returns `202 Accepted` immediately.
|
86 |
+
pub async fn time_or_accept<T, Fut>(future: Fut, wait: bool) -> HttpResponse
|
87 |
+
where
|
88 |
+
Fut: Future<Output = Result<T, StorageError>> + Send + 'static,
|
89 |
+
T: serde::Serialize + Send + 'static,
|
90 |
+
{
|
91 |
+
let future = async move {
|
92 |
+
let handle = tokio::task::spawn(async move {
|
93 |
+
let result = future.await;
|
94 |
+
|
95 |
+
if !wait {
|
96 |
+
if let Err(err) = &result {
|
97 |
+
log_service_error(err);
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
result
|
102 |
+
});
|
103 |
+
|
104 |
+
if wait {
|
105 |
+
handle.await?.map(Some)
|
106 |
+
} else {
|
107 |
+
Ok(None)
|
108 |
+
}
|
109 |
+
};
|
110 |
+
|
111 |
+
time_impl(future).await
|
112 |
+
}
|
113 |
+
|
114 |
+
/// # Cancel safety
|
115 |
+
///
|
116 |
+
/// Future must be cancel safe.
|
117 |
+
async fn time_impl<T, Fut>(future: Fut) -> HttpResponse
|
118 |
+
where
|
119 |
+
Fut: Future<Output = Result<Option<T>, StorageError>>,
|
120 |
+
T: serde::Serialize,
|
121 |
+
{
|
122 |
+
let instant = Instant::now();
|
123 |
+
match future.await.transpose() {
|
124 |
+
Some(res) => process_response(res, instant, None),
|
125 |
+
None => accepted_response(instant, None),
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
fn log_service_error(err: &StorageError) {
|
130 |
+
if let StorageError::ServiceError { backtrace, .. } = err {
|
131 |
+
log::error!("Error processing request: {err}");
|
132 |
+
|
133 |
+
if let Some(backtrace) = backtrace {
|
134 |
+
log::trace!("Backtrace: {backtrace}");
|
135 |
+
}
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
pub type HttpResult<T, E = HttpError> = Result<T, E>;
|
140 |
+
|
141 |
+
#[derive(Clone, Debug, thiserror::Error)]
|
142 |
+
#[error("{0}")]
|
143 |
+
pub struct HttpError(StorageError);
|
144 |
+
|
145 |
+
impl ResponseError for HttpError {
|
146 |
+
fn status_code(&self) -> http::StatusCode {
|
147 |
+
match &self.0 {
|
148 |
+
StorageError::BadInput { .. } => http::StatusCode::BAD_REQUEST,
|
149 |
+
StorageError::NotFound { .. } => http::StatusCode::NOT_FOUND,
|
150 |
+
StorageError::ServiceError { .. } => http::StatusCode::INTERNAL_SERVER_ERROR,
|
151 |
+
StorageError::BadRequest { .. } => http::StatusCode::BAD_REQUEST,
|
152 |
+
StorageError::Locked { .. } => http::StatusCode::FORBIDDEN,
|
153 |
+
StorageError::Timeout { .. } => http::StatusCode::REQUEST_TIMEOUT,
|
154 |
+
StorageError::AlreadyExists { .. } => http::StatusCode::CONFLICT,
|
155 |
+
StorageError::ChecksumMismatch { .. } => http::StatusCode::BAD_REQUEST,
|
156 |
+
StorageError::Forbidden { .. } => http::StatusCode::FORBIDDEN,
|
157 |
+
StorageError::PreconditionFailed { .. } => http::StatusCode::INTERNAL_SERVER_ERROR,
|
158 |
+
StorageError::InferenceError { .. } => http::StatusCode::BAD_REQUEST,
|
159 |
+
}
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
impl From<StorageError> for HttpError {
|
164 |
+
fn from(err: StorageError) -> Self {
|
165 |
+
HttpError(err)
|
166 |
+
}
|
167 |
+
}
|
168 |
+
|
169 |
+
impl From<CollectionError> for HttpError {
|
170 |
+
fn from(err: CollectionError) -> Self {
|
171 |
+
HttpError(err.into())
|
172 |
+
}
|
173 |
+
}
|
174 |
+
|
175 |
+
impl From<std::io::Error> for HttpError {
|
176 |
+
fn from(err: std::io::Error) -> Self {
|
177 |
+
HttpError(err.into()) // TODO: Is this good enough?.. 🤔
|
178 |
+
}
|
179 |
+
}
|
src/actix/mod.rs
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
2 |
+
pub mod actix_telemetry;
|
3 |
+
pub mod api;
|
4 |
+
mod auth;
|
5 |
+
mod certificate_helpers;
|
6 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
7 |
+
pub mod helpers;
|
8 |
+
pub mod web_ui;
|
9 |
+
|
10 |
+
use std::io;
|
11 |
+
use std::sync::Arc;
|
12 |
+
|
13 |
+
use ::api::rest::models::{ApiResponse, ApiStatus, VersionInfo};
|
14 |
+
use actix_cors::Cors;
|
15 |
+
use actix_multipart::form::tempfile::TempFileConfig;
|
16 |
+
use actix_multipart::form::MultipartFormConfig;
|
17 |
+
use actix_web::middleware::{Compress, Condition, Logger};
|
18 |
+
use actix_web::{error, get, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
|
19 |
+
use actix_web_extras::middleware::Condition as ConditionEx;
|
20 |
+
use api::facet_api::config_facet_api;
|
21 |
+
use collection::operations::validation;
|
22 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
23 |
+
use storage::dispatcher::Dispatcher;
|
24 |
+
use storage::rbac::Access;
|
25 |
+
|
26 |
+
use crate::actix::api::cluster_api::config_cluster_api;
|
27 |
+
use crate::actix::api::collections_api::config_collections_api;
|
28 |
+
use crate::actix::api::count_api::count_points;
|
29 |
+
use crate::actix::api::debug_api::config_debugger_api;
|
30 |
+
use crate::actix::api::discovery_api::config_discovery_api;
|
31 |
+
use crate::actix::api::issues_api::config_issues_api;
|
32 |
+
use crate::actix::api::local_shard_api::config_local_shard_api;
|
33 |
+
use crate::actix::api::query_api::config_query_api;
|
34 |
+
use crate::actix::api::recommend_api::config_recommend_api;
|
35 |
+
use crate::actix::api::retrieve_api::{get_point, get_points, scroll_points};
|
36 |
+
use crate::actix::api::search_api::config_search_api;
|
37 |
+
use crate::actix::api::service_api::config_service_api;
|
38 |
+
use crate::actix::api::shards_api::config_shards_api;
|
39 |
+
use crate::actix::api::snapshot_api::config_snapshots_api;
|
40 |
+
use crate::actix::api::update_api::config_update_api;
|
41 |
+
use crate::actix::auth::{Auth, WhitelistItem};
|
42 |
+
use crate::actix::web_ui::{web_ui_factory, web_ui_folder, WEB_UI_PATH};
|
43 |
+
use crate::common::auth::AuthKeys;
|
44 |
+
use crate::common::debugger::DebuggerState;
|
45 |
+
use crate::common::health;
|
46 |
+
use crate::common::http_client::HttpClient;
|
47 |
+
use crate::common::telemetry::TelemetryCollector;
|
48 |
+
use crate::settings::{max_web_workers, Settings};
|
49 |
+
use crate::tracing::LoggerHandle;
|
50 |
+
|
51 |
+
#[get("/")]
|
52 |
+
pub async fn index() -> impl Responder {
|
53 |
+
HttpResponse::Ok().json(VersionInfo::default())
|
54 |
+
}
|
55 |
+
|
56 |
+
#[allow(dead_code)]
|
57 |
+
pub fn init(
|
58 |
+
dispatcher: Arc<Dispatcher>,
|
59 |
+
telemetry_collector: Arc<tokio::sync::Mutex<TelemetryCollector>>,
|
60 |
+
health_checker: Option<Arc<health::HealthChecker>>,
|
61 |
+
settings: Settings,
|
62 |
+
logger_handle: LoggerHandle,
|
63 |
+
) -> io::Result<()> {
|
64 |
+
actix_web::rt::System::new().block_on(async {
|
65 |
+
// Nothing to verify here.
|
66 |
+
let pass = new_unchecked_verification_pass();
|
67 |
+
let auth_keys = AuthKeys::try_create(
|
68 |
+
&settings.service,
|
69 |
+
dispatcher
|
70 |
+
.toc(&Access::full("For JWT validation"), &pass)
|
71 |
+
.clone(),
|
72 |
+
);
|
73 |
+
let upload_dir = dispatcher
|
74 |
+
.toc(&Access::full("For upload dir"), &pass)
|
75 |
+
.upload_dir()
|
76 |
+
.unwrap();
|
77 |
+
let dispatcher_data = web::Data::from(dispatcher);
|
78 |
+
let actix_telemetry_collector = telemetry_collector
|
79 |
+
.lock()
|
80 |
+
.await
|
81 |
+
.actix_telemetry_collector
|
82 |
+
.clone();
|
83 |
+
let debugger_state = web::Data::new(DebuggerState::from_settings(&settings));
|
84 |
+
let telemetry_collector_data = web::Data::from(telemetry_collector);
|
85 |
+
let logger_handle_data = web::Data::new(logger_handle);
|
86 |
+
let http_client = web::Data::new(HttpClient::from_settings(&settings)?);
|
87 |
+
let health_checker = web::Data::new(health_checker);
|
88 |
+
let web_ui_available = web_ui_folder(&settings);
|
89 |
+
let service_config = web::Data::new(settings.service.clone());
|
90 |
+
|
91 |
+
let mut api_key_whitelist = vec![
|
92 |
+
WhitelistItem::exact("/"),
|
93 |
+
WhitelistItem::exact("/healthz"),
|
94 |
+
WhitelistItem::prefix("/readyz"),
|
95 |
+
WhitelistItem::prefix("/livez"),
|
96 |
+
];
|
97 |
+
if web_ui_available.is_some() {
|
98 |
+
api_key_whitelist.push(WhitelistItem::prefix(WEB_UI_PATH));
|
99 |
+
}
|
100 |
+
|
101 |
+
let mut server = HttpServer::new(move || {
|
102 |
+
let cors = Cors::default()
|
103 |
+
.allow_any_origin()
|
104 |
+
.allow_any_method()
|
105 |
+
.allow_any_header();
|
106 |
+
let validate_path_config = actix_web_validator::PathConfig::default()
|
107 |
+
.error_handler(|err, rec| validation_error_handler("path parameters", err, rec));
|
108 |
+
let validate_query_config = actix_web_validator::QueryConfig::default()
|
109 |
+
.error_handler(|err, rec| validation_error_handler("query parameters", err, rec));
|
110 |
+
let validate_json_config = actix_web_validator::JsonConfig::default()
|
111 |
+
.limit(settings.service.max_request_size_mb * 1024 * 1024)
|
112 |
+
.error_handler(|err, rec| validation_error_handler("JSON body", err, rec));
|
113 |
+
|
114 |
+
let mut app = App::new()
|
115 |
+
.wrap(Compress::default()) // Reads the `Accept-Encoding` header to negotiate which compression codec to use.
|
116 |
+
// api_key middleware
|
117 |
+
// note: the last call to `wrap()` or `wrap_fn()` is executed first
|
118 |
+
.wrap(ConditionEx::from_option(auth_keys.as_ref().map(
|
119 |
+
|auth_keys| Auth::new(auth_keys.clone(), api_key_whitelist.clone()),
|
120 |
+
)))
|
121 |
+
.wrap(Condition::new(settings.service.enable_cors, cors))
|
122 |
+
.wrap(
|
123 |
+
// Set up logger, but avoid logging hot status endpoints
|
124 |
+
Logger::default()
|
125 |
+
.exclude("/")
|
126 |
+
.exclude("/metrics")
|
127 |
+
.exclude("/telemetry")
|
128 |
+
.exclude("/healthz")
|
129 |
+
.exclude("/readyz")
|
130 |
+
.exclude("/livez"),
|
131 |
+
)
|
132 |
+
.wrap(actix_telemetry::ActixTelemetryTransform::new(
|
133 |
+
actix_telemetry_collector.clone(),
|
134 |
+
))
|
135 |
+
.app_data(dispatcher_data.clone())
|
136 |
+
.app_data(telemetry_collector_data.clone())
|
137 |
+
.app_data(logger_handle_data.clone())
|
138 |
+
.app_data(http_client.clone())
|
139 |
+
.app_data(debugger_state.clone())
|
140 |
+
.app_data(health_checker.clone())
|
141 |
+
.app_data(validate_path_config)
|
142 |
+
.app_data(validate_query_config)
|
143 |
+
.app_data(validate_json_config)
|
144 |
+
.app_data(TempFileConfig::default().directory(&upload_dir))
|
145 |
+
.app_data(MultipartFormConfig::default().total_limit(usize::MAX))
|
146 |
+
.app_data(service_config.clone())
|
147 |
+
.service(index)
|
148 |
+
.configure(config_collections_api)
|
149 |
+
.configure(config_snapshots_api)
|
150 |
+
.configure(config_update_api)
|
151 |
+
.configure(config_cluster_api)
|
152 |
+
.configure(config_service_api)
|
153 |
+
.configure(config_search_api)
|
154 |
+
.configure(config_recommend_api)
|
155 |
+
.configure(config_discovery_api)
|
156 |
+
.configure(config_query_api)
|
157 |
+
.configure(config_facet_api)
|
158 |
+
.configure(config_shards_api)
|
159 |
+
.configure(config_issues_api)
|
160 |
+
.configure(config_debugger_api)
|
161 |
+
.configure(config_local_shard_api)
|
162 |
+
// Ordering of services is important for correct path pattern matching
|
163 |
+
// See: <https://github.com/qdrant/qdrant/issues/3543>
|
164 |
+
.service(scroll_points)
|
165 |
+
.service(count_points)
|
166 |
+
.service(get_point)
|
167 |
+
.service(get_points);
|
168 |
+
|
169 |
+
if let Some(static_folder) = web_ui_available.as_deref() {
|
170 |
+
app = app.service(web_ui_factory(static_folder));
|
171 |
+
}
|
172 |
+
|
173 |
+
app
|
174 |
+
})
|
175 |
+
.workers(max_web_workers(&settings));
|
176 |
+
|
177 |
+
let port = settings.service.http_port;
|
178 |
+
let bind_addr = format!("{}:{}", settings.service.host, port);
|
179 |
+
|
180 |
+
// With TLS enabled, bind with certificate helper and Rustls, or bind regularly
|
181 |
+
server = if settings.service.enable_tls {
|
182 |
+
log::info!(
|
183 |
+
"TLS enabled for REST API (TTL: {})",
|
184 |
+
settings
|
185 |
+
.tls
|
186 |
+
.as_ref()
|
187 |
+
.and_then(|tls| tls.cert_ttl)
|
188 |
+
.map(|ttl| ttl.to_string())
|
189 |
+
.unwrap_or_else(|| "none".into()),
|
190 |
+
);
|
191 |
+
|
192 |
+
let config = certificate_helpers::actix_tls_server_config(&settings)
|
193 |
+
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
|
194 |
+
server.bind_rustls_0_23(bind_addr, config)?
|
195 |
+
} else {
|
196 |
+
log::info!("TLS disabled for REST API");
|
197 |
+
|
198 |
+
server.bind(bind_addr)?
|
199 |
+
};
|
200 |
+
|
201 |
+
log::info!("Qdrant HTTP listening on {}", port);
|
202 |
+
server.run().await
|
203 |
+
})
|
204 |
+
}
|
205 |
+
|
206 |
+
fn validation_error_handler(
|
207 |
+
name: &str,
|
208 |
+
err: actix_web_validator::Error,
|
209 |
+
_req: &HttpRequest,
|
210 |
+
) -> error::Error {
|
211 |
+
use actix_web_validator::error::DeserializeErrors;
|
212 |
+
|
213 |
+
// Nicely describe deserialization and validation errors
|
214 |
+
let msg = match &err {
|
215 |
+
actix_web_validator::Error::Validate(errs) => {
|
216 |
+
validation::label_errors(format!("Validation error in {name}"), errs)
|
217 |
+
}
|
218 |
+
actix_web_validator::Error::Deserialize(err) => {
|
219 |
+
format!(
|
220 |
+
"Deserialize error in {name}: {}",
|
221 |
+
match err {
|
222 |
+
DeserializeErrors::DeserializeQuery(err) => err.to_string(),
|
223 |
+
DeserializeErrors::DeserializeJson(err) => err.to_string(),
|
224 |
+
DeserializeErrors::DeserializePath(err) => err.to_string(),
|
225 |
+
}
|
226 |
+
)
|
227 |
+
}
|
228 |
+
actix_web_validator::Error::JsonPayloadError(
|
229 |
+
actix_web::error::JsonPayloadError::Deserialize(err),
|
230 |
+
) => {
|
231 |
+
format!("Format error in {name}: {err}",)
|
232 |
+
}
|
233 |
+
err => err.to_string(),
|
234 |
+
};
|
235 |
+
|
236 |
+
// Build fitting response
|
237 |
+
let response = match &err {
|
238 |
+
actix_web_validator::Error::Validate(_) => HttpResponse::UnprocessableEntity(),
|
239 |
+
_ => HttpResponse::BadRequest(),
|
240 |
+
}
|
241 |
+
.json(ApiResponse::<()> {
|
242 |
+
result: None,
|
243 |
+
status: ApiStatus::Error(msg),
|
244 |
+
time: 0.0,
|
245 |
+
usage: None,
|
246 |
+
});
|
247 |
+
error::InternalError::from_response(err, response).into()
|
248 |
+
}
|
249 |
+
|
250 |
+
#[cfg(test)]
|
251 |
+
mod tests {
|
252 |
+
use ::api::grpc::api_crate_version;
|
253 |
+
|
254 |
+
#[test]
|
255 |
+
fn test_version() {
|
256 |
+
assert_eq!(
|
257 |
+
api_crate_version(),
|
258 |
+
env!("CARGO_PKG_VERSION"),
|
259 |
+
"Qdrant and lib/api crate versions are not same"
|
260 |
+
);
|
261 |
+
}
|
262 |
+
}
|
src/actix/web_ui.rs
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::path::Path;
|
2 |
+
|
3 |
+
use actix_web::dev::HttpServiceFactory;
|
4 |
+
use actix_web::http::header::HeaderValue;
|
5 |
+
use actix_web::middleware::DefaultHeaders;
|
6 |
+
use actix_web::web;
|
7 |
+
|
8 |
+
use crate::settings::Settings;
|
9 |
+
|
10 |
+
const DEFAULT_STATIC_DIR: &str = "./static";
|
11 |
+
pub const WEB_UI_PATH: &str = "/dashboard";
|
12 |
+
|
13 |
+
pub fn web_ui_folder(settings: &Settings) -> Option<String> {
|
14 |
+
let web_ui_enabled = settings.service.enable_static_content.unwrap_or(true);
|
15 |
+
|
16 |
+
if web_ui_enabled {
|
17 |
+
let static_folder = settings
|
18 |
+
.service
|
19 |
+
.static_content_dir
|
20 |
+
.clone()
|
21 |
+
.unwrap_or_else(|| DEFAULT_STATIC_DIR.to_string());
|
22 |
+
let static_folder_path = Path::new(&static_folder);
|
23 |
+
if !static_folder_path.exists() || !static_folder_path.is_dir() {
|
24 |
+
// enabled BUT folder does not exist
|
25 |
+
log::warn!(
|
26 |
+
"Static content folder for Web UI '{}' does not exist",
|
27 |
+
static_folder_path.display(),
|
28 |
+
);
|
29 |
+
None
|
30 |
+
} else {
|
31 |
+
// enabled AND folder exists
|
32 |
+
Some(static_folder)
|
33 |
+
}
|
34 |
+
} else {
|
35 |
+
// not enabled
|
36 |
+
None
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
pub fn web_ui_factory(static_folder: &str) -> impl HttpServiceFactory {
|
41 |
+
web::scope(WEB_UI_PATH)
|
42 |
+
.wrap(DefaultHeaders::new().add(("X-Frame-Options", HeaderValue::from_static("DENY"))))
|
43 |
+
.service(actix_files::Files::new("/", static_folder).index_file("index.html"))
|
44 |
+
}
|
45 |
+
|
46 |
+
#[cfg(test)]
|
47 |
+
mod tests {
|
48 |
+
use actix_web::http::header::{self, HeaderMap};
|
49 |
+
use actix_web::http::StatusCode;
|
50 |
+
use actix_web::test::{self, TestRequest};
|
51 |
+
use actix_web::App;
|
52 |
+
|
53 |
+
use super::*;
|
54 |
+
|
55 |
+
fn assert_html_custom_headers(headers: &HeaderMap) {
|
56 |
+
let content_type = header::HeaderValue::from_static("text/html; charset=utf-8");
|
57 |
+
assert_eq!(headers.get(header::CONTENT_TYPE), Some(&content_type));
|
58 |
+
let x_frame_options = header::HeaderValue::from_static("DENY");
|
59 |
+
assert_eq!(headers.get(header::X_FRAME_OPTIONS), Some(&x_frame_options),);
|
60 |
+
}
|
61 |
+
|
62 |
+
#[actix_web::test]
|
63 |
+
async fn test_web_ui() {
|
64 |
+
let static_dir = String::from("static");
|
65 |
+
let mut settings = Settings::new(None).unwrap();
|
66 |
+
settings.service.static_content_dir = Some(static_dir.clone());
|
67 |
+
|
68 |
+
let maybe_static_folder = web_ui_folder(&settings);
|
69 |
+
if maybe_static_folder.is_none() {
|
70 |
+
println!("Skipping test because the static folder was not found.");
|
71 |
+
return;
|
72 |
+
}
|
73 |
+
|
74 |
+
let static_folder = maybe_static_folder.unwrap();
|
75 |
+
let srv = test::init_service(App::new().service(web_ui_factory(&static_folder))).await;
|
76 |
+
|
77 |
+
// Index path (no trailing slash)
|
78 |
+
let req = TestRequest::with_uri(WEB_UI_PATH).to_request();
|
79 |
+
let res = test::call_service(&srv, req).await;
|
80 |
+
assert_eq!(res.status(), StatusCode::OK);
|
81 |
+
let headers = res.headers();
|
82 |
+
assert_html_custom_headers(headers);
|
83 |
+
// Index path (trailing slash)
|
84 |
+
let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/").as_str()).to_request();
|
85 |
+
let res = test::call_service(&srv, req).await;
|
86 |
+
assert_eq!(res.status(), StatusCode::OK);
|
87 |
+
let headers = res.headers();
|
88 |
+
assert_html_custom_headers(headers);
|
89 |
+
// Index path (index.html file)
|
90 |
+
let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/index.html").as_str()).to_request();
|
91 |
+
let res = test::call_service(&srv, req).await;
|
92 |
+
assert_eq!(res.status(), StatusCode::OK);
|
93 |
+
let headers = res.headers();
|
94 |
+
assert_html_custom_headers(headers);
|
95 |
+
// Static asset (favicon.ico)
|
96 |
+
let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/favicon.ico").as_str()).to_request();
|
97 |
+
let res = test::call_service(&srv, req).await;
|
98 |
+
assert_eq!(res.status(), StatusCode::OK);
|
99 |
+
let headers = res.headers();
|
100 |
+
assert_eq!(
|
101 |
+
headers.get(header::CONTENT_TYPE),
|
102 |
+
Some(&header::HeaderValue::from_static("image/x-icon")),
|
103 |
+
);
|
104 |
+
// Non-existing path (404 Not Found)
|
105 |
+
let fake_path = uuid::Uuid::new_v4().to_string();
|
106 |
+
let srv = test::init_service(App::new().service(web_ui_factory(&fake_path))).await;
|
107 |
+
|
108 |
+
let req = TestRequest::with_uri(WEB_UI_PATH).to_request();
|
109 |
+
let res = test::call_service(&srv, req).await;
|
110 |
+
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
111 |
+
let headers = res.headers();
|
112 |
+
assert_eq!(headers.get(header::CONTENT_TYPE), None);
|
113 |
+
assert_eq!(headers.get(header::CONTENT_LENGTH), None);
|
114 |
+
}
|
115 |
+
}
|
src/common/auth/claims.rs
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use segment::json_path::JsonPath;
|
2 |
+
use segment::types::{Condition, FieldCondition, Filter, Match, ValueVariants};
|
3 |
+
use serde::{Deserialize, Serialize};
|
4 |
+
use storage::rbac::Access;
|
5 |
+
use validator::{Validate, ValidationErrors};
|
6 |
+
|
7 |
+
#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
|
8 |
+
pub struct Claims {
|
9 |
+
/// Expiration time (seconds since UNIX epoch)
|
10 |
+
pub exp: Option<u64>,
|
11 |
+
|
12 |
+
#[serde(default = "default_access")]
|
13 |
+
pub access: Access,
|
14 |
+
|
15 |
+
/// Validate this token by looking for a value inside a collection.
|
16 |
+
pub value_exists: Option<ValueExists>,
|
17 |
+
}
|
18 |
+
|
19 |
+
#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
|
20 |
+
pub struct KeyValuePair {
|
21 |
+
key: JsonPath,
|
22 |
+
value: ValueVariants,
|
23 |
+
}
|
24 |
+
|
25 |
+
impl KeyValuePair {
|
26 |
+
pub fn to_condition(&self) -> Condition {
|
27 |
+
Condition::Field(FieldCondition::new_match(
|
28 |
+
self.key.clone(),
|
29 |
+
Match::new_value(self.value.clone()),
|
30 |
+
))
|
31 |
+
}
|
32 |
+
}
|
33 |
+
|
34 |
+
#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
|
35 |
+
pub struct ValueExists {
|
36 |
+
collection: String,
|
37 |
+
matches: Vec<KeyValuePair>,
|
38 |
+
}
|
39 |
+
|
40 |
+
fn default_access() -> Access {
|
41 |
+
Access::full("Give full access when the access field is not present")
|
42 |
+
}
|
43 |
+
|
44 |
+
impl ValueExists {
|
45 |
+
pub fn get_collection(&self) -> &str {
|
46 |
+
&self.collection
|
47 |
+
}
|
48 |
+
|
49 |
+
pub fn to_filter(&self) -> Filter {
|
50 |
+
let conditions = self
|
51 |
+
.matches
|
52 |
+
.iter()
|
53 |
+
.map(|pair| pair.to_condition())
|
54 |
+
.collect();
|
55 |
+
|
56 |
+
Filter {
|
57 |
+
should: None,
|
58 |
+
min_should: None,
|
59 |
+
must: Some(conditions),
|
60 |
+
must_not: None,
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
impl Validate for Claims {
|
66 |
+
fn validate(&self) -> Result<(), ValidationErrors> {
|
67 |
+
ValidationErrors::merge_all(Ok(()), "access", self.access.validate())
|
68 |
+
}
|
69 |
+
}
|
src/common/auth/jwt_parser.rs
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use jsonwebtoken::errors::ErrorKind;
|
2 |
+
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
3 |
+
use validator::Validate;
|
4 |
+
|
5 |
+
use super::claims::Claims;
|
6 |
+
use super::AuthError;
|
7 |
+
|
8 |
+
#[derive(Clone)]
|
9 |
+
pub struct JwtParser {
|
10 |
+
key: DecodingKey,
|
11 |
+
validation: Validation,
|
12 |
+
}
|
13 |
+
|
14 |
+
impl JwtParser {
|
15 |
+
const ALGORITHM: Algorithm = Algorithm::HS256;
|
16 |
+
|
17 |
+
pub fn new(secret: &str) -> Self {
|
18 |
+
let key = DecodingKey::from_secret(secret.as_bytes());
|
19 |
+
let mut validation = Validation::new(Self::ALGORITHM);
|
20 |
+
|
21 |
+
// Qdrant server is the only audience
|
22 |
+
validation.validate_aud = false;
|
23 |
+
|
24 |
+
// Expiration time leeway to account for clock skew
|
25 |
+
validation.leeway = 30;
|
26 |
+
|
27 |
+
// All claims are optional
|
28 |
+
validation.required_spec_claims = Default::default();
|
29 |
+
|
30 |
+
JwtParser { key, validation }
|
31 |
+
}
|
32 |
+
|
33 |
+
/// Decode the token and return the claims, this already validates the `exp` claim with some leeway.
|
34 |
+
/// Returns None when the token doesn't look like a JWT.
|
35 |
+
pub fn decode(&self, token: &str) -> Option<Result<Claims, AuthError>> {
|
36 |
+
let claims = match decode::<Claims>(token, &self.key, &self.validation) {
|
37 |
+
Ok(token_data) => token_data.claims,
|
38 |
+
Err(e) => {
|
39 |
+
return match e.kind() {
|
40 |
+
ErrorKind::ExpiredSignature | ErrorKind::InvalidSignature => {
|
41 |
+
Some(Err(AuthError::Forbidden(e.to_string())))
|
42 |
+
}
|
43 |
+
_ => None,
|
44 |
+
}
|
45 |
+
}
|
46 |
+
};
|
47 |
+
if let Err(e) = claims.validate() {
|
48 |
+
return Some(Err(AuthError::Unauthorized(e.to_string())));
|
49 |
+
}
|
50 |
+
Some(Ok(claims))
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
#[cfg(test)]
|
55 |
+
mod tests {
|
56 |
+
use segment::types::ValueVariants;
|
57 |
+
use storage::rbac::{
|
58 |
+
Access, CollectionAccess, CollectionAccessList, CollectionAccessMode, GlobalAccessMode,
|
59 |
+
PayloadConstraint,
|
60 |
+
};
|
61 |
+
|
62 |
+
use super::*;
|
63 |
+
|
64 |
+
pub fn create_token(claims: &Claims) -> String {
|
65 |
+
use jsonwebtoken::{encode, EncodingKey, Header};
|
66 |
+
|
67 |
+
let key = EncodingKey::from_secret("secret".as_ref());
|
68 |
+
let header = Header::new(JwtParser::ALGORITHM);
|
69 |
+
encode(&header, claims, &key).unwrap()
|
70 |
+
}
|
71 |
+
|
72 |
+
#[test]
|
73 |
+
fn test_jwt_parser() {
|
74 |
+
let exp = std::time::SystemTime::now()
|
75 |
+
.duration_since(std::time::UNIX_EPOCH)
|
76 |
+
.expect("Time went backwards")
|
77 |
+
.as_secs();
|
78 |
+
let claims = Claims {
|
79 |
+
exp: Some(exp),
|
80 |
+
access: Access::Collection(CollectionAccessList(vec![CollectionAccess {
|
81 |
+
collection: "collection".to_string(),
|
82 |
+
access: CollectionAccessMode::ReadWrite,
|
83 |
+
payload: Some(PayloadConstraint(
|
84 |
+
vec![
|
85 |
+
(
|
86 |
+
"field1".parse().unwrap(),
|
87 |
+
ValueVariants::String("value".to_string()),
|
88 |
+
),
|
89 |
+
("field2".parse().unwrap(), ValueVariants::Integer(42)),
|
90 |
+
("field3".parse().unwrap(), ValueVariants::Bool(true)),
|
91 |
+
]
|
92 |
+
.into_iter()
|
93 |
+
.collect(),
|
94 |
+
)),
|
95 |
+
}])),
|
96 |
+
value_exists: None,
|
97 |
+
};
|
98 |
+
let token = create_token(&claims);
|
99 |
+
|
100 |
+
let secret = "secret";
|
101 |
+
let parser = JwtParser::new(secret);
|
102 |
+
let decoded_claims = parser.decode(&token).unwrap().unwrap();
|
103 |
+
|
104 |
+
assert_eq!(claims, decoded_claims);
|
105 |
+
}
|
106 |
+
|
107 |
+
#[test]
|
108 |
+
fn test_exp_validation() {
|
109 |
+
let exp = std::time::SystemTime::now()
|
110 |
+
.duration_since(std::time::UNIX_EPOCH)
|
111 |
+
.expect("Time went backwards")
|
112 |
+
.as_secs()
|
113 |
+
- 31; // 31 seconds in the past, bigger than the 30 seconds leeway
|
114 |
+
|
115 |
+
let mut claims = Claims {
|
116 |
+
exp: Some(exp),
|
117 |
+
access: Access::Global(GlobalAccessMode::Read),
|
118 |
+
value_exists: None,
|
119 |
+
};
|
120 |
+
|
121 |
+
let token = create_token(&claims);
|
122 |
+
|
123 |
+
let secret = "secret";
|
124 |
+
let parser = JwtParser::new(secret);
|
125 |
+
assert!(matches!(
|
126 |
+
parser.decode(&token),
|
127 |
+
Some(Err(AuthError::Forbidden(_)))
|
128 |
+
));
|
129 |
+
|
130 |
+
// Remove the exp claim and it should work
|
131 |
+
claims.exp = None;
|
132 |
+
let token = create_token(&claims);
|
133 |
+
|
134 |
+
let decoded_claims = parser.decode(&token).unwrap().unwrap();
|
135 |
+
|
136 |
+
assert_eq!(claims, decoded_claims);
|
137 |
+
}
|
138 |
+
|
139 |
+
#[test]
|
140 |
+
fn test_invalid_token() {
|
141 |
+
let claims = Claims {
|
142 |
+
exp: None,
|
143 |
+
access: Access::Global(GlobalAccessMode::Read),
|
144 |
+
value_exists: None,
|
145 |
+
};
|
146 |
+
let token = create_token(&claims);
|
147 |
+
|
148 |
+
assert!(matches!(
|
149 |
+
JwtParser::new("wrong-secret").decode(&token),
|
150 |
+
Some(Err(AuthError::Forbidden(_)))
|
151 |
+
));
|
152 |
+
|
153 |
+
assert!(JwtParser::new("secret").decode("foo.bar.baz").is_none());
|
154 |
+
}
|
155 |
+
}
|
src/common/auth/mod.rs
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::sync::Arc;
|
2 |
+
|
3 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
4 |
+
use collection::operations::types::ScrollRequestInternal;
|
5 |
+
use segment::types::{WithPayloadInterface, WithVector};
|
6 |
+
use storage::content_manager::errors::StorageError;
|
7 |
+
use storage::content_manager::toc::TableOfContent;
|
8 |
+
use storage::rbac::Access;
|
9 |
+
|
10 |
+
use self::claims::{Claims, ValueExists};
|
11 |
+
use self::jwt_parser::JwtParser;
|
12 |
+
use super::strings::ct_eq;
|
13 |
+
use crate::settings::ServiceConfig;
|
14 |
+
|
15 |
+
pub mod claims;
|
16 |
+
pub mod jwt_parser;
|
17 |
+
|
18 |
+
pub const HTTP_HEADER_API_KEY: &str = "api-key";
|
19 |
+
|
20 |
+
/// The API keys used for auth
|
21 |
+
#[derive(Clone)]
|
22 |
+
pub struct AuthKeys {
|
23 |
+
/// A key allowing Read or Write operations
|
24 |
+
read_write: Option<String>,
|
25 |
+
|
26 |
+
/// A key allowing Read operations
|
27 |
+
read_only: Option<String>,
|
28 |
+
|
29 |
+
/// A JWT parser, based on the read_write key
|
30 |
+
jwt_parser: Option<JwtParser>,
|
31 |
+
|
32 |
+
/// Table of content, needed to do stateful validation of JWT
|
33 |
+
toc: Arc<TableOfContent>,
|
34 |
+
}
|
35 |
+
|
36 |
+
#[derive(Debug)]
|
37 |
+
pub enum AuthError {
|
38 |
+
Unauthorized(String),
|
39 |
+
Forbidden(String),
|
40 |
+
StorageError(StorageError),
|
41 |
+
}
|
42 |
+
|
43 |
+
impl AuthKeys {
|
44 |
+
fn get_jwt_parser(service_config: &ServiceConfig) -> Option<JwtParser> {
|
45 |
+
if service_config.jwt_rbac.unwrap_or_default() {
|
46 |
+
service_config
|
47 |
+
.api_key
|
48 |
+
.as_ref()
|
49 |
+
.map(|secret| JwtParser::new(secret))
|
50 |
+
} else {
|
51 |
+
None
|
52 |
+
}
|
53 |
+
}
|
54 |
+
|
55 |
+
/// Defines the auth scheme given the service config
|
56 |
+
///
|
57 |
+
/// Returns None if no scheme is specified.
|
58 |
+
pub fn try_create(service_config: &ServiceConfig, toc: Arc<TableOfContent>) -> Option<Self> {
|
59 |
+
match (
|
60 |
+
service_config.api_key.clone(),
|
61 |
+
service_config.read_only_api_key.clone(),
|
62 |
+
) {
|
63 |
+
(None, None) => None,
|
64 |
+
(read_write, read_only) => Some(Self {
|
65 |
+
read_write,
|
66 |
+
read_only,
|
67 |
+
jwt_parser: Self::get_jwt_parser(service_config),
|
68 |
+
toc,
|
69 |
+
}),
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
/// Validate that the specified request is allowed for given keys.
|
74 |
+
pub async fn validate_request<'a>(
|
75 |
+
&self,
|
76 |
+
get_header: impl Fn(&'a str) -> Option<&'a str>,
|
77 |
+
) -> Result<Access, AuthError> {
|
78 |
+
let Some(key) = get_header(HTTP_HEADER_API_KEY)
|
79 |
+
.or_else(|| get_header("authorization").and_then(|v| v.strip_prefix("Bearer ")))
|
80 |
+
else {
|
81 |
+
return Err(AuthError::Unauthorized(
|
82 |
+
"Must provide an API key or an Authorization bearer token".to_string(),
|
83 |
+
));
|
84 |
+
};
|
85 |
+
|
86 |
+
if self.can_write(key) {
|
87 |
+
return Ok(Access::full("Read-write access by key"));
|
88 |
+
}
|
89 |
+
|
90 |
+
if self.can_read(key) {
|
91 |
+
return Ok(Access::full_ro("Read-only access by key"));
|
92 |
+
}
|
93 |
+
|
94 |
+
if let Some(claims) = self.jwt_parser.as_ref().and_then(|p| p.decode(key)) {
|
95 |
+
let Claims {
|
96 |
+
exp: _, // already validated on decoding
|
97 |
+
access,
|
98 |
+
value_exists,
|
99 |
+
} = claims?;
|
100 |
+
|
101 |
+
if let Some(value_exists) = value_exists {
|
102 |
+
self.validate_value_exists(&value_exists).await?;
|
103 |
+
}
|
104 |
+
|
105 |
+
return Ok(access);
|
106 |
+
}
|
107 |
+
|
108 |
+
Err(AuthError::Unauthorized(
|
109 |
+
"Invalid API key or JWT".to_string(),
|
110 |
+
))
|
111 |
+
}
|
112 |
+
|
113 |
+
async fn validate_value_exists(&self, value_exists: &ValueExists) -> Result<(), AuthError> {
|
114 |
+
let scroll_req = ScrollRequestInternal {
|
115 |
+
offset: None,
|
116 |
+
limit: Some(1),
|
117 |
+
filter: Some(value_exists.to_filter()),
|
118 |
+
with_payload: Some(WithPayloadInterface::Bool(false)),
|
119 |
+
with_vector: WithVector::Bool(false),
|
120 |
+
order_by: None,
|
121 |
+
};
|
122 |
+
|
123 |
+
let res = self
|
124 |
+
.toc
|
125 |
+
.scroll(
|
126 |
+
value_exists.get_collection(),
|
127 |
+
scroll_req,
|
128 |
+
None,
|
129 |
+
None, // no timeout
|
130 |
+
ShardSelectorInternal::All,
|
131 |
+
Access::full("JWT stateful validation"),
|
132 |
+
)
|
133 |
+
.await
|
134 |
+
.map_err(|e| match e {
|
135 |
+
StorageError::NotFound { .. } => {
|
136 |
+
AuthError::Forbidden("Invalid JWT, stateful validation failed".to_string())
|
137 |
+
}
|
138 |
+
_ => AuthError::StorageError(e),
|
139 |
+
})?;
|
140 |
+
|
141 |
+
if res.points.is_empty() {
|
142 |
+
return Err(AuthError::Unauthorized(
|
143 |
+
"Invalid JWT, stateful validation failed".to_string(),
|
144 |
+
));
|
145 |
+
};
|
146 |
+
|
147 |
+
Ok(())
|
148 |
+
}
|
149 |
+
|
150 |
+
/// Check if a key is allowed to read
|
151 |
+
#[inline]
|
152 |
+
fn can_read(&self, key: &str) -> bool {
|
153 |
+
self.read_only
|
154 |
+
.as_ref()
|
155 |
+
.is_some_and(|ro_key| ct_eq(ro_key, key))
|
156 |
+
}
|
157 |
+
|
158 |
+
/// Check if a key is allowed to write
|
159 |
+
#[inline]
|
160 |
+
fn can_write(&self, key: &str) -> bool {
|
161 |
+
self.read_write
|
162 |
+
.as_ref()
|
163 |
+
.is_some_and(|rw_key| ct_eq(rw_key, key))
|
164 |
+
}
|
165 |
+
}
|
src/common/collections.rs
ADDED
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::collections::HashMap;
|
2 |
+
use std::sync::Arc;
|
3 |
+
use std::time::Duration;
|
4 |
+
|
5 |
+
use api::grpc::qdrant::CollectionExists;
|
6 |
+
use api::rest::models::{CollectionDescription, CollectionsResponse};
|
7 |
+
use collection::config::ShardingMethod;
|
8 |
+
use collection::operations::cluster_ops::{
|
9 |
+
AbortTransferOperation, ClusterOperations, DropReplicaOperation, MoveShardOperation,
|
10 |
+
ReplicateShardOperation, ReshardingDirection, RestartTransfer, RestartTransferOperation,
|
11 |
+
StartResharding,
|
12 |
+
};
|
13 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
14 |
+
use collection::operations::snapshot_ops::SnapshotDescription;
|
15 |
+
use collection::operations::types::{
|
16 |
+
AliasDescription, CollectionClusterInfo, CollectionInfo, CollectionsAliasesResponse,
|
17 |
+
};
|
18 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
19 |
+
use collection::shards::replica_set;
|
20 |
+
use collection::shards::resharding::ReshardKey;
|
21 |
+
use collection::shards::shard::{PeerId, ShardId, ShardsPlacement};
|
22 |
+
use collection::shards::transfer::{ShardTransfer, ShardTransferKey, ShardTransferRestart};
|
23 |
+
use itertools::Itertools;
|
24 |
+
use rand::prelude::SliceRandom;
|
25 |
+
use rand::seq::IteratorRandom;
|
26 |
+
use storage::content_manager::collection_meta_ops::ShardTransferOperations::{Abort, Start};
|
27 |
+
use storage::content_manager::collection_meta_ops::{
|
28 |
+
CollectionMetaOperations, CreateShardKey, DropShardKey, ReshardingOperation,
|
29 |
+
SetShardReplicaState, ShardTransferOperations, UpdateCollectionOperation,
|
30 |
+
};
|
31 |
+
use storage::content_manager::errors::StorageError;
|
32 |
+
use storage::content_manager::toc::TableOfContent;
|
33 |
+
use storage::dispatcher::Dispatcher;
|
34 |
+
use storage::rbac::{Access, AccessRequirements};
|
35 |
+
|
36 |
+
pub async fn do_collection_exists(
|
37 |
+
toc: &TableOfContent,
|
38 |
+
access: Access,
|
39 |
+
name: &str,
|
40 |
+
) -> Result<CollectionExists, StorageError> {
|
41 |
+
let collection_pass = access.check_collection_access(name, AccessRequirements::new())?;
|
42 |
+
|
43 |
+
// if this returns Ok, it means the collection exists.
|
44 |
+
// if not, we check that the error is NotFound
|
45 |
+
let Err(error) = toc.get_collection(&collection_pass).await else {
|
46 |
+
return Ok(CollectionExists { exists: true });
|
47 |
+
};
|
48 |
+
match error {
|
49 |
+
StorageError::NotFound { .. } => Ok(CollectionExists { exists: false }),
|
50 |
+
e => Err(e),
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
pub async fn do_get_collection(
|
55 |
+
toc: &TableOfContent,
|
56 |
+
access: Access,
|
57 |
+
name: &str,
|
58 |
+
shard_selection: Option<ShardId>,
|
59 |
+
) -> Result<CollectionInfo, StorageError> {
|
60 |
+
let collection_pass =
|
61 |
+
access.check_collection_access(name, AccessRequirements::new().whole())?;
|
62 |
+
|
63 |
+
let collection = toc.get_collection(&collection_pass).await?;
|
64 |
+
|
65 |
+
let shard_selection = match shard_selection {
|
66 |
+
None => ShardSelectorInternal::All,
|
67 |
+
Some(shard_id) => ShardSelectorInternal::ShardId(shard_id),
|
68 |
+
};
|
69 |
+
|
70 |
+
Ok(collection.info(&shard_selection).await?)
|
71 |
+
}
|
72 |
+
|
73 |
+
pub async fn do_list_collections(
|
74 |
+
toc: &TableOfContent,
|
75 |
+
access: Access,
|
76 |
+
) -> Result<CollectionsResponse, StorageError> {
|
77 |
+
let collections = toc
|
78 |
+
.all_collections(&access)
|
79 |
+
.await
|
80 |
+
.into_iter()
|
81 |
+
.map(|pass| CollectionDescription {
|
82 |
+
name: pass.name().to_string(),
|
83 |
+
})
|
84 |
+
.collect_vec();
|
85 |
+
|
86 |
+
Ok(CollectionsResponse { collections })
|
87 |
+
}
|
88 |
+
|
89 |
+
/// Construct shards-replicas layout for the shard from the given scope of peers
|
90 |
+
/// Example:
|
91 |
+
/// Shards: 3
|
92 |
+
/// Replicas: 2
|
93 |
+
/// Peers: [A, B, C]
|
94 |
+
///
|
95 |
+
/// Placement:
|
96 |
+
/// [
|
97 |
+
/// [A, B]
|
98 |
+
/// [B, C]
|
99 |
+
/// [A, C]
|
100 |
+
/// ]
|
101 |
+
fn generate_even_placement(
|
102 |
+
mut pool: Vec<PeerId>,
|
103 |
+
shard_number: usize,
|
104 |
+
replication_factor: usize,
|
105 |
+
) -> ShardsPlacement {
|
106 |
+
let mut exact_placement = Vec::new();
|
107 |
+
let mut rng = rand::thread_rng();
|
108 |
+
pool.shuffle(&mut rng);
|
109 |
+
let mut loop_iter = pool.iter().cycle();
|
110 |
+
|
111 |
+
// pool: [1,2,3,4]
|
112 |
+
// shuf_pool: [2,3,4,1]
|
113 |
+
//
|
114 |
+
// loop_iter: [2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1,...]
|
115 |
+
// shard_placement: [2, 3, 4][1, 2, 3][4, 1, 2][3, 4, 1][2, 3, 4]
|
116 |
+
|
117 |
+
let max_replication_factor = std::cmp::min(replication_factor, pool.len());
|
118 |
+
for _shard in 0..shard_number {
|
119 |
+
let mut shard_placement = Vec::new();
|
120 |
+
for _replica in 0..max_replication_factor {
|
121 |
+
shard_placement.push(*loop_iter.next().unwrap());
|
122 |
+
}
|
123 |
+
exact_placement.push(shard_placement);
|
124 |
+
}
|
125 |
+
exact_placement
|
126 |
+
}
|
127 |
+
|
128 |
+
pub async fn do_list_collection_aliases(
|
129 |
+
toc: &TableOfContent,
|
130 |
+
access: Access,
|
131 |
+
collection_name: &str,
|
132 |
+
) -> Result<CollectionsAliasesResponse, StorageError> {
|
133 |
+
let collection_pass =
|
134 |
+
access.check_collection_access(collection_name, AccessRequirements::new())?;
|
135 |
+
let aliases: Vec<AliasDescription> = toc
|
136 |
+
.collection_aliases(&collection_pass, &access)
|
137 |
+
.await?
|
138 |
+
.into_iter()
|
139 |
+
.map(|alias| AliasDescription {
|
140 |
+
alias_name: alias,
|
141 |
+
collection_name: collection_name.to_string(),
|
142 |
+
})
|
143 |
+
.collect();
|
144 |
+
Ok(CollectionsAliasesResponse { aliases })
|
145 |
+
}
|
146 |
+
|
147 |
+
pub async fn do_list_aliases(
|
148 |
+
toc: &TableOfContent,
|
149 |
+
access: Access,
|
150 |
+
) -> Result<CollectionsAliasesResponse, StorageError> {
|
151 |
+
let aliases = toc.list_aliases(&access).await?;
|
152 |
+
Ok(CollectionsAliasesResponse { aliases })
|
153 |
+
}
|
154 |
+
|
155 |
+
pub async fn do_list_snapshots(
|
156 |
+
toc: &TableOfContent,
|
157 |
+
access: Access,
|
158 |
+
collection_name: &str,
|
159 |
+
) -> Result<Vec<SnapshotDescription>, StorageError> {
|
160 |
+
let collection_pass =
|
161 |
+
access.check_collection_access(collection_name, AccessRequirements::new().whole())?;
|
162 |
+
Ok(toc
|
163 |
+
.get_collection(&collection_pass)
|
164 |
+
.await?
|
165 |
+
.list_snapshots()
|
166 |
+
.await?)
|
167 |
+
}
|
168 |
+
|
169 |
+
pub async fn do_create_snapshot(
|
170 |
+
toc: Arc<TableOfContent>,
|
171 |
+
access: Access,
|
172 |
+
collection_name: &str,
|
173 |
+
) -> Result<SnapshotDescription, StorageError> {
|
174 |
+
let collection_pass = access
|
175 |
+
.check_collection_access(collection_name, AccessRequirements::new().write().whole())?
|
176 |
+
.into_static();
|
177 |
+
|
178 |
+
let result = tokio::spawn(async move { toc.create_snapshot(&collection_pass).await }).await??;
|
179 |
+
|
180 |
+
Ok(result)
|
181 |
+
}
|
182 |
+
|
183 |
+
pub async fn do_get_collection_cluster(
|
184 |
+
toc: &TableOfContent,
|
185 |
+
access: Access,
|
186 |
+
name: &str,
|
187 |
+
) -> Result<CollectionClusterInfo, StorageError> {
|
188 |
+
let collection_pass =
|
189 |
+
access.check_collection_access(name, AccessRequirements::new().whole())?;
|
190 |
+
let collection = toc.get_collection(&collection_pass).await?;
|
191 |
+
Ok(collection.cluster_info(toc.this_peer_id).await?)
|
192 |
+
}
|
193 |
+
|
194 |
+
pub async fn do_update_collection_cluster(
|
195 |
+
dispatcher: &Dispatcher,
|
196 |
+
collection_name: String,
|
197 |
+
operation: ClusterOperations,
|
198 |
+
access: Access,
|
199 |
+
wait_timeout: Option<Duration>,
|
200 |
+
) -> Result<bool, StorageError> {
|
201 |
+
let collection_pass = access.check_collection_access(
|
202 |
+
&collection_name,
|
203 |
+
AccessRequirements::new().write().manage().whole(),
|
204 |
+
)?;
|
205 |
+
|
206 |
+
if dispatcher.consensus_state().is_none() {
|
207 |
+
return Err(StorageError::BadRequest {
|
208 |
+
description: "Distributed mode disabled".to_string(),
|
209 |
+
});
|
210 |
+
}
|
211 |
+
let consensus_state = dispatcher.consensus_state().unwrap();
|
212 |
+
|
213 |
+
let get_all_peer_ids = || {
|
214 |
+
consensus_state
|
215 |
+
.persistent
|
216 |
+
.read()
|
217 |
+
.peer_address_by_id
|
218 |
+
.read()
|
219 |
+
.keys()
|
220 |
+
.cloned()
|
221 |
+
.collect_vec()
|
222 |
+
};
|
223 |
+
|
224 |
+
let validate_peer_exists = |peer_id| {
|
225 |
+
let target_peer_exist = consensus_state
|
226 |
+
.persistent
|
227 |
+
.read()
|
228 |
+
.peer_address_by_id
|
229 |
+
.read()
|
230 |
+
.contains_key(&peer_id);
|
231 |
+
if !target_peer_exist {
|
232 |
+
return Err(StorageError::BadRequest {
|
233 |
+
description: format!("Peer {peer_id} does not exist"),
|
234 |
+
});
|
235 |
+
}
|
236 |
+
Ok(())
|
237 |
+
};
|
238 |
+
|
239 |
+
// All checks should've been done at this point.
|
240 |
+
let pass = new_unchecked_verification_pass();
|
241 |
+
|
242 |
+
let collection = dispatcher
|
243 |
+
.toc(&access, &pass)
|
244 |
+
.get_collection(&collection_pass)
|
245 |
+
.await?;
|
246 |
+
|
247 |
+
match operation {
|
248 |
+
ClusterOperations::MoveShard(MoveShardOperation { move_shard }) => {
|
249 |
+
// validate shard to move
|
250 |
+
if !collection.contains_shard(move_shard.shard_id).await {
|
251 |
+
return Err(StorageError::BadRequest {
|
252 |
+
description: format!(
|
253 |
+
"Shard {} of {} does not exist",
|
254 |
+
move_shard.shard_id, collection_name
|
255 |
+
),
|
256 |
+
});
|
257 |
+
};
|
258 |
+
|
259 |
+
// validate target and source peer exists
|
260 |
+
validate_peer_exists(move_shard.to_peer_id)?;
|
261 |
+
validate_peer_exists(move_shard.from_peer_id)?;
|
262 |
+
|
263 |
+
// submit operation to consensus
|
264 |
+
dispatcher
|
265 |
+
.submit_collection_meta_op(
|
266 |
+
CollectionMetaOperations::TransferShard(
|
267 |
+
collection_name,
|
268 |
+
Start(ShardTransfer {
|
269 |
+
shard_id: move_shard.shard_id,
|
270 |
+
to_shard_id: move_shard.to_shard_id,
|
271 |
+
to: move_shard.to_peer_id,
|
272 |
+
from: move_shard.from_peer_id,
|
273 |
+
sync: false,
|
274 |
+
method: move_shard.method,
|
275 |
+
}),
|
276 |
+
),
|
277 |
+
access,
|
278 |
+
wait_timeout,
|
279 |
+
)
|
280 |
+
.await
|
281 |
+
}
|
282 |
+
ClusterOperations::ReplicateShard(ReplicateShardOperation { replicate_shard }) => {
|
283 |
+
// validate shard to move
|
284 |
+
if !collection.contains_shard(replicate_shard.shard_id).await {
|
285 |
+
return Err(StorageError::BadRequest {
|
286 |
+
description: format!(
|
287 |
+
"Shard {} of {} does not exist",
|
288 |
+
replicate_shard.shard_id, collection_name
|
289 |
+
),
|
290 |
+
});
|
291 |
+
};
|
292 |
+
|
293 |
+
// validate target peer exists
|
294 |
+
validate_peer_exists(replicate_shard.to_peer_id)?;
|
295 |
+
|
296 |
+
// validate source peer exists
|
297 |
+
validate_peer_exists(replicate_shard.from_peer_id)?;
|
298 |
+
|
299 |
+
// submit operation to consensus
|
300 |
+
dispatcher
|
301 |
+
.submit_collection_meta_op(
|
302 |
+
CollectionMetaOperations::TransferShard(
|
303 |
+
collection_name,
|
304 |
+
Start(ShardTransfer {
|
305 |
+
shard_id: replicate_shard.shard_id,
|
306 |
+
to_shard_id: replicate_shard.to_shard_id,
|
307 |
+
to: replicate_shard.to_peer_id,
|
308 |
+
from: replicate_shard.from_peer_id,
|
309 |
+
sync: true,
|
310 |
+
method: replicate_shard.method,
|
311 |
+
}),
|
312 |
+
),
|
313 |
+
access,
|
314 |
+
wait_timeout,
|
315 |
+
)
|
316 |
+
.await
|
317 |
+
}
|
318 |
+
ClusterOperations::AbortTransfer(AbortTransferOperation { abort_transfer }) => {
|
319 |
+
let transfer = ShardTransferKey {
|
320 |
+
shard_id: abort_transfer.shard_id,
|
321 |
+
to_shard_id: abort_transfer.to_shard_id,
|
322 |
+
to: abort_transfer.to_peer_id,
|
323 |
+
from: abort_transfer.from_peer_id,
|
324 |
+
};
|
325 |
+
|
326 |
+
if !collection.check_transfer_exists(&transfer).await {
|
327 |
+
return Err(StorageError::NotFound {
|
328 |
+
description: format!(
|
329 |
+
"Shard transfer {} -> {} for collection {}:{} does not exist",
|
330 |
+
transfer.from, transfer.to, collection_name, transfer.shard_id
|
331 |
+
),
|
332 |
+
});
|
333 |
+
}
|
334 |
+
|
335 |
+
dispatcher
|
336 |
+
.submit_collection_meta_op(
|
337 |
+
CollectionMetaOperations::TransferShard(
|
338 |
+
collection_name,
|
339 |
+
Abort {
|
340 |
+
transfer,
|
341 |
+
reason: "user request".to_string(),
|
342 |
+
},
|
343 |
+
),
|
344 |
+
access,
|
345 |
+
wait_timeout,
|
346 |
+
)
|
347 |
+
.await
|
348 |
+
}
|
349 |
+
ClusterOperations::DropReplica(DropReplicaOperation { drop_replica }) => {
|
350 |
+
if !collection.contains_shard(drop_replica.shard_id).await {
|
351 |
+
return Err(StorageError::BadRequest {
|
352 |
+
description: format!(
|
353 |
+
"Shard {} of {} does not exist",
|
354 |
+
drop_replica.shard_id, collection_name
|
355 |
+
),
|
356 |
+
});
|
357 |
+
};
|
358 |
+
|
359 |
+
validate_peer_exists(drop_replica.peer_id)?;
|
360 |
+
|
361 |
+
let mut update_operation = UpdateCollectionOperation::new_empty(collection_name);
|
362 |
+
|
363 |
+
update_operation.set_shard_replica_changes(vec![replica_set::Change::Remove(
|
364 |
+
drop_replica.shard_id,
|
365 |
+
drop_replica.peer_id,
|
366 |
+
)]);
|
367 |
+
|
368 |
+
dispatcher
|
369 |
+
.submit_collection_meta_op(
|
370 |
+
CollectionMetaOperations::UpdateCollection(update_operation),
|
371 |
+
access,
|
372 |
+
wait_timeout,
|
373 |
+
)
|
374 |
+
.await
|
375 |
+
}
|
376 |
+
ClusterOperations::CreateShardingKey(create_sharding_key_op) => {
|
377 |
+
let create_sharding_key = create_sharding_key_op.create_sharding_key;
|
378 |
+
|
379 |
+
// Validate that:
|
380 |
+
// - proper sharding method is used
|
381 |
+
// - key does not exist yet
|
382 |
+
//
|
383 |
+
// If placement suggested:
|
384 |
+
// - Peers exist
|
385 |
+
|
386 |
+
let state = collection.state().await;
|
387 |
+
|
388 |
+
match state.config.params.sharding_method.unwrap_or_default() {
|
389 |
+
ShardingMethod::Auto => {
|
390 |
+
return Err(StorageError::bad_request(
|
391 |
+
"Shard Key cannot be created with Auto sharding method",
|
392 |
+
));
|
393 |
+
}
|
394 |
+
ShardingMethod::Custom => {}
|
395 |
+
}
|
396 |
+
|
397 |
+
let shard_number = create_sharding_key
|
398 |
+
.shards_number
|
399 |
+
.unwrap_or(state.config.params.shard_number)
|
400 |
+
.get() as usize;
|
401 |
+
let replication_factor = create_sharding_key
|
402 |
+
.replication_factor
|
403 |
+
.unwrap_or(state.config.params.replication_factor)
|
404 |
+
.get() as usize;
|
405 |
+
|
406 |
+
let shard_keys_mapping = state.shards_key_mapping;
|
407 |
+
if shard_keys_mapping.contains_key(&create_sharding_key.shard_key) {
|
408 |
+
return Err(StorageError::BadRequest {
|
409 |
+
description: format!(
|
410 |
+
"Sharding key {} already exists for collection {}",
|
411 |
+
create_sharding_key.shard_key, collection_name
|
412 |
+
),
|
413 |
+
});
|
414 |
+
}
|
415 |
+
|
416 |
+
let peers_pool: Vec<_> = if let Some(placement) = create_sharding_key.placement {
|
417 |
+
if placement.is_empty() {
|
418 |
+
return Err(StorageError::BadRequest {
|
419 |
+
description: format!(
|
420 |
+
"Sharding key {} placement cannot be empty. If you want to use random placement, do not specify placement",
|
421 |
+
create_sharding_key.shard_key
|
422 |
+
),
|
423 |
+
});
|
424 |
+
}
|
425 |
+
|
426 |
+
for peer_id in placement.iter().copied() {
|
427 |
+
validate_peer_exists(peer_id)?;
|
428 |
+
}
|
429 |
+
placement
|
430 |
+
} else {
|
431 |
+
get_all_peer_ids()
|
432 |
+
};
|
433 |
+
|
434 |
+
let exact_placement =
|
435 |
+
generate_even_placement(peers_pool, shard_number, replication_factor);
|
436 |
+
|
437 |
+
dispatcher
|
438 |
+
.submit_collection_meta_op(
|
439 |
+
CollectionMetaOperations::CreateShardKey(CreateShardKey {
|
440 |
+
collection_name,
|
441 |
+
shard_key: create_sharding_key.shard_key,
|
442 |
+
placement: exact_placement,
|
443 |
+
}),
|
444 |
+
access,
|
445 |
+
wait_timeout,
|
446 |
+
)
|
447 |
+
.await
|
448 |
+
}
|
449 |
+
ClusterOperations::DropShardingKey(drop_sharding_key_op) => {
|
450 |
+
let drop_sharding_key = drop_sharding_key_op.drop_sharding_key;
|
451 |
+
// Validate that:
|
452 |
+
// - proper sharding method is used
|
453 |
+
// - key does exist
|
454 |
+
|
455 |
+
let state = collection.state().await;
|
456 |
+
|
457 |
+
match state.config.params.sharding_method.unwrap_or_default() {
|
458 |
+
ShardingMethod::Auto => {
|
459 |
+
return Err(StorageError::bad_request(
|
460 |
+
"Shard Key cannot be created with Auto sharding method",
|
461 |
+
));
|
462 |
+
}
|
463 |
+
ShardingMethod::Custom => {}
|
464 |
+
}
|
465 |
+
|
466 |
+
let shard_keys_mapping = state.shards_key_mapping;
|
467 |
+
if !shard_keys_mapping.contains_key(&drop_sharding_key.shard_key) {
|
468 |
+
return Err(StorageError::BadRequest {
|
469 |
+
description: format!(
|
470 |
+
"Sharding key {} does not exists for collection {}",
|
471 |
+
drop_sharding_key.shard_key, collection_name
|
472 |
+
),
|
473 |
+
});
|
474 |
+
}
|
475 |
+
|
476 |
+
dispatcher
|
477 |
+
.submit_collection_meta_op(
|
478 |
+
CollectionMetaOperations::DropShardKey(DropShardKey {
|
479 |
+
collection_name,
|
480 |
+
shard_key: drop_sharding_key.shard_key,
|
481 |
+
}),
|
482 |
+
access,
|
483 |
+
wait_timeout,
|
484 |
+
)
|
485 |
+
.await
|
486 |
+
}
|
487 |
+
ClusterOperations::RestartTransfer(RestartTransferOperation { restart_transfer }) => {
|
488 |
+
// TODO(reshading): Deduplicate resharding operations handling?
|
489 |
+
|
490 |
+
let RestartTransfer {
|
491 |
+
shard_id,
|
492 |
+
to_shard_id,
|
493 |
+
from_peer_id,
|
494 |
+
to_peer_id,
|
495 |
+
method,
|
496 |
+
} = restart_transfer;
|
497 |
+
|
498 |
+
let transfer_key = ShardTransferKey {
|
499 |
+
shard_id,
|
500 |
+
to_shard_id,
|
501 |
+
to: to_peer_id,
|
502 |
+
from: from_peer_id,
|
503 |
+
};
|
504 |
+
|
505 |
+
if !collection.check_transfer_exists(&transfer_key).await {
|
506 |
+
return Err(StorageError::NotFound {
|
507 |
+
description: format!(
|
508 |
+
"Shard transfer {} -> {} for collection {}:{} does not exist",
|
509 |
+
transfer_key.from, transfer_key.to, collection_name, transfer_key.shard_id
|
510 |
+
),
|
511 |
+
});
|
512 |
+
}
|
513 |
+
|
514 |
+
dispatcher
|
515 |
+
.submit_collection_meta_op(
|
516 |
+
CollectionMetaOperations::TransferShard(
|
517 |
+
collection_name,
|
518 |
+
ShardTransferOperations::Restart(ShardTransferRestart {
|
519 |
+
shard_id,
|
520 |
+
to_shard_id,
|
521 |
+
to: to_peer_id,
|
522 |
+
from: from_peer_id,
|
523 |
+
method,
|
524 |
+
}),
|
525 |
+
),
|
526 |
+
access,
|
527 |
+
wait_timeout,
|
528 |
+
)
|
529 |
+
.await
|
530 |
+
}
|
531 |
+
ClusterOperations::StartResharding(op) => {
|
532 |
+
let StartResharding {
|
533 |
+
direction,
|
534 |
+
peer_id,
|
535 |
+
shard_key,
|
536 |
+
} = op.start_resharding;
|
537 |
+
|
538 |
+
let collection_state = collection.state().await;
|
539 |
+
|
540 |
+
if let Some(shard_key) = &shard_key {
|
541 |
+
if !collection_state.shards_key_mapping.contains_key(shard_key) {
|
542 |
+
return Err(StorageError::bad_request(format!(
|
543 |
+
"sharding key {shard_key} does not exists for collection {collection_name}"
|
544 |
+
)));
|
545 |
+
}
|
546 |
+
}
|
547 |
+
|
548 |
+
let shard_id = match (direction, shard_key.as_ref()) {
|
549 |
+
// When scaling up, just pick the next shard ID
|
550 |
+
(ReshardingDirection::Up, _) => {
|
551 |
+
collection_state
|
552 |
+
.shards
|
553 |
+
.keys()
|
554 |
+
.copied()
|
555 |
+
.max()
|
556 |
+
.expect("collection must contain shards")
|
557 |
+
+ 1
|
558 |
+
}
|
559 |
+
// When scaling down without shard keys, pick the last shard ID
|
560 |
+
(ReshardingDirection::Down, None) => collection_state
|
561 |
+
.shards
|
562 |
+
.keys()
|
563 |
+
.copied()
|
564 |
+
.max()
|
565 |
+
.expect("collection must contain shards"),
|
566 |
+
// When scaling down with shard keys, pick the last shard ID of that key
|
567 |
+
(ReshardingDirection::Down, Some(shard_key)) => collection_state
|
568 |
+
.shards_key_mapping
|
569 |
+
.get(shard_key)
|
570 |
+
.expect("specified shard key must exist")
|
571 |
+
.iter()
|
572 |
+
.copied()
|
573 |
+
.max()
|
574 |
+
.expect("collection must contain shards"),
|
575 |
+
};
|
576 |
+
|
577 |
+
let peer_id = match (peer_id, direction) {
|
578 |
+
// Select user specified peer, but make sure it exists
|
579 |
+
(Some(peer_id), _) => {
|
580 |
+
validate_peer_exists(peer_id)?;
|
581 |
+
peer_id
|
582 |
+
}
|
583 |
+
|
584 |
+
// When scaling up, select peer with least number of shards for this collection
|
585 |
+
(None, ReshardingDirection::Up) => {
|
586 |
+
let mut shards_on_peers = collection_state
|
587 |
+
.shards
|
588 |
+
.values()
|
589 |
+
.flat_map(|shard_info| shard_info.replicas.keys())
|
590 |
+
.fold(HashMap::new(), |mut counts, peer_id| {
|
591 |
+
*counts.entry(*peer_id).or_insert(0) += 1;
|
592 |
+
counts
|
593 |
+
});
|
594 |
+
for peer_id in get_all_peer_ids() {
|
595 |
+
// Add registered peers not holding any shard yet
|
596 |
+
shards_on_peers.entry(peer_id).or_insert(0);
|
597 |
+
}
|
598 |
+
shards_on_peers
|
599 |
+
.into_iter()
|
600 |
+
.min_by_key(|(_, count)| *count)
|
601 |
+
.map(|(peer_id, _)| peer_id)
|
602 |
+
.expect("expected at least one peer")
|
603 |
+
}
|
604 |
+
|
605 |
+
// When scaling down, select random peer that contains the shard we're dropping
|
606 |
+
// Other peers work, but are less efficient due to remote operations
|
607 |
+
(None, ReshardingDirection::Down) => collection_state
|
608 |
+
.shards
|
609 |
+
.get(&shard_id)
|
610 |
+
.expect("select shard ID must always exist in collection state")
|
611 |
+
.replicas
|
612 |
+
.keys()
|
613 |
+
.choose(&mut rand::thread_rng())
|
614 |
+
.copied()
|
615 |
+
.unwrap(),
|
616 |
+
};
|
617 |
+
|
618 |
+
if let Some(resharding) = &collection_state.resharding {
|
619 |
+
return Err(StorageError::bad_request(format!(
|
620 |
+
"resharding {resharding:?} is already in progress \
|
621 |
+
for collection {collection_name}"
|
622 |
+
)));
|
623 |
+
}
|
624 |
+
|
625 |
+
dispatcher
|
626 |
+
.submit_collection_meta_op(
|
627 |
+
CollectionMetaOperations::Resharding(
|
628 |
+
collection_name.clone(),
|
629 |
+
ReshardingOperation::Start(ReshardKey {
|
630 |
+
direction,
|
631 |
+
peer_id,
|
632 |
+
shard_id,
|
633 |
+
shard_key,
|
634 |
+
}),
|
635 |
+
),
|
636 |
+
access,
|
637 |
+
wait_timeout,
|
638 |
+
)
|
639 |
+
.await
|
640 |
+
}
|
641 |
+
ClusterOperations::AbortResharding(_) => {
|
642 |
+
// TODO(reshading): Deduplicate resharding operations handling?
|
643 |
+
|
644 |
+
let Some(state) = collection.resharding_state().await else {
|
645 |
+
return Err(StorageError::bad_request(format!(
|
646 |
+
"resharding is not in progress for collection {collection_name}"
|
647 |
+
)));
|
648 |
+
};
|
649 |
+
|
650 |
+
dispatcher
|
651 |
+
.submit_collection_meta_op(
|
652 |
+
CollectionMetaOperations::Resharding(
|
653 |
+
collection_name.clone(),
|
654 |
+
ReshardingOperation::Abort(ReshardKey {
|
655 |
+
direction: state.direction,
|
656 |
+
peer_id: state.peer_id,
|
657 |
+
shard_id: state.shard_id,
|
658 |
+
shard_key: state.shard_key.clone(),
|
659 |
+
}),
|
660 |
+
),
|
661 |
+
access,
|
662 |
+
wait_timeout,
|
663 |
+
)
|
664 |
+
.await
|
665 |
+
}
|
666 |
+
ClusterOperations::FinishResharding(_) => {
|
667 |
+
// TODO(resharding): Deduplicate resharding operations handling?
|
668 |
+
|
669 |
+
let Some(state) = collection.resharding_state().await else {
|
670 |
+
return Err(StorageError::bad_request(format!(
|
671 |
+
"resharding is not in progress for collection {collection_name}"
|
672 |
+
)));
|
673 |
+
};
|
674 |
+
|
675 |
+
dispatcher
|
676 |
+
.submit_collection_meta_op(
|
677 |
+
CollectionMetaOperations::Resharding(
|
678 |
+
collection_name.clone(),
|
679 |
+
ReshardingOperation::Finish(state.key()),
|
680 |
+
),
|
681 |
+
access,
|
682 |
+
wait_timeout,
|
683 |
+
)
|
684 |
+
.await
|
685 |
+
}
|
686 |
+
|
687 |
+
ClusterOperations::FinishMigratingPoints(op) => {
|
688 |
+
// TODO(resharding): Deduplicate resharding operations handling?
|
689 |
+
|
690 |
+
let Some(state) = collection.resharding_state().await else {
|
691 |
+
return Err(StorageError::bad_request(format!(
|
692 |
+
"resharding is not in progress for collection {collection_name}"
|
693 |
+
)));
|
694 |
+
};
|
695 |
+
|
696 |
+
let op = op.finish_migrating_points;
|
697 |
+
|
698 |
+
let shard_id = match (op.shard_id, state.direction) {
|
699 |
+
(Some(shard_id), _) => shard_id,
|
700 |
+
(None, ReshardingDirection::Up) => state.shard_id,
|
701 |
+
(None, ReshardingDirection::Down) => {
|
702 |
+
return Err(StorageError::bad_request(
|
703 |
+
"shard ID must be specified when resharding down",
|
704 |
+
));
|
705 |
+
}
|
706 |
+
};
|
707 |
+
|
708 |
+
let peer_id = match (op.peer_id, state.direction) {
|
709 |
+
(Some(peer_id), _) => peer_id,
|
710 |
+
(None, ReshardingDirection::Up) => state.peer_id,
|
711 |
+
(None, ReshardingDirection::Down) => {
|
712 |
+
return Err(StorageError::bad_request(
|
713 |
+
"peer ID must be specified when resharding down",
|
714 |
+
));
|
715 |
+
}
|
716 |
+
};
|
717 |
+
|
718 |
+
dispatcher
|
719 |
+
.submit_collection_meta_op(
|
720 |
+
CollectionMetaOperations::SetShardReplicaState(SetShardReplicaState {
|
721 |
+
collection_name: collection_name.clone(),
|
722 |
+
shard_id,
|
723 |
+
peer_id,
|
724 |
+
state: replica_set::ReplicaState::Active,
|
725 |
+
from_state: Some(replica_set::ReplicaState::Resharding),
|
726 |
+
}),
|
727 |
+
access,
|
728 |
+
wait_timeout,
|
729 |
+
)
|
730 |
+
.await
|
731 |
+
}
|
732 |
+
|
733 |
+
ClusterOperations::CommitReadHashRing(_) => {
|
734 |
+
// TODO(reshading): Deduplicate resharding operations handling?
|
735 |
+
|
736 |
+
let Some(state) = collection.resharding_state().await else {
|
737 |
+
return Err(StorageError::bad_request(format!(
|
738 |
+
"resharding is not in progress for collection {collection_name}"
|
739 |
+
)));
|
740 |
+
};
|
741 |
+
|
742 |
+
// TODO(resharding): Add precondition checks?
|
743 |
+
|
744 |
+
dispatcher
|
745 |
+
.submit_collection_meta_op(
|
746 |
+
CollectionMetaOperations::Resharding(
|
747 |
+
collection_name.clone(),
|
748 |
+
ReshardingOperation::CommitRead(ReshardKey {
|
749 |
+
direction: state.direction,
|
750 |
+
peer_id: state.peer_id,
|
751 |
+
shard_id: state.shard_id,
|
752 |
+
shard_key: state.shard_key.clone(),
|
753 |
+
}),
|
754 |
+
),
|
755 |
+
access,
|
756 |
+
wait_timeout,
|
757 |
+
)
|
758 |
+
.await
|
759 |
+
}
|
760 |
+
|
761 |
+
ClusterOperations::CommitWriteHashRing(_) => {
|
762 |
+
// TODO(reshading): Deduplicate resharding operations handling?
|
763 |
+
|
764 |
+
let Some(state) = collection.resharding_state().await else {
|
765 |
+
return Err(StorageError::bad_request(format!(
|
766 |
+
"resharding is not in progress for collection {collection_name}"
|
767 |
+
)));
|
768 |
+
};
|
769 |
+
|
770 |
+
// TODO(resharding): Add precondition checks?
|
771 |
+
|
772 |
+
dispatcher
|
773 |
+
.submit_collection_meta_op(
|
774 |
+
CollectionMetaOperations::Resharding(
|
775 |
+
collection_name.clone(),
|
776 |
+
ReshardingOperation::CommitWrite(ReshardKey {
|
777 |
+
direction: state.direction,
|
778 |
+
peer_id: state.peer_id,
|
779 |
+
shard_id: state.shard_id,
|
780 |
+
shard_key: state.shard_key.clone(),
|
781 |
+
}),
|
782 |
+
),
|
783 |
+
access,
|
784 |
+
wait_timeout,
|
785 |
+
)
|
786 |
+
.await
|
787 |
+
}
|
788 |
+
}
|
789 |
+
}
|
790 |
+
|
791 |
+
#[cfg(test)]
|
792 |
+
mod tests {
|
793 |
+
use std::collections::HashSet;
|
794 |
+
|
795 |
+
use super::*;
|
796 |
+
|
797 |
+
#[test]
|
798 |
+
fn test_generate_even_placement() {
|
799 |
+
let pool = vec![1, 2, 3];
|
800 |
+
let placement = generate_even_placement(pool, 3, 2);
|
801 |
+
|
802 |
+
assert_eq!(placement.len(), 3);
|
803 |
+
for shard_placement in placement {
|
804 |
+
assert_eq!(shard_placement.len(), 2);
|
805 |
+
assert_ne!(shard_placement[0], shard_placement[1]);
|
806 |
+
}
|
807 |
+
|
808 |
+
let pool = vec![1, 2, 3];
|
809 |
+
let placement = generate_even_placement(pool, 3, 3);
|
810 |
+
|
811 |
+
assert_eq!(placement.len(), 3);
|
812 |
+
for shard_placement in placement {
|
813 |
+
assert_eq!(shard_placement.len(), 3);
|
814 |
+
let set: HashSet<_> = shard_placement.into_iter().collect();
|
815 |
+
assert_eq!(set.len(), 3);
|
816 |
+
}
|
817 |
+
|
818 |
+
let pool = vec![1, 2, 3, 4, 5, 6];
|
819 |
+
let placement = generate_even_placement(pool, 3, 2);
|
820 |
+
|
821 |
+
assert_eq!(placement.len(), 3);
|
822 |
+
let flat_placement: Vec<_> = placement.into_iter().flatten().collect();
|
823 |
+
let set: HashSet<_> = flat_placement.into_iter().collect();
|
824 |
+
assert_eq!(set.len(), 6);
|
825 |
+
|
826 |
+
let pool = vec![1, 2, 3, 4, 5];
|
827 |
+
let placement = generate_even_placement(pool, 3, 10);
|
828 |
+
|
829 |
+
assert_eq!(placement.len(), 3);
|
830 |
+
for shard_placement in placement {
|
831 |
+
assert_eq!(shard_placement.len(), 5);
|
832 |
+
}
|
833 |
+
}
|
834 |
+
}
|
src/common/debugger.rs
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::sync::Arc;
|
2 |
+
|
3 |
+
use parking_lot::Mutex;
|
4 |
+
use schemars::JsonSchema;
|
5 |
+
use serde::{Deserialize, Serialize};
|
6 |
+
|
7 |
+
use crate::common::pyroscope_state::pyro::PyroscopeState;
|
8 |
+
use crate::settings::Settings;
|
9 |
+
|
10 |
+
#[derive(Serialize, JsonSchema, Debug, Deserialize, Clone)]
|
11 |
+
pub struct PyroscopeConfig {
|
12 |
+
pub url: String,
|
13 |
+
pub identifier: String,
|
14 |
+
pub user: Option<String>,
|
15 |
+
pub password: Option<String>,
|
16 |
+
pub sampling_rate: Option<u32>,
|
17 |
+
}
|
18 |
+
|
19 |
+
#[derive(Default, Debug, Serialize, JsonSchema, Deserialize, Clone)]
|
20 |
+
pub struct DebuggerConfig {
|
21 |
+
pub pyroscope: Option<PyroscopeConfig>,
|
22 |
+
}
|
23 |
+
|
24 |
+
#[derive(Debug, Serialize, JsonSchema, Deserialize, Clone)]
|
25 |
+
#[serde(rename_all = "snake_case")]
|
26 |
+
pub enum DebugConfigPatch {
|
27 |
+
Pyroscope(Option<PyroscopeConfig>),
|
28 |
+
}
|
29 |
+
|
30 |
+
pub struct DebuggerState {
|
31 |
+
#[cfg_attr(not(target_os = "linux"), allow(dead_code))]
|
32 |
+
pub pyroscope: Arc<Mutex<Option<PyroscopeState>>>,
|
33 |
+
}
|
34 |
+
|
35 |
+
impl DebuggerState {
|
36 |
+
pub fn from_settings(settings: &Settings) -> Self {
|
37 |
+
let pyroscope_config = settings.debugger.pyroscope.clone();
|
38 |
+
Self {
|
39 |
+
pyroscope: Arc::new(Mutex::new(PyroscopeState::from_config(pyroscope_config))),
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
#[cfg_attr(not(target_os = "linux"), allow(clippy::unused_self))]
|
44 |
+
pub fn get_config(&self) -> DebuggerConfig {
|
45 |
+
let pyroscope_config = {
|
46 |
+
#[cfg(target_os = "linux")]
|
47 |
+
{
|
48 |
+
let pyroscope_state_guard = self.pyroscope.lock();
|
49 |
+
pyroscope_state_guard.as_ref().map(|s| s.config.clone())
|
50 |
+
}
|
51 |
+
#[cfg(not(target_os = "linux"))]
|
52 |
+
{
|
53 |
+
None
|
54 |
+
}
|
55 |
+
};
|
56 |
+
|
57 |
+
DebuggerConfig {
|
58 |
+
pyroscope: pyroscope_config,
|
59 |
+
}
|
60 |
+
}
|
61 |
+
|
62 |
+
#[cfg_attr(not(target_os = "linux"), allow(clippy::unused_self))]
|
63 |
+
pub fn apply_config_patch(&self, patch: DebugConfigPatch) -> bool {
|
64 |
+
#[cfg(target_os = "linux")]
|
65 |
+
{
|
66 |
+
match patch {
|
67 |
+
DebugConfigPatch::Pyroscope(new_config) => {
|
68 |
+
let mut pyroscope_guard = self.pyroscope.lock();
|
69 |
+
if let Some(pyroscope_state) = pyroscope_guard.as_mut() {
|
70 |
+
let stopped = pyroscope_state.stop_agent();
|
71 |
+
if !stopped {
|
72 |
+
return false;
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
if let Some(new_config) = new_config {
|
77 |
+
*pyroscope_guard = PyroscopeState::from_config(Some(new_config));
|
78 |
+
}
|
79 |
+
true
|
80 |
+
}
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
#[cfg(not(target_os = "linux"))]
|
85 |
+
{
|
86 |
+
let _ = patch; // Ignore new_config on non-linux OS
|
87 |
+
false
|
88 |
+
}
|
89 |
+
}
|
90 |
+
}
|
src/common/error_reporting.rs
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::time::Duration;
|
2 |
+
|
3 |
+
pub struct ErrorReporter;
|
4 |
+
|
5 |
+
impl ErrorReporter {
|
6 |
+
fn get_url() -> String {
|
7 |
+
if cfg!(debug_assertions) {
|
8 |
+
"https://staging-telemetry.qdrant.io".to_string()
|
9 |
+
} else {
|
10 |
+
"https://telemetry.qdrant.io".to_string()
|
11 |
+
}
|
12 |
+
}
|
13 |
+
|
14 |
+
pub fn report(error: &str, reporting_id: &str, backtrace: Option<&str>) {
|
15 |
+
let client = reqwest::blocking::Client::new();
|
16 |
+
|
17 |
+
let report = serde_json::json!({
|
18 |
+
"id": reporting_id,
|
19 |
+
"error": error,
|
20 |
+
"backtrace": backtrace.unwrap_or(""),
|
21 |
+
});
|
22 |
+
|
23 |
+
let data = serde_json::to_string(&report).unwrap();
|
24 |
+
let _resp = client
|
25 |
+
.post(Self::get_url())
|
26 |
+
.body(data)
|
27 |
+
.header("Content-Type", "application/json")
|
28 |
+
.timeout(Duration::from_secs(1))
|
29 |
+
.send();
|
30 |
+
}
|
31 |
+
}
|
src/common/health.rs
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::collections::HashSet;
|
2 |
+
use std::future::{self, Future};
|
3 |
+
use std::sync::atomic::{self, AtomicBool};
|
4 |
+
use std::sync::Arc;
|
5 |
+
use std::time::Duration;
|
6 |
+
use std::{panic, thread};
|
7 |
+
|
8 |
+
use api::grpc::qdrant::qdrant_internal_client::QdrantInternalClient;
|
9 |
+
use api::grpc::qdrant::{GetConsensusCommitRequest, GetConsensusCommitResponse};
|
10 |
+
use api::grpc::transport_channel_pool::{self, TransportChannelPool};
|
11 |
+
use collection::shards::shard::ShardId;
|
12 |
+
use collection::shards::CollectionId;
|
13 |
+
use common::defaults;
|
14 |
+
use futures::stream::FuturesUnordered;
|
15 |
+
use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _};
|
16 |
+
use itertools::Itertools;
|
17 |
+
use storage::content_manager::consensus_manager::ConsensusStateRef;
|
18 |
+
use storage::content_manager::toc::TableOfContent;
|
19 |
+
use storage::rbac::Access;
|
20 |
+
use tokio::{runtime, sync, time};
|
21 |
+
|
22 |
+
const READY_CHECK_TIMEOUT: Duration = Duration::from_millis(500);
|
23 |
+
const GET_CONSENSUS_COMMITS_RETRIES: usize = 2;
|
24 |
+
|
25 |
+
/// Structure used to process health checks like `/readyz` endpoints.
|
26 |
+
pub struct HealthChecker {
|
27 |
+
// The state of the health checker.
|
28 |
+
// Once set to `true`, it should not change back to `false`.
|
29 |
+
// Initially set to `false`.
|
30 |
+
is_ready: Arc<AtomicBool>,
|
31 |
+
// The signal that notifies that state has changed.
|
32 |
+
// Comes from the health checker task.
|
33 |
+
is_ready_signal: Arc<sync::Notify>,
|
34 |
+
// Signal to the health checker task, that the API was called.
|
35 |
+
// Used to drive the health checker task and avoid constant polling.
|
36 |
+
check_ready_signal: Arc<sync::Notify>,
|
37 |
+
cancel: cancel::DropGuard,
|
38 |
+
}
|
39 |
+
|
40 |
+
impl HealthChecker {
|
41 |
+
pub fn spawn(
|
42 |
+
toc: Arc<TableOfContent>,
|
43 |
+
consensus_state: ConsensusStateRef,
|
44 |
+
runtime: &runtime::Handle,
|
45 |
+
wait_for_bootstrap: bool,
|
46 |
+
) -> Self {
|
47 |
+
let task = Task {
|
48 |
+
toc,
|
49 |
+
consensus_state,
|
50 |
+
is_ready: Default::default(),
|
51 |
+
is_ready_signal: Default::default(),
|
52 |
+
check_ready_signal: Default::default(),
|
53 |
+
cancel: Default::default(),
|
54 |
+
wait_for_bootstrap,
|
55 |
+
};
|
56 |
+
|
57 |
+
let health_checker = Self {
|
58 |
+
is_ready: task.is_ready.clone(),
|
59 |
+
is_ready_signal: task.is_ready_signal.clone(),
|
60 |
+
check_ready_signal: task.check_ready_signal.clone(),
|
61 |
+
cancel: task.cancel.clone().drop_guard(),
|
62 |
+
};
|
63 |
+
|
64 |
+
let task = runtime.spawn(task.exec());
|
65 |
+
drop(task); // drop `JoinFuture` explicitly to make clippy happy
|
66 |
+
|
67 |
+
health_checker
|
68 |
+
}
|
69 |
+
|
70 |
+
pub async fn check_ready(&self) -> bool {
|
71 |
+
if self.is_ready() {
|
72 |
+
return true;
|
73 |
+
}
|
74 |
+
|
75 |
+
self.notify_task();
|
76 |
+
self.wait_ready().await
|
77 |
+
}
|
78 |
+
|
79 |
+
pub fn is_ready(&self) -> bool {
|
80 |
+
self.is_ready.load(atomic::Ordering::Relaxed)
|
81 |
+
}
|
82 |
+
|
83 |
+
pub fn notify_task(&self) {
|
84 |
+
self.check_ready_signal.notify_one();
|
85 |
+
}
|
86 |
+
|
87 |
+
async fn wait_ready(&self) -> bool {
|
88 |
+
let is_ready_signal = self.is_ready_signal.notified();
|
89 |
+
|
90 |
+
if self.is_ready() {
|
91 |
+
return true;
|
92 |
+
}
|
93 |
+
|
94 |
+
time::timeout(READY_CHECK_TIMEOUT, is_ready_signal)
|
95 |
+
.await
|
96 |
+
.is_ok()
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
pub struct Task {
|
101 |
+
toc: Arc<TableOfContent>,
|
102 |
+
consensus_state: ConsensusStateRef,
|
103 |
+
// Shared state with the health checker
|
104 |
+
// Once set to `true`, it should not change back to `false`.
|
105 |
+
is_ready: Arc<AtomicBool>,
|
106 |
+
// Used to notify the health checker service that the state has changed.
|
107 |
+
is_ready_signal: Arc<sync::Notify>,
|
108 |
+
// Driver signal for the health checker task
|
109 |
+
// Once received, the task should proceed with an attempt to check the state.
|
110 |
+
// Usually comes from the API call, but can be triggered by the task itself.
|
111 |
+
check_ready_signal: Arc<sync::Notify>,
|
112 |
+
cancel: cancel::CancellationToken,
|
113 |
+
wait_for_bootstrap: bool,
|
114 |
+
}
|
115 |
+
|
116 |
+
impl Task {
|
117 |
+
pub async fn exec(mut self) {
|
118 |
+
while let Err(err) = self.exec_catch_unwind().await {
|
119 |
+
let message = common::panic::downcast_str(&err).unwrap_or("");
|
120 |
+
let separator = if !message.is_empty() { ": " } else { "" };
|
121 |
+
|
122 |
+
log::error!("HealthChecker task panicked, retrying{separator}{message}",);
|
123 |
+
}
|
124 |
+
}
|
125 |
+
|
126 |
+
async fn exec_catch_unwind(&mut self) -> thread::Result<()> {
|
127 |
+
panic::AssertUnwindSafe(self.exec_cancel())
|
128 |
+
.catch_unwind()
|
129 |
+
.await
|
130 |
+
}
|
131 |
+
|
132 |
+
async fn exec_cancel(&mut self) {
|
133 |
+
let _ = cancel::future::cancel_on_token(self.cancel.clone(), self.exec_impl()).await;
|
134 |
+
}
|
135 |
+
|
136 |
+
async fn exec_impl(&mut self) {
|
137 |
+
// Wait until node joins cluster for the first time
|
138 |
+
//
|
139 |
+
// If this is a new deployment and `--bootstrap` CLI parameter was specified...
|
140 |
+
if self.wait_for_bootstrap {
|
141 |
+
// Check if this is the only node in the cluster
|
142 |
+
while self.consensus_state.peer_count() <= 1 {
|
143 |
+
// If cluster is empty, make another attempt to check
|
144 |
+
// after we receive another call to `/readyz`
|
145 |
+
//
|
146 |
+
// Wait for `/readyz` signal
|
147 |
+
self.check_ready_signal.notified().await;
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
151 |
+
// Artificial simulate signal from `/readyz` endpoint
|
152 |
+
// as if it was already called by the user.
|
153 |
+
// This allows to check the happy path without waiting for the first call.
|
154 |
+
self.check_ready_signal.notify_one();
|
155 |
+
|
156 |
+
// Get *cluster* commit index, or check if this is the only node in the cluster
|
157 |
+
let Some(cluster_commit_index) = self.cluster_commit_index().await else {
|
158 |
+
self.set_ready();
|
159 |
+
return;
|
160 |
+
};
|
161 |
+
|
162 |
+
// Check if *local* commit index >= *cluster* commit index...
|
163 |
+
while self.commit_index() < cluster_commit_index {
|
164 |
+
// Wait for `/readyz` signal
|
165 |
+
self.check_ready_signal.notified().await;
|
166 |
+
|
167 |
+
// If not:
|
168 |
+
//
|
169 |
+
// - Check if this is the only node in the cluster
|
170 |
+
if self.consensus_state.peer_count() <= 1 {
|
171 |
+
self.set_ready();
|
172 |
+
return;
|
173 |
+
}
|
174 |
+
|
175 |
+
// TODO: Do we want to update `cluster_commit_index` here?
|
176 |
+
//
|
177 |
+
// I.e.:
|
178 |
+
// - If we *don't* update `cluster_commit_index`, then we will only wait till the node
|
179 |
+
// catch up with the cluster commit index *at the moment the node has been started*
|
180 |
+
// - If we *do* update `cluster_commit_index`, then we will keep track of cluster
|
181 |
+
// commit index updates and wait till the node *completely* catch up with the leader,
|
182 |
+
// which might be hard (if not impossible) in some situations
|
183 |
+
}
|
184 |
+
|
185 |
+
// Collect "unhealthy" shards list
|
186 |
+
let mut unhealthy_shards = self.unhealthy_shards().await;
|
187 |
+
|
188 |
+
// Check if all shards are "healthy"...
|
189 |
+
while !unhealthy_shards.is_empty() {
|
190 |
+
// If not:
|
191 |
+
//
|
192 |
+
// - Wait for `/readyz` signal
|
193 |
+
self.check_ready_signal.notified().await;
|
194 |
+
|
195 |
+
// - Refresh "unhealthy" shards list
|
196 |
+
let current_unhealthy_shards = self.unhealthy_shards().await;
|
197 |
+
|
198 |
+
// - Check if any shards "healed" since last check
|
199 |
+
unhealthy_shards.retain(|shard| current_unhealthy_shards.contains(shard));
|
200 |
+
}
|
201 |
+
|
202 |
+
self.set_ready();
|
203 |
+
}
|
204 |
+
|
205 |
+
async fn cluster_commit_index(&self) -> Option<u64> {
|
206 |
+
// Wait for `/readyz` signal
|
207 |
+
self.check_ready_signal.notified().await;
|
208 |
+
|
209 |
+
// Check if there is only 1 node in the cluster
|
210 |
+
if self.consensus_state.peer_count() <= 1 {
|
211 |
+
return None;
|
212 |
+
}
|
213 |
+
|
214 |
+
// Get *cluster* commit index
|
215 |
+
let peer_address_by_id = self.consensus_state.peer_address_by_id();
|
216 |
+
let transport_channel_pool = &self.toc.get_channel_service().channel_pool;
|
217 |
+
let this_peer_id = self.toc.this_peer_id;
|
218 |
+
let this_peer_uri = peer_address_by_id.get(&this_peer_id);
|
219 |
+
|
220 |
+
let mut requests = peer_address_by_id
|
221 |
+
.values()
|
222 |
+
// Do not get the current commit from ourselves
|
223 |
+
.filter(|&uri| Some(uri) != this_peer_uri)
|
224 |
+
// Historic peers might use the same URLs as our current peers, request each URI once
|
225 |
+
.unique()
|
226 |
+
.map(|uri| get_consensus_commit(transport_channel_pool, uri))
|
227 |
+
.collect::<FuturesUnordered<_>>()
|
228 |
+
.inspect_err(|err| log::error!("GetConsensusCommit request failed: {err}"))
|
229 |
+
.filter_map(|res| future::ready(res.ok()));
|
230 |
+
|
231 |
+
// Raft commits consensus operation, after majority of nodes persisted it.
|
232 |
+
//
|
233 |
+
// This means, if we check the majority of nodes (e.g., `total nodes / 2 + 1`), at least one
|
234 |
+
// of these nodes will *always* have an up-to-date commit index. And so, the highest commit
|
235 |
+
// index among majority of nodes *is* the cluster commit index.
|
236 |
+
//
|
237 |
+
// Our current node *is* one of the cluster nodes, so it's enough to query `total nodes / 2`
|
238 |
+
// *additional* nodes, to get cluster commit index.
|
239 |
+
//
|
240 |
+
// The check goes like this:
|
241 |
+
// - Either at least one of the "additional" nodes return a *higher* commit index, which
|
242 |
+
// means our node is *not* up-to-date, and we have to wait to reach this commit index
|
243 |
+
// - Or *all* of them return *lower* commit index, which means current node is *already*
|
244 |
+
// up-to-date, and `/readyz` check will pass to the next step
|
245 |
+
//
|
246 |
+
// Example:
|
247 |
+
//
|
248 |
+
// Total nodes: 2
|
249 |
+
// Required: 2 / 2 = 1
|
250 |
+
//
|
251 |
+
// Total nodes: 3
|
252 |
+
// Required: 3 / 2 = 1
|
253 |
+
//
|
254 |
+
// Total nodes: 4
|
255 |
+
// Required: 4 / 2 = 2
|
256 |
+
//
|
257 |
+
// Total nodes: 5
|
258 |
+
// Required: 5 / 2 = 2
|
259 |
+
let sufficient_commit_indices_count = peer_address_by_id.len() / 2;
|
260 |
+
|
261 |
+
// *Wait* for `total nodex / 2` successful responses...
|
262 |
+
let mut commit_indices: Vec<_> = (&mut requests)
|
263 |
+
.take(sufficient_commit_indices_count)
|
264 |
+
.collect()
|
265 |
+
.await;
|
266 |
+
|
267 |
+
// ...and also collect any additional responses, that we might have *already* received
|
268 |
+
while let Ok(Some(resp)) = time::timeout(Duration::ZERO, requests.next()).await {
|
269 |
+
commit_indices.push(resp);
|
270 |
+
}
|
271 |
+
|
272 |
+
// Find the maximum commit index among all responses.
|
273 |
+
//
|
274 |
+
// Note, that we progress even if most (or even *all*) requests failed (e.g., because all
|
275 |
+
// other nodes are unavailable or they don't support `GetConsensusCommit` gRPC API).
|
276 |
+
//
|
277 |
+
// So this check is not 100% reliable and can give a false-positive result!
|
278 |
+
let cluster_commit_index = commit_indices
|
279 |
+
.into_iter()
|
280 |
+
.map(|resp| resp.into_inner().commit)
|
281 |
+
.max()
|
282 |
+
.unwrap_or(0);
|
283 |
+
|
284 |
+
Some(cluster_commit_index as _)
|
285 |
+
}
|
286 |
+
|
287 |
+
fn commit_index(&self) -> u64 {
|
288 |
+
// TODO: Blocking call in async context!?
|
289 |
+
self.consensus_state
|
290 |
+
.persistent
|
291 |
+
.read()
|
292 |
+
.last_applied_entry()
|
293 |
+
.unwrap_or(0)
|
294 |
+
}
|
295 |
+
|
296 |
+
/// List shards that are unhealthy, which may undergo automatic recovery.
|
297 |
+
///
|
298 |
+
/// Shards in resharding state are not considered unhealthy and are excluded here.
|
299 |
+
/// They require an external driver to make them active or to drop them.
|
300 |
+
async fn unhealthy_shards(&self) -> HashSet<Shard> {
|
301 |
+
let this_peer_id = self.toc.this_peer_id;
|
302 |
+
let collections = self
|
303 |
+
.toc
|
304 |
+
.all_collections(&Access::full("For health check"))
|
305 |
+
.await;
|
306 |
+
|
307 |
+
let mut unhealthy_shards = HashSet::new();
|
308 |
+
|
309 |
+
for collection_pass in &collections {
|
310 |
+
let state = match self.toc.get_collection(collection_pass).await {
|
311 |
+
Ok(collection) => collection.state().await,
|
312 |
+
Err(_) => continue,
|
313 |
+
};
|
314 |
+
|
315 |
+
for (&shard, info) in state.shards.iter() {
|
316 |
+
let Some(state) = info.replicas.get(&this_peer_id) else {
|
317 |
+
continue;
|
318 |
+
};
|
319 |
+
|
320 |
+
if state.is_active_or_listener_or_resharding() {
|
321 |
+
continue;
|
322 |
+
}
|
323 |
+
|
324 |
+
unhealthy_shards.insert(Shard::new(collection_pass.name(), shard));
|
325 |
+
}
|
326 |
+
}
|
327 |
+
|
328 |
+
unhealthy_shards
|
329 |
+
}
|
330 |
+
|
331 |
+
fn set_ready(&self) {
|
332 |
+
self.is_ready.store(true, atomic::Ordering::Relaxed);
|
333 |
+
self.is_ready_signal.notify_waiters();
|
334 |
+
}
|
335 |
+
}
|
336 |
+
|
337 |
+
fn get_consensus_commit<'a>(
|
338 |
+
transport_channel_pool: &'a TransportChannelPool,
|
339 |
+
uri: &'a tonic::transport::Uri,
|
340 |
+
) -> impl Future<Output = GetConsensusCommitResult> + 'a {
|
341 |
+
transport_channel_pool.with_channel_timeout(
|
342 |
+
uri,
|
343 |
+
|channel| async {
|
344 |
+
let mut client = QdrantInternalClient::new(channel);
|
345 |
+
let mut request = tonic::Request::new(GetConsensusCommitRequest {});
|
346 |
+
request.set_timeout(defaults::CONSENSUS_META_OP_WAIT);
|
347 |
+
client.get_consensus_commit(request).await
|
348 |
+
},
|
349 |
+
Some(defaults::CONSENSUS_META_OP_WAIT),
|
350 |
+
GET_CONSENSUS_COMMITS_RETRIES,
|
351 |
+
)
|
352 |
+
}
|
353 |
+
|
354 |
+
type GetConsensusCommitResult = Result<
|
355 |
+
tonic::Response<GetConsensusCommitResponse>,
|
356 |
+
transport_channel_pool::RequestError<tonic::Status>,
|
357 |
+
>;
|
358 |
+
|
359 |
+
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
360 |
+
struct Shard {
|
361 |
+
collection: CollectionId,
|
362 |
+
shard: ShardId,
|
363 |
+
}
|
364 |
+
|
365 |
+
impl Shard {
|
366 |
+
pub fn new(collection: impl Into<CollectionId>, shard: ShardId) -> Self {
|
367 |
+
Self {
|
368 |
+
collection: collection.into(),
|
369 |
+
shard,
|
370 |
+
}
|
371 |
+
}
|
372 |
+
}
|
src/common/helpers.rs
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::cmp::max;
|
2 |
+
use std::sync::atomic::{AtomicUsize, Ordering};
|
3 |
+
use std::{fs, io};
|
4 |
+
|
5 |
+
use schemars::JsonSchema;
|
6 |
+
use serde::{Deserialize, Serialize};
|
7 |
+
use tokio::runtime;
|
8 |
+
use tokio::runtime::Runtime;
|
9 |
+
use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig};
|
10 |
+
use validator::Validate;
|
11 |
+
|
12 |
+
use crate::settings::{Settings, TlsConfig};
|
13 |
+
|
14 |
+
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate)]
|
15 |
+
pub struct LocksOption {
|
16 |
+
pub error_message: Option<String>,
|
17 |
+
pub write: bool,
|
18 |
+
}
|
19 |
+
|
20 |
+
pub fn create_search_runtime(max_search_threads: usize) -> io::Result<Runtime> {
|
21 |
+
let mut search_threads = max_search_threads;
|
22 |
+
|
23 |
+
if search_threads == 0 {
|
24 |
+
let num_cpu = common::cpu::get_num_cpus();
|
25 |
+
// At least one thread, but not more than number of CPUs - 1 if there are more than 2 CPU
|
26 |
+
// Example:
|
27 |
+
// Num CPU = 1 -> 1 thread
|
28 |
+
// Num CPU = 2 -> 2 thread - if we use one thread with 2 cpus, its too much un-utilized resources
|
29 |
+
// Num CPU = 3 -> 2 thread
|
30 |
+
// Num CPU = 4 -> 3 thread
|
31 |
+
// Num CPU = 5 -> 4 thread
|
32 |
+
search_threads = match num_cpu {
|
33 |
+
0 => 1,
|
34 |
+
1 => 1,
|
35 |
+
2 => 2,
|
36 |
+
_ => num_cpu - 1,
|
37 |
+
};
|
38 |
+
}
|
39 |
+
|
40 |
+
runtime::Builder::new_multi_thread()
|
41 |
+
.worker_threads(search_threads)
|
42 |
+
.max_blocking_threads(search_threads)
|
43 |
+
.enable_all()
|
44 |
+
.thread_name_fn(|| {
|
45 |
+
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
|
46 |
+
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
|
47 |
+
format!("search-{id}")
|
48 |
+
})
|
49 |
+
.build()
|
50 |
+
}
|
51 |
+
|
52 |
+
pub fn create_update_runtime(max_optimization_threads: usize) -> io::Result<Runtime> {
|
53 |
+
let mut update_runtime_builder = runtime::Builder::new_multi_thread();
|
54 |
+
|
55 |
+
update_runtime_builder
|
56 |
+
.enable_time()
|
57 |
+
.thread_name_fn(move || {
|
58 |
+
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
|
59 |
+
let update_id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
|
60 |
+
format!("update-{update_id}")
|
61 |
+
});
|
62 |
+
|
63 |
+
if max_optimization_threads > 0 {
|
64 |
+
// panics if val is not larger than 0.
|
65 |
+
update_runtime_builder.max_blocking_threads(max_optimization_threads);
|
66 |
+
}
|
67 |
+
update_runtime_builder.build()
|
68 |
+
}
|
69 |
+
|
70 |
+
pub fn create_general_purpose_runtime() -> io::Result<Runtime> {
|
71 |
+
runtime::Builder::new_multi_thread()
|
72 |
+
.enable_time()
|
73 |
+
.enable_io()
|
74 |
+
.worker_threads(max(common::cpu::get_num_cpus(), 2))
|
75 |
+
.thread_name_fn(|| {
|
76 |
+
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
|
77 |
+
let general_id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
|
78 |
+
format!("general-{general_id}")
|
79 |
+
})
|
80 |
+
.build()
|
81 |
+
}
|
82 |
+
|
83 |
+
/// Load client TLS configuration.
|
84 |
+
pub fn load_tls_client_config(settings: &Settings) -> io::Result<Option<ClientTlsConfig>> {
|
85 |
+
if settings.cluster.p2p.enable_tls {
|
86 |
+
let tls_config = &settings.tls()?;
|
87 |
+
Ok(Some(
|
88 |
+
ClientTlsConfig::new()
|
89 |
+
.identity(load_identity(tls_config)?)
|
90 |
+
.ca_certificate(load_ca_certificate(tls_config)?),
|
91 |
+
))
|
92 |
+
} else {
|
93 |
+
Ok(None)
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
/// Load server TLS configuration for external gRPC
|
98 |
+
pub fn load_tls_external_server_config(tls_config: &TlsConfig) -> io::Result<ServerTlsConfig> {
|
99 |
+
Ok(ServerTlsConfig::new().identity(load_identity(tls_config)?))
|
100 |
+
}
|
101 |
+
|
102 |
+
/// Load server TLS configuration for internal gRPC, check client certificate against CA
|
103 |
+
pub fn load_tls_internal_server_config(tls_config: &TlsConfig) -> io::Result<ServerTlsConfig> {
|
104 |
+
Ok(ServerTlsConfig::new()
|
105 |
+
.identity(load_identity(tls_config)?)
|
106 |
+
.client_ca_root(load_ca_certificate(tls_config)?))
|
107 |
+
}
|
108 |
+
|
109 |
+
fn load_identity(tls_config: &TlsConfig) -> io::Result<Identity> {
|
110 |
+
let cert = fs::read_to_string(&tls_config.cert)?;
|
111 |
+
let key = fs::read_to_string(&tls_config.key)?;
|
112 |
+
Ok(Identity::from_pem(cert, key))
|
113 |
+
}
|
114 |
+
|
115 |
+
fn load_ca_certificate(tls_config: &TlsConfig) -> io::Result<Certificate> {
|
116 |
+
let pem = fs::read_to_string(&tls_config.ca_cert)?;
|
117 |
+
Ok(Certificate::from_pem(pem))
|
118 |
+
}
|
119 |
+
|
120 |
+
pub fn tonic_error_to_io_error(err: tonic::transport::Error) -> io::Error {
|
121 |
+
io::Error::new(io::ErrorKind::Other, err)
|
122 |
+
}
|
123 |
+
|
124 |
+
#[cfg(test)]
|
125 |
+
mod tests {
|
126 |
+
use std::sync::Arc;
|
127 |
+
use std::thread;
|
128 |
+
use std::thread::sleep;
|
129 |
+
use std::time::Duration;
|
130 |
+
|
131 |
+
use collection::common::is_ready::IsReady;
|
132 |
+
|
133 |
+
#[test]
|
134 |
+
fn test_is_ready() {
|
135 |
+
let is_ready = Arc::new(IsReady::default());
|
136 |
+
let is_ready_clone = is_ready.clone();
|
137 |
+
let join = thread::spawn(move || {
|
138 |
+
is_ready_clone.await_ready();
|
139 |
+
eprintln!(
|
140 |
+
"is_ready_clone.check_ready() = {:#?}",
|
141 |
+
is_ready_clone.check_ready()
|
142 |
+
);
|
143 |
+
});
|
144 |
+
|
145 |
+
sleep(Duration::from_millis(500));
|
146 |
+
eprintln!("Making ready");
|
147 |
+
is_ready.make_ready();
|
148 |
+
sleep(Duration::from_millis(500));
|
149 |
+
join.join().unwrap()
|
150 |
+
}
|
151 |
+
}
|
src/common/http_client.rs
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::path::Path;
|
2 |
+
use std::{fs, io, result};
|
3 |
+
|
4 |
+
use reqwest::header::{HeaderMap, HeaderValue, InvalidHeaderValue};
|
5 |
+
use storage::content_manager::errors::StorageError;
|
6 |
+
|
7 |
+
use super::auth::HTTP_HEADER_API_KEY;
|
8 |
+
use crate::settings::{Settings, TlsConfig};
|
9 |
+
|
10 |
+
#[derive(Clone)]
|
11 |
+
pub struct HttpClient {
|
12 |
+
tls_config: Option<TlsConfig>,
|
13 |
+
verify_https_client_certificate: bool,
|
14 |
+
}
|
15 |
+
|
16 |
+
impl HttpClient {
|
17 |
+
pub fn from_settings(settings: &Settings) -> Result<Self> {
|
18 |
+
let tls_config = if settings.service.enable_tls {
|
19 |
+
let Some(tls_config) = settings.tls.clone() else {
|
20 |
+
return Err(Error::TlsConfigUndefined);
|
21 |
+
};
|
22 |
+
|
23 |
+
Some(tls_config)
|
24 |
+
} else {
|
25 |
+
None
|
26 |
+
};
|
27 |
+
|
28 |
+
let verify_https_client_certificate = settings.service.verify_https_client_certificate;
|
29 |
+
|
30 |
+
let http_client = Self {
|
31 |
+
tls_config,
|
32 |
+
verify_https_client_certificate,
|
33 |
+
};
|
34 |
+
|
35 |
+
Ok(http_client)
|
36 |
+
}
|
37 |
+
|
38 |
+
/// Create a new HTTP(S) client
|
39 |
+
///
|
40 |
+
/// An API key can be optionally provided to be used in this HTTP client. It'll send the API
|
41 |
+
/// key as `Api-key` header in every request.
|
42 |
+
///
|
43 |
+
/// # Warning
|
44 |
+
///
|
45 |
+
/// Setting an API key may leak when the client is used to send a request to a malicious
|
46 |
+
/// server. This is potentially dangerous if a user has control over what URL is accessed.
|
47 |
+
///
|
48 |
+
/// For this reason the API key is not set by default as provided in the configuration. It must
|
49 |
+
/// be explicitly provided when creating the HTTP client.
|
50 |
+
pub fn client(&self, api_key: Option<&str>) -> Result<reqwest::Client> {
|
51 |
+
https_client(
|
52 |
+
api_key,
|
53 |
+
self.tls_config.as_ref(),
|
54 |
+
self.verify_https_client_certificate,
|
55 |
+
)
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
fn https_client(
|
60 |
+
api_key: Option<&str>,
|
61 |
+
tls_config: Option<&TlsConfig>,
|
62 |
+
verify_https_client_certificate: bool,
|
63 |
+
) -> Result<reqwest::Client> {
|
64 |
+
let mut builder = reqwest::Client::builder();
|
65 |
+
|
66 |
+
// Configure TLS root certificate and validation
|
67 |
+
if let Some(tls_config) = tls_config {
|
68 |
+
builder = builder.add_root_certificate(https_client_ca_cert(tls_config.ca_cert.as_ref())?);
|
69 |
+
|
70 |
+
if verify_https_client_certificate {
|
71 |
+
builder = builder.identity(https_client_identity(
|
72 |
+
tls_config.cert.as_ref(),
|
73 |
+
tls_config.key.as_ref(),
|
74 |
+
)?);
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
// Attach API key as sensitive header
|
79 |
+
if let Some(api_key) = api_key {
|
80 |
+
let mut headers = HeaderMap::new();
|
81 |
+
let mut api_key_value = HeaderValue::from_str(api_key).map_err(Error::MalformedApiKey)?;
|
82 |
+
api_key_value.set_sensitive(true);
|
83 |
+
headers.insert(HTTP_HEADER_API_KEY, api_key_value);
|
84 |
+
builder = builder.default_headers(headers);
|
85 |
+
}
|
86 |
+
|
87 |
+
let client = builder.build()?;
|
88 |
+
|
89 |
+
Ok(client)
|
90 |
+
}
|
91 |
+
|
92 |
+
fn https_client_ca_cert(ca_cert: &Path) -> Result<reqwest::tls::Certificate> {
|
93 |
+
let ca_cert_pem =
|
94 |
+
fs::read(ca_cert).map_err(|err| Error::failed_to_read(err, "CA certificate", ca_cert))?;
|
95 |
+
|
96 |
+
let ca_cert = reqwest::Certificate::from_pem(&ca_cert_pem)?;
|
97 |
+
|
98 |
+
Ok(ca_cert)
|
99 |
+
}
|
100 |
+
|
101 |
+
fn https_client_identity(cert: &Path, key: &Path) -> Result<reqwest::tls::Identity> {
|
102 |
+
let mut identity_pem =
|
103 |
+
fs::read(cert).map_err(|err| Error::failed_to_read(err, "certificate", cert))?;
|
104 |
+
|
105 |
+
let mut key_file = fs::File::open(key).map_err(|err| Error::failed_to_read(err, "key", key))?;
|
106 |
+
|
107 |
+
// Concatenate certificate and key into a single PEM bytes
|
108 |
+
io::copy(&mut key_file, &mut identity_pem)
|
109 |
+
.map_err(|err| Error::failed_to_read(err, "key", key))?;
|
110 |
+
|
111 |
+
let identity = reqwest::Identity::from_pem(&identity_pem)?;
|
112 |
+
|
113 |
+
Ok(identity)
|
114 |
+
}
|
115 |
+
|
116 |
+
pub type Result<T, E = Error> = result::Result<T, E>;
|
117 |
+
|
118 |
+
#[derive(Debug, thiserror::Error)]
|
119 |
+
pub enum Error {
|
120 |
+
#[error("TLS config is not defined in the Qdrant config file")]
|
121 |
+
TlsConfigUndefined,
|
122 |
+
|
123 |
+
#[error("{1}: {0}")]
|
124 |
+
Io(#[source] io::Error, String),
|
125 |
+
|
126 |
+
#[error("failed to setup HTTPS client: {0}")]
|
127 |
+
Reqwest(#[from] reqwest::Error),
|
128 |
+
|
129 |
+
#[error("malformed API key")]
|
130 |
+
MalformedApiKey(#[source] InvalidHeaderValue),
|
131 |
+
}
|
132 |
+
|
133 |
+
impl Error {
|
134 |
+
pub fn io(source: io::Error, context: impl Into<String>) -> Self {
|
135 |
+
Self::Io(source, context.into())
|
136 |
+
}
|
137 |
+
|
138 |
+
pub fn failed_to_read(source: io::Error, file: &str, path: &Path) -> Self {
|
139 |
+
Self::io(
|
140 |
+
source,
|
141 |
+
format!("failed to read HTTPS client {file} file {}", path.display()),
|
142 |
+
)
|
143 |
+
}
|
144 |
+
}
|
145 |
+
|
146 |
+
impl From<Error> for StorageError {
|
147 |
+
fn from(err: Error) -> Self {
|
148 |
+
StorageError::service_error(format!("failed to initialize HTTP(S) client: {err}"))
|
149 |
+
}
|
150 |
+
}
|
151 |
+
|
152 |
+
impl From<Error> for io::Error {
|
153 |
+
fn from(err: Error) -> Self {
|
154 |
+
io::Error::new(io::ErrorKind::Other, err)
|
155 |
+
}
|
156 |
+
}
|
src/common/inference/batch_processing.rs
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::collections::HashSet;
|
2 |
+
|
3 |
+
use api::rest::{
|
4 |
+
ContextInput, ContextPair, DiscoverInput, Prefetch, Query, QueryGroupsRequestInternal,
|
5 |
+
QueryInterface, QueryRequestInternal, RecommendInput, VectorInput,
|
6 |
+
};
|
7 |
+
|
8 |
+
use super::service::{InferenceData, InferenceInput, InferenceRequest};
|
9 |
+
|
10 |
+
pub struct BatchAccum {
|
11 |
+
pub(crate) objects: HashSet<InferenceData>,
|
12 |
+
}
|
13 |
+
|
14 |
+
impl BatchAccum {
|
15 |
+
pub fn new() -> Self {
|
16 |
+
Self {
|
17 |
+
objects: HashSet::new(),
|
18 |
+
}
|
19 |
+
}
|
20 |
+
|
21 |
+
pub fn add(&mut self, data: InferenceData) {
|
22 |
+
self.objects.insert(data);
|
23 |
+
}
|
24 |
+
|
25 |
+
pub fn extend(&mut self, other: BatchAccum) {
|
26 |
+
self.objects.extend(other.objects);
|
27 |
+
}
|
28 |
+
|
29 |
+
pub fn is_empty(&self) -> bool {
|
30 |
+
self.objects.is_empty()
|
31 |
+
}
|
32 |
+
}
|
33 |
+
|
34 |
+
impl From<&BatchAccum> for InferenceRequest {
|
35 |
+
fn from(batch: &BatchAccum) -> Self {
|
36 |
+
Self {
|
37 |
+
inputs: batch
|
38 |
+
.objects
|
39 |
+
.iter()
|
40 |
+
.cloned()
|
41 |
+
.map(InferenceInput::from)
|
42 |
+
.collect(),
|
43 |
+
inference: None,
|
44 |
+
token: None,
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
fn collect_vector_input(vector: &VectorInput, batch: &mut BatchAccum) {
|
50 |
+
match vector {
|
51 |
+
VectorInput::Document(doc) => batch.add(InferenceData::Document(doc.clone())),
|
52 |
+
VectorInput::Image(img) => batch.add(InferenceData::Image(img.clone())),
|
53 |
+
VectorInput::Object(obj) => batch.add(InferenceData::Object(obj.clone())),
|
54 |
+
// types that are not supported in the Inference Service
|
55 |
+
VectorInput::DenseVector(_) => {}
|
56 |
+
VectorInput::SparseVector(_) => {}
|
57 |
+
VectorInput::MultiDenseVector(_) => {}
|
58 |
+
VectorInput::Id(_) => {}
|
59 |
+
}
|
60 |
+
}
|
61 |
+
|
62 |
+
fn collect_context_pair(pair: &ContextPair, batch: &mut BatchAccum) {
|
63 |
+
collect_vector_input(&pair.positive, batch);
|
64 |
+
collect_vector_input(&pair.negative, batch);
|
65 |
+
}
|
66 |
+
|
67 |
+
fn collect_discover_input(discover: &DiscoverInput, batch: &mut BatchAccum) {
|
68 |
+
collect_vector_input(&discover.target, batch);
|
69 |
+
if let Some(context) = &discover.context {
|
70 |
+
for pair in context {
|
71 |
+
collect_context_pair(pair, batch);
|
72 |
+
}
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
fn collect_recommend_input(recommend: &RecommendInput, batch: &mut BatchAccum) {
|
77 |
+
if let Some(positive) = &recommend.positive {
|
78 |
+
for vector in positive {
|
79 |
+
collect_vector_input(vector, batch);
|
80 |
+
}
|
81 |
+
}
|
82 |
+
if let Some(negative) = &recommend.negative {
|
83 |
+
for vector in negative {
|
84 |
+
collect_vector_input(vector, batch);
|
85 |
+
}
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
fn collect_query(query: &Query, batch: &mut BatchAccum) {
|
90 |
+
match query {
|
91 |
+
Query::Nearest(nearest) => collect_vector_input(&nearest.nearest, batch),
|
92 |
+
Query::Recommend(recommend) => collect_recommend_input(&recommend.recommend, batch),
|
93 |
+
Query::Discover(discover) => collect_discover_input(&discover.discover, batch),
|
94 |
+
Query::Context(context) => {
|
95 |
+
if let ContextInput(Some(pairs)) = &context.context {
|
96 |
+
for pair in pairs {
|
97 |
+
collect_context_pair(pair, batch);
|
98 |
+
}
|
99 |
+
}
|
100 |
+
}
|
101 |
+
Query::OrderBy(_) | Query::Fusion(_) | Query::Sample(_) => {}
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
fn collect_query_interface(query: &QueryInterface, batch: &mut BatchAccum) {
|
106 |
+
match query {
|
107 |
+
QueryInterface::Nearest(vector) => collect_vector_input(vector, batch),
|
108 |
+
QueryInterface::Query(query) => collect_query(query, batch),
|
109 |
+
}
|
110 |
+
}
|
111 |
+
|
112 |
+
fn collect_prefetch(prefetch: &Prefetch, batch: &mut BatchAccum) {
|
113 |
+
let Prefetch {
|
114 |
+
prefetch,
|
115 |
+
query,
|
116 |
+
using: _,
|
117 |
+
filter: _,
|
118 |
+
params: _,
|
119 |
+
score_threshold: _,
|
120 |
+
limit: _,
|
121 |
+
lookup_from: _,
|
122 |
+
} = prefetch;
|
123 |
+
|
124 |
+
if let Some(query) = query {
|
125 |
+
collect_query_interface(query, batch);
|
126 |
+
}
|
127 |
+
|
128 |
+
if let Some(prefetches) = prefetch {
|
129 |
+
for p in prefetches {
|
130 |
+
collect_prefetch(p, batch);
|
131 |
+
}
|
132 |
+
}
|
133 |
+
}
|
134 |
+
|
135 |
+
pub fn collect_query_groups_request(request: &QueryGroupsRequestInternal) -> BatchAccum {
|
136 |
+
let mut batch = BatchAccum::new();
|
137 |
+
|
138 |
+
let QueryGroupsRequestInternal {
|
139 |
+
query,
|
140 |
+
prefetch,
|
141 |
+
using: _,
|
142 |
+
filter: _,
|
143 |
+
params: _,
|
144 |
+
score_threshold: _,
|
145 |
+
with_vector: _,
|
146 |
+
with_payload: _,
|
147 |
+
lookup_from: _,
|
148 |
+
group_request: _,
|
149 |
+
} = request;
|
150 |
+
|
151 |
+
if let Some(query) = query {
|
152 |
+
collect_query_interface(query, &mut batch);
|
153 |
+
}
|
154 |
+
|
155 |
+
if let Some(prefetches) = prefetch {
|
156 |
+
for prefetch in prefetches {
|
157 |
+
collect_prefetch(prefetch, &mut batch);
|
158 |
+
}
|
159 |
+
}
|
160 |
+
|
161 |
+
batch
|
162 |
+
}
|
163 |
+
|
164 |
+
pub fn collect_query_request(request: &QueryRequestInternal) -> BatchAccum {
|
165 |
+
let mut batch = BatchAccum::new();
|
166 |
+
|
167 |
+
let QueryRequestInternal {
|
168 |
+
prefetch,
|
169 |
+
query,
|
170 |
+
using: _,
|
171 |
+
filter: _,
|
172 |
+
score_threshold: _,
|
173 |
+
params: _,
|
174 |
+
limit: _,
|
175 |
+
offset: _,
|
176 |
+
with_vector: _,
|
177 |
+
with_payload: _,
|
178 |
+
lookup_from: _,
|
179 |
+
} = request;
|
180 |
+
|
181 |
+
if let Some(query) = query {
|
182 |
+
collect_query_interface(query, &mut batch);
|
183 |
+
}
|
184 |
+
|
185 |
+
if let Some(prefetches) = prefetch {
|
186 |
+
for prefetch in prefetches {
|
187 |
+
collect_prefetch(prefetch, &mut batch);
|
188 |
+
}
|
189 |
+
}
|
190 |
+
|
191 |
+
batch
|
192 |
+
}
|
193 |
+
|
194 |
+
#[cfg(test)]
|
195 |
+
mod tests {
|
196 |
+
use api::rest::schema::{DiscoverQuery, Document, Image, InferenceObject, NearestQuery};
|
197 |
+
use api::rest::QueryBaseGroupRequest;
|
198 |
+
use serde_json::json;
|
199 |
+
|
200 |
+
use super::*;
|
201 |
+
|
202 |
+
fn create_test_document(text: &str) -> Document {
|
203 |
+
Document {
|
204 |
+
text: text.to_string(),
|
205 |
+
model: "test-model".to_string(),
|
206 |
+
options: Default::default(),
|
207 |
+
}
|
208 |
+
}
|
209 |
+
|
210 |
+
fn create_test_image(url: &str) -> Image {
|
211 |
+
Image {
|
212 |
+
image: json!({"data": url.to_string()}),
|
213 |
+
model: "test-model".to_string(),
|
214 |
+
options: Default::default(),
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
fn create_test_object(data: &str) -> InferenceObject {
|
219 |
+
InferenceObject {
|
220 |
+
object: json!({"data": data}),
|
221 |
+
model: "test-model".to_string(),
|
222 |
+
options: Default::default(),
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
#[test]
|
227 |
+
fn test_batch_accum_basic() {
|
228 |
+
let mut batch = BatchAccum::new();
|
229 |
+
assert!(batch.objects.is_empty());
|
230 |
+
|
231 |
+
let doc = InferenceData::Document(create_test_document("test"));
|
232 |
+
batch.add(doc.clone());
|
233 |
+
assert_eq!(batch.objects.len(), 1);
|
234 |
+
|
235 |
+
batch.add(doc);
|
236 |
+
assert_eq!(batch.objects.len(), 1);
|
237 |
+
}
|
238 |
+
|
239 |
+
#[test]
|
240 |
+
fn test_batch_accum_extend() {
|
241 |
+
let mut batch1 = BatchAccum::new();
|
242 |
+
let mut batch2 = BatchAccum::new();
|
243 |
+
|
244 |
+
let doc1 = InferenceData::Document(create_test_document("test1"));
|
245 |
+
let doc2 = InferenceData::Document(create_test_document("test2"));
|
246 |
+
|
247 |
+
batch1.add(doc1);
|
248 |
+
batch2.add(doc2);
|
249 |
+
|
250 |
+
batch1.extend(batch2);
|
251 |
+
assert_eq!(batch1.objects.len(), 2);
|
252 |
+
}
|
253 |
+
|
254 |
+
#[test]
|
255 |
+
fn test_deduplication() {
|
256 |
+
let mut batch = BatchAccum::new();
|
257 |
+
|
258 |
+
let doc1 = InferenceData::Document(create_test_document("same"));
|
259 |
+
let doc2 = InferenceData::Document(create_test_document("same"));
|
260 |
+
|
261 |
+
batch.add(doc1);
|
262 |
+
batch.add(doc2);
|
263 |
+
|
264 |
+
assert_eq!(batch.objects.len(), 1);
|
265 |
+
}
|
266 |
+
|
267 |
+
#[test]
|
268 |
+
fn test_collect_vector_input() {
|
269 |
+
let mut batch = BatchAccum::new();
|
270 |
+
|
271 |
+
let doc_input = VectorInput::Document(create_test_document("test"));
|
272 |
+
let img_input = VectorInput::Image(create_test_image("test.jpg"));
|
273 |
+
let obj_input = VectorInput::Object(create_test_object("test"));
|
274 |
+
|
275 |
+
collect_vector_input(&doc_input, &mut batch);
|
276 |
+
collect_vector_input(&img_input, &mut batch);
|
277 |
+
collect_vector_input(&obj_input, &mut batch);
|
278 |
+
|
279 |
+
assert_eq!(batch.objects.len(), 3);
|
280 |
+
}
|
281 |
+
|
282 |
+
#[test]
|
283 |
+
fn test_collect_prefetch() {
|
284 |
+
let prefetch = Prefetch {
|
285 |
+
query: Some(QueryInterface::Nearest(VectorInput::Document(
|
286 |
+
create_test_document("test"),
|
287 |
+
))),
|
288 |
+
prefetch: Some(vec![Prefetch {
|
289 |
+
query: Some(QueryInterface::Nearest(VectorInput::Image(
|
290 |
+
create_test_image("nested.jpg"),
|
291 |
+
))),
|
292 |
+
prefetch: None,
|
293 |
+
using: None,
|
294 |
+
filter: None,
|
295 |
+
params: None,
|
296 |
+
score_threshold: None,
|
297 |
+
limit: None,
|
298 |
+
lookup_from: None,
|
299 |
+
}]),
|
300 |
+
using: None,
|
301 |
+
filter: None,
|
302 |
+
params: None,
|
303 |
+
score_threshold: None,
|
304 |
+
limit: None,
|
305 |
+
lookup_from: None,
|
306 |
+
};
|
307 |
+
|
308 |
+
let mut batch = BatchAccum::new();
|
309 |
+
collect_prefetch(&prefetch, &mut batch);
|
310 |
+
assert_eq!(batch.objects.len(), 2);
|
311 |
+
}
|
312 |
+
|
313 |
+
#[test]
|
314 |
+
fn test_collect_query_groups_request() {
|
315 |
+
let request = QueryGroupsRequestInternal {
|
316 |
+
query: Some(QueryInterface::Query(Query::Nearest(NearestQuery {
|
317 |
+
nearest: VectorInput::Document(create_test_document("test")),
|
318 |
+
}))),
|
319 |
+
prefetch: Some(vec![Prefetch {
|
320 |
+
query: Some(QueryInterface::Query(Query::Discover(DiscoverQuery {
|
321 |
+
discover: DiscoverInput {
|
322 |
+
target: VectorInput::Image(create_test_image("test.jpg")),
|
323 |
+
context: Some(vec![ContextPair {
|
324 |
+
positive: VectorInput::Document(create_test_document("pos")),
|
325 |
+
negative: VectorInput::Image(create_test_image("neg.jpg")),
|
326 |
+
}]),
|
327 |
+
},
|
328 |
+
}))),
|
329 |
+
prefetch: None,
|
330 |
+
using: None,
|
331 |
+
filter: None,
|
332 |
+
params: None,
|
333 |
+
score_threshold: None,
|
334 |
+
limit: None,
|
335 |
+
lookup_from: None,
|
336 |
+
}]),
|
337 |
+
using: None,
|
338 |
+
filter: None,
|
339 |
+
params: None,
|
340 |
+
score_threshold: None,
|
341 |
+
with_vector: None,
|
342 |
+
with_payload: None,
|
343 |
+
lookup_from: None,
|
344 |
+
group_request: QueryBaseGroupRequest {
|
345 |
+
group_by: "test".parse().unwrap(),
|
346 |
+
group_size: None,
|
347 |
+
limit: None,
|
348 |
+
with_lookup: None,
|
349 |
+
},
|
350 |
+
};
|
351 |
+
|
352 |
+
let batch = collect_query_groups_request(&request);
|
353 |
+
assert_eq!(batch.objects.len(), 4);
|
354 |
+
}
|
355 |
+
|
356 |
+
#[test]
|
357 |
+
fn test_different_model_same_content() {
|
358 |
+
let mut batch = BatchAccum::new();
|
359 |
+
|
360 |
+
let mut doc1 = create_test_document("same");
|
361 |
+
let mut doc2 = create_test_document("same");
|
362 |
+
doc1.model = "model1".to_string();
|
363 |
+
doc2.model = "model2".to_string();
|
364 |
+
|
365 |
+
batch.add(InferenceData::Document(doc1));
|
366 |
+
batch.add(InferenceData::Document(doc2));
|
367 |
+
|
368 |
+
assert_eq!(batch.objects.len(), 2);
|
369 |
+
}
|
370 |
+
}
|
src/common/inference/batch_processing_grpc.rs
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::collections::HashSet;
|
2 |
+
|
3 |
+
use api::grpc::qdrant::vector_input::Variant;
|
4 |
+
use api::grpc::qdrant::{
|
5 |
+
query, ContextInput, ContextInputPair, DiscoverInput, PrefetchQuery, Query, RecommendInput,
|
6 |
+
VectorInput,
|
7 |
+
};
|
8 |
+
use api::rest::schema as rest;
|
9 |
+
use tonic::Status;
|
10 |
+
|
11 |
+
use super::service::{InferenceData, InferenceInput, InferenceRequest};
|
12 |
+
|
13 |
+
pub struct BatchAccumGrpc {
|
14 |
+
pub(crate) objects: HashSet<InferenceData>,
|
15 |
+
}
|
16 |
+
|
17 |
+
impl BatchAccumGrpc {
|
18 |
+
pub fn new() -> Self {
|
19 |
+
Self {
|
20 |
+
objects: HashSet::new(),
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
pub fn add(&mut self, data: InferenceData) {
|
25 |
+
self.objects.insert(data);
|
26 |
+
}
|
27 |
+
|
28 |
+
pub fn extend(&mut self, other: BatchAccumGrpc) {
|
29 |
+
self.objects.extend(other.objects);
|
30 |
+
}
|
31 |
+
|
32 |
+
pub fn is_empty(&self) -> bool {
|
33 |
+
self.objects.is_empty()
|
34 |
+
}
|
35 |
+
}
|
36 |
+
|
37 |
+
impl From<&BatchAccumGrpc> for InferenceRequest {
|
38 |
+
fn from(batch: &BatchAccumGrpc) -> Self {
|
39 |
+
Self {
|
40 |
+
inputs: batch
|
41 |
+
.objects
|
42 |
+
.iter()
|
43 |
+
.cloned()
|
44 |
+
.map(InferenceInput::from)
|
45 |
+
.collect(),
|
46 |
+
inference: None,
|
47 |
+
token: None,
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
fn collect_vector_input(vector: &VectorInput, batch: &mut BatchAccumGrpc) -> Result<(), Status> {
|
53 |
+
let Some(variant) = &vector.variant else {
|
54 |
+
return Ok(());
|
55 |
+
};
|
56 |
+
|
57 |
+
match variant {
|
58 |
+
Variant::Id(_) => {}
|
59 |
+
Variant::Dense(_) => {}
|
60 |
+
Variant::Sparse(_) => {}
|
61 |
+
Variant::MultiDense(_) => {}
|
62 |
+
Variant::Document(document) => {
|
63 |
+
let doc = rest::Document::try_from(document.clone())
|
64 |
+
.map_err(|e| Status::internal(format!("Document conversion error: {e:?}")))?;
|
65 |
+
batch.add(InferenceData::Document(doc));
|
66 |
+
}
|
67 |
+
Variant::Image(image) => {
|
68 |
+
let img = rest::Image::try_from(image.clone())
|
69 |
+
.map_err(|e| Status::internal(format!("Image conversion error: {e:?}")))?;
|
70 |
+
batch.add(InferenceData::Image(img));
|
71 |
+
}
|
72 |
+
Variant::Object(object) => {
|
73 |
+
let obj = rest::InferenceObject::try_from(object.clone())
|
74 |
+
.map_err(|e| Status::internal(format!("Object conversion error: {e:?}")))?;
|
75 |
+
batch.add(InferenceData::Object(obj));
|
76 |
+
}
|
77 |
+
}
|
78 |
+
Ok(())
|
79 |
+
}
|
80 |
+
|
81 |
+
pub(crate) fn collect_context_input(
|
82 |
+
context: &ContextInput,
|
83 |
+
batch: &mut BatchAccumGrpc,
|
84 |
+
) -> Result<(), Status> {
|
85 |
+
let ContextInput { pairs } = context;
|
86 |
+
|
87 |
+
for pair in pairs {
|
88 |
+
collect_context_input_pair(pair, batch)?;
|
89 |
+
}
|
90 |
+
|
91 |
+
Ok(())
|
92 |
+
}
|
93 |
+
|
94 |
+
fn collect_context_input_pair(
|
95 |
+
pair: &ContextInputPair,
|
96 |
+
batch: &mut BatchAccumGrpc,
|
97 |
+
) -> Result<(), Status> {
|
98 |
+
let ContextInputPair { positive, negative } = pair;
|
99 |
+
|
100 |
+
if let Some(positive) = positive {
|
101 |
+
collect_vector_input(positive, batch)?;
|
102 |
+
}
|
103 |
+
|
104 |
+
if let Some(negative) = negative {
|
105 |
+
collect_vector_input(negative, batch)?;
|
106 |
+
}
|
107 |
+
|
108 |
+
Ok(())
|
109 |
+
}
|
110 |
+
|
111 |
+
pub(crate) fn collect_discover_input(
|
112 |
+
discover: &DiscoverInput,
|
113 |
+
batch: &mut BatchAccumGrpc,
|
114 |
+
) -> Result<(), Status> {
|
115 |
+
let DiscoverInput { target, context } = discover;
|
116 |
+
|
117 |
+
if let Some(vector) = target {
|
118 |
+
collect_vector_input(vector, batch)?;
|
119 |
+
}
|
120 |
+
|
121 |
+
if let Some(context) = context {
|
122 |
+
for pair in &context.pairs {
|
123 |
+
collect_context_input_pair(pair, batch)?;
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
Ok(())
|
128 |
+
}
|
129 |
+
|
130 |
+
pub(crate) fn collect_recommend_input(
|
131 |
+
recommend: &RecommendInput,
|
132 |
+
batch: &mut BatchAccumGrpc,
|
133 |
+
) -> Result<(), Status> {
|
134 |
+
let RecommendInput {
|
135 |
+
positive,
|
136 |
+
negative,
|
137 |
+
strategy: _,
|
138 |
+
} = recommend;
|
139 |
+
|
140 |
+
for vector in positive {
|
141 |
+
collect_vector_input(vector, batch)?;
|
142 |
+
}
|
143 |
+
|
144 |
+
for vector in negative {
|
145 |
+
collect_vector_input(vector, batch)?;
|
146 |
+
}
|
147 |
+
|
148 |
+
Ok(())
|
149 |
+
}
|
150 |
+
|
151 |
+
pub(crate) fn collect_query(query: &Query, batch: &mut BatchAccumGrpc) -> Result<(), Status> {
|
152 |
+
let Some(variant) = &query.variant else {
|
153 |
+
return Ok(());
|
154 |
+
};
|
155 |
+
|
156 |
+
match variant {
|
157 |
+
query::Variant::Nearest(nearest) => collect_vector_input(nearest, batch)?,
|
158 |
+
query::Variant::Recommend(recommend) => collect_recommend_input(recommend, batch)?,
|
159 |
+
query::Variant::Discover(discover) => collect_discover_input(discover, batch)?,
|
160 |
+
query::Variant::Context(context) => collect_context_input(context, batch)?,
|
161 |
+
query::Variant::OrderBy(_) => {}
|
162 |
+
query::Variant::Fusion(_) => {}
|
163 |
+
query::Variant::Sample(_) => {}
|
164 |
+
}
|
165 |
+
|
166 |
+
Ok(())
|
167 |
+
}
|
168 |
+
|
169 |
+
pub(crate) fn collect_prefetch(
|
170 |
+
prefetch: &PrefetchQuery,
|
171 |
+
batch: &mut BatchAccumGrpc,
|
172 |
+
) -> Result<(), Status> {
|
173 |
+
let PrefetchQuery {
|
174 |
+
prefetch,
|
175 |
+
query,
|
176 |
+
using: _,
|
177 |
+
filter: _,
|
178 |
+
params: _,
|
179 |
+
score_threshold: _,
|
180 |
+
limit: _,
|
181 |
+
lookup_from: _,
|
182 |
+
} = prefetch;
|
183 |
+
|
184 |
+
if let Some(query) = query {
|
185 |
+
collect_query(query, batch)?;
|
186 |
+
}
|
187 |
+
|
188 |
+
for p in prefetch {
|
189 |
+
collect_prefetch(p, batch)?;
|
190 |
+
}
|
191 |
+
|
192 |
+
Ok(())
|
193 |
+
}
|
194 |
+
|
195 |
+
#[cfg(test)]
|
196 |
+
mod tests {
|
197 |
+
use api::rest::schema::{Document, Image, InferenceObject};
|
198 |
+
use serde_json::json;
|
199 |
+
|
200 |
+
use super::*;
|
201 |
+
|
202 |
+
fn create_test_document(text: &str) -> Document {
|
203 |
+
Document {
|
204 |
+
text: text.to_string(),
|
205 |
+
model: "test-model".to_string(),
|
206 |
+
options: Default::default(),
|
207 |
+
}
|
208 |
+
}
|
209 |
+
|
210 |
+
fn create_test_image(url: &str) -> Image {
|
211 |
+
Image {
|
212 |
+
image: json!({"data": url.to_string()}),
|
213 |
+
model: "test-model".to_string(),
|
214 |
+
options: Default::default(),
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
fn create_test_object(data: &str) -> InferenceObject {
|
219 |
+
InferenceObject {
|
220 |
+
object: json!({"data": data}),
|
221 |
+
model: "test-model".to_string(),
|
222 |
+
options: Default::default(),
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
#[test]
|
227 |
+
fn test_batch_accum_basic() {
|
228 |
+
let mut batch = BatchAccumGrpc::new();
|
229 |
+
assert!(batch.objects.is_empty());
|
230 |
+
|
231 |
+
let doc = InferenceData::Document(create_test_document("test"));
|
232 |
+
batch.add(doc.clone());
|
233 |
+
assert_eq!(batch.objects.len(), 1);
|
234 |
+
|
235 |
+
batch.add(doc);
|
236 |
+
assert_eq!(batch.objects.len(), 1);
|
237 |
+
}
|
238 |
+
|
239 |
+
#[test]
|
240 |
+
fn test_batch_accum_extend() {
|
241 |
+
let mut batch1 = BatchAccumGrpc::new();
|
242 |
+
let mut batch2 = BatchAccumGrpc::new();
|
243 |
+
|
244 |
+
let doc1 = InferenceData::Document(create_test_document("test1"));
|
245 |
+
let doc2 = InferenceData::Document(create_test_document("test2"));
|
246 |
+
|
247 |
+
batch1.add(doc1);
|
248 |
+
batch2.add(doc2);
|
249 |
+
|
250 |
+
batch1.extend(batch2);
|
251 |
+
assert_eq!(batch1.objects.len(), 2);
|
252 |
+
}
|
253 |
+
|
254 |
+
#[test]
|
255 |
+
fn test_deduplication() {
|
256 |
+
let mut batch = BatchAccumGrpc::new();
|
257 |
+
|
258 |
+
let doc1 = InferenceData::Document(create_test_document("same"));
|
259 |
+
let doc2 = InferenceData::Document(create_test_document("same"));
|
260 |
+
|
261 |
+
batch.add(doc1);
|
262 |
+
batch.add(doc2);
|
263 |
+
|
264 |
+
assert_eq!(batch.objects.len(), 1);
|
265 |
+
}
|
266 |
+
|
267 |
+
#[test]
|
268 |
+
fn test_different_model_same_content() {
|
269 |
+
let mut batch = BatchAccumGrpc::new();
|
270 |
+
|
271 |
+
let mut doc1 = create_test_document("same");
|
272 |
+
let mut doc2 = create_test_document("same");
|
273 |
+
doc1.model = "model1".to_string();
|
274 |
+
doc2.model = "model2".to_string();
|
275 |
+
|
276 |
+
batch.add(InferenceData::Document(doc1));
|
277 |
+
batch.add(InferenceData::Document(doc2));
|
278 |
+
|
279 |
+
assert_eq!(batch.objects.len(), 2);
|
280 |
+
}
|
281 |
+
}
|
src/common/inference/config.rs
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use serde::{Deserialize, Serialize};
|
2 |
+
|
3 |
+
#[derive(Debug, Clone, Serialize, Deserialize)]
|
4 |
+
pub struct InferenceConfig {
|
5 |
+
pub address: Option<String>,
|
6 |
+
#[serde(default = "default_inference_timeout")]
|
7 |
+
pub timeout: u64,
|
8 |
+
pub token: Option<String>,
|
9 |
+
}
|
10 |
+
|
11 |
+
fn default_inference_timeout() -> u64 {
|
12 |
+
10
|
13 |
+
}
|
14 |
+
|
15 |
+
impl InferenceConfig {
|
16 |
+
pub fn new(address: Option<String>) -> Self {
|
17 |
+
Self {
|
18 |
+
address,
|
19 |
+
timeout: default_inference_timeout(),
|
20 |
+
token: None,
|
21 |
+
}
|
22 |
+
}
|
23 |
+
}
|
src/common/inference/infer_processing.rs
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::collections::{HashMap, HashSet};
|
2 |
+
|
3 |
+
use collection::operations::point_ops::VectorPersisted;
|
4 |
+
use storage::content_manager::errors::StorageError;
|
5 |
+
|
6 |
+
use super::batch_processing::BatchAccum;
|
7 |
+
use super::service::{InferenceData, InferenceInput, InferenceService, InferenceType};
|
8 |
+
|
9 |
+
pub struct BatchAccumInferred {
|
10 |
+
pub(crate) objects: HashMap<InferenceData, VectorPersisted>,
|
11 |
+
}
|
12 |
+
|
13 |
+
impl BatchAccumInferred {
|
14 |
+
pub fn new() -> Self {
|
15 |
+
Self {
|
16 |
+
objects: HashMap::new(),
|
17 |
+
}
|
18 |
+
}
|
19 |
+
|
20 |
+
pub async fn from_objects(
|
21 |
+
objects: HashSet<InferenceData>,
|
22 |
+
inference_type: InferenceType,
|
23 |
+
) -> Result<Self, StorageError> {
|
24 |
+
if objects.is_empty() {
|
25 |
+
return Ok(Self::new());
|
26 |
+
}
|
27 |
+
|
28 |
+
let Some(service) = InferenceService::get_global() else {
|
29 |
+
return Err(StorageError::service_error(
|
30 |
+
"InferenceService is not initialized. Please check if it was properly configured and initialized during startup."
|
31 |
+
));
|
32 |
+
};
|
33 |
+
|
34 |
+
service.validate()?;
|
35 |
+
|
36 |
+
let objects_serialized: Vec<_> = objects.into_iter().collect();
|
37 |
+
let inference_inputs: Vec<_> = objects_serialized
|
38 |
+
.iter()
|
39 |
+
.cloned()
|
40 |
+
.map(InferenceInput::from)
|
41 |
+
.collect();
|
42 |
+
|
43 |
+
let vectors = service
|
44 |
+
.infer(inference_inputs, inference_type)
|
45 |
+
.await
|
46 |
+
.map_err(|e| StorageError::service_error(
|
47 |
+
format!("Inference request failed. Check if inference service is running and properly configured: {e}")
|
48 |
+
))?;
|
49 |
+
|
50 |
+
if vectors.is_empty() {
|
51 |
+
return Err(StorageError::service_error(
|
52 |
+
"Inference service returned no vectors. Check if models are properly loaded.",
|
53 |
+
));
|
54 |
+
}
|
55 |
+
|
56 |
+
let objects = objects_serialized.into_iter().zip(vectors).collect();
|
57 |
+
|
58 |
+
Ok(Self { objects })
|
59 |
+
}
|
60 |
+
|
61 |
+
pub async fn from_batch_accum(
|
62 |
+
batch: BatchAccum,
|
63 |
+
inference_type: InferenceType,
|
64 |
+
) -> Result<Self, StorageError> {
|
65 |
+
let BatchAccum { objects } = batch;
|
66 |
+
Self::from_objects(objects, inference_type).await
|
67 |
+
}
|
68 |
+
|
69 |
+
pub fn get_vector(&self, data: &InferenceData) -> Option<&VectorPersisted> {
|
70 |
+
self.objects.get(data)
|
71 |
+
}
|
72 |
+
}
|
src/common/inference/mod.rs
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mod batch_processing;
|
2 |
+
mod batch_processing_grpc;
|
3 |
+
pub(crate) mod config;
|
4 |
+
mod infer_processing;
|
5 |
+
pub mod query_requests_grpc;
|
6 |
+
pub mod query_requests_rest;
|
7 |
+
pub mod service;
|
8 |
+
pub mod update_requests;
|
src/common/inference/query_requests_grpc.rs
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use api::conversions::json::json_path_from_proto;
|
2 |
+
use api::grpc::qdrant as grpc;
|
3 |
+
use api::grpc::qdrant::query::Variant;
|
4 |
+
use api::grpc::qdrant::RecommendInput;
|
5 |
+
use api::rest;
|
6 |
+
use api::rest::RecommendStrategy;
|
7 |
+
use collection::operations::universal_query::collection_query::{
|
8 |
+
CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, Query,
|
9 |
+
VectorInputInternal, VectorQuery,
|
10 |
+
};
|
11 |
+
use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal};
|
12 |
+
use segment::data_types::order_by::OrderBy;
|
13 |
+
use segment::data_types::vectors::{VectorInternal, DEFAULT_VECTOR_NAME};
|
14 |
+
use segment::vector_storage::query::{ContextPair, ContextQuery, DiscoveryQuery, RecoQuery};
|
15 |
+
use tonic::Status;
|
16 |
+
|
17 |
+
use crate::common::inference::batch_processing_grpc::{
|
18 |
+
collect_prefetch, collect_query, BatchAccumGrpc,
|
19 |
+
};
|
20 |
+
use crate::common::inference::infer_processing::BatchAccumInferred;
|
21 |
+
use crate::common::inference::service::{InferenceData, InferenceType};
|
22 |
+
|
23 |
+
/// ToDo: this function is supposed to call an inference endpoint internally
|
24 |
+
pub async fn convert_query_point_groups_from_grpc(
|
25 |
+
query: grpc::QueryPointGroups,
|
26 |
+
) -> Result<CollectionQueryGroupsRequest, Status> {
|
27 |
+
let grpc::QueryPointGroups {
|
28 |
+
collection_name: _,
|
29 |
+
prefetch,
|
30 |
+
query,
|
31 |
+
using,
|
32 |
+
filter,
|
33 |
+
params,
|
34 |
+
score_threshold,
|
35 |
+
with_payload,
|
36 |
+
with_vectors,
|
37 |
+
lookup_from,
|
38 |
+
limit,
|
39 |
+
group_size,
|
40 |
+
group_by,
|
41 |
+
with_lookup,
|
42 |
+
read_consistency: _,
|
43 |
+
timeout: _,
|
44 |
+
shard_key_selector: _,
|
45 |
+
} = query;
|
46 |
+
|
47 |
+
let mut batch = BatchAccumGrpc::new();
|
48 |
+
|
49 |
+
if let Some(q) = &query {
|
50 |
+
collect_query(q, &mut batch)?;
|
51 |
+
}
|
52 |
+
|
53 |
+
for p in &prefetch {
|
54 |
+
collect_prefetch(p, &mut batch)?;
|
55 |
+
}
|
56 |
+
|
57 |
+
let BatchAccumGrpc { objects } = batch;
|
58 |
+
|
59 |
+
let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search)
|
60 |
+
.await
|
61 |
+
.map_err(|e| Status::internal(format!("Inference error: {e}")))?;
|
62 |
+
|
63 |
+
let query = if let Some(q) = query {
|
64 |
+
Some(convert_query_with_inferred(q, &inferred)?)
|
65 |
+
} else {
|
66 |
+
None
|
67 |
+
};
|
68 |
+
|
69 |
+
let prefetch = prefetch
|
70 |
+
.into_iter()
|
71 |
+
.map(|p| convert_prefetch_with_inferred(p, &inferred))
|
72 |
+
.collect::<Result<Vec<_>, _>>()?;
|
73 |
+
|
74 |
+
let request = CollectionQueryGroupsRequest {
|
75 |
+
prefetch,
|
76 |
+
query,
|
77 |
+
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
|
78 |
+
filter: filter.map(TryFrom::try_from).transpose()?,
|
79 |
+
score_threshold,
|
80 |
+
with_vector: with_vectors
|
81 |
+
.map(From::from)
|
82 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
|
83 |
+
with_payload: with_payload
|
84 |
+
.map(TryFrom::try_from)
|
85 |
+
.transpose()?
|
86 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
|
87 |
+
lookup_from: lookup_from.map(From::from),
|
88 |
+
group_by: json_path_from_proto(&group_by)?,
|
89 |
+
group_size: group_size
|
90 |
+
.map(|s| s as usize)
|
91 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_GROUP_SIZE),
|
92 |
+
limit: limit
|
93 |
+
.map(|l| l as usize)
|
94 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
|
95 |
+
params: params.map(From::from),
|
96 |
+
with_lookup: with_lookup.map(TryFrom::try_from).transpose()?,
|
97 |
+
};
|
98 |
+
|
99 |
+
Ok(request)
|
100 |
+
}
|
101 |
+
|
102 |
+
/// ToDo: this function is supposed to call an inference endpoint internally
|
103 |
+
pub async fn convert_query_points_from_grpc(
|
104 |
+
query: grpc::QueryPoints,
|
105 |
+
) -> Result<CollectionQueryRequest, Status> {
|
106 |
+
let grpc::QueryPoints {
|
107 |
+
collection_name: _,
|
108 |
+
prefetch,
|
109 |
+
query,
|
110 |
+
using,
|
111 |
+
filter,
|
112 |
+
params,
|
113 |
+
score_threshold,
|
114 |
+
limit,
|
115 |
+
offset,
|
116 |
+
with_payload,
|
117 |
+
with_vectors,
|
118 |
+
read_consistency: _,
|
119 |
+
shard_key_selector: _,
|
120 |
+
lookup_from,
|
121 |
+
timeout: _,
|
122 |
+
} = query;
|
123 |
+
|
124 |
+
let mut batch = BatchAccumGrpc::new();
|
125 |
+
|
126 |
+
if let Some(q) = &query {
|
127 |
+
collect_query(q, &mut batch)?;
|
128 |
+
}
|
129 |
+
|
130 |
+
for p in &prefetch {
|
131 |
+
collect_prefetch(p, &mut batch)?;
|
132 |
+
}
|
133 |
+
|
134 |
+
let BatchAccumGrpc { objects } = batch;
|
135 |
+
|
136 |
+
let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search)
|
137 |
+
.await
|
138 |
+
.map_err(|e| Status::internal(format!("Inference error: {e}")))?;
|
139 |
+
|
140 |
+
let prefetch = prefetch
|
141 |
+
.into_iter()
|
142 |
+
.map(|p| convert_prefetch_with_inferred(p, &inferred))
|
143 |
+
.collect::<Result<Vec<_>, _>>()?;
|
144 |
+
|
145 |
+
let query = query
|
146 |
+
.map(|q| convert_query_with_inferred(q, &inferred))
|
147 |
+
.transpose()?;
|
148 |
+
|
149 |
+
Ok(CollectionQueryRequest {
|
150 |
+
prefetch,
|
151 |
+
query,
|
152 |
+
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
|
153 |
+
filter: filter.map(TryFrom::try_from).transpose()?,
|
154 |
+
score_threshold,
|
155 |
+
limit: limit
|
156 |
+
.map(|l| l as usize)
|
157 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
|
158 |
+
offset: offset
|
159 |
+
.map(|o| o as usize)
|
160 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_OFFSET),
|
161 |
+
params: params.map(From::from),
|
162 |
+
with_vector: with_vectors
|
163 |
+
.map(From::from)
|
164 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
|
165 |
+
with_payload: with_payload
|
166 |
+
.map(TryFrom::try_from)
|
167 |
+
.transpose()?
|
168 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
|
169 |
+
lookup_from: lookup_from.map(From::from),
|
170 |
+
})
|
171 |
+
}
|
172 |
+
|
173 |
+
fn convert_prefetch_with_inferred(
|
174 |
+
prefetch: grpc::PrefetchQuery,
|
175 |
+
inferred: &BatchAccumInferred,
|
176 |
+
) -> Result<CollectionPrefetch, Status> {
|
177 |
+
let grpc::PrefetchQuery {
|
178 |
+
prefetch,
|
179 |
+
query,
|
180 |
+
using,
|
181 |
+
filter,
|
182 |
+
params,
|
183 |
+
score_threshold,
|
184 |
+
limit,
|
185 |
+
lookup_from,
|
186 |
+
} = prefetch;
|
187 |
+
|
188 |
+
let nested_prefetches = prefetch
|
189 |
+
.into_iter()
|
190 |
+
.map(|p| convert_prefetch_with_inferred(p, inferred))
|
191 |
+
.collect::<Result<Vec<_>, _>>()?;
|
192 |
+
|
193 |
+
let query = query
|
194 |
+
.map(|q| convert_query_with_inferred(q, inferred))
|
195 |
+
.transpose()?;
|
196 |
+
|
197 |
+
Ok(CollectionPrefetch {
|
198 |
+
prefetch: nested_prefetches,
|
199 |
+
query,
|
200 |
+
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
|
201 |
+
filter: filter.map(TryFrom::try_from).transpose()?,
|
202 |
+
score_threshold,
|
203 |
+
limit: limit
|
204 |
+
.map(|l| l as usize)
|
205 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
|
206 |
+
params: params.map(From::from),
|
207 |
+
lookup_from: lookup_from.map(From::from),
|
208 |
+
})
|
209 |
+
}
|
210 |
+
|
211 |
+
fn convert_query_with_inferred(
|
212 |
+
query: grpc::Query,
|
213 |
+
inferred: &BatchAccumInferred,
|
214 |
+
) -> Result<Query, Status> {
|
215 |
+
let variant = query
|
216 |
+
.variant
|
217 |
+
.ok_or_else(|| Status::invalid_argument("Query variant is missing"))?;
|
218 |
+
|
219 |
+
let query = match variant {
|
220 |
+
Variant::Nearest(nearest) => {
|
221 |
+
let vector = convert_vector_input_with_inferred(nearest, inferred)?;
|
222 |
+
Query::Vector(VectorQuery::Nearest(vector))
|
223 |
+
}
|
224 |
+
Variant::Recommend(recommend) => {
|
225 |
+
let RecommendInput {
|
226 |
+
positive,
|
227 |
+
negative,
|
228 |
+
strategy,
|
229 |
+
} = recommend;
|
230 |
+
|
231 |
+
let positives = positive
|
232 |
+
.into_iter()
|
233 |
+
.map(|v| convert_vector_input_with_inferred(v, inferred))
|
234 |
+
.collect::<Result<Vec<_>, _>>()?;
|
235 |
+
|
236 |
+
let negatives = negative
|
237 |
+
.into_iter()
|
238 |
+
.map(|v| convert_vector_input_with_inferred(v, inferred))
|
239 |
+
.collect::<Result<Vec<_>, _>>()?;
|
240 |
+
|
241 |
+
let reco_query = RecoQuery::new(positives, negatives);
|
242 |
+
|
243 |
+
let strategy = strategy
|
244 |
+
.and_then(|x| grpc::RecommendStrategy::try_from(x).ok())
|
245 |
+
.map(RecommendStrategy::from)
|
246 |
+
.unwrap_or_default();
|
247 |
+
|
248 |
+
match strategy {
|
249 |
+
RecommendStrategy::AverageVector => {
|
250 |
+
Query::Vector(VectorQuery::RecommendAverageVector(reco_query))
|
251 |
+
}
|
252 |
+
RecommendStrategy::BestScore => {
|
253 |
+
Query::Vector(VectorQuery::RecommendBestScore(reco_query))
|
254 |
+
}
|
255 |
+
}
|
256 |
+
}
|
257 |
+
Variant::Discover(discover) => {
|
258 |
+
let grpc::DiscoverInput { target, context } = discover;
|
259 |
+
|
260 |
+
let target = target
|
261 |
+
.map(|t| convert_vector_input_with_inferred(t, inferred))
|
262 |
+
.transpose()?
|
263 |
+
.ok_or_else(|| Status::invalid_argument("DiscoverInput target is missing"))?;
|
264 |
+
|
265 |
+
let grpc::ContextInput { pairs } = context
|
266 |
+
.ok_or_else(|| Status::invalid_argument("DiscoverInput context is missing"))?;
|
267 |
+
|
268 |
+
let context = pairs
|
269 |
+
.into_iter()
|
270 |
+
.map(|pair| context_pair_from_grpc_with_inferred(pair, inferred))
|
271 |
+
.collect::<Result<_, _>>()?;
|
272 |
+
|
273 |
+
Query::Vector(VectorQuery::Discover(DiscoveryQuery::new(target, context)))
|
274 |
+
}
|
275 |
+
Variant::Context(context) => {
|
276 |
+
let context_query = context_query_from_grpc_with_inferred(context, inferred)?;
|
277 |
+
Query::Vector(VectorQuery::Context(context_query))
|
278 |
+
}
|
279 |
+
Variant::OrderBy(order_by) => Query::OrderBy(OrderBy::try_from(order_by)?),
|
280 |
+
Variant::Fusion(fusion) => Query::Fusion(FusionInternal::try_from(fusion)?),
|
281 |
+
Variant::Sample(sample) => Query::Sample(SampleInternal::try_from(sample)?),
|
282 |
+
};
|
283 |
+
|
284 |
+
Ok(query)
|
285 |
+
}
|
286 |
+
|
287 |
+
fn convert_vector_input_with_inferred(
|
288 |
+
vector: grpc::VectorInput,
|
289 |
+
inferred: &BatchAccumInferred,
|
290 |
+
) -> Result<VectorInputInternal, Status> {
|
291 |
+
use api::grpc::qdrant::vector_input::Variant;
|
292 |
+
|
293 |
+
let variant = vector
|
294 |
+
.variant
|
295 |
+
.ok_or_else(|| Status::invalid_argument("VectorInput variant is missing"))?;
|
296 |
+
|
297 |
+
match variant {
|
298 |
+
Variant::Id(id) => Ok(VectorInputInternal::Id(TryFrom::try_from(id)?)),
|
299 |
+
Variant::Dense(dense) => Ok(VectorInputInternal::Vector(VectorInternal::Dense(
|
300 |
+
From::from(dense),
|
301 |
+
))),
|
302 |
+
Variant::Sparse(sparse) => Ok(VectorInputInternal::Vector(VectorInternal::Sparse(
|
303 |
+
From::from(sparse),
|
304 |
+
))),
|
305 |
+
Variant::MultiDense(multi_dense) => Ok(VectorInputInternal::Vector(
|
306 |
+
VectorInternal::MultiDense(From::from(multi_dense)),
|
307 |
+
)),
|
308 |
+
Variant::Document(doc) => {
|
309 |
+
let doc: rest::Document = doc
|
310 |
+
.try_into()
|
311 |
+
.map_err(|e| Status::internal(format!("Document conversion error: {e}")))?;
|
312 |
+
let data = InferenceData::Document(doc);
|
313 |
+
let vector = inferred
|
314 |
+
.get_vector(&data)
|
315 |
+
.ok_or_else(|| Status::internal("Missing inferred vector for document"))?;
|
316 |
+
|
317 |
+
Ok(VectorInputInternal::Vector(VectorInternal::from(
|
318 |
+
vector.clone(),
|
319 |
+
)))
|
320 |
+
}
|
321 |
+
Variant::Image(img) => {
|
322 |
+
let img: rest::Image = img
|
323 |
+
.try_into()
|
324 |
+
.map_err(|e| Status::internal(format!("Image conversion error: {e}",)))?;
|
325 |
+
let data = InferenceData::Image(img);
|
326 |
+
|
327 |
+
let vector = inferred
|
328 |
+
.get_vector(&data)
|
329 |
+
.ok_or_else(|| Status::internal("Missing inferred vector for image"))?;
|
330 |
+
|
331 |
+
Ok(VectorInputInternal::Vector(VectorInternal::from(
|
332 |
+
vector.clone(),
|
333 |
+
)))
|
334 |
+
}
|
335 |
+
Variant::Object(obj) => {
|
336 |
+
let obj: rest::InferenceObject = obj
|
337 |
+
.try_into()
|
338 |
+
.map_err(|e| Status::internal(format!("Object conversion error: {e}")))?;
|
339 |
+
let data = InferenceData::Object(obj);
|
340 |
+
let vector = inferred
|
341 |
+
.get_vector(&data)
|
342 |
+
.ok_or_else(|| Status::internal("Missing inferred vector for object"))?;
|
343 |
+
|
344 |
+
Ok(VectorInputInternal::Vector(VectorInternal::from(
|
345 |
+
vector.clone(),
|
346 |
+
)))
|
347 |
+
}
|
348 |
+
}
|
349 |
+
}
|
350 |
+
|
351 |
+
fn context_query_from_grpc_with_inferred(
|
352 |
+
value: grpc::ContextInput,
|
353 |
+
inferred: &BatchAccumInferred,
|
354 |
+
) -> Result<ContextQuery<VectorInputInternal>, Status> {
|
355 |
+
let grpc::ContextInput { pairs } = value;
|
356 |
+
|
357 |
+
Ok(ContextQuery {
|
358 |
+
pairs: pairs
|
359 |
+
.into_iter()
|
360 |
+
.map(|pair| context_pair_from_grpc_with_inferred(pair, inferred))
|
361 |
+
.collect::<Result<_, _>>()?,
|
362 |
+
})
|
363 |
+
}
|
364 |
+
|
365 |
+
fn context_pair_from_grpc_with_inferred(
|
366 |
+
value: grpc::ContextInputPair,
|
367 |
+
inferred: &BatchAccumInferred,
|
368 |
+
) -> Result<ContextPair<VectorInputInternal>, Status> {
|
369 |
+
let grpc::ContextInputPair { positive, negative } = value;
|
370 |
+
|
371 |
+
let positive =
|
372 |
+
positive.ok_or_else(|| Status::invalid_argument("ContextPair positive is missing"))?;
|
373 |
+
let negative =
|
374 |
+
negative.ok_or_else(|| Status::invalid_argument("ContextPair negative is missing"))?;
|
375 |
+
|
376 |
+
Ok(ContextPair {
|
377 |
+
positive: convert_vector_input_with_inferred(positive, inferred)?,
|
378 |
+
negative: convert_vector_input_with_inferred(negative, inferred)?,
|
379 |
+
})
|
380 |
+
}
|
381 |
+
|
382 |
+
#[cfg(test)]
|
383 |
+
mod tests {
|
384 |
+
use std::collections::HashMap;
|
385 |
+
|
386 |
+
use api::grpc::qdrant::value::Kind;
|
387 |
+
use api::grpc::qdrant::vector_input::Variant;
|
388 |
+
use api::grpc::qdrant::Value;
|
389 |
+
use collection::operations::point_ops::VectorPersisted;
|
390 |
+
|
391 |
+
use super::*;
|
392 |
+
|
393 |
+
fn create_test_document() -> api::grpc::qdrant::Document {
|
394 |
+
api::grpc::qdrant::Document {
|
395 |
+
text: "test".to_string(),
|
396 |
+
model: "test-model".to_string(),
|
397 |
+
options: HashMap::new(),
|
398 |
+
}
|
399 |
+
}
|
400 |
+
|
401 |
+
fn create_test_image() -> api::grpc::qdrant::Image {
|
402 |
+
api::grpc::qdrant::Image {
|
403 |
+
image: Some(Value {
|
404 |
+
kind: Some(Kind::StringValue("test.jpg".to_string())),
|
405 |
+
}),
|
406 |
+
model: "test-model".to_string(),
|
407 |
+
options: HashMap::new(),
|
408 |
+
}
|
409 |
+
}
|
410 |
+
|
411 |
+
fn create_test_object() -> api::grpc::qdrant::InferenceObject {
|
412 |
+
api::grpc::qdrant::InferenceObject {
|
413 |
+
object: Some(Value {
|
414 |
+
kind: Some(Kind::StringValue("test".to_string())),
|
415 |
+
}),
|
416 |
+
model: "test-model".to_string(),
|
417 |
+
options: HashMap::new(),
|
418 |
+
}
|
419 |
+
}
|
420 |
+
|
421 |
+
fn create_test_inferred_batch() -> BatchAccumInferred {
|
422 |
+
let mut objects = HashMap::new();
|
423 |
+
|
424 |
+
let grpc_doc = create_test_document();
|
425 |
+
let grpc_img = create_test_image();
|
426 |
+
let grpc_obj = create_test_object();
|
427 |
+
|
428 |
+
let doc: rest::Document = grpc_doc.try_into().unwrap();
|
429 |
+
let img: rest::Image = grpc_img.try_into().unwrap();
|
430 |
+
let obj: rest::InferenceObject = grpc_obj.try_into().unwrap();
|
431 |
+
|
432 |
+
let doc_data = InferenceData::Document(doc);
|
433 |
+
let img_data = InferenceData::Image(img);
|
434 |
+
let obj_data = InferenceData::Object(obj);
|
435 |
+
|
436 |
+
let dense_vector = vec![1.0, 2.0, 3.0];
|
437 |
+
let vector_persisted = VectorPersisted::Dense(dense_vector);
|
438 |
+
|
439 |
+
objects.insert(doc_data, vector_persisted.clone());
|
440 |
+
objects.insert(img_data, vector_persisted.clone());
|
441 |
+
objects.insert(obj_data, vector_persisted);
|
442 |
+
|
443 |
+
BatchAccumInferred { objects }
|
444 |
+
}
|
445 |
+
|
446 |
+
#[test]
|
447 |
+
fn test_convert_vector_input_with_inferred_dense() {
|
448 |
+
let inferred = create_test_inferred_batch();
|
449 |
+
let vector = grpc::VectorInput {
|
450 |
+
variant: Some(Variant::Dense(grpc::DenseVector {
|
451 |
+
data: vec![1.0, 2.0, 3.0],
|
452 |
+
})),
|
453 |
+
};
|
454 |
+
|
455 |
+
let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
|
456 |
+
match result {
|
457 |
+
VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
|
458 |
+
assert_eq!(values, vec![1.0, 2.0, 3.0]);
|
459 |
+
}
|
460 |
+
_ => panic!("Expected dense vector"),
|
461 |
+
}
|
462 |
+
}
|
463 |
+
|
464 |
+
#[test]
|
465 |
+
fn test_convert_vector_input_with_inferred_document() {
|
466 |
+
let inferred = create_test_inferred_batch();
|
467 |
+
let doc = create_test_document();
|
468 |
+
let vector = grpc::VectorInput {
|
469 |
+
variant: Some(Variant::Document(doc)),
|
470 |
+
};
|
471 |
+
|
472 |
+
let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
|
473 |
+
match result {
|
474 |
+
VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
|
475 |
+
assert_eq!(values, vec![1.0, 2.0, 3.0]);
|
476 |
+
}
|
477 |
+
_ => panic!("Expected dense vector from inference"),
|
478 |
+
}
|
479 |
+
}
|
480 |
+
|
481 |
+
#[test]
|
482 |
+
fn test_convert_vector_input_missing_variant() {
|
483 |
+
let inferred = create_test_inferred_batch();
|
484 |
+
let vector = grpc::VectorInput { variant: None };
|
485 |
+
|
486 |
+
let result = convert_vector_input_with_inferred(vector, &inferred);
|
487 |
+
assert!(result.is_err());
|
488 |
+
assert!(result.unwrap_err().message().contains("variant is missing"));
|
489 |
+
}
|
490 |
+
|
491 |
+
#[test]
|
492 |
+
fn test_context_pair_from_grpc_with_inferred() {
|
493 |
+
let inferred = create_test_inferred_batch();
|
494 |
+
let pair = grpc::ContextInputPair {
|
495 |
+
positive: Some(grpc::VectorInput {
|
496 |
+
variant: Some(Variant::Dense(grpc::DenseVector {
|
497 |
+
data: vec![1.0, 2.0, 3.0],
|
498 |
+
})),
|
499 |
+
}),
|
500 |
+
negative: Some(grpc::VectorInput {
|
501 |
+
variant: Some(Variant::Document(create_test_document())),
|
502 |
+
}),
|
503 |
+
};
|
504 |
+
|
505 |
+
let result = context_pair_from_grpc_with_inferred(pair, &inferred).unwrap();
|
506 |
+
match (result.positive, result.negative) {
|
507 |
+
(
|
508 |
+
VectorInputInternal::Vector(VectorInternal::Dense(pos)),
|
509 |
+
VectorInputInternal::Vector(VectorInternal::Dense(neg)),
|
510 |
+
) => {
|
511 |
+
assert_eq!(pos, vec![1.0, 2.0, 3.0]);
|
512 |
+
assert_eq!(neg, vec![1.0, 2.0, 3.0]);
|
513 |
+
}
|
514 |
+
_ => panic!("Expected dense vectors"),
|
515 |
+
}
|
516 |
+
}
|
517 |
+
|
518 |
+
#[test]
|
519 |
+
fn test_context_pair_missing_vectors() {
|
520 |
+
let inferred = create_test_inferred_batch();
|
521 |
+
let pair = grpc::ContextInputPair {
|
522 |
+
positive: None,
|
523 |
+
negative: Some(grpc::VectorInput {
|
524 |
+
variant: Some(Variant::Document(create_test_document())),
|
525 |
+
}),
|
526 |
+
};
|
527 |
+
|
528 |
+
let result = context_pair_from_grpc_with_inferred(pair, &inferred);
|
529 |
+
assert!(result.is_err());
|
530 |
+
assert!(result
|
531 |
+
.unwrap_err()
|
532 |
+
.message()
|
533 |
+
.contains("positive is missing"));
|
534 |
+
}
|
535 |
+
}
|
src/common/inference/query_requests_rest.rs
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use api::rest::schema as rest;
|
2 |
+
use collection::lookup::WithLookup;
|
3 |
+
use collection::operations::universal_query::collection_query::{
|
4 |
+
CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, Query,
|
5 |
+
VectorInputInternal, VectorQuery,
|
6 |
+
};
|
7 |
+
use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal};
|
8 |
+
use segment::data_types::order_by::OrderBy;
|
9 |
+
use segment::data_types::vectors::{MultiDenseVectorInternal, VectorInternal, DEFAULT_VECTOR_NAME};
|
10 |
+
use segment::vector_storage::query::{ContextPair, ContextQuery, DiscoveryQuery, RecoQuery};
|
11 |
+
use storage::content_manager::errors::StorageError;
|
12 |
+
|
13 |
+
use crate::common::inference::batch_processing::{
|
14 |
+
collect_query_groups_request, collect_query_request,
|
15 |
+
};
|
16 |
+
use crate::common::inference::infer_processing::BatchAccumInferred;
|
17 |
+
use crate::common::inference::service::{InferenceData, InferenceType};
|
18 |
+
|
19 |
+
pub async fn convert_query_groups_request_from_rest(
|
20 |
+
request: rest::QueryGroupsRequestInternal,
|
21 |
+
) -> Result<CollectionQueryGroupsRequest, StorageError> {
|
22 |
+
let batch = collect_query_groups_request(&request);
|
23 |
+
let rest::QueryGroupsRequestInternal {
|
24 |
+
prefetch,
|
25 |
+
query,
|
26 |
+
using,
|
27 |
+
filter,
|
28 |
+
score_threshold,
|
29 |
+
params,
|
30 |
+
with_vector,
|
31 |
+
with_payload,
|
32 |
+
lookup_from,
|
33 |
+
group_request,
|
34 |
+
} = request;
|
35 |
+
|
36 |
+
let inferred = BatchAccumInferred::from_batch_accum(batch, InferenceType::Search).await?;
|
37 |
+
let query = query
|
38 |
+
.map(|q| convert_query_with_inferred(q, &inferred))
|
39 |
+
.transpose()?;
|
40 |
+
|
41 |
+
let prefetch = prefetch
|
42 |
+
.map(|prefetches| {
|
43 |
+
prefetches
|
44 |
+
.into_iter()
|
45 |
+
.map(|p| convert_prefetch_with_inferred(p, &inferred))
|
46 |
+
.collect::<Result<Vec<_>, _>>()
|
47 |
+
})
|
48 |
+
.transpose()?
|
49 |
+
.unwrap_or_default();
|
50 |
+
|
51 |
+
Ok(CollectionQueryGroupsRequest {
|
52 |
+
prefetch,
|
53 |
+
query,
|
54 |
+
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
|
55 |
+
filter,
|
56 |
+
score_threshold,
|
57 |
+
params,
|
58 |
+
with_vector: with_vector.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
|
59 |
+
with_payload: with_payload.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
|
60 |
+
lookup_from,
|
61 |
+
limit: group_request
|
62 |
+
.limit
|
63 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
|
64 |
+
group_by: group_request.group_by,
|
65 |
+
group_size: group_request
|
66 |
+
.group_size
|
67 |
+
.unwrap_or(CollectionQueryRequest::DEFAULT_GROUP_SIZE),
|
68 |
+
with_lookup: group_request.with_lookup.map(WithLookup::from),
|
69 |
+
})
|
70 |
+
}
|
71 |
+
|
72 |
+
pub async fn convert_query_request_from_rest(
|
73 |
+
request: rest::QueryRequestInternal,
|
74 |
+
) -> Result<CollectionQueryRequest, StorageError> {
|
75 |
+
let batch = collect_query_request(&request);
|
76 |
+
let inferred = BatchAccumInferred::from_batch_accum(batch, InferenceType::Search).await?;
|
77 |
+
let rest::QueryRequestInternal {
|
78 |
+
prefetch,
|
79 |
+
query,
|
80 |
+
using,
|
81 |
+
filter,
|
82 |
+
score_threshold,
|
83 |
+
params,
|
84 |
+
limit,
|
85 |
+
offset,
|
86 |
+
with_vector,
|
87 |
+
with_payload,
|
88 |
+
lookup_from,
|
89 |
+
} = request;
|
90 |
+
|
91 |
+
let prefetch = prefetch
|
92 |
+
.map(|prefetches| {
|
93 |
+
prefetches
|
94 |
+
.into_iter()
|
95 |
+
.map(|p| convert_prefetch_with_inferred(p, &inferred))
|
96 |
+
.collect::<Result<Vec<_>, _>>()
|
97 |
+
})
|
98 |
+
.transpose()?
|
99 |
+
.unwrap_or_default();
|
100 |
+
|
101 |
+
let query = query
|
102 |
+
.map(|q| convert_query_with_inferred(q, &inferred))
|
103 |
+
.transpose()?;
|
104 |
+
|
105 |
+
Ok(CollectionQueryRequest {
|
106 |
+
prefetch,
|
107 |
+
query,
|
108 |
+
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
|
109 |
+
filter,
|
110 |
+
score_threshold,
|
111 |
+
limit: limit.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
|
112 |
+
offset: offset.unwrap_or(CollectionQueryRequest::DEFAULT_OFFSET),
|
113 |
+
params,
|
114 |
+
with_vector: with_vector.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
|
115 |
+
with_payload: with_payload.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
|
116 |
+
lookup_from,
|
117 |
+
})
|
118 |
+
}
|
119 |
+
|
120 |
+
fn convert_vector_input_with_inferred(
|
121 |
+
vector: rest::VectorInput,
|
122 |
+
inferred: &BatchAccumInferred,
|
123 |
+
) -> Result<VectorInputInternal, StorageError> {
|
124 |
+
match vector {
|
125 |
+
rest::VectorInput::Id(id) => Ok(VectorInputInternal::Id(id)),
|
126 |
+
rest::VectorInput::DenseVector(dense) => {
|
127 |
+
Ok(VectorInputInternal::Vector(VectorInternal::Dense(dense)))
|
128 |
+
}
|
129 |
+
rest::VectorInput::SparseVector(sparse) => {
|
130 |
+
Ok(VectorInputInternal::Vector(VectorInternal::Sparse(sparse)))
|
131 |
+
}
|
132 |
+
rest::VectorInput::MultiDenseVector(multi_dense) => Ok(VectorInputInternal::Vector(
|
133 |
+
VectorInternal::MultiDense(MultiDenseVectorInternal::new_unchecked(multi_dense)),
|
134 |
+
)),
|
135 |
+
rest::VectorInput::Document(doc) => {
|
136 |
+
let data = InferenceData::Document(doc);
|
137 |
+
let vector = inferred.get_vector(&data).ok_or_else(|| {
|
138 |
+
StorageError::inference_error("Missing inferred vector for document")
|
139 |
+
})?;
|
140 |
+
Ok(VectorInputInternal::Vector(VectorInternal::from(
|
141 |
+
vector.clone(),
|
142 |
+
)))
|
143 |
+
}
|
144 |
+
rest::VectorInput::Image(img) => {
|
145 |
+
let data = InferenceData::Image(img);
|
146 |
+
let vector = inferred.get_vector(&data).ok_or_else(|| {
|
147 |
+
StorageError::inference_error("Missing inferred vector for image")
|
148 |
+
})?;
|
149 |
+
Ok(VectorInputInternal::Vector(VectorInternal::from(
|
150 |
+
vector.clone(),
|
151 |
+
)))
|
152 |
+
}
|
153 |
+
rest::VectorInput::Object(obj) => {
|
154 |
+
let data = InferenceData::Object(obj);
|
155 |
+
let vector = inferred.get_vector(&data).ok_or_else(|| {
|
156 |
+
StorageError::inference_error("Missing inferred vector for object")
|
157 |
+
})?;
|
158 |
+
Ok(VectorInputInternal::Vector(VectorInternal::from(
|
159 |
+
vector.clone(),
|
160 |
+
)))
|
161 |
+
}
|
162 |
+
}
|
163 |
+
}
|
164 |
+
|
165 |
+
fn convert_query_with_inferred(
|
166 |
+
query: rest::QueryInterface,
|
167 |
+
inferred: &BatchAccumInferred,
|
168 |
+
) -> Result<Query, StorageError> {
|
169 |
+
let query = rest::Query::from(query);
|
170 |
+
match query {
|
171 |
+
rest::Query::Nearest(nearest) => {
|
172 |
+
let vector = convert_vector_input_with_inferred(nearest.nearest, inferred)?;
|
173 |
+
Ok(Query::Vector(VectorQuery::Nearest(vector)))
|
174 |
+
}
|
175 |
+
rest::Query::Recommend(recommend) => {
|
176 |
+
let rest::RecommendInput {
|
177 |
+
positive,
|
178 |
+
negative,
|
179 |
+
strategy,
|
180 |
+
} = recommend.recommend;
|
181 |
+
let positives = positive
|
182 |
+
.into_iter()
|
183 |
+
.flatten()
|
184 |
+
.map(|v| convert_vector_input_with_inferred(v, inferred))
|
185 |
+
.collect::<Result<Vec<_>, _>>()?;
|
186 |
+
let negatives = negative
|
187 |
+
.into_iter()
|
188 |
+
.flatten()
|
189 |
+
.map(|v| convert_vector_input_with_inferred(v, inferred))
|
190 |
+
.collect::<Result<Vec<_>, _>>()?;
|
191 |
+
let reco_query = RecoQuery::new(positives, negatives);
|
192 |
+
match strategy.unwrap_or_default() {
|
193 |
+
rest::RecommendStrategy::AverageVector => Ok(Query::Vector(
|
194 |
+
VectorQuery::RecommendAverageVector(reco_query),
|
195 |
+
)),
|
196 |
+
rest::RecommendStrategy::BestScore => {
|
197 |
+
Ok(Query::Vector(VectorQuery::RecommendBestScore(reco_query)))
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
rest::Query::Discover(discover) => {
|
202 |
+
let rest::DiscoverInput { target, context } = discover.discover;
|
203 |
+
let target = convert_vector_input_with_inferred(target, inferred)?;
|
204 |
+
let context = context
|
205 |
+
.into_iter()
|
206 |
+
.flatten()
|
207 |
+
.map(|pair| context_pair_from_rest_with_inferred(pair, inferred))
|
208 |
+
.collect::<Result<Vec<_>, _>>()?;
|
209 |
+
Ok(Query::Vector(VectorQuery::Discover(DiscoveryQuery::new(
|
210 |
+
target, context,
|
211 |
+
))))
|
212 |
+
}
|
213 |
+
rest::Query::Context(context) => {
|
214 |
+
let rest::ContextInput(context) = context.context;
|
215 |
+
let context = context
|
216 |
+
.into_iter()
|
217 |
+
.flatten()
|
218 |
+
.map(|pair| context_pair_from_rest_with_inferred(pair, inferred))
|
219 |
+
.collect::<Result<Vec<_>, _>>()?;
|
220 |
+
Ok(Query::Vector(VectorQuery::Context(ContextQuery::new(
|
221 |
+
context,
|
222 |
+
))))
|
223 |
+
}
|
224 |
+
rest::Query::OrderBy(order_by) => Ok(Query::OrderBy(OrderBy::from(order_by.order_by))),
|
225 |
+
rest::Query::Fusion(fusion) => Ok(Query::Fusion(FusionInternal::from(fusion.fusion))),
|
226 |
+
rest::Query::Sample(sample) => Ok(Query::Sample(SampleInternal::from(sample.sample))),
|
227 |
+
}
|
228 |
+
}
|
229 |
+
|
230 |
+
fn convert_prefetch_with_inferred(
|
231 |
+
prefetch: rest::Prefetch,
|
232 |
+
inferred: &BatchAccumInferred,
|
233 |
+
) -> Result<CollectionPrefetch, StorageError> {
|
234 |
+
let rest::Prefetch {
|
235 |
+
prefetch,
|
236 |
+
query,
|
237 |
+
using,
|
238 |
+
filter,
|
239 |
+
score_threshold,
|
240 |
+
params,
|
241 |
+
limit,
|
242 |
+
lookup_from,
|
243 |
+
} = prefetch;
|
244 |
+
|
245 |
+
let query = query
|
246 |
+
.map(|q| convert_query_with_inferred(q, inferred))
|
247 |
+
.transpose()?;
|
248 |
+
let nested_prefetches = prefetch
|
249 |
+
.map(|prefetches| {
|
250 |
+
prefetches
|
251 |
+
.into_iter()
|
252 |
+
.map(|p| convert_prefetch_with_inferred(p, inferred))
|
253 |
+
.collect::<Result<Vec<_>, _>>()
|
254 |
+
})
|
255 |
+
.transpose()?
|
256 |
+
.unwrap_or_default();
|
257 |
+
|
258 |
+
Ok(CollectionPrefetch {
|
259 |
+
prefetch: nested_prefetches,
|
260 |
+
query,
|
261 |
+
using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
|
262 |
+
filter,
|
263 |
+
score_threshold,
|
264 |
+
limit: limit.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
|
265 |
+
params,
|
266 |
+
lookup_from,
|
267 |
+
})
|
268 |
+
}
|
269 |
+
|
270 |
+
fn context_pair_from_rest_with_inferred(
|
271 |
+
value: rest::ContextPair,
|
272 |
+
inferred: &BatchAccumInferred,
|
273 |
+
) -> Result<ContextPair<VectorInputInternal>, StorageError> {
|
274 |
+
let rest::ContextPair { positive, negative } = value;
|
275 |
+
Ok(ContextPair {
|
276 |
+
positive: convert_vector_input_with_inferred(positive, inferred)?,
|
277 |
+
negative: convert_vector_input_with_inferred(negative, inferred)?,
|
278 |
+
})
|
279 |
+
}
|
280 |
+
|
281 |
+
#[cfg(test)]
|
282 |
+
mod tests {
|
283 |
+
use std::collections::HashMap;
|
284 |
+
|
285 |
+
use api::rest::schema::{Document, Image, InferenceObject, NearestQuery};
|
286 |
+
use collection::operations::point_ops::VectorPersisted;
|
287 |
+
use serde_json::json;
|
288 |
+
|
289 |
+
use super::*;
|
290 |
+
|
291 |
+
fn create_test_document(text: &str) -> Document {
|
292 |
+
Document {
|
293 |
+
text: text.to_string(),
|
294 |
+
model: "test-model".to_string(),
|
295 |
+
options: Default::default(),
|
296 |
+
}
|
297 |
+
}
|
298 |
+
|
299 |
+
fn create_test_image(url: &str) -> Image {
|
300 |
+
Image {
|
301 |
+
image: json!({"data": url.to_string()}),
|
302 |
+
model: "test-model".to_string(),
|
303 |
+
options: Default::default(),
|
304 |
+
}
|
305 |
+
}
|
306 |
+
|
307 |
+
fn create_test_object(data: &str) -> InferenceObject {
|
308 |
+
InferenceObject {
|
309 |
+
object: json!({"data": data}),
|
310 |
+
model: "test-model".to_string(),
|
311 |
+
options: Default::default(),
|
312 |
+
}
|
313 |
+
}
|
314 |
+
|
315 |
+
fn create_test_inferred_batch() -> BatchAccumInferred {
|
316 |
+
let mut objects = HashMap::new();
|
317 |
+
|
318 |
+
let doc = InferenceData::Document(create_test_document("test"));
|
319 |
+
let img = InferenceData::Image(create_test_image("test.jpg"));
|
320 |
+
let obj = InferenceData::Object(create_test_object("test"));
|
321 |
+
|
322 |
+
let dense_vector = vec![1.0, 2.0, 3.0];
|
323 |
+
let vector_persisted = VectorPersisted::Dense(dense_vector);
|
324 |
+
|
325 |
+
objects.insert(doc, vector_persisted.clone());
|
326 |
+
objects.insert(img, vector_persisted.clone());
|
327 |
+
objects.insert(obj, vector_persisted);
|
328 |
+
|
329 |
+
BatchAccumInferred { objects }
|
330 |
+
}
|
331 |
+
|
332 |
+
#[test]
|
333 |
+
fn test_convert_vector_input_with_inferred_dense() {
|
334 |
+
let inferred = create_test_inferred_batch();
|
335 |
+
let vector = rest::VectorInput::DenseVector(vec![1.0, 2.0, 3.0]);
|
336 |
+
|
337 |
+
let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
|
338 |
+
match result {
|
339 |
+
VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
|
340 |
+
assert_eq!(values, vec![1.0, 2.0, 3.0]);
|
341 |
+
}
|
342 |
+
_ => panic!("Expected dense vector"),
|
343 |
+
}
|
344 |
+
}
|
345 |
+
|
346 |
+
#[test]
|
347 |
+
fn test_convert_vector_input_with_inferred_document() {
|
348 |
+
let inferred = create_test_inferred_batch();
|
349 |
+
let doc = create_test_document("test");
|
350 |
+
let vector = rest::VectorInput::Document(doc);
|
351 |
+
|
352 |
+
let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
|
353 |
+
match result {
|
354 |
+
VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
|
355 |
+
assert_eq!(values, vec![1.0, 2.0, 3.0]);
|
356 |
+
}
|
357 |
+
_ => panic!("Expected dense vector from inference"),
|
358 |
+
}
|
359 |
+
}
|
360 |
+
|
361 |
+
#[test]
|
362 |
+
fn test_convert_vector_input_with_inferred_missing() {
|
363 |
+
let inferred = create_test_inferred_batch();
|
364 |
+
let doc = create_test_document("missing");
|
365 |
+
let vector = rest::VectorInput::Document(doc);
|
366 |
+
|
367 |
+
let result = convert_vector_input_with_inferred(vector, &inferred);
|
368 |
+
assert!(result.is_err());
|
369 |
+
assert!(result
|
370 |
+
.unwrap_err()
|
371 |
+
.to_string()
|
372 |
+
.contains("Missing inferred vector"));
|
373 |
+
}
|
374 |
+
|
375 |
+
#[test]
|
376 |
+
fn test_context_pair_from_rest_with_inferred() {
|
377 |
+
let inferred = create_test_inferred_batch();
|
378 |
+
let pair = rest::ContextPair {
|
379 |
+
positive: rest::VectorInput::DenseVector(vec![1.0, 2.0, 3.0]),
|
380 |
+
negative: rest::VectorInput::Document(create_test_document("test")),
|
381 |
+
};
|
382 |
+
|
383 |
+
let result = context_pair_from_rest_with_inferred(pair, &inferred).unwrap();
|
384 |
+
match (result.positive, result.negative) {
|
385 |
+
(
|
386 |
+
VectorInputInternal::Vector(VectorInternal::Dense(pos)),
|
387 |
+
VectorInputInternal::Vector(VectorInternal::Dense(neg)),
|
388 |
+
) => {
|
389 |
+
assert_eq!(pos, vec![1.0, 2.0, 3.0]);
|
390 |
+
assert_eq!(neg, vec![1.0, 2.0, 3.0]);
|
391 |
+
}
|
392 |
+
_ => panic!("Expected dense vectors"),
|
393 |
+
}
|
394 |
+
}
|
395 |
+
|
396 |
+
#[test]
|
397 |
+
fn test_convert_query_with_inferred_nearest() {
|
398 |
+
let inferred = create_test_inferred_batch();
|
399 |
+
let nearest = NearestQuery {
|
400 |
+
nearest: rest::VectorInput::Document(create_test_document("test")),
|
401 |
+
};
|
402 |
+
let query = rest::QueryInterface::Query(rest::Query::Nearest(nearest));
|
403 |
+
|
404 |
+
let result = convert_query_with_inferred(query, &inferred).unwrap();
|
405 |
+
match result {
|
406 |
+
Query::Vector(VectorQuery::Nearest(vector)) => match vector {
|
407 |
+
VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
|
408 |
+
assert_eq!(values, vec![1.0, 2.0, 3.0]);
|
409 |
+
}
|
410 |
+
_ => panic!("Expected dense vector"),
|
411 |
+
},
|
412 |
+
_ => panic!("Expected nearest query"),
|
413 |
+
}
|
414 |
+
}
|
415 |
+
}
|
src/common/inference/service.rs
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::collections::HashMap;
|
2 |
+
use std::fmt::Display;
|
3 |
+
use std::hash::Hash;
|
4 |
+
use std::sync::Arc;
|
5 |
+
use std::time::Duration;
|
6 |
+
|
7 |
+
use api::rest::{Document, Image, InferenceObject};
|
8 |
+
use collection::operations::point_ops::VectorPersisted;
|
9 |
+
use parking_lot::RwLock;
|
10 |
+
use reqwest::Client;
|
11 |
+
use serde::{Deserialize, Serialize};
|
12 |
+
use serde_json::Value;
|
13 |
+
use storage::content_manager::errors::StorageError;
|
14 |
+
|
15 |
+
use crate::common::inference::config::InferenceConfig;
|
16 |
+
|
17 |
+
const DOCUMENT_DATA_TYPE: &str = "text";
|
18 |
+
const IMAGE_DATA_TYPE: &str = "image";
|
19 |
+
const OBJECT_DATA_TYPE: &str = "object";
|
20 |
+
|
21 |
+
#[derive(Debug, Serialize, Default, Clone, Copy)]
|
22 |
+
#[serde(rename_all = "lowercase")]
|
23 |
+
pub enum InferenceType {
|
24 |
+
#[default]
|
25 |
+
Update,
|
26 |
+
Search,
|
27 |
+
}
|
28 |
+
|
29 |
+
impl Display for InferenceType {
|
30 |
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
31 |
+
write!(f, "{}", format!("{self:?}").to_lowercase())
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
#[derive(Debug, Serialize)]
|
36 |
+
pub struct InferenceRequest {
|
37 |
+
pub(crate) inputs: Vec<InferenceInput>,
|
38 |
+
pub(crate) inference: Option<InferenceType>,
|
39 |
+
#[serde(default)]
|
40 |
+
pub(crate) token: Option<String>,
|
41 |
+
}
|
42 |
+
|
43 |
+
#[derive(Debug, Serialize)]
|
44 |
+
pub struct InferenceInput {
|
45 |
+
data: Value,
|
46 |
+
data_type: String,
|
47 |
+
model: String,
|
48 |
+
options: Option<HashMap<String, Value>>,
|
49 |
+
}
|
50 |
+
|
51 |
+
#[derive(Debug, Deserialize)]
|
52 |
+
pub(crate) struct InferenceResponse {
|
53 |
+
pub(crate) embeddings: Vec<VectorPersisted>,
|
54 |
+
}
|
55 |
+
|
56 |
+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
|
57 |
+
pub enum InferenceData {
|
58 |
+
Document(Document),
|
59 |
+
Image(Image),
|
60 |
+
Object(InferenceObject),
|
61 |
+
}
|
62 |
+
|
63 |
+
impl InferenceData {
|
64 |
+
pub(crate) fn type_name(&self) -> &'static str {
|
65 |
+
match self {
|
66 |
+
InferenceData::Document(_) => "document",
|
67 |
+
InferenceData::Image(_) => "image",
|
68 |
+
InferenceData::Object(_) => "object",
|
69 |
+
}
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
impl From<InferenceData> for InferenceInput {
|
74 |
+
fn from(value: InferenceData) -> Self {
|
75 |
+
match value {
|
76 |
+
InferenceData::Document(doc) => {
|
77 |
+
let Document {
|
78 |
+
text,
|
79 |
+
model,
|
80 |
+
options,
|
81 |
+
} = doc;
|
82 |
+
InferenceInput {
|
83 |
+
data: Value::String(text),
|
84 |
+
data_type: DOCUMENT_DATA_TYPE.to_string(),
|
85 |
+
model: model.to_string(),
|
86 |
+
options: options.options,
|
87 |
+
}
|
88 |
+
}
|
89 |
+
InferenceData::Image(img) => {
|
90 |
+
let Image {
|
91 |
+
image,
|
92 |
+
model,
|
93 |
+
options,
|
94 |
+
} = img;
|
95 |
+
InferenceInput {
|
96 |
+
data: image,
|
97 |
+
data_type: IMAGE_DATA_TYPE.to_string(),
|
98 |
+
model: model.to_string(),
|
99 |
+
options: options.options,
|
100 |
+
}
|
101 |
+
}
|
102 |
+
InferenceData::Object(obj) => {
|
103 |
+
let InferenceObject {
|
104 |
+
object,
|
105 |
+
model,
|
106 |
+
options,
|
107 |
+
} = obj;
|
108 |
+
InferenceInput {
|
109 |
+
data: object,
|
110 |
+
data_type: OBJECT_DATA_TYPE.to_string(),
|
111 |
+
model: model.to_string(),
|
112 |
+
options: options.options,
|
113 |
+
}
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
pub struct InferenceService {
|
120 |
+
pub(crate) config: InferenceConfig,
|
121 |
+
pub(crate) client: Client,
|
122 |
+
}
|
123 |
+
|
124 |
+
static INFERENCE_SERVICE: RwLock<Option<Arc<InferenceService>>> = RwLock::new(None);
|
125 |
+
|
126 |
+
impl InferenceService {
|
127 |
+
pub fn new(config: InferenceConfig) -> Self {
|
128 |
+
let timeout = Duration::from_secs(config.timeout);
|
129 |
+
Self {
|
130 |
+
config,
|
131 |
+
client: Client::builder()
|
132 |
+
.timeout(timeout)
|
133 |
+
.build()
|
134 |
+
.expect("Invalid timeout value for HTTP client"),
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
pub fn init_global(config: InferenceConfig) -> Result<(), StorageError> {
|
139 |
+
let mut inference_service = INFERENCE_SERVICE.write();
|
140 |
+
|
141 |
+
if config.token.is_none() {
|
142 |
+
return Err(StorageError::service_error(
|
143 |
+
"Cannot initialize InferenceService: token is required but not provided in config",
|
144 |
+
));
|
145 |
+
}
|
146 |
+
|
147 |
+
if config.address.is_none() || config.address.as_ref().unwrap().is_empty() {
|
148 |
+
return Err(StorageError::service_error(
|
149 |
+
"Cannot initialize InferenceService: address is required but not provided or empty in config"
|
150 |
+
));
|
151 |
+
}
|
152 |
+
|
153 |
+
*inference_service = Some(Arc::new(Self::new(config)));
|
154 |
+
Ok(())
|
155 |
+
}
|
156 |
+
|
157 |
+
pub fn get_global() -> Option<Arc<InferenceService>> {
|
158 |
+
INFERENCE_SERVICE.read().as_ref().cloned()
|
159 |
+
}
|
160 |
+
|
161 |
+
pub(crate) fn validate(&self) -> Result<(), StorageError> {
|
162 |
+
if self
|
163 |
+
.config
|
164 |
+
.address
|
165 |
+
.as_ref()
|
166 |
+
.map_or(true, |url| url.is_empty())
|
167 |
+
{
|
168 |
+
return Err(StorageError::service_error(
|
169 |
+
"InferenceService configuration error: address is missing or empty",
|
170 |
+
));
|
171 |
+
}
|
172 |
+
Ok(())
|
173 |
+
}
|
174 |
+
|
175 |
+
pub async fn infer(
|
176 |
+
&self,
|
177 |
+
inference_inputs: Vec<InferenceInput>,
|
178 |
+
inference_type: InferenceType,
|
179 |
+
) -> Result<Vec<VectorPersisted>, StorageError> {
|
180 |
+
let request = InferenceRequest {
|
181 |
+
inputs: inference_inputs,
|
182 |
+
inference: Some(inference_type),
|
183 |
+
token: self.config.token.clone(),
|
184 |
+
};
|
185 |
+
|
186 |
+
let url = self.config.address.as_ref().ok_or_else(|| {
|
187 |
+
StorageError::service_error(
|
188 |
+
"InferenceService URL not configured - please provide valid address in config",
|
189 |
+
)
|
190 |
+
})?;
|
191 |
+
|
192 |
+
let response = self
|
193 |
+
.client
|
194 |
+
.post(url)
|
195 |
+
.json(&request)
|
196 |
+
.send()
|
197 |
+
.await
|
198 |
+
.map_err(|e| {
|
199 |
+
let error_body = e.to_string();
|
200 |
+
StorageError::service_error(format!(
|
201 |
+
"Failed to send inference request to {url}: {e}, error details: {error_body}",
|
202 |
+
))
|
203 |
+
})?;
|
204 |
+
|
205 |
+
let status = response.status();
|
206 |
+
let response_body = response.text().await.map_err(|e| {
|
207 |
+
StorageError::service_error(format!("Failed to read inference response body: {e}",))
|
208 |
+
})?;
|
209 |
+
|
210 |
+
Self::handle_inference_response(status, &response_body)
|
211 |
+
}
|
212 |
+
|
213 |
+
pub(crate) fn handle_inference_response(
|
214 |
+
status: reqwest::StatusCode,
|
215 |
+
response_body: &str,
|
216 |
+
) -> Result<Vec<VectorPersisted>, StorageError> {
|
217 |
+
match status {
|
218 |
+
reqwest::StatusCode::OK => {
|
219 |
+
let inference_response: InferenceResponse = serde_json::from_str(response_body)
|
220 |
+
.map_err(|e| {
|
221 |
+
StorageError::service_error(format!(
|
222 |
+
"Failed to parse successful inference response: {e}. Response body: {response_body}",
|
223 |
+
))
|
224 |
+
})?;
|
225 |
+
|
226 |
+
if inference_response.embeddings.is_empty() {
|
227 |
+
Err(StorageError::service_error(
|
228 |
+
"Inference response contained no embeddings - this may indicate an issue with the model or input"
|
229 |
+
))
|
230 |
+
} else {
|
231 |
+
Ok(inference_response.embeddings)
|
232 |
+
}
|
233 |
+
}
|
234 |
+
reqwest::StatusCode::BAD_REQUEST => {
|
235 |
+
let error_json: Value = serde_json::from_str(response_body).map_err(|e| {
|
236 |
+
StorageError::service_error(format!(
|
237 |
+
"Failed to parse error response: {e}. Raw response: {response_body}",
|
238 |
+
))
|
239 |
+
})?;
|
240 |
+
|
241 |
+
if let Some(error_message) = error_json["error"].as_str() {
|
242 |
+
Err(StorageError::bad_request(format!(
|
243 |
+
"Inference request validation failed: {error_message}",
|
244 |
+
)))
|
245 |
+
} else {
|
246 |
+
Err(StorageError::bad_request(format!(
|
247 |
+
"Invalid inference request: {response_body}",
|
248 |
+
)))
|
249 |
+
}
|
250 |
+
}
|
251 |
+
status @ (reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN) => {
|
252 |
+
Err(StorageError::service_error(format!(
|
253 |
+
"Authentication failed for inference service ({status}): {response_body}",
|
254 |
+
)))
|
255 |
+
}
|
256 |
+
status @ (reqwest::StatusCode::INTERNAL_SERVER_ERROR
|
257 |
+
| reqwest::StatusCode::SERVICE_UNAVAILABLE
|
258 |
+
| reqwest::StatusCode::GATEWAY_TIMEOUT) => Err(StorageError::service_error(format!(
|
259 |
+
"Inference service error ({status}): {response_body}",
|
260 |
+
))),
|
261 |
+
_ => Err(StorageError::service_error(format!(
|
262 |
+
"Unexpected inference service response ({status}): {response_body}"
|
263 |
+
))),
|
264 |
+
}
|
265 |
+
}
|
266 |
+
}
|
src/common/inference/update_requests.rs
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::collections::HashMap;
|
2 |
+
|
3 |
+
use api::rest::{Batch, BatchVectorStruct, PointStruct, PointVectors, Vector, VectorStruct};
|
4 |
+
use collection::operations::point_ops::{
|
5 |
+
BatchPersisted, BatchVectorStructPersisted, PointStructPersisted, VectorPersisted,
|
6 |
+
VectorStructPersisted,
|
7 |
+
};
|
8 |
+
use collection::operations::vector_ops::PointVectorsPersisted;
|
9 |
+
use storage::content_manager::errors::StorageError;
|
10 |
+
|
11 |
+
use crate::common::inference::batch_processing::BatchAccum;
|
12 |
+
use crate::common::inference::infer_processing::BatchAccumInferred;
|
13 |
+
use crate::common::inference::service::{InferenceData, InferenceType};
|
14 |
+
|
15 |
+
pub async fn convert_point_struct(
|
16 |
+
point_structs: Vec<PointStruct>,
|
17 |
+
inference_type: InferenceType,
|
18 |
+
) -> Result<Vec<PointStructPersisted>, StorageError> {
|
19 |
+
let mut batch_accum = BatchAccum::new();
|
20 |
+
|
21 |
+
for point_struct in &point_structs {
|
22 |
+
match &point_struct.vector {
|
23 |
+
VectorStruct::Named(named) => {
|
24 |
+
for vector in named.values() {
|
25 |
+
match vector {
|
26 |
+
Vector::Document(doc) => {
|
27 |
+
batch_accum.add(InferenceData::Document(doc.clone()))
|
28 |
+
}
|
29 |
+
Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
|
30 |
+
Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
|
31 |
+
Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {}
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
35 |
+
VectorStruct::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())),
|
36 |
+
VectorStruct::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
|
37 |
+
VectorStruct::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
|
38 |
+
VectorStruct::MultiDense(_) | VectorStruct::Single(_) => {}
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
let inferred = if !batch_accum.objects.is_empty() {
|
43 |
+
Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?)
|
44 |
+
} else {
|
45 |
+
None
|
46 |
+
};
|
47 |
+
|
48 |
+
let mut converted_points: Vec<PointStructPersisted> = Vec::new();
|
49 |
+
for point_struct in point_structs {
|
50 |
+
let PointStruct {
|
51 |
+
id,
|
52 |
+
vector,
|
53 |
+
payload,
|
54 |
+
} = point_struct;
|
55 |
+
|
56 |
+
let converted_vector_struct = match vector {
|
57 |
+
VectorStruct::Single(single) => VectorStructPersisted::Single(single),
|
58 |
+
VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
|
59 |
+
VectorStruct::Named(named) => {
|
60 |
+
let mut named_vectors = HashMap::new();
|
61 |
+
for (name, vector) in named {
|
62 |
+
let converted_vector = match &inferred {
|
63 |
+
Some(inferred) => convert_vector_with_inferred(vector, inferred)?,
|
64 |
+
None => match vector {
|
65 |
+
Vector::Dense(dense) => VectorPersisted::Dense(dense),
|
66 |
+
Vector::Sparse(sparse) => VectorPersisted::Sparse(sparse),
|
67 |
+
Vector::MultiDense(multi) => VectorPersisted::MultiDense(multi),
|
68 |
+
Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => {
|
69 |
+
return Err(StorageError::inference_error(
|
70 |
+
"Inference required but service returned no results",
|
71 |
+
))
|
72 |
+
}
|
73 |
+
},
|
74 |
+
};
|
75 |
+
named_vectors.insert(name, converted_vector);
|
76 |
+
}
|
77 |
+
VectorStructPersisted::Named(named_vectors)
|
78 |
+
}
|
79 |
+
VectorStruct::Document(doc) => {
|
80 |
+
let vector = match &inferred {
|
81 |
+
Some(inferred) => {
|
82 |
+
convert_vector_with_inferred(Vector::Document(doc), inferred)?
|
83 |
+
}
|
84 |
+
None => {
|
85 |
+
return Err(StorageError::inference_error(
|
86 |
+
"Inference required but service returned no results",
|
87 |
+
))
|
88 |
+
}
|
89 |
+
};
|
90 |
+
match vector {
|
91 |
+
VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
|
92 |
+
VectorPersisted::Sparse(_) => {
|
93 |
+
return Err(StorageError::bad_request("Sparse vector should be named"));
|
94 |
+
}
|
95 |
+
VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
|
96 |
+
}
|
97 |
+
}
|
98 |
+
VectorStruct::Image(img) => {
|
99 |
+
let vector = match &inferred {
|
100 |
+
Some(inferred) => convert_vector_with_inferred(Vector::Image(img), inferred)?,
|
101 |
+
None => {
|
102 |
+
return Err(StorageError::inference_error(
|
103 |
+
"Inference required but service returned no results",
|
104 |
+
))
|
105 |
+
}
|
106 |
+
};
|
107 |
+
match vector {
|
108 |
+
VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
|
109 |
+
VectorPersisted::Sparse(_) => {
|
110 |
+
return Err(StorageError::bad_request("Sparse vector should be named"));
|
111 |
+
}
|
112 |
+
VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
|
113 |
+
}
|
114 |
+
}
|
115 |
+
VectorStruct::Object(obj) => {
|
116 |
+
let vector = match &inferred {
|
117 |
+
Some(inferred) => convert_vector_with_inferred(Vector::Object(obj), inferred)?,
|
118 |
+
None => {
|
119 |
+
return Err(StorageError::inference_error(
|
120 |
+
"Inference required but service returned no results",
|
121 |
+
))
|
122 |
+
}
|
123 |
+
};
|
124 |
+
match vector {
|
125 |
+
VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
|
126 |
+
VectorPersisted::Sparse(_) => {
|
127 |
+
return Err(StorageError::bad_request("Sparse vector should be named"));
|
128 |
+
}
|
129 |
+
VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
|
130 |
+
}
|
131 |
+
}
|
132 |
+
};
|
133 |
+
|
134 |
+
let converted = PointStructPersisted {
|
135 |
+
id,
|
136 |
+
vector: converted_vector_struct,
|
137 |
+
payload,
|
138 |
+
};
|
139 |
+
|
140 |
+
converted_points.push(converted);
|
141 |
+
}
|
142 |
+
|
143 |
+
Ok(converted_points)
|
144 |
+
}
|
145 |
+
|
146 |
+
pub async fn convert_batch(batch: Batch) -> Result<BatchPersisted, StorageError> {
|
147 |
+
let Batch {
|
148 |
+
ids,
|
149 |
+
vectors,
|
150 |
+
payloads,
|
151 |
+
} = batch;
|
152 |
+
|
153 |
+
let batch_persisted = BatchPersisted {
|
154 |
+
ids,
|
155 |
+
vectors: match vectors {
|
156 |
+
BatchVectorStruct::Single(single) => BatchVectorStructPersisted::Single(single),
|
157 |
+
BatchVectorStruct::MultiDense(multi) => BatchVectorStructPersisted::MultiDense(multi),
|
158 |
+
BatchVectorStruct::Named(named) => {
|
159 |
+
let mut named_vectors = HashMap::new();
|
160 |
+
|
161 |
+
for (name, vectors) in named {
|
162 |
+
let converted_vectors = convert_vectors(vectors, InferenceType::Update).await?;
|
163 |
+
named_vectors.insert(name, converted_vectors);
|
164 |
+
}
|
165 |
+
|
166 |
+
BatchVectorStructPersisted::Named(named_vectors)
|
167 |
+
}
|
168 |
+
BatchVectorStruct::Document(_) => {
|
169 |
+
return Err(StorageError::inference_error(
|
170 |
+
"Document processing is not supported in batch operations.",
|
171 |
+
))
|
172 |
+
}
|
173 |
+
BatchVectorStruct::Image(_) => {
|
174 |
+
return Err(StorageError::inference_error(
|
175 |
+
"Image processing is not supported in batch operations.",
|
176 |
+
))
|
177 |
+
}
|
178 |
+
BatchVectorStruct::Object(_) => {
|
179 |
+
return Err(StorageError::inference_error(
|
180 |
+
"Object processing is not supported in batch operations.",
|
181 |
+
))
|
182 |
+
}
|
183 |
+
},
|
184 |
+
payloads,
|
185 |
+
};
|
186 |
+
|
187 |
+
Ok(batch_persisted)
|
188 |
+
}
|
189 |
+
|
190 |
+
pub async fn convert_point_vectors(
|
191 |
+
point_vectors_list: Vec<PointVectors>,
|
192 |
+
inference_type: InferenceType,
|
193 |
+
) -> Result<Vec<PointVectorsPersisted>, StorageError> {
|
194 |
+
let mut converted_point_vectors = Vec::new();
|
195 |
+
let mut batch_accum = BatchAccum::new();
|
196 |
+
|
197 |
+
for point_vectors in &point_vectors_list {
|
198 |
+
if let VectorStruct::Named(named) = &point_vectors.vector {
|
199 |
+
for vector in named.values() {
|
200 |
+
match vector {
|
201 |
+
Vector::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())),
|
202 |
+
Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
|
203 |
+
Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
|
204 |
+
Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
}
|
209 |
+
|
210 |
+
let inferred = if !batch_accum.objects.is_empty() {
|
211 |
+
Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?)
|
212 |
+
} else {
|
213 |
+
None
|
214 |
+
};
|
215 |
+
|
216 |
+
for point_vectors in point_vectors_list {
|
217 |
+
let PointVectors { id, vector } = point_vectors;
|
218 |
+
|
219 |
+
let converted_vector = match vector {
|
220 |
+
VectorStruct::Single(dense) => VectorStructPersisted::Single(dense),
|
221 |
+
VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
|
222 |
+
VectorStruct::Named(named) => {
|
223 |
+
let mut converted = HashMap::new();
|
224 |
+
|
225 |
+
for (name, vec) in named {
|
226 |
+
let converted_vec = match &inferred {
|
227 |
+
Some(inferred) => convert_vector_with_inferred(vec, inferred)?,
|
228 |
+
None => match vec {
|
229 |
+
Vector::Dense(dense) => VectorPersisted::Dense(dense),
|
230 |
+
Vector::Sparse(sparse) => VectorPersisted::Sparse(sparse),
|
231 |
+
Vector::MultiDense(multi) => VectorPersisted::MultiDense(multi),
|
232 |
+
Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => {
|
233 |
+
return Err(StorageError::inference_error(
|
234 |
+
"Inference required but service returned no results",
|
235 |
+
))
|
236 |
+
}
|
237 |
+
},
|
238 |
+
};
|
239 |
+
converted.insert(name, converted_vec);
|
240 |
+
}
|
241 |
+
|
242 |
+
VectorStructPersisted::Named(converted)
|
243 |
+
}
|
244 |
+
VectorStruct::Document(_) => {
|
245 |
+
return Err(StorageError::inference_error(
|
246 |
+
"Document processing is not supported for point vectors.",
|
247 |
+
))
|
248 |
+
}
|
249 |
+
VectorStruct::Image(_) => {
|
250 |
+
return Err(StorageError::inference_error(
|
251 |
+
"Image processing is not supported for point vectors.",
|
252 |
+
))
|
253 |
+
}
|
254 |
+
VectorStruct::Object(_) => {
|
255 |
+
return Err(StorageError::inference_error(
|
256 |
+
"Object processing is not supported for point vectors.",
|
257 |
+
))
|
258 |
+
}
|
259 |
+
};
|
260 |
+
|
261 |
+
let converted_point_vector = PointVectorsPersisted {
|
262 |
+
id,
|
263 |
+
vector: converted_vector,
|
264 |
+
};
|
265 |
+
|
266 |
+
converted_point_vectors.push(converted_point_vector);
|
267 |
+
}
|
268 |
+
|
269 |
+
Ok(converted_point_vectors)
|
270 |
+
}
|
271 |
+
|
272 |
+
fn convert_point_struct_with_inferred(
|
273 |
+
point_structs: Vec<PointStruct>,
|
274 |
+
inferred: &BatchAccumInferred,
|
275 |
+
) -> Result<Vec<PointStructPersisted>, StorageError> {
|
276 |
+
point_structs
|
277 |
+
.into_iter()
|
278 |
+
.map(|point_struct| {
|
279 |
+
let PointStruct {
|
280 |
+
id,
|
281 |
+
vector,
|
282 |
+
payload,
|
283 |
+
} = point_struct;
|
284 |
+
let converted_vector_struct = match vector {
|
285 |
+
VectorStruct::Single(single) => VectorStructPersisted::Single(single),
|
286 |
+
VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
|
287 |
+
VectorStruct::Named(named) => {
|
288 |
+
let mut named_vectors = HashMap::new();
|
289 |
+
for (name, vector) in named {
|
290 |
+
let converted_vector = convert_vector_with_inferred(vector, inferred)?;
|
291 |
+
named_vectors.insert(name, converted_vector);
|
292 |
+
}
|
293 |
+
VectorStructPersisted::Named(named_vectors)
|
294 |
+
}
|
295 |
+
VectorStruct::Document(doc) => {
|
296 |
+
let vector = convert_vector_with_inferred(Vector::Document(doc), inferred)?;
|
297 |
+
match vector {
|
298 |
+
VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
|
299 |
+
VectorPersisted::Sparse(_) => {
|
300 |
+
return Err(StorageError::bad_request("Sparse vector should be named"))
|
301 |
+
}
|
302 |
+
VectorPersisted::MultiDense(multi) => {
|
303 |
+
VectorStructPersisted::MultiDense(multi)
|
304 |
+
}
|
305 |
+
}
|
306 |
+
}
|
307 |
+
VectorStruct::Image(img) => {
|
308 |
+
let vector = convert_vector_with_inferred(Vector::Image(img), inferred)?;
|
309 |
+
match vector {
|
310 |
+
VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
|
311 |
+
VectorPersisted::Sparse(_) => {
|
312 |
+
return Err(StorageError::bad_request("Sparse vector should be named"))
|
313 |
+
}
|
314 |
+
VectorPersisted::MultiDense(multi) => {
|
315 |
+
VectorStructPersisted::MultiDense(multi)
|
316 |
+
}
|
317 |
+
}
|
318 |
+
}
|
319 |
+
VectorStruct::Object(obj) => {
|
320 |
+
let vector = convert_vector_with_inferred(Vector::Object(obj), inferred)?;
|
321 |
+
match vector {
|
322 |
+
VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
|
323 |
+
VectorPersisted::Sparse(_) => {
|
324 |
+
return Err(StorageError::bad_request("Sparse vector should be named"))
|
325 |
+
}
|
326 |
+
VectorPersisted::MultiDense(multi) => {
|
327 |
+
VectorStructPersisted::MultiDense(multi)
|
328 |
+
}
|
329 |
+
}
|
330 |
+
}
|
331 |
+
};
|
332 |
+
|
333 |
+
Ok(PointStructPersisted {
|
334 |
+
id,
|
335 |
+
vector: converted_vector_struct,
|
336 |
+
payload,
|
337 |
+
})
|
338 |
+
})
|
339 |
+
.collect()
|
340 |
+
}
|
341 |
+
|
342 |
+
pub async fn convert_vectors(
|
343 |
+
vectors: Vec<Vector>,
|
344 |
+
inference_type: InferenceType,
|
345 |
+
) -> Result<Vec<VectorPersisted>, StorageError> {
|
346 |
+
let mut batch_accum = BatchAccum::new();
|
347 |
+
for vector in &vectors {
|
348 |
+
match vector {
|
349 |
+
Vector::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())),
|
350 |
+
Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
|
351 |
+
Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
|
352 |
+
Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {}
|
353 |
+
}
|
354 |
+
}
|
355 |
+
|
356 |
+
let inferred = if !batch_accum.objects.is_empty() {
|
357 |
+
Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?)
|
358 |
+
} else {
|
359 |
+
None
|
360 |
+
};
|
361 |
+
|
362 |
+
vectors
|
363 |
+
.into_iter()
|
364 |
+
.map(|vector| match &inferred {
|
365 |
+
Some(inferred) => convert_vector_with_inferred(vector, inferred),
|
366 |
+
None => match vector {
|
367 |
+
Vector::Dense(dense) => Ok(VectorPersisted::Dense(dense)),
|
368 |
+
Vector::Sparse(sparse) => Ok(VectorPersisted::Sparse(sparse)),
|
369 |
+
Vector::MultiDense(multi) => Ok(VectorPersisted::MultiDense(multi)),
|
370 |
+
Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => {
|
371 |
+
Err(StorageError::inference_error(
|
372 |
+
"Inference required but service returned no results",
|
373 |
+
))
|
374 |
+
}
|
375 |
+
},
|
376 |
+
})
|
377 |
+
.collect()
|
378 |
+
}
|
379 |
+
|
380 |
+
fn convert_vector_with_inferred(
|
381 |
+
vector: Vector,
|
382 |
+
inferred: &BatchAccumInferred,
|
383 |
+
) -> Result<VectorPersisted, StorageError> {
|
384 |
+
match vector {
|
385 |
+
Vector::Dense(dense) => Ok(VectorPersisted::Dense(dense)),
|
386 |
+
Vector::Sparse(sparse) => Ok(VectorPersisted::Sparse(sparse)),
|
387 |
+
Vector::MultiDense(multi) => Ok(VectorPersisted::MultiDense(multi)),
|
388 |
+
Vector::Document(doc) => {
|
389 |
+
let data = InferenceData::Document(doc);
|
390 |
+
inferred.get_vector(&data).cloned().ok_or_else(|| {
|
391 |
+
StorageError::inference_error("Missing inferred vector for document")
|
392 |
+
})
|
393 |
+
}
|
394 |
+
Vector::Image(img) => {
|
395 |
+
let data = InferenceData::Image(img);
|
396 |
+
inferred
|
397 |
+
.get_vector(&data)
|
398 |
+
.cloned()
|
399 |
+
.ok_or_else(|| StorageError::inference_error("Missing inferred vector for image"))
|
400 |
+
}
|
401 |
+
Vector::Object(obj) => {
|
402 |
+
let data = InferenceData::Object(obj);
|
403 |
+
inferred
|
404 |
+
.get_vector(&data)
|
405 |
+
.cloned()
|
406 |
+
.ok_or_else(|| StorageError::inference_error("Missing inferred vector for object"))
|
407 |
+
}
|
408 |
+
}
|
409 |
+
}
|
src/common/metrics.rs
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use prometheus::proto::{Counter, Gauge, LabelPair, Metric, MetricFamily, MetricType};
|
2 |
+
use prometheus::TextEncoder;
|
3 |
+
use segment::common::operation_time_statistics::OperationDurationStatistics;
|
4 |
+
|
5 |
+
use crate::common::telemetry::TelemetryData;
|
6 |
+
use crate::common::telemetry_ops::app_telemetry::{AppBuildTelemetry, AppFeaturesTelemetry};
|
7 |
+
use crate::common::telemetry_ops::cluster_telemetry::{ClusterStatusTelemetry, ClusterTelemetry};
|
8 |
+
use crate::common::telemetry_ops::collections_telemetry::{
|
9 |
+
CollectionTelemetryEnum, CollectionsTelemetry,
|
10 |
+
};
|
11 |
+
use crate::common::telemetry_ops::memory_telemetry::MemoryTelemetry;
|
12 |
+
use crate::common::telemetry_ops::requests_telemetry::{
|
13 |
+
GrpcTelemetry, RequestsTelemetry, WebApiTelemetry,
|
14 |
+
};
|
15 |
+
|
16 |
+
/// Whitelist for REST endpoints in metrics output.
|
17 |
+
///
|
18 |
+
/// Contains selection of search, recommend, scroll and upsert endpoints.
|
19 |
+
///
|
20 |
+
/// This array *must* be sorted.
|
21 |
+
const REST_ENDPOINT_WHITELIST: &[&str] = &[
|
22 |
+
"/collections/{name}/index",
|
23 |
+
"/collections/{name}/points",
|
24 |
+
"/collections/{name}/points/batch",
|
25 |
+
"/collections/{name}/points/count",
|
26 |
+
"/collections/{name}/points/delete",
|
27 |
+
"/collections/{name}/points/discover",
|
28 |
+
"/collections/{name}/points/discover/batch",
|
29 |
+
"/collections/{name}/points/facet",
|
30 |
+
"/collections/{name}/points/payload",
|
31 |
+
"/collections/{name}/points/payload/clear",
|
32 |
+
"/collections/{name}/points/payload/delete",
|
33 |
+
"/collections/{name}/points/query",
|
34 |
+
"/collections/{name}/points/query/batch",
|
35 |
+
"/collections/{name}/points/query/groups",
|
36 |
+
"/collections/{name}/points/recommend",
|
37 |
+
"/collections/{name}/points/recommend/batch",
|
38 |
+
"/collections/{name}/points/recommend/groups",
|
39 |
+
"/collections/{name}/points/scroll",
|
40 |
+
"/collections/{name}/points/search",
|
41 |
+
"/collections/{name}/points/search/batch",
|
42 |
+
"/collections/{name}/points/search/groups",
|
43 |
+
"/collections/{name}/points/search/matrix/offsets",
|
44 |
+
"/collections/{name}/points/search/matrix/pairs",
|
45 |
+
"/collections/{name}/points/vectors",
|
46 |
+
"/collections/{name}/points/vectors/delete",
|
47 |
+
];
|
48 |
+
|
49 |
+
/// Whitelist for GRPC endpoints in metrics output.
|
50 |
+
///
|
51 |
+
/// Contains selection of search, recommend, scroll and upsert endpoints.
|
52 |
+
///
|
53 |
+
/// This array *must* be sorted.
|
54 |
+
const GRPC_ENDPOINT_WHITELIST: &[&str] = &[
|
55 |
+
"/qdrant.Points/ClearPayload",
|
56 |
+
"/qdrant.Points/Count",
|
57 |
+
"/qdrant.Points/Delete",
|
58 |
+
"/qdrant.Points/DeletePayload",
|
59 |
+
"/qdrant.Points/Discover",
|
60 |
+
"/qdrant.Points/DiscoverBatch",
|
61 |
+
"/qdrant.Points/Facet",
|
62 |
+
"/qdrant.Points/Get",
|
63 |
+
"/qdrant.Points/OverwritePayload",
|
64 |
+
"/qdrant.Points/Query",
|
65 |
+
"/qdrant.Points/QueryBatch",
|
66 |
+
"/qdrant.Points/QueryGroups",
|
67 |
+
"/qdrant.Points/Recommend",
|
68 |
+
"/qdrant.Points/RecommendBatch",
|
69 |
+
"/qdrant.Points/RecommendGroups",
|
70 |
+
"/qdrant.Points/Scroll",
|
71 |
+
"/qdrant.Points/Search",
|
72 |
+
"/qdrant.Points/SearchBatch",
|
73 |
+
"/qdrant.Points/SearchGroups",
|
74 |
+
"/qdrant.Points/SetPayload",
|
75 |
+
"/qdrant.Points/UpdateBatch",
|
76 |
+
"/qdrant.Points/UpdateVectors",
|
77 |
+
"/qdrant.Points/Upsert",
|
78 |
+
];
|
79 |
+
|
80 |
+
/// For REST requests, only report timings when having this HTTP response status.
|
81 |
+
const REST_TIMINGS_FOR_STATUS: u16 = 200;
|
82 |
+
|
83 |
+
/// Encapsulates metrics data in Prometheus format.
|
84 |
+
pub struct MetricsData {
|
85 |
+
metrics: Vec<MetricFamily>,
|
86 |
+
}
|
87 |
+
|
88 |
+
impl MetricsData {
|
89 |
+
pub fn format_metrics(&self) -> String {
|
90 |
+
TextEncoder::new().encode_to_string(&self.metrics).unwrap()
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
impl From<TelemetryData> for MetricsData {
|
95 |
+
fn from(telemetry_data: TelemetryData) -> Self {
|
96 |
+
let mut metrics = vec![];
|
97 |
+
telemetry_data.add_metrics(&mut metrics);
|
98 |
+
Self { metrics }
|
99 |
+
}
|
100 |
+
}
|
101 |
+
|
102 |
+
trait MetricsProvider {
|
103 |
+
/// Add metrics definitions for this.
|
104 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>);
|
105 |
+
}
|
106 |
+
|
107 |
+
impl MetricsProvider for TelemetryData {
|
108 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
109 |
+
self.app.add_metrics(metrics);
|
110 |
+
self.collections.add_metrics(metrics);
|
111 |
+
self.cluster.add_metrics(metrics);
|
112 |
+
self.requests.add_metrics(metrics);
|
113 |
+
if let Some(mem) = &self.memory {
|
114 |
+
mem.add_metrics(metrics);
|
115 |
+
}
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
impl MetricsProvider for AppBuildTelemetry {
|
120 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
121 |
+
metrics.push(metric_family(
|
122 |
+
"app_info",
|
123 |
+
"information about qdrant server",
|
124 |
+
MetricType::GAUGE,
|
125 |
+
vec![gauge(
|
126 |
+
1.0,
|
127 |
+
&[("name", &self.name), ("version", &self.version)],
|
128 |
+
)],
|
129 |
+
));
|
130 |
+
self.features.iter().for_each(|f| f.add_metrics(metrics));
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
impl MetricsProvider for AppFeaturesTelemetry {
|
135 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
136 |
+
metrics.push(metric_family(
|
137 |
+
"app_status_recovery_mode",
|
138 |
+
"features enabled in qdrant server",
|
139 |
+
MetricType::GAUGE,
|
140 |
+
vec![gauge(if self.recovery_mode { 1.0 } else { 0.0 }, &[])],
|
141 |
+
))
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
impl MetricsProvider for CollectionsTelemetry {
|
146 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
147 |
+
let vector_count = self
|
148 |
+
.collections
|
149 |
+
.iter()
|
150 |
+
.flatten()
|
151 |
+
.map(|p| match p {
|
152 |
+
CollectionTelemetryEnum::Aggregated(a) => a.vectors,
|
153 |
+
CollectionTelemetryEnum::Full(c) => c.count_vectors(),
|
154 |
+
})
|
155 |
+
.sum::<usize>();
|
156 |
+
metrics.push(metric_family(
|
157 |
+
"collections_total",
|
158 |
+
"number of collections",
|
159 |
+
MetricType::GAUGE,
|
160 |
+
vec![gauge(self.number_of_collections as f64, &[])],
|
161 |
+
));
|
162 |
+
metrics.push(metric_family(
|
163 |
+
"collections_vector_total",
|
164 |
+
"total number of vectors in all collections",
|
165 |
+
MetricType::GAUGE,
|
166 |
+
vec![gauge(vector_count as f64, &[])],
|
167 |
+
));
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
impl MetricsProvider for ClusterTelemetry {
|
172 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
173 |
+
let ClusterTelemetry {
|
174 |
+
enabled,
|
175 |
+
status,
|
176 |
+
config: _,
|
177 |
+
peers: _,
|
178 |
+
metadata: _,
|
179 |
+
} = self;
|
180 |
+
|
181 |
+
metrics.push(metric_family(
|
182 |
+
"cluster_enabled",
|
183 |
+
"is cluster support enabled",
|
184 |
+
MetricType::GAUGE,
|
185 |
+
vec![gauge(if *enabled { 1.0 } else { 0.0 }, &[])],
|
186 |
+
));
|
187 |
+
|
188 |
+
if let Some(ref status) = status {
|
189 |
+
status.add_metrics(metrics);
|
190 |
+
}
|
191 |
+
}
|
192 |
+
}
|
193 |
+
|
194 |
+
impl MetricsProvider for ClusterStatusTelemetry {
|
195 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
196 |
+
metrics.push(metric_family(
|
197 |
+
"cluster_peers_total",
|
198 |
+
"total number of cluster peers",
|
199 |
+
MetricType::GAUGE,
|
200 |
+
vec![gauge(self.number_of_peers as f64, &[])],
|
201 |
+
));
|
202 |
+
metrics.push(metric_family(
|
203 |
+
"cluster_term",
|
204 |
+
"current cluster term",
|
205 |
+
MetricType::COUNTER,
|
206 |
+
vec![counter(self.term as f64, &[])],
|
207 |
+
));
|
208 |
+
|
209 |
+
if let Some(ref peer_id) = self.peer_id.map(|p| p.to_string()) {
|
210 |
+
metrics.push(metric_family(
|
211 |
+
"cluster_commit",
|
212 |
+
"index of last committed (finalized) operation cluster peer is aware of",
|
213 |
+
MetricType::COUNTER,
|
214 |
+
vec![counter(self.commit as f64, &[("peer_id", peer_id)])],
|
215 |
+
));
|
216 |
+
metrics.push(metric_family(
|
217 |
+
"cluster_pending_operations_total",
|
218 |
+
"total number of pending operations for cluster peer",
|
219 |
+
MetricType::GAUGE,
|
220 |
+
vec![gauge(self.pending_operations as f64, &[])],
|
221 |
+
));
|
222 |
+
metrics.push(metric_family(
|
223 |
+
"cluster_voter",
|
224 |
+
"is cluster peer a voter or learner",
|
225 |
+
MetricType::GAUGE,
|
226 |
+
vec![gauge(if self.is_voter { 1.0 } else { 0.0 }, &[])],
|
227 |
+
));
|
228 |
+
}
|
229 |
+
}
|
230 |
+
}
|
231 |
+
|
232 |
+
impl MetricsProvider for RequestsTelemetry {
|
233 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
234 |
+
self.rest.add_metrics(metrics);
|
235 |
+
self.grpc.add_metrics(metrics);
|
236 |
+
}
|
237 |
+
}
|
238 |
+
|
239 |
+
impl MetricsProvider for WebApiTelemetry {
|
240 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
241 |
+
let mut builder = OperationDurationMetricsBuilder::default();
|
242 |
+
for (endpoint, responses) in &self.responses {
|
243 |
+
let Some((method, endpoint)) = endpoint.split_once(' ') else {
|
244 |
+
continue;
|
245 |
+
};
|
246 |
+
// Endpoint must be whitelisted
|
247 |
+
if REST_ENDPOINT_WHITELIST.binary_search(&endpoint).is_err() {
|
248 |
+
continue;
|
249 |
+
}
|
250 |
+
for (status, stats) in responses {
|
251 |
+
builder.add(
|
252 |
+
stats,
|
253 |
+
&[
|
254 |
+
("method", method),
|
255 |
+
("endpoint", endpoint),
|
256 |
+
("status", &status.to_string()),
|
257 |
+
],
|
258 |
+
*status == REST_TIMINGS_FOR_STATUS,
|
259 |
+
);
|
260 |
+
}
|
261 |
+
}
|
262 |
+
builder.build("rest", metrics);
|
263 |
+
}
|
264 |
+
}
|
265 |
+
|
266 |
+
impl MetricsProvider for GrpcTelemetry {
|
267 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
268 |
+
let mut builder = OperationDurationMetricsBuilder::default();
|
269 |
+
for (endpoint, stats) in &self.responses {
|
270 |
+
// Endpoint must be whitelisted
|
271 |
+
if GRPC_ENDPOINT_WHITELIST
|
272 |
+
.binary_search(&endpoint.as_str())
|
273 |
+
.is_err()
|
274 |
+
{
|
275 |
+
continue;
|
276 |
+
}
|
277 |
+
builder.add(stats, &[("endpoint", endpoint.as_str())], true);
|
278 |
+
}
|
279 |
+
builder.build("grpc", metrics);
|
280 |
+
}
|
281 |
+
}
|
282 |
+
|
283 |
+
impl MetricsProvider for MemoryTelemetry {
|
284 |
+
fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
|
285 |
+
metrics.push(metric_family(
|
286 |
+
"memory_active_bytes",
|
287 |
+
"Total number of bytes in active pages allocated by the application",
|
288 |
+
MetricType::GAUGE,
|
289 |
+
vec![gauge(self.active_bytes as f64, &[])],
|
290 |
+
));
|
291 |
+
metrics.push(metric_family(
|
292 |
+
"memory_allocated_bytes",
|
293 |
+
"Total number of bytes allocated by the application",
|
294 |
+
MetricType::GAUGE,
|
295 |
+
vec![gauge(self.allocated_bytes as f64, &[])],
|
296 |
+
));
|
297 |
+
metrics.push(metric_family(
|
298 |
+
"memory_metadata_bytes",
|
299 |
+
"Total number of bytes dedicated to metadata",
|
300 |
+
MetricType::GAUGE,
|
301 |
+
vec![gauge(self.metadata_bytes as f64, &[])],
|
302 |
+
));
|
303 |
+
metrics.push(metric_family(
|
304 |
+
"memory_resident_bytes",
|
305 |
+
"Maximum number of bytes in physically resident data pages mapped",
|
306 |
+
MetricType::GAUGE,
|
307 |
+
vec![gauge(self.resident_bytes as f64, &[])],
|
308 |
+
));
|
309 |
+
metrics.push(metric_family(
|
310 |
+
"memory_retained_bytes",
|
311 |
+
"Total number of bytes in virtual memory mappings",
|
312 |
+
MetricType::GAUGE,
|
313 |
+
vec![gauge(self.retained_bytes as f64, &[])],
|
314 |
+
));
|
315 |
+
}
|
316 |
+
}
|
317 |
+
|
318 |
+
/// A helper struct to build a vector of [`MetricFamily`] out of a collection of
|
319 |
+
/// [`OperationDurationStatistics`].
|
320 |
+
#[derive(Default)]
|
321 |
+
struct OperationDurationMetricsBuilder {
|
322 |
+
total: Vec<Metric>,
|
323 |
+
fail_total: Vec<Metric>,
|
324 |
+
avg_secs: Vec<Metric>,
|
325 |
+
min_secs: Vec<Metric>,
|
326 |
+
max_secs: Vec<Metric>,
|
327 |
+
duration_histogram_secs: Vec<Metric>,
|
328 |
+
}
|
329 |
+
|
330 |
+
impl OperationDurationMetricsBuilder {
|
331 |
+
/// Add metrics for the provided statistics.
|
332 |
+
/// If `add_timings` is `false`, only the total and fail_total counters will be added.
|
333 |
+
pub fn add(
|
334 |
+
&mut self,
|
335 |
+
stat: &OperationDurationStatistics,
|
336 |
+
labels: &[(&str, &str)],
|
337 |
+
add_timings: bool,
|
338 |
+
) {
|
339 |
+
self.total.push(counter(stat.count as f64, labels));
|
340 |
+
self.fail_total
|
341 |
+
.push(counter(stat.fail_count as f64, labels));
|
342 |
+
|
343 |
+
if !add_timings {
|
344 |
+
return;
|
345 |
+
}
|
346 |
+
|
347 |
+
self.avg_secs.push(gauge(
|
348 |
+
f64::from(stat.avg_duration_micros.unwrap_or(0.0)) / 1_000_000.0,
|
349 |
+
labels,
|
350 |
+
));
|
351 |
+
self.min_secs.push(gauge(
|
352 |
+
f64::from(stat.min_duration_micros.unwrap_or(0.0)) / 1_000_000.0,
|
353 |
+
labels,
|
354 |
+
));
|
355 |
+
self.max_secs.push(gauge(
|
356 |
+
f64::from(stat.max_duration_micros.unwrap_or(0.0)) / 1_000_000.0,
|
357 |
+
labels,
|
358 |
+
));
|
359 |
+
self.duration_histogram_secs.push(histogram(
|
360 |
+
stat.count as u64,
|
361 |
+
stat.total_duration_micros as f64 / 1_000_000.0,
|
362 |
+
&stat
|
363 |
+
.duration_micros_histogram
|
364 |
+
.iter()
|
365 |
+
.map(|&(b, c)| (f64::from(b) / 1_000_000.0, c as u64))
|
366 |
+
.collect::<Vec<_>>(),
|
367 |
+
labels,
|
368 |
+
));
|
369 |
+
}
|
370 |
+
|
371 |
+
/// Build metrics and add them to the provided vector.
|
372 |
+
pub fn build(self, prefix: &str, metrics: &mut Vec<MetricFamily>) {
|
373 |
+
if !self.total.is_empty() {
|
374 |
+
metrics.push(metric_family(
|
375 |
+
&format!("{prefix}_responses_total"),
|
376 |
+
"total number of responses",
|
377 |
+
MetricType::COUNTER,
|
378 |
+
self.total,
|
379 |
+
));
|
380 |
+
}
|
381 |
+
if !self.fail_total.is_empty() {
|
382 |
+
metrics.push(metric_family(
|
383 |
+
&format!("{prefix}_responses_fail_total"),
|
384 |
+
"total number of failed responses",
|
385 |
+
MetricType::COUNTER,
|
386 |
+
self.fail_total,
|
387 |
+
));
|
388 |
+
}
|
389 |
+
if !self.avg_secs.is_empty() {
|
390 |
+
metrics.push(metric_family(
|
391 |
+
&format!("{prefix}_responses_avg_duration_seconds"),
|
392 |
+
"average response duration",
|
393 |
+
MetricType::GAUGE,
|
394 |
+
self.avg_secs,
|
395 |
+
));
|
396 |
+
}
|
397 |
+
if !self.min_secs.is_empty() {
|
398 |
+
metrics.push(metric_family(
|
399 |
+
&format!("{prefix}_responses_min_duration_seconds"),
|
400 |
+
"minimum response duration",
|
401 |
+
MetricType::GAUGE,
|
402 |
+
self.min_secs,
|
403 |
+
));
|
404 |
+
}
|
405 |
+
if !self.max_secs.is_empty() {
|
406 |
+
metrics.push(metric_family(
|
407 |
+
&format!("{prefix}_responses_max_duration_seconds"),
|
408 |
+
"maximum response duration",
|
409 |
+
MetricType::GAUGE,
|
410 |
+
self.max_secs,
|
411 |
+
));
|
412 |
+
}
|
413 |
+
if !self.duration_histogram_secs.is_empty() {
|
414 |
+
metrics.push(metric_family(
|
415 |
+
&format!("{prefix}_responses_duration_seconds"),
|
416 |
+
"response duration histogram",
|
417 |
+
MetricType::HISTOGRAM,
|
418 |
+
self.duration_histogram_secs,
|
419 |
+
));
|
420 |
+
}
|
421 |
+
}
|
422 |
+
}
|
423 |
+
|
424 |
+
fn metric_family(name: &str, help: &str, r#type: MetricType, metrics: Vec<Metric>) -> MetricFamily {
|
425 |
+
let mut metric_family = MetricFamily::default();
|
426 |
+
metric_family.set_name(name.into());
|
427 |
+
metric_family.set_help(help.into());
|
428 |
+
metric_family.set_field_type(r#type);
|
429 |
+
metric_family.set_metric(metrics);
|
430 |
+
metric_family
|
431 |
+
}
|
432 |
+
|
433 |
+
fn counter(value: f64, labels: &[(&str, &str)]) -> Metric {
|
434 |
+
let mut metric = Metric::default();
|
435 |
+
metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect());
|
436 |
+
metric.set_counter({
|
437 |
+
let mut counter = Counter::default();
|
438 |
+
counter.set_value(value);
|
439 |
+
counter
|
440 |
+
});
|
441 |
+
metric
|
442 |
+
}
|
443 |
+
|
444 |
+
fn gauge(value: f64, labels: &[(&str, &str)]) -> Metric {
|
445 |
+
let mut metric = Metric::default();
|
446 |
+
metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect());
|
447 |
+
metric.set_gauge({
|
448 |
+
let mut gauge = Gauge::default();
|
449 |
+
gauge.set_value(value);
|
450 |
+
gauge
|
451 |
+
});
|
452 |
+
metric
|
453 |
+
}
|
454 |
+
|
455 |
+
fn histogram(
|
456 |
+
sample_count: u64,
|
457 |
+
sample_sum: f64,
|
458 |
+
buckets: &[(f64, u64)],
|
459 |
+
labels: &[(&str, &str)],
|
460 |
+
) -> Metric {
|
461 |
+
let mut metric = Metric::default();
|
462 |
+
metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect());
|
463 |
+
metric.set_histogram({
|
464 |
+
let mut histogram = prometheus::proto::Histogram::default();
|
465 |
+
histogram.set_sample_count(sample_count);
|
466 |
+
histogram.set_sample_sum(sample_sum);
|
467 |
+
histogram.set_bucket(
|
468 |
+
buckets
|
469 |
+
.iter()
|
470 |
+
.map(|&(upper_bound, cumulative_count)| {
|
471 |
+
let mut bucket = prometheus::proto::Bucket::default();
|
472 |
+
bucket.set_cumulative_count(cumulative_count);
|
473 |
+
bucket.set_upper_bound(upper_bound);
|
474 |
+
bucket
|
475 |
+
})
|
476 |
+
.collect(),
|
477 |
+
);
|
478 |
+
histogram
|
479 |
+
});
|
480 |
+
metric
|
481 |
+
}
|
482 |
+
|
483 |
+
fn label_pair(name: &str, value: &str) -> LabelPair {
|
484 |
+
let mut label = LabelPair::default();
|
485 |
+
label.set_name(name.into());
|
486 |
+
label.set_value(value.into());
|
487 |
+
label
|
488 |
+
}
|
489 |
+
|
490 |
+
#[cfg(test)]
|
491 |
+
mod tests {
|
492 |
+
#[test]
|
493 |
+
fn test_endpoint_whitelists_sorted() {
|
494 |
+
use super::{GRPC_ENDPOINT_WHITELIST, REST_ENDPOINT_WHITELIST};
|
495 |
+
|
496 |
+
assert!(
|
497 |
+
REST_ENDPOINT_WHITELIST.windows(2).all(|n| n[0] <= n[1]),
|
498 |
+
"REST_ENDPOINT_WHITELIST must be sorted in code to allow binary search"
|
499 |
+
);
|
500 |
+
assert!(
|
501 |
+
GRPC_ENDPOINT_WHITELIST.windows(2).all(|n| n[0] <= n[1]),
|
502 |
+
"GRPC_ENDPOINT_WHITELIST must be sorted in code to allow binary search"
|
503 |
+
);
|
504 |
+
}
|
505 |
+
}
|
src/common/mod.rs
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
2 |
+
pub mod collections;
|
3 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
4 |
+
pub mod error_reporting;
|
5 |
+
#[allow(dead_code)]
|
6 |
+
pub mod health;
|
7 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
8 |
+
pub mod helpers;
|
9 |
+
pub mod http_client;
|
10 |
+
pub mod metrics;
|
11 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
12 |
+
pub mod points;
|
13 |
+
pub mod snapshots;
|
14 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
15 |
+
pub mod stacktrace;
|
16 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
17 |
+
pub mod telemetry;
|
18 |
+
pub mod telemetry_ops;
|
19 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
20 |
+
pub mod telemetry_reporting;
|
21 |
+
|
22 |
+
pub mod auth;
|
23 |
+
|
24 |
+
pub mod strings;
|
25 |
+
|
26 |
+
pub mod debugger;
|
27 |
+
|
28 |
+
#[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
|
29 |
+
pub mod inference;
|
30 |
+
|
31 |
+
pub mod pyroscope_state;
|
src/common/points.rs
ADDED
@@ -0,0 +1,1175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::sync::Arc;
|
2 |
+
use std::time::Duration;
|
3 |
+
|
4 |
+
use api::rest::schema::{PointInsertOperations, PointsBatch, PointsList};
|
5 |
+
use api::rest::{SearchGroupsRequestInternal, ShardKeySelector, UpdateVectors};
|
6 |
+
use collection::collection::distance_matrix::{
|
7 |
+
CollectionSearchMatrixRequest, CollectionSearchMatrixResponse,
|
8 |
+
};
|
9 |
+
use collection::collection::Collection;
|
10 |
+
use collection::common::batching::batch_requests;
|
11 |
+
use collection::grouping::group_by::GroupRequest;
|
12 |
+
use collection::operations::consistency_params::ReadConsistency;
|
13 |
+
use collection::operations::payload_ops::{
|
14 |
+
DeletePayload, DeletePayloadOp, PayloadOps, SetPayload, SetPayloadOp,
|
15 |
+
};
|
16 |
+
use collection::operations::point_ops::{
|
17 |
+
FilterSelector, PointIdsList, PointInsertOperationsInternal, PointOperations, PointsSelector,
|
18 |
+
WriteOrdering,
|
19 |
+
};
|
20 |
+
use collection::operations::shard_selector_internal::ShardSelectorInternal;
|
21 |
+
use collection::operations::types::{
|
22 |
+
CollectionError, CoreSearchRequest, CoreSearchRequestBatch, CountRequestInternal, CountResult,
|
23 |
+
DiscoverRequestBatch, GroupsResult, PointRequestInternal, RecommendGroupsRequestInternal,
|
24 |
+
RecordInternal, ScrollRequestInternal, ScrollResult, UpdateResult,
|
25 |
+
};
|
26 |
+
use collection::operations::universal_query::collection_query::{
|
27 |
+
CollectionQueryGroupsRequest, CollectionQueryRequest,
|
28 |
+
};
|
29 |
+
use collection::operations::vector_ops::{DeleteVectors, UpdateVectorsOp, VectorOperations};
|
30 |
+
use collection::operations::verification::{
|
31 |
+
new_unchecked_verification_pass, StrictModeVerification,
|
32 |
+
};
|
33 |
+
use collection::operations::{
|
34 |
+
ClockTag, CollectionUpdateOperations, CreateIndex, FieldIndexOperations, OperationWithClockTag,
|
35 |
+
};
|
36 |
+
use collection::shards::shard::ShardId;
|
37 |
+
use common::counter::hardware_accumulator::HwMeasurementAcc;
|
38 |
+
use schemars::JsonSchema;
|
39 |
+
use segment::json_path::JsonPath;
|
40 |
+
use segment::types::{PayloadFieldSchema, PayloadKeyType, ScoredPoint, StrictModeConfig};
|
41 |
+
use serde::{Deserialize, Serialize};
|
42 |
+
use storage::content_manager::collection_meta_ops::{
|
43 |
+
CollectionMetaOperations, CreatePayloadIndex, DropPayloadIndex,
|
44 |
+
};
|
45 |
+
use storage::content_manager::errors::StorageError;
|
46 |
+
use storage::content_manager::toc::TableOfContent;
|
47 |
+
use storage::dispatcher::Dispatcher;
|
48 |
+
use storage::rbac::Access;
|
49 |
+
use validator::Validate;
|
50 |
+
|
51 |
+
use crate::common::inference::service::InferenceType;
|
52 |
+
use crate::common::inference::update_requests::{
|
53 |
+
convert_batch, convert_point_struct, convert_point_vectors,
|
54 |
+
};
|
55 |
+
|
56 |
+
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate)]
|
57 |
+
pub struct CreateFieldIndex {
|
58 |
+
pub field_name: PayloadKeyType,
|
59 |
+
#[serde(alias = "field_type")]
|
60 |
+
pub field_schema: Option<PayloadFieldSchema>,
|
61 |
+
}
|
62 |
+
|
63 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
64 |
+
pub struct UpsertOperation {
|
65 |
+
#[validate(nested)]
|
66 |
+
upsert: PointInsertOperations,
|
67 |
+
}
|
68 |
+
|
69 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
70 |
+
pub struct DeleteOperation {
|
71 |
+
#[validate(nested)]
|
72 |
+
delete: PointsSelector,
|
73 |
+
}
|
74 |
+
|
75 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
76 |
+
pub struct SetPayloadOperation {
|
77 |
+
#[validate(nested)]
|
78 |
+
set_payload: SetPayload,
|
79 |
+
}
|
80 |
+
|
81 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
82 |
+
pub struct OverwritePayloadOperation {
|
83 |
+
#[validate(nested)]
|
84 |
+
overwrite_payload: SetPayload,
|
85 |
+
}
|
86 |
+
|
87 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
88 |
+
pub struct DeletePayloadOperation {
|
89 |
+
#[validate(nested)]
|
90 |
+
delete_payload: DeletePayload,
|
91 |
+
}
|
92 |
+
|
93 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
94 |
+
pub struct ClearPayloadOperation {
|
95 |
+
#[validate(nested)]
|
96 |
+
clear_payload: PointsSelector,
|
97 |
+
}
|
98 |
+
|
99 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
100 |
+
pub struct UpdateVectorsOperation {
|
101 |
+
#[validate(nested)]
|
102 |
+
update_vectors: UpdateVectors,
|
103 |
+
}
|
104 |
+
|
105 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
106 |
+
pub struct DeleteVectorsOperation {
|
107 |
+
#[validate(nested)]
|
108 |
+
delete_vectors: DeleteVectors,
|
109 |
+
}
|
110 |
+
|
111 |
+
#[derive(Deserialize, Serialize, JsonSchema)]
|
112 |
+
#[serde(rename_all = "snake_case")]
|
113 |
+
#[serde(untagged)]
|
114 |
+
pub enum UpdateOperation {
|
115 |
+
Upsert(UpsertOperation),
|
116 |
+
Delete(DeleteOperation),
|
117 |
+
SetPayload(SetPayloadOperation),
|
118 |
+
OverwritePayload(OverwritePayloadOperation),
|
119 |
+
DeletePayload(DeletePayloadOperation),
|
120 |
+
ClearPayload(ClearPayloadOperation),
|
121 |
+
UpdateVectors(UpdateVectorsOperation),
|
122 |
+
DeleteVectors(DeleteVectorsOperation),
|
123 |
+
}
|
124 |
+
|
125 |
+
#[derive(Deserialize, Serialize, JsonSchema, Validate)]
|
126 |
+
pub struct UpdateOperations {
|
127 |
+
pub operations: Vec<UpdateOperation>,
|
128 |
+
}
|
129 |
+
|
130 |
+
impl Validate for UpdateOperation {
|
131 |
+
fn validate(&self) -> Result<(), validator::ValidationErrors> {
|
132 |
+
match self {
|
133 |
+
UpdateOperation::Upsert(op) => op.validate(),
|
134 |
+
UpdateOperation::Delete(op) => op.validate(),
|
135 |
+
UpdateOperation::SetPayload(op) => op.validate(),
|
136 |
+
UpdateOperation::OverwritePayload(op) => op.validate(),
|
137 |
+
UpdateOperation::DeletePayload(op) => op.validate(),
|
138 |
+
UpdateOperation::ClearPayload(op) => op.validate(),
|
139 |
+
UpdateOperation::UpdateVectors(op) => op.validate(),
|
140 |
+
UpdateOperation::DeleteVectors(op) => op.validate(),
|
141 |
+
}
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
impl StrictModeVerification for UpdateOperation {
|
146 |
+
fn query_limit(&self) -> Option<usize> {
|
147 |
+
None
|
148 |
+
}
|
149 |
+
|
150 |
+
fn indexed_filter_read(&self) -> Option<&segment::types::Filter> {
|
151 |
+
None
|
152 |
+
}
|
153 |
+
|
154 |
+
fn indexed_filter_write(&self) -> Option<&segment::types::Filter> {
|
155 |
+
None
|
156 |
+
}
|
157 |
+
|
158 |
+
fn request_exact(&self) -> Option<bool> {
|
159 |
+
None
|
160 |
+
}
|
161 |
+
|
162 |
+
fn request_search_params(&self) -> Option<&segment::types::SearchParams> {
|
163 |
+
None
|
164 |
+
}
|
165 |
+
|
166 |
+
fn check_strict_mode(
|
167 |
+
&self,
|
168 |
+
collection: &Collection,
|
169 |
+
strict_mode_config: &StrictModeConfig,
|
170 |
+
) -> Result<(), CollectionError> {
|
171 |
+
match self {
|
172 |
+
UpdateOperation::Delete(delete_op) => delete_op
|
173 |
+
.delete
|
174 |
+
.check_strict_mode(collection, strict_mode_config),
|
175 |
+
UpdateOperation::SetPayload(set_payload) => set_payload
|
176 |
+
.set_payload
|
177 |
+
.check_strict_mode(collection, strict_mode_config),
|
178 |
+
UpdateOperation::OverwritePayload(overwrite_payload) => overwrite_payload
|
179 |
+
.overwrite_payload
|
180 |
+
.check_strict_mode(collection, strict_mode_config),
|
181 |
+
UpdateOperation::DeletePayload(delete_payload) => delete_payload
|
182 |
+
.delete_payload
|
183 |
+
.check_strict_mode(collection, strict_mode_config),
|
184 |
+
UpdateOperation::ClearPayload(clear_payload) => clear_payload
|
185 |
+
.clear_payload
|
186 |
+
.check_strict_mode(collection, strict_mode_config),
|
187 |
+
UpdateOperation::DeleteVectors(delete_op) => delete_op
|
188 |
+
.delete_vectors
|
189 |
+
.check_strict_mode(collection, strict_mode_config),
|
190 |
+
UpdateOperation::UpdateVectors(_) | UpdateOperation::Upsert(_) => Ok(()),
|
191 |
+
}
|
192 |
+
}
|
193 |
+
}
|
194 |
+
|
195 |
+
/// Converts a pair of parameters into a shard selector
|
196 |
+
/// suitable for update operations.
|
197 |
+
///
|
198 |
+
/// The key difference from selector for search operations is that
|
199 |
+
/// empty shard selector in case of update means default shard,
|
200 |
+
/// while empty shard selector in case of search means all shards.
|
201 |
+
///
|
202 |
+
/// Parameters:
|
203 |
+
/// - shard_selection: selection of the exact shard ID, always have priority over shard_key
|
204 |
+
/// - shard_key: selection of the shard key, can be a single key or a list of keys
|
205 |
+
///
|
206 |
+
/// Returns:
|
207 |
+
/// - ShardSelectorInternal - resolved shard selector
|
208 |
+
fn get_shard_selector_for_update(
|
209 |
+
shard_selection: Option<ShardId>,
|
210 |
+
shard_key: Option<ShardKeySelector>,
|
211 |
+
) -> ShardSelectorInternal {
|
212 |
+
match (shard_selection, shard_key) {
|
213 |
+
(Some(shard_selection), None) => ShardSelectorInternal::ShardId(shard_selection),
|
214 |
+
(Some(shard_selection), Some(_)) => {
|
215 |
+
debug_assert!(
|
216 |
+
false,
|
217 |
+
"Shard selection and shard key are mutually exclusive"
|
218 |
+
);
|
219 |
+
ShardSelectorInternal::ShardId(shard_selection)
|
220 |
+
}
|
221 |
+
(None, Some(shard_key)) => ShardSelectorInternal::from(shard_key),
|
222 |
+
(None, None) => ShardSelectorInternal::Empty,
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
#[allow(clippy::too_many_arguments)]
|
227 |
+
pub async fn do_upsert_points(
|
228 |
+
toc: Arc<TableOfContent>,
|
229 |
+
collection_name: String,
|
230 |
+
operation: PointInsertOperations,
|
231 |
+
clock_tag: Option<ClockTag>,
|
232 |
+
shard_selection: Option<ShardId>,
|
233 |
+
wait: bool,
|
234 |
+
ordering: WriteOrdering,
|
235 |
+
access: Access,
|
236 |
+
) -> Result<UpdateResult, StorageError> {
|
237 |
+
let (shard_key, operation) = match operation {
|
238 |
+
PointInsertOperations::PointsBatch(PointsBatch { batch, shard_key }) => (
|
239 |
+
shard_key,
|
240 |
+
PointInsertOperationsInternal::PointsBatch(convert_batch(batch).await?),
|
241 |
+
),
|
242 |
+
PointInsertOperations::PointsList(PointsList { points, shard_key }) => (
|
243 |
+
shard_key,
|
244 |
+
PointInsertOperationsInternal::PointsList(
|
245 |
+
convert_point_struct(points, InferenceType::Update).await?,
|
246 |
+
),
|
247 |
+
),
|
248 |
+
};
|
249 |
+
|
250 |
+
let collection_operation =
|
251 |
+
CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(operation));
|
252 |
+
|
253 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
254 |
+
|
255 |
+
toc.update(
|
256 |
+
&collection_name,
|
257 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
258 |
+
wait,
|
259 |
+
ordering,
|
260 |
+
shard_selector,
|
261 |
+
access,
|
262 |
+
)
|
263 |
+
.await
|
264 |
+
}
|
265 |
+
|
266 |
+
#[allow(clippy::too_many_arguments)]
|
267 |
+
pub async fn do_delete_points(
|
268 |
+
toc: Arc<TableOfContent>,
|
269 |
+
collection_name: String,
|
270 |
+
points: PointsSelector,
|
271 |
+
clock_tag: Option<ClockTag>,
|
272 |
+
shard_selection: Option<ShardId>,
|
273 |
+
wait: bool,
|
274 |
+
ordering: WriteOrdering,
|
275 |
+
access: Access,
|
276 |
+
) -> Result<UpdateResult, StorageError> {
|
277 |
+
let (point_operation, shard_key) = match points {
|
278 |
+
PointsSelector::PointIdsSelector(PointIdsList { points, shard_key }) => {
|
279 |
+
(PointOperations::DeletePoints { ids: points }, shard_key)
|
280 |
+
}
|
281 |
+
PointsSelector::FilterSelector(FilterSelector { filter, shard_key }) => {
|
282 |
+
(PointOperations::DeletePointsByFilter(filter), shard_key)
|
283 |
+
}
|
284 |
+
};
|
285 |
+
let collection_operation = CollectionUpdateOperations::PointOperation(point_operation);
|
286 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
287 |
+
|
288 |
+
toc.update(
|
289 |
+
&collection_name,
|
290 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
291 |
+
wait,
|
292 |
+
ordering,
|
293 |
+
shard_selector,
|
294 |
+
access,
|
295 |
+
)
|
296 |
+
.await
|
297 |
+
}
|
298 |
+
|
299 |
+
#[allow(clippy::too_many_arguments)]
|
300 |
+
pub async fn do_update_vectors(
|
301 |
+
toc: Arc<TableOfContent>,
|
302 |
+
collection_name: String,
|
303 |
+
operation: UpdateVectors,
|
304 |
+
clock_tag: Option<ClockTag>,
|
305 |
+
shard_selection: Option<ShardId>,
|
306 |
+
wait: bool,
|
307 |
+
ordering: WriteOrdering,
|
308 |
+
access: Access,
|
309 |
+
) -> Result<UpdateResult, StorageError> {
|
310 |
+
let UpdateVectors { points, shard_key } = operation;
|
311 |
+
|
312 |
+
let persisted_points = convert_point_vectors(points, InferenceType::Update).await?;
|
313 |
+
|
314 |
+
let collection_operation = CollectionUpdateOperations::VectorOperation(
|
315 |
+
VectorOperations::UpdateVectors(UpdateVectorsOp {
|
316 |
+
points: persisted_points,
|
317 |
+
}),
|
318 |
+
);
|
319 |
+
|
320 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
321 |
+
|
322 |
+
toc.update(
|
323 |
+
&collection_name,
|
324 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
325 |
+
wait,
|
326 |
+
ordering,
|
327 |
+
shard_selector,
|
328 |
+
access,
|
329 |
+
)
|
330 |
+
.await
|
331 |
+
}
|
332 |
+
|
333 |
+
#[allow(clippy::too_many_arguments)]
|
334 |
+
pub async fn do_delete_vectors(
|
335 |
+
toc: Arc<TableOfContent>,
|
336 |
+
collection_name: String,
|
337 |
+
operation: DeleteVectors,
|
338 |
+
clock_tag: Option<ClockTag>,
|
339 |
+
shard_selection: Option<ShardId>,
|
340 |
+
wait: bool,
|
341 |
+
ordering: WriteOrdering,
|
342 |
+
access: Access,
|
343 |
+
) -> Result<UpdateResult, StorageError> {
|
344 |
+
// TODO: Is this cancel safe!?
|
345 |
+
|
346 |
+
let DeleteVectors {
|
347 |
+
vector,
|
348 |
+
filter,
|
349 |
+
points,
|
350 |
+
shard_key,
|
351 |
+
} = operation;
|
352 |
+
|
353 |
+
let vector_names: Vec<_> = vector.into_iter().collect();
|
354 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
355 |
+
|
356 |
+
let mut result = None;
|
357 |
+
|
358 |
+
if let Some(filter) = filter {
|
359 |
+
let vectors_operation =
|
360 |
+
VectorOperations::DeleteVectorsByFilter(filter, vector_names.clone());
|
361 |
+
|
362 |
+
let collection_operation = CollectionUpdateOperations::VectorOperation(vectors_operation);
|
363 |
+
|
364 |
+
result = Some(
|
365 |
+
toc.update(
|
366 |
+
&collection_name,
|
367 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
368 |
+
wait,
|
369 |
+
ordering,
|
370 |
+
shard_selector.clone(),
|
371 |
+
access.clone(),
|
372 |
+
)
|
373 |
+
.await?,
|
374 |
+
);
|
375 |
+
}
|
376 |
+
|
377 |
+
if let Some(points) = points {
|
378 |
+
let vectors_operation = VectorOperations::DeleteVectors(points.into(), vector_names);
|
379 |
+
let collection_operation = CollectionUpdateOperations::VectorOperation(vectors_operation);
|
380 |
+
result = Some(
|
381 |
+
toc.update(
|
382 |
+
&collection_name,
|
383 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
384 |
+
wait,
|
385 |
+
ordering,
|
386 |
+
shard_selector,
|
387 |
+
access,
|
388 |
+
)
|
389 |
+
.await?,
|
390 |
+
);
|
391 |
+
}
|
392 |
+
|
393 |
+
result.ok_or_else(|| StorageError::bad_request("No filter or points provided"))
|
394 |
+
}
|
395 |
+
|
396 |
+
#[allow(clippy::too_many_arguments)]
|
397 |
+
pub async fn do_set_payload(
|
398 |
+
toc: Arc<TableOfContent>,
|
399 |
+
collection_name: String,
|
400 |
+
operation: SetPayload,
|
401 |
+
clock_tag: Option<ClockTag>,
|
402 |
+
shard_selection: Option<ShardId>,
|
403 |
+
wait: bool,
|
404 |
+
ordering: WriteOrdering,
|
405 |
+
access: Access,
|
406 |
+
) -> Result<UpdateResult, StorageError> {
|
407 |
+
let SetPayload {
|
408 |
+
points,
|
409 |
+
payload,
|
410 |
+
filter,
|
411 |
+
shard_key,
|
412 |
+
key,
|
413 |
+
} = operation;
|
414 |
+
|
415 |
+
let collection_operation =
|
416 |
+
CollectionUpdateOperations::PayloadOperation(PayloadOps::SetPayload(SetPayloadOp {
|
417 |
+
payload,
|
418 |
+
points,
|
419 |
+
filter,
|
420 |
+
key,
|
421 |
+
}));
|
422 |
+
|
423 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
424 |
+
|
425 |
+
toc.update(
|
426 |
+
&collection_name,
|
427 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
428 |
+
wait,
|
429 |
+
ordering,
|
430 |
+
shard_selector,
|
431 |
+
access,
|
432 |
+
)
|
433 |
+
.await
|
434 |
+
}
|
435 |
+
|
436 |
+
#[allow(clippy::too_many_arguments)]
|
437 |
+
pub async fn do_overwrite_payload(
|
438 |
+
toc: Arc<TableOfContent>,
|
439 |
+
collection_name: String,
|
440 |
+
operation: SetPayload,
|
441 |
+
clock_tag: Option<ClockTag>,
|
442 |
+
shard_selection: Option<ShardId>,
|
443 |
+
wait: bool,
|
444 |
+
ordering: WriteOrdering,
|
445 |
+
access: Access,
|
446 |
+
) -> Result<UpdateResult, StorageError> {
|
447 |
+
let SetPayload {
|
448 |
+
points,
|
449 |
+
payload,
|
450 |
+
filter,
|
451 |
+
shard_key,
|
452 |
+
..
|
453 |
+
} = operation;
|
454 |
+
|
455 |
+
let collection_operation =
|
456 |
+
CollectionUpdateOperations::PayloadOperation(PayloadOps::OverwritePayload(SetPayloadOp {
|
457 |
+
payload,
|
458 |
+
points,
|
459 |
+
filter,
|
460 |
+
// overwrite operation doesn't support payload selector
|
461 |
+
key: None,
|
462 |
+
}));
|
463 |
+
|
464 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
465 |
+
|
466 |
+
toc.update(
|
467 |
+
&collection_name,
|
468 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
469 |
+
wait,
|
470 |
+
ordering,
|
471 |
+
shard_selector,
|
472 |
+
access,
|
473 |
+
)
|
474 |
+
.await
|
475 |
+
}
|
476 |
+
|
477 |
+
#[allow(clippy::too_many_arguments)]
|
478 |
+
pub async fn do_delete_payload(
|
479 |
+
toc: Arc<TableOfContent>,
|
480 |
+
collection_name: String,
|
481 |
+
operation: DeletePayload,
|
482 |
+
clock_tag: Option<ClockTag>,
|
483 |
+
shard_selection: Option<ShardId>,
|
484 |
+
wait: bool,
|
485 |
+
ordering: WriteOrdering,
|
486 |
+
access: Access,
|
487 |
+
) -> Result<UpdateResult, StorageError> {
|
488 |
+
let DeletePayload {
|
489 |
+
keys,
|
490 |
+
points,
|
491 |
+
filter,
|
492 |
+
shard_key,
|
493 |
+
} = operation;
|
494 |
+
|
495 |
+
let collection_operation =
|
496 |
+
CollectionUpdateOperations::PayloadOperation(PayloadOps::DeletePayload(DeletePayloadOp {
|
497 |
+
keys,
|
498 |
+
points,
|
499 |
+
filter,
|
500 |
+
}));
|
501 |
+
|
502 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
503 |
+
|
504 |
+
toc.update(
|
505 |
+
&collection_name,
|
506 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
507 |
+
wait,
|
508 |
+
ordering,
|
509 |
+
shard_selector,
|
510 |
+
access,
|
511 |
+
)
|
512 |
+
.await
|
513 |
+
}
|
514 |
+
|
515 |
+
#[allow(clippy::too_many_arguments)]
|
516 |
+
pub async fn do_clear_payload(
|
517 |
+
toc: Arc<TableOfContent>,
|
518 |
+
collection_name: String,
|
519 |
+
points: PointsSelector,
|
520 |
+
clock_tag: Option<ClockTag>,
|
521 |
+
shard_selection: Option<ShardId>,
|
522 |
+
wait: bool,
|
523 |
+
ordering: WriteOrdering,
|
524 |
+
access: Access,
|
525 |
+
) -> Result<UpdateResult, StorageError> {
|
526 |
+
let (point_operation, shard_key) = match points {
|
527 |
+
PointsSelector::PointIdsSelector(PointIdsList { points, shard_key }) => {
|
528 |
+
(PayloadOps::ClearPayload { points }, shard_key)
|
529 |
+
}
|
530 |
+
PointsSelector::FilterSelector(FilterSelector { filter, shard_key }) => {
|
531 |
+
(PayloadOps::ClearPayloadByFilter(filter), shard_key)
|
532 |
+
}
|
533 |
+
};
|
534 |
+
|
535 |
+
let collection_operation = CollectionUpdateOperations::PayloadOperation(point_operation);
|
536 |
+
|
537 |
+
let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
|
538 |
+
|
539 |
+
toc.update(
|
540 |
+
&collection_name,
|
541 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
542 |
+
wait,
|
543 |
+
ordering,
|
544 |
+
shard_selector,
|
545 |
+
access,
|
546 |
+
)
|
547 |
+
.await
|
548 |
+
}
|
549 |
+
|
550 |
+
#[allow(clippy::too_many_arguments)]
|
551 |
+
pub async fn do_batch_update_points(
|
552 |
+
toc: Arc<TableOfContent>,
|
553 |
+
collection_name: String,
|
554 |
+
operations: Vec<UpdateOperation>,
|
555 |
+
clock_tag: Option<ClockTag>,
|
556 |
+
shard_selection: Option<ShardId>,
|
557 |
+
wait: bool,
|
558 |
+
ordering: WriteOrdering,
|
559 |
+
access: Access,
|
560 |
+
) -> Result<Vec<UpdateResult>, StorageError> {
|
561 |
+
let mut results = Vec::with_capacity(operations.len());
|
562 |
+
for operation in operations {
|
563 |
+
let result = match operation {
|
564 |
+
UpdateOperation::Upsert(operation) => {
|
565 |
+
do_upsert_points(
|
566 |
+
toc.clone(),
|
567 |
+
collection_name.clone(),
|
568 |
+
operation.upsert,
|
569 |
+
clock_tag,
|
570 |
+
shard_selection,
|
571 |
+
wait,
|
572 |
+
ordering,
|
573 |
+
access.clone(),
|
574 |
+
)
|
575 |
+
.await
|
576 |
+
}
|
577 |
+
UpdateOperation::Delete(operation) => {
|
578 |
+
do_delete_points(
|
579 |
+
toc.clone(),
|
580 |
+
collection_name.clone(),
|
581 |
+
operation.delete,
|
582 |
+
clock_tag,
|
583 |
+
shard_selection,
|
584 |
+
wait,
|
585 |
+
ordering,
|
586 |
+
access.clone(),
|
587 |
+
)
|
588 |
+
.await
|
589 |
+
}
|
590 |
+
UpdateOperation::SetPayload(operation) => {
|
591 |
+
do_set_payload(
|
592 |
+
toc.clone(),
|
593 |
+
collection_name.clone(),
|
594 |
+
operation.set_payload,
|
595 |
+
clock_tag,
|
596 |
+
shard_selection,
|
597 |
+
wait,
|
598 |
+
ordering,
|
599 |
+
access.clone(),
|
600 |
+
)
|
601 |
+
.await
|
602 |
+
}
|
603 |
+
UpdateOperation::OverwritePayload(operation) => {
|
604 |
+
do_overwrite_payload(
|
605 |
+
toc.clone(),
|
606 |
+
collection_name.clone(),
|
607 |
+
operation.overwrite_payload,
|
608 |
+
clock_tag,
|
609 |
+
shard_selection,
|
610 |
+
wait,
|
611 |
+
ordering,
|
612 |
+
access.clone(),
|
613 |
+
)
|
614 |
+
.await
|
615 |
+
}
|
616 |
+
UpdateOperation::DeletePayload(operation) => {
|
617 |
+
do_delete_payload(
|
618 |
+
toc.clone(),
|
619 |
+
collection_name.clone(),
|
620 |
+
operation.delete_payload,
|
621 |
+
clock_tag,
|
622 |
+
shard_selection,
|
623 |
+
wait,
|
624 |
+
ordering,
|
625 |
+
access.clone(),
|
626 |
+
)
|
627 |
+
.await
|
628 |
+
}
|
629 |
+
UpdateOperation::ClearPayload(operation) => {
|
630 |
+
do_clear_payload(
|
631 |
+
toc.clone(),
|
632 |
+
collection_name.clone(),
|
633 |
+
operation.clear_payload,
|
634 |
+
clock_tag,
|
635 |
+
shard_selection,
|
636 |
+
wait,
|
637 |
+
ordering,
|
638 |
+
access.clone(),
|
639 |
+
)
|
640 |
+
.await
|
641 |
+
}
|
642 |
+
UpdateOperation::UpdateVectors(operation) => {
|
643 |
+
do_update_vectors(
|
644 |
+
toc.clone(),
|
645 |
+
collection_name.clone(),
|
646 |
+
operation.update_vectors,
|
647 |
+
clock_tag,
|
648 |
+
shard_selection,
|
649 |
+
wait,
|
650 |
+
ordering,
|
651 |
+
access.clone(),
|
652 |
+
)
|
653 |
+
.await
|
654 |
+
}
|
655 |
+
UpdateOperation::DeleteVectors(operation) => {
|
656 |
+
do_delete_vectors(
|
657 |
+
toc.clone(),
|
658 |
+
collection_name.clone(),
|
659 |
+
operation.delete_vectors,
|
660 |
+
clock_tag,
|
661 |
+
shard_selection,
|
662 |
+
wait,
|
663 |
+
ordering,
|
664 |
+
access.clone(),
|
665 |
+
)
|
666 |
+
.await
|
667 |
+
}
|
668 |
+
}?;
|
669 |
+
results.push(result);
|
670 |
+
}
|
671 |
+
Ok(results)
|
672 |
+
}
|
673 |
+
|
674 |
+
#[allow(clippy::too_many_arguments)]
|
675 |
+
pub async fn do_create_index_internal(
|
676 |
+
toc: Arc<TableOfContent>,
|
677 |
+
collection_name: String,
|
678 |
+
field_name: PayloadKeyType,
|
679 |
+
field_schema: Option<PayloadFieldSchema>,
|
680 |
+
clock_tag: Option<ClockTag>,
|
681 |
+
shard_selection: Option<ShardId>,
|
682 |
+
wait: bool,
|
683 |
+
ordering: WriteOrdering,
|
684 |
+
) -> Result<UpdateResult, StorageError> {
|
685 |
+
let collection_operation = CollectionUpdateOperations::FieldIndexOperation(
|
686 |
+
FieldIndexOperations::CreateIndex(CreateIndex {
|
687 |
+
field_name,
|
688 |
+
field_schema,
|
689 |
+
}),
|
690 |
+
);
|
691 |
+
|
692 |
+
let shard_selector = if let Some(shard_selection) = shard_selection {
|
693 |
+
ShardSelectorInternal::ShardId(shard_selection)
|
694 |
+
} else {
|
695 |
+
ShardSelectorInternal::All
|
696 |
+
};
|
697 |
+
|
698 |
+
toc.update(
|
699 |
+
&collection_name,
|
700 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
701 |
+
wait,
|
702 |
+
ordering,
|
703 |
+
shard_selector,
|
704 |
+
Access::full("Internal API"),
|
705 |
+
)
|
706 |
+
.await
|
707 |
+
}
|
708 |
+
|
709 |
+
#[allow(clippy::too_many_arguments)]
|
710 |
+
pub async fn do_create_index(
|
711 |
+
dispatcher: Arc<Dispatcher>,
|
712 |
+
collection_name: String,
|
713 |
+
operation: CreateFieldIndex,
|
714 |
+
clock_tag: Option<ClockTag>,
|
715 |
+
shard_selection: Option<ShardId>,
|
716 |
+
wait: bool,
|
717 |
+
ordering: WriteOrdering,
|
718 |
+
access: Access,
|
719 |
+
) -> Result<UpdateResult, StorageError> {
|
720 |
+
// TODO: Is this cancel safe!?
|
721 |
+
|
722 |
+
let Some(field_schema) = operation.field_schema else {
|
723 |
+
return Err(StorageError::bad_request(
|
724 |
+
"Can't auto-detect field type, please specify `field_schema` in the request",
|
725 |
+
));
|
726 |
+
};
|
727 |
+
|
728 |
+
let consensus_op = CollectionMetaOperations::CreatePayloadIndex(CreatePayloadIndex {
|
729 |
+
collection_name: collection_name.to_string(),
|
730 |
+
field_name: operation.field_name.clone(),
|
731 |
+
field_schema: field_schema.clone(),
|
732 |
+
});
|
733 |
+
|
734 |
+
// Default consensus timeout will be used
|
735 |
+
let wait_timeout = None; // ToDo: make it configurable
|
736 |
+
|
737 |
+
// Nothing to verify here.
|
738 |
+
let pass = new_unchecked_verification_pass();
|
739 |
+
|
740 |
+
let toc = dispatcher.toc(&access, &pass).clone();
|
741 |
+
|
742 |
+
// TODO: Is `submit_collection_meta_op` cancel-safe!? Should be, I think?.. 🤔
|
743 |
+
dispatcher
|
744 |
+
.submit_collection_meta_op(consensus_op, access, wait_timeout)
|
745 |
+
.await?;
|
746 |
+
|
747 |
+
// This function is required as long as we want to maintain interface compatibility
|
748 |
+
// for `wait` parameter and return type.
|
749 |
+
// The idea is to migrate from the point-like interface to consensus-like interface in the next few versions
|
750 |
+
|
751 |
+
do_create_index_internal(
|
752 |
+
toc,
|
753 |
+
collection_name,
|
754 |
+
operation.field_name,
|
755 |
+
Some(field_schema),
|
756 |
+
clock_tag,
|
757 |
+
shard_selection,
|
758 |
+
wait,
|
759 |
+
ordering,
|
760 |
+
)
|
761 |
+
.await
|
762 |
+
}
|
763 |
+
|
764 |
+
#[allow(clippy::too_many_arguments)]
|
765 |
+
pub async fn do_delete_index_internal(
|
766 |
+
toc: Arc<TableOfContent>,
|
767 |
+
collection_name: String,
|
768 |
+
index_name: JsonPath,
|
769 |
+
clock_tag: Option<ClockTag>,
|
770 |
+
shard_selection: Option<ShardId>,
|
771 |
+
wait: bool,
|
772 |
+
ordering: WriteOrdering,
|
773 |
+
) -> Result<UpdateResult, StorageError> {
|
774 |
+
let collection_operation = CollectionUpdateOperations::FieldIndexOperation(
|
775 |
+
FieldIndexOperations::DeleteIndex(index_name),
|
776 |
+
);
|
777 |
+
|
778 |
+
let shard_selector = if let Some(shard_selection) = shard_selection {
|
779 |
+
ShardSelectorInternal::ShardId(shard_selection)
|
780 |
+
} else {
|
781 |
+
ShardSelectorInternal::All
|
782 |
+
};
|
783 |
+
|
784 |
+
toc.update(
|
785 |
+
&collection_name,
|
786 |
+
OperationWithClockTag::new(collection_operation, clock_tag),
|
787 |
+
wait,
|
788 |
+
ordering,
|
789 |
+
shard_selector,
|
790 |
+
Access::full("Internal API"),
|
791 |
+
)
|
792 |
+
.await
|
793 |
+
}
|
794 |
+
|
795 |
+
#[allow(clippy::too_many_arguments)]
|
796 |
+
pub async fn do_delete_index(
|
797 |
+
dispatcher: Arc<Dispatcher>,
|
798 |
+
collection_name: String,
|
799 |
+
index_name: JsonPath,
|
800 |
+
clock_tag: Option<ClockTag>,
|
801 |
+
shard_selection: Option<ShardId>,
|
802 |
+
wait: bool,
|
803 |
+
ordering: WriteOrdering,
|
804 |
+
access: Access,
|
805 |
+
) -> Result<UpdateResult, StorageError> {
|
806 |
+
// TODO: Is this cancel safe!?
|
807 |
+
|
808 |
+
let consensus_op = CollectionMetaOperations::DropPayloadIndex(DropPayloadIndex {
|
809 |
+
collection_name: collection_name.to_string(),
|
810 |
+
field_name: index_name.clone(),
|
811 |
+
});
|
812 |
+
|
813 |
+
// Default consensus timeout will be used
|
814 |
+
let wait_timeout = None; // ToDo: make it configurable
|
815 |
+
|
816 |
+
// Nothing to verify here.
|
817 |
+
let pass = new_unchecked_verification_pass();
|
818 |
+
|
819 |
+
let toc = dispatcher.toc(&access, &pass).clone();
|
820 |
+
|
821 |
+
// TODO: Is `submit_collection_meta_op` cancel-safe!? Should be, I think?.. 🤔
|
822 |
+
dispatcher
|
823 |
+
.submit_collection_meta_op(consensus_op, access, wait_timeout)
|
824 |
+
.await?;
|
825 |
+
|
826 |
+
do_delete_index_internal(
|
827 |
+
toc,
|
828 |
+
collection_name,
|
829 |
+
index_name,
|
830 |
+
clock_tag,
|
831 |
+
shard_selection,
|
832 |
+
wait,
|
833 |
+
ordering,
|
834 |
+
)
|
835 |
+
.await
|
836 |
+
}
|
837 |
+
|
838 |
+
#[allow(clippy::too_many_arguments)]
|
839 |
+
pub async fn do_core_search_points(
|
840 |
+
toc: &TableOfContent,
|
841 |
+
collection_name: &str,
|
842 |
+
request: CoreSearchRequest,
|
843 |
+
read_consistency: Option<ReadConsistency>,
|
844 |
+
shard_selection: ShardSelectorInternal,
|
845 |
+
access: Access,
|
846 |
+
timeout: Option<Duration>,
|
847 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
848 |
+
) -> Result<Vec<ScoredPoint>, StorageError> {
|
849 |
+
let batch_res = do_core_search_batch_points(
|
850 |
+
toc,
|
851 |
+
collection_name,
|
852 |
+
CoreSearchRequestBatch {
|
853 |
+
searches: vec![request],
|
854 |
+
},
|
855 |
+
read_consistency,
|
856 |
+
shard_selection,
|
857 |
+
access,
|
858 |
+
timeout,
|
859 |
+
hw_measurement_acc,
|
860 |
+
)
|
861 |
+
.await?;
|
862 |
+
batch_res
|
863 |
+
.into_iter()
|
864 |
+
.next()
|
865 |
+
.ok_or_else(|| StorageError::service_error("Empty search result"))
|
866 |
+
}
|
867 |
+
|
868 |
+
pub async fn do_search_batch_points(
|
869 |
+
toc: &TableOfContent,
|
870 |
+
collection_name: &str,
|
871 |
+
requests: Vec<(CoreSearchRequest, ShardSelectorInternal)>,
|
872 |
+
read_consistency: Option<ReadConsistency>,
|
873 |
+
access: Access,
|
874 |
+
timeout: Option<Duration>,
|
875 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
876 |
+
) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
|
877 |
+
let requests = batch_requests::<
|
878 |
+
(CoreSearchRequest, ShardSelectorInternal),
|
879 |
+
ShardSelectorInternal,
|
880 |
+
Vec<CoreSearchRequest>,
|
881 |
+
Vec<_>,
|
882 |
+
>(
|
883 |
+
requests,
|
884 |
+
|(_, shard_selector)| shard_selector,
|
885 |
+
|(request, _), core_reqs| {
|
886 |
+
core_reqs.push(request);
|
887 |
+
Ok(())
|
888 |
+
},
|
889 |
+
|shard_selector, core_requests, res| {
|
890 |
+
if core_requests.is_empty() {
|
891 |
+
return Ok(());
|
892 |
+
}
|
893 |
+
|
894 |
+
let core_batch = CoreSearchRequestBatch {
|
895 |
+
searches: core_requests,
|
896 |
+
};
|
897 |
+
|
898 |
+
let req = toc.core_search_batch(
|
899 |
+
collection_name,
|
900 |
+
core_batch,
|
901 |
+
read_consistency,
|
902 |
+
shard_selector,
|
903 |
+
access.clone(),
|
904 |
+
timeout,
|
905 |
+
hw_measurement_acc,
|
906 |
+
);
|
907 |
+
res.push(req);
|
908 |
+
Ok(())
|
909 |
+
},
|
910 |
+
)?;
|
911 |
+
|
912 |
+
let results = futures::future::try_join_all(requests).await?;
|
913 |
+
let flatten_results: Vec<Vec<_>> = results.into_iter().flatten().collect();
|
914 |
+
Ok(flatten_results)
|
915 |
+
}
|
916 |
+
|
917 |
+
#[allow(clippy::too_many_arguments)]
|
918 |
+
pub async fn do_core_search_batch_points(
|
919 |
+
toc: &TableOfContent,
|
920 |
+
collection_name: &str,
|
921 |
+
request: CoreSearchRequestBatch,
|
922 |
+
read_consistency: Option<ReadConsistency>,
|
923 |
+
shard_selection: ShardSelectorInternal,
|
924 |
+
access: Access,
|
925 |
+
timeout: Option<Duration>,
|
926 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
927 |
+
) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
|
928 |
+
toc.core_search_batch(
|
929 |
+
collection_name,
|
930 |
+
request,
|
931 |
+
read_consistency,
|
932 |
+
shard_selection,
|
933 |
+
access,
|
934 |
+
timeout,
|
935 |
+
hw_measurement_acc,
|
936 |
+
)
|
937 |
+
.await
|
938 |
+
}
|
939 |
+
|
940 |
+
#[allow(clippy::too_many_arguments)]
|
941 |
+
pub async fn do_search_point_groups(
|
942 |
+
toc: &TableOfContent,
|
943 |
+
collection_name: &str,
|
944 |
+
request: SearchGroupsRequestInternal,
|
945 |
+
read_consistency: Option<ReadConsistency>,
|
946 |
+
shard_selection: ShardSelectorInternal,
|
947 |
+
access: Access,
|
948 |
+
timeout: Option<Duration>,
|
949 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
950 |
+
) -> Result<GroupsResult, StorageError> {
|
951 |
+
toc.group(
|
952 |
+
collection_name,
|
953 |
+
GroupRequest::from(request),
|
954 |
+
read_consistency,
|
955 |
+
shard_selection,
|
956 |
+
access,
|
957 |
+
timeout,
|
958 |
+
hw_measurement_acc,
|
959 |
+
)
|
960 |
+
.await
|
961 |
+
}
|
962 |
+
|
963 |
+
#[allow(clippy::too_many_arguments)]
|
964 |
+
pub async fn do_recommend_point_groups(
|
965 |
+
toc: &TableOfContent,
|
966 |
+
collection_name: &str,
|
967 |
+
request: RecommendGroupsRequestInternal,
|
968 |
+
read_consistency: Option<ReadConsistency>,
|
969 |
+
shard_selection: ShardSelectorInternal,
|
970 |
+
access: Access,
|
971 |
+
timeout: Option<Duration>,
|
972 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
973 |
+
) -> Result<GroupsResult, StorageError> {
|
974 |
+
toc.group(
|
975 |
+
collection_name,
|
976 |
+
GroupRequest::from(request),
|
977 |
+
read_consistency,
|
978 |
+
shard_selection,
|
979 |
+
access,
|
980 |
+
timeout,
|
981 |
+
hw_measurement_acc,
|
982 |
+
)
|
983 |
+
.await
|
984 |
+
}
|
985 |
+
|
986 |
+
pub async fn do_discover_batch_points(
|
987 |
+
toc: &TableOfContent,
|
988 |
+
collection_name: &str,
|
989 |
+
request: DiscoverRequestBatch,
|
990 |
+
read_consistency: Option<ReadConsistency>,
|
991 |
+
access: Access,
|
992 |
+
timeout: Option<Duration>,
|
993 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
994 |
+
) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
|
995 |
+
let requests = request
|
996 |
+
.searches
|
997 |
+
.into_iter()
|
998 |
+
.map(|req| {
|
999 |
+
let shard_selector = match req.shard_key {
|
1000 |
+
None => ShardSelectorInternal::All,
|
1001 |
+
Some(shard_key) => ShardSelectorInternal::from(shard_key),
|
1002 |
+
};
|
1003 |
+
|
1004 |
+
(req.discover_request, shard_selector)
|
1005 |
+
})
|
1006 |
+
.collect();
|
1007 |
+
|
1008 |
+
toc.discover_batch(
|
1009 |
+
collection_name,
|
1010 |
+
requests,
|
1011 |
+
read_consistency,
|
1012 |
+
access,
|
1013 |
+
timeout,
|
1014 |
+
hw_measurement_acc,
|
1015 |
+
)
|
1016 |
+
.await
|
1017 |
+
}
|
1018 |
+
|
1019 |
+
#[allow(clippy::too_many_arguments)]
|
1020 |
+
pub async fn do_count_points(
|
1021 |
+
toc: &TableOfContent,
|
1022 |
+
collection_name: &str,
|
1023 |
+
request: CountRequestInternal,
|
1024 |
+
read_consistency: Option<ReadConsistency>,
|
1025 |
+
timeout: Option<Duration>,
|
1026 |
+
shard_selection: ShardSelectorInternal,
|
1027 |
+
access: Access,
|
1028 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
1029 |
+
) -> Result<CountResult, StorageError> {
|
1030 |
+
toc.count(
|
1031 |
+
collection_name,
|
1032 |
+
request,
|
1033 |
+
read_consistency,
|
1034 |
+
timeout,
|
1035 |
+
shard_selection,
|
1036 |
+
access,
|
1037 |
+
hw_measurement_acc,
|
1038 |
+
)
|
1039 |
+
.await
|
1040 |
+
}
|
1041 |
+
|
1042 |
+
pub async fn do_get_points(
|
1043 |
+
toc: &TableOfContent,
|
1044 |
+
collection_name: &str,
|
1045 |
+
request: PointRequestInternal,
|
1046 |
+
read_consistency: Option<ReadConsistency>,
|
1047 |
+
timeout: Option<Duration>,
|
1048 |
+
shard_selection: ShardSelectorInternal,
|
1049 |
+
access: Access,
|
1050 |
+
) -> Result<Vec<RecordInternal>, StorageError> {
|
1051 |
+
toc.retrieve(
|
1052 |
+
collection_name,
|
1053 |
+
request,
|
1054 |
+
read_consistency,
|
1055 |
+
timeout,
|
1056 |
+
shard_selection,
|
1057 |
+
access,
|
1058 |
+
)
|
1059 |
+
.await
|
1060 |
+
}
|
1061 |
+
|
1062 |
+
pub async fn do_scroll_points(
|
1063 |
+
toc: &TableOfContent,
|
1064 |
+
collection_name: &str,
|
1065 |
+
request: ScrollRequestInternal,
|
1066 |
+
read_consistency: Option<ReadConsistency>,
|
1067 |
+
timeout: Option<Duration>,
|
1068 |
+
shard_selection: ShardSelectorInternal,
|
1069 |
+
access: Access,
|
1070 |
+
) -> Result<ScrollResult, StorageError> {
|
1071 |
+
toc.scroll(
|
1072 |
+
collection_name,
|
1073 |
+
request,
|
1074 |
+
read_consistency,
|
1075 |
+
timeout,
|
1076 |
+
shard_selection,
|
1077 |
+
access,
|
1078 |
+
)
|
1079 |
+
.await
|
1080 |
+
}
|
1081 |
+
|
1082 |
+
#[allow(clippy::too_many_arguments)]
|
1083 |
+
pub async fn do_query_points(
|
1084 |
+
toc: &TableOfContent,
|
1085 |
+
collection_name: &str,
|
1086 |
+
request: CollectionQueryRequest,
|
1087 |
+
read_consistency: Option<ReadConsistency>,
|
1088 |
+
shard_selection: ShardSelectorInternal,
|
1089 |
+
access: Access,
|
1090 |
+
timeout: Option<Duration>,
|
1091 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
1092 |
+
) -> Result<Vec<ScoredPoint>, StorageError> {
|
1093 |
+
let requests = vec![(request, shard_selection)];
|
1094 |
+
let batch_res = toc
|
1095 |
+
.query_batch(
|
1096 |
+
collection_name,
|
1097 |
+
requests,
|
1098 |
+
read_consistency,
|
1099 |
+
access,
|
1100 |
+
timeout,
|
1101 |
+
hw_measurement_acc,
|
1102 |
+
)
|
1103 |
+
.await?;
|
1104 |
+
batch_res
|
1105 |
+
.into_iter()
|
1106 |
+
.next()
|
1107 |
+
.ok_or_else(|| StorageError::service_error("Empty query result"))
|
1108 |
+
}
|
1109 |
+
|
1110 |
+
#[allow(clippy::too_many_arguments)]
|
1111 |
+
pub async fn do_query_batch_points(
|
1112 |
+
toc: &TableOfContent,
|
1113 |
+
collection_name: &str,
|
1114 |
+
requests: Vec<(CollectionQueryRequest, ShardSelectorInternal)>,
|
1115 |
+
read_consistency: Option<ReadConsistency>,
|
1116 |
+
access: Access,
|
1117 |
+
timeout: Option<Duration>,
|
1118 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
1119 |
+
) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
|
1120 |
+
toc.query_batch(
|
1121 |
+
collection_name,
|
1122 |
+
requests,
|
1123 |
+
read_consistency,
|
1124 |
+
access,
|
1125 |
+
timeout,
|
1126 |
+
hw_measurement_acc,
|
1127 |
+
)
|
1128 |
+
.await
|
1129 |
+
}
|
1130 |
+
|
1131 |
+
#[allow(clippy::too_many_arguments)]
|
1132 |
+
pub async fn do_query_point_groups(
|
1133 |
+
toc: &TableOfContent,
|
1134 |
+
collection_name: &str,
|
1135 |
+
request: CollectionQueryGroupsRequest,
|
1136 |
+
read_consistency: Option<ReadConsistency>,
|
1137 |
+
shard_selection: ShardSelectorInternal,
|
1138 |
+
access: Access,
|
1139 |
+
timeout: Option<Duration>,
|
1140 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
1141 |
+
) -> Result<GroupsResult, StorageError> {
|
1142 |
+
toc.group(
|
1143 |
+
collection_name,
|
1144 |
+
GroupRequest::from(request),
|
1145 |
+
read_consistency,
|
1146 |
+
shard_selection,
|
1147 |
+
access,
|
1148 |
+
timeout,
|
1149 |
+
hw_measurement_acc,
|
1150 |
+
)
|
1151 |
+
.await
|
1152 |
+
}
|
1153 |
+
|
1154 |
+
#[allow(clippy::too_many_arguments)]
|
1155 |
+
pub async fn do_search_points_matrix(
|
1156 |
+
toc: &TableOfContent,
|
1157 |
+
collection_name: &str,
|
1158 |
+
request: CollectionSearchMatrixRequest,
|
1159 |
+
read_consistency: Option<ReadConsistency>,
|
1160 |
+
shard_selection: ShardSelectorInternal,
|
1161 |
+
access: Access,
|
1162 |
+
timeout: Option<Duration>,
|
1163 |
+
hw_measurement_acc: &HwMeasurementAcc,
|
1164 |
+
) -> Result<CollectionSearchMatrixResponse, StorageError> {
|
1165 |
+
toc.search_points_matrix(
|
1166 |
+
collection_name,
|
1167 |
+
request,
|
1168 |
+
read_consistency,
|
1169 |
+
shard_selection,
|
1170 |
+
access,
|
1171 |
+
timeout,
|
1172 |
+
hw_measurement_acc,
|
1173 |
+
)
|
1174 |
+
.await
|
1175 |
+
}
|
src/common/pyroscope_state.rs
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#[cfg(target_os = "linux")]
|
2 |
+
pub mod pyro {
|
3 |
+
|
4 |
+
use pyroscope::pyroscope::PyroscopeAgentRunning;
|
5 |
+
use pyroscope::{PyroscopeAgent, PyroscopeError};
|
6 |
+
use pyroscope_pprofrs::{pprof_backend, PprofConfig};
|
7 |
+
|
8 |
+
use crate::common::debugger::PyroscopeConfig;
|
9 |
+
|
10 |
+
pub struct PyroscopeState {
|
11 |
+
pub config: PyroscopeConfig,
|
12 |
+
pub agent: Option<PyroscopeAgent<PyroscopeAgentRunning>>,
|
13 |
+
}
|
14 |
+
|
15 |
+
impl PyroscopeState {
|
16 |
+
fn build_agent(
|
17 |
+
config: &PyroscopeConfig,
|
18 |
+
) -> Result<PyroscopeAgent<PyroscopeAgentRunning>, PyroscopeError> {
|
19 |
+
let pprof_config = PprofConfig::new().sample_rate(config.sampling_rate.unwrap_or(100));
|
20 |
+
let backend_impl = pprof_backend(pprof_config);
|
21 |
+
|
22 |
+
log::info!(
|
23 |
+
"Starting pyroscope agent with identifier {}",
|
24 |
+
&config.identifier
|
25 |
+
);
|
26 |
+
// TODO: Add more tags like peerId and peerUrl
|
27 |
+
let agent = PyroscopeAgent::builder(config.url.to_string(), "qdrant".to_string())
|
28 |
+
.backend(backend_impl)
|
29 |
+
.tags(vec![("app", "Qdrant"), ("identifier", &config.identifier)])
|
30 |
+
.build()?;
|
31 |
+
let running_agent = agent.start()?;
|
32 |
+
|
33 |
+
Ok(running_agent)
|
34 |
+
}
|
35 |
+
|
36 |
+
pub fn from_config(config: Option<PyroscopeConfig>) -> Option<Self> {
|
37 |
+
match config {
|
38 |
+
Some(pyro_config) => {
|
39 |
+
let agent = PyroscopeState::build_agent(&pyro_config);
|
40 |
+
match agent {
|
41 |
+
Ok(agent) => Some(PyroscopeState {
|
42 |
+
config: pyro_config,
|
43 |
+
agent: Some(agent),
|
44 |
+
}),
|
45 |
+
Err(err) => {
|
46 |
+
log::warn!("Pyroscope agent failed to start {}", err);
|
47 |
+
None
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
None => None,
|
52 |
+
}
|
53 |
+
}
|
54 |
+
|
55 |
+
pub fn stop_agent(&mut self) -> bool {
|
56 |
+
log::info!("Stopping pyroscope agent");
|
57 |
+
if let Some(agent) = self.agent.take() {
|
58 |
+
match agent.stop() {
|
59 |
+
Ok(stopped_agent) => {
|
60 |
+
log::info!("Stopped pyroscope agent. Shutting it down");
|
61 |
+
stopped_agent.shutdown();
|
62 |
+
log::info!("Pyroscope agent shut down completed.");
|
63 |
+
return true;
|
64 |
+
}
|
65 |
+
Err(err) => {
|
66 |
+
log::warn!("Pyroscope agent failed to stop {}", err);
|
67 |
+
return false;
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
71 |
+
true
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
impl Drop for PyroscopeState {
|
76 |
+
fn drop(&mut self) {
|
77 |
+
self.stop_agent();
|
78 |
+
}
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
#[cfg(not(target_os = "linux"))]
|
83 |
+
pub mod pyro {
|
84 |
+
use crate::common::debugger::PyroscopeConfig;
|
85 |
+
|
86 |
+
pub struct PyroscopeState {}
|
87 |
+
|
88 |
+
impl PyroscopeState {
|
89 |
+
pub fn from_config(_config: Option<PyroscopeConfig>) -> Option<Self> {
|
90 |
+
None
|
91 |
+
}
|
92 |
+
}
|
93 |
+
}
|
src/common/snapshots.rs
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::sync::Arc;
|
2 |
+
|
3 |
+
use collection::collection::Collection;
|
4 |
+
use collection::common::sha_256::hash_file;
|
5 |
+
use collection::common::snapshot_stream::SnapshotStream;
|
6 |
+
use collection::operations::snapshot_ops::{
|
7 |
+
ShardSnapshotLocation, SnapshotDescription, SnapshotPriority,
|
8 |
+
};
|
9 |
+
use collection::shards::replica_set::ReplicaState;
|
10 |
+
use collection::shards::shard::ShardId;
|
11 |
+
use storage::content_manager::errors::StorageError;
|
12 |
+
use storage::content_manager::snapshots;
|
13 |
+
use storage::content_manager::toc::TableOfContent;
|
14 |
+
use storage::rbac::{Access, AccessRequirements};
|
15 |
+
|
16 |
+
use super::http_client::HttpClient;
|
17 |
+
|
18 |
+
/// # Cancel safety
|
19 |
+
///
|
20 |
+
/// This function is cancel safe.
|
21 |
+
pub async fn create_shard_snapshot(
|
22 |
+
toc: Arc<TableOfContent>,
|
23 |
+
access: Access,
|
24 |
+
collection_name: String,
|
25 |
+
shard_id: ShardId,
|
26 |
+
) -> Result<SnapshotDescription, StorageError> {
|
27 |
+
let collection_pass = access
|
28 |
+
.check_collection_access(&collection_name, AccessRequirements::new().write().whole())?;
|
29 |
+
let collection = toc.get_collection(&collection_pass).await?;
|
30 |
+
|
31 |
+
let snapshot = collection
|
32 |
+
.create_shard_snapshot(shard_id, &toc.optional_temp_or_snapshot_temp_path()?)
|
33 |
+
.await?;
|
34 |
+
|
35 |
+
Ok(snapshot)
|
36 |
+
}
|
37 |
+
|
38 |
+
/// # Cancel safety
|
39 |
+
///
|
40 |
+
/// This function is cancel safe.
|
41 |
+
pub async fn stream_shard_snapshot(
|
42 |
+
toc: Arc<TableOfContent>,
|
43 |
+
access: Access,
|
44 |
+
collection_name: String,
|
45 |
+
shard_id: ShardId,
|
46 |
+
) -> Result<SnapshotStream, StorageError> {
|
47 |
+
let collection_pass = access
|
48 |
+
.check_collection_access(&collection_name, AccessRequirements::new().write().whole())?;
|
49 |
+
let collection = toc.get_collection(&collection_pass).await?;
|
50 |
+
|
51 |
+
Ok(collection
|
52 |
+
.stream_shard_snapshot(shard_id, &toc.optional_temp_or_snapshot_temp_path()?)
|
53 |
+
.await?)
|
54 |
+
}
|
55 |
+
|
56 |
+
/// # Cancel safety
|
57 |
+
///
|
58 |
+
/// This function is cancel safe.
|
59 |
+
pub async fn list_shard_snapshots(
|
60 |
+
toc: Arc<TableOfContent>,
|
61 |
+
access: Access,
|
62 |
+
collection_name: String,
|
63 |
+
shard_id: ShardId,
|
64 |
+
) -> Result<Vec<SnapshotDescription>, StorageError> {
|
65 |
+
let collection_pass =
|
66 |
+
access.check_collection_access(&collection_name, AccessRequirements::new().whole())?;
|
67 |
+
let collection = toc.get_collection(&collection_pass).await?;
|
68 |
+
let snapshots = collection.list_shard_snapshots(shard_id).await?;
|
69 |
+
Ok(snapshots)
|
70 |
+
}
|
71 |
+
|
72 |
+
/// # Cancel safety
|
73 |
+
///
|
74 |
+
/// This function is cancel safe.
|
75 |
+
pub async fn delete_shard_snapshot(
|
76 |
+
toc: Arc<TableOfContent>,
|
77 |
+
access: Access,
|
78 |
+
collection_name: String,
|
79 |
+
shard_id: ShardId,
|
80 |
+
snapshot_name: String,
|
81 |
+
) -> Result<(), StorageError> {
|
82 |
+
let collection_pass = access
|
83 |
+
.check_collection_access(&collection_name, AccessRequirements::new().write().whole())?;
|
84 |
+
let collection = toc.get_collection(&collection_pass).await?;
|
85 |
+
let snapshot_manager = collection.get_snapshots_storage_manager()?;
|
86 |
+
|
87 |
+
let snapshot_path = collection
|
88 |
+
.shards_holder()
|
89 |
+
.read()
|
90 |
+
.await
|
91 |
+
.get_shard_snapshot_path(collection.snapshots_path(), shard_id, &snapshot_name)
|
92 |
+
.await?;
|
93 |
+
|
94 |
+
tokio::spawn(async move { snapshot_manager.delete_snapshot(&snapshot_path).await }).await??;
|
95 |
+
|
96 |
+
Ok(())
|
97 |
+
}
|
98 |
+
|
99 |
+
/// # Cancel safety
|
100 |
+
///
|
101 |
+
/// This function is cancel safe.
|
102 |
+
#[allow(clippy::too_many_arguments)]
|
103 |
+
pub async fn recover_shard_snapshot(
|
104 |
+
toc: Arc<TableOfContent>,
|
105 |
+
access: Access,
|
106 |
+
collection_name: String,
|
107 |
+
shard_id: ShardId,
|
108 |
+
snapshot_location: ShardSnapshotLocation,
|
109 |
+
snapshot_priority: SnapshotPriority,
|
110 |
+
checksum: Option<String>,
|
111 |
+
client: HttpClient,
|
112 |
+
api_key: Option<String>,
|
113 |
+
) -> Result<(), StorageError> {
|
114 |
+
let collection_pass = access
|
115 |
+
.check_global_access(AccessRequirements::new().manage())?
|
116 |
+
.issue_pass(&collection_name)
|
117 |
+
.into_static();
|
118 |
+
|
119 |
+
// - `recover_shard_snapshot_impl` is *not* cancel safe
|
120 |
+
// - but the task is *spawned* on the runtime and won't be cancelled, if request is cancelled
|
121 |
+
|
122 |
+
cancel::future::spawn_cancel_on_drop(move |cancel| async move {
|
123 |
+
let future = async {
|
124 |
+
let collection = toc.get_collection(&collection_pass).await?;
|
125 |
+
collection.assert_shard_exists(shard_id).await?;
|
126 |
+
|
127 |
+
let download_dir = toc.optional_temp_or_snapshot_temp_path()?;
|
128 |
+
|
129 |
+
let snapshot_path = match snapshot_location {
|
130 |
+
ShardSnapshotLocation::Url(url) => {
|
131 |
+
if !matches!(url.scheme(), "http" | "https") {
|
132 |
+
let description = format!(
|
133 |
+
"Invalid snapshot URL {url}: URLs with {} scheme are not supported",
|
134 |
+
url.scheme(),
|
135 |
+
);
|
136 |
+
|
137 |
+
return Err(StorageError::bad_input(description));
|
138 |
+
}
|
139 |
+
|
140 |
+
let client = client.client(api_key.as_deref())?;
|
141 |
+
|
142 |
+
snapshots::download::download_snapshot(&client, url, &download_dir).await?
|
143 |
+
}
|
144 |
+
|
145 |
+
ShardSnapshotLocation::Path(snapshot_file_name) => {
|
146 |
+
let snapshot_path = collection
|
147 |
+
.shards_holder()
|
148 |
+
.read()
|
149 |
+
.await
|
150 |
+
.get_shard_snapshot_path(
|
151 |
+
collection.snapshots_path(),
|
152 |
+
shard_id,
|
153 |
+
&snapshot_file_name,
|
154 |
+
)
|
155 |
+
.await?;
|
156 |
+
|
157 |
+
collection
|
158 |
+
.get_snapshots_storage_manager()?
|
159 |
+
.get_snapshot_file(&snapshot_path, &download_dir)
|
160 |
+
.await?
|
161 |
+
}
|
162 |
+
};
|
163 |
+
|
164 |
+
if let Some(checksum) = checksum {
|
165 |
+
let snapshot_checksum = hash_file(&snapshot_path).await?;
|
166 |
+
if snapshot_checksum != checksum {
|
167 |
+
return Err(StorageError::bad_input(format!(
|
168 |
+
"Snapshot checksum mismatch: expected {checksum}, got {snapshot_checksum}"
|
169 |
+
)));
|
170 |
+
}
|
171 |
+
}
|
172 |
+
|
173 |
+
Result::<_, StorageError>::Ok((collection, snapshot_path))
|
174 |
+
};
|
175 |
+
|
176 |
+
let (collection, snapshot_path) =
|
177 |
+
cancel::future::cancel_on_token(cancel.clone(), future).await??;
|
178 |
+
|
179 |
+
// `recover_shard_snapshot_impl` is *not* cancel safe
|
180 |
+
let result = recover_shard_snapshot_impl(
|
181 |
+
&toc,
|
182 |
+
&collection,
|
183 |
+
shard_id,
|
184 |
+
&snapshot_path,
|
185 |
+
snapshot_priority,
|
186 |
+
cancel,
|
187 |
+
)
|
188 |
+
.await;
|
189 |
+
|
190 |
+
// Remove snapshot after recovery if downloaded
|
191 |
+
if let Err(err) = snapshot_path.close() {
|
192 |
+
log::error!("Failed to remove downloaded shards snapshot after recovery: {err}");
|
193 |
+
}
|
194 |
+
|
195 |
+
result
|
196 |
+
})
|
197 |
+
.await??;
|
198 |
+
|
199 |
+
Ok(())
|
200 |
+
}
|
201 |
+
|
202 |
+
/// # Cancel safety
|
203 |
+
///
|
204 |
+
/// This function is *not* cancel safe.
|
205 |
+
pub async fn recover_shard_snapshot_impl(
|
206 |
+
toc: &TableOfContent,
|
207 |
+
collection: &Collection,
|
208 |
+
shard: ShardId,
|
209 |
+
snapshot_path: &std::path::Path,
|
210 |
+
priority: SnapshotPriority,
|
211 |
+
cancel: cancel::CancellationToken,
|
212 |
+
) -> Result<(), StorageError> {
|
213 |
+
// `Collection::restore_shard_snapshot` and `activate_shard` calls *have to* be executed as a
|
214 |
+
// single transaction
|
215 |
+
//
|
216 |
+
// It is *possible* to make this function to be cancel safe, but it is *extremely tedious* to do so
|
217 |
+
|
218 |
+
// `Collection::restore_shard_snapshot` is *not* cancel safe
|
219 |
+
// (see `ShardReplicaSet::restore_local_replica_from`)
|
220 |
+
collection
|
221 |
+
.restore_shard_snapshot(
|
222 |
+
shard,
|
223 |
+
snapshot_path,
|
224 |
+
toc.this_peer_id,
|
225 |
+
toc.is_distributed(),
|
226 |
+
&toc.optional_temp_or_snapshot_temp_path()?,
|
227 |
+
cancel,
|
228 |
+
)
|
229 |
+
.await?;
|
230 |
+
|
231 |
+
let state = collection.state().await;
|
232 |
+
let shard_info = state.shards.get(&shard).unwrap(); // TODO: Handle `unwrap`?..
|
233 |
+
|
234 |
+
// TODO: Unify (and de-duplicate) "recovered shard state notification" logic in `_do_recover_from_snapshot` with this one!
|
235 |
+
|
236 |
+
let other_active_replicas: Vec<_> = shard_info
|
237 |
+
.replicas
|
238 |
+
.iter()
|
239 |
+
.map(|(&peer, &state)| (peer, state))
|
240 |
+
.filter(|&(peer, state)| peer != toc.this_peer_id && state == ReplicaState::Active)
|
241 |
+
.collect();
|
242 |
+
|
243 |
+
if other_active_replicas.is_empty() {
|
244 |
+
snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard).await?;
|
245 |
+
} else {
|
246 |
+
match priority {
|
247 |
+
SnapshotPriority::NoSync => {
|
248 |
+
snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard)
|
249 |
+
.await?;
|
250 |
+
}
|
251 |
+
|
252 |
+
SnapshotPriority::Snapshot => {
|
253 |
+
snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard)
|
254 |
+
.await?;
|
255 |
+
|
256 |
+
for &(peer, _) in other_active_replicas.iter() {
|
257 |
+
toc.send_set_replica_state_proposal(
|
258 |
+
collection.name(),
|
259 |
+
peer,
|
260 |
+
shard,
|
261 |
+
ReplicaState::Dead,
|
262 |
+
None,
|
263 |
+
)?;
|
264 |
+
}
|
265 |
+
}
|
266 |
+
|
267 |
+
SnapshotPriority::Replica => {
|
268 |
+
toc.send_set_replica_state_proposal(
|
269 |
+
collection.name(),
|
270 |
+
toc.this_peer_id,
|
271 |
+
shard,
|
272 |
+
ReplicaState::Dead,
|
273 |
+
None,
|
274 |
+
)?;
|
275 |
+
}
|
276 |
+
|
277 |
+
// `ShardTransfer` is only used during snapshot *shard transfer*.
|
278 |
+
// State transitions are performed as part of shard transfer *later*, so this simply does *nothing*.
|
279 |
+
SnapshotPriority::ShardTransfer => (),
|
280 |
+
}
|
281 |
+
}
|
282 |
+
|
283 |
+
Ok(())
|
284 |
+
}
|
src/common/stacktrace.rs
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use schemars::JsonSchema;
|
2 |
+
use serde::{Deserialize, Serialize};
|
3 |
+
|
4 |
+
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
5 |
+
struct StackTraceSymbol {
|
6 |
+
name: Option<String>,
|
7 |
+
file: Option<String>,
|
8 |
+
line: Option<u32>,
|
9 |
+
}
|
10 |
+
|
11 |
+
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
12 |
+
struct StackTraceFrame {
|
13 |
+
symbols: Vec<StackTraceSymbol>,
|
14 |
+
}
|
15 |
+
|
16 |
+
impl StackTraceFrame {
|
17 |
+
pub fn render(&self) -> String {
|
18 |
+
let mut result = String::new();
|
19 |
+
for symbol in &self.symbols {
|
20 |
+
let symbol_string = format!(
|
21 |
+
"{}:{} - {} ",
|
22 |
+
symbol.file.as_deref().unwrap_or_default(),
|
23 |
+
symbol.line.unwrap_or_default(),
|
24 |
+
symbol.name.as_deref().unwrap_or_default(),
|
25 |
+
);
|
26 |
+
result.push_str(&symbol_string);
|
27 |
+
}
|
28 |
+
result
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
33 |
+
pub struct ThreadStackTrace {
|
34 |
+
id: u32,
|
35 |
+
name: String,
|
36 |
+
frames: Vec<String>,
|
37 |
+
}
|
38 |
+
|
39 |
+
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
40 |
+
pub struct StackTrace {
|
41 |
+
threads: Vec<ThreadStackTrace>,
|
42 |
+
}
|
43 |
+
|
44 |
+
pub fn get_stack_trace() -> StackTrace {
|
45 |
+
#[cfg(not(all(target_os = "linux", feature = "stacktrace")))]
|
46 |
+
{
|
47 |
+
StackTrace { threads: vec![] }
|
48 |
+
}
|
49 |
+
|
50 |
+
#[cfg(all(target_os = "linux", feature = "stacktrace"))]
|
51 |
+
{
|
52 |
+
let exe = std::env::current_exe().unwrap();
|
53 |
+
let trace =
|
54 |
+
rstack_self::trace(std::process::Command::new(exe).arg("--stacktrace")).unwrap();
|
55 |
+
StackTrace {
|
56 |
+
threads: trace
|
57 |
+
.threads()
|
58 |
+
.iter()
|
59 |
+
.map(|thread| ThreadStackTrace {
|
60 |
+
id: thread.id(),
|
61 |
+
name: thread.name().to_string(),
|
62 |
+
frames: thread
|
63 |
+
.frames()
|
64 |
+
.iter()
|
65 |
+
.map(|frame| {
|
66 |
+
let frame = StackTraceFrame {
|
67 |
+
symbols: frame
|
68 |
+
.symbols()
|
69 |
+
.iter()
|
70 |
+
.map(|symbol| StackTraceSymbol {
|
71 |
+
name: symbol.name().map(|name| name.to_string()),
|
72 |
+
file: symbol.file().map(|file| {
|
73 |
+
file.to_str().unwrap_or_default().to_string()
|
74 |
+
}),
|
75 |
+
line: symbol.line(),
|
76 |
+
})
|
77 |
+
.collect(),
|
78 |
+
};
|
79 |
+
frame.render()
|
80 |
+
})
|
81 |
+
.collect(),
|
82 |
+
})
|
83 |
+
.collect(),
|
84 |
+
}
|
85 |
+
}
|
86 |
+
}
|
src/common/strings.rs
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/// Constant-time equality for String types
|
2 |
+
#[inline]
|
3 |
+
pub fn ct_eq(lhs: impl AsRef<str>, rhs: impl AsRef<str>) -> bool {
|
4 |
+
constant_time_eq::constant_time_eq(lhs.as_ref().as_bytes(), rhs.as_ref().as_bytes())
|
5 |
+
}
|
src/common/telemetry.rs
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::sync::Arc;
|
2 |
+
|
3 |
+
use collection::operations::verification::new_unchecked_verification_pass;
|
4 |
+
use common::types::{DetailsLevel, TelemetryDetail};
|
5 |
+
use parking_lot::Mutex;
|
6 |
+
use schemars::JsonSchema;
|
7 |
+
use segment::common::anonymize::Anonymize;
|
8 |
+
use serde::Serialize;
|
9 |
+
use storage::dispatcher::Dispatcher;
|
10 |
+
use storage::rbac::Access;
|
11 |
+
use uuid::Uuid;
|
12 |
+
|
13 |
+
use crate::common::telemetry_ops::app_telemetry::{AppBuildTelemetry, AppBuildTelemetryCollector};
|
14 |
+
use crate::common::telemetry_ops::cluster_telemetry::ClusterTelemetry;
|
15 |
+
use crate::common::telemetry_ops::collections_telemetry::CollectionsTelemetry;
|
16 |
+
use crate::common::telemetry_ops::memory_telemetry::MemoryTelemetry;
|
17 |
+
use crate::common::telemetry_ops::requests_telemetry::{
|
18 |
+
ActixTelemetryCollector, RequestsTelemetry, TonicTelemetryCollector,
|
19 |
+
};
|
20 |
+
use crate::settings::Settings;
|
21 |
+
|
22 |
+
pub struct TelemetryCollector {
|
23 |
+
process_id: Uuid,
|
24 |
+
settings: Settings,
|
25 |
+
dispatcher: Arc<Dispatcher>,
|
26 |
+
pub app_telemetry_collector: AppBuildTelemetryCollector,
|
27 |
+
pub actix_telemetry_collector: Arc<Mutex<ActixTelemetryCollector>>,
|
28 |
+
pub tonic_telemetry_collector: Arc<Mutex<TonicTelemetryCollector>>,
|
29 |
+
}
|
30 |
+
|
31 |
+
// Whole telemetry data
|
32 |
+
#[derive(Serialize, Clone, Debug, JsonSchema)]
|
33 |
+
pub struct TelemetryData {
|
34 |
+
id: String,
|
35 |
+
pub(crate) app: AppBuildTelemetry,
|
36 |
+
pub(crate) collections: CollectionsTelemetry,
|
37 |
+
pub(crate) cluster: ClusterTelemetry,
|
38 |
+
pub(crate) requests: RequestsTelemetry,
|
39 |
+
pub(crate) memory: Option<MemoryTelemetry>,
|
40 |
+
}
|
41 |
+
|
42 |
+
impl Anonymize for TelemetryData {
|
43 |
+
fn anonymize(&self) -> Self {
|
44 |
+
TelemetryData {
|
45 |
+
id: self.id.clone(),
|
46 |
+
app: self.app.anonymize(),
|
47 |
+
collections: self.collections.anonymize(),
|
48 |
+
cluster: self.cluster.anonymize(),
|
49 |
+
requests: self.requests.anonymize(),
|
50 |
+
memory: self.memory.anonymize(),
|
51 |
+
}
|
52 |
+
}
|
53 |
+
}
|
54 |
+
|
55 |
+
impl TelemetryCollector {
|
56 |
+
pub fn reporting_id(&self) -> String {
|
57 |
+
self.process_id.to_string()
|
58 |
+
}
|
59 |
+
|
60 |
+
pub fn generate_id() -> Uuid {
|
61 |
+
Uuid::new_v4()
|
62 |
+
}
|
63 |
+
|
64 |
+
pub fn new(settings: Settings, dispatcher: Arc<Dispatcher>, id: Uuid) -> Self {
|
65 |
+
Self {
|
66 |
+
process_id: id,
|
67 |
+
settings,
|
68 |
+
dispatcher,
|
69 |
+
app_telemetry_collector: AppBuildTelemetryCollector::new(),
|
70 |
+
actix_telemetry_collector: Arc::new(Mutex::new(ActixTelemetryCollector {
|
71 |
+
workers: Vec::new(),
|
72 |
+
})),
|
73 |
+
tonic_telemetry_collector: Arc::new(Mutex::new(TonicTelemetryCollector {
|
74 |
+
workers: Vec::new(),
|
75 |
+
})),
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
pub async fn prepare_data(&self, access: &Access, detail: TelemetryDetail) -> TelemetryData {
|
80 |
+
TelemetryData {
|
81 |
+
id: self.process_id.to_string(),
|
82 |
+
collections: CollectionsTelemetry::collect(
|
83 |
+
detail,
|
84 |
+
access,
|
85 |
+
self.dispatcher
|
86 |
+
.toc(access, &new_unchecked_verification_pass()),
|
87 |
+
)
|
88 |
+
.await,
|
89 |
+
app: AppBuildTelemetry::collect(detail, &self.app_telemetry_collector, &self.settings),
|
90 |
+
cluster: ClusterTelemetry::collect(detail, &self.dispatcher, &self.settings),
|
91 |
+
requests: RequestsTelemetry::collect(
|
92 |
+
&self.actix_telemetry_collector.lock(),
|
93 |
+
&self.tonic_telemetry_collector.lock(),
|
94 |
+
detail,
|
95 |
+
),
|
96 |
+
memory: (detail.level > DetailsLevel::Level0)
|
97 |
+
.then(MemoryTelemetry::collect)
|
98 |
+
.flatten(),
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|