Gouzi Mohaled commited on
Commit
d8435ba
·
1 Parent(s): 3932407

Ajout du dossier src

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/actix/actix_telemetry.rs +90 -0
  2. src/actix/api/cluster_api.rs +189 -0
  3. src/actix/api/collections_api.rs +256 -0
  4. src/actix/api/count_api.rs +69 -0
  5. src/actix/api/debug_api.rs +36 -0
  6. src/actix/api/discovery_api.rs +140 -0
  7. src/actix/api/facet_api.rs +77 -0
  8. src/actix/api/issues_api.rs +32 -0
  9. src/actix/api/local_shard_api.rs +267 -0
  10. src/actix/api/mod.rs +46 -0
  11. src/actix/api/query_api.rs +232 -0
  12. src/actix/api/read_params.rs +118 -0
  13. src/actix/api/recommend_api.rs +235 -0
  14. src/actix/api/retrieve_api.rs +200 -0
  15. src/actix/api/search_api.rs +333 -0
  16. src/actix/api/service_api.rs +217 -0
  17. src/actix/api/shards_api.rs +80 -0
  18. src/actix/api/snapshot_api.rs +585 -0
  19. src/actix/api/update_api.rs +392 -0
  20. src/actix/auth.rs +160 -0
  21. src/actix/certificate_helpers.rs +203 -0
  22. src/actix/helpers.rs +179 -0
  23. src/actix/mod.rs +262 -0
  24. src/actix/web_ui.rs +115 -0
  25. src/common/auth/claims.rs +69 -0
  26. src/common/auth/jwt_parser.rs +155 -0
  27. src/common/auth/mod.rs +165 -0
  28. src/common/collections.rs +834 -0
  29. src/common/debugger.rs +90 -0
  30. src/common/error_reporting.rs +31 -0
  31. src/common/health.rs +372 -0
  32. src/common/helpers.rs +151 -0
  33. src/common/http_client.rs +156 -0
  34. src/common/inference/batch_processing.rs +370 -0
  35. src/common/inference/batch_processing_grpc.rs +281 -0
  36. src/common/inference/config.rs +23 -0
  37. src/common/inference/infer_processing.rs +72 -0
  38. src/common/inference/mod.rs +8 -0
  39. src/common/inference/query_requests_grpc.rs +535 -0
  40. src/common/inference/query_requests_rest.rs +415 -0
  41. src/common/inference/service.rs +266 -0
  42. src/common/inference/update_requests.rs +409 -0
  43. src/common/metrics.rs +505 -0
  44. src/common/mod.rs +31 -0
  45. src/common/points.rs +1175 -0
  46. src/common/pyroscope_state.rs +93 -0
  47. src/common/snapshots.rs +284 -0
  48. src/common/stacktrace.rs +86 -0
  49. src/common/strings.rs +5 -0
  50. src/common/telemetry.rs +101 -0
src/actix/actix_telemetry.rs ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::future::{ready, Ready};
2
+ use std::sync::Arc;
3
+
4
+ use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
5
+ use actix_web::Error;
6
+ use futures_util::future::LocalBoxFuture;
7
+ use parking_lot::Mutex;
8
+
9
+ use crate::common::telemetry_ops::requests_telemetry::{
10
+ ActixTelemetryCollector, ActixWorkerTelemetryCollector,
11
+ };
12
+
13
+ pub struct ActixTelemetryService<S> {
14
+ service: S,
15
+ telemetry_data: Arc<Mutex<ActixWorkerTelemetryCollector>>,
16
+ }
17
+
18
+ pub struct ActixTelemetryTransform {
19
+ telemetry_collector: Arc<Mutex<ActixTelemetryCollector>>,
20
+ }
21
+
22
+ /// Actix telemetry service. It hooks every request and looks into response status code.
23
+ ///
24
+ /// More about actix service with similar example
25
+ /// <https://actix.rs/docs/middleware/>
26
+ impl<S, B> Service<ServiceRequest> for ActixTelemetryService<S>
27
+ where
28
+ S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
29
+ S::Future: 'static,
30
+ B: 'static,
31
+ {
32
+ type Response = ServiceResponse<B>;
33
+ type Error = Error;
34
+ type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
35
+
36
+ actix_web::dev::forward_ready!(service);
37
+
38
+ fn call(&self, request: ServiceRequest) -> Self::Future {
39
+ let match_pattern = request
40
+ .match_pattern()
41
+ .unwrap_or_else(|| "unknown".to_owned());
42
+ let request_key = format!("{} {}", request.method(), match_pattern);
43
+ let future = self.service.call(request);
44
+ let telemetry_data = self.telemetry_data.clone();
45
+ Box::pin(async move {
46
+ let instant = std::time::Instant::now();
47
+ let response = future.await?;
48
+ let status = response.response().status().as_u16();
49
+ telemetry_data
50
+ .lock()
51
+ .add_response(request_key, status, instant);
52
+ Ok(response)
53
+ })
54
+ }
55
+ }
56
+
57
+ impl ActixTelemetryTransform {
58
+ pub fn new(telemetry_collector: Arc<Mutex<ActixTelemetryCollector>>) -> Self {
59
+ Self {
60
+ telemetry_collector,
61
+ }
62
+ }
63
+ }
64
+
65
+ /// Actix telemetry transform. It's a builder for an actix service
66
+ ///
67
+ /// More about actix transform with similar example
68
+ /// <https://actix.rs/docs/middleware/>
69
+ impl<S, B> Transform<S, ServiceRequest> for ActixTelemetryTransform
70
+ where
71
+ S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
72
+ S::Future: 'static,
73
+ B: 'static,
74
+ {
75
+ type Response = ServiceResponse<B>;
76
+ type Error = Error;
77
+ type Transform = ActixTelemetryService<S>;
78
+ type InitError = ();
79
+ type Future = Ready<Result<Self::Transform, Self::InitError>>;
80
+
81
+ fn new_transform(&self, service: S) -> Self::Future {
82
+ ready(Ok(ActixTelemetryService {
83
+ service,
84
+ telemetry_data: self
85
+ .telemetry_collector
86
+ .lock()
87
+ .create_web_worker_telemetry(),
88
+ }))
89
+ }
90
+ }
src/actix/api/cluster_api.rs ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::future::Future;
2
+
3
+ use actix_web::{delete, get, post, put, web, HttpResponse};
4
+ use actix_web_validator::Query;
5
+ use collection::operations::verification::new_unchecked_verification_pass;
6
+ use schemars::JsonSchema;
7
+ use serde::{Deserialize, Serialize};
8
+ use storage::content_manager::consensus_ops::ConsensusOperations;
9
+ use storage::content_manager::errors::StorageError;
10
+ use storage::dispatcher::Dispatcher;
11
+ use storage::rbac::AccessRequirements;
12
+ use validator::Validate;
13
+
14
+ use crate::actix::auth::ActixAccess;
15
+ use crate::actix::helpers;
16
+
17
+ #[derive(Debug, Deserialize, Validate)]
18
+ struct QueryParams {
19
+ #[serde(default)]
20
+ force: bool,
21
+ #[serde(default)]
22
+ #[validate(range(min = 1))]
23
+ timeout: Option<u64>,
24
+ }
25
+
26
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
27
+ pub struct MetadataParams {
28
+ #[serde(default)]
29
+ pub wait: bool,
30
+ }
31
+
32
+ #[get("/cluster")]
33
+ fn cluster_status(
34
+ dispatcher: web::Data<Dispatcher>,
35
+ ActixAccess(access): ActixAccess,
36
+ ) -> impl Future<Output = HttpResponse> {
37
+ helpers::time(async move {
38
+ access.check_global_access(AccessRequirements::new())?;
39
+ Ok(dispatcher.cluster_status())
40
+ })
41
+ }
42
+
43
+ #[post("/cluster/recover")]
44
+ fn recover_current_peer(
45
+ dispatcher: web::Data<Dispatcher>,
46
+ ActixAccess(access): ActixAccess,
47
+ ) -> impl Future<Output = HttpResponse> {
48
+ // Not a collection level request.
49
+ let pass = new_unchecked_verification_pass();
50
+
51
+ helpers::time(async move {
52
+ access.check_global_access(AccessRequirements::new().manage())?;
53
+ dispatcher.toc(&access, &pass).request_snapshot()?;
54
+ Ok(true)
55
+ })
56
+ }
57
+
58
+ #[delete("/cluster/peer/{peer_id}")]
59
+ fn remove_peer(
60
+ dispatcher: web::Data<Dispatcher>,
61
+ peer_id: web::Path<u64>,
62
+ Query(params): Query<QueryParams>,
63
+ ActixAccess(access): ActixAccess,
64
+ ) -> impl Future<Output = HttpResponse> {
65
+ // Not a collection level request.
66
+ let pass = new_unchecked_verification_pass();
67
+
68
+ helpers::time(async move {
69
+ access.check_global_access(AccessRequirements::new().manage())?;
70
+
71
+ let dispatcher = dispatcher.into_inner();
72
+ let toc = dispatcher.toc(&access, &pass);
73
+ let peer_id = peer_id.into_inner();
74
+
75
+ let has_shards = toc.peer_has_shards(peer_id).await;
76
+ if !params.force && has_shards {
77
+ return Err(StorageError::BadRequest {
78
+ description: format!("Cannot remove peer {peer_id} as there are shards on it"),
79
+ });
80
+ }
81
+
82
+ match dispatcher.consensus_state() {
83
+ Some(consensus_state) => {
84
+ consensus_state
85
+ .propose_consensus_op_with_await(
86
+ ConsensusOperations::RemovePeer(peer_id),
87
+ params.timeout.map(std::time::Duration::from_secs),
88
+ )
89
+ .await
90
+ }
91
+ None => Err(StorageError::BadRequest {
92
+ description: "Distributed mode disabled.".to_string(),
93
+ }),
94
+ }
95
+ })
96
+ }
97
+
98
+ #[get("/cluster/metadata/keys")]
99
+ async fn get_cluster_metadata_keys(
100
+ dispatcher: web::Data<Dispatcher>,
101
+ ActixAccess(access): ActixAccess,
102
+ ) -> HttpResponse {
103
+ helpers::time(async move {
104
+ access.check_global_access(AccessRequirements::new())?;
105
+
106
+ let keys = dispatcher
107
+ .consensus_state()
108
+ .ok_or_else(|| StorageError::service_error("Qdrant is running in standalone mode"))?
109
+ .persistent
110
+ .read()
111
+ .get_cluster_metadata_keys();
112
+
113
+ Ok(keys)
114
+ })
115
+ .await
116
+ }
117
+
118
+ #[get("/cluster/metadata/keys/{key}")]
119
+ async fn get_cluster_metadata_key(
120
+ dispatcher: web::Data<Dispatcher>,
121
+ ActixAccess(access): ActixAccess,
122
+ key: web::Path<String>,
123
+ ) -> HttpResponse {
124
+ helpers::time(async move {
125
+ access.check_global_access(AccessRequirements::new())?;
126
+
127
+ let value = dispatcher
128
+ .consensus_state()
129
+ .ok_or_else(|| StorageError::service_error("Qdrant is running in standalone mode"))?
130
+ .persistent
131
+ .read()
132
+ .get_cluster_metadata_key(key.as_ref());
133
+
134
+ Ok(value)
135
+ })
136
+ .await
137
+ }
138
+
139
+ #[put("/cluster/metadata/keys/{key}")]
140
+ async fn update_cluster_metadata_key(
141
+ dispatcher: web::Data<Dispatcher>,
142
+ ActixAccess(access): ActixAccess,
143
+ key: web::Path<String>,
144
+ params: Query<MetadataParams>,
145
+ value: web::Json<serde_json::Value>,
146
+ ) -> HttpResponse {
147
+ // Not a collection level request.
148
+ let pass = new_unchecked_verification_pass();
149
+ helpers::time(async move {
150
+ let toc = dispatcher.toc(&access, &pass);
151
+ access.check_global_access(AccessRequirements::new().write())?;
152
+
153
+ toc.update_cluster_metadata(key.into_inner(), value.into_inner(), params.wait)
154
+ .await?;
155
+ Ok(true)
156
+ })
157
+ .await
158
+ }
159
+
160
+ #[delete("/cluster/metadata/keys/{key}")]
161
+ async fn delete_cluster_metadata_key(
162
+ dispatcher: web::Data<Dispatcher>,
163
+ ActixAccess(access): ActixAccess,
164
+ key: web::Path<String>,
165
+ params: Query<MetadataParams>,
166
+ ) -> HttpResponse {
167
+ // Not a collection level request.
168
+ let pass = new_unchecked_verification_pass();
169
+ helpers::time(async move {
170
+ let toc = dispatcher.toc(&access, &pass);
171
+ access.check_global_access(AccessRequirements::new().write())?;
172
+
173
+ toc.update_cluster_metadata(key.into_inner(), serde_json::Value::Null, params.wait)
174
+ .await?;
175
+ Ok(true)
176
+ })
177
+ .await
178
+ }
179
+
180
+ // Configure services
181
+ pub fn config_cluster_api(cfg: &mut web::ServiceConfig) {
182
+ cfg.service(cluster_status)
183
+ .service(remove_peer)
184
+ .service(recover_current_peer)
185
+ .service(get_cluster_metadata_keys)
186
+ .service(get_cluster_metadata_key)
187
+ .service(update_cluster_metadata_key)
188
+ .service(delete_cluster_metadata_key);
189
+ }
src/actix/api/collections_api.rs ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::time::Duration;
2
+
3
+ use actix_web::rt::time::Instant;
4
+ use actix_web::{delete, get, patch, post, put, web, HttpResponse, Responder};
5
+ use actix_web_validator::{Json, Path, Query};
6
+ use collection::operations::cluster_ops::ClusterOperations;
7
+ use collection::operations::verification::new_unchecked_verification_pass;
8
+ use serde::Deserialize;
9
+ use storage::content_manager::collection_meta_ops::{
10
+ ChangeAliasesOperation, CollectionMetaOperations, CreateCollection, CreateCollectionOperation,
11
+ DeleteCollectionOperation, UpdateCollection, UpdateCollectionOperation,
12
+ };
13
+ use storage::dispatcher::Dispatcher;
14
+ use validator::Validate;
15
+
16
+ use super::CollectionPath;
17
+ use crate::actix::api::StrictCollectionPath;
18
+ use crate::actix::auth::ActixAccess;
19
+ use crate::actix::helpers::{self, process_response};
20
+ use crate::common::collections::*;
21
+
22
+ #[derive(Debug, Deserialize, Validate)]
23
+ pub struct WaitTimeout {
24
+ #[validate(range(min = 1))]
25
+ timeout: Option<u64>,
26
+ }
27
+
28
+ impl WaitTimeout {
29
+ pub fn timeout(&self) -> Option<Duration> {
30
+ self.timeout.map(Duration::from_secs)
31
+ }
32
+ }
33
+
34
+ #[get("/collections")]
35
+ async fn get_collections(
36
+ dispatcher: web::Data<Dispatcher>,
37
+ ActixAccess(access): ActixAccess,
38
+ ) -> HttpResponse {
39
+ // No request to verify
40
+ let pass = new_unchecked_verification_pass();
41
+
42
+ helpers::time(do_list_collections(dispatcher.toc(&access, &pass), access)).await
43
+ }
44
+
45
+ #[get("/aliases")]
46
+ async fn get_aliases(
47
+ dispatcher: web::Data<Dispatcher>,
48
+ ActixAccess(access): ActixAccess,
49
+ ) -> HttpResponse {
50
+ // No request to verify
51
+ let pass = new_unchecked_verification_pass();
52
+
53
+ helpers::time(do_list_aliases(dispatcher.toc(&access, &pass), access)).await
54
+ }
55
+
56
+ #[get("/collections/{name}")]
57
+ async fn get_collection(
58
+ dispatcher: web::Data<Dispatcher>,
59
+ collection: Path<CollectionPath>,
60
+ ActixAccess(access): ActixAccess,
61
+ ) -> HttpResponse {
62
+ // No request to verify
63
+ let pass = new_unchecked_verification_pass();
64
+
65
+ helpers::time(do_get_collection(
66
+ dispatcher.toc(&access, &pass),
67
+ access,
68
+ &collection.name,
69
+ None,
70
+ ))
71
+ .await
72
+ }
73
+
74
+ #[get("/collections/{name}/exists")]
75
+ async fn get_collection_existence(
76
+ dispatcher: web::Data<Dispatcher>,
77
+ collection: Path<CollectionPath>,
78
+ ActixAccess(access): ActixAccess,
79
+ ) -> HttpResponse {
80
+ // No request to verify
81
+ let pass = new_unchecked_verification_pass();
82
+
83
+ helpers::time(do_collection_exists(
84
+ dispatcher.toc(&access, &pass),
85
+ access,
86
+ &collection.name,
87
+ ))
88
+ .await
89
+ }
90
+
91
+ #[get("/collections/{name}/aliases")]
92
+ async fn get_collection_aliases(
93
+ dispatcher: web::Data<Dispatcher>,
94
+ collection: Path<CollectionPath>,
95
+ ActixAccess(access): ActixAccess,
96
+ ) -> HttpResponse {
97
+ // No request to verify
98
+ let pass = new_unchecked_verification_pass();
99
+
100
+ helpers::time(do_list_collection_aliases(
101
+ dispatcher.toc(&access, &pass),
102
+ access,
103
+ &collection.name,
104
+ ))
105
+ .await
106
+ }
107
+
108
+ #[put("/collections/{name}")]
109
+ async fn create_collection(
110
+ dispatcher: web::Data<Dispatcher>,
111
+ collection: Path<StrictCollectionPath>,
112
+ operation: Json<CreateCollection>,
113
+ Query(query): Query<WaitTimeout>,
114
+ ActixAccess(access): ActixAccess,
115
+ ) -> HttpResponse {
116
+ helpers::time(dispatcher.submit_collection_meta_op(
117
+ CollectionMetaOperations::CreateCollection(CreateCollectionOperation::new(
118
+ collection.name.clone(),
119
+ operation.into_inner(),
120
+ )),
121
+ access,
122
+ query.timeout(),
123
+ ))
124
+ .await
125
+ }
126
+
127
+ #[patch("/collections/{name}")]
128
+ async fn update_collection(
129
+ dispatcher: web::Data<Dispatcher>,
130
+ collection: Path<CollectionPath>,
131
+ operation: Json<UpdateCollection>,
132
+ Query(query): Query<WaitTimeout>,
133
+ ActixAccess(access): ActixAccess,
134
+ ) -> impl Responder {
135
+ let timing = Instant::now();
136
+ let name = collection.name.clone();
137
+ let response = dispatcher
138
+ .submit_collection_meta_op(
139
+ CollectionMetaOperations::UpdateCollection(UpdateCollectionOperation::new(
140
+ name,
141
+ operation.into_inner(),
142
+ )),
143
+ access,
144
+ query.timeout(),
145
+ )
146
+ .await;
147
+ process_response(response, timing, None)
148
+ }
149
+
150
+ #[delete("/collections/{name}")]
151
+ async fn delete_collection(
152
+ dispatcher: web::Data<Dispatcher>,
153
+ collection: Path<CollectionPath>,
154
+ Query(query): Query<WaitTimeout>,
155
+ ActixAccess(access): ActixAccess,
156
+ ) -> impl Responder {
157
+ let timing = Instant::now();
158
+ let response = dispatcher
159
+ .submit_collection_meta_op(
160
+ CollectionMetaOperations::DeleteCollection(DeleteCollectionOperation(
161
+ collection.name.clone(),
162
+ )),
163
+ access,
164
+ query.timeout(),
165
+ )
166
+ .await;
167
+ process_response(response, timing, None)
168
+ }
169
+
170
+ #[post("/collections/aliases")]
171
+ async fn update_aliases(
172
+ dispatcher: web::Data<Dispatcher>,
173
+ operation: Json<ChangeAliasesOperation>,
174
+ Query(query): Query<WaitTimeout>,
175
+ ActixAccess(access): ActixAccess,
176
+ ) -> impl Responder {
177
+ let timing = Instant::now();
178
+ let response = dispatcher
179
+ .submit_collection_meta_op(
180
+ CollectionMetaOperations::ChangeAliases(operation.0),
181
+ access,
182
+ query.timeout(),
183
+ )
184
+ .await;
185
+ process_response(response, timing, None)
186
+ }
187
+
188
+ #[get("/collections/{name}/cluster")]
189
+ async fn get_cluster_info(
190
+ dispatcher: web::Data<Dispatcher>,
191
+ collection: Path<CollectionPath>,
192
+ ActixAccess(access): ActixAccess,
193
+ ) -> impl Responder {
194
+ // No request to verify
195
+ let pass = new_unchecked_verification_pass();
196
+
197
+ helpers::time(do_get_collection_cluster(
198
+ dispatcher.toc(&access, &pass),
199
+ access,
200
+ &collection.name,
201
+ ))
202
+ .await
203
+ }
204
+
205
+ #[post("/collections/{name}/cluster")]
206
+ async fn update_collection_cluster(
207
+ dispatcher: web::Data<Dispatcher>,
208
+ collection: Path<CollectionPath>,
209
+ operation: Json<ClusterOperations>,
210
+ Query(query): Query<WaitTimeout>,
211
+ ActixAccess(access): ActixAccess,
212
+ ) -> impl Responder {
213
+ let timing = Instant::now();
214
+ let wait_timeout = query.timeout();
215
+ let response = do_update_collection_cluster(
216
+ &dispatcher.into_inner(),
217
+ collection.name.clone(),
218
+ operation.0,
219
+ access,
220
+ wait_timeout,
221
+ )
222
+ .await;
223
+ process_response(response, timing, None)
224
+ }
225
+
226
+ // Configure services
227
+ pub fn config_collections_api(cfg: &mut web::ServiceConfig) {
228
+ // Ordering of services is important for correct path pattern matching
229
+ // See: <https://github.com/qdrant/qdrant/issues/3543>
230
+ cfg.service(update_aliases)
231
+ .service(get_collections)
232
+ .service(get_collection)
233
+ .service(get_collection_existence)
234
+ .service(create_collection)
235
+ .service(update_collection)
236
+ .service(delete_collection)
237
+ .service(get_aliases)
238
+ .service(get_collection_aliases)
239
+ .service(get_cluster_info)
240
+ .service(update_collection_cluster);
241
+ }
242
+
243
+ #[cfg(test)]
244
+ mod tests {
245
+ use actix_web::web::Query;
246
+
247
+ use super::WaitTimeout;
248
+
249
+ #[test]
250
+ fn timeout_is_deserialized() {
251
+ let timeout: WaitTimeout = Query::from_query("").unwrap().0;
252
+ assert!(timeout.timeout.is_none());
253
+ let timeout: WaitTimeout = Query::from_query("timeout=10").unwrap().0;
254
+ assert_eq!(timeout.timeout, Some(10))
255
+ }
256
+ }
src/actix/api/count_api.rs ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{post, web, Responder};
2
+ use actix_web_validator::{Json, Path, Query};
3
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
4
+ use collection::operations::types::CountRequest;
5
+ use storage::content_manager::collection_verification::check_strict_mode;
6
+ use storage::dispatcher::Dispatcher;
7
+ use tokio::time::Instant;
8
+
9
+ use super::CollectionPath;
10
+ use crate::actix::api::read_params::ReadParams;
11
+ use crate::actix::auth::ActixAccess;
12
+ use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
13
+ use crate::common::points::do_count_points;
14
+ use crate::settings::ServiceConfig;
15
+
16
+ #[post("/collections/{name}/points/count")]
17
+ async fn count_points(
18
+ dispatcher: web::Data<Dispatcher>,
19
+ collection: Path<CollectionPath>,
20
+ request: Json<CountRequest>,
21
+ params: Query<ReadParams>,
22
+ service_config: web::Data<ServiceConfig>,
23
+ ActixAccess(access): ActixAccess,
24
+ ) -> impl Responder {
25
+ let CountRequest {
26
+ count_request,
27
+ shard_key,
28
+ } = request.into_inner();
29
+
30
+ let pass = match check_strict_mode(
31
+ &count_request,
32
+ params.timeout_as_secs(),
33
+ &collection.name,
34
+ &dispatcher,
35
+ &access,
36
+ )
37
+ .await
38
+ {
39
+ Ok(pass) => pass,
40
+ Err(err) => return process_response_error(err, Instant::now(), None),
41
+ };
42
+
43
+ let shard_selector = match shard_key {
44
+ None => ShardSelectorInternal::All,
45
+ Some(shard_keys) => ShardSelectorInternal::from(shard_keys),
46
+ };
47
+
48
+ let request_hw_counter = get_request_hardware_counter(
49
+ &dispatcher,
50
+ collection.name.clone(),
51
+ service_config.hardware_reporting(),
52
+ );
53
+
54
+ let timing = Instant::now();
55
+
56
+ let result = do_count_points(
57
+ dispatcher.toc(&access, &pass),
58
+ &collection.name,
59
+ count_request,
60
+ params.consistency,
61
+ params.timeout(),
62
+ shard_selector,
63
+ access,
64
+ request_hw_counter.get_counter(),
65
+ )
66
+ .await;
67
+
68
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
69
+ }
src/actix/api/debug_api.rs ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{get, patch, web, Responder};
2
+ use storage::rbac::AccessRequirements;
3
+
4
+ use crate::actix::auth::ActixAccess;
5
+ use crate::common::debugger::{DebugConfigPatch, DebuggerState};
6
+
7
+ #[get("/debugger")]
8
+ async fn get_debugger_config(
9
+ ActixAccess(access): ActixAccess,
10
+ debugger_state: web::Data<DebuggerState>,
11
+ ) -> impl Responder {
12
+ crate::actix::helpers::time(async move {
13
+ access.check_global_access(AccessRequirements::new().manage())?;
14
+ Ok(debugger_state.get_config())
15
+ })
16
+ .await
17
+ }
18
+
19
+ #[patch("/debugger")]
20
+ async fn update_debugger_config(
21
+ ActixAccess(access): ActixAccess,
22
+ debugger_state: web::Data<DebuggerState>,
23
+ debug_patch: web::Json<DebugConfigPatch>,
24
+ ) -> impl Responder {
25
+ crate::actix::helpers::time(async move {
26
+ access.check_global_access(AccessRequirements::new().manage())?;
27
+ Ok(debugger_state.apply_config_patch(debug_patch.into_inner()))
28
+ })
29
+ .await
30
+ }
31
+
32
+ // Configure services
33
+ pub fn config_debugger_api(cfg: &mut web::ServiceConfig) {
34
+ cfg.service(get_debugger_config);
35
+ cfg.service(update_debugger_config);
36
+ }
src/actix/api/discovery_api.rs ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{post, web, Responder};
2
+ use actix_web_validator::{Json, Path, Query};
3
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
4
+ use collection::operations::types::{DiscoverRequest, DiscoverRequestBatch};
5
+ use itertools::Itertools;
6
+ use storage::content_manager::collection_verification::{
7
+ check_strict_mode, check_strict_mode_batch,
8
+ };
9
+ use storage::dispatcher::Dispatcher;
10
+ use tokio::time::Instant;
11
+
12
+ use crate::actix::api::read_params::ReadParams;
13
+ use crate::actix::api::CollectionPath;
14
+ use crate::actix::auth::ActixAccess;
15
+ use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
16
+ use crate::common::points::do_discover_batch_points;
17
+ use crate::settings::ServiceConfig;
18
+
19
+ #[post("/collections/{name}/points/discover")]
20
+ async fn discover_points(
21
+ dispatcher: web::Data<Dispatcher>,
22
+ collection: Path<CollectionPath>,
23
+ request: Json<DiscoverRequest>,
24
+ params: Query<ReadParams>,
25
+ service_config: web::Data<ServiceConfig>,
26
+ ActixAccess(access): ActixAccess,
27
+ ) -> impl Responder {
28
+ let DiscoverRequest {
29
+ discover_request,
30
+ shard_key,
31
+ } = request.into_inner();
32
+
33
+ let pass = match check_strict_mode(
34
+ &discover_request,
35
+ params.timeout_as_secs(),
36
+ &collection.name,
37
+ &dispatcher,
38
+ &access,
39
+ )
40
+ .await
41
+ {
42
+ Ok(pass) => pass,
43
+ Err(err) => return process_response_error(err, Instant::now(), None),
44
+ };
45
+
46
+ let shard_selection = match shard_key {
47
+ None => ShardSelectorInternal::All,
48
+ Some(shard_keys) => shard_keys.into(),
49
+ };
50
+
51
+ let request_hw_counter = get_request_hardware_counter(
52
+ &dispatcher,
53
+ collection.name.clone(),
54
+ service_config.hardware_reporting(),
55
+ );
56
+
57
+ let timing = Instant::now();
58
+
59
+ let result = dispatcher
60
+ .toc(&access, &pass)
61
+ .discover(
62
+ &collection.name,
63
+ discover_request,
64
+ params.consistency,
65
+ shard_selection,
66
+ access,
67
+ params.timeout(),
68
+ request_hw_counter.get_counter(),
69
+ )
70
+ .await
71
+ .map(|scored_points| {
72
+ scored_points
73
+ .into_iter()
74
+ .map(api::rest::ScoredPoint::from)
75
+ .collect_vec()
76
+ });
77
+
78
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
79
+ }
80
+
81
+ #[post("/collections/{name}/points/discover/batch")]
82
+ async fn discover_batch_points(
83
+ dispatcher: web::Data<Dispatcher>,
84
+ collection: Path<CollectionPath>,
85
+ request: Json<DiscoverRequestBatch>,
86
+ params: Query<ReadParams>,
87
+ service_config: web::Data<ServiceConfig>,
88
+ ActixAccess(access): ActixAccess,
89
+ ) -> impl Responder {
90
+ let request = request.into_inner();
91
+
92
+ let pass = match check_strict_mode_batch(
93
+ request.searches.iter().map(|i| &i.discover_request),
94
+ params.timeout_as_secs(),
95
+ &collection.name,
96
+ &dispatcher,
97
+ &access,
98
+ )
99
+ .await
100
+ {
101
+ Ok(pass) => pass,
102
+ Err(err) => return process_response_error(err, Instant::now(), None),
103
+ };
104
+
105
+ let request_hw_counter = get_request_hardware_counter(
106
+ &dispatcher,
107
+ collection.name.clone(),
108
+ service_config.hardware_reporting(),
109
+ );
110
+ let timing = Instant::now();
111
+
112
+ let result = do_discover_batch_points(
113
+ dispatcher.toc(&access, &pass),
114
+ &collection.name,
115
+ request,
116
+ params.consistency,
117
+ access,
118
+ params.timeout(),
119
+ request_hw_counter.get_counter(),
120
+ )
121
+ .await
122
+ .map(|batch_scored_points| {
123
+ batch_scored_points
124
+ .into_iter()
125
+ .map(|scored_points| {
126
+ scored_points
127
+ .into_iter()
128
+ .map(api::rest::ScoredPoint::from)
129
+ .collect_vec()
130
+ })
131
+ .collect_vec()
132
+ });
133
+
134
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
135
+ }
136
+
137
+ pub fn config_discovery_api(cfg: &mut web::ServiceConfig) {
138
+ cfg.service(discover_points);
139
+ cfg.service(discover_batch_points);
140
+ }
src/actix/api/facet_api.rs ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{post, web, Responder};
2
+ use actix_web_validator::{Json, Path, Query};
3
+ use api::rest::{FacetRequest, FacetResponse};
4
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
5
+ use storage::content_manager::collection_verification::check_strict_mode;
6
+ use storage::dispatcher::Dispatcher;
7
+ use tokio::time::Instant;
8
+
9
+ use crate::actix::api::read_params::ReadParams;
10
+ use crate::actix::api::CollectionPath;
11
+ use crate::actix::auth::ActixAccess;
12
+ use crate::actix::helpers::{
13
+ get_request_hardware_counter, process_response, process_response_error,
14
+ };
15
+ use crate::settings::ServiceConfig;
16
+
17
+ #[post("/collections/{name}/facet")]
18
+ async fn facet(
19
+ dispatcher: web::Data<Dispatcher>,
20
+ collection: Path<CollectionPath>,
21
+ request: Json<FacetRequest>,
22
+ params: Query<ReadParams>,
23
+ service_config: web::Data<ServiceConfig>,
24
+ ActixAccess(access): ActixAccess,
25
+ ) -> impl Responder {
26
+ let timing = Instant::now();
27
+
28
+ let FacetRequest {
29
+ facet_request,
30
+ shard_key,
31
+ } = request.into_inner();
32
+
33
+ let pass = match check_strict_mode(
34
+ &facet_request,
35
+ params.timeout_as_secs(),
36
+ &collection.name,
37
+ &dispatcher,
38
+ &access,
39
+ )
40
+ .await
41
+ {
42
+ Ok(pass) => pass,
43
+ Err(err) => return process_response_error(err, timing, None),
44
+ };
45
+
46
+ let facet_params = From::from(facet_request);
47
+
48
+ let shard_selection = match shard_key {
49
+ None => ShardSelectorInternal::All,
50
+ Some(shard_keys) => shard_keys.into(),
51
+ };
52
+
53
+ let request_hw_counter = get_request_hardware_counter(
54
+ &dispatcher,
55
+ collection.name.clone(),
56
+ service_config.hardware_reporting(),
57
+ );
58
+
59
+ let response = dispatcher
60
+ .toc(&access, &pass)
61
+ .facet(
62
+ &collection.name,
63
+ facet_params,
64
+ shard_selection,
65
+ params.consistency,
66
+ access,
67
+ params.timeout(),
68
+ )
69
+ .await
70
+ .map(FacetResponse::from);
71
+
72
+ process_response(response, timing, request_hw_counter.to_rest_api())
73
+ }
74
+
75
+ pub fn config_facet_api(cfg: &mut web::ServiceConfig) {
76
+ cfg.service(facet);
77
+ }
src/actix/api/issues_api.rs ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{delete, get, web, Responder};
2
+ use collection::operations::types::IssuesReport;
3
+ use storage::rbac::AccessRequirements;
4
+
5
+ use crate::actix::auth::ActixAccess;
6
+
7
+ #[get("/issues")]
8
+ async fn get_issues(ActixAccess(access): ActixAccess) -> impl Responder {
9
+ crate::actix::helpers::time(async move {
10
+ access.check_global_access(AccessRequirements::new().manage())?;
11
+ Ok(IssuesReport {
12
+ issues: issues::all_issues(),
13
+ })
14
+ })
15
+ .await
16
+ }
17
+
18
+ #[delete("/issues")]
19
+ async fn clear_issues(ActixAccess(access): ActixAccess) -> impl Responder {
20
+ crate::actix::helpers::time(async move {
21
+ access.check_global_access(AccessRequirements::new().manage())?;
22
+ issues::clear();
23
+ Ok(true)
24
+ })
25
+ .await
26
+ }
27
+
28
+ // Configure services
29
+ pub fn config_issues_api(cfg: &mut web::ServiceConfig) {
30
+ cfg.service(get_issues);
31
+ cfg.service(clear_issues);
32
+ }
src/actix/api/local_shard_api.rs ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::sync::Arc;
2
+
3
+ use actix_web::{post, web, Responder};
4
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
5
+ use collection::operations::types::{
6
+ CountRequestInternal, PointRequestInternal, ScrollRequestInternal,
7
+ };
8
+ use collection::operations::verification::{new_unchecked_verification_pass, VerificationPass};
9
+ use collection::shards::shard::ShardId;
10
+ use segment::types::{Condition, Filter};
11
+ use storage::content_manager::collection_verification::check_strict_mode;
12
+ use storage::content_manager::errors::{StorageError, StorageResult};
13
+ use storage::dispatcher::Dispatcher;
14
+ use storage::rbac::{Access, AccessRequirements};
15
+ use tokio::time::Instant;
16
+
17
+ use crate::actix::api::read_params::ReadParams;
18
+ use crate::actix::auth::ActixAccess;
19
+ use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
20
+ use crate::common::points;
21
+ use crate::settings::ServiceConfig;
22
+
23
+ // Configure services
24
+ pub fn config_local_shard_api(cfg: &mut web::ServiceConfig) {
25
+ cfg.service(get_points)
26
+ .service(scroll_points)
27
+ .service(count_points)
28
+ .service(cleanup_shard);
29
+ }
30
+
31
+ #[post("/collections/{collection}/shards/{shard}/points")]
32
+ async fn get_points(
33
+ dispatcher: web::Data<Dispatcher>,
34
+ ActixAccess(access): ActixAccess,
35
+ path: web::Path<CollectionShard>,
36
+ request: web::Json<PointRequestInternal>,
37
+ params: web::Query<ReadParams>,
38
+ ) -> impl Responder {
39
+ // No strict mode verification needed
40
+ let pass = new_unchecked_verification_pass();
41
+
42
+ helpers::time(async move {
43
+ let records = points::do_get_points(
44
+ dispatcher.toc(&access, &pass),
45
+ &path.collection,
46
+ request.into_inner(),
47
+ params.consistency,
48
+ params.timeout(),
49
+ ShardSelectorInternal::ShardId(path.shard),
50
+ access,
51
+ )
52
+ .await?;
53
+
54
+ let records: Vec<_> = records.into_iter().map(api::rest::Record::from).collect();
55
+ Ok(records)
56
+ })
57
+ .await
58
+ }
59
+
60
+ #[post("/collections/{collection}/shards/{shard}/points/scroll")]
61
+ async fn scroll_points(
62
+ dispatcher: web::Data<Dispatcher>,
63
+ ActixAccess(access): ActixAccess,
64
+ path: web::Path<CollectionShard>,
65
+ request: web::Json<WithFilter<ScrollRequestInternal>>,
66
+ params: web::Query<ReadParams>,
67
+ ) -> impl Responder {
68
+ let WithFilter {
69
+ mut request,
70
+ hash_ring_filter,
71
+ } = request.into_inner();
72
+
73
+ let pass = match check_strict_mode(
74
+ &request,
75
+ params.timeout_as_secs(),
76
+ &path.collection,
77
+ &dispatcher,
78
+ &access,
79
+ )
80
+ .await
81
+ {
82
+ Ok(pass) => pass,
83
+ Err(err) => return process_response_error(err, Instant::now(), None),
84
+ };
85
+
86
+ helpers::time(async move {
87
+ let hash_ring_filter = match hash_ring_filter {
88
+ Some(filter) => get_hash_ring_filter(
89
+ &dispatcher,
90
+ &access,
91
+ &path.collection,
92
+ AccessRequirements::new(),
93
+ filter.expected_shard_id,
94
+ &pass,
95
+ )
96
+ .await?
97
+ .into(),
98
+
99
+ None => None,
100
+ };
101
+
102
+ request.filter = merge_with_optional_filter(request.filter.take(), hash_ring_filter);
103
+
104
+ dispatcher
105
+ .toc(&access, &pass)
106
+ .scroll(
107
+ &path.collection,
108
+ request,
109
+ params.consistency,
110
+ params.timeout(),
111
+ ShardSelectorInternal::ShardId(path.shard),
112
+ access,
113
+ )
114
+ .await
115
+ })
116
+ .await
117
+ }
118
+
119
+ #[post("/collections/{collection}/shards/{shard}/points/count")]
120
+ async fn count_points(
121
+ dispatcher: web::Data<Dispatcher>,
122
+ ActixAccess(access): ActixAccess,
123
+ path: web::Path<CollectionShard>,
124
+ request: web::Json<WithFilter<CountRequestInternal>>,
125
+ params: web::Query<ReadParams>,
126
+ service_config: web::Data<ServiceConfig>,
127
+ ) -> impl Responder {
128
+ let WithFilter {
129
+ mut request,
130
+ hash_ring_filter,
131
+ } = request.into_inner();
132
+
133
+ let pass = match check_strict_mode(
134
+ &request,
135
+ params.timeout_as_secs(),
136
+ &path.collection,
137
+ &dispatcher,
138
+ &access,
139
+ )
140
+ .await
141
+ {
142
+ Ok(pass) => pass,
143
+ Err(err) => return process_response_error(err, Instant::now(), None),
144
+ };
145
+
146
+ let request_hw_counter = get_request_hardware_counter(
147
+ &dispatcher,
148
+ path.collection.clone(),
149
+ service_config.hardware_reporting(),
150
+ );
151
+ let timing = Instant::now();
152
+ let hw_measurement_acc = request_hw_counter.get_counter();
153
+
154
+ let result = async move {
155
+ let hash_ring_filter = match hash_ring_filter {
156
+ Some(filter) => get_hash_ring_filter(
157
+ &dispatcher,
158
+ &access,
159
+ &path.collection,
160
+ AccessRequirements::new(),
161
+ filter.expected_shard_id,
162
+ &pass,
163
+ )
164
+ .await?
165
+ .into(),
166
+
167
+ None => None,
168
+ };
169
+
170
+ request.filter = merge_with_optional_filter(request.filter.take(), hash_ring_filter);
171
+
172
+ points::do_count_points(
173
+ dispatcher.toc(&access, &pass),
174
+ &path.collection,
175
+ request,
176
+ params.consistency,
177
+ params.timeout(),
178
+ ShardSelectorInternal::ShardId(path.shard),
179
+ access,
180
+ hw_measurement_acc,
181
+ )
182
+ .await
183
+ }
184
+ .await;
185
+
186
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
187
+ }
188
+
189
+ #[post("/collections/{collection}/shards/{shard}/cleanup")]
190
+ async fn cleanup_shard(
191
+ dispatcher: web::Data<Dispatcher>,
192
+ ActixAccess(access): ActixAccess,
193
+ path: web::Path<CollectionShard>,
194
+ ) -> impl Responder {
195
+ // Nothing to verify here.
196
+ let pass = new_unchecked_verification_pass();
197
+
198
+ helpers::time(async move {
199
+ let path = path.into_inner();
200
+ dispatcher
201
+ .toc(&access, &pass)
202
+ .cleanup_local_shard(&path.collection, path.shard, access)
203
+ .await
204
+ })
205
+ .await
206
+ }
207
+
208
+ #[derive(serde::Deserialize, validator::Validate)]
209
+ struct CollectionShard {
210
+ #[validate(length(min = 1, max = 255))]
211
+ collection: String,
212
+ shard: ShardId,
213
+ }
214
+
215
+ #[derive(Clone, Debug, serde::Deserialize)]
216
+ struct WithFilter<T> {
217
+ #[serde(flatten)]
218
+ request: T,
219
+ #[serde(default)]
220
+ hash_ring_filter: Option<SerdeHelper>,
221
+ }
222
+
223
+ #[derive(Clone, Debug, serde::Deserialize)]
224
+ struct SerdeHelper {
225
+ expected_shard_id: ShardId,
226
+ }
227
+
228
+ async fn get_hash_ring_filter(
229
+ dispatcher: &Dispatcher,
230
+ access: &Access,
231
+ collection: &str,
232
+ reqs: AccessRequirements,
233
+ expected_shard_id: ShardId,
234
+ verification_pass: &VerificationPass,
235
+ ) -> StorageResult<Filter> {
236
+ let pass = access.check_collection_access(collection, reqs)?;
237
+
238
+ let shard_holder = dispatcher
239
+ .toc(access, verification_pass)
240
+ .get_collection(&pass)
241
+ .await?
242
+ .shards_holder();
243
+
244
+ let hash_ring_filter = shard_holder
245
+ .read()
246
+ .await
247
+ .hash_ring_filter(expected_shard_id)
248
+ .ok_or_else(|| {
249
+ StorageError::bad_request(format!(
250
+ "shard {expected_shard_id} does not exist in collection {collection}"
251
+ ))
252
+ })?;
253
+
254
+ let condition = Condition::CustomIdChecker(Arc::new(hash_ring_filter));
255
+ let filter = Filter::new_must(condition);
256
+
257
+ Ok(filter)
258
+ }
259
+
260
+ fn merge_with_optional_filter(filter: Option<Filter>, hash_ring: Option<Filter>) -> Option<Filter> {
261
+ match (filter, hash_ring) {
262
+ (Some(filter), Some(hash_ring)) => hash_ring.merge_owned(filter).into(),
263
+ (Some(filter), None) => filter.into(),
264
+ (None, Some(hash_ring)) => hash_ring.into(),
265
+ _ => None,
266
+ }
267
+ }
src/actix/api/mod.rs ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use common::validation::validate_collection_name;
2
+ use serde::Deserialize;
3
+ use validator::Validate;
4
+
5
+ pub mod cluster_api;
6
+ pub mod collections_api;
7
+ pub mod count_api;
8
+ pub mod debug_api;
9
+ pub mod discovery_api;
10
+ pub mod facet_api;
11
+ pub mod issues_api;
12
+ pub mod local_shard_api;
13
+ pub mod query_api;
14
+ pub mod read_params;
15
+ pub mod recommend_api;
16
+ pub mod retrieve_api;
17
+ pub mod search_api;
18
+ pub mod service_api;
19
+ pub mod shards_api;
20
+ pub mod snapshot_api;
21
+ pub mod update_api;
22
+
23
+ /// A collection path with stricter validation
24
+ ///
25
+ /// Validation for collection paths has been made more strict over time.
26
+ /// To prevent breaking changes on existing collections, this is only enforced for newly created
27
+ /// collections. Basic validation is enforced everywhere else.
28
+ #[derive(Deserialize, Validate)]
29
+ struct StrictCollectionPath {
30
+ #[validate(
31
+ length(min = 1, max = 255),
32
+ custom(function = "validate_collection_name")
33
+ )]
34
+ name: String,
35
+ }
36
+
37
+ /// A collection path with basic validation
38
+ ///
39
+ /// Validation for collection paths has been made more strict over time.
40
+ /// To prevent breaking changes on existing collections, this is only enforced for newly created
41
+ /// collections. Basic validation is enforced everywhere else.
42
+ #[derive(Deserialize, Validate)]
43
+ struct CollectionPath {
44
+ #[validate(length(min = 1, max = 255))]
45
+ name: String,
46
+ }
src/actix/api/query_api.rs ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{post, web, Responder};
2
+ use actix_web_validator::{Json, Path, Query};
3
+ use api::rest::{QueryGroupsRequest, QueryRequest, QueryRequestBatch, QueryResponse};
4
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
5
+ use itertools::Itertools;
6
+ use storage::content_manager::collection_verification::{
7
+ check_strict_mode, check_strict_mode_batch,
8
+ };
9
+ use storage::content_manager::errors::StorageError;
10
+ use storage::dispatcher::Dispatcher;
11
+ use tokio::time::Instant;
12
+
13
+ use super::read_params::ReadParams;
14
+ use super::CollectionPath;
15
+ use crate::actix::auth::ActixAccess;
16
+ use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
17
+ use crate::common::inference::query_requests_rest::{
18
+ convert_query_groups_request_from_rest, convert_query_request_from_rest,
19
+ };
20
+ use crate::common::points::do_query_point_groups;
21
+ use crate::settings::ServiceConfig;
22
+
23
+ #[post("/collections/{name}/points/query")]
24
+ async fn query_points(
25
+ dispatcher: web::Data<Dispatcher>,
26
+ collection: Path<CollectionPath>,
27
+ request: Json<QueryRequest>,
28
+ params: Query<ReadParams>,
29
+ service_config: web::Data<ServiceConfig>,
30
+ ActixAccess(access): ActixAccess,
31
+ ) -> impl Responder {
32
+ let QueryRequest {
33
+ internal: query_request,
34
+ shard_key,
35
+ } = request.into_inner();
36
+
37
+ let pass = match check_strict_mode(
38
+ &query_request,
39
+ params.timeout_as_secs(),
40
+ &collection.name,
41
+ &dispatcher,
42
+ &access,
43
+ )
44
+ .await
45
+ {
46
+ Ok(pass) => pass,
47
+ Err(err) => return process_response_error(err, Instant::now(), None),
48
+ };
49
+
50
+ let request_hw_counter = get_request_hardware_counter(
51
+ &dispatcher,
52
+ collection.name.clone(),
53
+ service_config.hardware_reporting(),
54
+ );
55
+ let timing = Instant::now();
56
+
57
+ let shard_selection = match shard_key {
58
+ None => ShardSelectorInternal::All,
59
+ Some(shard_keys) => shard_keys.into(),
60
+ };
61
+ let hw_measurement_acc = request_hw_counter.get_counter();
62
+
63
+ let result = async move {
64
+ let request = convert_query_request_from_rest(query_request).await?;
65
+
66
+ let points = dispatcher
67
+ .toc(&access, &pass)
68
+ .query_batch(
69
+ &collection.name,
70
+ vec![(request, shard_selection)],
71
+ params.consistency,
72
+ access,
73
+ params.timeout(),
74
+ hw_measurement_acc,
75
+ )
76
+ .await?
77
+ .pop()
78
+ .ok_or_else(|| {
79
+ StorageError::service_error("Expected at least one response for one query")
80
+ })?
81
+ .into_iter()
82
+ .map(api::rest::ScoredPoint::from)
83
+ .collect_vec();
84
+
85
+ Ok(QueryResponse { points })
86
+ }
87
+ .await;
88
+
89
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
90
+ }
91
+
92
+ #[post("/collections/{name}/points/query/batch")]
93
+ async fn query_points_batch(
94
+ dispatcher: web::Data<Dispatcher>,
95
+ collection: Path<CollectionPath>,
96
+ request: Json<QueryRequestBatch>,
97
+ params: Query<ReadParams>,
98
+ service_config: web::Data<ServiceConfig>,
99
+ ActixAccess(access): ActixAccess,
100
+ ) -> impl Responder {
101
+ let QueryRequestBatch { searches } = request.into_inner();
102
+
103
+ let pass = match check_strict_mode_batch(
104
+ searches.iter().map(|i| &i.internal),
105
+ params.timeout_as_secs(),
106
+ &collection.name,
107
+ &dispatcher,
108
+ &access,
109
+ )
110
+ .await
111
+ {
112
+ Ok(pass) => pass,
113
+ Err(err) => return process_response_error(err, Instant::now(), None),
114
+ };
115
+
116
+ let request_hw_counter = get_request_hardware_counter(
117
+ &dispatcher,
118
+ collection.name.clone(),
119
+ service_config.hardware_reporting(),
120
+ );
121
+ let timing = Instant::now();
122
+ let hw_measurement_acc = request_hw_counter.get_counter();
123
+
124
+ let result = async move {
125
+ let mut batch = Vec::with_capacity(searches.len());
126
+ for request in searches {
127
+ let QueryRequest {
128
+ internal,
129
+ shard_key,
130
+ } = request;
131
+
132
+ let request = convert_query_request_from_rest(internal).await?;
133
+ let shard_selection = match shard_key {
134
+ None => ShardSelectorInternal::All,
135
+ Some(shard_keys) => shard_keys.into(),
136
+ };
137
+
138
+ batch.push((request, shard_selection));
139
+ }
140
+
141
+ let res = dispatcher
142
+ .toc(&access, &pass)
143
+ .query_batch(
144
+ &collection.name,
145
+ batch,
146
+ params.consistency,
147
+ access,
148
+ params.timeout(),
149
+ hw_measurement_acc,
150
+ )
151
+ .await?
152
+ .into_iter()
153
+ .map(|response| QueryResponse {
154
+ points: response
155
+ .into_iter()
156
+ .map(api::rest::ScoredPoint::from)
157
+ .collect_vec(),
158
+ })
159
+ .collect_vec();
160
+ Ok(res)
161
+ }
162
+ .await;
163
+
164
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
165
+ }
166
+
167
+ #[post("/collections/{name}/points/query/groups")]
168
+ async fn query_points_groups(
169
+ dispatcher: web::Data<Dispatcher>,
170
+ collection: Path<CollectionPath>,
171
+ request: Json<QueryGroupsRequest>,
172
+ params: Query<ReadParams>,
173
+ service_config: web::Data<ServiceConfig>,
174
+ ActixAccess(access): ActixAccess,
175
+ ) -> impl Responder {
176
+ let QueryGroupsRequest {
177
+ search_group_request,
178
+ shard_key,
179
+ } = request.into_inner();
180
+
181
+ let pass = match check_strict_mode(
182
+ &search_group_request,
183
+ params.timeout_as_secs(),
184
+ &collection.name,
185
+ &dispatcher,
186
+ &access,
187
+ )
188
+ .await
189
+ {
190
+ Ok(pass) => pass,
191
+ Err(err) => return process_response_error(err, Instant::now(), None),
192
+ };
193
+
194
+ let request_hw_counter = get_request_hardware_counter(
195
+ &dispatcher,
196
+ collection.name.clone(),
197
+ service_config.hardware_reporting(),
198
+ );
199
+ let timing = Instant::now();
200
+ let hw_measurement_acc = request_hw_counter.get_counter();
201
+
202
+ let result = async move {
203
+ let shard_selection = match shard_key {
204
+ None => ShardSelectorInternal::All,
205
+ Some(shard_keys) => shard_keys.into(),
206
+ };
207
+
208
+ let query_group_request =
209
+ convert_query_groups_request_from_rest(search_group_request).await?;
210
+
211
+ do_query_point_groups(
212
+ dispatcher.toc(&access, &pass),
213
+ &collection.name,
214
+ query_group_request,
215
+ params.consistency,
216
+ shard_selection,
217
+ access,
218
+ params.timeout(),
219
+ hw_measurement_acc,
220
+ )
221
+ .await
222
+ }
223
+ .await;
224
+
225
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
226
+ }
227
+
228
+ pub fn config_query_api(cfg: &mut web::ServiceConfig) {
229
+ cfg.service(query_points);
230
+ cfg.service(query_points_batch);
231
+ cfg.service(query_points_groups);
232
+ }
src/actix/api/read_params.rs ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::num::NonZeroU64;
2
+ use std::time::Duration;
3
+
4
+ use collection::operations::consistency_params::ReadConsistency;
5
+ use schemars::JsonSchema;
6
+ use serde::Deserialize;
7
+ use validator::Validate;
8
+
9
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Deserialize, JsonSchema, Validate)]
10
+ pub struct ReadParams {
11
+ #[serde(default, deserialize_with = "deserialize_read_consistency")]
12
+ #[validate(nested)]
13
+ pub consistency: Option<ReadConsistency>,
14
+ /// If set, overrides global timeout for this request. Unit is seconds.
15
+ pub timeout: Option<NonZeroU64>,
16
+ }
17
+
18
+ impl ReadParams {
19
+ pub fn timeout(&self) -> Option<Duration> {
20
+ self.timeout.map(|num| Duration::from_secs(num.get()))
21
+ }
22
+
23
+ pub(crate) fn timeout_as_secs(&self) -> Option<usize> {
24
+ self.timeout.map(|i| i.get() as usize)
25
+ }
26
+ }
27
+
28
+ fn deserialize_read_consistency<'de, D>(
29
+ deserializer: D,
30
+ ) -> Result<Option<ReadConsistency>, D::Error>
31
+ where
32
+ D: serde::Deserializer<'de>,
33
+ {
34
+ #[derive(Deserialize)]
35
+ #[serde(untagged)]
36
+ enum Helper<'a> {
37
+ ReadConsistency(ReadConsistency),
38
+ Str(&'a str),
39
+ }
40
+
41
+ match Helper::deserialize(deserializer)? {
42
+ Helper::ReadConsistency(read_consistency) => Ok(Some(read_consistency)),
43
+ Helper::Str("") => Ok(None),
44
+ _ => Err(serde::de::Error::custom(
45
+ "failed to deserialize read consistency query parameter value",
46
+ )),
47
+ }
48
+ }
49
+
50
+ #[cfg(test)]
51
+ mod test {
52
+ use collection::operations::consistency_params::ReadConsistencyType;
53
+
54
+ use super::*;
55
+
56
+ #[test]
57
+ fn deserialize_empty_string() {
58
+ test_str("", ReadParams::default());
59
+ }
60
+
61
+ #[test]
62
+ fn deserialize_empty_value() {
63
+ test("", ReadParams::default());
64
+ }
65
+
66
+ #[test]
67
+ fn deserialize_type() {
68
+ test("all", from_type(ReadConsistencyType::All));
69
+ test("majority", from_type(ReadConsistencyType::Majority));
70
+ test("quorum", from_type(ReadConsistencyType::Quorum));
71
+ }
72
+
73
+ #[test]
74
+ fn deserialize_factor() {
75
+ for factor in 1..42 {
76
+ test(&factor.to_string(), from_factor(factor));
77
+ }
78
+ }
79
+
80
+ #[test]
81
+ fn try_deserialize_factor_0() {
82
+ assert!(try_deserialize(&str("0")).is_err());
83
+ }
84
+
85
+ fn test(value: &str, params: ReadParams) {
86
+ test_str(&str(value), params);
87
+ }
88
+
89
+ fn test_str(str: &str, params: ReadParams) {
90
+ assert_eq!(deserialize(str), params);
91
+ }
92
+
93
+ fn deserialize(str: &str) -> ReadParams {
94
+ try_deserialize(str).unwrap()
95
+ }
96
+
97
+ fn try_deserialize(str: &str) -> Result<ReadParams, serde_urlencoded::de::Error> {
98
+ serde_urlencoded::from_str(str)
99
+ }
100
+
101
+ fn str(value: &str) -> String {
102
+ format!("consistency={value}")
103
+ }
104
+
105
+ fn from_type(r#type: ReadConsistencyType) -> ReadParams {
106
+ ReadParams {
107
+ consistency: Some(ReadConsistency::Type(r#type)),
108
+ ..Default::default()
109
+ }
110
+ }
111
+
112
+ fn from_factor(factor: usize) -> ReadParams {
113
+ ReadParams {
114
+ consistency: Some(ReadConsistency::Factor(factor)),
115
+ ..Default::default()
116
+ }
117
+ }
118
+ }
src/actix/api/recommend_api.rs ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::time::Duration;
2
+
3
+ use actix_web::{post, web, Responder};
4
+ use actix_web_validator::{Json, Path, Query};
5
+ use collection::operations::consistency_params::ReadConsistency;
6
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
7
+ use collection::operations::types::{
8
+ RecommendGroupsRequest, RecommendRequest, RecommendRequestBatch,
9
+ };
10
+ use common::counter::hardware_accumulator::HwMeasurementAcc;
11
+ use itertools::Itertools;
12
+ use segment::types::ScoredPoint;
13
+ use storage::content_manager::collection_verification::{
14
+ check_strict_mode, check_strict_mode_batch,
15
+ };
16
+ use storage::content_manager::errors::StorageError;
17
+ use storage::content_manager::toc::TableOfContent;
18
+ use storage::dispatcher::Dispatcher;
19
+ use storage::rbac::Access;
20
+ use tokio::time::Instant;
21
+
22
+ use super::read_params::ReadParams;
23
+ use super::CollectionPath;
24
+ use crate::actix::auth::ActixAccess;
25
+ use crate::actix::helpers::{self, get_request_hardware_counter, process_response_error};
26
+ use crate::settings::ServiceConfig;
27
+
28
+ #[post("/collections/{name}/points/recommend")]
29
+ async fn recommend_points(
30
+ dispatcher: web::Data<Dispatcher>,
31
+ collection: Path<CollectionPath>,
32
+ request: Json<RecommendRequest>,
33
+ params: Query<ReadParams>,
34
+ service_config: web::Data<ServiceConfig>,
35
+ ActixAccess(access): ActixAccess,
36
+ ) -> impl Responder {
37
+ let RecommendRequest {
38
+ recommend_request,
39
+ shard_key,
40
+ } = request.into_inner();
41
+
42
+ let pass = match check_strict_mode(
43
+ &recommend_request,
44
+ params.timeout_as_secs(),
45
+ &collection.name,
46
+ &dispatcher,
47
+ &access,
48
+ )
49
+ .await
50
+ {
51
+ Ok(pass) => pass,
52
+ Err(err) => return process_response_error(err, Instant::now(), None),
53
+ };
54
+
55
+ let shard_selection = match shard_key {
56
+ None => ShardSelectorInternal::All,
57
+ Some(shard_keys) => shard_keys.into(),
58
+ };
59
+
60
+ let request_hw_counter = get_request_hardware_counter(
61
+ &dispatcher,
62
+ collection.name.clone(),
63
+ service_config.hardware_reporting(),
64
+ );
65
+
66
+ let timing = Instant::now();
67
+
68
+ let result = dispatcher
69
+ .toc(&access, &pass)
70
+ .recommend(
71
+ &collection.name,
72
+ recommend_request,
73
+ params.consistency,
74
+ shard_selection,
75
+ access,
76
+ params.timeout(),
77
+ request_hw_counter.get_counter(),
78
+ )
79
+ .await
80
+ .map(|scored_points| {
81
+ scored_points
82
+ .into_iter()
83
+ .map(api::rest::ScoredPoint::from)
84
+ .collect_vec()
85
+ });
86
+
87
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
88
+ }
89
+
90
+ async fn do_recommend_batch_points(
91
+ toc: &TableOfContent,
92
+ collection_name: &str,
93
+ request: RecommendRequestBatch,
94
+ read_consistency: Option<ReadConsistency>,
95
+ access: Access,
96
+ timeout: Option<Duration>,
97
+ hw_measurement_acc: &HwMeasurementAcc,
98
+ ) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
99
+ let requests = request
100
+ .searches
101
+ .into_iter()
102
+ .map(|req| {
103
+ let shard_selector = match req.shard_key {
104
+ None => ShardSelectorInternal::All,
105
+ Some(shard_key) => ShardSelectorInternal::from(shard_key),
106
+ };
107
+
108
+ (req.recommend_request, shard_selector)
109
+ })
110
+ .collect();
111
+
112
+ toc.recommend_batch(
113
+ collection_name,
114
+ requests,
115
+ read_consistency,
116
+ access,
117
+ timeout,
118
+ hw_measurement_acc,
119
+ )
120
+ .await
121
+ }
122
+
123
+ #[post("/collections/{name}/points/recommend/batch")]
124
+ async fn recommend_batch_points(
125
+ dispatcher: web::Data<Dispatcher>,
126
+ collection: Path<CollectionPath>,
127
+ request: Json<RecommendRequestBatch>,
128
+ params: Query<ReadParams>,
129
+ service_config: web::Data<ServiceConfig>,
130
+ ActixAccess(access): ActixAccess,
131
+ ) -> impl Responder {
132
+ let pass = match check_strict_mode_batch(
133
+ request.searches.iter().map(|i| &i.recommend_request),
134
+ params.timeout_as_secs(),
135
+ &collection.name,
136
+ &dispatcher,
137
+ &access,
138
+ )
139
+ .await
140
+ {
141
+ Ok(pass) => pass,
142
+ Err(err) => return process_response_error(err, Instant::now(), None),
143
+ };
144
+
145
+ let request_hw_counter = get_request_hardware_counter(
146
+ &dispatcher,
147
+ collection.name.clone(),
148
+ service_config.hardware_reporting(),
149
+ );
150
+ let timing = Instant::now();
151
+
152
+ let result = do_recommend_batch_points(
153
+ dispatcher.toc(&access, &pass),
154
+ &collection.name,
155
+ request.into_inner(),
156
+ params.consistency,
157
+ access,
158
+ params.timeout(),
159
+ request_hw_counter.get_counter(),
160
+ )
161
+ .await
162
+ .map(|batch_scored_points| {
163
+ batch_scored_points
164
+ .into_iter()
165
+ .map(|scored_points| {
166
+ scored_points
167
+ .into_iter()
168
+ .map(api::rest::ScoredPoint::from)
169
+ .collect_vec()
170
+ })
171
+ .collect_vec()
172
+ });
173
+
174
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
175
+ }
176
+
177
+ #[post("/collections/{name}/points/recommend/groups")]
178
+ async fn recommend_point_groups(
179
+ dispatcher: web::Data<Dispatcher>,
180
+ collection: Path<CollectionPath>,
181
+ request: Json<RecommendGroupsRequest>,
182
+ params: Query<ReadParams>,
183
+ service_config: web::Data<ServiceConfig>,
184
+ ActixAccess(access): ActixAccess,
185
+ ) -> impl Responder {
186
+ let RecommendGroupsRequest {
187
+ recommend_group_request,
188
+ shard_key,
189
+ } = request.into_inner();
190
+
191
+ let pass = match check_strict_mode(
192
+ &recommend_group_request,
193
+ params.timeout_as_secs(),
194
+ &collection.name,
195
+ &dispatcher,
196
+ &access,
197
+ )
198
+ .await
199
+ {
200
+ Ok(pass) => pass,
201
+ Err(err) => return process_response_error(err, Instant::now(), None),
202
+ };
203
+
204
+ let shard_selection = match shard_key {
205
+ None => ShardSelectorInternal::All,
206
+ Some(shard_keys) => shard_keys.into(),
207
+ };
208
+
209
+ let request_hw_counter = get_request_hardware_counter(
210
+ &dispatcher,
211
+ collection.name.clone(),
212
+ service_config.hardware_reporting(),
213
+ );
214
+ let timing = Instant::now();
215
+
216
+ let result = crate::common::points::do_recommend_point_groups(
217
+ dispatcher.toc(&access, &pass),
218
+ &collection.name,
219
+ recommend_group_request,
220
+ params.consistency,
221
+ shard_selection,
222
+ access,
223
+ params.timeout(),
224
+ request_hw_counter.get_counter(),
225
+ )
226
+ .await;
227
+
228
+ helpers::process_response(result, timing, request_hw_counter.to_rest_api())
229
+ }
230
+ // Configure services
231
+ pub fn config_recommend_api(cfg: &mut web::ServiceConfig) {
232
+ cfg.service(recommend_points)
233
+ .service(recommend_batch_points)
234
+ .service(recommend_point_groups);
235
+ }
src/actix/api/retrieve_api.rs ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::time::Duration;
2
+
3
+ use actix_web::{get, post, web, Responder};
4
+ use actix_web_validator::{Json, Path, Query};
5
+ use collection::operations::consistency_params::ReadConsistency;
6
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
7
+ use collection::operations::types::{
8
+ PointRequest, PointRequestInternal, RecordInternal, ScrollRequest,
9
+ };
10
+ use futures::TryFutureExt;
11
+ use itertools::Itertools;
12
+ use segment::types::{PointIdType, WithPayloadInterface};
13
+ use serde::Deserialize;
14
+ use storage::content_manager::collection_verification::{
15
+ check_strict_mode, check_strict_mode_timeout,
16
+ };
17
+ use storage::content_manager::errors::StorageError;
18
+ use storage::content_manager::toc::TableOfContent;
19
+ use storage::dispatcher::Dispatcher;
20
+ use storage::rbac::Access;
21
+ use tokio::time::Instant;
22
+ use validator::Validate;
23
+
24
+ use super::read_params::ReadParams;
25
+ use super::CollectionPath;
26
+ use crate::actix::auth::ActixAccess;
27
+ use crate::actix::helpers::{self, process_response_error};
28
+ use crate::common::points::do_get_points;
29
+
30
+ #[derive(Deserialize, Validate)]
31
+ struct PointPath {
32
+ #[validate(length(min = 1))]
33
+ // TODO: validate this is a valid ID type (usize or UUID)? Does currently error on deserialize.
34
+ id: String,
35
+ }
36
+
37
+ async fn do_get_point(
38
+ toc: &TableOfContent,
39
+ collection_name: &str,
40
+ point_id: PointIdType,
41
+ read_consistency: Option<ReadConsistency>,
42
+ timeout: Option<Duration>,
43
+ access: Access,
44
+ ) -> Result<Option<RecordInternal>, StorageError> {
45
+ let request = PointRequestInternal {
46
+ ids: vec![point_id],
47
+ with_payload: Some(WithPayloadInterface::Bool(true)),
48
+ with_vector: true.into(),
49
+ };
50
+
51
+ let shard_selection = ShardSelectorInternal::All;
52
+
53
+ toc.retrieve(
54
+ collection_name,
55
+ request,
56
+ read_consistency,
57
+ timeout,
58
+ shard_selection,
59
+ access,
60
+ )
61
+ .await
62
+ .map(|points| points.into_iter().next())
63
+ }
64
+
65
+ #[get("/collections/{name}/points/{id}")]
66
+ async fn get_point(
67
+ dispatcher: web::Data<Dispatcher>,
68
+ collection: Path<CollectionPath>,
69
+ point: Path<PointPath>,
70
+ params: Query<ReadParams>,
71
+ ActixAccess(access): ActixAccess,
72
+ ) -> impl Responder {
73
+ let pass = match check_strict_mode_timeout(
74
+ params.timeout_as_secs(),
75
+ &collection.name,
76
+ &dispatcher,
77
+ &access,
78
+ )
79
+ .await
80
+ {
81
+ Ok(p) => p,
82
+ Err(err) => return process_response_error(err, Instant::now(), None),
83
+ };
84
+
85
+ helpers::time(async move {
86
+ let point_id: PointIdType = point.id.parse().map_err(|_| StorageError::BadInput {
87
+ description: format!("Can not recognize \"{}\" as point id", point.id),
88
+ })?;
89
+
90
+ let Some(record) = do_get_point(
91
+ dispatcher.toc(&access, &pass),
92
+ &collection.name,
93
+ point_id,
94
+ params.consistency,
95
+ params.timeout(),
96
+ access,
97
+ )
98
+ .await?
99
+ else {
100
+ return Err(StorageError::NotFound {
101
+ description: format!("Point with id {point_id} does not exists!"),
102
+ });
103
+ };
104
+
105
+ Ok(api::rest::Record::from(record))
106
+ })
107
+ .await
108
+ }
109
+
110
+ #[post("/collections/{name}/points")]
111
+ async fn get_points(
112
+ dispatcher: web::Data<Dispatcher>,
113
+ collection: Path<CollectionPath>,
114
+ request: Json<PointRequest>,
115
+ params: Query<ReadParams>,
116
+ ActixAccess(access): ActixAccess,
117
+ ) -> impl Responder {
118
+ let pass = match check_strict_mode_timeout(
119
+ params.timeout_as_secs(),
120
+ &collection.name,
121
+ &dispatcher,
122
+ &access,
123
+ )
124
+ .await
125
+ {
126
+ Ok(p) => p,
127
+ Err(err) => return process_response_error(err, Instant::now(), None),
128
+ };
129
+
130
+ let PointRequest {
131
+ point_request,
132
+ shard_key,
133
+ } = request.into_inner();
134
+
135
+ let shard_selection = match shard_key {
136
+ None => ShardSelectorInternal::All,
137
+ Some(shard_keys) => ShardSelectorInternal::from(shard_keys),
138
+ };
139
+
140
+ helpers::time(
141
+ do_get_points(
142
+ dispatcher.toc(&access, &pass),
143
+ &collection.name,
144
+ point_request,
145
+ params.consistency,
146
+ params.timeout(),
147
+ shard_selection,
148
+ access,
149
+ )
150
+ .map_ok(|response| {
151
+ response
152
+ .into_iter()
153
+ .map(api::rest::Record::from)
154
+ .collect_vec()
155
+ }),
156
+ )
157
+ .await
158
+ }
159
+
160
+ #[post("/collections/{name}/points/scroll")]
161
+ async fn scroll_points(
162
+ dispatcher: web::Data<Dispatcher>,
163
+ collection: Path<CollectionPath>,
164
+ request: Json<ScrollRequest>,
165
+ params: Query<ReadParams>,
166
+ ActixAccess(access): ActixAccess,
167
+ ) -> impl Responder {
168
+ let ScrollRequest {
169
+ scroll_request,
170
+ shard_key,
171
+ } = request.into_inner();
172
+
173
+ let pass = match check_strict_mode(
174
+ &scroll_request,
175
+ params.timeout_as_secs(),
176
+ &collection.name,
177
+ &dispatcher,
178
+ &access,
179
+ )
180
+ .await
181
+ {
182
+ Ok(pass) => pass,
183
+ Err(err) => return process_response_error(err, Instant::now(), None),
184
+ };
185
+
186
+ let shard_selection = match shard_key {
187
+ None => ShardSelectorInternal::All,
188
+ Some(shard_keys) => ShardSelectorInternal::from(shard_keys),
189
+ };
190
+
191
+ helpers::time(dispatcher.toc(&access, &pass).scroll(
192
+ &collection.name,
193
+ scroll_request,
194
+ params.consistency,
195
+ params.timeout(),
196
+ shard_selection,
197
+ access,
198
+ ))
199
+ .await
200
+ }
src/actix/api/search_api.rs ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{post, web, HttpResponse, Responder};
2
+ use actix_web_validator::{Json, Path, Query};
3
+ use api::rest::{SearchMatrixOffsetsResponse, SearchMatrixPairsResponse, SearchMatrixRequest};
4
+ use collection::collection::distance_matrix::CollectionSearchMatrixRequest;
5
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
6
+ use collection::operations::types::{
7
+ CoreSearchRequest, SearchGroupsRequest, SearchRequest, SearchRequestBatch,
8
+ };
9
+ use itertools::Itertools;
10
+ use storage::content_manager::collection_verification::{
11
+ check_strict_mode, check_strict_mode_batch,
12
+ };
13
+ use storage::dispatcher::Dispatcher;
14
+ use tokio::time::Instant;
15
+
16
+ use super::read_params::ReadParams;
17
+ use super::CollectionPath;
18
+ use crate::actix::auth::ActixAccess;
19
+ use crate::actix::helpers::{
20
+ get_request_hardware_counter, process_response, process_response_error,
21
+ };
22
+ use crate::common::points::{
23
+ do_core_search_points, do_search_batch_points, do_search_point_groups, do_search_points_matrix,
24
+ };
25
+ use crate::settings::ServiceConfig;
26
+
27
+ #[post("/collections/{name}/points/search")]
28
+ async fn search_points(
29
+ dispatcher: web::Data<Dispatcher>,
30
+ collection: Path<CollectionPath>,
31
+ request: Json<SearchRequest>,
32
+ params: Query<ReadParams>,
33
+ service_config: web::Data<ServiceConfig>,
34
+ ActixAccess(access): ActixAccess,
35
+ ) -> HttpResponse {
36
+ let SearchRequest {
37
+ search_request,
38
+ shard_key,
39
+ } = request.into_inner();
40
+
41
+ let pass = match check_strict_mode(
42
+ &search_request,
43
+ params.timeout_as_secs(),
44
+ &collection.name,
45
+ &dispatcher,
46
+ &access,
47
+ )
48
+ .await
49
+ {
50
+ Ok(pass) => pass,
51
+ Err(err) => return process_response_error(err, Instant::now(), None),
52
+ };
53
+
54
+ let shard_selection = match shard_key {
55
+ None => ShardSelectorInternal::All,
56
+ Some(shard_keys) => shard_keys.into(),
57
+ };
58
+
59
+ let request_hw_counter = get_request_hardware_counter(
60
+ &dispatcher,
61
+ collection.name.clone(),
62
+ service_config.hardware_reporting(),
63
+ );
64
+
65
+ let timing = Instant::now();
66
+
67
+ let result = do_core_search_points(
68
+ dispatcher.toc(&access, &pass),
69
+ &collection.name,
70
+ search_request.into(),
71
+ params.consistency,
72
+ shard_selection,
73
+ access,
74
+ params.timeout(),
75
+ request_hw_counter.get_counter(),
76
+ )
77
+ .await
78
+ .map(|scored_points| {
79
+ scored_points
80
+ .into_iter()
81
+ .map(api::rest::ScoredPoint::from)
82
+ .collect_vec()
83
+ });
84
+
85
+ process_response(result, timing, request_hw_counter.to_rest_api())
86
+ }
87
+
88
+ #[post("/collections/{name}/points/search/batch")]
89
+ async fn batch_search_points(
90
+ dispatcher: web::Data<Dispatcher>,
91
+ collection: Path<CollectionPath>,
92
+ request: Json<SearchRequestBatch>,
93
+ params: Query<ReadParams>,
94
+ service_config: web::Data<ServiceConfig>,
95
+ ActixAccess(access): ActixAccess,
96
+ ) -> HttpResponse {
97
+ let requests = request
98
+ .into_inner()
99
+ .searches
100
+ .into_iter()
101
+ .map(|req| {
102
+ let SearchRequest {
103
+ search_request,
104
+ shard_key,
105
+ } = req;
106
+ let shard_selection = match shard_key {
107
+ None => ShardSelectorInternal::All,
108
+ Some(shard_keys) => shard_keys.into(),
109
+ };
110
+ let core_request: CoreSearchRequest = search_request.into();
111
+
112
+ (core_request, shard_selection)
113
+ })
114
+ .collect::<Vec<_>>();
115
+
116
+ let pass = match check_strict_mode_batch(
117
+ requests.iter().map(|i| &i.0),
118
+ params.timeout_as_secs(),
119
+ &collection.name,
120
+ &dispatcher,
121
+ &access,
122
+ )
123
+ .await
124
+ {
125
+ Ok(pass) => pass,
126
+ Err(err) => return process_response_error(err, Instant::now(), None),
127
+ };
128
+
129
+ let request_hw_counter = get_request_hardware_counter(
130
+ &dispatcher,
131
+ collection.name.clone(),
132
+ service_config.hardware_reporting(),
133
+ );
134
+
135
+ let timing = Instant::now();
136
+
137
+ let result = do_search_batch_points(
138
+ dispatcher.toc(&access, &pass),
139
+ &collection.name,
140
+ requests,
141
+ params.consistency,
142
+ access,
143
+ params.timeout(),
144
+ request_hw_counter.get_counter(),
145
+ )
146
+ .await
147
+ .map(|batch_scored_points| {
148
+ batch_scored_points
149
+ .into_iter()
150
+ .map(|scored_points| {
151
+ scored_points
152
+ .into_iter()
153
+ .map(api::rest::ScoredPoint::from)
154
+ .collect_vec()
155
+ })
156
+ .collect_vec()
157
+ });
158
+
159
+ process_response(result, timing, request_hw_counter.to_rest_api())
160
+ }
161
+
162
+ #[post("/collections/{name}/points/search/groups")]
163
+ async fn search_point_groups(
164
+ dispatcher: web::Data<Dispatcher>,
165
+ collection: Path<CollectionPath>,
166
+ request: Json<SearchGroupsRequest>,
167
+ params: Query<ReadParams>,
168
+ service_config: web::Data<ServiceConfig>,
169
+ ActixAccess(access): ActixAccess,
170
+ ) -> HttpResponse {
171
+ let SearchGroupsRequest {
172
+ search_group_request,
173
+ shard_key,
174
+ } = request.into_inner();
175
+
176
+ let pass = match check_strict_mode(
177
+ &search_group_request,
178
+ params.timeout_as_secs(),
179
+ &collection.name,
180
+ &dispatcher,
181
+ &access,
182
+ )
183
+ .await
184
+ {
185
+ Ok(pass) => pass,
186
+ Err(err) => return process_response_error(err, Instant::now(), None),
187
+ };
188
+
189
+ let shard_selection = match shard_key {
190
+ None => ShardSelectorInternal::All,
191
+ Some(shard_keys) => shard_keys.into(),
192
+ };
193
+
194
+ let request_hw_counter = get_request_hardware_counter(
195
+ &dispatcher,
196
+ collection.name.clone(),
197
+ service_config.hardware_reporting(),
198
+ );
199
+ let timing = Instant::now();
200
+
201
+ let result = do_search_point_groups(
202
+ dispatcher.toc(&access, &pass),
203
+ &collection.name,
204
+ search_group_request,
205
+ params.consistency,
206
+ shard_selection,
207
+ access,
208
+ params.timeout(),
209
+ request_hw_counter.get_counter(),
210
+ )
211
+ .await;
212
+
213
+ process_response(result, timing, request_hw_counter.to_rest_api())
214
+ }
215
+
216
+ #[post("/collections/{name}/points/search/matrix/pairs")]
217
+ async fn search_points_matrix_pairs(
218
+ dispatcher: web::Data<Dispatcher>,
219
+ collection: Path<CollectionPath>,
220
+ request: Json<SearchMatrixRequest>,
221
+ params: Query<ReadParams>,
222
+ service_config: web::Data<ServiceConfig>,
223
+ ActixAccess(access): ActixAccess,
224
+ ) -> impl Responder {
225
+ let SearchMatrixRequest {
226
+ search_request,
227
+ shard_key,
228
+ } = request.into_inner();
229
+
230
+ let pass = match check_strict_mode(
231
+ &search_request,
232
+ params.timeout_as_secs(),
233
+ &collection.name,
234
+ &dispatcher,
235
+ &access,
236
+ )
237
+ .await
238
+ {
239
+ Ok(pass) => pass,
240
+ Err(err) => return process_response_error(err, Instant::now(), None),
241
+ };
242
+
243
+ let shard_selection = match shard_key {
244
+ None => ShardSelectorInternal::All,
245
+ Some(shard_keys) => shard_keys.into(),
246
+ };
247
+
248
+ let request_hw_counter = get_request_hardware_counter(
249
+ &dispatcher,
250
+ collection.name.clone(),
251
+ service_config.hardware_reporting(),
252
+ );
253
+ let timing = Instant::now();
254
+
255
+ let response = do_search_points_matrix(
256
+ dispatcher.toc(&access, &pass),
257
+ &collection.name,
258
+ CollectionSearchMatrixRequest::from(search_request),
259
+ params.consistency,
260
+ shard_selection,
261
+ access,
262
+ params.timeout(),
263
+ request_hw_counter.get_counter(),
264
+ )
265
+ .await
266
+ .map(SearchMatrixPairsResponse::from);
267
+
268
+ process_response(response, timing, request_hw_counter.to_rest_api())
269
+ }
270
+
271
+ #[post("/collections/{name}/points/search/matrix/offsets")]
272
+ async fn search_points_matrix_offsets(
273
+ dispatcher: web::Data<Dispatcher>,
274
+ collection: Path<CollectionPath>,
275
+ request: Json<SearchMatrixRequest>,
276
+ params: Query<ReadParams>,
277
+ service_config: web::Data<ServiceConfig>,
278
+ ActixAccess(access): ActixAccess,
279
+ ) -> impl Responder {
280
+ let SearchMatrixRequest {
281
+ search_request,
282
+ shard_key,
283
+ } = request.into_inner();
284
+
285
+ let pass = match check_strict_mode(
286
+ &search_request,
287
+ params.timeout_as_secs(),
288
+ &collection.name,
289
+ &dispatcher,
290
+ &access,
291
+ )
292
+ .await
293
+ {
294
+ Ok(pass) => pass,
295
+ Err(err) => return process_response_error(err, Instant::now(), None),
296
+ };
297
+
298
+ let shard_selection = match shard_key {
299
+ None => ShardSelectorInternal::All,
300
+ Some(shard_keys) => shard_keys.into(),
301
+ };
302
+
303
+ let request_hw_counter = get_request_hardware_counter(
304
+ &dispatcher,
305
+ collection.name.clone(),
306
+ service_config.hardware_reporting(),
307
+ );
308
+ let timing = Instant::now();
309
+
310
+ let response = do_search_points_matrix(
311
+ dispatcher.toc(&access, &pass),
312
+ &collection.name,
313
+ CollectionSearchMatrixRequest::from(search_request),
314
+ params.consistency,
315
+ shard_selection,
316
+ access,
317
+ params.timeout(),
318
+ request_hw_counter.get_counter(),
319
+ )
320
+ .await
321
+ .map(SearchMatrixOffsetsResponse::from);
322
+
323
+ process_response(response, timing, request_hw_counter.to_rest_api())
324
+ }
325
+
326
+ // Configure services
327
+ pub fn config_search_api(cfg: &mut web::ServiceConfig) {
328
+ cfg.service(search_points)
329
+ .service(batch_search_points)
330
+ .service(search_point_groups)
331
+ .service(search_points_matrix_pairs)
332
+ .service(search_points_matrix_offsets);
333
+ }
src/actix/api/service_api.rs ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::future::Future;
2
+ use std::sync::Arc;
3
+
4
+ use actix_web::http::header::ContentType;
5
+ use actix_web::http::StatusCode;
6
+ use actix_web::rt::time::Instant;
7
+ use actix_web::web::Query;
8
+ use actix_web::{get, post, web, HttpResponse, Responder};
9
+ use actix_web_validator::Json;
10
+ use collection::operations::verification::new_unchecked_verification_pass;
11
+ use common::types::{DetailsLevel, TelemetryDetail};
12
+ use schemars::JsonSchema;
13
+ use segment::common::anonymize::Anonymize;
14
+ use serde::{Deserialize, Serialize};
15
+ use storage::content_manager::errors::StorageError;
16
+ use storage::dispatcher::Dispatcher;
17
+ use storage::rbac::AccessRequirements;
18
+ use tokio::sync::Mutex;
19
+
20
+ use crate::actix::auth::ActixAccess;
21
+ use crate::actix::helpers::{self, process_response_error};
22
+ use crate::common::health;
23
+ use crate::common::helpers::LocksOption;
24
+ use crate::common::metrics::MetricsData;
25
+ use crate::common::stacktrace::get_stack_trace;
26
+ use crate::common::telemetry::TelemetryCollector;
27
+ use crate::tracing;
28
+
29
+ #[derive(Deserialize, Serialize, JsonSchema)]
30
+ pub struct TelemetryParam {
31
+ pub anonymize: Option<bool>,
32
+ pub details_level: Option<usize>,
33
+ }
34
+
35
+ #[get("/telemetry")]
36
+ fn telemetry(
37
+ telemetry_collector: web::Data<Mutex<TelemetryCollector>>,
38
+ params: Query<TelemetryParam>,
39
+ ActixAccess(access): ActixAccess,
40
+ ) -> impl Future<Output = HttpResponse> {
41
+ helpers::time(async move {
42
+ access.check_global_access(AccessRequirements::new())?;
43
+ let anonymize = params.anonymize.unwrap_or(false);
44
+ let details_level = params
45
+ .details_level
46
+ .map_or(DetailsLevel::Level0, Into::into);
47
+ let detail = TelemetryDetail {
48
+ level: details_level,
49
+ histograms: false,
50
+ };
51
+ let telemetry_collector = telemetry_collector.lock().await;
52
+ let telemetry_data = telemetry_collector.prepare_data(&access, detail).await;
53
+ let telemetry_data = if anonymize {
54
+ telemetry_data.anonymize()
55
+ } else {
56
+ telemetry_data
57
+ };
58
+ Ok(telemetry_data)
59
+ })
60
+ }
61
+
62
+ #[derive(Deserialize, Serialize, JsonSchema)]
63
+ pub struct MetricsParam {
64
+ pub anonymize: Option<bool>,
65
+ }
66
+
67
+ #[get("/metrics")]
68
+ async fn metrics(
69
+ telemetry_collector: web::Data<Mutex<TelemetryCollector>>,
70
+ params: Query<MetricsParam>,
71
+ ActixAccess(access): ActixAccess,
72
+ ) -> HttpResponse {
73
+ if let Err(err) = access.check_global_access(AccessRequirements::new()) {
74
+ return process_response_error(err, Instant::now(), None);
75
+ }
76
+
77
+ let anonymize = params.anonymize.unwrap_or(false);
78
+ let telemetry_collector = telemetry_collector.lock().await;
79
+ let telemetry_data = telemetry_collector
80
+ .prepare_data(
81
+ &access,
82
+ TelemetryDetail {
83
+ level: DetailsLevel::Level1,
84
+ histograms: true,
85
+ },
86
+ )
87
+ .await;
88
+ let telemetry_data = if anonymize {
89
+ telemetry_data.anonymize()
90
+ } else {
91
+ telemetry_data
92
+ };
93
+
94
+ HttpResponse::Ok()
95
+ .content_type(ContentType::plaintext())
96
+ .body(MetricsData::from(telemetry_data).format_metrics())
97
+ }
98
+
99
+ #[post("/locks")]
100
+ fn put_locks(
101
+ dispatcher: web::Data<Dispatcher>,
102
+ locks_option: Json<LocksOption>,
103
+ ActixAccess(access): ActixAccess,
104
+ ) -> impl Future<Output = HttpResponse> {
105
+ // Not a collection level request.
106
+ let pass = new_unchecked_verification_pass();
107
+
108
+ helpers::time(async move {
109
+ let toc = dispatcher.toc(&access, &pass);
110
+ access.check_global_access(AccessRequirements::new().manage())?;
111
+ let result = LocksOption {
112
+ write: toc.is_write_locked(),
113
+ error_message: toc.get_lock_error_message(),
114
+ };
115
+ toc.set_locks(locks_option.write, locks_option.error_message.clone());
116
+ Ok(result)
117
+ })
118
+ }
119
+
120
+ #[get("/locks")]
121
+ fn get_locks(
122
+ dispatcher: web::Data<Dispatcher>,
123
+ ActixAccess(access): ActixAccess,
124
+ ) -> impl Future<Output = HttpResponse> {
125
+ // Not a collection level request.
126
+ let pass = new_unchecked_verification_pass();
127
+
128
+ helpers::time(async move {
129
+ access.check_global_access(AccessRequirements::new())?;
130
+ let toc = dispatcher.toc(&access, &pass);
131
+ let result = LocksOption {
132
+ write: toc.is_write_locked(),
133
+ error_message: toc.get_lock_error_message(),
134
+ };
135
+ Ok(result)
136
+ })
137
+ }
138
+
139
+ #[get("/stacktrace")]
140
+ fn get_stacktrace(ActixAccess(access): ActixAccess) -> impl Future<Output = HttpResponse> {
141
+ helpers::time(async move {
142
+ access.check_global_access(AccessRequirements::new().manage())?;
143
+ Ok(get_stack_trace())
144
+ })
145
+ }
146
+
147
+ #[get("/healthz")]
148
+ async fn healthz() -> impl Responder {
149
+ kubernetes_healthz()
150
+ }
151
+
152
+ #[get("/livez")]
153
+ async fn livez() -> impl Responder {
154
+ kubernetes_healthz()
155
+ }
156
+
157
+ #[get("/readyz")]
158
+ async fn readyz(health_checker: web::Data<Option<Arc<health::HealthChecker>>>) -> impl Responder {
159
+ let is_ready = match health_checker.as_ref() {
160
+ Some(health_checker) => health_checker.check_ready().await,
161
+ None => true,
162
+ };
163
+
164
+ let (status, body) = if is_ready {
165
+ (StatusCode::OK, "all shards are ready")
166
+ } else {
167
+ (StatusCode::SERVICE_UNAVAILABLE, "some shards are not ready")
168
+ };
169
+
170
+ HttpResponse::build(status)
171
+ .content_type(ContentType::plaintext())
172
+ .body(body)
173
+ }
174
+
175
+ /// Basic Kubernetes healthz endpoint
176
+ fn kubernetes_healthz() -> impl Responder {
177
+ HttpResponse::Ok()
178
+ .content_type(ContentType::plaintext())
179
+ .body("healthz check passed")
180
+ }
181
+
182
+ #[get("/logger")]
183
+ async fn get_logger_config(handle: web::Data<tracing::LoggerHandle>) -> impl Responder {
184
+ let timing = Instant::now();
185
+ let result = handle.get_config().await;
186
+ helpers::process_response(Ok(result), timing, None)
187
+ }
188
+
189
+ #[post("/logger")]
190
+ async fn update_logger_config(
191
+ handle: web::Data<tracing::LoggerHandle>,
192
+ config: web::Json<tracing::LoggerConfig>,
193
+ ) -> impl Responder {
194
+ let timing = Instant::now();
195
+
196
+ let result = handle
197
+ .update_config(config.into_inner())
198
+ .await
199
+ .map(|_| true)
200
+ .map_err(|err| StorageError::service_error(err.to_string()));
201
+
202
+ helpers::process_response(result, timing, None)
203
+ }
204
+
205
+ // Configure services
206
+ pub fn config_service_api(cfg: &mut web::ServiceConfig) {
207
+ cfg.service(telemetry)
208
+ .service(metrics)
209
+ .service(put_locks)
210
+ .service(get_locks)
211
+ .service(get_stacktrace)
212
+ .service(healthz)
213
+ .service(livez)
214
+ .service(readyz)
215
+ .service(get_logger_config)
216
+ .service(update_logger_config);
217
+ }
src/actix/api/shards_api.rs ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::{post, put, web, Responder};
2
+ use actix_web_validator::{Json, Path, Query};
3
+ use collection::operations::cluster_ops::{
4
+ ClusterOperations, CreateShardingKey, CreateShardingKeyOperation, DropShardingKey,
5
+ DropShardingKeyOperation,
6
+ };
7
+ use storage::dispatcher::Dispatcher;
8
+ use tokio::time::Instant;
9
+
10
+ use crate::actix::api::collections_api::WaitTimeout;
11
+ use crate::actix::api::CollectionPath;
12
+ use crate::actix::auth::ActixAccess;
13
+ use crate::actix::helpers::process_response;
14
+ use crate::common::collections::do_update_collection_cluster;
15
+
16
+ // ToDo: introduce API for listing shard keys
17
+
18
+ #[put("/collections/{name}/shards")]
19
+ async fn create_shard_key(
20
+ dispatcher: web::Data<Dispatcher>,
21
+ collection: Path<CollectionPath>,
22
+ request: Json<CreateShardingKey>,
23
+ Query(query): Query<WaitTimeout>,
24
+ ActixAccess(access): ActixAccess,
25
+ ) -> impl Responder {
26
+ let timing = Instant::now();
27
+ let wait_timeout = query.timeout();
28
+ let dispatcher = dispatcher.into_inner();
29
+
30
+ let request = request.into_inner();
31
+
32
+ let operation = ClusterOperations::CreateShardingKey(CreateShardingKeyOperation {
33
+ create_sharding_key: request,
34
+ });
35
+
36
+ let response = do_update_collection_cluster(
37
+ &dispatcher,
38
+ collection.name.clone(),
39
+ operation,
40
+ access,
41
+ wait_timeout,
42
+ )
43
+ .await;
44
+
45
+ process_response(response, timing, None)
46
+ }
47
+
48
+ #[post("/collections/{name}/shards/delete")]
49
+ async fn delete_shard_key(
50
+ dispatcher: web::Data<Dispatcher>,
51
+ collection: Path<CollectionPath>,
52
+ request: Json<DropShardingKey>,
53
+ Query(query): Query<WaitTimeout>,
54
+ ActixAccess(access): ActixAccess,
55
+ ) -> impl Responder {
56
+ let timing = Instant::now();
57
+ let wait_timeout = query.timeout();
58
+
59
+ let dispatcher = dispatcher.into_inner();
60
+ let request = request.into_inner();
61
+
62
+ let operation = ClusterOperations::DropShardingKey(DropShardingKeyOperation {
63
+ drop_sharding_key: request,
64
+ });
65
+
66
+ let response = do_update_collection_cluster(
67
+ &dispatcher,
68
+ collection.name.clone(),
69
+ operation,
70
+ access,
71
+ wait_timeout,
72
+ )
73
+ .await;
74
+
75
+ process_response(response, timing, None)
76
+ }
77
+
78
+ pub fn config_shards_api(cfg: &mut web::ServiceConfig) {
79
+ cfg.service(create_shard_key).service(delete_shard_key);
80
+ }
src/actix/api/snapshot_api.rs ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::path::Path;
2
+
3
+ use actix_multipart::form::tempfile::TempFile;
4
+ use actix_multipart::form::MultipartForm;
5
+ use actix_web::{delete, get, post, put, web, Responder, Result};
6
+ use actix_web_validator as valid;
7
+ use collection::common::file_utils::move_file;
8
+ use collection::common::sha_256::{hash_file, hashes_equal};
9
+ use collection::common::snapshot_stream::SnapshotStream;
10
+ use collection::operations::snapshot_ops::{
11
+ ShardSnapshotRecover, SnapshotPriority, SnapshotRecover,
12
+ };
13
+ use collection::operations::verification::new_unchecked_verification_pass;
14
+ use collection::shards::shard::ShardId;
15
+ use futures::{FutureExt as _, TryFutureExt as _};
16
+ use reqwest::Url;
17
+ use schemars::JsonSchema;
18
+ use serde::{Deserialize, Serialize};
19
+ use storage::content_manager::errors::StorageError;
20
+ use storage::content_manager::snapshots::recover::do_recover_from_snapshot;
21
+ use storage::content_manager::snapshots::{
22
+ do_create_full_snapshot, do_delete_collection_snapshot, do_delete_full_snapshot,
23
+ do_list_full_snapshots,
24
+ };
25
+ use storage::content_manager::toc::TableOfContent;
26
+ use storage::dispatcher::Dispatcher;
27
+ use storage::rbac::{Access, AccessRequirements};
28
+ use uuid::Uuid;
29
+ use validator::Validate;
30
+
31
+ use super::{CollectionPath, StrictCollectionPath};
32
+ use crate::actix::auth::ActixAccess;
33
+ use crate::actix::helpers::{self, HttpError};
34
+ use crate::common;
35
+ use crate::common::collections::*;
36
+ use crate::common::http_client::HttpClient;
37
+
38
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
39
+ pub struct SnapshotUploadingParam {
40
+ pub wait: Option<bool>,
41
+ pub priority: Option<SnapshotPriority>,
42
+
43
+ /// Optional SHA256 checksum to verify snapshot integrity before recovery.
44
+ #[serde(default)]
45
+ #[validate(custom(function = "::common::validation::validate_sha256_hash"))]
46
+ pub checksum: Option<String>,
47
+ }
48
+
49
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
50
+ pub struct SnapshottingParam {
51
+ pub wait: Option<bool>,
52
+ }
53
+
54
+ #[derive(MultipartForm)]
55
+ pub struct SnapshottingForm {
56
+ snapshot: TempFile,
57
+ }
58
+
59
+ // Actix specific code
60
+ pub async fn do_get_full_snapshot(
61
+ toc: &TableOfContent,
62
+ access: Access,
63
+ snapshot_name: &str,
64
+ ) -> Result<SnapshotStream, HttpError> {
65
+ access.check_global_access(AccessRequirements::new())?;
66
+ let snapshots_storage_manager = toc.get_snapshots_storage_manager()?;
67
+ let snapshot_path =
68
+ snapshots_storage_manager.get_full_snapshot_path(toc.snapshots_path(), snapshot_name)?;
69
+ let snapshot_stream = snapshots_storage_manager
70
+ .get_snapshot_stream(&snapshot_path)
71
+ .await?;
72
+ Ok(snapshot_stream)
73
+ }
74
+
75
+ pub async fn do_save_uploaded_snapshot(
76
+ toc: &TableOfContent,
77
+ collection_name: &str,
78
+ snapshot: TempFile,
79
+ ) -> Result<Url, StorageError> {
80
+ let filename = snapshot
81
+ .file_name
82
+ // Sanitize the file name:
83
+ // - only take the top level path (no directories such as ../)
84
+ // - require the file name to be valid UTF-8
85
+ .and_then(|x| {
86
+ Path::new(&x)
87
+ .file_name()
88
+ .map(|filename| filename.to_owned())
89
+ })
90
+ .and_then(|x| x.to_str().map(|x| x.to_owned()))
91
+ .unwrap_or_else(|| Uuid::new_v4().to_string());
92
+ let collection_snapshot_path = toc.snapshots_path_for_collection(collection_name);
93
+ if !collection_snapshot_path.exists() {
94
+ log::debug!(
95
+ "Creating missing collection snapshots directory for {}",
96
+ collection_name
97
+ );
98
+ toc.create_snapshots_path(collection_name).await?;
99
+ }
100
+
101
+ let path = collection_snapshot_path.join(filename);
102
+
103
+ move_file(snapshot.file.path(), &path).await?;
104
+
105
+ let absolute_path = path.canonicalize()?;
106
+
107
+ let snapshot_location = Url::from_file_path(&absolute_path).map_err(|_| {
108
+ StorageError::service_error(format!(
109
+ "Failed to convert path to URL: {}",
110
+ absolute_path.display()
111
+ ))
112
+ })?;
113
+
114
+ Ok(snapshot_location)
115
+ }
116
+
117
+ // Actix specific code
118
+ pub async fn do_get_snapshot(
119
+ toc: &TableOfContent,
120
+ access: Access,
121
+ collection_name: &str,
122
+ snapshot_name: &str,
123
+ ) -> Result<SnapshotStream, HttpError> {
124
+ let collection_pass =
125
+ access.check_collection_access(collection_name, AccessRequirements::new().whole())?;
126
+ let collection: tokio::sync::RwLockReadGuard<collection::collection::Collection> =
127
+ toc.get_collection(&collection_pass).await?;
128
+ let snapshot_storage_manager = collection.get_snapshots_storage_manager()?;
129
+ let snapshot_path =
130
+ snapshot_storage_manager.get_snapshot_path(collection.snapshots_path(), snapshot_name)?;
131
+ let snapshot_stream = snapshot_storage_manager
132
+ .get_snapshot_stream(&snapshot_path)
133
+ .await?;
134
+ Ok(snapshot_stream)
135
+ }
136
+
137
+ #[get("/collections/{name}/snapshots")]
138
+ async fn list_snapshots(
139
+ dispatcher: web::Data<Dispatcher>,
140
+ path: web::Path<String>,
141
+ ActixAccess(access): ActixAccess,
142
+ ) -> impl Responder {
143
+ // Nothing to verify.
144
+ let pass = new_unchecked_verification_pass();
145
+
146
+ helpers::time(do_list_snapshots(
147
+ dispatcher.toc(&access, &pass),
148
+ access,
149
+ &path,
150
+ ))
151
+ .await
152
+ }
153
+
154
+ #[post("/collections/{name}/snapshots")]
155
+ async fn create_snapshot(
156
+ dispatcher: web::Data<Dispatcher>,
157
+ path: web::Path<String>,
158
+ params: valid::Query<SnapshottingParam>,
159
+ ActixAccess(access): ActixAccess,
160
+ ) -> impl Responder {
161
+ // Nothing to verify.
162
+ let pass = new_unchecked_verification_pass();
163
+
164
+ let collection_name = path.into_inner();
165
+
166
+ let future = async move {
167
+ do_create_snapshot(
168
+ dispatcher.toc(&access, &pass).clone(),
169
+ access,
170
+ &collection_name,
171
+ )
172
+ .await
173
+ };
174
+
175
+ helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
176
+ }
177
+
178
+ #[post("/collections/{name}/snapshots/upload")]
179
+ async fn upload_snapshot(
180
+ dispatcher: web::Data<Dispatcher>,
181
+ http_client: web::Data<HttpClient>,
182
+ collection: valid::Path<StrictCollectionPath>,
183
+ MultipartForm(form): MultipartForm<SnapshottingForm>,
184
+ params: valid::Query<SnapshotUploadingParam>,
185
+ ActixAccess(access): ActixAccess,
186
+ ) -> impl Responder {
187
+ let wait = params.wait;
188
+
189
+ // Nothing to verify.
190
+ let pass = new_unchecked_verification_pass();
191
+
192
+ let future = async move {
193
+ let snapshot = form.snapshot;
194
+
195
+ access.check_global_access(AccessRequirements::new().manage())?;
196
+
197
+ if let Some(checksum) = &params.checksum {
198
+ let snapshot_checksum = hash_file(snapshot.file.path()).await?;
199
+ if !hashes_equal(snapshot_checksum.as_str(), checksum.as_str()) {
200
+ return Err(StorageError::checksum_mismatch(snapshot_checksum, checksum));
201
+ }
202
+ }
203
+
204
+ let snapshot_location =
205
+ do_save_uploaded_snapshot(dispatcher.toc(&access, &pass), &collection.name, snapshot)
206
+ .await?;
207
+
208
+ // Snapshot is a local file, we do not need an API key for that
209
+ let http_client = http_client.client(None)?;
210
+
211
+ let snapshot_recover = SnapshotRecover {
212
+ location: snapshot_location,
213
+ priority: params.priority,
214
+ checksum: None,
215
+ api_key: None,
216
+ };
217
+
218
+ do_recover_from_snapshot(
219
+ dispatcher.get_ref(),
220
+ &collection.name,
221
+ snapshot_recover,
222
+ access,
223
+ http_client,
224
+ )
225
+ .await
226
+ };
227
+
228
+ helpers::time_or_accept(future, wait.unwrap_or(true)).await
229
+ }
230
+
231
+ #[put("/collections/{name}/snapshots/recover")]
232
+ async fn recover_from_snapshot(
233
+ dispatcher: web::Data<Dispatcher>,
234
+ http_client: web::Data<HttpClient>,
235
+ collection: valid::Path<CollectionPath>,
236
+ request: valid::Json<SnapshotRecover>,
237
+ params: valid::Query<SnapshottingParam>,
238
+ ActixAccess(access): ActixAccess,
239
+ ) -> impl Responder {
240
+ let future = async move {
241
+ let snapshot_recover = request.into_inner();
242
+ let http_client = http_client.client(snapshot_recover.api_key.as_deref())?;
243
+
244
+ do_recover_from_snapshot(
245
+ dispatcher.get_ref(),
246
+ &collection.name,
247
+ snapshot_recover,
248
+ access,
249
+ http_client,
250
+ )
251
+ .await
252
+ };
253
+
254
+ helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
255
+ }
256
+
257
+ #[get("/collections/{name}/snapshots/{snapshot_name}")]
258
+ async fn get_snapshot(
259
+ dispatcher: web::Data<Dispatcher>,
260
+ path: web::Path<(String, String)>,
261
+ ActixAccess(access): ActixAccess,
262
+ ) -> impl Responder {
263
+ // Nothing to verify.
264
+ let pass = new_unchecked_verification_pass();
265
+
266
+ let (collection_name, snapshot_name) = path.into_inner();
267
+ do_get_snapshot(
268
+ dispatcher.toc(&access, &pass),
269
+ access,
270
+ &collection_name,
271
+ &snapshot_name,
272
+ )
273
+ .await
274
+ }
275
+
276
+ #[get("/snapshots")]
277
+ async fn list_full_snapshots(
278
+ dispatcher: web::Data<Dispatcher>,
279
+ ActixAccess(access): ActixAccess,
280
+ ) -> impl Responder {
281
+ // nothing to verify.
282
+ let pass = new_unchecked_verification_pass();
283
+
284
+ helpers::time(do_list_full_snapshots(
285
+ dispatcher.toc(&access, &pass),
286
+ access,
287
+ ))
288
+ .await
289
+ }
290
+
291
+ #[post("/snapshots")]
292
+ async fn create_full_snapshot(
293
+ dispatcher: web::Data<Dispatcher>,
294
+ params: valid::Query<SnapshottingParam>,
295
+ ActixAccess(access): ActixAccess,
296
+ ) -> impl Responder {
297
+ let future = async move { do_create_full_snapshot(dispatcher.get_ref(), access).await };
298
+ helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
299
+ }
300
+
301
+ #[get("/snapshots/{snapshot_name}")]
302
+ async fn get_full_snapshot(
303
+ dispatcher: web::Data<Dispatcher>,
304
+ path: web::Path<String>,
305
+ ActixAccess(access): ActixAccess,
306
+ ) -> impl Responder {
307
+ // nothing to verify.
308
+ let pass = new_unchecked_verification_pass();
309
+
310
+ let snapshot_name = path.into_inner();
311
+ do_get_full_snapshot(dispatcher.toc(&access, &pass), access, &snapshot_name).await
312
+ }
313
+
314
+ #[delete("/snapshots/{snapshot_name}")]
315
+ async fn delete_full_snapshot(
316
+ dispatcher: web::Data<Dispatcher>,
317
+ path: web::Path<String>,
318
+ params: valid::Query<SnapshottingParam>,
319
+ ActixAccess(access): ActixAccess,
320
+ ) -> impl Responder {
321
+ let future = async move {
322
+ let snapshot_name = path.into_inner();
323
+ do_delete_full_snapshot(dispatcher.get_ref(), access, &snapshot_name).await
324
+ };
325
+
326
+ helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
327
+ }
328
+
329
+ #[delete("/collections/{name}/snapshots/{snapshot_name}")]
330
+ async fn delete_collection_snapshot(
331
+ dispatcher: web::Data<Dispatcher>,
332
+ path: web::Path<(String, String)>,
333
+ params: valid::Query<SnapshottingParam>,
334
+ ActixAccess(access): ActixAccess,
335
+ ) -> impl Responder {
336
+ let future = async move {
337
+ let (collection_name, snapshot_name) = path.into_inner();
338
+
339
+ do_delete_collection_snapshot(
340
+ dispatcher.get_ref(),
341
+ access,
342
+ &collection_name,
343
+ &snapshot_name,
344
+ )
345
+ .await
346
+ };
347
+
348
+ helpers::time_or_accept(future, params.wait.unwrap_or(true)).await
349
+ }
350
+
351
+ #[get("/collections/{collection}/shards/{shard}/snapshots")]
352
+ async fn list_shard_snapshots(
353
+ dispatcher: web::Data<Dispatcher>,
354
+ path: web::Path<(String, ShardId)>,
355
+ ActixAccess(access): ActixAccess,
356
+ ) -> impl Responder {
357
+ // nothing to verify.
358
+ let pass = new_unchecked_verification_pass();
359
+
360
+ let (collection, shard) = path.into_inner();
361
+
362
+ let future = common::snapshots::list_shard_snapshots(
363
+ dispatcher.toc(&access, &pass).clone(),
364
+ access,
365
+ collection,
366
+ shard,
367
+ )
368
+ .map_err(Into::into);
369
+
370
+ helpers::time(future).await
371
+ }
372
+
373
+ #[post("/collections/{collection}/shards/{shard}/snapshots")]
374
+ async fn create_shard_snapshot(
375
+ dispatcher: web::Data<Dispatcher>,
376
+ path: web::Path<(String, ShardId)>,
377
+ query: web::Query<SnapshottingParam>,
378
+ ActixAccess(access): ActixAccess,
379
+ ) -> impl Responder {
380
+ // nothing to verify.
381
+ let pass = new_unchecked_verification_pass();
382
+
383
+ let (collection, shard) = path.into_inner();
384
+ let future = common::snapshots::create_shard_snapshot(
385
+ dispatcher.toc(&access, &pass).clone(),
386
+ access,
387
+ collection,
388
+ shard,
389
+ );
390
+
391
+ helpers::time_or_accept(future, query.wait.unwrap_or(true)).await
392
+ }
393
+
394
+ #[get("/collections/{collection}/shards/{shard}/snapshot")]
395
+ async fn stream_shard_snapshot(
396
+ dispatcher: web::Data<Dispatcher>,
397
+ path: web::Path<(String, ShardId)>,
398
+ ActixAccess(access): ActixAccess,
399
+ ) -> Result<SnapshotStream, HttpError> {
400
+ // nothing to verify.
401
+ let pass = new_unchecked_verification_pass();
402
+
403
+ let (collection, shard) = path.into_inner();
404
+ Ok(common::snapshots::stream_shard_snapshot(
405
+ dispatcher.toc(&access, &pass).clone(),
406
+ access,
407
+ collection,
408
+ shard,
409
+ )
410
+ .await?)
411
+ }
412
+
413
+ // TODO: `PUT` (same as `recover_from_snapshot`) or `POST`!?
414
+ #[put("/collections/{collection}/shards/{shard}/snapshots/recover")]
415
+ async fn recover_shard_snapshot(
416
+ dispatcher: web::Data<Dispatcher>,
417
+ http_client: web::Data<HttpClient>,
418
+ path: web::Path<(String, ShardId)>,
419
+ query: web::Query<SnapshottingParam>,
420
+ web::Json(request): web::Json<ShardSnapshotRecover>,
421
+ ActixAccess(access): ActixAccess,
422
+ ) -> impl Responder {
423
+ // nothing to verify.
424
+ let pass = new_unchecked_verification_pass();
425
+
426
+ let future = async move {
427
+ let (collection, shard) = path.into_inner();
428
+
429
+ common::snapshots::recover_shard_snapshot(
430
+ dispatcher.toc(&access, &pass).clone(),
431
+ access,
432
+ collection,
433
+ shard,
434
+ request.location,
435
+ request.priority.unwrap_or_default(),
436
+ request.checksum,
437
+ http_client.as_ref().clone(),
438
+ request.api_key,
439
+ )
440
+ .await?;
441
+
442
+ Ok(true)
443
+ };
444
+
445
+ helpers::time_or_accept(future, query.wait.unwrap_or(true)).await
446
+ }
447
+
448
+ // TODO: `POST` (same as `upload_snapshot`) or `PUT`!?
449
+ #[post("/collections/{collection}/shards/{shard}/snapshots/upload")]
450
+ async fn upload_shard_snapshot(
451
+ dispatcher: web::Data<Dispatcher>,
452
+ path: web::Path<(String, ShardId)>,
453
+ query: web::Query<SnapshotUploadingParam>,
454
+ MultipartForm(form): MultipartForm<SnapshottingForm>,
455
+ ActixAccess(access): ActixAccess,
456
+ ) -> impl Responder {
457
+ // nothing to verify.
458
+ let pass = new_unchecked_verification_pass();
459
+
460
+ let (collection, shard) = path.into_inner();
461
+ let SnapshotUploadingParam {
462
+ wait,
463
+ priority,
464
+ checksum,
465
+ } = query.into_inner();
466
+
467
+ // - `recover_shard_snapshot_impl` is *not* cancel safe
468
+ // - but the task is *spawned* on the runtime and won't be cancelled, if request is cancelled
469
+
470
+ let future = cancel::future::spawn_cancel_on_drop(move |cancel| async move {
471
+ // TODO: Run this check before the multipart blob is uploaded
472
+ let collection_pass = access
473
+ .check_global_access(AccessRequirements::new().manage())?
474
+ .issue_pass(&collection);
475
+
476
+ if let Some(checksum) = checksum {
477
+ let snapshot_checksum = hash_file(form.snapshot.file.path()).await?;
478
+ if !hashes_equal(snapshot_checksum.as_str(), checksum.as_str()) {
479
+ return Err(StorageError::checksum_mismatch(snapshot_checksum, checksum));
480
+ }
481
+ }
482
+
483
+ let future = async {
484
+ let collection = dispatcher
485
+ .toc(&access, &pass)
486
+ .get_collection(&collection_pass)
487
+ .await?;
488
+ collection.assert_shard_exists(shard).await?;
489
+
490
+ Result::<_, StorageError>::Ok(collection)
491
+ };
492
+
493
+ let collection = cancel::future::cancel_on_token(cancel.clone(), future).await??;
494
+
495
+ // `recover_shard_snapshot_impl` is *not* cancel safe
496
+ common::snapshots::recover_shard_snapshot_impl(
497
+ dispatcher.toc(&access, &pass),
498
+ &collection,
499
+ shard,
500
+ form.snapshot.file.path(),
501
+ priority.unwrap_or_default(),
502
+ cancel,
503
+ )
504
+ .await?;
505
+
506
+ Ok(())
507
+ })
508
+ .map(|x| x.map_err(Into::into).and_then(|x| x));
509
+
510
+ helpers::time_or_accept(future, wait.unwrap_or(true)).await
511
+ }
512
+
513
+ #[get("/collections/{collection}/shards/{shard}/snapshots/{snapshot}")]
514
+ async fn download_shard_snapshot(
515
+ dispatcher: web::Data<Dispatcher>,
516
+ path: web::Path<(String, ShardId, String)>,
517
+ ActixAccess(access): ActixAccess,
518
+ ) -> Result<impl Responder, HttpError> {
519
+ // nothing to verify.
520
+ let pass = new_unchecked_verification_pass();
521
+
522
+ let (collection, shard, snapshot) = path.into_inner();
523
+ let collection_pass =
524
+ access.check_collection_access(&collection, AccessRequirements::new().whole())?;
525
+ let collection = dispatcher
526
+ .toc(&access, &pass)
527
+ .get_collection(&collection_pass)
528
+ .await?;
529
+ let snapshots_storage_manager = collection.get_snapshots_storage_manager()?;
530
+ let snapshot_path = collection
531
+ .shards_holder()
532
+ .read()
533
+ .await
534
+ .get_shard_snapshot_path(collection.snapshots_path(), shard, &snapshot)
535
+ .await?;
536
+ let snapshot_stream = snapshots_storage_manager
537
+ .get_snapshot_stream(&snapshot_path)
538
+ .await?;
539
+ Ok(snapshot_stream)
540
+ }
541
+
542
+ #[delete("/collections/{collection}/shards/{shard}/snapshots/{snapshot}")]
543
+ async fn delete_shard_snapshot(
544
+ dispatcher: web::Data<Dispatcher>,
545
+ path: web::Path<(String, ShardId, String)>,
546
+ query: web::Query<SnapshottingParam>,
547
+ ActixAccess(access): ActixAccess,
548
+ ) -> impl Responder {
549
+ // nothing to verify.
550
+ let pass = new_unchecked_verification_pass();
551
+
552
+ let (collection, shard, snapshot) = path.into_inner();
553
+ let future = common::snapshots::delete_shard_snapshot(
554
+ dispatcher.toc(&access, &pass).clone(),
555
+ access,
556
+ collection,
557
+ shard,
558
+ snapshot,
559
+ )
560
+ .map_ok(|_| true)
561
+ .map_err(Into::into);
562
+
563
+ helpers::time_or_accept(future, query.wait.unwrap_or(true)).await
564
+ }
565
+
566
+ // Configure services
567
+ pub fn config_snapshots_api(cfg: &mut web::ServiceConfig) {
568
+ cfg.service(list_snapshots)
569
+ .service(create_snapshot)
570
+ .service(upload_snapshot)
571
+ .service(recover_from_snapshot)
572
+ .service(get_snapshot)
573
+ .service(list_full_snapshots)
574
+ .service(create_full_snapshot)
575
+ .service(get_full_snapshot)
576
+ .service(delete_full_snapshot)
577
+ .service(delete_collection_snapshot)
578
+ .service(list_shard_snapshots)
579
+ .service(create_shard_snapshot)
580
+ .service(stream_shard_snapshot)
581
+ .service(recover_shard_snapshot)
582
+ .service(upload_shard_snapshot)
583
+ .service(download_shard_snapshot)
584
+ .service(delete_shard_snapshot);
585
+ }
src/actix/api/update_api.rs ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use actix_web::rt::time::Instant;
2
+ use actix_web::{delete, post, put, web, Responder};
3
+ use actix_web_validator::{Json, Path, Query};
4
+ use api::rest::schema::PointInsertOperations;
5
+ use api::rest::UpdateVectors;
6
+ use collection::operations::payload_ops::{DeletePayload, SetPayload};
7
+ use collection::operations::point_ops::{PointsSelector, WriteOrdering};
8
+ use collection::operations::types::UpdateResult;
9
+ use collection::operations::vector_ops::DeleteVectors;
10
+ use collection::operations::verification::new_unchecked_verification_pass;
11
+ use schemars::JsonSchema;
12
+ use segment::json_path::JsonPath;
13
+ use serde::{Deserialize, Serialize};
14
+ use storage::content_manager::collection_verification::check_strict_mode;
15
+ use storage::dispatcher::Dispatcher;
16
+ use validator::Validate;
17
+
18
+ use super::CollectionPath;
19
+ use crate::actix::auth::ActixAccess;
20
+ use crate::actix::helpers::{self, process_response, process_response_error};
21
+ use crate::common::points::{
22
+ do_batch_update_points, do_clear_payload, do_create_index, do_delete_index, do_delete_payload,
23
+ do_delete_points, do_delete_vectors, do_overwrite_payload, do_set_payload, do_update_vectors,
24
+ do_upsert_points, CreateFieldIndex, UpdateOperations,
25
+ };
26
+
27
+ #[derive(Deserialize, Validate)]
28
+ struct FieldPath {
29
+ #[serde(rename = "field_name")]
30
+ name: JsonPath,
31
+ }
32
+
33
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
34
+ pub struct UpdateParam {
35
+ pub wait: Option<bool>,
36
+ pub ordering: Option<WriteOrdering>,
37
+ }
38
+
39
+ #[put("/collections/{name}/points")]
40
+ async fn upsert_points(
41
+ dispatcher: web::Data<Dispatcher>,
42
+ collection: Path<CollectionPath>,
43
+ operation: Json<PointInsertOperations>,
44
+ params: Query<UpdateParam>,
45
+ ActixAccess(access): ActixAccess,
46
+ ) -> impl Responder {
47
+ // nothing to verify.
48
+ let pass = new_unchecked_verification_pass();
49
+
50
+ let operation = operation.into_inner();
51
+ let wait = params.wait.unwrap_or(false);
52
+ let ordering = params.ordering.unwrap_or_default();
53
+
54
+ helpers::time(do_upsert_points(
55
+ dispatcher.toc(&access, &pass).clone(),
56
+ collection.into_inner().name,
57
+ operation,
58
+ None,
59
+ None,
60
+ wait,
61
+ ordering,
62
+ access,
63
+ ))
64
+ .await
65
+ }
66
+
67
+ #[post("/collections/{name}/points/delete")]
68
+ async fn delete_points(
69
+ dispatcher: web::Data<Dispatcher>,
70
+ collection: Path<CollectionPath>,
71
+ operation: Json<PointsSelector>,
72
+ params: Query<UpdateParam>,
73
+ ActixAccess(access): ActixAccess,
74
+ ) -> impl Responder {
75
+ let operation = operation.into_inner();
76
+ let pass =
77
+ match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
78
+ Ok(pass) => pass,
79
+ Err(err) => return process_response_error(err, Instant::now(), None),
80
+ };
81
+
82
+ let wait = params.wait.unwrap_or(false);
83
+ let ordering = params.ordering.unwrap_or_default();
84
+
85
+ helpers::time(do_delete_points(
86
+ dispatcher.toc(&access, &pass).clone(),
87
+ collection.into_inner().name,
88
+ operation,
89
+ None,
90
+ None,
91
+ wait,
92
+ ordering,
93
+ access,
94
+ ))
95
+ .await
96
+ }
97
+
98
+ #[put("/collections/{name}/points/vectors")]
99
+ async fn update_vectors(
100
+ dispatcher: web::Data<Dispatcher>,
101
+ collection: Path<CollectionPath>,
102
+ operation: Json<UpdateVectors>,
103
+ params: Query<UpdateParam>,
104
+ ActixAccess(access): ActixAccess,
105
+ ) -> impl Responder {
106
+ // Nothing to verify here.
107
+ let pass = new_unchecked_verification_pass();
108
+
109
+ let operation = operation.into_inner();
110
+ let wait = params.wait.unwrap_or(false);
111
+ let ordering = params.ordering.unwrap_or_default();
112
+
113
+ helpers::time(do_update_vectors(
114
+ dispatcher.toc(&access, &pass).clone(),
115
+ collection.into_inner().name,
116
+ operation,
117
+ None,
118
+ None,
119
+ wait,
120
+ ordering,
121
+ access,
122
+ ))
123
+ .await
124
+ }
125
+
126
+ #[post("/collections/{name}/points/vectors/delete")]
127
+ async fn delete_vectors(
128
+ dispatcher: web::Data<Dispatcher>,
129
+ collection: Path<CollectionPath>,
130
+ operation: Json<DeleteVectors>,
131
+ params: Query<UpdateParam>,
132
+ ActixAccess(access): ActixAccess,
133
+ ) -> impl Responder {
134
+ let timing = Instant::now();
135
+
136
+ let operation = operation.into_inner();
137
+ let pass =
138
+ match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
139
+ Ok(pass) => pass,
140
+ Err(err) => return process_response_error(err, timing, None),
141
+ };
142
+
143
+ let wait = params.wait.unwrap_or(false);
144
+ let ordering = params.ordering.unwrap_or_default();
145
+
146
+ let response = do_delete_vectors(
147
+ dispatcher.toc(&access, &pass).clone(),
148
+ collection.into_inner().name,
149
+ operation,
150
+ None,
151
+ None,
152
+ wait,
153
+ ordering,
154
+ access,
155
+ )
156
+ .await;
157
+ process_response(response, timing, None)
158
+ }
159
+
160
+ #[post("/collections/{name}/points/payload")]
161
+ async fn set_payload(
162
+ dispatcher: web::Data<Dispatcher>,
163
+ collection: Path<CollectionPath>,
164
+ operation: Json<SetPayload>,
165
+ params: Query<UpdateParam>,
166
+ ActixAccess(access): ActixAccess,
167
+ ) -> impl Responder {
168
+ let operation = operation.into_inner();
169
+
170
+ let pass =
171
+ match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
172
+ Ok(pass) => pass,
173
+ Err(err) => return process_response_error(err, Instant::now(), None),
174
+ };
175
+
176
+ let wait = params.wait.unwrap_or(false);
177
+ let ordering = params.ordering.unwrap_or_default();
178
+
179
+ helpers::time(do_set_payload(
180
+ dispatcher.toc(&access, &pass).clone(),
181
+ collection.into_inner().name,
182
+ operation,
183
+ None,
184
+ None,
185
+ wait,
186
+ ordering,
187
+ access,
188
+ ))
189
+ .await
190
+ }
191
+
192
+ #[put("/collections/{name}/points/payload")]
193
+ async fn overwrite_payload(
194
+ dispatcher: web::Data<Dispatcher>,
195
+ collection: Path<CollectionPath>,
196
+ operation: Json<SetPayload>,
197
+ params: Query<UpdateParam>,
198
+ ActixAccess(access): ActixAccess,
199
+ ) -> impl Responder {
200
+ let operation = operation.into_inner();
201
+ let pass =
202
+ match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
203
+ Ok(pass) => pass,
204
+ Err(err) => return process_response_error(err, Instant::now(), None),
205
+ };
206
+ let wait = params.wait.unwrap_or(false);
207
+ let ordering = params.ordering.unwrap_or_default();
208
+
209
+ helpers::time(do_overwrite_payload(
210
+ dispatcher.toc(&access, &pass).clone(),
211
+ collection.into_inner().name,
212
+ operation,
213
+ None,
214
+ None,
215
+ wait,
216
+ ordering,
217
+ access,
218
+ ))
219
+ .await
220
+ }
221
+
222
+ #[post("/collections/{name}/points/payload/delete")]
223
+ async fn delete_payload(
224
+ dispatcher: web::Data<Dispatcher>,
225
+ collection: Path<CollectionPath>,
226
+ operation: Json<DeletePayload>,
227
+ params: Query<UpdateParam>,
228
+ ActixAccess(access): ActixAccess,
229
+ ) -> impl Responder {
230
+ let operation = operation.into_inner();
231
+ let pass =
232
+ match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
233
+ Ok(pass) => pass,
234
+ Err(err) => return process_response_error(err, Instant::now(), None),
235
+ };
236
+ let wait = params.wait.unwrap_or(false);
237
+ let ordering = params.ordering.unwrap_or_default();
238
+
239
+ helpers::time(do_delete_payload(
240
+ dispatcher.toc(&access, &pass).clone(),
241
+ collection.into_inner().name,
242
+ operation,
243
+ None,
244
+ None,
245
+ wait,
246
+ ordering,
247
+ access,
248
+ ))
249
+ .await
250
+ }
251
+
252
+ #[post("/collections/{name}/points/payload/clear")]
253
+ async fn clear_payload(
254
+ dispatcher: web::Data<Dispatcher>,
255
+ collection: Path<CollectionPath>,
256
+ operation: Json<PointsSelector>,
257
+ params: Query<UpdateParam>,
258
+ ActixAccess(access): ActixAccess,
259
+ ) -> impl Responder {
260
+ let operation = operation.into_inner();
261
+ let pass =
262
+ match check_strict_mode(&operation, None, &collection.name, &dispatcher, &access).await {
263
+ Ok(pass) => pass,
264
+ Err(err) => return process_response_error(err, Instant::now(), None),
265
+ };
266
+
267
+ let wait = params.wait.unwrap_or(false);
268
+ let ordering = params.ordering.unwrap_or_default();
269
+
270
+ helpers::time(do_clear_payload(
271
+ dispatcher.toc(&access, &pass).clone(),
272
+ collection.into_inner().name,
273
+ operation,
274
+ None,
275
+ None,
276
+ wait,
277
+ ordering,
278
+ access,
279
+ ))
280
+ .await
281
+ }
282
+
283
+ #[post("/collections/{name}/points/batch")]
284
+ async fn update_batch(
285
+ dispatcher: web::Data<Dispatcher>,
286
+ collection: Path<CollectionPath>,
287
+ operations: Json<UpdateOperations>,
288
+ params: Query<UpdateParam>,
289
+ ActixAccess(access): ActixAccess,
290
+ ) -> impl Responder {
291
+ let timing = Instant::now();
292
+ let operations = operations.into_inner();
293
+
294
+ let mut vpass = None;
295
+ for operation in operations.operations.iter() {
296
+ let pass = match check_strict_mode(operation, None, &collection.name, &dispatcher, &access)
297
+ .await
298
+ {
299
+ Ok(pass) => pass,
300
+ Err(err) => return process_response_error(err, Instant::now(), None),
301
+ };
302
+ vpass = Some(pass);
303
+ }
304
+
305
+ // vpass == None => No update operation available
306
+ let Some(pass) = vpass else {
307
+ return process_response::<Vec<UpdateResult>>(Ok(vec![]), timing, None);
308
+ };
309
+
310
+ let wait = params.wait.unwrap_or(false);
311
+ let ordering = params.ordering.unwrap_or_default();
312
+
313
+ let response = do_batch_update_points(
314
+ dispatcher.toc(&access, &pass).clone(),
315
+ collection.into_inner().name,
316
+ operations.operations,
317
+ None,
318
+ None,
319
+ wait,
320
+ ordering,
321
+ access,
322
+ )
323
+ .await;
324
+ process_response(response, timing, None)
325
+ }
326
+ #[put("/collections/{name}/index")]
327
+ async fn create_field_index(
328
+ dispatcher: web::Data<Dispatcher>,
329
+ collection: Path<CollectionPath>,
330
+ operation: Json<CreateFieldIndex>,
331
+ params: Query<UpdateParam>,
332
+ ActixAccess(access): ActixAccess,
333
+ ) -> impl Responder {
334
+ let timing = Instant::now();
335
+ let operation = operation.into_inner();
336
+ let wait = params.wait.unwrap_or(false);
337
+ let ordering = params.ordering.unwrap_or_default();
338
+
339
+ let response = do_create_index(
340
+ dispatcher.into_inner(),
341
+ collection.into_inner().name,
342
+ operation,
343
+ None,
344
+ None,
345
+ wait,
346
+ ordering,
347
+ access,
348
+ )
349
+ .await;
350
+ process_response(response, timing, None)
351
+ }
352
+
353
+ #[delete("/collections/{name}/index/{field_name}")]
354
+ async fn delete_field_index(
355
+ dispatcher: web::Data<Dispatcher>,
356
+ collection: Path<CollectionPath>,
357
+ field: Path<FieldPath>,
358
+ params: Query<UpdateParam>,
359
+ ActixAccess(access): ActixAccess,
360
+ ) -> impl Responder {
361
+ let timing = Instant::now();
362
+ let wait = params.wait.unwrap_or(false);
363
+ let ordering = params.ordering.unwrap_or_default();
364
+
365
+ let response = do_delete_index(
366
+ dispatcher.into_inner(),
367
+ collection.into_inner().name,
368
+ field.name.clone(),
369
+ None,
370
+ None,
371
+ wait,
372
+ ordering,
373
+ access,
374
+ )
375
+ .await;
376
+ process_response(response, timing, None)
377
+ }
378
+
379
+ // Configure services
380
+ pub fn config_update_api(cfg: &mut web::ServiceConfig) {
381
+ cfg.service(upsert_points)
382
+ .service(delete_points)
383
+ .service(update_vectors)
384
+ .service(delete_vectors)
385
+ .service(set_payload)
386
+ .service(overwrite_payload)
387
+ .service(delete_payload)
388
+ .service(clear_payload)
389
+ .service(create_field_index)
390
+ .service(delete_field_index)
391
+ .service(update_batch);
392
+ }
src/actix/auth.rs ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::convert::Infallible;
2
+ use std::future::{ready, Ready};
3
+ use std::sync::Arc;
4
+
5
+ use actix_web::body::{BoxBody, EitherBody};
6
+ use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
7
+ use actix_web::{Error, FromRequest, HttpMessage, HttpResponse, ResponseError};
8
+ use futures_util::future::LocalBoxFuture;
9
+ use storage::rbac::Access;
10
+
11
+ use super::helpers::HttpError;
12
+ use crate::common::auth::{AuthError, AuthKeys};
13
+
14
+ pub struct Auth {
15
+ auth_keys: AuthKeys,
16
+ whitelist: Vec<WhitelistItem>,
17
+ }
18
+
19
+ impl Auth {
20
+ pub fn new(auth_keys: AuthKeys, whitelist: Vec<WhitelistItem>) -> Self {
21
+ Self {
22
+ auth_keys,
23
+ whitelist,
24
+ }
25
+ }
26
+ }
27
+
28
+ impl<S, B> Transform<S, ServiceRequest> for Auth
29
+ where
30
+ S: Service<ServiceRequest, Response = ServiceResponse<EitherBody<B, BoxBody>>, Error = Error>
31
+ + 'static,
32
+ S::Future: 'static,
33
+ B: 'static,
34
+ {
35
+ type Response = ServiceResponse<EitherBody<B, BoxBody>>;
36
+ type Error = Error;
37
+ type InitError = ();
38
+ type Transform = AuthMiddleware<S>;
39
+ type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
+
41
+ fn new_transform(&self, service: S) -> Self::Future {
42
+ ready(Ok(AuthMiddleware {
43
+ auth_keys: Arc::new(self.auth_keys.clone()),
44
+ whitelist: self.whitelist.clone(),
45
+ service: Arc::new(service),
46
+ }))
47
+ }
48
+ }
49
+
50
+ #[derive(Clone, Eq, PartialEq, Hash)]
51
+ pub struct WhitelistItem(pub String, pub PathMode);
52
+
53
+ impl WhitelistItem {
54
+ pub fn exact<S: Into<String>>(path: S) -> Self {
55
+ Self(path.into(), PathMode::Exact)
56
+ }
57
+
58
+ pub fn prefix<S: Into<String>>(path: S) -> Self {
59
+ Self(path.into(), PathMode::Prefix)
60
+ }
61
+
62
+ pub fn matches(&self, other: &str) -> bool {
63
+ self.1.check(&self.0, other)
64
+ }
65
+ }
66
+
67
+ #[derive(Copy, Clone, Eq, PartialEq, Hash)]
68
+ pub enum PathMode {
69
+ /// Path must match exactly
70
+ Exact,
71
+ /// Path must have given prefix
72
+ Prefix,
73
+ }
74
+
75
+ impl PathMode {
76
+ fn check(&self, key: &str, other: &str) -> bool {
77
+ match self {
78
+ Self::Exact => key == other,
79
+ Self::Prefix => other.starts_with(key),
80
+ }
81
+ }
82
+ }
83
+
84
+ pub struct AuthMiddleware<S> {
85
+ auth_keys: Arc<AuthKeys>,
86
+ /// List of items whitelisted from authentication.
87
+ whitelist: Vec<WhitelistItem>,
88
+ service: Arc<S>,
89
+ }
90
+
91
+ impl<S> AuthMiddleware<S> {
92
+ pub fn is_path_whitelisted(&self, path: &str) -> bool {
93
+ self.whitelist.iter().any(|item| item.matches(path))
94
+ }
95
+ }
96
+
97
+ impl<S, B> Service<ServiceRequest> for AuthMiddleware<S>
98
+ where
99
+ S: Service<ServiceRequest, Response = ServiceResponse<EitherBody<B, BoxBody>>, Error = Error>
100
+ + 'static,
101
+ S::Future: 'static,
102
+ B: 'static,
103
+ {
104
+ type Response = ServiceResponse<EitherBody<B, BoxBody>>;
105
+ type Error = Error;
106
+ type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
107
+
108
+ forward_ready!(service);
109
+
110
+ fn call(&self, req: ServiceRequest) -> Self::Future {
111
+ let path = req.path();
112
+
113
+ if self.is_path_whitelisted(path) {
114
+ return Box::pin(self.service.call(req));
115
+ }
116
+
117
+ let auth_keys = self.auth_keys.clone();
118
+ let service = self.service.clone();
119
+ Box::pin(async move {
120
+ match auth_keys
121
+ .validate_request(|key| req.headers().get(key).and_then(|val| val.to_str().ok()))
122
+ .await
123
+ {
124
+ Ok(access) => {
125
+ let previous = req.extensions_mut().insert::<Access>(access);
126
+ debug_assert!(
127
+ previous.is_none(),
128
+ "Previous access object should not exist in the request"
129
+ );
130
+ service.call(req).await
131
+ }
132
+ Err(e) => {
133
+ let resp = match e {
134
+ AuthError::Unauthorized(e) => HttpResponse::Unauthorized().body(e),
135
+ AuthError::Forbidden(e) => HttpResponse::Forbidden().body(e),
136
+ AuthError::StorageError(e) => HttpError::from(e).error_response(),
137
+ };
138
+ Ok(req.into_response(resp).map_into_right_body())
139
+ }
140
+ }
141
+ })
142
+ }
143
+ }
144
+
145
+ pub struct ActixAccess(pub Access);
146
+
147
+ impl FromRequest for ActixAccess {
148
+ type Error = Infallible;
149
+ type Future = Ready<Result<Self, Self::Error>>;
150
+
151
+ fn from_request(
152
+ req: &actix_web::HttpRequest,
153
+ _payload: &mut actix_web::dev::Payload,
154
+ ) -> Self::Future {
155
+ let access = req.extensions_mut().remove::<Access>().unwrap_or_else(|| {
156
+ Access::full("All requests have full by default access when API key is not configured")
157
+ });
158
+ ready(Ok(ActixAccess(access)))
159
+ }
160
+ }
src/actix/certificate_helpers.rs ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::fmt::Debug;
2
+ use std::fs::File;
3
+ use std::io::{self, BufRead, BufReader};
4
+ use std::sync::Arc;
5
+ use std::time::{Duration, Instant};
6
+
7
+ use parking_lot::RwLock;
8
+ use rustls::client::VerifierBuilderError;
9
+ use rustls::pki_types::CertificateDer;
10
+ use rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier};
11
+ use rustls::sign::CertifiedKey;
12
+ use rustls::{crypto, RootCertStore, ServerConfig};
13
+ use rustls_pemfile::Item;
14
+
15
+ use crate::settings::{Settings, TlsConfig};
16
+
17
+ type Result<T> = std::result::Result<T, Error>;
18
+
19
+ /// A TTL based rotating server certificate resolver
20
+ #[derive(Debug)]
21
+ struct RotatingCertificateResolver {
22
+ /// TLS configuration used for loading/refreshing certified key
23
+ tls_config: TlsConfig,
24
+
25
+ /// TTL for each rotation
26
+ ttl: Option<Duration>,
27
+
28
+ /// Current certified key
29
+ key: RwLock<CertifiedKeyWithAge>,
30
+ }
31
+
32
+ impl RotatingCertificateResolver {
33
+ pub fn new(tls_config: TlsConfig, ttl: Option<Duration>) -> Result<Self> {
34
+ let certified_key = load_certified_key(&tls_config)?;
35
+
36
+ Ok(Self {
37
+ tls_config,
38
+ ttl,
39
+ key: RwLock::new(CertifiedKeyWithAge::from(certified_key)),
40
+ })
41
+ }
42
+
43
+ /// Get certificate key or refresh
44
+ ///
45
+ /// The key is automatically refreshed when the TTL is reached.
46
+ /// If refreshing fails, an error is logged and the old key is persisted.
47
+ fn get_key_or_refresh(&self) -> Arc<CertifiedKey> {
48
+ // Get read-only lock to the key. If TTL is not configured or is not expired, return key.
49
+ let key = self.key.read();
50
+ let ttl = match self.ttl {
51
+ Some(ttl) if key.is_expired(ttl) => ttl,
52
+ _ => return key.key.clone(),
53
+ };
54
+ drop(key);
55
+
56
+ // If TTL is expired:
57
+ // - get read-write lock to the key
58
+ // - *re-check that TTL is expired* (to avoid refreshing the key multiple times from concurrent threads)
59
+ // - refresh and return the key
60
+ let mut key = self.key.write();
61
+ if key.is_expired(ttl) {
62
+ if let Err(err) = key.refresh(&self.tls_config) {
63
+ log::error!("Failed to refresh server TLS certificate, keeping current: {err}");
64
+ }
65
+ }
66
+
67
+ key.key.clone()
68
+ }
69
+ }
70
+
71
+ impl ResolvesServerCert for RotatingCertificateResolver {
72
+ fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
73
+ Some(self.get_key_or_refresh())
74
+ }
75
+ }
76
+
77
+ #[derive(Debug)]
78
+ struct CertifiedKeyWithAge {
79
+ /// Last time the certificate was updated/replaced
80
+ last_update: Instant,
81
+
82
+ /// Current certified key
83
+ key: Arc<CertifiedKey>,
84
+ }
85
+
86
+ impl CertifiedKeyWithAge {
87
+ pub fn from(key: Arc<CertifiedKey>) -> Self {
88
+ Self {
89
+ last_update: Instant::now(),
90
+ key,
91
+ }
92
+ }
93
+
94
+ pub fn refresh(&mut self, tls_config: &TlsConfig) -> Result<()> {
95
+ *self = Self::from(load_certified_key(tls_config)?);
96
+ Ok(())
97
+ }
98
+
99
+ pub fn age(&self) -> Duration {
100
+ self.last_update.elapsed()
101
+ }
102
+
103
+ pub fn is_expired(&self, ttl: Duration) -> bool {
104
+ self.age() >= ttl
105
+ }
106
+ }
107
+
108
+ /// Load TLS configuration and construct certified key.
109
+ fn load_certified_key(tls_config: &TlsConfig) -> Result<Arc<CertifiedKey>> {
110
+ // Load certificates
111
+ let certs: Vec<CertificateDer> = with_buf_read(&tls_config.cert, |rd| {
112
+ rustls_pemfile::read_all(rd).collect::<io::Result<Vec<_>>>()
113
+ })?
114
+ .into_iter()
115
+ .filter_map(|item| match item {
116
+ Item::X509Certificate(data) => Some(data),
117
+ _ => None,
118
+ })
119
+ .collect();
120
+ if certs.is_empty() {
121
+ return Err(Error::NoServerCert);
122
+ }
123
+
124
+ // Load private key
125
+ let private_key_item =
126
+ with_buf_read(&tls_config.key, rustls_pemfile::read_one)?.ok_or(Error::NoPrivateKey)?;
127
+ let private_key = match private_key_item {
128
+ Item::Pkcs1Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey),
129
+ Item::Pkcs8Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey),
130
+ Item::Sec1Key(pkey) => rustls_pki_types::PrivateKeyDer::from(pkey),
131
+ _ => return Err(Error::InvalidPrivateKey),
132
+ };
133
+ let signing_key = crypto::ring::sign::any_supported_type(&private_key).map_err(Error::Sign)?;
134
+
135
+ // Construct certified key
136
+ let certified_key = CertifiedKey::new(certs, signing_key);
137
+ Ok(Arc::new(certified_key))
138
+ }
139
+
140
+ /// Generate an actix server configuration with TLS
141
+ ///
142
+ /// Uses TLS settings as configured in configuration by user.
143
+ pub fn actix_tls_server_config(settings: &Settings) -> Result<ServerConfig> {
144
+ let config = ServerConfig::builder();
145
+ let tls_config = settings
146
+ .tls
147
+ .clone()
148
+ .ok_or_else(Settings::tls_config_is_undefined_error)
149
+ .map_err(Error::Io)?;
150
+
151
+ // Verify client CA or not
152
+ let config = if settings.service.verify_https_client_certificate {
153
+ let mut root_cert_store = RootCertStore::empty();
154
+ let ca_certs: Vec<CertificateDer> = with_buf_read(&tls_config.ca_cert, |rd| {
155
+ rustls_pemfile::certs(rd).collect()
156
+ })?;
157
+ root_cert_store.add_parsable_certificates(ca_certs);
158
+ let client_cert_verifier = WebPkiClientVerifier::builder(root_cert_store.into())
159
+ .build()
160
+ .map_err(Error::ClientCertVerifier)?;
161
+ config.with_client_cert_verifier(client_cert_verifier)
162
+ } else {
163
+ config.with_no_client_auth()
164
+ };
165
+
166
+ // Configure rotating certificate resolver
167
+ let ttl = match tls_config.cert_ttl {
168
+ None | Some(0) => None,
169
+ Some(seconds) => Some(Duration::from_secs(seconds)),
170
+ };
171
+ let cert_resolver = RotatingCertificateResolver::new(tls_config, ttl)?;
172
+ let config = config.with_cert_resolver(Arc::new(cert_resolver));
173
+
174
+ Ok(config)
175
+ }
176
+
177
+ fn with_buf_read<T>(path: &str, f: impl FnOnce(&mut dyn BufRead) -> io::Result<T>) -> Result<T> {
178
+ let file = File::open(path).map_err(|err| Error::OpenFile(err, path.into()))?;
179
+ let mut reader = BufReader::new(file);
180
+ let dyn_reader: &mut dyn BufRead = &mut reader;
181
+ f(dyn_reader).map_err(|err| Error::ReadFile(err, path.into()))
182
+ }
183
+
184
+ /// Actix TLS errors.
185
+ #[derive(thiserror::Error, Debug)]
186
+ pub enum Error {
187
+ #[error("TLS file could not be opened: {1}")]
188
+ OpenFile(#[source] io::Error, String),
189
+ #[error("TLS file could not be read: {1}")]
190
+ ReadFile(#[source] io::Error, String),
191
+ #[error("general TLS IO error")]
192
+ Io(#[source] io::Error),
193
+ #[error("no server certificate found")]
194
+ NoServerCert,
195
+ #[error("no private key found")]
196
+ NoPrivateKey,
197
+ #[error("invalid private key")]
198
+ InvalidPrivateKey,
199
+ #[error("TLS signing error")]
200
+ Sign(#[source] rustls::Error),
201
+ #[error("client certificate verification")]
202
+ ClientCertVerifier(#[source] VerifierBuilderError),
203
+ }
src/actix/helpers.rs ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::fmt::Debug;
2
+ use std::future::Future;
3
+
4
+ use actix_web::rt::time::Instant;
5
+ use actix_web::{http, HttpResponse, ResponseError};
6
+ use api::rest::models::{ApiResponse, ApiStatus, HardwareUsage};
7
+ use collection::operations::types::CollectionError;
8
+ use common::counter::hardware_accumulator::HwMeasurementAcc;
9
+ use serde::Serialize;
10
+ use storage::content_manager::errors::StorageError;
11
+ use storage::content_manager::toc::request_hw_counter::RequestHwCounter;
12
+ use storage::dispatcher::Dispatcher;
13
+
14
+ pub fn get_request_hardware_counter(
15
+ dispatcher: &Dispatcher,
16
+ collection_name: String,
17
+ report_to_api: bool,
18
+ ) -> RequestHwCounter {
19
+ RequestHwCounter::new(
20
+ HwMeasurementAcc::new_with_drain(&dispatcher.get_collection_hw_metrics(collection_name)),
21
+ report_to_api,
22
+ false,
23
+ )
24
+ }
25
+
26
+ pub fn accepted_response(timing: Instant, hardware_usage: Option<HardwareUsage>) -> HttpResponse {
27
+ HttpResponse::Accepted().json(ApiResponse::<()> {
28
+ result: None,
29
+ status: ApiStatus::Accepted,
30
+ time: timing.elapsed().as_secs_f64(),
31
+ usage: hardware_usage,
32
+ })
33
+ }
34
+
35
+ pub fn process_response<T>(
36
+ response: Result<T, StorageError>,
37
+ timing: Instant,
38
+ hardware_usage: Option<HardwareUsage>,
39
+ ) -> HttpResponse
40
+ where
41
+ T: Serialize,
42
+ {
43
+ match response {
44
+ Ok(res) => HttpResponse::Ok().json(ApiResponse {
45
+ result: Some(res),
46
+ status: ApiStatus::Ok,
47
+ time: timing.elapsed().as_secs_f64(),
48
+ usage: hardware_usage,
49
+ }),
50
+ Err(err) => process_response_error(err, timing, hardware_usage),
51
+ }
52
+ }
53
+
54
+ pub fn process_response_error(
55
+ err: StorageError,
56
+ timing: Instant,
57
+ hardware_usage: Option<HardwareUsage>,
58
+ ) -> HttpResponse {
59
+ log_service_error(&err);
60
+
61
+ let error = HttpError::from(err);
62
+
63
+ HttpResponse::build(error.status_code()).json(ApiResponse::<()> {
64
+ result: None,
65
+ status: ApiStatus::Error(error.to_string()),
66
+ time: timing.elapsed().as_secs_f64(),
67
+ usage: hardware_usage,
68
+ })
69
+ }
70
+
71
+ /// Response wrapper for a `Future` returning `Result`.
72
+ ///
73
+ /// # Cancel safety
74
+ ///
75
+ /// Future must be cancel safe.
76
+ pub async fn time<T, Fut>(future: Fut) -> HttpResponse
77
+ where
78
+ Fut: Future<Output = Result<T, StorageError>>,
79
+ T: serde::Serialize,
80
+ {
81
+ time_impl(async { future.await.map(Some) }).await
82
+ }
83
+
84
+ /// Response wrapper for a `Future` returning `Result`.
85
+ /// If `wait` is false, returns `202 Accepted` immediately.
86
+ pub async fn time_or_accept<T, Fut>(future: Fut, wait: bool) -> HttpResponse
87
+ where
88
+ Fut: Future<Output = Result<T, StorageError>> + Send + 'static,
89
+ T: serde::Serialize + Send + 'static,
90
+ {
91
+ let future = async move {
92
+ let handle = tokio::task::spawn(async move {
93
+ let result = future.await;
94
+
95
+ if !wait {
96
+ if let Err(err) = &result {
97
+ log_service_error(err);
98
+ }
99
+ }
100
+
101
+ result
102
+ });
103
+
104
+ if wait {
105
+ handle.await?.map(Some)
106
+ } else {
107
+ Ok(None)
108
+ }
109
+ };
110
+
111
+ time_impl(future).await
112
+ }
113
+
114
+ /// # Cancel safety
115
+ ///
116
+ /// Future must be cancel safe.
117
+ async fn time_impl<T, Fut>(future: Fut) -> HttpResponse
118
+ where
119
+ Fut: Future<Output = Result<Option<T>, StorageError>>,
120
+ T: serde::Serialize,
121
+ {
122
+ let instant = Instant::now();
123
+ match future.await.transpose() {
124
+ Some(res) => process_response(res, instant, None),
125
+ None => accepted_response(instant, None),
126
+ }
127
+ }
128
+
129
+ fn log_service_error(err: &StorageError) {
130
+ if let StorageError::ServiceError { backtrace, .. } = err {
131
+ log::error!("Error processing request: {err}");
132
+
133
+ if let Some(backtrace) = backtrace {
134
+ log::trace!("Backtrace: {backtrace}");
135
+ }
136
+ }
137
+ }
138
+
139
+ pub type HttpResult<T, E = HttpError> = Result<T, E>;
140
+
141
+ #[derive(Clone, Debug, thiserror::Error)]
142
+ #[error("{0}")]
143
+ pub struct HttpError(StorageError);
144
+
145
+ impl ResponseError for HttpError {
146
+ fn status_code(&self) -> http::StatusCode {
147
+ match &self.0 {
148
+ StorageError::BadInput { .. } => http::StatusCode::BAD_REQUEST,
149
+ StorageError::NotFound { .. } => http::StatusCode::NOT_FOUND,
150
+ StorageError::ServiceError { .. } => http::StatusCode::INTERNAL_SERVER_ERROR,
151
+ StorageError::BadRequest { .. } => http::StatusCode::BAD_REQUEST,
152
+ StorageError::Locked { .. } => http::StatusCode::FORBIDDEN,
153
+ StorageError::Timeout { .. } => http::StatusCode::REQUEST_TIMEOUT,
154
+ StorageError::AlreadyExists { .. } => http::StatusCode::CONFLICT,
155
+ StorageError::ChecksumMismatch { .. } => http::StatusCode::BAD_REQUEST,
156
+ StorageError::Forbidden { .. } => http::StatusCode::FORBIDDEN,
157
+ StorageError::PreconditionFailed { .. } => http::StatusCode::INTERNAL_SERVER_ERROR,
158
+ StorageError::InferenceError { .. } => http::StatusCode::BAD_REQUEST,
159
+ }
160
+ }
161
+ }
162
+
163
+ impl From<StorageError> for HttpError {
164
+ fn from(err: StorageError) -> Self {
165
+ HttpError(err)
166
+ }
167
+ }
168
+
169
+ impl From<CollectionError> for HttpError {
170
+ fn from(err: CollectionError) -> Self {
171
+ HttpError(err.into())
172
+ }
173
+ }
174
+
175
+ impl From<std::io::Error> for HttpError {
176
+ fn from(err: std::io::Error) -> Self {
177
+ HttpError(err.into()) // TODO: Is this good enough?.. 🤔
178
+ }
179
+ }
src/actix/mod.rs ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
2
+ pub mod actix_telemetry;
3
+ pub mod api;
4
+ mod auth;
5
+ mod certificate_helpers;
6
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
7
+ pub mod helpers;
8
+ pub mod web_ui;
9
+
10
+ use std::io;
11
+ use std::sync::Arc;
12
+
13
+ use ::api::rest::models::{ApiResponse, ApiStatus, VersionInfo};
14
+ use actix_cors::Cors;
15
+ use actix_multipart::form::tempfile::TempFileConfig;
16
+ use actix_multipart::form::MultipartFormConfig;
17
+ use actix_web::middleware::{Compress, Condition, Logger};
18
+ use actix_web::{error, get, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
19
+ use actix_web_extras::middleware::Condition as ConditionEx;
20
+ use api::facet_api::config_facet_api;
21
+ use collection::operations::validation;
22
+ use collection::operations::verification::new_unchecked_verification_pass;
23
+ use storage::dispatcher::Dispatcher;
24
+ use storage::rbac::Access;
25
+
26
+ use crate::actix::api::cluster_api::config_cluster_api;
27
+ use crate::actix::api::collections_api::config_collections_api;
28
+ use crate::actix::api::count_api::count_points;
29
+ use crate::actix::api::debug_api::config_debugger_api;
30
+ use crate::actix::api::discovery_api::config_discovery_api;
31
+ use crate::actix::api::issues_api::config_issues_api;
32
+ use crate::actix::api::local_shard_api::config_local_shard_api;
33
+ use crate::actix::api::query_api::config_query_api;
34
+ use crate::actix::api::recommend_api::config_recommend_api;
35
+ use crate::actix::api::retrieve_api::{get_point, get_points, scroll_points};
36
+ use crate::actix::api::search_api::config_search_api;
37
+ use crate::actix::api::service_api::config_service_api;
38
+ use crate::actix::api::shards_api::config_shards_api;
39
+ use crate::actix::api::snapshot_api::config_snapshots_api;
40
+ use crate::actix::api::update_api::config_update_api;
41
+ use crate::actix::auth::{Auth, WhitelistItem};
42
+ use crate::actix::web_ui::{web_ui_factory, web_ui_folder, WEB_UI_PATH};
43
+ use crate::common::auth::AuthKeys;
44
+ use crate::common::debugger::DebuggerState;
45
+ use crate::common::health;
46
+ use crate::common::http_client::HttpClient;
47
+ use crate::common::telemetry::TelemetryCollector;
48
+ use crate::settings::{max_web_workers, Settings};
49
+ use crate::tracing::LoggerHandle;
50
+
51
+ #[get("/")]
52
+ pub async fn index() -> impl Responder {
53
+ HttpResponse::Ok().json(VersionInfo::default())
54
+ }
55
+
56
+ #[allow(dead_code)]
57
+ pub fn init(
58
+ dispatcher: Arc<Dispatcher>,
59
+ telemetry_collector: Arc<tokio::sync::Mutex<TelemetryCollector>>,
60
+ health_checker: Option<Arc<health::HealthChecker>>,
61
+ settings: Settings,
62
+ logger_handle: LoggerHandle,
63
+ ) -> io::Result<()> {
64
+ actix_web::rt::System::new().block_on(async {
65
+ // Nothing to verify here.
66
+ let pass = new_unchecked_verification_pass();
67
+ let auth_keys = AuthKeys::try_create(
68
+ &settings.service,
69
+ dispatcher
70
+ .toc(&Access::full("For JWT validation"), &pass)
71
+ .clone(),
72
+ );
73
+ let upload_dir = dispatcher
74
+ .toc(&Access::full("For upload dir"), &pass)
75
+ .upload_dir()
76
+ .unwrap();
77
+ let dispatcher_data = web::Data::from(dispatcher);
78
+ let actix_telemetry_collector = telemetry_collector
79
+ .lock()
80
+ .await
81
+ .actix_telemetry_collector
82
+ .clone();
83
+ let debugger_state = web::Data::new(DebuggerState::from_settings(&settings));
84
+ let telemetry_collector_data = web::Data::from(telemetry_collector);
85
+ let logger_handle_data = web::Data::new(logger_handle);
86
+ let http_client = web::Data::new(HttpClient::from_settings(&settings)?);
87
+ let health_checker = web::Data::new(health_checker);
88
+ let web_ui_available = web_ui_folder(&settings);
89
+ let service_config = web::Data::new(settings.service.clone());
90
+
91
+ let mut api_key_whitelist = vec![
92
+ WhitelistItem::exact("/"),
93
+ WhitelistItem::exact("/healthz"),
94
+ WhitelistItem::prefix("/readyz"),
95
+ WhitelistItem::prefix("/livez"),
96
+ ];
97
+ if web_ui_available.is_some() {
98
+ api_key_whitelist.push(WhitelistItem::prefix(WEB_UI_PATH));
99
+ }
100
+
101
+ let mut server = HttpServer::new(move || {
102
+ let cors = Cors::default()
103
+ .allow_any_origin()
104
+ .allow_any_method()
105
+ .allow_any_header();
106
+ let validate_path_config = actix_web_validator::PathConfig::default()
107
+ .error_handler(|err, rec| validation_error_handler("path parameters", err, rec));
108
+ let validate_query_config = actix_web_validator::QueryConfig::default()
109
+ .error_handler(|err, rec| validation_error_handler("query parameters", err, rec));
110
+ let validate_json_config = actix_web_validator::JsonConfig::default()
111
+ .limit(settings.service.max_request_size_mb * 1024 * 1024)
112
+ .error_handler(|err, rec| validation_error_handler("JSON body", err, rec));
113
+
114
+ let mut app = App::new()
115
+ .wrap(Compress::default()) // Reads the `Accept-Encoding` header to negotiate which compression codec to use.
116
+ // api_key middleware
117
+ // note: the last call to `wrap()` or `wrap_fn()` is executed first
118
+ .wrap(ConditionEx::from_option(auth_keys.as_ref().map(
119
+ |auth_keys| Auth::new(auth_keys.clone(), api_key_whitelist.clone()),
120
+ )))
121
+ .wrap(Condition::new(settings.service.enable_cors, cors))
122
+ .wrap(
123
+ // Set up logger, but avoid logging hot status endpoints
124
+ Logger::default()
125
+ .exclude("/")
126
+ .exclude("/metrics")
127
+ .exclude("/telemetry")
128
+ .exclude("/healthz")
129
+ .exclude("/readyz")
130
+ .exclude("/livez"),
131
+ )
132
+ .wrap(actix_telemetry::ActixTelemetryTransform::new(
133
+ actix_telemetry_collector.clone(),
134
+ ))
135
+ .app_data(dispatcher_data.clone())
136
+ .app_data(telemetry_collector_data.clone())
137
+ .app_data(logger_handle_data.clone())
138
+ .app_data(http_client.clone())
139
+ .app_data(debugger_state.clone())
140
+ .app_data(health_checker.clone())
141
+ .app_data(validate_path_config)
142
+ .app_data(validate_query_config)
143
+ .app_data(validate_json_config)
144
+ .app_data(TempFileConfig::default().directory(&upload_dir))
145
+ .app_data(MultipartFormConfig::default().total_limit(usize::MAX))
146
+ .app_data(service_config.clone())
147
+ .service(index)
148
+ .configure(config_collections_api)
149
+ .configure(config_snapshots_api)
150
+ .configure(config_update_api)
151
+ .configure(config_cluster_api)
152
+ .configure(config_service_api)
153
+ .configure(config_search_api)
154
+ .configure(config_recommend_api)
155
+ .configure(config_discovery_api)
156
+ .configure(config_query_api)
157
+ .configure(config_facet_api)
158
+ .configure(config_shards_api)
159
+ .configure(config_issues_api)
160
+ .configure(config_debugger_api)
161
+ .configure(config_local_shard_api)
162
+ // Ordering of services is important for correct path pattern matching
163
+ // See: <https://github.com/qdrant/qdrant/issues/3543>
164
+ .service(scroll_points)
165
+ .service(count_points)
166
+ .service(get_point)
167
+ .service(get_points);
168
+
169
+ if let Some(static_folder) = web_ui_available.as_deref() {
170
+ app = app.service(web_ui_factory(static_folder));
171
+ }
172
+
173
+ app
174
+ })
175
+ .workers(max_web_workers(&settings));
176
+
177
+ let port = settings.service.http_port;
178
+ let bind_addr = format!("{}:{}", settings.service.host, port);
179
+
180
+ // With TLS enabled, bind with certificate helper and Rustls, or bind regularly
181
+ server = if settings.service.enable_tls {
182
+ log::info!(
183
+ "TLS enabled for REST API (TTL: {})",
184
+ settings
185
+ .tls
186
+ .as_ref()
187
+ .and_then(|tls| tls.cert_ttl)
188
+ .map(|ttl| ttl.to_string())
189
+ .unwrap_or_else(|| "none".into()),
190
+ );
191
+
192
+ let config = certificate_helpers::actix_tls_server_config(&settings)
193
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
194
+ server.bind_rustls_0_23(bind_addr, config)?
195
+ } else {
196
+ log::info!("TLS disabled for REST API");
197
+
198
+ server.bind(bind_addr)?
199
+ };
200
+
201
+ log::info!("Qdrant HTTP listening on {}", port);
202
+ server.run().await
203
+ })
204
+ }
205
+
206
+ fn validation_error_handler(
207
+ name: &str,
208
+ err: actix_web_validator::Error,
209
+ _req: &HttpRequest,
210
+ ) -> error::Error {
211
+ use actix_web_validator::error::DeserializeErrors;
212
+
213
+ // Nicely describe deserialization and validation errors
214
+ let msg = match &err {
215
+ actix_web_validator::Error::Validate(errs) => {
216
+ validation::label_errors(format!("Validation error in {name}"), errs)
217
+ }
218
+ actix_web_validator::Error::Deserialize(err) => {
219
+ format!(
220
+ "Deserialize error in {name}: {}",
221
+ match err {
222
+ DeserializeErrors::DeserializeQuery(err) => err.to_string(),
223
+ DeserializeErrors::DeserializeJson(err) => err.to_string(),
224
+ DeserializeErrors::DeserializePath(err) => err.to_string(),
225
+ }
226
+ )
227
+ }
228
+ actix_web_validator::Error::JsonPayloadError(
229
+ actix_web::error::JsonPayloadError::Deserialize(err),
230
+ ) => {
231
+ format!("Format error in {name}: {err}",)
232
+ }
233
+ err => err.to_string(),
234
+ };
235
+
236
+ // Build fitting response
237
+ let response = match &err {
238
+ actix_web_validator::Error::Validate(_) => HttpResponse::UnprocessableEntity(),
239
+ _ => HttpResponse::BadRequest(),
240
+ }
241
+ .json(ApiResponse::<()> {
242
+ result: None,
243
+ status: ApiStatus::Error(msg),
244
+ time: 0.0,
245
+ usage: None,
246
+ });
247
+ error::InternalError::from_response(err, response).into()
248
+ }
249
+
250
+ #[cfg(test)]
251
+ mod tests {
252
+ use ::api::grpc::api_crate_version;
253
+
254
+ #[test]
255
+ fn test_version() {
256
+ assert_eq!(
257
+ api_crate_version(),
258
+ env!("CARGO_PKG_VERSION"),
259
+ "Qdrant and lib/api crate versions are not same"
260
+ );
261
+ }
262
+ }
src/actix/web_ui.rs ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::path::Path;
2
+
3
+ use actix_web::dev::HttpServiceFactory;
4
+ use actix_web::http::header::HeaderValue;
5
+ use actix_web::middleware::DefaultHeaders;
6
+ use actix_web::web;
7
+
8
+ use crate::settings::Settings;
9
+
10
+ const DEFAULT_STATIC_DIR: &str = "./static";
11
+ pub const WEB_UI_PATH: &str = "/dashboard";
12
+
13
+ pub fn web_ui_folder(settings: &Settings) -> Option<String> {
14
+ let web_ui_enabled = settings.service.enable_static_content.unwrap_or(true);
15
+
16
+ if web_ui_enabled {
17
+ let static_folder = settings
18
+ .service
19
+ .static_content_dir
20
+ .clone()
21
+ .unwrap_or_else(|| DEFAULT_STATIC_DIR.to_string());
22
+ let static_folder_path = Path::new(&static_folder);
23
+ if !static_folder_path.exists() || !static_folder_path.is_dir() {
24
+ // enabled BUT folder does not exist
25
+ log::warn!(
26
+ "Static content folder for Web UI '{}' does not exist",
27
+ static_folder_path.display(),
28
+ );
29
+ None
30
+ } else {
31
+ // enabled AND folder exists
32
+ Some(static_folder)
33
+ }
34
+ } else {
35
+ // not enabled
36
+ None
37
+ }
38
+ }
39
+
40
+ pub fn web_ui_factory(static_folder: &str) -> impl HttpServiceFactory {
41
+ web::scope(WEB_UI_PATH)
42
+ .wrap(DefaultHeaders::new().add(("X-Frame-Options", HeaderValue::from_static("DENY"))))
43
+ .service(actix_files::Files::new("/", static_folder).index_file("index.html"))
44
+ }
45
+
46
+ #[cfg(test)]
47
+ mod tests {
48
+ use actix_web::http::header::{self, HeaderMap};
49
+ use actix_web::http::StatusCode;
50
+ use actix_web::test::{self, TestRequest};
51
+ use actix_web::App;
52
+
53
+ use super::*;
54
+
55
+ fn assert_html_custom_headers(headers: &HeaderMap) {
56
+ let content_type = header::HeaderValue::from_static("text/html; charset=utf-8");
57
+ assert_eq!(headers.get(header::CONTENT_TYPE), Some(&content_type));
58
+ let x_frame_options = header::HeaderValue::from_static("DENY");
59
+ assert_eq!(headers.get(header::X_FRAME_OPTIONS), Some(&x_frame_options),);
60
+ }
61
+
62
+ #[actix_web::test]
63
+ async fn test_web_ui() {
64
+ let static_dir = String::from("static");
65
+ let mut settings = Settings::new(None).unwrap();
66
+ settings.service.static_content_dir = Some(static_dir.clone());
67
+
68
+ let maybe_static_folder = web_ui_folder(&settings);
69
+ if maybe_static_folder.is_none() {
70
+ println!("Skipping test because the static folder was not found.");
71
+ return;
72
+ }
73
+
74
+ let static_folder = maybe_static_folder.unwrap();
75
+ let srv = test::init_service(App::new().service(web_ui_factory(&static_folder))).await;
76
+
77
+ // Index path (no trailing slash)
78
+ let req = TestRequest::with_uri(WEB_UI_PATH).to_request();
79
+ let res = test::call_service(&srv, req).await;
80
+ assert_eq!(res.status(), StatusCode::OK);
81
+ let headers = res.headers();
82
+ assert_html_custom_headers(headers);
83
+ // Index path (trailing slash)
84
+ let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/").as_str()).to_request();
85
+ let res = test::call_service(&srv, req).await;
86
+ assert_eq!(res.status(), StatusCode::OK);
87
+ let headers = res.headers();
88
+ assert_html_custom_headers(headers);
89
+ // Index path (index.html file)
90
+ let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/index.html").as_str()).to_request();
91
+ let res = test::call_service(&srv, req).await;
92
+ assert_eq!(res.status(), StatusCode::OK);
93
+ let headers = res.headers();
94
+ assert_html_custom_headers(headers);
95
+ // Static asset (favicon.ico)
96
+ let req = TestRequest::with_uri(format!("{WEB_UI_PATH}/favicon.ico").as_str()).to_request();
97
+ let res = test::call_service(&srv, req).await;
98
+ assert_eq!(res.status(), StatusCode::OK);
99
+ let headers = res.headers();
100
+ assert_eq!(
101
+ headers.get(header::CONTENT_TYPE),
102
+ Some(&header::HeaderValue::from_static("image/x-icon")),
103
+ );
104
+ // Non-existing path (404 Not Found)
105
+ let fake_path = uuid::Uuid::new_v4().to_string();
106
+ let srv = test::init_service(App::new().service(web_ui_factory(&fake_path))).await;
107
+
108
+ let req = TestRequest::with_uri(WEB_UI_PATH).to_request();
109
+ let res = test::call_service(&srv, req).await;
110
+ assert_eq!(res.status(), StatusCode::NOT_FOUND);
111
+ let headers = res.headers();
112
+ assert_eq!(headers.get(header::CONTENT_TYPE), None);
113
+ assert_eq!(headers.get(header::CONTENT_LENGTH), None);
114
+ }
115
+ }
src/common/auth/claims.rs ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use segment::json_path::JsonPath;
2
+ use segment::types::{Condition, FieldCondition, Filter, Match, ValueVariants};
3
+ use serde::{Deserialize, Serialize};
4
+ use storage::rbac::Access;
5
+ use validator::{Validate, ValidationErrors};
6
+
7
+ #[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
8
+ pub struct Claims {
9
+ /// Expiration time (seconds since UNIX epoch)
10
+ pub exp: Option<u64>,
11
+
12
+ #[serde(default = "default_access")]
13
+ pub access: Access,
14
+
15
+ /// Validate this token by looking for a value inside a collection.
16
+ pub value_exists: Option<ValueExists>,
17
+ }
18
+
19
+ #[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
20
+ pub struct KeyValuePair {
21
+ key: JsonPath,
22
+ value: ValueVariants,
23
+ }
24
+
25
+ impl KeyValuePair {
26
+ pub fn to_condition(&self) -> Condition {
27
+ Condition::Field(FieldCondition::new_match(
28
+ self.key.clone(),
29
+ Match::new_value(self.value.clone()),
30
+ ))
31
+ }
32
+ }
33
+
34
+ #[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
35
+ pub struct ValueExists {
36
+ collection: String,
37
+ matches: Vec<KeyValuePair>,
38
+ }
39
+
40
+ fn default_access() -> Access {
41
+ Access::full("Give full access when the access field is not present")
42
+ }
43
+
44
+ impl ValueExists {
45
+ pub fn get_collection(&self) -> &str {
46
+ &self.collection
47
+ }
48
+
49
+ pub fn to_filter(&self) -> Filter {
50
+ let conditions = self
51
+ .matches
52
+ .iter()
53
+ .map(|pair| pair.to_condition())
54
+ .collect();
55
+
56
+ Filter {
57
+ should: None,
58
+ min_should: None,
59
+ must: Some(conditions),
60
+ must_not: None,
61
+ }
62
+ }
63
+ }
64
+
65
+ impl Validate for Claims {
66
+ fn validate(&self) -> Result<(), ValidationErrors> {
67
+ ValidationErrors::merge_all(Ok(()), "access", self.access.validate())
68
+ }
69
+ }
src/common/auth/jwt_parser.rs ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use jsonwebtoken::errors::ErrorKind;
2
+ use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
3
+ use validator::Validate;
4
+
5
+ use super::claims::Claims;
6
+ use super::AuthError;
7
+
8
+ #[derive(Clone)]
9
+ pub struct JwtParser {
10
+ key: DecodingKey,
11
+ validation: Validation,
12
+ }
13
+
14
+ impl JwtParser {
15
+ const ALGORITHM: Algorithm = Algorithm::HS256;
16
+
17
+ pub fn new(secret: &str) -> Self {
18
+ let key = DecodingKey::from_secret(secret.as_bytes());
19
+ let mut validation = Validation::new(Self::ALGORITHM);
20
+
21
+ // Qdrant server is the only audience
22
+ validation.validate_aud = false;
23
+
24
+ // Expiration time leeway to account for clock skew
25
+ validation.leeway = 30;
26
+
27
+ // All claims are optional
28
+ validation.required_spec_claims = Default::default();
29
+
30
+ JwtParser { key, validation }
31
+ }
32
+
33
+ /// Decode the token and return the claims, this already validates the `exp` claim with some leeway.
34
+ /// Returns None when the token doesn't look like a JWT.
35
+ pub fn decode(&self, token: &str) -> Option<Result<Claims, AuthError>> {
36
+ let claims = match decode::<Claims>(token, &self.key, &self.validation) {
37
+ Ok(token_data) => token_data.claims,
38
+ Err(e) => {
39
+ return match e.kind() {
40
+ ErrorKind::ExpiredSignature | ErrorKind::InvalidSignature => {
41
+ Some(Err(AuthError::Forbidden(e.to_string())))
42
+ }
43
+ _ => None,
44
+ }
45
+ }
46
+ };
47
+ if let Err(e) = claims.validate() {
48
+ return Some(Err(AuthError::Unauthorized(e.to_string())));
49
+ }
50
+ Some(Ok(claims))
51
+ }
52
+ }
53
+
54
+ #[cfg(test)]
55
+ mod tests {
56
+ use segment::types::ValueVariants;
57
+ use storage::rbac::{
58
+ Access, CollectionAccess, CollectionAccessList, CollectionAccessMode, GlobalAccessMode,
59
+ PayloadConstraint,
60
+ };
61
+
62
+ use super::*;
63
+
64
+ pub fn create_token(claims: &Claims) -> String {
65
+ use jsonwebtoken::{encode, EncodingKey, Header};
66
+
67
+ let key = EncodingKey::from_secret("secret".as_ref());
68
+ let header = Header::new(JwtParser::ALGORITHM);
69
+ encode(&header, claims, &key).unwrap()
70
+ }
71
+
72
+ #[test]
73
+ fn test_jwt_parser() {
74
+ let exp = std::time::SystemTime::now()
75
+ .duration_since(std::time::UNIX_EPOCH)
76
+ .expect("Time went backwards")
77
+ .as_secs();
78
+ let claims = Claims {
79
+ exp: Some(exp),
80
+ access: Access::Collection(CollectionAccessList(vec![CollectionAccess {
81
+ collection: "collection".to_string(),
82
+ access: CollectionAccessMode::ReadWrite,
83
+ payload: Some(PayloadConstraint(
84
+ vec![
85
+ (
86
+ "field1".parse().unwrap(),
87
+ ValueVariants::String("value".to_string()),
88
+ ),
89
+ ("field2".parse().unwrap(), ValueVariants::Integer(42)),
90
+ ("field3".parse().unwrap(), ValueVariants::Bool(true)),
91
+ ]
92
+ .into_iter()
93
+ .collect(),
94
+ )),
95
+ }])),
96
+ value_exists: None,
97
+ };
98
+ let token = create_token(&claims);
99
+
100
+ let secret = "secret";
101
+ let parser = JwtParser::new(secret);
102
+ let decoded_claims = parser.decode(&token).unwrap().unwrap();
103
+
104
+ assert_eq!(claims, decoded_claims);
105
+ }
106
+
107
+ #[test]
108
+ fn test_exp_validation() {
109
+ let exp = std::time::SystemTime::now()
110
+ .duration_since(std::time::UNIX_EPOCH)
111
+ .expect("Time went backwards")
112
+ .as_secs()
113
+ - 31; // 31 seconds in the past, bigger than the 30 seconds leeway
114
+
115
+ let mut claims = Claims {
116
+ exp: Some(exp),
117
+ access: Access::Global(GlobalAccessMode::Read),
118
+ value_exists: None,
119
+ };
120
+
121
+ let token = create_token(&claims);
122
+
123
+ let secret = "secret";
124
+ let parser = JwtParser::new(secret);
125
+ assert!(matches!(
126
+ parser.decode(&token),
127
+ Some(Err(AuthError::Forbidden(_)))
128
+ ));
129
+
130
+ // Remove the exp claim and it should work
131
+ claims.exp = None;
132
+ let token = create_token(&claims);
133
+
134
+ let decoded_claims = parser.decode(&token).unwrap().unwrap();
135
+
136
+ assert_eq!(claims, decoded_claims);
137
+ }
138
+
139
+ #[test]
140
+ fn test_invalid_token() {
141
+ let claims = Claims {
142
+ exp: None,
143
+ access: Access::Global(GlobalAccessMode::Read),
144
+ value_exists: None,
145
+ };
146
+ let token = create_token(&claims);
147
+
148
+ assert!(matches!(
149
+ JwtParser::new("wrong-secret").decode(&token),
150
+ Some(Err(AuthError::Forbidden(_)))
151
+ ));
152
+
153
+ assert!(JwtParser::new("secret").decode("foo.bar.baz").is_none());
154
+ }
155
+ }
src/common/auth/mod.rs ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::sync::Arc;
2
+
3
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
4
+ use collection::operations::types::ScrollRequestInternal;
5
+ use segment::types::{WithPayloadInterface, WithVector};
6
+ use storage::content_manager::errors::StorageError;
7
+ use storage::content_manager::toc::TableOfContent;
8
+ use storage::rbac::Access;
9
+
10
+ use self::claims::{Claims, ValueExists};
11
+ use self::jwt_parser::JwtParser;
12
+ use super::strings::ct_eq;
13
+ use crate::settings::ServiceConfig;
14
+
15
+ pub mod claims;
16
+ pub mod jwt_parser;
17
+
18
+ pub const HTTP_HEADER_API_KEY: &str = "api-key";
19
+
20
+ /// The API keys used for auth
21
+ #[derive(Clone)]
22
+ pub struct AuthKeys {
23
+ /// A key allowing Read or Write operations
24
+ read_write: Option<String>,
25
+
26
+ /// A key allowing Read operations
27
+ read_only: Option<String>,
28
+
29
+ /// A JWT parser, based on the read_write key
30
+ jwt_parser: Option<JwtParser>,
31
+
32
+ /// Table of content, needed to do stateful validation of JWT
33
+ toc: Arc<TableOfContent>,
34
+ }
35
+
36
+ #[derive(Debug)]
37
+ pub enum AuthError {
38
+ Unauthorized(String),
39
+ Forbidden(String),
40
+ StorageError(StorageError),
41
+ }
42
+
43
+ impl AuthKeys {
44
+ fn get_jwt_parser(service_config: &ServiceConfig) -> Option<JwtParser> {
45
+ if service_config.jwt_rbac.unwrap_or_default() {
46
+ service_config
47
+ .api_key
48
+ .as_ref()
49
+ .map(|secret| JwtParser::new(secret))
50
+ } else {
51
+ None
52
+ }
53
+ }
54
+
55
+ /// Defines the auth scheme given the service config
56
+ ///
57
+ /// Returns None if no scheme is specified.
58
+ pub fn try_create(service_config: &ServiceConfig, toc: Arc<TableOfContent>) -> Option<Self> {
59
+ match (
60
+ service_config.api_key.clone(),
61
+ service_config.read_only_api_key.clone(),
62
+ ) {
63
+ (None, None) => None,
64
+ (read_write, read_only) => Some(Self {
65
+ read_write,
66
+ read_only,
67
+ jwt_parser: Self::get_jwt_parser(service_config),
68
+ toc,
69
+ }),
70
+ }
71
+ }
72
+
73
+ /// Validate that the specified request is allowed for given keys.
74
+ pub async fn validate_request<'a>(
75
+ &self,
76
+ get_header: impl Fn(&'a str) -> Option<&'a str>,
77
+ ) -> Result<Access, AuthError> {
78
+ let Some(key) = get_header(HTTP_HEADER_API_KEY)
79
+ .or_else(|| get_header("authorization").and_then(|v| v.strip_prefix("Bearer ")))
80
+ else {
81
+ return Err(AuthError::Unauthorized(
82
+ "Must provide an API key or an Authorization bearer token".to_string(),
83
+ ));
84
+ };
85
+
86
+ if self.can_write(key) {
87
+ return Ok(Access::full("Read-write access by key"));
88
+ }
89
+
90
+ if self.can_read(key) {
91
+ return Ok(Access::full_ro("Read-only access by key"));
92
+ }
93
+
94
+ if let Some(claims) = self.jwt_parser.as_ref().and_then(|p| p.decode(key)) {
95
+ let Claims {
96
+ exp: _, // already validated on decoding
97
+ access,
98
+ value_exists,
99
+ } = claims?;
100
+
101
+ if let Some(value_exists) = value_exists {
102
+ self.validate_value_exists(&value_exists).await?;
103
+ }
104
+
105
+ return Ok(access);
106
+ }
107
+
108
+ Err(AuthError::Unauthorized(
109
+ "Invalid API key or JWT".to_string(),
110
+ ))
111
+ }
112
+
113
+ async fn validate_value_exists(&self, value_exists: &ValueExists) -> Result<(), AuthError> {
114
+ let scroll_req = ScrollRequestInternal {
115
+ offset: None,
116
+ limit: Some(1),
117
+ filter: Some(value_exists.to_filter()),
118
+ with_payload: Some(WithPayloadInterface::Bool(false)),
119
+ with_vector: WithVector::Bool(false),
120
+ order_by: None,
121
+ };
122
+
123
+ let res = self
124
+ .toc
125
+ .scroll(
126
+ value_exists.get_collection(),
127
+ scroll_req,
128
+ None,
129
+ None, // no timeout
130
+ ShardSelectorInternal::All,
131
+ Access::full("JWT stateful validation"),
132
+ )
133
+ .await
134
+ .map_err(|e| match e {
135
+ StorageError::NotFound { .. } => {
136
+ AuthError::Forbidden("Invalid JWT, stateful validation failed".to_string())
137
+ }
138
+ _ => AuthError::StorageError(e),
139
+ })?;
140
+
141
+ if res.points.is_empty() {
142
+ return Err(AuthError::Unauthorized(
143
+ "Invalid JWT, stateful validation failed".to_string(),
144
+ ));
145
+ };
146
+
147
+ Ok(())
148
+ }
149
+
150
+ /// Check if a key is allowed to read
151
+ #[inline]
152
+ fn can_read(&self, key: &str) -> bool {
153
+ self.read_only
154
+ .as_ref()
155
+ .is_some_and(|ro_key| ct_eq(ro_key, key))
156
+ }
157
+
158
+ /// Check if a key is allowed to write
159
+ #[inline]
160
+ fn can_write(&self, key: &str) -> bool {
161
+ self.read_write
162
+ .as_ref()
163
+ .is_some_and(|rw_key| ct_eq(rw_key, key))
164
+ }
165
+ }
src/common/collections.rs ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::HashMap;
2
+ use std::sync::Arc;
3
+ use std::time::Duration;
4
+
5
+ use api::grpc::qdrant::CollectionExists;
6
+ use api::rest::models::{CollectionDescription, CollectionsResponse};
7
+ use collection::config::ShardingMethod;
8
+ use collection::operations::cluster_ops::{
9
+ AbortTransferOperation, ClusterOperations, DropReplicaOperation, MoveShardOperation,
10
+ ReplicateShardOperation, ReshardingDirection, RestartTransfer, RestartTransferOperation,
11
+ StartResharding,
12
+ };
13
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
14
+ use collection::operations::snapshot_ops::SnapshotDescription;
15
+ use collection::operations::types::{
16
+ AliasDescription, CollectionClusterInfo, CollectionInfo, CollectionsAliasesResponse,
17
+ };
18
+ use collection::operations::verification::new_unchecked_verification_pass;
19
+ use collection::shards::replica_set;
20
+ use collection::shards::resharding::ReshardKey;
21
+ use collection::shards::shard::{PeerId, ShardId, ShardsPlacement};
22
+ use collection::shards::transfer::{ShardTransfer, ShardTransferKey, ShardTransferRestart};
23
+ use itertools::Itertools;
24
+ use rand::prelude::SliceRandom;
25
+ use rand::seq::IteratorRandom;
26
+ use storage::content_manager::collection_meta_ops::ShardTransferOperations::{Abort, Start};
27
+ use storage::content_manager::collection_meta_ops::{
28
+ CollectionMetaOperations, CreateShardKey, DropShardKey, ReshardingOperation,
29
+ SetShardReplicaState, ShardTransferOperations, UpdateCollectionOperation,
30
+ };
31
+ use storage::content_manager::errors::StorageError;
32
+ use storage::content_manager::toc::TableOfContent;
33
+ use storage::dispatcher::Dispatcher;
34
+ use storage::rbac::{Access, AccessRequirements};
35
+
36
+ pub async fn do_collection_exists(
37
+ toc: &TableOfContent,
38
+ access: Access,
39
+ name: &str,
40
+ ) -> Result<CollectionExists, StorageError> {
41
+ let collection_pass = access.check_collection_access(name, AccessRequirements::new())?;
42
+
43
+ // if this returns Ok, it means the collection exists.
44
+ // if not, we check that the error is NotFound
45
+ let Err(error) = toc.get_collection(&collection_pass).await else {
46
+ return Ok(CollectionExists { exists: true });
47
+ };
48
+ match error {
49
+ StorageError::NotFound { .. } => Ok(CollectionExists { exists: false }),
50
+ e => Err(e),
51
+ }
52
+ }
53
+
54
+ pub async fn do_get_collection(
55
+ toc: &TableOfContent,
56
+ access: Access,
57
+ name: &str,
58
+ shard_selection: Option<ShardId>,
59
+ ) -> Result<CollectionInfo, StorageError> {
60
+ let collection_pass =
61
+ access.check_collection_access(name, AccessRequirements::new().whole())?;
62
+
63
+ let collection = toc.get_collection(&collection_pass).await?;
64
+
65
+ let shard_selection = match shard_selection {
66
+ None => ShardSelectorInternal::All,
67
+ Some(shard_id) => ShardSelectorInternal::ShardId(shard_id),
68
+ };
69
+
70
+ Ok(collection.info(&shard_selection).await?)
71
+ }
72
+
73
+ pub async fn do_list_collections(
74
+ toc: &TableOfContent,
75
+ access: Access,
76
+ ) -> Result<CollectionsResponse, StorageError> {
77
+ let collections = toc
78
+ .all_collections(&access)
79
+ .await
80
+ .into_iter()
81
+ .map(|pass| CollectionDescription {
82
+ name: pass.name().to_string(),
83
+ })
84
+ .collect_vec();
85
+
86
+ Ok(CollectionsResponse { collections })
87
+ }
88
+
89
+ /// Construct shards-replicas layout for the shard from the given scope of peers
90
+ /// Example:
91
+ /// Shards: 3
92
+ /// Replicas: 2
93
+ /// Peers: [A, B, C]
94
+ ///
95
+ /// Placement:
96
+ /// [
97
+ /// [A, B]
98
+ /// [B, C]
99
+ /// [A, C]
100
+ /// ]
101
+ fn generate_even_placement(
102
+ mut pool: Vec<PeerId>,
103
+ shard_number: usize,
104
+ replication_factor: usize,
105
+ ) -> ShardsPlacement {
106
+ let mut exact_placement = Vec::new();
107
+ let mut rng = rand::thread_rng();
108
+ pool.shuffle(&mut rng);
109
+ let mut loop_iter = pool.iter().cycle();
110
+
111
+ // pool: [1,2,3,4]
112
+ // shuf_pool: [2,3,4,1]
113
+ //
114
+ // loop_iter: [2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1,...]
115
+ // shard_placement: [2, 3, 4][1, 2, 3][4, 1, 2][3, 4, 1][2, 3, 4]
116
+
117
+ let max_replication_factor = std::cmp::min(replication_factor, pool.len());
118
+ for _shard in 0..shard_number {
119
+ let mut shard_placement = Vec::new();
120
+ for _replica in 0..max_replication_factor {
121
+ shard_placement.push(*loop_iter.next().unwrap());
122
+ }
123
+ exact_placement.push(shard_placement);
124
+ }
125
+ exact_placement
126
+ }
127
+
128
+ pub async fn do_list_collection_aliases(
129
+ toc: &TableOfContent,
130
+ access: Access,
131
+ collection_name: &str,
132
+ ) -> Result<CollectionsAliasesResponse, StorageError> {
133
+ let collection_pass =
134
+ access.check_collection_access(collection_name, AccessRequirements::new())?;
135
+ let aliases: Vec<AliasDescription> = toc
136
+ .collection_aliases(&collection_pass, &access)
137
+ .await?
138
+ .into_iter()
139
+ .map(|alias| AliasDescription {
140
+ alias_name: alias,
141
+ collection_name: collection_name.to_string(),
142
+ })
143
+ .collect();
144
+ Ok(CollectionsAliasesResponse { aliases })
145
+ }
146
+
147
+ pub async fn do_list_aliases(
148
+ toc: &TableOfContent,
149
+ access: Access,
150
+ ) -> Result<CollectionsAliasesResponse, StorageError> {
151
+ let aliases = toc.list_aliases(&access).await?;
152
+ Ok(CollectionsAliasesResponse { aliases })
153
+ }
154
+
155
+ pub async fn do_list_snapshots(
156
+ toc: &TableOfContent,
157
+ access: Access,
158
+ collection_name: &str,
159
+ ) -> Result<Vec<SnapshotDescription>, StorageError> {
160
+ let collection_pass =
161
+ access.check_collection_access(collection_name, AccessRequirements::new().whole())?;
162
+ Ok(toc
163
+ .get_collection(&collection_pass)
164
+ .await?
165
+ .list_snapshots()
166
+ .await?)
167
+ }
168
+
169
+ pub async fn do_create_snapshot(
170
+ toc: Arc<TableOfContent>,
171
+ access: Access,
172
+ collection_name: &str,
173
+ ) -> Result<SnapshotDescription, StorageError> {
174
+ let collection_pass = access
175
+ .check_collection_access(collection_name, AccessRequirements::new().write().whole())?
176
+ .into_static();
177
+
178
+ let result = tokio::spawn(async move { toc.create_snapshot(&collection_pass).await }).await??;
179
+
180
+ Ok(result)
181
+ }
182
+
183
+ pub async fn do_get_collection_cluster(
184
+ toc: &TableOfContent,
185
+ access: Access,
186
+ name: &str,
187
+ ) -> Result<CollectionClusterInfo, StorageError> {
188
+ let collection_pass =
189
+ access.check_collection_access(name, AccessRequirements::new().whole())?;
190
+ let collection = toc.get_collection(&collection_pass).await?;
191
+ Ok(collection.cluster_info(toc.this_peer_id).await?)
192
+ }
193
+
194
+ pub async fn do_update_collection_cluster(
195
+ dispatcher: &Dispatcher,
196
+ collection_name: String,
197
+ operation: ClusterOperations,
198
+ access: Access,
199
+ wait_timeout: Option<Duration>,
200
+ ) -> Result<bool, StorageError> {
201
+ let collection_pass = access.check_collection_access(
202
+ &collection_name,
203
+ AccessRequirements::new().write().manage().whole(),
204
+ )?;
205
+
206
+ if dispatcher.consensus_state().is_none() {
207
+ return Err(StorageError::BadRequest {
208
+ description: "Distributed mode disabled".to_string(),
209
+ });
210
+ }
211
+ let consensus_state = dispatcher.consensus_state().unwrap();
212
+
213
+ let get_all_peer_ids = || {
214
+ consensus_state
215
+ .persistent
216
+ .read()
217
+ .peer_address_by_id
218
+ .read()
219
+ .keys()
220
+ .cloned()
221
+ .collect_vec()
222
+ };
223
+
224
+ let validate_peer_exists = |peer_id| {
225
+ let target_peer_exist = consensus_state
226
+ .persistent
227
+ .read()
228
+ .peer_address_by_id
229
+ .read()
230
+ .contains_key(&peer_id);
231
+ if !target_peer_exist {
232
+ return Err(StorageError::BadRequest {
233
+ description: format!("Peer {peer_id} does not exist"),
234
+ });
235
+ }
236
+ Ok(())
237
+ };
238
+
239
+ // All checks should've been done at this point.
240
+ let pass = new_unchecked_verification_pass();
241
+
242
+ let collection = dispatcher
243
+ .toc(&access, &pass)
244
+ .get_collection(&collection_pass)
245
+ .await?;
246
+
247
+ match operation {
248
+ ClusterOperations::MoveShard(MoveShardOperation { move_shard }) => {
249
+ // validate shard to move
250
+ if !collection.contains_shard(move_shard.shard_id).await {
251
+ return Err(StorageError::BadRequest {
252
+ description: format!(
253
+ "Shard {} of {} does not exist",
254
+ move_shard.shard_id, collection_name
255
+ ),
256
+ });
257
+ };
258
+
259
+ // validate target and source peer exists
260
+ validate_peer_exists(move_shard.to_peer_id)?;
261
+ validate_peer_exists(move_shard.from_peer_id)?;
262
+
263
+ // submit operation to consensus
264
+ dispatcher
265
+ .submit_collection_meta_op(
266
+ CollectionMetaOperations::TransferShard(
267
+ collection_name,
268
+ Start(ShardTransfer {
269
+ shard_id: move_shard.shard_id,
270
+ to_shard_id: move_shard.to_shard_id,
271
+ to: move_shard.to_peer_id,
272
+ from: move_shard.from_peer_id,
273
+ sync: false,
274
+ method: move_shard.method,
275
+ }),
276
+ ),
277
+ access,
278
+ wait_timeout,
279
+ )
280
+ .await
281
+ }
282
+ ClusterOperations::ReplicateShard(ReplicateShardOperation { replicate_shard }) => {
283
+ // validate shard to move
284
+ if !collection.contains_shard(replicate_shard.shard_id).await {
285
+ return Err(StorageError::BadRequest {
286
+ description: format!(
287
+ "Shard {} of {} does not exist",
288
+ replicate_shard.shard_id, collection_name
289
+ ),
290
+ });
291
+ };
292
+
293
+ // validate target peer exists
294
+ validate_peer_exists(replicate_shard.to_peer_id)?;
295
+
296
+ // validate source peer exists
297
+ validate_peer_exists(replicate_shard.from_peer_id)?;
298
+
299
+ // submit operation to consensus
300
+ dispatcher
301
+ .submit_collection_meta_op(
302
+ CollectionMetaOperations::TransferShard(
303
+ collection_name,
304
+ Start(ShardTransfer {
305
+ shard_id: replicate_shard.shard_id,
306
+ to_shard_id: replicate_shard.to_shard_id,
307
+ to: replicate_shard.to_peer_id,
308
+ from: replicate_shard.from_peer_id,
309
+ sync: true,
310
+ method: replicate_shard.method,
311
+ }),
312
+ ),
313
+ access,
314
+ wait_timeout,
315
+ )
316
+ .await
317
+ }
318
+ ClusterOperations::AbortTransfer(AbortTransferOperation { abort_transfer }) => {
319
+ let transfer = ShardTransferKey {
320
+ shard_id: abort_transfer.shard_id,
321
+ to_shard_id: abort_transfer.to_shard_id,
322
+ to: abort_transfer.to_peer_id,
323
+ from: abort_transfer.from_peer_id,
324
+ };
325
+
326
+ if !collection.check_transfer_exists(&transfer).await {
327
+ return Err(StorageError::NotFound {
328
+ description: format!(
329
+ "Shard transfer {} -> {} for collection {}:{} does not exist",
330
+ transfer.from, transfer.to, collection_name, transfer.shard_id
331
+ ),
332
+ });
333
+ }
334
+
335
+ dispatcher
336
+ .submit_collection_meta_op(
337
+ CollectionMetaOperations::TransferShard(
338
+ collection_name,
339
+ Abort {
340
+ transfer,
341
+ reason: "user request".to_string(),
342
+ },
343
+ ),
344
+ access,
345
+ wait_timeout,
346
+ )
347
+ .await
348
+ }
349
+ ClusterOperations::DropReplica(DropReplicaOperation { drop_replica }) => {
350
+ if !collection.contains_shard(drop_replica.shard_id).await {
351
+ return Err(StorageError::BadRequest {
352
+ description: format!(
353
+ "Shard {} of {} does not exist",
354
+ drop_replica.shard_id, collection_name
355
+ ),
356
+ });
357
+ };
358
+
359
+ validate_peer_exists(drop_replica.peer_id)?;
360
+
361
+ let mut update_operation = UpdateCollectionOperation::new_empty(collection_name);
362
+
363
+ update_operation.set_shard_replica_changes(vec![replica_set::Change::Remove(
364
+ drop_replica.shard_id,
365
+ drop_replica.peer_id,
366
+ )]);
367
+
368
+ dispatcher
369
+ .submit_collection_meta_op(
370
+ CollectionMetaOperations::UpdateCollection(update_operation),
371
+ access,
372
+ wait_timeout,
373
+ )
374
+ .await
375
+ }
376
+ ClusterOperations::CreateShardingKey(create_sharding_key_op) => {
377
+ let create_sharding_key = create_sharding_key_op.create_sharding_key;
378
+
379
+ // Validate that:
380
+ // - proper sharding method is used
381
+ // - key does not exist yet
382
+ //
383
+ // If placement suggested:
384
+ // - Peers exist
385
+
386
+ let state = collection.state().await;
387
+
388
+ match state.config.params.sharding_method.unwrap_or_default() {
389
+ ShardingMethod::Auto => {
390
+ return Err(StorageError::bad_request(
391
+ "Shard Key cannot be created with Auto sharding method",
392
+ ));
393
+ }
394
+ ShardingMethod::Custom => {}
395
+ }
396
+
397
+ let shard_number = create_sharding_key
398
+ .shards_number
399
+ .unwrap_or(state.config.params.shard_number)
400
+ .get() as usize;
401
+ let replication_factor = create_sharding_key
402
+ .replication_factor
403
+ .unwrap_or(state.config.params.replication_factor)
404
+ .get() as usize;
405
+
406
+ let shard_keys_mapping = state.shards_key_mapping;
407
+ if shard_keys_mapping.contains_key(&create_sharding_key.shard_key) {
408
+ return Err(StorageError::BadRequest {
409
+ description: format!(
410
+ "Sharding key {} already exists for collection {}",
411
+ create_sharding_key.shard_key, collection_name
412
+ ),
413
+ });
414
+ }
415
+
416
+ let peers_pool: Vec<_> = if let Some(placement) = create_sharding_key.placement {
417
+ if placement.is_empty() {
418
+ return Err(StorageError::BadRequest {
419
+ description: format!(
420
+ "Sharding key {} placement cannot be empty. If you want to use random placement, do not specify placement",
421
+ create_sharding_key.shard_key
422
+ ),
423
+ });
424
+ }
425
+
426
+ for peer_id in placement.iter().copied() {
427
+ validate_peer_exists(peer_id)?;
428
+ }
429
+ placement
430
+ } else {
431
+ get_all_peer_ids()
432
+ };
433
+
434
+ let exact_placement =
435
+ generate_even_placement(peers_pool, shard_number, replication_factor);
436
+
437
+ dispatcher
438
+ .submit_collection_meta_op(
439
+ CollectionMetaOperations::CreateShardKey(CreateShardKey {
440
+ collection_name,
441
+ shard_key: create_sharding_key.shard_key,
442
+ placement: exact_placement,
443
+ }),
444
+ access,
445
+ wait_timeout,
446
+ )
447
+ .await
448
+ }
449
+ ClusterOperations::DropShardingKey(drop_sharding_key_op) => {
450
+ let drop_sharding_key = drop_sharding_key_op.drop_sharding_key;
451
+ // Validate that:
452
+ // - proper sharding method is used
453
+ // - key does exist
454
+
455
+ let state = collection.state().await;
456
+
457
+ match state.config.params.sharding_method.unwrap_or_default() {
458
+ ShardingMethod::Auto => {
459
+ return Err(StorageError::bad_request(
460
+ "Shard Key cannot be created with Auto sharding method",
461
+ ));
462
+ }
463
+ ShardingMethod::Custom => {}
464
+ }
465
+
466
+ let shard_keys_mapping = state.shards_key_mapping;
467
+ if !shard_keys_mapping.contains_key(&drop_sharding_key.shard_key) {
468
+ return Err(StorageError::BadRequest {
469
+ description: format!(
470
+ "Sharding key {} does not exists for collection {}",
471
+ drop_sharding_key.shard_key, collection_name
472
+ ),
473
+ });
474
+ }
475
+
476
+ dispatcher
477
+ .submit_collection_meta_op(
478
+ CollectionMetaOperations::DropShardKey(DropShardKey {
479
+ collection_name,
480
+ shard_key: drop_sharding_key.shard_key,
481
+ }),
482
+ access,
483
+ wait_timeout,
484
+ )
485
+ .await
486
+ }
487
+ ClusterOperations::RestartTransfer(RestartTransferOperation { restart_transfer }) => {
488
+ // TODO(reshading): Deduplicate resharding operations handling?
489
+
490
+ let RestartTransfer {
491
+ shard_id,
492
+ to_shard_id,
493
+ from_peer_id,
494
+ to_peer_id,
495
+ method,
496
+ } = restart_transfer;
497
+
498
+ let transfer_key = ShardTransferKey {
499
+ shard_id,
500
+ to_shard_id,
501
+ to: to_peer_id,
502
+ from: from_peer_id,
503
+ };
504
+
505
+ if !collection.check_transfer_exists(&transfer_key).await {
506
+ return Err(StorageError::NotFound {
507
+ description: format!(
508
+ "Shard transfer {} -> {} for collection {}:{} does not exist",
509
+ transfer_key.from, transfer_key.to, collection_name, transfer_key.shard_id
510
+ ),
511
+ });
512
+ }
513
+
514
+ dispatcher
515
+ .submit_collection_meta_op(
516
+ CollectionMetaOperations::TransferShard(
517
+ collection_name,
518
+ ShardTransferOperations::Restart(ShardTransferRestart {
519
+ shard_id,
520
+ to_shard_id,
521
+ to: to_peer_id,
522
+ from: from_peer_id,
523
+ method,
524
+ }),
525
+ ),
526
+ access,
527
+ wait_timeout,
528
+ )
529
+ .await
530
+ }
531
+ ClusterOperations::StartResharding(op) => {
532
+ let StartResharding {
533
+ direction,
534
+ peer_id,
535
+ shard_key,
536
+ } = op.start_resharding;
537
+
538
+ let collection_state = collection.state().await;
539
+
540
+ if let Some(shard_key) = &shard_key {
541
+ if !collection_state.shards_key_mapping.contains_key(shard_key) {
542
+ return Err(StorageError::bad_request(format!(
543
+ "sharding key {shard_key} does not exists for collection {collection_name}"
544
+ )));
545
+ }
546
+ }
547
+
548
+ let shard_id = match (direction, shard_key.as_ref()) {
549
+ // When scaling up, just pick the next shard ID
550
+ (ReshardingDirection::Up, _) => {
551
+ collection_state
552
+ .shards
553
+ .keys()
554
+ .copied()
555
+ .max()
556
+ .expect("collection must contain shards")
557
+ + 1
558
+ }
559
+ // When scaling down without shard keys, pick the last shard ID
560
+ (ReshardingDirection::Down, None) => collection_state
561
+ .shards
562
+ .keys()
563
+ .copied()
564
+ .max()
565
+ .expect("collection must contain shards"),
566
+ // When scaling down with shard keys, pick the last shard ID of that key
567
+ (ReshardingDirection::Down, Some(shard_key)) => collection_state
568
+ .shards_key_mapping
569
+ .get(shard_key)
570
+ .expect("specified shard key must exist")
571
+ .iter()
572
+ .copied()
573
+ .max()
574
+ .expect("collection must contain shards"),
575
+ };
576
+
577
+ let peer_id = match (peer_id, direction) {
578
+ // Select user specified peer, but make sure it exists
579
+ (Some(peer_id), _) => {
580
+ validate_peer_exists(peer_id)?;
581
+ peer_id
582
+ }
583
+
584
+ // When scaling up, select peer with least number of shards for this collection
585
+ (None, ReshardingDirection::Up) => {
586
+ let mut shards_on_peers = collection_state
587
+ .shards
588
+ .values()
589
+ .flat_map(|shard_info| shard_info.replicas.keys())
590
+ .fold(HashMap::new(), |mut counts, peer_id| {
591
+ *counts.entry(*peer_id).or_insert(0) += 1;
592
+ counts
593
+ });
594
+ for peer_id in get_all_peer_ids() {
595
+ // Add registered peers not holding any shard yet
596
+ shards_on_peers.entry(peer_id).or_insert(0);
597
+ }
598
+ shards_on_peers
599
+ .into_iter()
600
+ .min_by_key(|(_, count)| *count)
601
+ .map(|(peer_id, _)| peer_id)
602
+ .expect("expected at least one peer")
603
+ }
604
+
605
+ // When scaling down, select random peer that contains the shard we're dropping
606
+ // Other peers work, but are less efficient due to remote operations
607
+ (None, ReshardingDirection::Down) => collection_state
608
+ .shards
609
+ .get(&shard_id)
610
+ .expect("select shard ID must always exist in collection state")
611
+ .replicas
612
+ .keys()
613
+ .choose(&mut rand::thread_rng())
614
+ .copied()
615
+ .unwrap(),
616
+ };
617
+
618
+ if let Some(resharding) = &collection_state.resharding {
619
+ return Err(StorageError::bad_request(format!(
620
+ "resharding {resharding:?} is already in progress \
621
+ for collection {collection_name}"
622
+ )));
623
+ }
624
+
625
+ dispatcher
626
+ .submit_collection_meta_op(
627
+ CollectionMetaOperations::Resharding(
628
+ collection_name.clone(),
629
+ ReshardingOperation::Start(ReshardKey {
630
+ direction,
631
+ peer_id,
632
+ shard_id,
633
+ shard_key,
634
+ }),
635
+ ),
636
+ access,
637
+ wait_timeout,
638
+ )
639
+ .await
640
+ }
641
+ ClusterOperations::AbortResharding(_) => {
642
+ // TODO(reshading): Deduplicate resharding operations handling?
643
+
644
+ let Some(state) = collection.resharding_state().await else {
645
+ return Err(StorageError::bad_request(format!(
646
+ "resharding is not in progress for collection {collection_name}"
647
+ )));
648
+ };
649
+
650
+ dispatcher
651
+ .submit_collection_meta_op(
652
+ CollectionMetaOperations::Resharding(
653
+ collection_name.clone(),
654
+ ReshardingOperation::Abort(ReshardKey {
655
+ direction: state.direction,
656
+ peer_id: state.peer_id,
657
+ shard_id: state.shard_id,
658
+ shard_key: state.shard_key.clone(),
659
+ }),
660
+ ),
661
+ access,
662
+ wait_timeout,
663
+ )
664
+ .await
665
+ }
666
+ ClusterOperations::FinishResharding(_) => {
667
+ // TODO(resharding): Deduplicate resharding operations handling?
668
+
669
+ let Some(state) = collection.resharding_state().await else {
670
+ return Err(StorageError::bad_request(format!(
671
+ "resharding is not in progress for collection {collection_name}"
672
+ )));
673
+ };
674
+
675
+ dispatcher
676
+ .submit_collection_meta_op(
677
+ CollectionMetaOperations::Resharding(
678
+ collection_name.clone(),
679
+ ReshardingOperation::Finish(state.key()),
680
+ ),
681
+ access,
682
+ wait_timeout,
683
+ )
684
+ .await
685
+ }
686
+
687
+ ClusterOperations::FinishMigratingPoints(op) => {
688
+ // TODO(resharding): Deduplicate resharding operations handling?
689
+
690
+ let Some(state) = collection.resharding_state().await else {
691
+ return Err(StorageError::bad_request(format!(
692
+ "resharding is not in progress for collection {collection_name}"
693
+ )));
694
+ };
695
+
696
+ let op = op.finish_migrating_points;
697
+
698
+ let shard_id = match (op.shard_id, state.direction) {
699
+ (Some(shard_id), _) => shard_id,
700
+ (None, ReshardingDirection::Up) => state.shard_id,
701
+ (None, ReshardingDirection::Down) => {
702
+ return Err(StorageError::bad_request(
703
+ "shard ID must be specified when resharding down",
704
+ ));
705
+ }
706
+ };
707
+
708
+ let peer_id = match (op.peer_id, state.direction) {
709
+ (Some(peer_id), _) => peer_id,
710
+ (None, ReshardingDirection::Up) => state.peer_id,
711
+ (None, ReshardingDirection::Down) => {
712
+ return Err(StorageError::bad_request(
713
+ "peer ID must be specified when resharding down",
714
+ ));
715
+ }
716
+ };
717
+
718
+ dispatcher
719
+ .submit_collection_meta_op(
720
+ CollectionMetaOperations::SetShardReplicaState(SetShardReplicaState {
721
+ collection_name: collection_name.clone(),
722
+ shard_id,
723
+ peer_id,
724
+ state: replica_set::ReplicaState::Active,
725
+ from_state: Some(replica_set::ReplicaState::Resharding),
726
+ }),
727
+ access,
728
+ wait_timeout,
729
+ )
730
+ .await
731
+ }
732
+
733
+ ClusterOperations::CommitReadHashRing(_) => {
734
+ // TODO(reshading): Deduplicate resharding operations handling?
735
+
736
+ let Some(state) = collection.resharding_state().await else {
737
+ return Err(StorageError::bad_request(format!(
738
+ "resharding is not in progress for collection {collection_name}"
739
+ )));
740
+ };
741
+
742
+ // TODO(resharding): Add precondition checks?
743
+
744
+ dispatcher
745
+ .submit_collection_meta_op(
746
+ CollectionMetaOperations::Resharding(
747
+ collection_name.clone(),
748
+ ReshardingOperation::CommitRead(ReshardKey {
749
+ direction: state.direction,
750
+ peer_id: state.peer_id,
751
+ shard_id: state.shard_id,
752
+ shard_key: state.shard_key.clone(),
753
+ }),
754
+ ),
755
+ access,
756
+ wait_timeout,
757
+ )
758
+ .await
759
+ }
760
+
761
+ ClusterOperations::CommitWriteHashRing(_) => {
762
+ // TODO(reshading): Deduplicate resharding operations handling?
763
+
764
+ let Some(state) = collection.resharding_state().await else {
765
+ return Err(StorageError::bad_request(format!(
766
+ "resharding is not in progress for collection {collection_name}"
767
+ )));
768
+ };
769
+
770
+ // TODO(resharding): Add precondition checks?
771
+
772
+ dispatcher
773
+ .submit_collection_meta_op(
774
+ CollectionMetaOperations::Resharding(
775
+ collection_name.clone(),
776
+ ReshardingOperation::CommitWrite(ReshardKey {
777
+ direction: state.direction,
778
+ peer_id: state.peer_id,
779
+ shard_id: state.shard_id,
780
+ shard_key: state.shard_key.clone(),
781
+ }),
782
+ ),
783
+ access,
784
+ wait_timeout,
785
+ )
786
+ .await
787
+ }
788
+ }
789
+ }
790
+
791
+ #[cfg(test)]
792
+ mod tests {
793
+ use std::collections::HashSet;
794
+
795
+ use super::*;
796
+
797
+ #[test]
798
+ fn test_generate_even_placement() {
799
+ let pool = vec![1, 2, 3];
800
+ let placement = generate_even_placement(pool, 3, 2);
801
+
802
+ assert_eq!(placement.len(), 3);
803
+ for shard_placement in placement {
804
+ assert_eq!(shard_placement.len(), 2);
805
+ assert_ne!(shard_placement[0], shard_placement[1]);
806
+ }
807
+
808
+ let pool = vec![1, 2, 3];
809
+ let placement = generate_even_placement(pool, 3, 3);
810
+
811
+ assert_eq!(placement.len(), 3);
812
+ for shard_placement in placement {
813
+ assert_eq!(shard_placement.len(), 3);
814
+ let set: HashSet<_> = shard_placement.into_iter().collect();
815
+ assert_eq!(set.len(), 3);
816
+ }
817
+
818
+ let pool = vec![1, 2, 3, 4, 5, 6];
819
+ let placement = generate_even_placement(pool, 3, 2);
820
+
821
+ assert_eq!(placement.len(), 3);
822
+ let flat_placement: Vec<_> = placement.into_iter().flatten().collect();
823
+ let set: HashSet<_> = flat_placement.into_iter().collect();
824
+ assert_eq!(set.len(), 6);
825
+
826
+ let pool = vec![1, 2, 3, 4, 5];
827
+ let placement = generate_even_placement(pool, 3, 10);
828
+
829
+ assert_eq!(placement.len(), 3);
830
+ for shard_placement in placement {
831
+ assert_eq!(shard_placement.len(), 5);
832
+ }
833
+ }
834
+ }
src/common/debugger.rs ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::sync::Arc;
2
+
3
+ use parking_lot::Mutex;
4
+ use schemars::JsonSchema;
5
+ use serde::{Deserialize, Serialize};
6
+
7
+ use crate::common::pyroscope_state::pyro::PyroscopeState;
8
+ use crate::settings::Settings;
9
+
10
+ #[derive(Serialize, JsonSchema, Debug, Deserialize, Clone)]
11
+ pub struct PyroscopeConfig {
12
+ pub url: String,
13
+ pub identifier: String,
14
+ pub user: Option<String>,
15
+ pub password: Option<String>,
16
+ pub sampling_rate: Option<u32>,
17
+ }
18
+
19
+ #[derive(Default, Debug, Serialize, JsonSchema, Deserialize, Clone)]
20
+ pub struct DebuggerConfig {
21
+ pub pyroscope: Option<PyroscopeConfig>,
22
+ }
23
+
24
+ #[derive(Debug, Serialize, JsonSchema, Deserialize, Clone)]
25
+ #[serde(rename_all = "snake_case")]
26
+ pub enum DebugConfigPatch {
27
+ Pyroscope(Option<PyroscopeConfig>),
28
+ }
29
+
30
+ pub struct DebuggerState {
31
+ #[cfg_attr(not(target_os = "linux"), allow(dead_code))]
32
+ pub pyroscope: Arc<Mutex<Option<PyroscopeState>>>,
33
+ }
34
+
35
+ impl DebuggerState {
36
+ pub fn from_settings(settings: &Settings) -> Self {
37
+ let pyroscope_config = settings.debugger.pyroscope.clone();
38
+ Self {
39
+ pyroscope: Arc::new(Mutex::new(PyroscopeState::from_config(pyroscope_config))),
40
+ }
41
+ }
42
+
43
+ #[cfg_attr(not(target_os = "linux"), allow(clippy::unused_self))]
44
+ pub fn get_config(&self) -> DebuggerConfig {
45
+ let pyroscope_config = {
46
+ #[cfg(target_os = "linux")]
47
+ {
48
+ let pyroscope_state_guard = self.pyroscope.lock();
49
+ pyroscope_state_guard.as_ref().map(|s| s.config.clone())
50
+ }
51
+ #[cfg(not(target_os = "linux"))]
52
+ {
53
+ None
54
+ }
55
+ };
56
+
57
+ DebuggerConfig {
58
+ pyroscope: pyroscope_config,
59
+ }
60
+ }
61
+
62
+ #[cfg_attr(not(target_os = "linux"), allow(clippy::unused_self))]
63
+ pub fn apply_config_patch(&self, patch: DebugConfigPatch) -> bool {
64
+ #[cfg(target_os = "linux")]
65
+ {
66
+ match patch {
67
+ DebugConfigPatch::Pyroscope(new_config) => {
68
+ let mut pyroscope_guard = self.pyroscope.lock();
69
+ if let Some(pyroscope_state) = pyroscope_guard.as_mut() {
70
+ let stopped = pyroscope_state.stop_agent();
71
+ if !stopped {
72
+ return false;
73
+ }
74
+ }
75
+
76
+ if let Some(new_config) = new_config {
77
+ *pyroscope_guard = PyroscopeState::from_config(Some(new_config));
78
+ }
79
+ true
80
+ }
81
+ }
82
+ }
83
+
84
+ #[cfg(not(target_os = "linux"))]
85
+ {
86
+ let _ = patch; // Ignore new_config on non-linux OS
87
+ false
88
+ }
89
+ }
90
+ }
src/common/error_reporting.rs ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::time::Duration;
2
+
3
+ pub struct ErrorReporter;
4
+
5
+ impl ErrorReporter {
6
+ fn get_url() -> String {
7
+ if cfg!(debug_assertions) {
8
+ "https://staging-telemetry.qdrant.io".to_string()
9
+ } else {
10
+ "https://telemetry.qdrant.io".to_string()
11
+ }
12
+ }
13
+
14
+ pub fn report(error: &str, reporting_id: &str, backtrace: Option<&str>) {
15
+ let client = reqwest::blocking::Client::new();
16
+
17
+ let report = serde_json::json!({
18
+ "id": reporting_id,
19
+ "error": error,
20
+ "backtrace": backtrace.unwrap_or(""),
21
+ });
22
+
23
+ let data = serde_json::to_string(&report).unwrap();
24
+ let _resp = client
25
+ .post(Self::get_url())
26
+ .body(data)
27
+ .header("Content-Type", "application/json")
28
+ .timeout(Duration::from_secs(1))
29
+ .send();
30
+ }
31
+ }
src/common/health.rs ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::HashSet;
2
+ use std::future::{self, Future};
3
+ use std::sync::atomic::{self, AtomicBool};
4
+ use std::sync::Arc;
5
+ use std::time::Duration;
6
+ use std::{panic, thread};
7
+
8
+ use api::grpc::qdrant::qdrant_internal_client::QdrantInternalClient;
9
+ use api::grpc::qdrant::{GetConsensusCommitRequest, GetConsensusCommitResponse};
10
+ use api::grpc::transport_channel_pool::{self, TransportChannelPool};
11
+ use collection::shards::shard::ShardId;
12
+ use collection::shards::CollectionId;
13
+ use common::defaults;
14
+ use futures::stream::FuturesUnordered;
15
+ use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _};
16
+ use itertools::Itertools;
17
+ use storage::content_manager::consensus_manager::ConsensusStateRef;
18
+ use storage::content_manager::toc::TableOfContent;
19
+ use storage::rbac::Access;
20
+ use tokio::{runtime, sync, time};
21
+
22
+ const READY_CHECK_TIMEOUT: Duration = Duration::from_millis(500);
23
+ const GET_CONSENSUS_COMMITS_RETRIES: usize = 2;
24
+
25
+ /// Structure used to process health checks like `/readyz` endpoints.
26
+ pub struct HealthChecker {
27
+ // The state of the health checker.
28
+ // Once set to `true`, it should not change back to `false`.
29
+ // Initially set to `false`.
30
+ is_ready: Arc<AtomicBool>,
31
+ // The signal that notifies that state has changed.
32
+ // Comes from the health checker task.
33
+ is_ready_signal: Arc<sync::Notify>,
34
+ // Signal to the health checker task, that the API was called.
35
+ // Used to drive the health checker task and avoid constant polling.
36
+ check_ready_signal: Arc<sync::Notify>,
37
+ cancel: cancel::DropGuard,
38
+ }
39
+
40
+ impl HealthChecker {
41
+ pub fn spawn(
42
+ toc: Arc<TableOfContent>,
43
+ consensus_state: ConsensusStateRef,
44
+ runtime: &runtime::Handle,
45
+ wait_for_bootstrap: bool,
46
+ ) -> Self {
47
+ let task = Task {
48
+ toc,
49
+ consensus_state,
50
+ is_ready: Default::default(),
51
+ is_ready_signal: Default::default(),
52
+ check_ready_signal: Default::default(),
53
+ cancel: Default::default(),
54
+ wait_for_bootstrap,
55
+ };
56
+
57
+ let health_checker = Self {
58
+ is_ready: task.is_ready.clone(),
59
+ is_ready_signal: task.is_ready_signal.clone(),
60
+ check_ready_signal: task.check_ready_signal.clone(),
61
+ cancel: task.cancel.clone().drop_guard(),
62
+ };
63
+
64
+ let task = runtime.spawn(task.exec());
65
+ drop(task); // drop `JoinFuture` explicitly to make clippy happy
66
+
67
+ health_checker
68
+ }
69
+
70
+ pub async fn check_ready(&self) -> bool {
71
+ if self.is_ready() {
72
+ return true;
73
+ }
74
+
75
+ self.notify_task();
76
+ self.wait_ready().await
77
+ }
78
+
79
+ pub fn is_ready(&self) -> bool {
80
+ self.is_ready.load(atomic::Ordering::Relaxed)
81
+ }
82
+
83
+ pub fn notify_task(&self) {
84
+ self.check_ready_signal.notify_one();
85
+ }
86
+
87
+ async fn wait_ready(&self) -> bool {
88
+ let is_ready_signal = self.is_ready_signal.notified();
89
+
90
+ if self.is_ready() {
91
+ return true;
92
+ }
93
+
94
+ time::timeout(READY_CHECK_TIMEOUT, is_ready_signal)
95
+ .await
96
+ .is_ok()
97
+ }
98
+ }
99
+
100
+ pub struct Task {
101
+ toc: Arc<TableOfContent>,
102
+ consensus_state: ConsensusStateRef,
103
+ // Shared state with the health checker
104
+ // Once set to `true`, it should not change back to `false`.
105
+ is_ready: Arc<AtomicBool>,
106
+ // Used to notify the health checker service that the state has changed.
107
+ is_ready_signal: Arc<sync::Notify>,
108
+ // Driver signal for the health checker task
109
+ // Once received, the task should proceed with an attempt to check the state.
110
+ // Usually comes from the API call, but can be triggered by the task itself.
111
+ check_ready_signal: Arc<sync::Notify>,
112
+ cancel: cancel::CancellationToken,
113
+ wait_for_bootstrap: bool,
114
+ }
115
+
116
+ impl Task {
117
+ pub async fn exec(mut self) {
118
+ while let Err(err) = self.exec_catch_unwind().await {
119
+ let message = common::panic::downcast_str(&err).unwrap_or("");
120
+ let separator = if !message.is_empty() { ": " } else { "" };
121
+
122
+ log::error!("HealthChecker task panicked, retrying{separator}{message}",);
123
+ }
124
+ }
125
+
126
+ async fn exec_catch_unwind(&mut self) -> thread::Result<()> {
127
+ panic::AssertUnwindSafe(self.exec_cancel())
128
+ .catch_unwind()
129
+ .await
130
+ }
131
+
132
+ async fn exec_cancel(&mut self) {
133
+ let _ = cancel::future::cancel_on_token(self.cancel.clone(), self.exec_impl()).await;
134
+ }
135
+
136
+ async fn exec_impl(&mut self) {
137
+ // Wait until node joins cluster for the first time
138
+ //
139
+ // If this is a new deployment and `--bootstrap` CLI parameter was specified...
140
+ if self.wait_for_bootstrap {
141
+ // Check if this is the only node in the cluster
142
+ while self.consensus_state.peer_count() <= 1 {
143
+ // If cluster is empty, make another attempt to check
144
+ // after we receive another call to `/readyz`
145
+ //
146
+ // Wait for `/readyz` signal
147
+ self.check_ready_signal.notified().await;
148
+ }
149
+ }
150
+
151
+ // Artificial simulate signal from `/readyz` endpoint
152
+ // as if it was already called by the user.
153
+ // This allows to check the happy path without waiting for the first call.
154
+ self.check_ready_signal.notify_one();
155
+
156
+ // Get *cluster* commit index, or check if this is the only node in the cluster
157
+ let Some(cluster_commit_index) = self.cluster_commit_index().await else {
158
+ self.set_ready();
159
+ return;
160
+ };
161
+
162
+ // Check if *local* commit index >= *cluster* commit index...
163
+ while self.commit_index() < cluster_commit_index {
164
+ // Wait for `/readyz` signal
165
+ self.check_ready_signal.notified().await;
166
+
167
+ // If not:
168
+ //
169
+ // - Check if this is the only node in the cluster
170
+ if self.consensus_state.peer_count() <= 1 {
171
+ self.set_ready();
172
+ return;
173
+ }
174
+
175
+ // TODO: Do we want to update `cluster_commit_index` here?
176
+ //
177
+ // I.e.:
178
+ // - If we *don't* update `cluster_commit_index`, then we will only wait till the node
179
+ // catch up with the cluster commit index *at the moment the node has been started*
180
+ // - If we *do* update `cluster_commit_index`, then we will keep track of cluster
181
+ // commit index updates and wait till the node *completely* catch up with the leader,
182
+ // which might be hard (if not impossible) in some situations
183
+ }
184
+
185
+ // Collect "unhealthy" shards list
186
+ let mut unhealthy_shards = self.unhealthy_shards().await;
187
+
188
+ // Check if all shards are "healthy"...
189
+ while !unhealthy_shards.is_empty() {
190
+ // If not:
191
+ //
192
+ // - Wait for `/readyz` signal
193
+ self.check_ready_signal.notified().await;
194
+
195
+ // - Refresh "unhealthy" shards list
196
+ let current_unhealthy_shards = self.unhealthy_shards().await;
197
+
198
+ // - Check if any shards "healed" since last check
199
+ unhealthy_shards.retain(|shard| current_unhealthy_shards.contains(shard));
200
+ }
201
+
202
+ self.set_ready();
203
+ }
204
+
205
+ async fn cluster_commit_index(&self) -> Option<u64> {
206
+ // Wait for `/readyz` signal
207
+ self.check_ready_signal.notified().await;
208
+
209
+ // Check if there is only 1 node in the cluster
210
+ if self.consensus_state.peer_count() <= 1 {
211
+ return None;
212
+ }
213
+
214
+ // Get *cluster* commit index
215
+ let peer_address_by_id = self.consensus_state.peer_address_by_id();
216
+ let transport_channel_pool = &self.toc.get_channel_service().channel_pool;
217
+ let this_peer_id = self.toc.this_peer_id;
218
+ let this_peer_uri = peer_address_by_id.get(&this_peer_id);
219
+
220
+ let mut requests = peer_address_by_id
221
+ .values()
222
+ // Do not get the current commit from ourselves
223
+ .filter(|&uri| Some(uri) != this_peer_uri)
224
+ // Historic peers might use the same URLs as our current peers, request each URI once
225
+ .unique()
226
+ .map(|uri| get_consensus_commit(transport_channel_pool, uri))
227
+ .collect::<FuturesUnordered<_>>()
228
+ .inspect_err(|err| log::error!("GetConsensusCommit request failed: {err}"))
229
+ .filter_map(|res| future::ready(res.ok()));
230
+
231
+ // Raft commits consensus operation, after majority of nodes persisted it.
232
+ //
233
+ // This means, if we check the majority of nodes (e.g., `total nodes / 2 + 1`), at least one
234
+ // of these nodes will *always* have an up-to-date commit index. And so, the highest commit
235
+ // index among majority of nodes *is* the cluster commit index.
236
+ //
237
+ // Our current node *is* one of the cluster nodes, so it's enough to query `total nodes / 2`
238
+ // *additional* nodes, to get cluster commit index.
239
+ //
240
+ // The check goes like this:
241
+ // - Either at least one of the "additional" nodes return a *higher* commit index, which
242
+ // means our node is *not* up-to-date, and we have to wait to reach this commit index
243
+ // - Or *all* of them return *lower* commit index, which means current node is *already*
244
+ // up-to-date, and `/readyz` check will pass to the next step
245
+ //
246
+ // Example:
247
+ //
248
+ // Total nodes: 2
249
+ // Required: 2 / 2 = 1
250
+ //
251
+ // Total nodes: 3
252
+ // Required: 3 / 2 = 1
253
+ //
254
+ // Total nodes: 4
255
+ // Required: 4 / 2 = 2
256
+ //
257
+ // Total nodes: 5
258
+ // Required: 5 / 2 = 2
259
+ let sufficient_commit_indices_count = peer_address_by_id.len() / 2;
260
+
261
+ // *Wait* for `total nodex / 2` successful responses...
262
+ let mut commit_indices: Vec<_> = (&mut requests)
263
+ .take(sufficient_commit_indices_count)
264
+ .collect()
265
+ .await;
266
+
267
+ // ...and also collect any additional responses, that we might have *already* received
268
+ while let Ok(Some(resp)) = time::timeout(Duration::ZERO, requests.next()).await {
269
+ commit_indices.push(resp);
270
+ }
271
+
272
+ // Find the maximum commit index among all responses.
273
+ //
274
+ // Note, that we progress even if most (or even *all*) requests failed (e.g., because all
275
+ // other nodes are unavailable or they don't support `GetConsensusCommit` gRPC API).
276
+ //
277
+ // So this check is not 100% reliable and can give a false-positive result!
278
+ let cluster_commit_index = commit_indices
279
+ .into_iter()
280
+ .map(|resp| resp.into_inner().commit)
281
+ .max()
282
+ .unwrap_or(0);
283
+
284
+ Some(cluster_commit_index as _)
285
+ }
286
+
287
+ fn commit_index(&self) -> u64 {
288
+ // TODO: Blocking call in async context!?
289
+ self.consensus_state
290
+ .persistent
291
+ .read()
292
+ .last_applied_entry()
293
+ .unwrap_or(0)
294
+ }
295
+
296
+ /// List shards that are unhealthy, which may undergo automatic recovery.
297
+ ///
298
+ /// Shards in resharding state are not considered unhealthy and are excluded here.
299
+ /// They require an external driver to make them active or to drop them.
300
+ async fn unhealthy_shards(&self) -> HashSet<Shard> {
301
+ let this_peer_id = self.toc.this_peer_id;
302
+ let collections = self
303
+ .toc
304
+ .all_collections(&Access::full("For health check"))
305
+ .await;
306
+
307
+ let mut unhealthy_shards = HashSet::new();
308
+
309
+ for collection_pass in &collections {
310
+ let state = match self.toc.get_collection(collection_pass).await {
311
+ Ok(collection) => collection.state().await,
312
+ Err(_) => continue,
313
+ };
314
+
315
+ for (&shard, info) in state.shards.iter() {
316
+ let Some(state) = info.replicas.get(&this_peer_id) else {
317
+ continue;
318
+ };
319
+
320
+ if state.is_active_or_listener_or_resharding() {
321
+ continue;
322
+ }
323
+
324
+ unhealthy_shards.insert(Shard::new(collection_pass.name(), shard));
325
+ }
326
+ }
327
+
328
+ unhealthy_shards
329
+ }
330
+
331
+ fn set_ready(&self) {
332
+ self.is_ready.store(true, atomic::Ordering::Relaxed);
333
+ self.is_ready_signal.notify_waiters();
334
+ }
335
+ }
336
+
337
+ fn get_consensus_commit<'a>(
338
+ transport_channel_pool: &'a TransportChannelPool,
339
+ uri: &'a tonic::transport::Uri,
340
+ ) -> impl Future<Output = GetConsensusCommitResult> + 'a {
341
+ transport_channel_pool.with_channel_timeout(
342
+ uri,
343
+ |channel| async {
344
+ let mut client = QdrantInternalClient::new(channel);
345
+ let mut request = tonic::Request::new(GetConsensusCommitRequest {});
346
+ request.set_timeout(defaults::CONSENSUS_META_OP_WAIT);
347
+ client.get_consensus_commit(request).await
348
+ },
349
+ Some(defaults::CONSENSUS_META_OP_WAIT),
350
+ GET_CONSENSUS_COMMITS_RETRIES,
351
+ )
352
+ }
353
+
354
+ type GetConsensusCommitResult = Result<
355
+ tonic::Response<GetConsensusCommitResponse>,
356
+ transport_channel_pool::RequestError<tonic::Status>,
357
+ >;
358
+
359
+ #[derive(Clone, Debug, Eq, PartialEq, Hash)]
360
+ struct Shard {
361
+ collection: CollectionId,
362
+ shard: ShardId,
363
+ }
364
+
365
+ impl Shard {
366
+ pub fn new(collection: impl Into<CollectionId>, shard: ShardId) -> Self {
367
+ Self {
368
+ collection: collection.into(),
369
+ shard,
370
+ }
371
+ }
372
+ }
src/common/helpers.rs ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::cmp::max;
2
+ use std::sync::atomic::{AtomicUsize, Ordering};
3
+ use std::{fs, io};
4
+
5
+ use schemars::JsonSchema;
6
+ use serde::{Deserialize, Serialize};
7
+ use tokio::runtime;
8
+ use tokio::runtime::Runtime;
9
+ use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig};
10
+ use validator::Validate;
11
+
12
+ use crate::settings::{Settings, TlsConfig};
13
+
14
+ #[derive(Debug, Deserialize, Serialize, JsonSchema, Validate)]
15
+ pub struct LocksOption {
16
+ pub error_message: Option<String>,
17
+ pub write: bool,
18
+ }
19
+
20
+ pub fn create_search_runtime(max_search_threads: usize) -> io::Result<Runtime> {
21
+ let mut search_threads = max_search_threads;
22
+
23
+ if search_threads == 0 {
24
+ let num_cpu = common::cpu::get_num_cpus();
25
+ // At least one thread, but not more than number of CPUs - 1 if there are more than 2 CPU
26
+ // Example:
27
+ // Num CPU = 1 -> 1 thread
28
+ // Num CPU = 2 -> 2 thread - if we use one thread with 2 cpus, its too much un-utilized resources
29
+ // Num CPU = 3 -> 2 thread
30
+ // Num CPU = 4 -> 3 thread
31
+ // Num CPU = 5 -> 4 thread
32
+ search_threads = match num_cpu {
33
+ 0 => 1,
34
+ 1 => 1,
35
+ 2 => 2,
36
+ _ => num_cpu - 1,
37
+ };
38
+ }
39
+
40
+ runtime::Builder::new_multi_thread()
41
+ .worker_threads(search_threads)
42
+ .max_blocking_threads(search_threads)
43
+ .enable_all()
44
+ .thread_name_fn(|| {
45
+ static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
46
+ let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
47
+ format!("search-{id}")
48
+ })
49
+ .build()
50
+ }
51
+
52
+ pub fn create_update_runtime(max_optimization_threads: usize) -> io::Result<Runtime> {
53
+ let mut update_runtime_builder = runtime::Builder::new_multi_thread();
54
+
55
+ update_runtime_builder
56
+ .enable_time()
57
+ .thread_name_fn(move || {
58
+ static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
59
+ let update_id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
60
+ format!("update-{update_id}")
61
+ });
62
+
63
+ if max_optimization_threads > 0 {
64
+ // panics if val is not larger than 0.
65
+ update_runtime_builder.max_blocking_threads(max_optimization_threads);
66
+ }
67
+ update_runtime_builder.build()
68
+ }
69
+
70
+ pub fn create_general_purpose_runtime() -> io::Result<Runtime> {
71
+ runtime::Builder::new_multi_thread()
72
+ .enable_time()
73
+ .enable_io()
74
+ .worker_threads(max(common::cpu::get_num_cpus(), 2))
75
+ .thread_name_fn(|| {
76
+ static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
77
+ let general_id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
78
+ format!("general-{general_id}")
79
+ })
80
+ .build()
81
+ }
82
+
83
+ /// Load client TLS configuration.
84
+ pub fn load_tls_client_config(settings: &Settings) -> io::Result<Option<ClientTlsConfig>> {
85
+ if settings.cluster.p2p.enable_tls {
86
+ let tls_config = &settings.tls()?;
87
+ Ok(Some(
88
+ ClientTlsConfig::new()
89
+ .identity(load_identity(tls_config)?)
90
+ .ca_certificate(load_ca_certificate(tls_config)?),
91
+ ))
92
+ } else {
93
+ Ok(None)
94
+ }
95
+ }
96
+
97
+ /// Load server TLS configuration for external gRPC
98
+ pub fn load_tls_external_server_config(tls_config: &TlsConfig) -> io::Result<ServerTlsConfig> {
99
+ Ok(ServerTlsConfig::new().identity(load_identity(tls_config)?))
100
+ }
101
+
102
+ /// Load server TLS configuration for internal gRPC, check client certificate against CA
103
+ pub fn load_tls_internal_server_config(tls_config: &TlsConfig) -> io::Result<ServerTlsConfig> {
104
+ Ok(ServerTlsConfig::new()
105
+ .identity(load_identity(tls_config)?)
106
+ .client_ca_root(load_ca_certificate(tls_config)?))
107
+ }
108
+
109
+ fn load_identity(tls_config: &TlsConfig) -> io::Result<Identity> {
110
+ let cert = fs::read_to_string(&tls_config.cert)?;
111
+ let key = fs::read_to_string(&tls_config.key)?;
112
+ Ok(Identity::from_pem(cert, key))
113
+ }
114
+
115
+ fn load_ca_certificate(tls_config: &TlsConfig) -> io::Result<Certificate> {
116
+ let pem = fs::read_to_string(&tls_config.ca_cert)?;
117
+ Ok(Certificate::from_pem(pem))
118
+ }
119
+
120
+ pub fn tonic_error_to_io_error(err: tonic::transport::Error) -> io::Error {
121
+ io::Error::new(io::ErrorKind::Other, err)
122
+ }
123
+
124
+ #[cfg(test)]
125
+ mod tests {
126
+ use std::sync::Arc;
127
+ use std::thread;
128
+ use std::thread::sleep;
129
+ use std::time::Duration;
130
+
131
+ use collection::common::is_ready::IsReady;
132
+
133
+ #[test]
134
+ fn test_is_ready() {
135
+ let is_ready = Arc::new(IsReady::default());
136
+ let is_ready_clone = is_ready.clone();
137
+ let join = thread::spawn(move || {
138
+ is_ready_clone.await_ready();
139
+ eprintln!(
140
+ "is_ready_clone.check_ready() = {:#?}",
141
+ is_ready_clone.check_ready()
142
+ );
143
+ });
144
+
145
+ sleep(Duration::from_millis(500));
146
+ eprintln!("Making ready");
147
+ is_ready.make_ready();
148
+ sleep(Duration::from_millis(500));
149
+ join.join().unwrap()
150
+ }
151
+ }
src/common/http_client.rs ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::path::Path;
2
+ use std::{fs, io, result};
3
+
4
+ use reqwest::header::{HeaderMap, HeaderValue, InvalidHeaderValue};
5
+ use storage::content_manager::errors::StorageError;
6
+
7
+ use super::auth::HTTP_HEADER_API_KEY;
8
+ use crate::settings::{Settings, TlsConfig};
9
+
10
+ #[derive(Clone)]
11
+ pub struct HttpClient {
12
+ tls_config: Option<TlsConfig>,
13
+ verify_https_client_certificate: bool,
14
+ }
15
+
16
+ impl HttpClient {
17
+ pub fn from_settings(settings: &Settings) -> Result<Self> {
18
+ let tls_config = if settings.service.enable_tls {
19
+ let Some(tls_config) = settings.tls.clone() else {
20
+ return Err(Error::TlsConfigUndefined);
21
+ };
22
+
23
+ Some(tls_config)
24
+ } else {
25
+ None
26
+ };
27
+
28
+ let verify_https_client_certificate = settings.service.verify_https_client_certificate;
29
+
30
+ let http_client = Self {
31
+ tls_config,
32
+ verify_https_client_certificate,
33
+ };
34
+
35
+ Ok(http_client)
36
+ }
37
+
38
+ /// Create a new HTTP(S) client
39
+ ///
40
+ /// An API key can be optionally provided to be used in this HTTP client. It'll send the API
41
+ /// key as `Api-key` header in every request.
42
+ ///
43
+ /// # Warning
44
+ ///
45
+ /// Setting an API key may leak when the client is used to send a request to a malicious
46
+ /// server. This is potentially dangerous if a user has control over what URL is accessed.
47
+ ///
48
+ /// For this reason the API key is not set by default as provided in the configuration. It must
49
+ /// be explicitly provided when creating the HTTP client.
50
+ pub fn client(&self, api_key: Option<&str>) -> Result<reqwest::Client> {
51
+ https_client(
52
+ api_key,
53
+ self.tls_config.as_ref(),
54
+ self.verify_https_client_certificate,
55
+ )
56
+ }
57
+ }
58
+
59
+ fn https_client(
60
+ api_key: Option<&str>,
61
+ tls_config: Option<&TlsConfig>,
62
+ verify_https_client_certificate: bool,
63
+ ) -> Result<reqwest::Client> {
64
+ let mut builder = reqwest::Client::builder();
65
+
66
+ // Configure TLS root certificate and validation
67
+ if let Some(tls_config) = tls_config {
68
+ builder = builder.add_root_certificate(https_client_ca_cert(tls_config.ca_cert.as_ref())?);
69
+
70
+ if verify_https_client_certificate {
71
+ builder = builder.identity(https_client_identity(
72
+ tls_config.cert.as_ref(),
73
+ tls_config.key.as_ref(),
74
+ )?);
75
+ }
76
+ }
77
+
78
+ // Attach API key as sensitive header
79
+ if let Some(api_key) = api_key {
80
+ let mut headers = HeaderMap::new();
81
+ let mut api_key_value = HeaderValue::from_str(api_key).map_err(Error::MalformedApiKey)?;
82
+ api_key_value.set_sensitive(true);
83
+ headers.insert(HTTP_HEADER_API_KEY, api_key_value);
84
+ builder = builder.default_headers(headers);
85
+ }
86
+
87
+ let client = builder.build()?;
88
+
89
+ Ok(client)
90
+ }
91
+
92
+ fn https_client_ca_cert(ca_cert: &Path) -> Result<reqwest::tls::Certificate> {
93
+ let ca_cert_pem =
94
+ fs::read(ca_cert).map_err(|err| Error::failed_to_read(err, "CA certificate", ca_cert))?;
95
+
96
+ let ca_cert = reqwest::Certificate::from_pem(&ca_cert_pem)?;
97
+
98
+ Ok(ca_cert)
99
+ }
100
+
101
+ fn https_client_identity(cert: &Path, key: &Path) -> Result<reqwest::tls::Identity> {
102
+ let mut identity_pem =
103
+ fs::read(cert).map_err(|err| Error::failed_to_read(err, "certificate", cert))?;
104
+
105
+ let mut key_file = fs::File::open(key).map_err(|err| Error::failed_to_read(err, "key", key))?;
106
+
107
+ // Concatenate certificate and key into a single PEM bytes
108
+ io::copy(&mut key_file, &mut identity_pem)
109
+ .map_err(|err| Error::failed_to_read(err, "key", key))?;
110
+
111
+ let identity = reqwest::Identity::from_pem(&identity_pem)?;
112
+
113
+ Ok(identity)
114
+ }
115
+
116
+ pub type Result<T, E = Error> = result::Result<T, E>;
117
+
118
+ #[derive(Debug, thiserror::Error)]
119
+ pub enum Error {
120
+ #[error("TLS config is not defined in the Qdrant config file")]
121
+ TlsConfigUndefined,
122
+
123
+ #[error("{1}: {0}")]
124
+ Io(#[source] io::Error, String),
125
+
126
+ #[error("failed to setup HTTPS client: {0}")]
127
+ Reqwest(#[from] reqwest::Error),
128
+
129
+ #[error("malformed API key")]
130
+ MalformedApiKey(#[source] InvalidHeaderValue),
131
+ }
132
+
133
+ impl Error {
134
+ pub fn io(source: io::Error, context: impl Into<String>) -> Self {
135
+ Self::Io(source, context.into())
136
+ }
137
+
138
+ pub fn failed_to_read(source: io::Error, file: &str, path: &Path) -> Self {
139
+ Self::io(
140
+ source,
141
+ format!("failed to read HTTPS client {file} file {}", path.display()),
142
+ )
143
+ }
144
+ }
145
+
146
+ impl From<Error> for StorageError {
147
+ fn from(err: Error) -> Self {
148
+ StorageError::service_error(format!("failed to initialize HTTP(S) client: {err}"))
149
+ }
150
+ }
151
+
152
+ impl From<Error> for io::Error {
153
+ fn from(err: Error) -> Self {
154
+ io::Error::new(io::ErrorKind::Other, err)
155
+ }
156
+ }
src/common/inference/batch_processing.rs ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::HashSet;
2
+
3
+ use api::rest::{
4
+ ContextInput, ContextPair, DiscoverInput, Prefetch, Query, QueryGroupsRequestInternal,
5
+ QueryInterface, QueryRequestInternal, RecommendInput, VectorInput,
6
+ };
7
+
8
+ use super::service::{InferenceData, InferenceInput, InferenceRequest};
9
+
10
+ pub struct BatchAccum {
11
+ pub(crate) objects: HashSet<InferenceData>,
12
+ }
13
+
14
+ impl BatchAccum {
15
+ pub fn new() -> Self {
16
+ Self {
17
+ objects: HashSet::new(),
18
+ }
19
+ }
20
+
21
+ pub fn add(&mut self, data: InferenceData) {
22
+ self.objects.insert(data);
23
+ }
24
+
25
+ pub fn extend(&mut self, other: BatchAccum) {
26
+ self.objects.extend(other.objects);
27
+ }
28
+
29
+ pub fn is_empty(&self) -> bool {
30
+ self.objects.is_empty()
31
+ }
32
+ }
33
+
34
+ impl From<&BatchAccum> for InferenceRequest {
35
+ fn from(batch: &BatchAccum) -> Self {
36
+ Self {
37
+ inputs: batch
38
+ .objects
39
+ .iter()
40
+ .cloned()
41
+ .map(InferenceInput::from)
42
+ .collect(),
43
+ inference: None,
44
+ token: None,
45
+ }
46
+ }
47
+ }
48
+
49
+ fn collect_vector_input(vector: &VectorInput, batch: &mut BatchAccum) {
50
+ match vector {
51
+ VectorInput::Document(doc) => batch.add(InferenceData::Document(doc.clone())),
52
+ VectorInput::Image(img) => batch.add(InferenceData::Image(img.clone())),
53
+ VectorInput::Object(obj) => batch.add(InferenceData::Object(obj.clone())),
54
+ // types that are not supported in the Inference Service
55
+ VectorInput::DenseVector(_) => {}
56
+ VectorInput::SparseVector(_) => {}
57
+ VectorInput::MultiDenseVector(_) => {}
58
+ VectorInput::Id(_) => {}
59
+ }
60
+ }
61
+
62
+ fn collect_context_pair(pair: &ContextPair, batch: &mut BatchAccum) {
63
+ collect_vector_input(&pair.positive, batch);
64
+ collect_vector_input(&pair.negative, batch);
65
+ }
66
+
67
+ fn collect_discover_input(discover: &DiscoverInput, batch: &mut BatchAccum) {
68
+ collect_vector_input(&discover.target, batch);
69
+ if let Some(context) = &discover.context {
70
+ for pair in context {
71
+ collect_context_pair(pair, batch);
72
+ }
73
+ }
74
+ }
75
+
76
+ fn collect_recommend_input(recommend: &RecommendInput, batch: &mut BatchAccum) {
77
+ if let Some(positive) = &recommend.positive {
78
+ for vector in positive {
79
+ collect_vector_input(vector, batch);
80
+ }
81
+ }
82
+ if let Some(negative) = &recommend.negative {
83
+ for vector in negative {
84
+ collect_vector_input(vector, batch);
85
+ }
86
+ }
87
+ }
88
+
89
+ fn collect_query(query: &Query, batch: &mut BatchAccum) {
90
+ match query {
91
+ Query::Nearest(nearest) => collect_vector_input(&nearest.nearest, batch),
92
+ Query::Recommend(recommend) => collect_recommend_input(&recommend.recommend, batch),
93
+ Query::Discover(discover) => collect_discover_input(&discover.discover, batch),
94
+ Query::Context(context) => {
95
+ if let ContextInput(Some(pairs)) = &context.context {
96
+ for pair in pairs {
97
+ collect_context_pair(pair, batch);
98
+ }
99
+ }
100
+ }
101
+ Query::OrderBy(_) | Query::Fusion(_) | Query::Sample(_) => {}
102
+ }
103
+ }
104
+
105
+ fn collect_query_interface(query: &QueryInterface, batch: &mut BatchAccum) {
106
+ match query {
107
+ QueryInterface::Nearest(vector) => collect_vector_input(vector, batch),
108
+ QueryInterface::Query(query) => collect_query(query, batch),
109
+ }
110
+ }
111
+
112
+ fn collect_prefetch(prefetch: &Prefetch, batch: &mut BatchAccum) {
113
+ let Prefetch {
114
+ prefetch,
115
+ query,
116
+ using: _,
117
+ filter: _,
118
+ params: _,
119
+ score_threshold: _,
120
+ limit: _,
121
+ lookup_from: _,
122
+ } = prefetch;
123
+
124
+ if let Some(query) = query {
125
+ collect_query_interface(query, batch);
126
+ }
127
+
128
+ if let Some(prefetches) = prefetch {
129
+ for p in prefetches {
130
+ collect_prefetch(p, batch);
131
+ }
132
+ }
133
+ }
134
+
135
+ pub fn collect_query_groups_request(request: &QueryGroupsRequestInternal) -> BatchAccum {
136
+ let mut batch = BatchAccum::new();
137
+
138
+ let QueryGroupsRequestInternal {
139
+ query,
140
+ prefetch,
141
+ using: _,
142
+ filter: _,
143
+ params: _,
144
+ score_threshold: _,
145
+ with_vector: _,
146
+ with_payload: _,
147
+ lookup_from: _,
148
+ group_request: _,
149
+ } = request;
150
+
151
+ if let Some(query) = query {
152
+ collect_query_interface(query, &mut batch);
153
+ }
154
+
155
+ if let Some(prefetches) = prefetch {
156
+ for prefetch in prefetches {
157
+ collect_prefetch(prefetch, &mut batch);
158
+ }
159
+ }
160
+
161
+ batch
162
+ }
163
+
164
+ pub fn collect_query_request(request: &QueryRequestInternal) -> BatchAccum {
165
+ let mut batch = BatchAccum::new();
166
+
167
+ let QueryRequestInternal {
168
+ prefetch,
169
+ query,
170
+ using: _,
171
+ filter: _,
172
+ score_threshold: _,
173
+ params: _,
174
+ limit: _,
175
+ offset: _,
176
+ with_vector: _,
177
+ with_payload: _,
178
+ lookup_from: _,
179
+ } = request;
180
+
181
+ if let Some(query) = query {
182
+ collect_query_interface(query, &mut batch);
183
+ }
184
+
185
+ if let Some(prefetches) = prefetch {
186
+ for prefetch in prefetches {
187
+ collect_prefetch(prefetch, &mut batch);
188
+ }
189
+ }
190
+
191
+ batch
192
+ }
193
+
194
+ #[cfg(test)]
195
+ mod tests {
196
+ use api::rest::schema::{DiscoverQuery, Document, Image, InferenceObject, NearestQuery};
197
+ use api::rest::QueryBaseGroupRequest;
198
+ use serde_json::json;
199
+
200
+ use super::*;
201
+
202
+ fn create_test_document(text: &str) -> Document {
203
+ Document {
204
+ text: text.to_string(),
205
+ model: "test-model".to_string(),
206
+ options: Default::default(),
207
+ }
208
+ }
209
+
210
+ fn create_test_image(url: &str) -> Image {
211
+ Image {
212
+ image: json!({"data": url.to_string()}),
213
+ model: "test-model".to_string(),
214
+ options: Default::default(),
215
+ }
216
+ }
217
+
218
+ fn create_test_object(data: &str) -> InferenceObject {
219
+ InferenceObject {
220
+ object: json!({"data": data}),
221
+ model: "test-model".to_string(),
222
+ options: Default::default(),
223
+ }
224
+ }
225
+
226
+ #[test]
227
+ fn test_batch_accum_basic() {
228
+ let mut batch = BatchAccum::new();
229
+ assert!(batch.objects.is_empty());
230
+
231
+ let doc = InferenceData::Document(create_test_document("test"));
232
+ batch.add(doc.clone());
233
+ assert_eq!(batch.objects.len(), 1);
234
+
235
+ batch.add(doc);
236
+ assert_eq!(batch.objects.len(), 1);
237
+ }
238
+
239
+ #[test]
240
+ fn test_batch_accum_extend() {
241
+ let mut batch1 = BatchAccum::new();
242
+ let mut batch2 = BatchAccum::new();
243
+
244
+ let doc1 = InferenceData::Document(create_test_document("test1"));
245
+ let doc2 = InferenceData::Document(create_test_document("test2"));
246
+
247
+ batch1.add(doc1);
248
+ batch2.add(doc2);
249
+
250
+ batch1.extend(batch2);
251
+ assert_eq!(batch1.objects.len(), 2);
252
+ }
253
+
254
+ #[test]
255
+ fn test_deduplication() {
256
+ let mut batch = BatchAccum::new();
257
+
258
+ let doc1 = InferenceData::Document(create_test_document("same"));
259
+ let doc2 = InferenceData::Document(create_test_document("same"));
260
+
261
+ batch.add(doc1);
262
+ batch.add(doc2);
263
+
264
+ assert_eq!(batch.objects.len(), 1);
265
+ }
266
+
267
+ #[test]
268
+ fn test_collect_vector_input() {
269
+ let mut batch = BatchAccum::new();
270
+
271
+ let doc_input = VectorInput::Document(create_test_document("test"));
272
+ let img_input = VectorInput::Image(create_test_image("test.jpg"));
273
+ let obj_input = VectorInput::Object(create_test_object("test"));
274
+
275
+ collect_vector_input(&doc_input, &mut batch);
276
+ collect_vector_input(&img_input, &mut batch);
277
+ collect_vector_input(&obj_input, &mut batch);
278
+
279
+ assert_eq!(batch.objects.len(), 3);
280
+ }
281
+
282
+ #[test]
283
+ fn test_collect_prefetch() {
284
+ let prefetch = Prefetch {
285
+ query: Some(QueryInterface::Nearest(VectorInput::Document(
286
+ create_test_document("test"),
287
+ ))),
288
+ prefetch: Some(vec![Prefetch {
289
+ query: Some(QueryInterface::Nearest(VectorInput::Image(
290
+ create_test_image("nested.jpg"),
291
+ ))),
292
+ prefetch: None,
293
+ using: None,
294
+ filter: None,
295
+ params: None,
296
+ score_threshold: None,
297
+ limit: None,
298
+ lookup_from: None,
299
+ }]),
300
+ using: None,
301
+ filter: None,
302
+ params: None,
303
+ score_threshold: None,
304
+ limit: None,
305
+ lookup_from: None,
306
+ };
307
+
308
+ let mut batch = BatchAccum::new();
309
+ collect_prefetch(&prefetch, &mut batch);
310
+ assert_eq!(batch.objects.len(), 2);
311
+ }
312
+
313
+ #[test]
314
+ fn test_collect_query_groups_request() {
315
+ let request = QueryGroupsRequestInternal {
316
+ query: Some(QueryInterface::Query(Query::Nearest(NearestQuery {
317
+ nearest: VectorInput::Document(create_test_document("test")),
318
+ }))),
319
+ prefetch: Some(vec![Prefetch {
320
+ query: Some(QueryInterface::Query(Query::Discover(DiscoverQuery {
321
+ discover: DiscoverInput {
322
+ target: VectorInput::Image(create_test_image("test.jpg")),
323
+ context: Some(vec![ContextPair {
324
+ positive: VectorInput::Document(create_test_document("pos")),
325
+ negative: VectorInput::Image(create_test_image("neg.jpg")),
326
+ }]),
327
+ },
328
+ }))),
329
+ prefetch: None,
330
+ using: None,
331
+ filter: None,
332
+ params: None,
333
+ score_threshold: None,
334
+ limit: None,
335
+ lookup_from: None,
336
+ }]),
337
+ using: None,
338
+ filter: None,
339
+ params: None,
340
+ score_threshold: None,
341
+ with_vector: None,
342
+ with_payload: None,
343
+ lookup_from: None,
344
+ group_request: QueryBaseGroupRequest {
345
+ group_by: "test".parse().unwrap(),
346
+ group_size: None,
347
+ limit: None,
348
+ with_lookup: None,
349
+ },
350
+ };
351
+
352
+ let batch = collect_query_groups_request(&request);
353
+ assert_eq!(batch.objects.len(), 4);
354
+ }
355
+
356
+ #[test]
357
+ fn test_different_model_same_content() {
358
+ let mut batch = BatchAccum::new();
359
+
360
+ let mut doc1 = create_test_document("same");
361
+ let mut doc2 = create_test_document("same");
362
+ doc1.model = "model1".to_string();
363
+ doc2.model = "model2".to_string();
364
+
365
+ batch.add(InferenceData::Document(doc1));
366
+ batch.add(InferenceData::Document(doc2));
367
+
368
+ assert_eq!(batch.objects.len(), 2);
369
+ }
370
+ }
src/common/inference/batch_processing_grpc.rs ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::HashSet;
2
+
3
+ use api::grpc::qdrant::vector_input::Variant;
4
+ use api::grpc::qdrant::{
5
+ query, ContextInput, ContextInputPair, DiscoverInput, PrefetchQuery, Query, RecommendInput,
6
+ VectorInput,
7
+ };
8
+ use api::rest::schema as rest;
9
+ use tonic::Status;
10
+
11
+ use super::service::{InferenceData, InferenceInput, InferenceRequest};
12
+
13
+ pub struct BatchAccumGrpc {
14
+ pub(crate) objects: HashSet<InferenceData>,
15
+ }
16
+
17
+ impl BatchAccumGrpc {
18
+ pub fn new() -> Self {
19
+ Self {
20
+ objects: HashSet::new(),
21
+ }
22
+ }
23
+
24
+ pub fn add(&mut self, data: InferenceData) {
25
+ self.objects.insert(data);
26
+ }
27
+
28
+ pub fn extend(&mut self, other: BatchAccumGrpc) {
29
+ self.objects.extend(other.objects);
30
+ }
31
+
32
+ pub fn is_empty(&self) -> bool {
33
+ self.objects.is_empty()
34
+ }
35
+ }
36
+
37
+ impl From<&BatchAccumGrpc> for InferenceRequest {
38
+ fn from(batch: &BatchAccumGrpc) -> Self {
39
+ Self {
40
+ inputs: batch
41
+ .objects
42
+ .iter()
43
+ .cloned()
44
+ .map(InferenceInput::from)
45
+ .collect(),
46
+ inference: None,
47
+ token: None,
48
+ }
49
+ }
50
+ }
51
+
52
+ fn collect_vector_input(vector: &VectorInput, batch: &mut BatchAccumGrpc) -> Result<(), Status> {
53
+ let Some(variant) = &vector.variant else {
54
+ return Ok(());
55
+ };
56
+
57
+ match variant {
58
+ Variant::Id(_) => {}
59
+ Variant::Dense(_) => {}
60
+ Variant::Sparse(_) => {}
61
+ Variant::MultiDense(_) => {}
62
+ Variant::Document(document) => {
63
+ let doc = rest::Document::try_from(document.clone())
64
+ .map_err(|e| Status::internal(format!("Document conversion error: {e:?}")))?;
65
+ batch.add(InferenceData::Document(doc));
66
+ }
67
+ Variant::Image(image) => {
68
+ let img = rest::Image::try_from(image.clone())
69
+ .map_err(|e| Status::internal(format!("Image conversion error: {e:?}")))?;
70
+ batch.add(InferenceData::Image(img));
71
+ }
72
+ Variant::Object(object) => {
73
+ let obj = rest::InferenceObject::try_from(object.clone())
74
+ .map_err(|e| Status::internal(format!("Object conversion error: {e:?}")))?;
75
+ batch.add(InferenceData::Object(obj));
76
+ }
77
+ }
78
+ Ok(())
79
+ }
80
+
81
+ pub(crate) fn collect_context_input(
82
+ context: &ContextInput,
83
+ batch: &mut BatchAccumGrpc,
84
+ ) -> Result<(), Status> {
85
+ let ContextInput { pairs } = context;
86
+
87
+ for pair in pairs {
88
+ collect_context_input_pair(pair, batch)?;
89
+ }
90
+
91
+ Ok(())
92
+ }
93
+
94
+ fn collect_context_input_pair(
95
+ pair: &ContextInputPair,
96
+ batch: &mut BatchAccumGrpc,
97
+ ) -> Result<(), Status> {
98
+ let ContextInputPair { positive, negative } = pair;
99
+
100
+ if let Some(positive) = positive {
101
+ collect_vector_input(positive, batch)?;
102
+ }
103
+
104
+ if let Some(negative) = negative {
105
+ collect_vector_input(negative, batch)?;
106
+ }
107
+
108
+ Ok(())
109
+ }
110
+
111
+ pub(crate) fn collect_discover_input(
112
+ discover: &DiscoverInput,
113
+ batch: &mut BatchAccumGrpc,
114
+ ) -> Result<(), Status> {
115
+ let DiscoverInput { target, context } = discover;
116
+
117
+ if let Some(vector) = target {
118
+ collect_vector_input(vector, batch)?;
119
+ }
120
+
121
+ if let Some(context) = context {
122
+ for pair in &context.pairs {
123
+ collect_context_input_pair(pair, batch)?;
124
+ }
125
+ }
126
+
127
+ Ok(())
128
+ }
129
+
130
+ pub(crate) fn collect_recommend_input(
131
+ recommend: &RecommendInput,
132
+ batch: &mut BatchAccumGrpc,
133
+ ) -> Result<(), Status> {
134
+ let RecommendInput {
135
+ positive,
136
+ negative,
137
+ strategy: _,
138
+ } = recommend;
139
+
140
+ for vector in positive {
141
+ collect_vector_input(vector, batch)?;
142
+ }
143
+
144
+ for vector in negative {
145
+ collect_vector_input(vector, batch)?;
146
+ }
147
+
148
+ Ok(())
149
+ }
150
+
151
+ pub(crate) fn collect_query(query: &Query, batch: &mut BatchAccumGrpc) -> Result<(), Status> {
152
+ let Some(variant) = &query.variant else {
153
+ return Ok(());
154
+ };
155
+
156
+ match variant {
157
+ query::Variant::Nearest(nearest) => collect_vector_input(nearest, batch)?,
158
+ query::Variant::Recommend(recommend) => collect_recommend_input(recommend, batch)?,
159
+ query::Variant::Discover(discover) => collect_discover_input(discover, batch)?,
160
+ query::Variant::Context(context) => collect_context_input(context, batch)?,
161
+ query::Variant::OrderBy(_) => {}
162
+ query::Variant::Fusion(_) => {}
163
+ query::Variant::Sample(_) => {}
164
+ }
165
+
166
+ Ok(())
167
+ }
168
+
169
+ pub(crate) fn collect_prefetch(
170
+ prefetch: &PrefetchQuery,
171
+ batch: &mut BatchAccumGrpc,
172
+ ) -> Result<(), Status> {
173
+ let PrefetchQuery {
174
+ prefetch,
175
+ query,
176
+ using: _,
177
+ filter: _,
178
+ params: _,
179
+ score_threshold: _,
180
+ limit: _,
181
+ lookup_from: _,
182
+ } = prefetch;
183
+
184
+ if let Some(query) = query {
185
+ collect_query(query, batch)?;
186
+ }
187
+
188
+ for p in prefetch {
189
+ collect_prefetch(p, batch)?;
190
+ }
191
+
192
+ Ok(())
193
+ }
194
+
195
+ #[cfg(test)]
196
+ mod tests {
197
+ use api::rest::schema::{Document, Image, InferenceObject};
198
+ use serde_json::json;
199
+
200
+ use super::*;
201
+
202
+ fn create_test_document(text: &str) -> Document {
203
+ Document {
204
+ text: text.to_string(),
205
+ model: "test-model".to_string(),
206
+ options: Default::default(),
207
+ }
208
+ }
209
+
210
+ fn create_test_image(url: &str) -> Image {
211
+ Image {
212
+ image: json!({"data": url.to_string()}),
213
+ model: "test-model".to_string(),
214
+ options: Default::default(),
215
+ }
216
+ }
217
+
218
+ fn create_test_object(data: &str) -> InferenceObject {
219
+ InferenceObject {
220
+ object: json!({"data": data}),
221
+ model: "test-model".to_string(),
222
+ options: Default::default(),
223
+ }
224
+ }
225
+
226
+ #[test]
227
+ fn test_batch_accum_basic() {
228
+ let mut batch = BatchAccumGrpc::new();
229
+ assert!(batch.objects.is_empty());
230
+
231
+ let doc = InferenceData::Document(create_test_document("test"));
232
+ batch.add(doc.clone());
233
+ assert_eq!(batch.objects.len(), 1);
234
+
235
+ batch.add(doc);
236
+ assert_eq!(batch.objects.len(), 1);
237
+ }
238
+
239
+ #[test]
240
+ fn test_batch_accum_extend() {
241
+ let mut batch1 = BatchAccumGrpc::new();
242
+ let mut batch2 = BatchAccumGrpc::new();
243
+
244
+ let doc1 = InferenceData::Document(create_test_document("test1"));
245
+ let doc2 = InferenceData::Document(create_test_document("test2"));
246
+
247
+ batch1.add(doc1);
248
+ batch2.add(doc2);
249
+
250
+ batch1.extend(batch2);
251
+ assert_eq!(batch1.objects.len(), 2);
252
+ }
253
+
254
+ #[test]
255
+ fn test_deduplication() {
256
+ let mut batch = BatchAccumGrpc::new();
257
+
258
+ let doc1 = InferenceData::Document(create_test_document("same"));
259
+ let doc2 = InferenceData::Document(create_test_document("same"));
260
+
261
+ batch.add(doc1);
262
+ batch.add(doc2);
263
+
264
+ assert_eq!(batch.objects.len(), 1);
265
+ }
266
+
267
+ #[test]
268
+ fn test_different_model_same_content() {
269
+ let mut batch = BatchAccumGrpc::new();
270
+
271
+ let mut doc1 = create_test_document("same");
272
+ let mut doc2 = create_test_document("same");
273
+ doc1.model = "model1".to_string();
274
+ doc2.model = "model2".to_string();
275
+
276
+ batch.add(InferenceData::Document(doc1));
277
+ batch.add(InferenceData::Document(doc2));
278
+
279
+ assert_eq!(batch.objects.len(), 2);
280
+ }
281
+ }
src/common/inference/config.rs ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use serde::{Deserialize, Serialize};
2
+
3
+ #[derive(Debug, Clone, Serialize, Deserialize)]
4
+ pub struct InferenceConfig {
5
+ pub address: Option<String>,
6
+ #[serde(default = "default_inference_timeout")]
7
+ pub timeout: u64,
8
+ pub token: Option<String>,
9
+ }
10
+
11
+ fn default_inference_timeout() -> u64 {
12
+ 10
13
+ }
14
+
15
+ impl InferenceConfig {
16
+ pub fn new(address: Option<String>) -> Self {
17
+ Self {
18
+ address,
19
+ timeout: default_inference_timeout(),
20
+ token: None,
21
+ }
22
+ }
23
+ }
src/common/inference/infer_processing.rs ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::{HashMap, HashSet};
2
+
3
+ use collection::operations::point_ops::VectorPersisted;
4
+ use storage::content_manager::errors::StorageError;
5
+
6
+ use super::batch_processing::BatchAccum;
7
+ use super::service::{InferenceData, InferenceInput, InferenceService, InferenceType};
8
+
9
+ pub struct BatchAccumInferred {
10
+ pub(crate) objects: HashMap<InferenceData, VectorPersisted>,
11
+ }
12
+
13
+ impl BatchAccumInferred {
14
+ pub fn new() -> Self {
15
+ Self {
16
+ objects: HashMap::new(),
17
+ }
18
+ }
19
+
20
+ pub async fn from_objects(
21
+ objects: HashSet<InferenceData>,
22
+ inference_type: InferenceType,
23
+ ) -> Result<Self, StorageError> {
24
+ if objects.is_empty() {
25
+ return Ok(Self::new());
26
+ }
27
+
28
+ let Some(service) = InferenceService::get_global() else {
29
+ return Err(StorageError::service_error(
30
+ "InferenceService is not initialized. Please check if it was properly configured and initialized during startup."
31
+ ));
32
+ };
33
+
34
+ service.validate()?;
35
+
36
+ let objects_serialized: Vec<_> = objects.into_iter().collect();
37
+ let inference_inputs: Vec<_> = objects_serialized
38
+ .iter()
39
+ .cloned()
40
+ .map(InferenceInput::from)
41
+ .collect();
42
+
43
+ let vectors = service
44
+ .infer(inference_inputs, inference_type)
45
+ .await
46
+ .map_err(|e| StorageError::service_error(
47
+ format!("Inference request failed. Check if inference service is running and properly configured: {e}")
48
+ ))?;
49
+
50
+ if vectors.is_empty() {
51
+ return Err(StorageError::service_error(
52
+ "Inference service returned no vectors. Check if models are properly loaded.",
53
+ ));
54
+ }
55
+
56
+ let objects = objects_serialized.into_iter().zip(vectors).collect();
57
+
58
+ Ok(Self { objects })
59
+ }
60
+
61
+ pub async fn from_batch_accum(
62
+ batch: BatchAccum,
63
+ inference_type: InferenceType,
64
+ ) -> Result<Self, StorageError> {
65
+ let BatchAccum { objects } = batch;
66
+ Self::from_objects(objects, inference_type).await
67
+ }
68
+
69
+ pub fn get_vector(&self, data: &InferenceData) -> Option<&VectorPersisted> {
70
+ self.objects.get(data)
71
+ }
72
+ }
src/common/inference/mod.rs ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ mod batch_processing;
2
+ mod batch_processing_grpc;
3
+ pub(crate) mod config;
4
+ mod infer_processing;
5
+ pub mod query_requests_grpc;
6
+ pub mod query_requests_rest;
7
+ pub mod service;
8
+ pub mod update_requests;
src/common/inference/query_requests_grpc.rs ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use api::conversions::json::json_path_from_proto;
2
+ use api::grpc::qdrant as grpc;
3
+ use api::grpc::qdrant::query::Variant;
4
+ use api::grpc::qdrant::RecommendInput;
5
+ use api::rest;
6
+ use api::rest::RecommendStrategy;
7
+ use collection::operations::universal_query::collection_query::{
8
+ CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, Query,
9
+ VectorInputInternal, VectorQuery,
10
+ };
11
+ use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal};
12
+ use segment::data_types::order_by::OrderBy;
13
+ use segment::data_types::vectors::{VectorInternal, DEFAULT_VECTOR_NAME};
14
+ use segment::vector_storage::query::{ContextPair, ContextQuery, DiscoveryQuery, RecoQuery};
15
+ use tonic::Status;
16
+
17
+ use crate::common::inference::batch_processing_grpc::{
18
+ collect_prefetch, collect_query, BatchAccumGrpc,
19
+ };
20
+ use crate::common::inference::infer_processing::BatchAccumInferred;
21
+ use crate::common::inference::service::{InferenceData, InferenceType};
22
+
23
+ /// ToDo: this function is supposed to call an inference endpoint internally
24
+ pub async fn convert_query_point_groups_from_grpc(
25
+ query: grpc::QueryPointGroups,
26
+ ) -> Result<CollectionQueryGroupsRequest, Status> {
27
+ let grpc::QueryPointGroups {
28
+ collection_name: _,
29
+ prefetch,
30
+ query,
31
+ using,
32
+ filter,
33
+ params,
34
+ score_threshold,
35
+ with_payload,
36
+ with_vectors,
37
+ lookup_from,
38
+ limit,
39
+ group_size,
40
+ group_by,
41
+ with_lookup,
42
+ read_consistency: _,
43
+ timeout: _,
44
+ shard_key_selector: _,
45
+ } = query;
46
+
47
+ let mut batch = BatchAccumGrpc::new();
48
+
49
+ if let Some(q) = &query {
50
+ collect_query(q, &mut batch)?;
51
+ }
52
+
53
+ for p in &prefetch {
54
+ collect_prefetch(p, &mut batch)?;
55
+ }
56
+
57
+ let BatchAccumGrpc { objects } = batch;
58
+
59
+ let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search)
60
+ .await
61
+ .map_err(|e| Status::internal(format!("Inference error: {e}")))?;
62
+
63
+ let query = if let Some(q) = query {
64
+ Some(convert_query_with_inferred(q, &inferred)?)
65
+ } else {
66
+ None
67
+ };
68
+
69
+ let prefetch = prefetch
70
+ .into_iter()
71
+ .map(|p| convert_prefetch_with_inferred(p, &inferred))
72
+ .collect::<Result<Vec<_>, _>>()?;
73
+
74
+ let request = CollectionQueryGroupsRequest {
75
+ prefetch,
76
+ query,
77
+ using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
78
+ filter: filter.map(TryFrom::try_from).transpose()?,
79
+ score_threshold,
80
+ with_vector: with_vectors
81
+ .map(From::from)
82
+ .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
83
+ with_payload: with_payload
84
+ .map(TryFrom::try_from)
85
+ .transpose()?
86
+ .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
87
+ lookup_from: lookup_from.map(From::from),
88
+ group_by: json_path_from_proto(&group_by)?,
89
+ group_size: group_size
90
+ .map(|s| s as usize)
91
+ .unwrap_or(CollectionQueryRequest::DEFAULT_GROUP_SIZE),
92
+ limit: limit
93
+ .map(|l| l as usize)
94
+ .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
95
+ params: params.map(From::from),
96
+ with_lookup: with_lookup.map(TryFrom::try_from).transpose()?,
97
+ };
98
+
99
+ Ok(request)
100
+ }
101
+
102
+ /// ToDo: this function is supposed to call an inference endpoint internally
103
+ pub async fn convert_query_points_from_grpc(
104
+ query: grpc::QueryPoints,
105
+ ) -> Result<CollectionQueryRequest, Status> {
106
+ let grpc::QueryPoints {
107
+ collection_name: _,
108
+ prefetch,
109
+ query,
110
+ using,
111
+ filter,
112
+ params,
113
+ score_threshold,
114
+ limit,
115
+ offset,
116
+ with_payload,
117
+ with_vectors,
118
+ read_consistency: _,
119
+ shard_key_selector: _,
120
+ lookup_from,
121
+ timeout: _,
122
+ } = query;
123
+
124
+ let mut batch = BatchAccumGrpc::new();
125
+
126
+ if let Some(q) = &query {
127
+ collect_query(q, &mut batch)?;
128
+ }
129
+
130
+ for p in &prefetch {
131
+ collect_prefetch(p, &mut batch)?;
132
+ }
133
+
134
+ let BatchAccumGrpc { objects } = batch;
135
+
136
+ let inferred = BatchAccumInferred::from_objects(objects, InferenceType::Search)
137
+ .await
138
+ .map_err(|e| Status::internal(format!("Inference error: {e}")))?;
139
+
140
+ let prefetch = prefetch
141
+ .into_iter()
142
+ .map(|p| convert_prefetch_with_inferred(p, &inferred))
143
+ .collect::<Result<Vec<_>, _>>()?;
144
+
145
+ let query = query
146
+ .map(|q| convert_query_with_inferred(q, &inferred))
147
+ .transpose()?;
148
+
149
+ Ok(CollectionQueryRequest {
150
+ prefetch,
151
+ query,
152
+ using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
153
+ filter: filter.map(TryFrom::try_from).transpose()?,
154
+ score_threshold,
155
+ limit: limit
156
+ .map(|l| l as usize)
157
+ .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
158
+ offset: offset
159
+ .map(|o| o as usize)
160
+ .unwrap_or(CollectionQueryRequest::DEFAULT_OFFSET),
161
+ params: params.map(From::from),
162
+ with_vector: with_vectors
163
+ .map(From::from)
164
+ .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
165
+ with_payload: with_payload
166
+ .map(TryFrom::try_from)
167
+ .transpose()?
168
+ .unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
169
+ lookup_from: lookup_from.map(From::from),
170
+ })
171
+ }
172
+
173
+ fn convert_prefetch_with_inferred(
174
+ prefetch: grpc::PrefetchQuery,
175
+ inferred: &BatchAccumInferred,
176
+ ) -> Result<CollectionPrefetch, Status> {
177
+ let grpc::PrefetchQuery {
178
+ prefetch,
179
+ query,
180
+ using,
181
+ filter,
182
+ params,
183
+ score_threshold,
184
+ limit,
185
+ lookup_from,
186
+ } = prefetch;
187
+
188
+ let nested_prefetches = prefetch
189
+ .into_iter()
190
+ .map(|p| convert_prefetch_with_inferred(p, inferred))
191
+ .collect::<Result<Vec<_>, _>>()?;
192
+
193
+ let query = query
194
+ .map(|q| convert_query_with_inferred(q, inferred))
195
+ .transpose()?;
196
+
197
+ Ok(CollectionPrefetch {
198
+ prefetch: nested_prefetches,
199
+ query,
200
+ using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
201
+ filter: filter.map(TryFrom::try_from).transpose()?,
202
+ score_threshold,
203
+ limit: limit
204
+ .map(|l| l as usize)
205
+ .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
206
+ params: params.map(From::from),
207
+ lookup_from: lookup_from.map(From::from),
208
+ })
209
+ }
210
+
211
+ fn convert_query_with_inferred(
212
+ query: grpc::Query,
213
+ inferred: &BatchAccumInferred,
214
+ ) -> Result<Query, Status> {
215
+ let variant = query
216
+ .variant
217
+ .ok_or_else(|| Status::invalid_argument("Query variant is missing"))?;
218
+
219
+ let query = match variant {
220
+ Variant::Nearest(nearest) => {
221
+ let vector = convert_vector_input_with_inferred(nearest, inferred)?;
222
+ Query::Vector(VectorQuery::Nearest(vector))
223
+ }
224
+ Variant::Recommend(recommend) => {
225
+ let RecommendInput {
226
+ positive,
227
+ negative,
228
+ strategy,
229
+ } = recommend;
230
+
231
+ let positives = positive
232
+ .into_iter()
233
+ .map(|v| convert_vector_input_with_inferred(v, inferred))
234
+ .collect::<Result<Vec<_>, _>>()?;
235
+
236
+ let negatives = negative
237
+ .into_iter()
238
+ .map(|v| convert_vector_input_with_inferred(v, inferred))
239
+ .collect::<Result<Vec<_>, _>>()?;
240
+
241
+ let reco_query = RecoQuery::new(positives, negatives);
242
+
243
+ let strategy = strategy
244
+ .and_then(|x| grpc::RecommendStrategy::try_from(x).ok())
245
+ .map(RecommendStrategy::from)
246
+ .unwrap_or_default();
247
+
248
+ match strategy {
249
+ RecommendStrategy::AverageVector => {
250
+ Query::Vector(VectorQuery::RecommendAverageVector(reco_query))
251
+ }
252
+ RecommendStrategy::BestScore => {
253
+ Query::Vector(VectorQuery::RecommendBestScore(reco_query))
254
+ }
255
+ }
256
+ }
257
+ Variant::Discover(discover) => {
258
+ let grpc::DiscoverInput { target, context } = discover;
259
+
260
+ let target = target
261
+ .map(|t| convert_vector_input_with_inferred(t, inferred))
262
+ .transpose()?
263
+ .ok_or_else(|| Status::invalid_argument("DiscoverInput target is missing"))?;
264
+
265
+ let grpc::ContextInput { pairs } = context
266
+ .ok_or_else(|| Status::invalid_argument("DiscoverInput context is missing"))?;
267
+
268
+ let context = pairs
269
+ .into_iter()
270
+ .map(|pair| context_pair_from_grpc_with_inferred(pair, inferred))
271
+ .collect::<Result<_, _>>()?;
272
+
273
+ Query::Vector(VectorQuery::Discover(DiscoveryQuery::new(target, context)))
274
+ }
275
+ Variant::Context(context) => {
276
+ let context_query = context_query_from_grpc_with_inferred(context, inferred)?;
277
+ Query::Vector(VectorQuery::Context(context_query))
278
+ }
279
+ Variant::OrderBy(order_by) => Query::OrderBy(OrderBy::try_from(order_by)?),
280
+ Variant::Fusion(fusion) => Query::Fusion(FusionInternal::try_from(fusion)?),
281
+ Variant::Sample(sample) => Query::Sample(SampleInternal::try_from(sample)?),
282
+ };
283
+
284
+ Ok(query)
285
+ }
286
+
287
+ fn convert_vector_input_with_inferred(
288
+ vector: grpc::VectorInput,
289
+ inferred: &BatchAccumInferred,
290
+ ) -> Result<VectorInputInternal, Status> {
291
+ use api::grpc::qdrant::vector_input::Variant;
292
+
293
+ let variant = vector
294
+ .variant
295
+ .ok_or_else(|| Status::invalid_argument("VectorInput variant is missing"))?;
296
+
297
+ match variant {
298
+ Variant::Id(id) => Ok(VectorInputInternal::Id(TryFrom::try_from(id)?)),
299
+ Variant::Dense(dense) => Ok(VectorInputInternal::Vector(VectorInternal::Dense(
300
+ From::from(dense),
301
+ ))),
302
+ Variant::Sparse(sparse) => Ok(VectorInputInternal::Vector(VectorInternal::Sparse(
303
+ From::from(sparse),
304
+ ))),
305
+ Variant::MultiDense(multi_dense) => Ok(VectorInputInternal::Vector(
306
+ VectorInternal::MultiDense(From::from(multi_dense)),
307
+ )),
308
+ Variant::Document(doc) => {
309
+ let doc: rest::Document = doc
310
+ .try_into()
311
+ .map_err(|e| Status::internal(format!("Document conversion error: {e}")))?;
312
+ let data = InferenceData::Document(doc);
313
+ let vector = inferred
314
+ .get_vector(&data)
315
+ .ok_or_else(|| Status::internal("Missing inferred vector for document"))?;
316
+
317
+ Ok(VectorInputInternal::Vector(VectorInternal::from(
318
+ vector.clone(),
319
+ )))
320
+ }
321
+ Variant::Image(img) => {
322
+ let img: rest::Image = img
323
+ .try_into()
324
+ .map_err(|e| Status::internal(format!("Image conversion error: {e}",)))?;
325
+ let data = InferenceData::Image(img);
326
+
327
+ let vector = inferred
328
+ .get_vector(&data)
329
+ .ok_or_else(|| Status::internal("Missing inferred vector for image"))?;
330
+
331
+ Ok(VectorInputInternal::Vector(VectorInternal::from(
332
+ vector.clone(),
333
+ )))
334
+ }
335
+ Variant::Object(obj) => {
336
+ let obj: rest::InferenceObject = obj
337
+ .try_into()
338
+ .map_err(|e| Status::internal(format!("Object conversion error: {e}")))?;
339
+ let data = InferenceData::Object(obj);
340
+ let vector = inferred
341
+ .get_vector(&data)
342
+ .ok_or_else(|| Status::internal("Missing inferred vector for object"))?;
343
+
344
+ Ok(VectorInputInternal::Vector(VectorInternal::from(
345
+ vector.clone(),
346
+ )))
347
+ }
348
+ }
349
+ }
350
+
351
+ fn context_query_from_grpc_with_inferred(
352
+ value: grpc::ContextInput,
353
+ inferred: &BatchAccumInferred,
354
+ ) -> Result<ContextQuery<VectorInputInternal>, Status> {
355
+ let grpc::ContextInput { pairs } = value;
356
+
357
+ Ok(ContextQuery {
358
+ pairs: pairs
359
+ .into_iter()
360
+ .map(|pair| context_pair_from_grpc_with_inferred(pair, inferred))
361
+ .collect::<Result<_, _>>()?,
362
+ })
363
+ }
364
+
365
+ fn context_pair_from_grpc_with_inferred(
366
+ value: grpc::ContextInputPair,
367
+ inferred: &BatchAccumInferred,
368
+ ) -> Result<ContextPair<VectorInputInternal>, Status> {
369
+ let grpc::ContextInputPair { positive, negative } = value;
370
+
371
+ let positive =
372
+ positive.ok_or_else(|| Status::invalid_argument("ContextPair positive is missing"))?;
373
+ let negative =
374
+ negative.ok_or_else(|| Status::invalid_argument("ContextPair negative is missing"))?;
375
+
376
+ Ok(ContextPair {
377
+ positive: convert_vector_input_with_inferred(positive, inferred)?,
378
+ negative: convert_vector_input_with_inferred(negative, inferred)?,
379
+ })
380
+ }
381
+
382
+ #[cfg(test)]
383
+ mod tests {
384
+ use std::collections::HashMap;
385
+
386
+ use api::grpc::qdrant::value::Kind;
387
+ use api::grpc::qdrant::vector_input::Variant;
388
+ use api::grpc::qdrant::Value;
389
+ use collection::operations::point_ops::VectorPersisted;
390
+
391
+ use super::*;
392
+
393
+ fn create_test_document() -> api::grpc::qdrant::Document {
394
+ api::grpc::qdrant::Document {
395
+ text: "test".to_string(),
396
+ model: "test-model".to_string(),
397
+ options: HashMap::new(),
398
+ }
399
+ }
400
+
401
+ fn create_test_image() -> api::grpc::qdrant::Image {
402
+ api::grpc::qdrant::Image {
403
+ image: Some(Value {
404
+ kind: Some(Kind::StringValue("test.jpg".to_string())),
405
+ }),
406
+ model: "test-model".to_string(),
407
+ options: HashMap::new(),
408
+ }
409
+ }
410
+
411
+ fn create_test_object() -> api::grpc::qdrant::InferenceObject {
412
+ api::grpc::qdrant::InferenceObject {
413
+ object: Some(Value {
414
+ kind: Some(Kind::StringValue("test".to_string())),
415
+ }),
416
+ model: "test-model".to_string(),
417
+ options: HashMap::new(),
418
+ }
419
+ }
420
+
421
+ fn create_test_inferred_batch() -> BatchAccumInferred {
422
+ let mut objects = HashMap::new();
423
+
424
+ let grpc_doc = create_test_document();
425
+ let grpc_img = create_test_image();
426
+ let grpc_obj = create_test_object();
427
+
428
+ let doc: rest::Document = grpc_doc.try_into().unwrap();
429
+ let img: rest::Image = grpc_img.try_into().unwrap();
430
+ let obj: rest::InferenceObject = grpc_obj.try_into().unwrap();
431
+
432
+ let doc_data = InferenceData::Document(doc);
433
+ let img_data = InferenceData::Image(img);
434
+ let obj_data = InferenceData::Object(obj);
435
+
436
+ let dense_vector = vec![1.0, 2.0, 3.0];
437
+ let vector_persisted = VectorPersisted::Dense(dense_vector);
438
+
439
+ objects.insert(doc_data, vector_persisted.clone());
440
+ objects.insert(img_data, vector_persisted.clone());
441
+ objects.insert(obj_data, vector_persisted);
442
+
443
+ BatchAccumInferred { objects }
444
+ }
445
+
446
+ #[test]
447
+ fn test_convert_vector_input_with_inferred_dense() {
448
+ let inferred = create_test_inferred_batch();
449
+ let vector = grpc::VectorInput {
450
+ variant: Some(Variant::Dense(grpc::DenseVector {
451
+ data: vec![1.0, 2.0, 3.0],
452
+ })),
453
+ };
454
+
455
+ let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
456
+ match result {
457
+ VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
458
+ assert_eq!(values, vec![1.0, 2.0, 3.0]);
459
+ }
460
+ _ => panic!("Expected dense vector"),
461
+ }
462
+ }
463
+
464
+ #[test]
465
+ fn test_convert_vector_input_with_inferred_document() {
466
+ let inferred = create_test_inferred_batch();
467
+ let doc = create_test_document();
468
+ let vector = grpc::VectorInput {
469
+ variant: Some(Variant::Document(doc)),
470
+ };
471
+
472
+ let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
473
+ match result {
474
+ VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
475
+ assert_eq!(values, vec![1.0, 2.0, 3.0]);
476
+ }
477
+ _ => panic!("Expected dense vector from inference"),
478
+ }
479
+ }
480
+
481
+ #[test]
482
+ fn test_convert_vector_input_missing_variant() {
483
+ let inferred = create_test_inferred_batch();
484
+ let vector = grpc::VectorInput { variant: None };
485
+
486
+ let result = convert_vector_input_with_inferred(vector, &inferred);
487
+ assert!(result.is_err());
488
+ assert!(result.unwrap_err().message().contains("variant is missing"));
489
+ }
490
+
491
+ #[test]
492
+ fn test_context_pair_from_grpc_with_inferred() {
493
+ let inferred = create_test_inferred_batch();
494
+ let pair = grpc::ContextInputPair {
495
+ positive: Some(grpc::VectorInput {
496
+ variant: Some(Variant::Dense(grpc::DenseVector {
497
+ data: vec![1.0, 2.0, 3.0],
498
+ })),
499
+ }),
500
+ negative: Some(grpc::VectorInput {
501
+ variant: Some(Variant::Document(create_test_document())),
502
+ }),
503
+ };
504
+
505
+ let result = context_pair_from_grpc_with_inferred(pair, &inferred).unwrap();
506
+ match (result.positive, result.negative) {
507
+ (
508
+ VectorInputInternal::Vector(VectorInternal::Dense(pos)),
509
+ VectorInputInternal::Vector(VectorInternal::Dense(neg)),
510
+ ) => {
511
+ assert_eq!(pos, vec![1.0, 2.0, 3.0]);
512
+ assert_eq!(neg, vec![1.0, 2.0, 3.0]);
513
+ }
514
+ _ => panic!("Expected dense vectors"),
515
+ }
516
+ }
517
+
518
+ #[test]
519
+ fn test_context_pair_missing_vectors() {
520
+ let inferred = create_test_inferred_batch();
521
+ let pair = grpc::ContextInputPair {
522
+ positive: None,
523
+ negative: Some(grpc::VectorInput {
524
+ variant: Some(Variant::Document(create_test_document())),
525
+ }),
526
+ };
527
+
528
+ let result = context_pair_from_grpc_with_inferred(pair, &inferred);
529
+ assert!(result.is_err());
530
+ assert!(result
531
+ .unwrap_err()
532
+ .message()
533
+ .contains("positive is missing"));
534
+ }
535
+ }
src/common/inference/query_requests_rest.rs ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use api::rest::schema as rest;
2
+ use collection::lookup::WithLookup;
3
+ use collection::operations::universal_query::collection_query::{
4
+ CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, Query,
5
+ VectorInputInternal, VectorQuery,
6
+ };
7
+ use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal};
8
+ use segment::data_types::order_by::OrderBy;
9
+ use segment::data_types::vectors::{MultiDenseVectorInternal, VectorInternal, DEFAULT_VECTOR_NAME};
10
+ use segment::vector_storage::query::{ContextPair, ContextQuery, DiscoveryQuery, RecoQuery};
11
+ use storage::content_manager::errors::StorageError;
12
+
13
+ use crate::common::inference::batch_processing::{
14
+ collect_query_groups_request, collect_query_request,
15
+ };
16
+ use crate::common::inference::infer_processing::BatchAccumInferred;
17
+ use crate::common::inference::service::{InferenceData, InferenceType};
18
+
19
+ pub async fn convert_query_groups_request_from_rest(
20
+ request: rest::QueryGroupsRequestInternal,
21
+ ) -> Result<CollectionQueryGroupsRequest, StorageError> {
22
+ let batch = collect_query_groups_request(&request);
23
+ let rest::QueryGroupsRequestInternal {
24
+ prefetch,
25
+ query,
26
+ using,
27
+ filter,
28
+ score_threshold,
29
+ params,
30
+ with_vector,
31
+ with_payload,
32
+ lookup_from,
33
+ group_request,
34
+ } = request;
35
+
36
+ let inferred = BatchAccumInferred::from_batch_accum(batch, InferenceType::Search).await?;
37
+ let query = query
38
+ .map(|q| convert_query_with_inferred(q, &inferred))
39
+ .transpose()?;
40
+
41
+ let prefetch = prefetch
42
+ .map(|prefetches| {
43
+ prefetches
44
+ .into_iter()
45
+ .map(|p| convert_prefetch_with_inferred(p, &inferred))
46
+ .collect::<Result<Vec<_>, _>>()
47
+ })
48
+ .transpose()?
49
+ .unwrap_or_default();
50
+
51
+ Ok(CollectionQueryGroupsRequest {
52
+ prefetch,
53
+ query,
54
+ using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
55
+ filter,
56
+ score_threshold,
57
+ params,
58
+ with_vector: with_vector.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
59
+ with_payload: with_payload.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
60
+ lookup_from,
61
+ limit: group_request
62
+ .limit
63
+ .unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
64
+ group_by: group_request.group_by,
65
+ group_size: group_request
66
+ .group_size
67
+ .unwrap_or(CollectionQueryRequest::DEFAULT_GROUP_SIZE),
68
+ with_lookup: group_request.with_lookup.map(WithLookup::from),
69
+ })
70
+ }
71
+
72
+ pub async fn convert_query_request_from_rest(
73
+ request: rest::QueryRequestInternal,
74
+ ) -> Result<CollectionQueryRequest, StorageError> {
75
+ let batch = collect_query_request(&request);
76
+ let inferred = BatchAccumInferred::from_batch_accum(batch, InferenceType::Search).await?;
77
+ let rest::QueryRequestInternal {
78
+ prefetch,
79
+ query,
80
+ using,
81
+ filter,
82
+ score_threshold,
83
+ params,
84
+ limit,
85
+ offset,
86
+ with_vector,
87
+ with_payload,
88
+ lookup_from,
89
+ } = request;
90
+
91
+ let prefetch = prefetch
92
+ .map(|prefetches| {
93
+ prefetches
94
+ .into_iter()
95
+ .map(|p| convert_prefetch_with_inferred(p, &inferred))
96
+ .collect::<Result<Vec<_>, _>>()
97
+ })
98
+ .transpose()?
99
+ .unwrap_or_default();
100
+
101
+ let query = query
102
+ .map(|q| convert_query_with_inferred(q, &inferred))
103
+ .transpose()?;
104
+
105
+ Ok(CollectionQueryRequest {
106
+ prefetch,
107
+ query,
108
+ using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
109
+ filter,
110
+ score_threshold,
111
+ limit: limit.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
112
+ offset: offset.unwrap_or(CollectionQueryRequest::DEFAULT_OFFSET),
113
+ params,
114
+ with_vector: with_vector.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_VECTOR),
115
+ with_payload: with_payload.unwrap_or(CollectionQueryRequest::DEFAULT_WITH_PAYLOAD),
116
+ lookup_from,
117
+ })
118
+ }
119
+
120
+ fn convert_vector_input_with_inferred(
121
+ vector: rest::VectorInput,
122
+ inferred: &BatchAccumInferred,
123
+ ) -> Result<VectorInputInternal, StorageError> {
124
+ match vector {
125
+ rest::VectorInput::Id(id) => Ok(VectorInputInternal::Id(id)),
126
+ rest::VectorInput::DenseVector(dense) => {
127
+ Ok(VectorInputInternal::Vector(VectorInternal::Dense(dense)))
128
+ }
129
+ rest::VectorInput::SparseVector(sparse) => {
130
+ Ok(VectorInputInternal::Vector(VectorInternal::Sparse(sparse)))
131
+ }
132
+ rest::VectorInput::MultiDenseVector(multi_dense) => Ok(VectorInputInternal::Vector(
133
+ VectorInternal::MultiDense(MultiDenseVectorInternal::new_unchecked(multi_dense)),
134
+ )),
135
+ rest::VectorInput::Document(doc) => {
136
+ let data = InferenceData::Document(doc);
137
+ let vector = inferred.get_vector(&data).ok_or_else(|| {
138
+ StorageError::inference_error("Missing inferred vector for document")
139
+ })?;
140
+ Ok(VectorInputInternal::Vector(VectorInternal::from(
141
+ vector.clone(),
142
+ )))
143
+ }
144
+ rest::VectorInput::Image(img) => {
145
+ let data = InferenceData::Image(img);
146
+ let vector = inferred.get_vector(&data).ok_or_else(|| {
147
+ StorageError::inference_error("Missing inferred vector for image")
148
+ })?;
149
+ Ok(VectorInputInternal::Vector(VectorInternal::from(
150
+ vector.clone(),
151
+ )))
152
+ }
153
+ rest::VectorInput::Object(obj) => {
154
+ let data = InferenceData::Object(obj);
155
+ let vector = inferred.get_vector(&data).ok_or_else(|| {
156
+ StorageError::inference_error("Missing inferred vector for object")
157
+ })?;
158
+ Ok(VectorInputInternal::Vector(VectorInternal::from(
159
+ vector.clone(),
160
+ )))
161
+ }
162
+ }
163
+ }
164
+
165
+ fn convert_query_with_inferred(
166
+ query: rest::QueryInterface,
167
+ inferred: &BatchAccumInferred,
168
+ ) -> Result<Query, StorageError> {
169
+ let query = rest::Query::from(query);
170
+ match query {
171
+ rest::Query::Nearest(nearest) => {
172
+ let vector = convert_vector_input_with_inferred(nearest.nearest, inferred)?;
173
+ Ok(Query::Vector(VectorQuery::Nearest(vector)))
174
+ }
175
+ rest::Query::Recommend(recommend) => {
176
+ let rest::RecommendInput {
177
+ positive,
178
+ negative,
179
+ strategy,
180
+ } = recommend.recommend;
181
+ let positives = positive
182
+ .into_iter()
183
+ .flatten()
184
+ .map(|v| convert_vector_input_with_inferred(v, inferred))
185
+ .collect::<Result<Vec<_>, _>>()?;
186
+ let negatives = negative
187
+ .into_iter()
188
+ .flatten()
189
+ .map(|v| convert_vector_input_with_inferred(v, inferred))
190
+ .collect::<Result<Vec<_>, _>>()?;
191
+ let reco_query = RecoQuery::new(positives, negatives);
192
+ match strategy.unwrap_or_default() {
193
+ rest::RecommendStrategy::AverageVector => Ok(Query::Vector(
194
+ VectorQuery::RecommendAverageVector(reco_query),
195
+ )),
196
+ rest::RecommendStrategy::BestScore => {
197
+ Ok(Query::Vector(VectorQuery::RecommendBestScore(reco_query)))
198
+ }
199
+ }
200
+ }
201
+ rest::Query::Discover(discover) => {
202
+ let rest::DiscoverInput { target, context } = discover.discover;
203
+ let target = convert_vector_input_with_inferred(target, inferred)?;
204
+ let context = context
205
+ .into_iter()
206
+ .flatten()
207
+ .map(|pair| context_pair_from_rest_with_inferred(pair, inferred))
208
+ .collect::<Result<Vec<_>, _>>()?;
209
+ Ok(Query::Vector(VectorQuery::Discover(DiscoveryQuery::new(
210
+ target, context,
211
+ ))))
212
+ }
213
+ rest::Query::Context(context) => {
214
+ let rest::ContextInput(context) = context.context;
215
+ let context = context
216
+ .into_iter()
217
+ .flatten()
218
+ .map(|pair| context_pair_from_rest_with_inferred(pair, inferred))
219
+ .collect::<Result<Vec<_>, _>>()?;
220
+ Ok(Query::Vector(VectorQuery::Context(ContextQuery::new(
221
+ context,
222
+ ))))
223
+ }
224
+ rest::Query::OrderBy(order_by) => Ok(Query::OrderBy(OrderBy::from(order_by.order_by))),
225
+ rest::Query::Fusion(fusion) => Ok(Query::Fusion(FusionInternal::from(fusion.fusion))),
226
+ rest::Query::Sample(sample) => Ok(Query::Sample(SampleInternal::from(sample.sample))),
227
+ }
228
+ }
229
+
230
+ fn convert_prefetch_with_inferred(
231
+ prefetch: rest::Prefetch,
232
+ inferred: &BatchAccumInferred,
233
+ ) -> Result<CollectionPrefetch, StorageError> {
234
+ let rest::Prefetch {
235
+ prefetch,
236
+ query,
237
+ using,
238
+ filter,
239
+ score_threshold,
240
+ params,
241
+ limit,
242
+ lookup_from,
243
+ } = prefetch;
244
+
245
+ let query = query
246
+ .map(|q| convert_query_with_inferred(q, inferred))
247
+ .transpose()?;
248
+ let nested_prefetches = prefetch
249
+ .map(|prefetches| {
250
+ prefetches
251
+ .into_iter()
252
+ .map(|p| convert_prefetch_with_inferred(p, inferred))
253
+ .collect::<Result<Vec<_>, _>>()
254
+ })
255
+ .transpose()?
256
+ .unwrap_or_default();
257
+
258
+ Ok(CollectionPrefetch {
259
+ prefetch: nested_prefetches,
260
+ query,
261
+ using: using.unwrap_or(DEFAULT_VECTOR_NAME.to_string()),
262
+ filter,
263
+ score_threshold,
264
+ limit: limit.unwrap_or(CollectionQueryRequest::DEFAULT_LIMIT),
265
+ params,
266
+ lookup_from,
267
+ })
268
+ }
269
+
270
+ fn context_pair_from_rest_with_inferred(
271
+ value: rest::ContextPair,
272
+ inferred: &BatchAccumInferred,
273
+ ) -> Result<ContextPair<VectorInputInternal>, StorageError> {
274
+ let rest::ContextPair { positive, negative } = value;
275
+ Ok(ContextPair {
276
+ positive: convert_vector_input_with_inferred(positive, inferred)?,
277
+ negative: convert_vector_input_with_inferred(negative, inferred)?,
278
+ })
279
+ }
280
+
281
+ #[cfg(test)]
282
+ mod tests {
283
+ use std::collections::HashMap;
284
+
285
+ use api::rest::schema::{Document, Image, InferenceObject, NearestQuery};
286
+ use collection::operations::point_ops::VectorPersisted;
287
+ use serde_json::json;
288
+
289
+ use super::*;
290
+
291
+ fn create_test_document(text: &str) -> Document {
292
+ Document {
293
+ text: text.to_string(),
294
+ model: "test-model".to_string(),
295
+ options: Default::default(),
296
+ }
297
+ }
298
+
299
+ fn create_test_image(url: &str) -> Image {
300
+ Image {
301
+ image: json!({"data": url.to_string()}),
302
+ model: "test-model".to_string(),
303
+ options: Default::default(),
304
+ }
305
+ }
306
+
307
+ fn create_test_object(data: &str) -> InferenceObject {
308
+ InferenceObject {
309
+ object: json!({"data": data}),
310
+ model: "test-model".to_string(),
311
+ options: Default::default(),
312
+ }
313
+ }
314
+
315
+ fn create_test_inferred_batch() -> BatchAccumInferred {
316
+ let mut objects = HashMap::new();
317
+
318
+ let doc = InferenceData::Document(create_test_document("test"));
319
+ let img = InferenceData::Image(create_test_image("test.jpg"));
320
+ let obj = InferenceData::Object(create_test_object("test"));
321
+
322
+ let dense_vector = vec![1.0, 2.0, 3.0];
323
+ let vector_persisted = VectorPersisted::Dense(dense_vector);
324
+
325
+ objects.insert(doc, vector_persisted.clone());
326
+ objects.insert(img, vector_persisted.clone());
327
+ objects.insert(obj, vector_persisted);
328
+
329
+ BatchAccumInferred { objects }
330
+ }
331
+
332
+ #[test]
333
+ fn test_convert_vector_input_with_inferred_dense() {
334
+ let inferred = create_test_inferred_batch();
335
+ let vector = rest::VectorInput::DenseVector(vec![1.0, 2.0, 3.0]);
336
+
337
+ let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
338
+ match result {
339
+ VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
340
+ assert_eq!(values, vec![1.0, 2.0, 3.0]);
341
+ }
342
+ _ => panic!("Expected dense vector"),
343
+ }
344
+ }
345
+
346
+ #[test]
347
+ fn test_convert_vector_input_with_inferred_document() {
348
+ let inferred = create_test_inferred_batch();
349
+ let doc = create_test_document("test");
350
+ let vector = rest::VectorInput::Document(doc);
351
+
352
+ let result = convert_vector_input_with_inferred(vector, &inferred).unwrap();
353
+ match result {
354
+ VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
355
+ assert_eq!(values, vec![1.0, 2.0, 3.0]);
356
+ }
357
+ _ => panic!("Expected dense vector from inference"),
358
+ }
359
+ }
360
+
361
+ #[test]
362
+ fn test_convert_vector_input_with_inferred_missing() {
363
+ let inferred = create_test_inferred_batch();
364
+ let doc = create_test_document("missing");
365
+ let vector = rest::VectorInput::Document(doc);
366
+
367
+ let result = convert_vector_input_with_inferred(vector, &inferred);
368
+ assert!(result.is_err());
369
+ assert!(result
370
+ .unwrap_err()
371
+ .to_string()
372
+ .contains("Missing inferred vector"));
373
+ }
374
+
375
+ #[test]
376
+ fn test_context_pair_from_rest_with_inferred() {
377
+ let inferred = create_test_inferred_batch();
378
+ let pair = rest::ContextPair {
379
+ positive: rest::VectorInput::DenseVector(vec![1.0, 2.0, 3.0]),
380
+ negative: rest::VectorInput::Document(create_test_document("test")),
381
+ };
382
+
383
+ let result = context_pair_from_rest_with_inferred(pair, &inferred).unwrap();
384
+ match (result.positive, result.negative) {
385
+ (
386
+ VectorInputInternal::Vector(VectorInternal::Dense(pos)),
387
+ VectorInputInternal::Vector(VectorInternal::Dense(neg)),
388
+ ) => {
389
+ assert_eq!(pos, vec![1.0, 2.0, 3.0]);
390
+ assert_eq!(neg, vec![1.0, 2.0, 3.0]);
391
+ }
392
+ _ => panic!("Expected dense vectors"),
393
+ }
394
+ }
395
+
396
+ #[test]
397
+ fn test_convert_query_with_inferred_nearest() {
398
+ let inferred = create_test_inferred_batch();
399
+ let nearest = NearestQuery {
400
+ nearest: rest::VectorInput::Document(create_test_document("test")),
401
+ };
402
+ let query = rest::QueryInterface::Query(rest::Query::Nearest(nearest));
403
+
404
+ let result = convert_query_with_inferred(query, &inferred).unwrap();
405
+ match result {
406
+ Query::Vector(VectorQuery::Nearest(vector)) => match vector {
407
+ VectorInputInternal::Vector(VectorInternal::Dense(values)) => {
408
+ assert_eq!(values, vec![1.0, 2.0, 3.0]);
409
+ }
410
+ _ => panic!("Expected dense vector"),
411
+ },
412
+ _ => panic!("Expected nearest query"),
413
+ }
414
+ }
415
+ }
src/common/inference/service.rs ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::HashMap;
2
+ use std::fmt::Display;
3
+ use std::hash::Hash;
4
+ use std::sync::Arc;
5
+ use std::time::Duration;
6
+
7
+ use api::rest::{Document, Image, InferenceObject};
8
+ use collection::operations::point_ops::VectorPersisted;
9
+ use parking_lot::RwLock;
10
+ use reqwest::Client;
11
+ use serde::{Deserialize, Serialize};
12
+ use serde_json::Value;
13
+ use storage::content_manager::errors::StorageError;
14
+
15
+ use crate::common::inference::config::InferenceConfig;
16
+
17
+ const DOCUMENT_DATA_TYPE: &str = "text";
18
+ const IMAGE_DATA_TYPE: &str = "image";
19
+ const OBJECT_DATA_TYPE: &str = "object";
20
+
21
+ #[derive(Debug, Serialize, Default, Clone, Copy)]
22
+ #[serde(rename_all = "lowercase")]
23
+ pub enum InferenceType {
24
+ #[default]
25
+ Update,
26
+ Search,
27
+ }
28
+
29
+ impl Display for InferenceType {
30
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31
+ write!(f, "{}", format!("{self:?}").to_lowercase())
32
+ }
33
+ }
34
+
35
+ #[derive(Debug, Serialize)]
36
+ pub struct InferenceRequest {
37
+ pub(crate) inputs: Vec<InferenceInput>,
38
+ pub(crate) inference: Option<InferenceType>,
39
+ #[serde(default)]
40
+ pub(crate) token: Option<String>,
41
+ }
42
+
43
+ #[derive(Debug, Serialize)]
44
+ pub struct InferenceInput {
45
+ data: Value,
46
+ data_type: String,
47
+ model: String,
48
+ options: Option<HashMap<String, Value>>,
49
+ }
50
+
51
+ #[derive(Debug, Deserialize)]
52
+ pub(crate) struct InferenceResponse {
53
+ pub(crate) embeddings: Vec<VectorPersisted>,
54
+ }
55
+
56
+ #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
57
+ pub enum InferenceData {
58
+ Document(Document),
59
+ Image(Image),
60
+ Object(InferenceObject),
61
+ }
62
+
63
+ impl InferenceData {
64
+ pub(crate) fn type_name(&self) -> &'static str {
65
+ match self {
66
+ InferenceData::Document(_) => "document",
67
+ InferenceData::Image(_) => "image",
68
+ InferenceData::Object(_) => "object",
69
+ }
70
+ }
71
+ }
72
+
73
+ impl From<InferenceData> for InferenceInput {
74
+ fn from(value: InferenceData) -> Self {
75
+ match value {
76
+ InferenceData::Document(doc) => {
77
+ let Document {
78
+ text,
79
+ model,
80
+ options,
81
+ } = doc;
82
+ InferenceInput {
83
+ data: Value::String(text),
84
+ data_type: DOCUMENT_DATA_TYPE.to_string(),
85
+ model: model.to_string(),
86
+ options: options.options,
87
+ }
88
+ }
89
+ InferenceData::Image(img) => {
90
+ let Image {
91
+ image,
92
+ model,
93
+ options,
94
+ } = img;
95
+ InferenceInput {
96
+ data: image,
97
+ data_type: IMAGE_DATA_TYPE.to_string(),
98
+ model: model.to_string(),
99
+ options: options.options,
100
+ }
101
+ }
102
+ InferenceData::Object(obj) => {
103
+ let InferenceObject {
104
+ object,
105
+ model,
106
+ options,
107
+ } = obj;
108
+ InferenceInput {
109
+ data: object,
110
+ data_type: OBJECT_DATA_TYPE.to_string(),
111
+ model: model.to_string(),
112
+ options: options.options,
113
+ }
114
+ }
115
+ }
116
+ }
117
+ }
118
+
119
+ pub struct InferenceService {
120
+ pub(crate) config: InferenceConfig,
121
+ pub(crate) client: Client,
122
+ }
123
+
124
+ static INFERENCE_SERVICE: RwLock<Option<Arc<InferenceService>>> = RwLock::new(None);
125
+
126
+ impl InferenceService {
127
+ pub fn new(config: InferenceConfig) -> Self {
128
+ let timeout = Duration::from_secs(config.timeout);
129
+ Self {
130
+ config,
131
+ client: Client::builder()
132
+ .timeout(timeout)
133
+ .build()
134
+ .expect("Invalid timeout value for HTTP client"),
135
+ }
136
+ }
137
+
138
+ pub fn init_global(config: InferenceConfig) -> Result<(), StorageError> {
139
+ let mut inference_service = INFERENCE_SERVICE.write();
140
+
141
+ if config.token.is_none() {
142
+ return Err(StorageError::service_error(
143
+ "Cannot initialize InferenceService: token is required but not provided in config",
144
+ ));
145
+ }
146
+
147
+ if config.address.is_none() || config.address.as_ref().unwrap().is_empty() {
148
+ return Err(StorageError::service_error(
149
+ "Cannot initialize InferenceService: address is required but not provided or empty in config"
150
+ ));
151
+ }
152
+
153
+ *inference_service = Some(Arc::new(Self::new(config)));
154
+ Ok(())
155
+ }
156
+
157
+ pub fn get_global() -> Option<Arc<InferenceService>> {
158
+ INFERENCE_SERVICE.read().as_ref().cloned()
159
+ }
160
+
161
+ pub(crate) fn validate(&self) -> Result<(), StorageError> {
162
+ if self
163
+ .config
164
+ .address
165
+ .as_ref()
166
+ .map_or(true, |url| url.is_empty())
167
+ {
168
+ return Err(StorageError::service_error(
169
+ "InferenceService configuration error: address is missing or empty",
170
+ ));
171
+ }
172
+ Ok(())
173
+ }
174
+
175
+ pub async fn infer(
176
+ &self,
177
+ inference_inputs: Vec<InferenceInput>,
178
+ inference_type: InferenceType,
179
+ ) -> Result<Vec<VectorPersisted>, StorageError> {
180
+ let request = InferenceRequest {
181
+ inputs: inference_inputs,
182
+ inference: Some(inference_type),
183
+ token: self.config.token.clone(),
184
+ };
185
+
186
+ let url = self.config.address.as_ref().ok_or_else(|| {
187
+ StorageError::service_error(
188
+ "InferenceService URL not configured - please provide valid address in config",
189
+ )
190
+ })?;
191
+
192
+ let response = self
193
+ .client
194
+ .post(url)
195
+ .json(&request)
196
+ .send()
197
+ .await
198
+ .map_err(|e| {
199
+ let error_body = e.to_string();
200
+ StorageError::service_error(format!(
201
+ "Failed to send inference request to {url}: {e}, error details: {error_body}",
202
+ ))
203
+ })?;
204
+
205
+ let status = response.status();
206
+ let response_body = response.text().await.map_err(|e| {
207
+ StorageError::service_error(format!("Failed to read inference response body: {e}",))
208
+ })?;
209
+
210
+ Self::handle_inference_response(status, &response_body)
211
+ }
212
+
213
+ pub(crate) fn handle_inference_response(
214
+ status: reqwest::StatusCode,
215
+ response_body: &str,
216
+ ) -> Result<Vec<VectorPersisted>, StorageError> {
217
+ match status {
218
+ reqwest::StatusCode::OK => {
219
+ let inference_response: InferenceResponse = serde_json::from_str(response_body)
220
+ .map_err(|e| {
221
+ StorageError::service_error(format!(
222
+ "Failed to parse successful inference response: {e}. Response body: {response_body}",
223
+ ))
224
+ })?;
225
+
226
+ if inference_response.embeddings.is_empty() {
227
+ Err(StorageError::service_error(
228
+ "Inference response contained no embeddings - this may indicate an issue with the model or input"
229
+ ))
230
+ } else {
231
+ Ok(inference_response.embeddings)
232
+ }
233
+ }
234
+ reqwest::StatusCode::BAD_REQUEST => {
235
+ let error_json: Value = serde_json::from_str(response_body).map_err(|e| {
236
+ StorageError::service_error(format!(
237
+ "Failed to parse error response: {e}. Raw response: {response_body}",
238
+ ))
239
+ })?;
240
+
241
+ if let Some(error_message) = error_json["error"].as_str() {
242
+ Err(StorageError::bad_request(format!(
243
+ "Inference request validation failed: {error_message}",
244
+ )))
245
+ } else {
246
+ Err(StorageError::bad_request(format!(
247
+ "Invalid inference request: {response_body}",
248
+ )))
249
+ }
250
+ }
251
+ status @ (reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN) => {
252
+ Err(StorageError::service_error(format!(
253
+ "Authentication failed for inference service ({status}): {response_body}",
254
+ )))
255
+ }
256
+ status @ (reqwest::StatusCode::INTERNAL_SERVER_ERROR
257
+ | reqwest::StatusCode::SERVICE_UNAVAILABLE
258
+ | reqwest::StatusCode::GATEWAY_TIMEOUT) => Err(StorageError::service_error(format!(
259
+ "Inference service error ({status}): {response_body}",
260
+ ))),
261
+ _ => Err(StorageError::service_error(format!(
262
+ "Unexpected inference service response ({status}): {response_body}"
263
+ ))),
264
+ }
265
+ }
266
+ }
src/common/inference/update_requests.rs ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::HashMap;
2
+
3
+ use api::rest::{Batch, BatchVectorStruct, PointStruct, PointVectors, Vector, VectorStruct};
4
+ use collection::operations::point_ops::{
5
+ BatchPersisted, BatchVectorStructPersisted, PointStructPersisted, VectorPersisted,
6
+ VectorStructPersisted,
7
+ };
8
+ use collection::operations::vector_ops::PointVectorsPersisted;
9
+ use storage::content_manager::errors::StorageError;
10
+
11
+ use crate::common::inference::batch_processing::BatchAccum;
12
+ use crate::common::inference::infer_processing::BatchAccumInferred;
13
+ use crate::common::inference::service::{InferenceData, InferenceType};
14
+
15
+ pub async fn convert_point_struct(
16
+ point_structs: Vec<PointStruct>,
17
+ inference_type: InferenceType,
18
+ ) -> Result<Vec<PointStructPersisted>, StorageError> {
19
+ let mut batch_accum = BatchAccum::new();
20
+
21
+ for point_struct in &point_structs {
22
+ match &point_struct.vector {
23
+ VectorStruct::Named(named) => {
24
+ for vector in named.values() {
25
+ match vector {
26
+ Vector::Document(doc) => {
27
+ batch_accum.add(InferenceData::Document(doc.clone()))
28
+ }
29
+ Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
30
+ Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
31
+ Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {}
32
+ }
33
+ }
34
+ }
35
+ VectorStruct::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())),
36
+ VectorStruct::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
37
+ VectorStruct::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
38
+ VectorStruct::MultiDense(_) | VectorStruct::Single(_) => {}
39
+ }
40
+ }
41
+
42
+ let inferred = if !batch_accum.objects.is_empty() {
43
+ Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?)
44
+ } else {
45
+ None
46
+ };
47
+
48
+ let mut converted_points: Vec<PointStructPersisted> = Vec::new();
49
+ for point_struct in point_structs {
50
+ let PointStruct {
51
+ id,
52
+ vector,
53
+ payload,
54
+ } = point_struct;
55
+
56
+ let converted_vector_struct = match vector {
57
+ VectorStruct::Single(single) => VectorStructPersisted::Single(single),
58
+ VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
59
+ VectorStruct::Named(named) => {
60
+ let mut named_vectors = HashMap::new();
61
+ for (name, vector) in named {
62
+ let converted_vector = match &inferred {
63
+ Some(inferred) => convert_vector_with_inferred(vector, inferred)?,
64
+ None => match vector {
65
+ Vector::Dense(dense) => VectorPersisted::Dense(dense),
66
+ Vector::Sparse(sparse) => VectorPersisted::Sparse(sparse),
67
+ Vector::MultiDense(multi) => VectorPersisted::MultiDense(multi),
68
+ Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => {
69
+ return Err(StorageError::inference_error(
70
+ "Inference required but service returned no results",
71
+ ))
72
+ }
73
+ },
74
+ };
75
+ named_vectors.insert(name, converted_vector);
76
+ }
77
+ VectorStructPersisted::Named(named_vectors)
78
+ }
79
+ VectorStruct::Document(doc) => {
80
+ let vector = match &inferred {
81
+ Some(inferred) => {
82
+ convert_vector_with_inferred(Vector::Document(doc), inferred)?
83
+ }
84
+ None => {
85
+ return Err(StorageError::inference_error(
86
+ "Inference required but service returned no results",
87
+ ))
88
+ }
89
+ };
90
+ match vector {
91
+ VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
92
+ VectorPersisted::Sparse(_) => {
93
+ return Err(StorageError::bad_request("Sparse vector should be named"));
94
+ }
95
+ VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
96
+ }
97
+ }
98
+ VectorStruct::Image(img) => {
99
+ let vector = match &inferred {
100
+ Some(inferred) => convert_vector_with_inferred(Vector::Image(img), inferred)?,
101
+ None => {
102
+ return Err(StorageError::inference_error(
103
+ "Inference required but service returned no results",
104
+ ))
105
+ }
106
+ };
107
+ match vector {
108
+ VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
109
+ VectorPersisted::Sparse(_) => {
110
+ return Err(StorageError::bad_request("Sparse vector should be named"));
111
+ }
112
+ VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
113
+ }
114
+ }
115
+ VectorStruct::Object(obj) => {
116
+ let vector = match &inferred {
117
+ Some(inferred) => convert_vector_with_inferred(Vector::Object(obj), inferred)?,
118
+ None => {
119
+ return Err(StorageError::inference_error(
120
+ "Inference required but service returned no results",
121
+ ))
122
+ }
123
+ };
124
+ match vector {
125
+ VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
126
+ VectorPersisted::Sparse(_) => {
127
+ return Err(StorageError::bad_request("Sparse vector should be named"));
128
+ }
129
+ VectorPersisted::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
130
+ }
131
+ }
132
+ };
133
+
134
+ let converted = PointStructPersisted {
135
+ id,
136
+ vector: converted_vector_struct,
137
+ payload,
138
+ };
139
+
140
+ converted_points.push(converted);
141
+ }
142
+
143
+ Ok(converted_points)
144
+ }
145
+
146
+ pub async fn convert_batch(batch: Batch) -> Result<BatchPersisted, StorageError> {
147
+ let Batch {
148
+ ids,
149
+ vectors,
150
+ payloads,
151
+ } = batch;
152
+
153
+ let batch_persisted = BatchPersisted {
154
+ ids,
155
+ vectors: match vectors {
156
+ BatchVectorStruct::Single(single) => BatchVectorStructPersisted::Single(single),
157
+ BatchVectorStruct::MultiDense(multi) => BatchVectorStructPersisted::MultiDense(multi),
158
+ BatchVectorStruct::Named(named) => {
159
+ let mut named_vectors = HashMap::new();
160
+
161
+ for (name, vectors) in named {
162
+ let converted_vectors = convert_vectors(vectors, InferenceType::Update).await?;
163
+ named_vectors.insert(name, converted_vectors);
164
+ }
165
+
166
+ BatchVectorStructPersisted::Named(named_vectors)
167
+ }
168
+ BatchVectorStruct::Document(_) => {
169
+ return Err(StorageError::inference_error(
170
+ "Document processing is not supported in batch operations.",
171
+ ))
172
+ }
173
+ BatchVectorStruct::Image(_) => {
174
+ return Err(StorageError::inference_error(
175
+ "Image processing is not supported in batch operations.",
176
+ ))
177
+ }
178
+ BatchVectorStruct::Object(_) => {
179
+ return Err(StorageError::inference_error(
180
+ "Object processing is not supported in batch operations.",
181
+ ))
182
+ }
183
+ },
184
+ payloads,
185
+ };
186
+
187
+ Ok(batch_persisted)
188
+ }
189
+
190
+ pub async fn convert_point_vectors(
191
+ point_vectors_list: Vec<PointVectors>,
192
+ inference_type: InferenceType,
193
+ ) -> Result<Vec<PointVectorsPersisted>, StorageError> {
194
+ let mut converted_point_vectors = Vec::new();
195
+ let mut batch_accum = BatchAccum::new();
196
+
197
+ for point_vectors in &point_vectors_list {
198
+ if let VectorStruct::Named(named) = &point_vectors.vector {
199
+ for vector in named.values() {
200
+ match vector {
201
+ Vector::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())),
202
+ Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
203
+ Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
204
+ Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {}
205
+ }
206
+ }
207
+ }
208
+ }
209
+
210
+ let inferred = if !batch_accum.objects.is_empty() {
211
+ Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?)
212
+ } else {
213
+ None
214
+ };
215
+
216
+ for point_vectors in point_vectors_list {
217
+ let PointVectors { id, vector } = point_vectors;
218
+
219
+ let converted_vector = match vector {
220
+ VectorStruct::Single(dense) => VectorStructPersisted::Single(dense),
221
+ VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
222
+ VectorStruct::Named(named) => {
223
+ let mut converted = HashMap::new();
224
+
225
+ for (name, vec) in named {
226
+ let converted_vec = match &inferred {
227
+ Some(inferred) => convert_vector_with_inferred(vec, inferred)?,
228
+ None => match vec {
229
+ Vector::Dense(dense) => VectorPersisted::Dense(dense),
230
+ Vector::Sparse(sparse) => VectorPersisted::Sparse(sparse),
231
+ Vector::MultiDense(multi) => VectorPersisted::MultiDense(multi),
232
+ Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => {
233
+ return Err(StorageError::inference_error(
234
+ "Inference required but service returned no results",
235
+ ))
236
+ }
237
+ },
238
+ };
239
+ converted.insert(name, converted_vec);
240
+ }
241
+
242
+ VectorStructPersisted::Named(converted)
243
+ }
244
+ VectorStruct::Document(_) => {
245
+ return Err(StorageError::inference_error(
246
+ "Document processing is not supported for point vectors.",
247
+ ))
248
+ }
249
+ VectorStruct::Image(_) => {
250
+ return Err(StorageError::inference_error(
251
+ "Image processing is not supported for point vectors.",
252
+ ))
253
+ }
254
+ VectorStruct::Object(_) => {
255
+ return Err(StorageError::inference_error(
256
+ "Object processing is not supported for point vectors.",
257
+ ))
258
+ }
259
+ };
260
+
261
+ let converted_point_vector = PointVectorsPersisted {
262
+ id,
263
+ vector: converted_vector,
264
+ };
265
+
266
+ converted_point_vectors.push(converted_point_vector);
267
+ }
268
+
269
+ Ok(converted_point_vectors)
270
+ }
271
+
272
+ fn convert_point_struct_with_inferred(
273
+ point_structs: Vec<PointStruct>,
274
+ inferred: &BatchAccumInferred,
275
+ ) -> Result<Vec<PointStructPersisted>, StorageError> {
276
+ point_structs
277
+ .into_iter()
278
+ .map(|point_struct| {
279
+ let PointStruct {
280
+ id,
281
+ vector,
282
+ payload,
283
+ } = point_struct;
284
+ let converted_vector_struct = match vector {
285
+ VectorStruct::Single(single) => VectorStructPersisted::Single(single),
286
+ VectorStruct::MultiDense(multi) => VectorStructPersisted::MultiDense(multi),
287
+ VectorStruct::Named(named) => {
288
+ let mut named_vectors = HashMap::new();
289
+ for (name, vector) in named {
290
+ let converted_vector = convert_vector_with_inferred(vector, inferred)?;
291
+ named_vectors.insert(name, converted_vector);
292
+ }
293
+ VectorStructPersisted::Named(named_vectors)
294
+ }
295
+ VectorStruct::Document(doc) => {
296
+ let vector = convert_vector_with_inferred(Vector::Document(doc), inferred)?;
297
+ match vector {
298
+ VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
299
+ VectorPersisted::Sparse(_) => {
300
+ return Err(StorageError::bad_request("Sparse vector should be named"))
301
+ }
302
+ VectorPersisted::MultiDense(multi) => {
303
+ VectorStructPersisted::MultiDense(multi)
304
+ }
305
+ }
306
+ }
307
+ VectorStruct::Image(img) => {
308
+ let vector = convert_vector_with_inferred(Vector::Image(img), inferred)?;
309
+ match vector {
310
+ VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
311
+ VectorPersisted::Sparse(_) => {
312
+ return Err(StorageError::bad_request("Sparse vector should be named"))
313
+ }
314
+ VectorPersisted::MultiDense(multi) => {
315
+ VectorStructPersisted::MultiDense(multi)
316
+ }
317
+ }
318
+ }
319
+ VectorStruct::Object(obj) => {
320
+ let vector = convert_vector_with_inferred(Vector::Object(obj), inferred)?;
321
+ match vector {
322
+ VectorPersisted::Dense(dense) => VectorStructPersisted::Single(dense),
323
+ VectorPersisted::Sparse(_) => {
324
+ return Err(StorageError::bad_request("Sparse vector should be named"))
325
+ }
326
+ VectorPersisted::MultiDense(multi) => {
327
+ VectorStructPersisted::MultiDense(multi)
328
+ }
329
+ }
330
+ }
331
+ };
332
+
333
+ Ok(PointStructPersisted {
334
+ id,
335
+ vector: converted_vector_struct,
336
+ payload,
337
+ })
338
+ })
339
+ .collect()
340
+ }
341
+
342
+ pub async fn convert_vectors(
343
+ vectors: Vec<Vector>,
344
+ inference_type: InferenceType,
345
+ ) -> Result<Vec<VectorPersisted>, StorageError> {
346
+ let mut batch_accum = BatchAccum::new();
347
+ for vector in &vectors {
348
+ match vector {
349
+ Vector::Document(doc) => batch_accum.add(InferenceData::Document(doc.clone())),
350
+ Vector::Image(img) => batch_accum.add(InferenceData::Image(img.clone())),
351
+ Vector::Object(obj) => batch_accum.add(InferenceData::Object(obj.clone())),
352
+ Vector::Dense(_) | Vector::Sparse(_) | Vector::MultiDense(_) => {}
353
+ }
354
+ }
355
+
356
+ let inferred = if !batch_accum.objects.is_empty() {
357
+ Some(BatchAccumInferred::from_batch_accum(batch_accum, inference_type).await?)
358
+ } else {
359
+ None
360
+ };
361
+
362
+ vectors
363
+ .into_iter()
364
+ .map(|vector| match &inferred {
365
+ Some(inferred) => convert_vector_with_inferred(vector, inferred),
366
+ None => match vector {
367
+ Vector::Dense(dense) => Ok(VectorPersisted::Dense(dense)),
368
+ Vector::Sparse(sparse) => Ok(VectorPersisted::Sparse(sparse)),
369
+ Vector::MultiDense(multi) => Ok(VectorPersisted::MultiDense(multi)),
370
+ Vector::Document(_) | Vector::Image(_) | Vector::Object(_) => {
371
+ Err(StorageError::inference_error(
372
+ "Inference required but service returned no results",
373
+ ))
374
+ }
375
+ },
376
+ })
377
+ .collect()
378
+ }
379
+
380
+ fn convert_vector_with_inferred(
381
+ vector: Vector,
382
+ inferred: &BatchAccumInferred,
383
+ ) -> Result<VectorPersisted, StorageError> {
384
+ match vector {
385
+ Vector::Dense(dense) => Ok(VectorPersisted::Dense(dense)),
386
+ Vector::Sparse(sparse) => Ok(VectorPersisted::Sparse(sparse)),
387
+ Vector::MultiDense(multi) => Ok(VectorPersisted::MultiDense(multi)),
388
+ Vector::Document(doc) => {
389
+ let data = InferenceData::Document(doc);
390
+ inferred.get_vector(&data).cloned().ok_or_else(|| {
391
+ StorageError::inference_error("Missing inferred vector for document")
392
+ })
393
+ }
394
+ Vector::Image(img) => {
395
+ let data = InferenceData::Image(img);
396
+ inferred
397
+ .get_vector(&data)
398
+ .cloned()
399
+ .ok_or_else(|| StorageError::inference_error("Missing inferred vector for image"))
400
+ }
401
+ Vector::Object(obj) => {
402
+ let data = InferenceData::Object(obj);
403
+ inferred
404
+ .get_vector(&data)
405
+ .cloned()
406
+ .ok_or_else(|| StorageError::inference_error("Missing inferred vector for object"))
407
+ }
408
+ }
409
+ }
src/common/metrics.rs ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use prometheus::proto::{Counter, Gauge, LabelPair, Metric, MetricFamily, MetricType};
2
+ use prometheus::TextEncoder;
3
+ use segment::common::operation_time_statistics::OperationDurationStatistics;
4
+
5
+ use crate::common::telemetry::TelemetryData;
6
+ use crate::common::telemetry_ops::app_telemetry::{AppBuildTelemetry, AppFeaturesTelemetry};
7
+ use crate::common::telemetry_ops::cluster_telemetry::{ClusterStatusTelemetry, ClusterTelemetry};
8
+ use crate::common::telemetry_ops::collections_telemetry::{
9
+ CollectionTelemetryEnum, CollectionsTelemetry,
10
+ };
11
+ use crate::common::telemetry_ops::memory_telemetry::MemoryTelemetry;
12
+ use crate::common::telemetry_ops::requests_telemetry::{
13
+ GrpcTelemetry, RequestsTelemetry, WebApiTelemetry,
14
+ };
15
+
16
+ /// Whitelist for REST endpoints in metrics output.
17
+ ///
18
+ /// Contains selection of search, recommend, scroll and upsert endpoints.
19
+ ///
20
+ /// This array *must* be sorted.
21
+ const REST_ENDPOINT_WHITELIST: &[&str] = &[
22
+ "/collections/{name}/index",
23
+ "/collections/{name}/points",
24
+ "/collections/{name}/points/batch",
25
+ "/collections/{name}/points/count",
26
+ "/collections/{name}/points/delete",
27
+ "/collections/{name}/points/discover",
28
+ "/collections/{name}/points/discover/batch",
29
+ "/collections/{name}/points/facet",
30
+ "/collections/{name}/points/payload",
31
+ "/collections/{name}/points/payload/clear",
32
+ "/collections/{name}/points/payload/delete",
33
+ "/collections/{name}/points/query",
34
+ "/collections/{name}/points/query/batch",
35
+ "/collections/{name}/points/query/groups",
36
+ "/collections/{name}/points/recommend",
37
+ "/collections/{name}/points/recommend/batch",
38
+ "/collections/{name}/points/recommend/groups",
39
+ "/collections/{name}/points/scroll",
40
+ "/collections/{name}/points/search",
41
+ "/collections/{name}/points/search/batch",
42
+ "/collections/{name}/points/search/groups",
43
+ "/collections/{name}/points/search/matrix/offsets",
44
+ "/collections/{name}/points/search/matrix/pairs",
45
+ "/collections/{name}/points/vectors",
46
+ "/collections/{name}/points/vectors/delete",
47
+ ];
48
+
49
+ /// Whitelist for GRPC endpoints in metrics output.
50
+ ///
51
+ /// Contains selection of search, recommend, scroll and upsert endpoints.
52
+ ///
53
+ /// This array *must* be sorted.
54
+ const GRPC_ENDPOINT_WHITELIST: &[&str] = &[
55
+ "/qdrant.Points/ClearPayload",
56
+ "/qdrant.Points/Count",
57
+ "/qdrant.Points/Delete",
58
+ "/qdrant.Points/DeletePayload",
59
+ "/qdrant.Points/Discover",
60
+ "/qdrant.Points/DiscoverBatch",
61
+ "/qdrant.Points/Facet",
62
+ "/qdrant.Points/Get",
63
+ "/qdrant.Points/OverwritePayload",
64
+ "/qdrant.Points/Query",
65
+ "/qdrant.Points/QueryBatch",
66
+ "/qdrant.Points/QueryGroups",
67
+ "/qdrant.Points/Recommend",
68
+ "/qdrant.Points/RecommendBatch",
69
+ "/qdrant.Points/RecommendGroups",
70
+ "/qdrant.Points/Scroll",
71
+ "/qdrant.Points/Search",
72
+ "/qdrant.Points/SearchBatch",
73
+ "/qdrant.Points/SearchGroups",
74
+ "/qdrant.Points/SetPayload",
75
+ "/qdrant.Points/UpdateBatch",
76
+ "/qdrant.Points/UpdateVectors",
77
+ "/qdrant.Points/Upsert",
78
+ ];
79
+
80
+ /// For REST requests, only report timings when having this HTTP response status.
81
+ const REST_TIMINGS_FOR_STATUS: u16 = 200;
82
+
83
+ /// Encapsulates metrics data in Prometheus format.
84
+ pub struct MetricsData {
85
+ metrics: Vec<MetricFamily>,
86
+ }
87
+
88
+ impl MetricsData {
89
+ pub fn format_metrics(&self) -> String {
90
+ TextEncoder::new().encode_to_string(&self.metrics).unwrap()
91
+ }
92
+ }
93
+
94
+ impl From<TelemetryData> for MetricsData {
95
+ fn from(telemetry_data: TelemetryData) -> Self {
96
+ let mut metrics = vec![];
97
+ telemetry_data.add_metrics(&mut metrics);
98
+ Self { metrics }
99
+ }
100
+ }
101
+
102
+ trait MetricsProvider {
103
+ /// Add metrics definitions for this.
104
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>);
105
+ }
106
+
107
+ impl MetricsProvider for TelemetryData {
108
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
109
+ self.app.add_metrics(metrics);
110
+ self.collections.add_metrics(metrics);
111
+ self.cluster.add_metrics(metrics);
112
+ self.requests.add_metrics(metrics);
113
+ if let Some(mem) = &self.memory {
114
+ mem.add_metrics(metrics);
115
+ }
116
+ }
117
+ }
118
+
119
+ impl MetricsProvider for AppBuildTelemetry {
120
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
121
+ metrics.push(metric_family(
122
+ "app_info",
123
+ "information about qdrant server",
124
+ MetricType::GAUGE,
125
+ vec![gauge(
126
+ 1.0,
127
+ &[("name", &self.name), ("version", &self.version)],
128
+ )],
129
+ ));
130
+ self.features.iter().for_each(|f| f.add_metrics(metrics));
131
+ }
132
+ }
133
+
134
+ impl MetricsProvider for AppFeaturesTelemetry {
135
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
136
+ metrics.push(metric_family(
137
+ "app_status_recovery_mode",
138
+ "features enabled in qdrant server",
139
+ MetricType::GAUGE,
140
+ vec![gauge(if self.recovery_mode { 1.0 } else { 0.0 }, &[])],
141
+ ))
142
+ }
143
+ }
144
+
145
+ impl MetricsProvider for CollectionsTelemetry {
146
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
147
+ let vector_count = self
148
+ .collections
149
+ .iter()
150
+ .flatten()
151
+ .map(|p| match p {
152
+ CollectionTelemetryEnum::Aggregated(a) => a.vectors,
153
+ CollectionTelemetryEnum::Full(c) => c.count_vectors(),
154
+ })
155
+ .sum::<usize>();
156
+ metrics.push(metric_family(
157
+ "collections_total",
158
+ "number of collections",
159
+ MetricType::GAUGE,
160
+ vec![gauge(self.number_of_collections as f64, &[])],
161
+ ));
162
+ metrics.push(metric_family(
163
+ "collections_vector_total",
164
+ "total number of vectors in all collections",
165
+ MetricType::GAUGE,
166
+ vec![gauge(vector_count as f64, &[])],
167
+ ));
168
+ }
169
+ }
170
+
171
+ impl MetricsProvider for ClusterTelemetry {
172
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
173
+ let ClusterTelemetry {
174
+ enabled,
175
+ status,
176
+ config: _,
177
+ peers: _,
178
+ metadata: _,
179
+ } = self;
180
+
181
+ metrics.push(metric_family(
182
+ "cluster_enabled",
183
+ "is cluster support enabled",
184
+ MetricType::GAUGE,
185
+ vec![gauge(if *enabled { 1.0 } else { 0.0 }, &[])],
186
+ ));
187
+
188
+ if let Some(ref status) = status {
189
+ status.add_metrics(metrics);
190
+ }
191
+ }
192
+ }
193
+
194
+ impl MetricsProvider for ClusterStatusTelemetry {
195
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
196
+ metrics.push(metric_family(
197
+ "cluster_peers_total",
198
+ "total number of cluster peers",
199
+ MetricType::GAUGE,
200
+ vec![gauge(self.number_of_peers as f64, &[])],
201
+ ));
202
+ metrics.push(metric_family(
203
+ "cluster_term",
204
+ "current cluster term",
205
+ MetricType::COUNTER,
206
+ vec![counter(self.term as f64, &[])],
207
+ ));
208
+
209
+ if let Some(ref peer_id) = self.peer_id.map(|p| p.to_string()) {
210
+ metrics.push(metric_family(
211
+ "cluster_commit",
212
+ "index of last committed (finalized) operation cluster peer is aware of",
213
+ MetricType::COUNTER,
214
+ vec![counter(self.commit as f64, &[("peer_id", peer_id)])],
215
+ ));
216
+ metrics.push(metric_family(
217
+ "cluster_pending_operations_total",
218
+ "total number of pending operations for cluster peer",
219
+ MetricType::GAUGE,
220
+ vec![gauge(self.pending_operations as f64, &[])],
221
+ ));
222
+ metrics.push(metric_family(
223
+ "cluster_voter",
224
+ "is cluster peer a voter or learner",
225
+ MetricType::GAUGE,
226
+ vec![gauge(if self.is_voter { 1.0 } else { 0.0 }, &[])],
227
+ ));
228
+ }
229
+ }
230
+ }
231
+
232
+ impl MetricsProvider for RequestsTelemetry {
233
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
234
+ self.rest.add_metrics(metrics);
235
+ self.grpc.add_metrics(metrics);
236
+ }
237
+ }
238
+
239
+ impl MetricsProvider for WebApiTelemetry {
240
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
241
+ let mut builder = OperationDurationMetricsBuilder::default();
242
+ for (endpoint, responses) in &self.responses {
243
+ let Some((method, endpoint)) = endpoint.split_once(' ') else {
244
+ continue;
245
+ };
246
+ // Endpoint must be whitelisted
247
+ if REST_ENDPOINT_WHITELIST.binary_search(&endpoint).is_err() {
248
+ continue;
249
+ }
250
+ for (status, stats) in responses {
251
+ builder.add(
252
+ stats,
253
+ &[
254
+ ("method", method),
255
+ ("endpoint", endpoint),
256
+ ("status", &status.to_string()),
257
+ ],
258
+ *status == REST_TIMINGS_FOR_STATUS,
259
+ );
260
+ }
261
+ }
262
+ builder.build("rest", metrics);
263
+ }
264
+ }
265
+
266
+ impl MetricsProvider for GrpcTelemetry {
267
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
268
+ let mut builder = OperationDurationMetricsBuilder::default();
269
+ for (endpoint, stats) in &self.responses {
270
+ // Endpoint must be whitelisted
271
+ if GRPC_ENDPOINT_WHITELIST
272
+ .binary_search(&endpoint.as_str())
273
+ .is_err()
274
+ {
275
+ continue;
276
+ }
277
+ builder.add(stats, &[("endpoint", endpoint.as_str())], true);
278
+ }
279
+ builder.build("grpc", metrics);
280
+ }
281
+ }
282
+
283
+ impl MetricsProvider for MemoryTelemetry {
284
+ fn add_metrics(&self, metrics: &mut Vec<MetricFamily>) {
285
+ metrics.push(metric_family(
286
+ "memory_active_bytes",
287
+ "Total number of bytes in active pages allocated by the application",
288
+ MetricType::GAUGE,
289
+ vec![gauge(self.active_bytes as f64, &[])],
290
+ ));
291
+ metrics.push(metric_family(
292
+ "memory_allocated_bytes",
293
+ "Total number of bytes allocated by the application",
294
+ MetricType::GAUGE,
295
+ vec![gauge(self.allocated_bytes as f64, &[])],
296
+ ));
297
+ metrics.push(metric_family(
298
+ "memory_metadata_bytes",
299
+ "Total number of bytes dedicated to metadata",
300
+ MetricType::GAUGE,
301
+ vec![gauge(self.metadata_bytes as f64, &[])],
302
+ ));
303
+ metrics.push(metric_family(
304
+ "memory_resident_bytes",
305
+ "Maximum number of bytes in physically resident data pages mapped",
306
+ MetricType::GAUGE,
307
+ vec![gauge(self.resident_bytes as f64, &[])],
308
+ ));
309
+ metrics.push(metric_family(
310
+ "memory_retained_bytes",
311
+ "Total number of bytes in virtual memory mappings",
312
+ MetricType::GAUGE,
313
+ vec![gauge(self.retained_bytes as f64, &[])],
314
+ ));
315
+ }
316
+ }
317
+
318
+ /// A helper struct to build a vector of [`MetricFamily`] out of a collection of
319
+ /// [`OperationDurationStatistics`].
320
+ #[derive(Default)]
321
+ struct OperationDurationMetricsBuilder {
322
+ total: Vec<Metric>,
323
+ fail_total: Vec<Metric>,
324
+ avg_secs: Vec<Metric>,
325
+ min_secs: Vec<Metric>,
326
+ max_secs: Vec<Metric>,
327
+ duration_histogram_secs: Vec<Metric>,
328
+ }
329
+
330
+ impl OperationDurationMetricsBuilder {
331
+ /// Add metrics for the provided statistics.
332
+ /// If `add_timings` is `false`, only the total and fail_total counters will be added.
333
+ pub fn add(
334
+ &mut self,
335
+ stat: &OperationDurationStatistics,
336
+ labels: &[(&str, &str)],
337
+ add_timings: bool,
338
+ ) {
339
+ self.total.push(counter(stat.count as f64, labels));
340
+ self.fail_total
341
+ .push(counter(stat.fail_count as f64, labels));
342
+
343
+ if !add_timings {
344
+ return;
345
+ }
346
+
347
+ self.avg_secs.push(gauge(
348
+ f64::from(stat.avg_duration_micros.unwrap_or(0.0)) / 1_000_000.0,
349
+ labels,
350
+ ));
351
+ self.min_secs.push(gauge(
352
+ f64::from(stat.min_duration_micros.unwrap_or(0.0)) / 1_000_000.0,
353
+ labels,
354
+ ));
355
+ self.max_secs.push(gauge(
356
+ f64::from(stat.max_duration_micros.unwrap_or(0.0)) / 1_000_000.0,
357
+ labels,
358
+ ));
359
+ self.duration_histogram_secs.push(histogram(
360
+ stat.count as u64,
361
+ stat.total_duration_micros as f64 / 1_000_000.0,
362
+ &stat
363
+ .duration_micros_histogram
364
+ .iter()
365
+ .map(|&(b, c)| (f64::from(b) / 1_000_000.0, c as u64))
366
+ .collect::<Vec<_>>(),
367
+ labels,
368
+ ));
369
+ }
370
+
371
+ /// Build metrics and add them to the provided vector.
372
+ pub fn build(self, prefix: &str, metrics: &mut Vec<MetricFamily>) {
373
+ if !self.total.is_empty() {
374
+ metrics.push(metric_family(
375
+ &format!("{prefix}_responses_total"),
376
+ "total number of responses",
377
+ MetricType::COUNTER,
378
+ self.total,
379
+ ));
380
+ }
381
+ if !self.fail_total.is_empty() {
382
+ metrics.push(metric_family(
383
+ &format!("{prefix}_responses_fail_total"),
384
+ "total number of failed responses",
385
+ MetricType::COUNTER,
386
+ self.fail_total,
387
+ ));
388
+ }
389
+ if !self.avg_secs.is_empty() {
390
+ metrics.push(metric_family(
391
+ &format!("{prefix}_responses_avg_duration_seconds"),
392
+ "average response duration",
393
+ MetricType::GAUGE,
394
+ self.avg_secs,
395
+ ));
396
+ }
397
+ if !self.min_secs.is_empty() {
398
+ metrics.push(metric_family(
399
+ &format!("{prefix}_responses_min_duration_seconds"),
400
+ "minimum response duration",
401
+ MetricType::GAUGE,
402
+ self.min_secs,
403
+ ));
404
+ }
405
+ if !self.max_secs.is_empty() {
406
+ metrics.push(metric_family(
407
+ &format!("{prefix}_responses_max_duration_seconds"),
408
+ "maximum response duration",
409
+ MetricType::GAUGE,
410
+ self.max_secs,
411
+ ));
412
+ }
413
+ if !self.duration_histogram_secs.is_empty() {
414
+ metrics.push(metric_family(
415
+ &format!("{prefix}_responses_duration_seconds"),
416
+ "response duration histogram",
417
+ MetricType::HISTOGRAM,
418
+ self.duration_histogram_secs,
419
+ ));
420
+ }
421
+ }
422
+ }
423
+
424
+ fn metric_family(name: &str, help: &str, r#type: MetricType, metrics: Vec<Metric>) -> MetricFamily {
425
+ let mut metric_family = MetricFamily::default();
426
+ metric_family.set_name(name.into());
427
+ metric_family.set_help(help.into());
428
+ metric_family.set_field_type(r#type);
429
+ metric_family.set_metric(metrics);
430
+ metric_family
431
+ }
432
+
433
+ fn counter(value: f64, labels: &[(&str, &str)]) -> Metric {
434
+ let mut metric = Metric::default();
435
+ metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect());
436
+ metric.set_counter({
437
+ let mut counter = Counter::default();
438
+ counter.set_value(value);
439
+ counter
440
+ });
441
+ metric
442
+ }
443
+
444
+ fn gauge(value: f64, labels: &[(&str, &str)]) -> Metric {
445
+ let mut metric = Metric::default();
446
+ metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect());
447
+ metric.set_gauge({
448
+ let mut gauge = Gauge::default();
449
+ gauge.set_value(value);
450
+ gauge
451
+ });
452
+ metric
453
+ }
454
+
455
+ fn histogram(
456
+ sample_count: u64,
457
+ sample_sum: f64,
458
+ buckets: &[(f64, u64)],
459
+ labels: &[(&str, &str)],
460
+ ) -> Metric {
461
+ let mut metric = Metric::default();
462
+ metric.set_label(labels.iter().map(|(n, v)| label_pair(n, v)).collect());
463
+ metric.set_histogram({
464
+ let mut histogram = prometheus::proto::Histogram::default();
465
+ histogram.set_sample_count(sample_count);
466
+ histogram.set_sample_sum(sample_sum);
467
+ histogram.set_bucket(
468
+ buckets
469
+ .iter()
470
+ .map(|&(upper_bound, cumulative_count)| {
471
+ let mut bucket = prometheus::proto::Bucket::default();
472
+ bucket.set_cumulative_count(cumulative_count);
473
+ bucket.set_upper_bound(upper_bound);
474
+ bucket
475
+ })
476
+ .collect(),
477
+ );
478
+ histogram
479
+ });
480
+ metric
481
+ }
482
+
483
+ fn label_pair(name: &str, value: &str) -> LabelPair {
484
+ let mut label = LabelPair::default();
485
+ label.set_name(name.into());
486
+ label.set_value(value.into());
487
+ label
488
+ }
489
+
490
+ #[cfg(test)]
491
+ mod tests {
492
+ #[test]
493
+ fn test_endpoint_whitelists_sorted() {
494
+ use super::{GRPC_ENDPOINT_WHITELIST, REST_ENDPOINT_WHITELIST};
495
+
496
+ assert!(
497
+ REST_ENDPOINT_WHITELIST.windows(2).all(|n| n[0] <= n[1]),
498
+ "REST_ENDPOINT_WHITELIST must be sorted in code to allow binary search"
499
+ );
500
+ assert!(
501
+ GRPC_ENDPOINT_WHITELIST.windows(2).all(|n| n[0] <= n[1]),
502
+ "GRPC_ENDPOINT_WHITELIST must be sorted in code to allow binary search"
503
+ );
504
+ }
505
+ }
src/common/mod.rs ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
2
+ pub mod collections;
3
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
4
+ pub mod error_reporting;
5
+ #[allow(dead_code)]
6
+ pub mod health;
7
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
8
+ pub mod helpers;
9
+ pub mod http_client;
10
+ pub mod metrics;
11
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
12
+ pub mod points;
13
+ pub mod snapshots;
14
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
15
+ pub mod stacktrace;
16
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
17
+ pub mod telemetry;
18
+ pub mod telemetry_ops;
19
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
20
+ pub mod telemetry_reporting;
21
+
22
+ pub mod auth;
23
+
24
+ pub mod strings;
25
+
26
+ pub mod debugger;
27
+
28
+ #[allow(dead_code)] // May contain functions used in different binaries. Not actually dead
29
+ pub mod inference;
30
+
31
+ pub mod pyroscope_state;
src/common/points.rs ADDED
@@ -0,0 +1,1175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::sync::Arc;
2
+ use std::time::Duration;
3
+
4
+ use api::rest::schema::{PointInsertOperations, PointsBatch, PointsList};
5
+ use api::rest::{SearchGroupsRequestInternal, ShardKeySelector, UpdateVectors};
6
+ use collection::collection::distance_matrix::{
7
+ CollectionSearchMatrixRequest, CollectionSearchMatrixResponse,
8
+ };
9
+ use collection::collection::Collection;
10
+ use collection::common::batching::batch_requests;
11
+ use collection::grouping::group_by::GroupRequest;
12
+ use collection::operations::consistency_params::ReadConsistency;
13
+ use collection::operations::payload_ops::{
14
+ DeletePayload, DeletePayloadOp, PayloadOps, SetPayload, SetPayloadOp,
15
+ };
16
+ use collection::operations::point_ops::{
17
+ FilterSelector, PointIdsList, PointInsertOperationsInternal, PointOperations, PointsSelector,
18
+ WriteOrdering,
19
+ };
20
+ use collection::operations::shard_selector_internal::ShardSelectorInternal;
21
+ use collection::operations::types::{
22
+ CollectionError, CoreSearchRequest, CoreSearchRequestBatch, CountRequestInternal, CountResult,
23
+ DiscoverRequestBatch, GroupsResult, PointRequestInternal, RecommendGroupsRequestInternal,
24
+ RecordInternal, ScrollRequestInternal, ScrollResult, UpdateResult,
25
+ };
26
+ use collection::operations::universal_query::collection_query::{
27
+ CollectionQueryGroupsRequest, CollectionQueryRequest,
28
+ };
29
+ use collection::operations::vector_ops::{DeleteVectors, UpdateVectorsOp, VectorOperations};
30
+ use collection::operations::verification::{
31
+ new_unchecked_verification_pass, StrictModeVerification,
32
+ };
33
+ use collection::operations::{
34
+ ClockTag, CollectionUpdateOperations, CreateIndex, FieldIndexOperations, OperationWithClockTag,
35
+ };
36
+ use collection::shards::shard::ShardId;
37
+ use common::counter::hardware_accumulator::HwMeasurementAcc;
38
+ use schemars::JsonSchema;
39
+ use segment::json_path::JsonPath;
40
+ use segment::types::{PayloadFieldSchema, PayloadKeyType, ScoredPoint, StrictModeConfig};
41
+ use serde::{Deserialize, Serialize};
42
+ use storage::content_manager::collection_meta_ops::{
43
+ CollectionMetaOperations, CreatePayloadIndex, DropPayloadIndex,
44
+ };
45
+ use storage::content_manager::errors::StorageError;
46
+ use storage::content_manager::toc::TableOfContent;
47
+ use storage::dispatcher::Dispatcher;
48
+ use storage::rbac::Access;
49
+ use validator::Validate;
50
+
51
+ use crate::common::inference::service::InferenceType;
52
+ use crate::common::inference::update_requests::{
53
+ convert_batch, convert_point_struct, convert_point_vectors,
54
+ };
55
+
56
+ #[derive(Debug, Deserialize, Serialize, JsonSchema, Validate)]
57
+ pub struct CreateFieldIndex {
58
+ pub field_name: PayloadKeyType,
59
+ #[serde(alias = "field_type")]
60
+ pub field_schema: Option<PayloadFieldSchema>,
61
+ }
62
+
63
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
64
+ pub struct UpsertOperation {
65
+ #[validate(nested)]
66
+ upsert: PointInsertOperations,
67
+ }
68
+
69
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
70
+ pub struct DeleteOperation {
71
+ #[validate(nested)]
72
+ delete: PointsSelector,
73
+ }
74
+
75
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
76
+ pub struct SetPayloadOperation {
77
+ #[validate(nested)]
78
+ set_payload: SetPayload,
79
+ }
80
+
81
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
82
+ pub struct OverwritePayloadOperation {
83
+ #[validate(nested)]
84
+ overwrite_payload: SetPayload,
85
+ }
86
+
87
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
88
+ pub struct DeletePayloadOperation {
89
+ #[validate(nested)]
90
+ delete_payload: DeletePayload,
91
+ }
92
+
93
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
94
+ pub struct ClearPayloadOperation {
95
+ #[validate(nested)]
96
+ clear_payload: PointsSelector,
97
+ }
98
+
99
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
100
+ pub struct UpdateVectorsOperation {
101
+ #[validate(nested)]
102
+ update_vectors: UpdateVectors,
103
+ }
104
+
105
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
106
+ pub struct DeleteVectorsOperation {
107
+ #[validate(nested)]
108
+ delete_vectors: DeleteVectors,
109
+ }
110
+
111
+ #[derive(Deserialize, Serialize, JsonSchema)]
112
+ #[serde(rename_all = "snake_case")]
113
+ #[serde(untagged)]
114
+ pub enum UpdateOperation {
115
+ Upsert(UpsertOperation),
116
+ Delete(DeleteOperation),
117
+ SetPayload(SetPayloadOperation),
118
+ OverwritePayload(OverwritePayloadOperation),
119
+ DeletePayload(DeletePayloadOperation),
120
+ ClearPayload(ClearPayloadOperation),
121
+ UpdateVectors(UpdateVectorsOperation),
122
+ DeleteVectors(DeleteVectorsOperation),
123
+ }
124
+
125
+ #[derive(Deserialize, Serialize, JsonSchema, Validate)]
126
+ pub struct UpdateOperations {
127
+ pub operations: Vec<UpdateOperation>,
128
+ }
129
+
130
+ impl Validate for UpdateOperation {
131
+ fn validate(&self) -> Result<(), validator::ValidationErrors> {
132
+ match self {
133
+ UpdateOperation::Upsert(op) => op.validate(),
134
+ UpdateOperation::Delete(op) => op.validate(),
135
+ UpdateOperation::SetPayload(op) => op.validate(),
136
+ UpdateOperation::OverwritePayload(op) => op.validate(),
137
+ UpdateOperation::DeletePayload(op) => op.validate(),
138
+ UpdateOperation::ClearPayload(op) => op.validate(),
139
+ UpdateOperation::UpdateVectors(op) => op.validate(),
140
+ UpdateOperation::DeleteVectors(op) => op.validate(),
141
+ }
142
+ }
143
+ }
144
+
145
+ impl StrictModeVerification for UpdateOperation {
146
+ fn query_limit(&self) -> Option<usize> {
147
+ None
148
+ }
149
+
150
+ fn indexed_filter_read(&self) -> Option<&segment::types::Filter> {
151
+ None
152
+ }
153
+
154
+ fn indexed_filter_write(&self) -> Option<&segment::types::Filter> {
155
+ None
156
+ }
157
+
158
+ fn request_exact(&self) -> Option<bool> {
159
+ None
160
+ }
161
+
162
+ fn request_search_params(&self) -> Option<&segment::types::SearchParams> {
163
+ None
164
+ }
165
+
166
+ fn check_strict_mode(
167
+ &self,
168
+ collection: &Collection,
169
+ strict_mode_config: &StrictModeConfig,
170
+ ) -> Result<(), CollectionError> {
171
+ match self {
172
+ UpdateOperation::Delete(delete_op) => delete_op
173
+ .delete
174
+ .check_strict_mode(collection, strict_mode_config),
175
+ UpdateOperation::SetPayload(set_payload) => set_payload
176
+ .set_payload
177
+ .check_strict_mode(collection, strict_mode_config),
178
+ UpdateOperation::OverwritePayload(overwrite_payload) => overwrite_payload
179
+ .overwrite_payload
180
+ .check_strict_mode(collection, strict_mode_config),
181
+ UpdateOperation::DeletePayload(delete_payload) => delete_payload
182
+ .delete_payload
183
+ .check_strict_mode(collection, strict_mode_config),
184
+ UpdateOperation::ClearPayload(clear_payload) => clear_payload
185
+ .clear_payload
186
+ .check_strict_mode(collection, strict_mode_config),
187
+ UpdateOperation::DeleteVectors(delete_op) => delete_op
188
+ .delete_vectors
189
+ .check_strict_mode(collection, strict_mode_config),
190
+ UpdateOperation::UpdateVectors(_) | UpdateOperation::Upsert(_) => Ok(()),
191
+ }
192
+ }
193
+ }
194
+
195
+ /// Converts a pair of parameters into a shard selector
196
+ /// suitable for update operations.
197
+ ///
198
+ /// The key difference from selector for search operations is that
199
+ /// empty shard selector in case of update means default shard,
200
+ /// while empty shard selector in case of search means all shards.
201
+ ///
202
+ /// Parameters:
203
+ /// - shard_selection: selection of the exact shard ID, always have priority over shard_key
204
+ /// - shard_key: selection of the shard key, can be a single key or a list of keys
205
+ ///
206
+ /// Returns:
207
+ /// - ShardSelectorInternal - resolved shard selector
208
+ fn get_shard_selector_for_update(
209
+ shard_selection: Option<ShardId>,
210
+ shard_key: Option<ShardKeySelector>,
211
+ ) -> ShardSelectorInternal {
212
+ match (shard_selection, shard_key) {
213
+ (Some(shard_selection), None) => ShardSelectorInternal::ShardId(shard_selection),
214
+ (Some(shard_selection), Some(_)) => {
215
+ debug_assert!(
216
+ false,
217
+ "Shard selection and shard key are mutually exclusive"
218
+ );
219
+ ShardSelectorInternal::ShardId(shard_selection)
220
+ }
221
+ (None, Some(shard_key)) => ShardSelectorInternal::from(shard_key),
222
+ (None, None) => ShardSelectorInternal::Empty,
223
+ }
224
+ }
225
+
226
+ #[allow(clippy::too_many_arguments)]
227
+ pub async fn do_upsert_points(
228
+ toc: Arc<TableOfContent>,
229
+ collection_name: String,
230
+ operation: PointInsertOperations,
231
+ clock_tag: Option<ClockTag>,
232
+ shard_selection: Option<ShardId>,
233
+ wait: bool,
234
+ ordering: WriteOrdering,
235
+ access: Access,
236
+ ) -> Result<UpdateResult, StorageError> {
237
+ let (shard_key, operation) = match operation {
238
+ PointInsertOperations::PointsBatch(PointsBatch { batch, shard_key }) => (
239
+ shard_key,
240
+ PointInsertOperationsInternal::PointsBatch(convert_batch(batch).await?),
241
+ ),
242
+ PointInsertOperations::PointsList(PointsList { points, shard_key }) => (
243
+ shard_key,
244
+ PointInsertOperationsInternal::PointsList(
245
+ convert_point_struct(points, InferenceType::Update).await?,
246
+ ),
247
+ ),
248
+ };
249
+
250
+ let collection_operation =
251
+ CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(operation));
252
+
253
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
254
+
255
+ toc.update(
256
+ &collection_name,
257
+ OperationWithClockTag::new(collection_operation, clock_tag),
258
+ wait,
259
+ ordering,
260
+ shard_selector,
261
+ access,
262
+ )
263
+ .await
264
+ }
265
+
266
+ #[allow(clippy::too_many_arguments)]
267
+ pub async fn do_delete_points(
268
+ toc: Arc<TableOfContent>,
269
+ collection_name: String,
270
+ points: PointsSelector,
271
+ clock_tag: Option<ClockTag>,
272
+ shard_selection: Option<ShardId>,
273
+ wait: bool,
274
+ ordering: WriteOrdering,
275
+ access: Access,
276
+ ) -> Result<UpdateResult, StorageError> {
277
+ let (point_operation, shard_key) = match points {
278
+ PointsSelector::PointIdsSelector(PointIdsList { points, shard_key }) => {
279
+ (PointOperations::DeletePoints { ids: points }, shard_key)
280
+ }
281
+ PointsSelector::FilterSelector(FilterSelector { filter, shard_key }) => {
282
+ (PointOperations::DeletePointsByFilter(filter), shard_key)
283
+ }
284
+ };
285
+ let collection_operation = CollectionUpdateOperations::PointOperation(point_operation);
286
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
287
+
288
+ toc.update(
289
+ &collection_name,
290
+ OperationWithClockTag::new(collection_operation, clock_tag),
291
+ wait,
292
+ ordering,
293
+ shard_selector,
294
+ access,
295
+ )
296
+ .await
297
+ }
298
+
299
+ #[allow(clippy::too_many_arguments)]
300
+ pub async fn do_update_vectors(
301
+ toc: Arc<TableOfContent>,
302
+ collection_name: String,
303
+ operation: UpdateVectors,
304
+ clock_tag: Option<ClockTag>,
305
+ shard_selection: Option<ShardId>,
306
+ wait: bool,
307
+ ordering: WriteOrdering,
308
+ access: Access,
309
+ ) -> Result<UpdateResult, StorageError> {
310
+ let UpdateVectors { points, shard_key } = operation;
311
+
312
+ let persisted_points = convert_point_vectors(points, InferenceType::Update).await?;
313
+
314
+ let collection_operation = CollectionUpdateOperations::VectorOperation(
315
+ VectorOperations::UpdateVectors(UpdateVectorsOp {
316
+ points: persisted_points,
317
+ }),
318
+ );
319
+
320
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
321
+
322
+ toc.update(
323
+ &collection_name,
324
+ OperationWithClockTag::new(collection_operation, clock_tag),
325
+ wait,
326
+ ordering,
327
+ shard_selector,
328
+ access,
329
+ )
330
+ .await
331
+ }
332
+
333
+ #[allow(clippy::too_many_arguments)]
334
+ pub async fn do_delete_vectors(
335
+ toc: Arc<TableOfContent>,
336
+ collection_name: String,
337
+ operation: DeleteVectors,
338
+ clock_tag: Option<ClockTag>,
339
+ shard_selection: Option<ShardId>,
340
+ wait: bool,
341
+ ordering: WriteOrdering,
342
+ access: Access,
343
+ ) -> Result<UpdateResult, StorageError> {
344
+ // TODO: Is this cancel safe!?
345
+
346
+ let DeleteVectors {
347
+ vector,
348
+ filter,
349
+ points,
350
+ shard_key,
351
+ } = operation;
352
+
353
+ let vector_names: Vec<_> = vector.into_iter().collect();
354
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
355
+
356
+ let mut result = None;
357
+
358
+ if let Some(filter) = filter {
359
+ let vectors_operation =
360
+ VectorOperations::DeleteVectorsByFilter(filter, vector_names.clone());
361
+
362
+ let collection_operation = CollectionUpdateOperations::VectorOperation(vectors_operation);
363
+
364
+ result = Some(
365
+ toc.update(
366
+ &collection_name,
367
+ OperationWithClockTag::new(collection_operation, clock_tag),
368
+ wait,
369
+ ordering,
370
+ shard_selector.clone(),
371
+ access.clone(),
372
+ )
373
+ .await?,
374
+ );
375
+ }
376
+
377
+ if let Some(points) = points {
378
+ let vectors_operation = VectorOperations::DeleteVectors(points.into(), vector_names);
379
+ let collection_operation = CollectionUpdateOperations::VectorOperation(vectors_operation);
380
+ result = Some(
381
+ toc.update(
382
+ &collection_name,
383
+ OperationWithClockTag::new(collection_operation, clock_tag),
384
+ wait,
385
+ ordering,
386
+ shard_selector,
387
+ access,
388
+ )
389
+ .await?,
390
+ );
391
+ }
392
+
393
+ result.ok_or_else(|| StorageError::bad_request("No filter or points provided"))
394
+ }
395
+
396
+ #[allow(clippy::too_many_arguments)]
397
+ pub async fn do_set_payload(
398
+ toc: Arc<TableOfContent>,
399
+ collection_name: String,
400
+ operation: SetPayload,
401
+ clock_tag: Option<ClockTag>,
402
+ shard_selection: Option<ShardId>,
403
+ wait: bool,
404
+ ordering: WriteOrdering,
405
+ access: Access,
406
+ ) -> Result<UpdateResult, StorageError> {
407
+ let SetPayload {
408
+ points,
409
+ payload,
410
+ filter,
411
+ shard_key,
412
+ key,
413
+ } = operation;
414
+
415
+ let collection_operation =
416
+ CollectionUpdateOperations::PayloadOperation(PayloadOps::SetPayload(SetPayloadOp {
417
+ payload,
418
+ points,
419
+ filter,
420
+ key,
421
+ }));
422
+
423
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
424
+
425
+ toc.update(
426
+ &collection_name,
427
+ OperationWithClockTag::new(collection_operation, clock_tag),
428
+ wait,
429
+ ordering,
430
+ shard_selector,
431
+ access,
432
+ )
433
+ .await
434
+ }
435
+
436
+ #[allow(clippy::too_many_arguments)]
437
+ pub async fn do_overwrite_payload(
438
+ toc: Arc<TableOfContent>,
439
+ collection_name: String,
440
+ operation: SetPayload,
441
+ clock_tag: Option<ClockTag>,
442
+ shard_selection: Option<ShardId>,
443
+ wait: bool,
444
+ ordering: WriteOrdering,
445
+ access: Access,
446
+ ) -> Result<UpdateResult, StorageError> {
447
+ let SetPayload {
448
+ points,
449
+ payload,
450
+ filter,
451
+ shard_key,
452
+ ..
453
+ } = operation;
454
+
455
+ let collection_operation =
456
+ CollectionUpdateOperations::PayloadOperation(PayloadOps::OverwritePayload(SetPayloadOp {
457
+ payload,
458
+ points,
459
+ filter,
460
+ // overwrite operation doesn't support payload selector
461
+ key: None,
462
+ }));
463
+
464
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
465
+
466
+ toc.update(
467
+ &collection_name,
468
+ OperationWithClockTag::new(collection_operation, clock_tag),
469
+ wait,
470
+ ordering,
471
+ shard_selector,
472
+ access,
473
+ )
474
+ .await
475
+ }
476
+
477
+ #[allow(clippy::too_many_arguments)]
478
+ pub async fn do_delete_payload(
479
+ toc: Arc<TableOfContent>,
480
+ collection_name: String,
481
+ operation: DeletePayload,
482
+ clock_tag: Option<ClockTag>,
483
+ shard_selection: Option<ShardId>,
484
+ wait: bool,
485
+ ordering: WriteOrdering,
486
+ access: Access,
487
+ ) -> Result<UpdateResult, StorageError> {
488
+ let DeletePayload {
489
+ keys,
490
+ points,
491
+ filter,
492
+ shard_key,
493
+ } = operation;
494
+
495
+ let collection_operation =
496
+ CollectionUpdateOperations::PayloadOperation(PayloadOps::DeletePayload(DeletePayloadOp {
497
+ keys,
498
+ points,
499
+ filter,
500
+ }));
501
+
502
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
503
+
504
+ toc.update(
505
+ &collection_name,
506
+ OperationWithClockTag::new(collection_operation, clock_tag),
507
+ wait,
508
+ ordering,
509
+ shard_selector,
510
+ access,
511
+ )
512
+ .await
513
+ }
514
+
515
+ #[allow(clippy::too_many_arguments)]
516
+ pub async fn do_clear_payload(
517
+ toc: Arc<TableOfContent>,
518
+ collection_name: String,
519
+ points: PointsSelector,
520
+ clock_tag: Option<ClockTag>,
521
+ shard_selection: Option<ShardId>,
522
+ wait: bool,
523
+ ordering: WriteOrdering,
524
+ access: Access,
525
+ ) -> Result<UpdateResult, StorageError> {
526
+ let (point_operation, shard_key) = match points {
527
+ PointsSelector::PointIdsSelector(PointIdsList { points, shard_key }) => {
528
+ (PayloadOps::ClearPayload { points }, shard_key)
529
+ }
530
+ PointsSelector::FilterSelector(FilterSelector { filter, shard_key }) => {
531
+ (PayloadOps::ClearPayloadByFilter(filter), shard_key)
532
+ }
533
+ };
534
+
535
+ let collection_operation = CollectionUpdateOperations::PayloadOperation(point_operation);
536
+
537
+ let shard_selector = get_shard_selector_for_update(shard_selection, shard_key);
538
+
539
+ toc.update(
540
+ &collection_name,
541
+ OperationWithClockTag::new(collection_operation, clock_tag),
542
+ wait,
543
+ ordering,
544
+ shard_selector,
545
+ access,
546
+ )
547
+ .await
548
+ }
549
+
550
+ #[allow(clippy::too_many_arguments)]
551
+ pub async fn do_batch_update_points(
552
+ toc: Arc<TableOfContent>,
553
+ collection_name: String,
554
+ operations: Vec<UpdateOperation>,
555
+ clock_tag: Option<ClockTag>,
556
+ shard_selection: Option<ShardId>,
557
+ wait: bool,
558
+ ordering: WriteOrdering,
559
+ access: Access,
560
+ ) -> Result<Vec<UpdateResult>, StorageError> {
561
+ let mut results = Vec::with_capacity(operations.len());
562
+ for operation in operations {
563
+ let result = match operation {
564
+ UpdateOperation::Upsert(operation) => {
565
+ do_upsert_points(
566
+ toc.clone(),
567
+ collection_name.clone(),
568
+ operation.upsert,
569
+ clock_tag,
570
+ shard_selection,
571
+ wait,
572
+ ordering,
573
+ access.clone(),
574
+ )
575
+ .await
576
+ }
577
+ UpdateOperation::Delete(operation) => {
578
+ do_delete_points(
579
+ toc.clone(),
580
+ collection_name.clone(),
581
+ operation.delete,
582
+ clock_tag,
583
+ shard_selection,
584
+ wait,
585
+ ordering,
586
+ access.clone(),
587
+ )
588
+ .await
589
+ }
590
+ UpdateOperation::SetPayload(operation) => {
591
+ do_set_payload(
592
+ toc.clone(),
593
+ collection_name.clone(),
594
+ operation.set_payload,
595
+ clock_tag,
596
+ shard_selection,
597
+ wait,
598
+ ordering,
599
+ access.clone(),
600
+ )
601
+ .await
602
+ }
603
+ UpdateOperation::OverwritePayload(operation) => {
604
+ do_overwrite_payload(
605
+ toc.clone(),
606
+ collection_name.clone(),
607
+ operation.overwrite_payload,
608
+ clock_tag,
609
+ shard_selection,
610
+ wait,
611
+ ordering,
612
+ access.clone(),
613
+ )
614
+ .await
615
+ }
616
+ UpdateOperation::DeletePayload(operation) => {
617
+ do_delete_payload(
618
+ toc.clone(),
619
+ collection_name.clone(),
620
+ operation.delete_payload,
621
+ clock_tag,
622
+ shard_selection,
623
+ wait,
624
+ ordering,
625
+ access.clone(),
626
+ )
627
+ .await
628
+ }
629
+ UpdateOperation::ClearPayload(operation) => {
630
+ do_clear_payload(
631
+ toc.clone(),
632
+ collection_name.clone(),
633
+ operation.clear_payload,
634
+ clock_tag,
635
+ shard_selection,
636
+ wait,
637
+ ordering,
638
+ access.clone(),
639
+ )
640
+ .await
641
+ }
642
+ UpdateOperation::UpdateVectors(operation) => {
643
+ do_update_vectors(
644
+ toc.clone(),
645
+ collection_name.clone(),
646
+ operation.update_vectors,
647
+ clock_tag,
648
+ shard_selection,
649
+ wait,
650
+ ordering,
651
+ access.clone(),
652
+ )
653
+ .await
654
+ }
655
+ UpdateOperation::DeleteVectors(operation) => {
656
+ do_delete_vectors(
657
+ toc.clone(),
658
+ collection_name.clone(),
659
+ operation.delete_vectors,
660
+ clock_tag,
661
+ shard_selection,
662
+ wait,
663
+ ordering,
664
+ access.clone(),
665
+ )
666
+ .await
667
+ }
668
+ }?;
669
+ results.push(result);
670
+ }
671
+ Ok(results)
672
+ }
673
+
674
+ #[allow(clippy::too_many_arguments)]
675
+ pub async fn do_create_index_internal(
676
+ toc: Arc<TableOfContent>,
677
+ collection_name: String,
678
+ field_name: PayloadKeyType,
679
+ field_schema: Option<PayloadFieldSchema>,
680
+ clock_tag: Option<ClockTag>,
681
+ shard_selection: Option<ShardId>,
682
+ wait: bool,
683
+ ordering: WriteOrdering,
684
+ ) -> Result<UpdateResult, StorageError> {
685
+ let collection_operation = CollectionUpdateOperations::FieldIndexOperation(
686
+ FieldIndexOperations::CreateIndex(CreateIndex {
687
+ field_name,
688
+ field_schema,
689
+ }),
690
+ );
691
+
692
+ let shard_selector = if let Some(shard_selection) = shard_selection {
693
+ ShardSelectorInternal::ShardId(shard_selection)
694
+ } else {
695
+ ShardSelectorInternal::All
696
+ };
697
+
698
+ toc.update(
699
+ &collection_name,
700
+ OperationWithClockTag::new(collection_operation, clock_tag),
701
+ wait,
702
+ ordering,
703
+ shard_selector,
704
+ Access::full("Internal API"),
705
+ )
706
+ .await
707
+ }
708
+
709
+ #[allow(clippy::too_many_arguments)]
710
+ pub async fn do_create_index(
711
+ dispatcher: Arc<Dispatcher>,
712
+ collection_name: String,
713
+ operation: CreateFieldIndex,
714
+ clock_tag: Option<ClockTag>,
715
+ shard_selection: Option<ShardId>,
716
+ wait: bool,
717
+ ordering: WriteOrdering,
718
+ access: Access,
719
+ ) -> Result<UpdateResult, StorageError> {
720
+ // TODO: Is this cancel safe!?
721
+
722
+ let Some(field_schema) = operation.field_schema else {
723
+ return Err(StorageError::bad_request(
724
+ "Can't auto-detect field type, please specify `field_schema` in the request",
725
+ ));
726
+ };
727
+
728
+ let consensus_op = CollectionMetaOperations::CreatePayloadIndex(CreatePayloadIndex {
729
+ collection_name: collection_name.to_string(),
730
+ field_name: operation.field_name.clone(),
731
+ field_schema: field_schema.clone(),
732
+ });
733
+
734
+ // Default consensus timeout will be used
735
+ let wait_timeout = None; // ToDo: make it configurable
736
+
737
+ // Nothing to verify here.
738
+ let pass = new_unchecked_verification_pass();
739
+
740
+ let toc = dispatcher.toc(&access, &pass).clone();
741
+
742
+ // TODO: Is `submit_collection_meta_op` cancel-safe!? Should be, I think?.. 🤔
743
+ dispatcher
744
+ .submit_collection_meta_op(consensus_op, access, wait_timeout)
745
+ .await?;
746
+
747
+ // This function is required as long as we want to maintain interface compatibility
748
+ // for `wait` parameter and return type.
749
+ // The idea is to migrate from the point-like interface to consensus-like interface in the next few versions
750
+
751
+ do_create_index_internal(
752
+ toc,
753
+ collection_name,
754
+ operation.field_name,
755
+ Some(field_schema),
756
+ clock_tag,
757
+ shard_selection,
758
+ wait,
759
+ ordering,
760
+ )
761
+ .await
762
+ }
763
+
764
+ #[allow(clippy::too_many_arguments)]
765
+ pub async fn do_delete_index_internal(
766
+ toc: Arc<TableOfContent>,
767
+ collection_name: String,
768
+ index_name: JsonPath,
769
+ clock_tag: Option<ClockTag>,
770
+ shard_selection: Option<ShardId>,
771
+ wait: bool,
772
+ ordering: WriteOrdering,
773
+ ) -> Result<UpdateResult, StorageError> {
774
+ let collection_operation = CollectionUpdateOperations::FieldIndexOperation(
775
+ FieldIndexOperations::DeleteIndex(index_name),
776
+ );
777
+
778
+ let shard_selector = if let Some(shard_selection) = shard_selection {
779
+ ShardSelectorInternal::ShardId(shard_selection)
780
+ } else {
781
+ ShardSelectorInternal::All
782
+ };
783
+
784
+ toc.update(
785
+ &collection_name,
786
+ OperationWithClockTag::new(collection_operation, clock_tag),
787
+ wait,
788
+ ordering,
789
+ shard_selector,
790
+ Access::full("Internal API"),
791
+ )
792
+ .await
793
+ }
794
+
795
+ #[allow(clippy::too_many_arguments)]
796
+ pub async fn do_delete_index(
797
+ dispatcher: Arc<Dispatcher>,
798
+ collection_name: String,
799
+ index_name: JsonPath,
800
+ clock_tag: Option<ClockTag>,
801
+ shard_selection: Option<ShardId>,
802
+ wait: bool,
803
+ ordering: WriteOrdering,
804
+ access: Access,
805
+ ) -> Result<UpdateResult, StorageError> {
806
+ // TODO: Is this cancel safe!?
807
+
808
+ let consensus_op = CollectionMetaOperations::DropPayloadIndex(DropPayloadIndex {
809
+ collection_name: collection_name.to_string(),
810
+ field_name: index_name.clone(),
811
+ });
812
+
813
+ // Default consensus timeout will be used
814
+ let wait_timeout = None; // ToDo: make it configurable
815
+
816
+ // Nothing to verify here.
817
+ let pass = new_unchecked_verification_pass();
818
+
819
+ let toc = dispatcher.toc(&access, &pass).clone();
820
+
821
+ // TODO: Is `submit_collection_meta_op` cancel-safe!? Should be, I think?.. 🤔
822
+ dispatcher
823
+ .submit_collection_meta_op(consensus_op, access, wait_timeout)
824
+ .await?;
825
+
826
+ do_delete_index_internal(
827
+ toc,
828
+ collection_name,
829
+ index_name,
830
+ clock_tag,
831
+ shard_selection,
832
+ wait,
833
+ ordering,
834
+ )
835
+ .await
836
+ }
837
+
838
+ #[allow(clippy::too_many_arguments)]
839
+ pub async fn do_core_search_points(
840
+ toc: &TableOfContent,
841
+ collection_name: &str,
842
+ request: CoreSearchRequest,
843
+ read_consistency: Option<ReadConsistency>,
844
+ shard_selection: ShardSelectorInternal,
845
+ access: Access,
846
+ timeout: Option<Duration>,
847
+ hw_measurement_acc: &HwMeasurementAcc,
848
+ ) -> Result<Vec<ScoredPoint>, StorageError> {
849
+ let batch_res = do_core_search_batch_points(
850
+ toc,
851
+ collection_name,
852
+ CoreSearchRequestBatch {
853
+ searches: vec![request],
854
+ },
855
+ read_consistency,
856
+ shard_selection,
857
+ access,
858
+ timeout,
859
+ hw_measurement_acc,
860
+ )
861
+ .await?;
862
+ batch_res
863
+ .into_iter()
864
+ .next()
865
+ .ok_or_else(|| StorageError::service_error("Empty search result"))
866
+ }
867
+
868
+ pub async fn do_search_batch_points(
869
+ toc: &TableOfContent,
870
+ collection_name: &str,
871
+ requests: Vec<(CoreSearchRequest, ShardSelectorInternal)>,
872
+ read_consistency: Option<ReadConsistency>,
873
+ access: Access,
874
+ timeout: Option<Duration>,
875
+ hw_measurement_acc: &HwMeasurementAcc,
876
+ ) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
877
+ let requests = batch_requests::<
878
+ (CoreSearchRequest, ShardSelectorInternal),
879
+ ShardSelectorInternal,
880
+ Vec<CoreSearchRequest>,
881
+ Vec<_>,
882
+ >(
883
+ requests,
884
+ |(_, shard_selector)| shard_selector,
885
+ |(request, _), core_reqs| {
886
+ core_reqs.push(request);
887
+ Ok(())
888
+ },
889
+ |shard_selector, core_requests, res| {
890
+ if core_requests.is_empty() {
891
+ return Ok(());
892
+ }
893
+
894
+ let core_batch = CoreSearchRequestBatch {
895
+ searches: core_requests,
896
+ };
897
+
898
+ let req = toc.core_search_batch(
899
+ collection_name,
900
+ core_batch,
901
+ read_consistency,
902
+ shard_selector,
903
+ access.clone(),
904
+ timeout,
905
+ hw_measurement_acc,
906
+ );
907
+ res.push(req);
908
+ Ok(())
909
+ },
910
+ )?;
911
+
912
+ let results = futures::future::try_join_all(requests).await?;
913
+ let flatten_results: Vec<Vec<_>> = results.into_iter().flatten().collect();
914
+ Ok(flatten_results)
915
+ }
916
+
917
+ #[allow(clippy::too_many_arguments)]
918
+ pub async fn do_core_search_batch_points(
919
+ toc: &TableOfContent,
920
+ collection_name: &str,
921
+ request: CoreSearchRequestBatch,
922
+ read_consistency: Option<ReadConsistency>,
923
+ shard_selection: ShardSelectorInternal,
924
+ access: Access,
925
+ timeout: Option<Duration>,
926
+ hw_measurement_acc: &HwMeasurementAcc,
927
+ ) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
928
+ toc.core_search_batch(
929
+ collection_name,
930
+ request,
931
+ read_consistency,
932
+ shard_selection,
933
+ access,
934
+ timeout,
935
+ hw_measurement_acc,
936
+ )
937
+ .await
938
+ }
939
+
940
+ #[allow(clippy::too_many_arguments)]
941
+ pub async fn do_search_point_groups(
942
+ toc: &TableOfContent,
943
+ collection_name: &str,
944
+ request: SearchGroupsRequestInternal,
945
+ read_consistency: Option<ReadConsistency>,
946
+ shard_selection: ShardSelectorInternal,
947
+ access: Access,
948
+ timeout: Option<Duration>,
949
+ hw_measurement_acc: &HwMeasurementAcc,
950
+ ) -> Result<GroupsResult, StorageError> {
951
+ toc.group(
952
+ collection_name,
953
+ GroupRequest::from(request),
954
+ read_consistency,
955
+ shard_selection,
956
+ access,
957
+ timeout,
958
+ hw_measurement_acc,
959
+ )
960
+ .await
961
+ }
962
+
963
+ #[allow(clippy::too_many_arguments)]
964
+ pub async fn do_recommend_point_groups(
965
+ toc: &TableOfContent,
966
+ collection_name: &str,
967
+ request: RecommendGroupsRequestInternal,
968
+ read_consistency: Option<ReadConsistency>,
969
+ shard_selection: ShardSelectorInternal,
970
+ access: Access,
971
+ timeout: Option<Duration>,
972
+ hw_measurement_acc: &HwMeasurementAcc,
973
+ ) -> Result<GroupsResult, StorageError> {
974
+ toc.group(
975
+ collection_name,
976
+ GroupRequest::from(request),
977
+ read_consistency,
978
+ shard_selection,
979
+ access,
980
+ timeout,
981
+ hw_measurement_acc,
982
+ )
983
+ .await
984
+ }
985
+
986
+ pub async fn do_discover_batch_points(
987
+ toc: &TableOfContent,
988
+ collection_name: &str,
989
+ request: DiscoverRequestBatch,
990
+ read_consistency: Option<ReadConsistency>,
991
+ access: Access,
992
+ timeout: Option<Duration>,
993
+ hw_measurement_acc: &HwMeasurementAcc,
994
+ ) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
995
+ let requests = request
996
+ .searches
997
+ .into_iter()
998
+ .map(|req| {
999
+ let shard_selector = match req.shard_key {
1000
+ None => ShardSelectorInternal::All,
1001
+ Some(shard_key) => ShardSelectorInternal::from(shard_key),
1002
+ };
1003
+
1004
+ (req.discover_request, shard_selector)
1005
+ })
1006
+ .collect();
1007
+
1008
+ toc.discover_batch(
1009
+ collection_name,
1010
+ requests,
1011
+ read_consistency,
1012
+ access,
1013
+ timeout,
1014
+ hw_measurement_acc,
1015
+ )
1016
+ .await
1017
+ }
1018
+
1019
+ #[allow(clippy::too_many_arguments)]
1020
+ pub async fn do_count_points(
1021
+ toc: &TableOfContent,
1022
+ collection_name: &str,
1023
+ request: CountRequestInternal,
1024
+ read_consistency: Option<ReadConsistency>,
1025
+ timeout: Option<Duration>,
1026
+ shard_selection: ShardSelectorInternal,
1027
+ access: Access,
1028
+ hw_measurement_acc: &HwMeasurementAcc,
1029
+ ) -> Result<CountResult, StorageError> {
1030
+ toc.count(
1031
+ collection_name,
1032
+ request,
1033
+ read_consistency,
1034
+ timeout,
1035
+ shard_selection,
1036
+ access,
1037
+ hw_measurement_acc,
1038
+ )
1039
+ .await
1040
+ }
1041
+
1042
+ pub async fn do_get_points(
1043
+ toc: &TableOfContent,
1044
+ collection_name: &str,
1045
+ request: PointRequestInternal,
1046
+ read_consistency: Option<ReadConsistency>,
1047
+ timeout: Option<Duration>,
1048
+ shard_selection: ShardSelectorInternal,
1049
+ access: Access,
1050
+ ) -> Result<Vec<RecordInternal>, StorageError> {
1051
+ toc.retrieve(
1052
+ collection_name,
1053
+ request,
1054
+ read_consistency,
1055
+ timeout,
1056
+ shard_selection,
1057
+ access,
1058
+ )
1059
+ .await
1060
+ }
1061
+
1062
+ pub async fn do_scroll_points(
1063
+ toc: &TableOfContent,
1064
+ collection_name: &str,
1065
+ request: ScrollRequestInternal,
1066
+ read_consistency: Option<ReadConsistency>,
1067
+ timeout: Option<Duration>,
1068
+ shard_selection: ShardSelectorInternal,
1069
+ access: Access,
1070
+ ) -> Result<ScrollResult, StorageError> {
1071
+ toc.scroll(
1072
+ collection_name,
1073
+ request,
1074
+ read_consistency,
1075
+ timeout,
1076
+ shard_selection,
1077
+ access,
1078
+ )
1079
+ .await
1080
+ }
1081
+
1082
+ #[allow(clippy::too_many_arguments)]
1083
+ pub async fn do_query_points(
1084
+ toc: &TableOfContent,
1085
+ collection_name: &str,
1086
+ request: CollectionQueryRequest,
1087
+ read_consistency: Option<ReadConsistency>,
1088
+ shard_selection: ShardSelectorInternal,
1089
+ access: Access,
1090
+ timeout: Option<Duration>,
1091
+ hw_measurement_acc: &HwMeasurementAcc,
1092
+ ) -> Result<Vec<ScoredPoint>, StorageError> {
1093
+ let requests = vec![(request, shard_selection)];
1094
+ let batch_res = toc
1095
+ .query_batch(
1096
+ collection_name,
1097
+ requests,
1098
+ read_consistency,
1099
+ access,
1100
+ timeout,
1101
+ hw_measurement_acc,
1102
+ )
1103
+ .await?;
1104
+ batch_res
1105
+ .into_iter()
1106
+ .next()
1107
+ .ok_or_else(|| StorageError::service_error("Empty query result"))
1108
+ }
1109
+
1110
+ #[allow(clippy::too_many_arguments)]
1111
+ pub async fn do_query_batch_points(
1112
+ toc: &TableOfContent,
1113
+ collection_name: &str,
1114
+ requests: Vec<(CollectionQueryRequest, ShardSelectorInternal)>,
1115
+ read_consistency: Option<ReadConsistency>,
1116
+ access: Access,
1117
+ timeout: Option<Duration>,
1118
+ hw_measurement_acc: &HwMeasurementAcc,
1119
+ ) -> Result<Vec<Vec<ScoredPoint>>, StorageError> {
1120
+ toc.query_batch(
1121
+ collection_name,
1122
+ requests,
1123
+ read_consistency,
1124
+ access,
1125
+ timeout,
1126
+ hw_measurement_acc,
1127
+ )
1128
+ .await
1129
+ }
1130
+
1131
+ #[allow(clippy::too_many_arguments)]
1132
+ pub async fn do_query_point_groups(
1133
+ toc: &TableOfContent,
1134
+ collection_name: &str,
1135
+ request: CollectionQueryGroupsRequest,
1136
+ read_consistency: Option<ReadConsistency>,
1137
+ shard_selection: ShardSelectorInternal,
1138
+ access: Access,
1139
+ timeout: Option<Duration>,
1140
+ hw_measurement_acc: &HwMeasurementAcc,
1141
+ ) -> Result<GroupsResult, StorageError> {
1142
+ toc.group(
1143
+ collection_name,
1144
+ GroupRequest::from(request),
1145
+ read_consistency,
1146
+ shard_selection,
1147
+ access,
1148
+ timeout,
1149
+ hw_measurement_acc,
1150
+ )
1151
+ .await
1152
+ }
1153
+
1154
+ #[allow(clippy::too_many_arguments)]
1155
+ pub async fn do_search_points_matrix(
1156
+ toc: &TableOfContent,
1157
+ collection_name: &str,
1158
+ request: CollectionSearchMatrixRequest,
1159
+ read_consistency: Option<ReadConsistency>,
1160
+ shard_selection: ShardSelectorInternal,
1161
+ access: Access,
1162
+ timeout: Option<Duration>,
1163
+ hw_measurement_acc: &HwMeasurementAcc,
1164
+ ) -> Result<CollectionSearchMatrixResponse, StorageError> {
1165
+ toc.search_points_matrix(
1166
+ collection_name,
1167
+ request,
1168
+ read_consistency,
1169
+ shard_selection,
1170
+ access,
1171
+ timeout,
1172
+ hw_measurement_acc,
1173
+ )
1174
+ .await
1175
+ }
src/common/pyroscope_state.rs ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #[cfg(target_os = "linux")]
2
+ pub mod pyro {
3
+
4
+ use pyroscope::pyroscope::PyroscopeAgentRunning;
5
+ use pyroscope::{PyroscopeAgent, PyroscopeError};
6
+ use pyroscope_pprofrs::{pprof_backend, PprofConfig};
7
+
8
+ use crate::common::debugger::PyroscopeConfig;
9
+
10
+ pub struct PyroscopeState {
11
+ pub config: PyroscopeConfig,
12
+ pub agent: Option<PyroscopeAgent<PyroscopeAgentRunning>>,
13
+ }
14
+
15
+ impl PyroscopeState {
16
+ fn build_agent(
17
+ config: &PyroscopeConfig,
18
+ ) -> Result<PyroscopeAgent<PyroscopeAgentRunning>, PyroscopeError> {
19
+ let pprof_config = PprofConfig::new().sample_rate(config.sampling_rate.unwrap_or(100));
20
+ let backend_impl = pprof_backend(pprof_config);
21
+
22
+ log::info!(
23
+ "Starting pyroscope agent with identifier {}",
24
+ &config.identifier
25
+ );
26
+ // TODO: Add more tags like peerId and peerUrl
27
+ let agent = PyroscopeAgent::builder(config.url.to_string(), "qdrant".to_string())
28
+ .backend(backend_impl)
29
+ .tags(vec![("app", "Qdrant"), ("identifier", &config.identifier)])
30
+ .build()?;
31
+ let running_agent = agent.start()?;
32
+
33
+ Ok(running_agent)
34
+ }
35
+
36
+ pub fn from_config(config: Option<PyroscopeConfig>) -> Option<Self> {
37
+ match config {
38
+ Some(pyro_config) => {
39
+ let agent = PyroscopeState::build_agent(&pyro_config);
40
+ match agent {
41
+ Ok(agent) => Some(PyroscopeState {
42
+ config: pyro_config,
43
+ agent: Some(agent),
44
+ }),
45
+ Err(err) => {
46
+ log::warn!("Pyroscope agent failed to start {}", err);
47
+ None
48
+ }
49
+ }
50
+ }
51
+ None => None,
52
+ }
53
+ }
54
+
55
+ pub fn stop_agent(&mut self) -> bool {
56
+ log::info!("Stopping pyroscope agent");
57
+ if let Some(agent) = self.agent.take() {
58
+ match agent.stop() {
59
+ Ok(stopped_agent) => {
60
+ log::info!("Stopped pyroscope agent. Shutting it down");
61
+ stopped_agent.shutdown();
62
+ log::info!("Pyroscope agent shut down completed.");
63
+ return true;
64
+ }
65
+ Err(err) => {
66
+ log::warn!("Pyroscope agent failed to stop {}", err);
67
+ return false;
68
+ }
69
+ }
70
+ }
71
+ true
72
+ }
73
+ }
74
+
75
+ impl Drop for PyroscopeState {
76
+ fn drop(&mut self) {
77
+ self.stop_agent();
78
+ }
79
+ }
80
+ }
81
+
82
+ #[cfg(not(target_os = "linux"))]
83
+ pub mod pyro {
84
+ use crate::common::debugger::PyroscopeConfig;
85
+
86
+ pub struct PyroscopeState {}
87
+
88
+ impl PyroscopeState {
89
+ pub fn from_config(_config: Option<PyroscopeConfig>) -> Option<Self> {
90
+ None
91
+ }
92
+ }
93
+ }
src/common/snapshots.rs ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::sync::Arc;
2
+
3
+ use collection::collection::Collection;
4
+ use collection::common::sha_256::hash_file;
5
+ use collection::common::snapshot_stream::SnapshotStream;
6
+ use collection::operations::snapshot_ops::{
7
+ ShardSnapshotLocation, SnapshotDescription, SnapshotPriority,
8
+ };
9
+ use collection::shards::replica_set::ReplicaState;
10
+ use collection::shards::shard::ShardId;
11
+ use storage::content_manager::errors::StorageError;
12
+ use storage::content_manager::snapshots;
13
+ use storage::content_manager::toc::TableOfContent;
14
+ use storage::rbac::{Access, AccessRequirements};
15
+
16
+ use super::http_client::HttpClient;
17
+
18
+ /// # Cancel safety
19
+ ///
20
+ /// This function is cancel safe.
21
+ pub async fn create_shard_snapshot(
22
+ toc: Arc<TableOfContent>,
23
+ access: Access,
24
+ collection_name: String,
25
+ shard_id: ShardId,
26
+ ) -> Result<SnapshotDescription, StorageError> {
27
+ let collection_pass = access
28
+ .check_collection_access(&collection_name, AccessRequirements::new().write().whole())?;
29
+ let collection = toc.get_collection(&collection_pass).await?;
30
+
31
+ let snapshot = collection
32
+ .create_shard_snapshot(shard_id, &toc.optional_temp_or_snapshot_temp_path()?)
33
+ .await?;
34
+
35
+ Ok(snapshot)
36
+ }
37
+
38
+ /// # Cancel safety
39
+ ///
40
+ /// This function is cancel safe.
41
+ pub async fn stream_shard_snapshot(
42
+ toc: Arc<TableOfContent>,
43
+ access: Access,
44
+ collection_name: String,
45
+ shard_id: ShardId,
46
+ ) -> Result<SnapshotStream, StorageError> {
47
+ let collection_pass = access
48
+ .check_collection_access(&collection_name, AccessRequirements::new().write().whole())?;
49
+ let collection = toc.get_collection(&collection_pass).await?;
50
+
51
+ Ok(collection
52
+ .stream_shard_snapshot(shard_id, &toc.optional_temp_or_snapshot_temp_path()?)
53
+ .await?)
54
+ }
55
+
56
+ /// # Cancel safety
57
+ ///
58
+ /// This function is cancel safe.
59
+ pub async fn list_shard_snapshots(
60
+ toc: Arc<TableOfContent>,
61
+ access: Access,
62
+ collection_name: String,
63
+ shard_id: ShardId,
64
+ ) -> Result<Vec<SnapshotDescription>, StorageError> {
65
+ let collection_pass =
66
+ access.check_collection_access(&collection_name, AccessRequirements::new().whole())?;
67
+ let collection = toc.get_collection(&collection_pass).await?;
68
+ let snapshots = collection.list_shard_snapshots(shard_id).await?;
69
+ Ok(snapshots)
70
+ }
71
+
72
+ /// # Cancel safety
73
+ ///
74
+ /// This function is cancel safe.
75
+ pub async fn delete_shard_snapshot(
76
+ toc: Arc<TableOfContent>,
77
+ access: Access,
78
+ collection_name: String,
79
+ shard_id: ShardId,
80
+ snapshot_name: String,
81
+ ) -> Result<(), StorageError> {
82
+ let collection_pass = access
83
+ .check_collection_access(&collection_name, AccessRequirements::new().write().whole())?;
84
+ let collection = toc.get_collection(&collection_pass).await?;
85
+ let snapshot_manager = collection.get_snapshots_storage_manager()?;
86
+
87
+ let snapshot_path = collection
88
+ .shards_holder()
89
+ .read()
90
+ .await
91
+ .get_shard_snapshot_path(collection.snapshots_path(), shard_id, &snapshot_name)
92
+ .await?;
93
+
94
+ tokio::spawn(async move { snapshot_manager.delete_snapshot(&snapshot_path).await }).await??;
95
+
96
+ Ok(())
97
+ }
98
+
99
+ /// # Cancel safety
100
+ ///
101
+ /// This function is cancel safe.
102
+ #[allow(clippy::too_many_arguments)]
103
+ pub async fn recover_shard_snapshot(
104
+ toc: Arc<TableOfContent>,
105
+ access: Access,
106
+ collection_name: String,
107
+ shard_id: ShardId,
108
+ snapshot_location: ShardSnapshotLocation,
109
+ snapshot_priority: SnapshotPriority,
110
+ checksum: Option<String>,
111
+ client: HttpClient,
112
+ api_key: Option<String>,
113
+ ) -> Result<(), StorageError> {
114
+ let collection_pass = access
115
+ .check_global_access(AccessRequirements::new().manage())?
116
+ .issue_pass(&collection_name)
117
+ .into_static();
118
+
119
+ // - `recover_shard_snapshot_impl` is *not* cancel safe
120
+ // - but the task is *spawned* on the runtime and won't be cancelled, if request is cancelled
121
+
122
+ cancel::future::spawn_cancel_on_drop(move |cancel| async move {
123
+ let future = async {
124
+ let collection = toc.get_collection(&collection_pass).await?;
125
+ collection.assert_shard_exists(shard_id).await?;
126
+
127
+ let download_dir = toc.optional_temp_or_snapshot_temp_path()?;
128
+
129
+ let snapshot_path = match snapshot_location {
130
+ ShardSnapshotLocation::Url(url) => {
131
+ if !matches!(url.scheme(), "http" | "https") {
132
+ let description = format!(
133
+ "Invalid snapshot URL {url}: URLs with {} scheme are not supported",
134
+ url.scheme(),
135
+ );
136
+
137
+ return Err(StorageError::bad_input(description));
138
+ }
139
+
140
+ let client = client.client(api_key.as_deref())?;
141
+
142
+ snapshots::download::download_snapshot(&client, url, &download_dir).await?
143
+ }
144
+
145
+ ShardSnapshotLocation::Path(snapshot_file_name) => {
146
+ let snapshot_path = collection
147
+ .shards_holder()
148
+ .read()
149
+ .await
150
+ .get_shard_snapshot_path(
151
+ collection.snapshots_path(),
152
+ shard_id,
153
+ &snapshot_file_name,
154
+ )
155
+ .await?;
156
+
157
+ collection
158
+ .get_snapshots_storage_manager()?
159
+ .get_snapshot_file(&snapshot_path, &download_dir)
160
+ .await?
161
+ }
162
+ };
163
+
164
+ if let Some(checksum) = checksum {
165
+ let snapshot_checksum = hash_file(&snapshot_path).await?;
166
+ if snapshot_checksum != checksum {
167
+ return Err(StorageError::bad_input(format!(
168
+ "Snapshot checksum mismatch: expected {checksum}, got {snapshot_checksum}"
169
+ )));
170
+ }
171
+ }
172
+
173
+ Result::<_, StorageError>::Ok((collection, snapshot_path))
174
+ };
175
+
176
+ let (collection, snapshot_path) =
177
+ cancel::future::cancel_on_token(cancel.clone(), future).await??;
178
+
179
+ // `recover_shard_snapshot_impl` is *not* cancel safe
180
+ let result = recover_shard_snapshot_impl(
181
+ &toc,
182
+ &collection,
183
+ shard_id,
184
+ &snapshot_path,
185
+ snapshot_priority,
186
+ cancel,
187
+ )
188
+ .await;
189
+
190
+ // Remove snapshot after recovery if downloaded
191
+ if let Err(err) = snapshot_path.close() {
192
+ log::error!("Failed to remove downloaded shards snapshot after recovery: {err}");
193
+ }
194
+
195
+ result
196
+ })
197
+ .await??;
198
+
199
+ Ok(())
200
+ }
201
+
202
+ /// # Cancel safety
203
+ ///
204
+ /// This function is *not* cancel safe.
205
+ pub async fn recover_shard_snapshot_impl(
206
+ toc: &TableOfContent,
207
+ collection: &Collection,
208
+ shard: ShardId,
209
+ snapshot_path: &std::path::Path,
210
+ priority: SnapshotPriority,
211
+ cancel: cancel::CancellationToken,
212
+ ) -> Result<(), StorageError> {
213
+ // `Collection::restore_shard_snapshot` and `activate_shard` calls *have to* be executed as a
214
+ // single transaction
215
+ //
216
+ // It is *possible* to make this function to be cancel safe, but it is *extremely tedious* to do so
217
+
218
+ // `Collection::restore_shard_snapshot` is *not* cancel safe
219
+ // (see `ShardReplicaSet::restore_local_replica_from`)
220
+ collection
221
+ .restore_shard_snapshot(
222
+ shard,
223
+ snapshot_path,
224
+ toc.this_peer_id,
225
+ toc.is_distributed(),
226
+ &toc.optional_temp_or_snapshot_temp_path()?,
227
+ cancel,
228
+ )
229
+ .await?;
230
+
231
+ let state = collection.state().await;
232
+ let shard_info = state.shards.get(&shard).unwrap(); // TODO: Handle `unwrap`?..
233
+
234
+ // TODO: Unify (and de-duplicate) "recovered shard state notification" logic in `_do_recover_from_snapshot` with this one!
235
+
236
+ let other_active_replicas: Vec<_> = shard_info
237
+ .replicas
238
+ .iter()
239
+ .map(|(&peer, &state)| (peer, state))
240
+ .filter(|&(peer, state)| peer != toc.this_peer_id && state == ReplicaState::Active)
241
+ .collect();
242
+
243
+ if other_active_replicas.is_empty() {
244
+ snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard).await?;
245
+ } else {
246
+ match priority {
247
+ SnapshotPriority::NoSync => {
248
+ snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard)
249
+ .await?;
250
+ }
251
+
252
+ SnapshotPriority::Snapshot => {
253
+ snapshots::recover::activate_shard(toc, collection, toc.this_peer_id, &shard)
254
+ .await?;
255
+
256
+ for &(peer, _) in other_active_replicas.iter() {
257
+ toc.send_set_replica_state_proposal(
258
+ collection.name(),
259
+ peer,
260
+ shard,
261
+ ReplicaState::Dead,
262
+ None,
263
+ )?;
264
+ }
265
+ }
266
+
267
+ SnapshotPriority::Replica => {
268
+ toc.send_set_replica_state_proposal(
269
+ collection.name(),
270
+ toc.this_peer_id,
271
+ shard,
272
+ ReplicaState::Dead,
273
+ None,
274
+ )?;
275
+ }
276
+
277
+ // `ShardTransfer` is only used during snapshot *shard transfer*.
278
+ // State transitions are performed as part of shard transfer *later*, so this simply does *nothing*.
279
+ SnapshotPriority::ShardTransfer => (),
280
+ }
281
+ }
282
+
283
+ Ok(())
284
+ }
src/common/stacktrace.rs ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use schemars::JsonSchema;
2
+ use serde::{Deserialize, Serialize};
3
+
4
+ #[derive(Deserialize, Serialize, JsonSchema, Debug)]
5
+ struct StackTraceSymbol {
6
+ name: Option<String>,
7
+ file: Option<String>,
8
+ line: Option<u32>,
9
+ }
10
+
11
+ #[derive(Deserialize, Serialize, JsonSchema, Debug)]
12
+ struct StackTraceFrame {
13
+ symbols: Vec<StackTraceSymbol>,
14
+ }
15
+
16
+ impl StackTraceFrame {
17
+ pub fn render(&self) -> String {
18
+ let mut result = String::new();
19
+ for symbol in &self.symbols {
20
+ let symbol_string = format!(
21
+ "{}:{} - {} ",
22
+ symbol.file.as_deref().unwrap_or_default(),
23
+ symbol.line.unwrap_or_default(),
24
+ symbol.name.as_deref().unwrap_or_default(),
25
+ );
26
+ result.push_str(&symbol_string);
27
+ }
28
+ result
29
+ }
30
+ }
31
+
32
+ #[derive(Deserialize, Serialize, JsonSchema, Debug)]
33
+ pub struct ThreadStackTrace {
34
+ id: u32,
35
+ name: String,
36
+ frames: Vec<String>,
37
+ }
38
+
39
+ #[derive(Deserialize, Serialize, JsonSchema, Debug)]
40
+ pub struct StackTrace {
41
+ threads: Vec<ThreadStackTrace>,
42
+ }
43
+
44
+ pub fn get_stack_trace() -> StackTrace {
45
+ #[cfg(not(all(target_os = "linux", feature = "stacktrace")))]
46
+ {
47
+ StackTrace { threads: vec![] }
48
+ }
49
+
50
+ #[cfg(all(target_os = "linux", feature = "stacktrace"))]
51
+ {
52
+ let exe = std::env::current_exe().unwrap();
53
+ let trace =
54
+ rstack_self::trace(std::process::Command::new(exe).arg("--stacktrace")).unwrap();
55
+ StackTrace {
56
+ threads: trace
57
+ .threads()
58
+ .iter()
59
+ .map(|thread| ThreadStackTrace {
60
+ id: thread.id(),
61
+ name: thread.name().to_string(),
62
+ frames: thread
63
+ .frames()
64
+ .iter()
65
+ .map(|frame| {
66
+ let frame = StackTraceFrame {
67
+ symbols: frame
68
+ .symbols()
69
+ .iter()
70
+ .map(|symbol| StackTraceSymbol {
71
+ name: symbol.name().map(|name| name.to_string()),
72
+ file: symbol.file().map(|file| {
73
+ file.to_str().unwrap_or_default().to_string()
74
+ }),
75
+ line: symbol.line(),
76
+ })
77
+ .collect(),
78
+ };
79
+ frame.render()
80
+ })
81
+ .collect(),
82
+ })
83
+ .collect(),
84
+ }
85
+ }
86
+ }
src/common/strings.rs ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /// Constant-time equality for String types
2
+ #[inline]
3
+ pub fn ct_eq(lhs: impl AsRef<str>, rhs: impl AsRef<str>) -> bool {
4
+ constant_time_eq::constant_time_eq(lhs.as_ref().as_bytes(), rhs.as_ref().as_bytes())
5
+ }
src/common/telemetry.rs ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::sync::Arc;
2
+
3
+ use collection::operations::verification::new_unchecked_verification_pass;
4
+ use common::types::{DetailsLevel, TelemetryDetail};
5
+ use parking_lot::Mutex;
6
+ use schemars::JsonSchema;
7
+ use segment::common::anonymize::Anonymize;
8
+ use serde::Serialize;
9
+ use storage::dispatcher::Dispatcher;
10
+ use storage::rbac::Access;
11
+ use uuid::Uuid;
12
+
13
+ use crate::common::telemetry_ops::app_telemetry::{AppBuildTelemetry, AppBuildTelemetryCollector};
14
+ use crate::common::telemetry_ops::cluster_telemetry::ClusterTelemetry;
15
+ use crate::common::telemetry_ops::collections_telemetry::CollectionsTelemetry;
16
+ use crate::common::telemetry_ops::memory_telemetry::MemoryTelemetry;
17
+ use crate::common::telemetry_ops::requests_telemetry::{
18
+ ActixTelemetryCollector, RequestsTelemetry, TonicTelemetryCollector,
19
+ };
20
+ use crate::settings::Settings;
21
+
22
+ pub struct TelemetryCollector {
23
+ process_id: Uuid,
24
+ settings: Settings,
25
+ dispatcher: Arc<Dispatcher>,
26
+ pub app_telemetry_collector: AppBuildTelemetryCollector,
27
+ pub actix_telemetry_collector: Arc<Mutex<ActixTelemetryCollector>>,
28
+ pub tonic_telemetry_collector: Arc<Mutex<TonicTelemetryCollector>>,
29
+ }
30
+
31
+ // Whole telemetry data
32
+ #[derive(Serialize, Clone, Debug, JsonSchema)]
33
+ pub struct TelemetryData {
34
+ id: String,
35
+ pub(crate) app: AppBuildTelemetry,
36
+ pub(crate) collections: CollectionsTelemetry,
37
+ pub(crate) cluster: ClusterTelemetry,
38
+ pub(crate) requests: RequestsTelemetry,
39
+ pub(crate) memory: Option<MemoryTelemetry>,
40
+ }
41
+
42
+ impl Anonymize for TelemetryData {
43
+ fn anonymize(&self) -> Self {
44
+ TelemetryData {
45
+ id: self.id.clone(),
46
+ app: self.app.anonymize(),
47
+ collections: self.collections.anonymize(),
48
+ cluster: self.cluster.anonymize(),
49
+ requests: self.requests.anonymize(),
50
+ memory: self.memory.anonymize(),
51
+ }
52
+ }
53
+ }
54
+
55
+ impl TelemetryCollector {
56
+ pub fn reporting_id(&self) -> String {
57
+ self.process_id.to_string()
58
+ }
59
+
60
+ pub fn generate_id() -> Uuid {
61
+ Uuid::new_v4()
62
+ }
63
+
64
+ pub fn new(settings: Settings, dispatcher: Arc<Dispatcher>, id: Uuid) -> Self {
65
+ Self {
66
+ process_id: id,
67
+ settings,
68
+ dispatcher,
69
+ app_telemetry_collector: AppBuildTelemetryCollector::new(),
70
+ actix_telemetry_collector: Arc::new(Mutex::new(ActixTelemetryCollector {
71
+ workers: Vec::new(),
72
+ })),
73
+ tonic_telemetry_collector: Arc::new(Mutex::new(TonicTelemetryCollector {
74
+ workers: Vec::new(),
75
+ })),
76
+ }
77
+ }
78
+
79
+ pub async fn prepare_data(&self, access: &Access, detail: TelemetryDetail) -> TelemetryData {
80
+ TelemetryData {
81
+ id: self.process_id.to_string(),
82
+ collections: CollectionsTelemetry::collect(
83
+ detail,
84
+ access,
85
+ self.dispatcher
86
+ .toc(access, &new_unchecked_verification_pass()),
87
+ )
88
+ .await,
89
+ app: AppBuildTelemetry::collect(detail, &self.app_telemetry_collector, &self.settings),
90
+ cluster: ClusterTelemetry::collect(detail, &self.dispatcher, &self.settings),
91
+ requests: RequestsTelemetry::collect(
92
+ &self.actix_telemetry_collector.lock(),
93
+ &self.tonic_telemetry_collector.lock(),
94
+ detail,
95
+ ),
96
+ memory: (detail.level > DetailsLevel::Level0)
97
+ .then(MemoryTelemetry::collect)
98
+ .flatten(),
99
+ }
100
+ }
101
+ }