Spaces:
Build error
Build error
File size: 4,946 Bytes
d8435ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
use std::sync::Arc;
use collection::operations::shard_selector_internal::ShardSelectorInternal;
use collection::operations::types::ScrollRequestInternal;
use segment::types::{WithPayloadInterface, WithVector};
use storage::content_manager::errors::StorageError;
use storage::content_manager::toc::TableOfContent;
use storage::rbac::Access;
use self::claims::{Claims, ValueExists};
use self::jwt_parser::JwtParser;
use super::strings::ct_eq;
use crate::settings::ServiceConfig;
pub mod claims;
pub mod jwt_parser;
pub const HTTP_HEADER_API_KEY: &str = "api-key";
/// The API keys used for auth
#[derive(Clone)]
pub struct AuthKeys {
/// A key allowing Read or Write operations
read_write: Option<String>,
/// A key allowing Read operations
read_only: Option<String>,
/// A JWT parser, based on the read_write key
jwt_parser: Option<JwtParser>,
/// Table of content, needed to do stateful validation of JWT
toc: Arc<TableOfContent>,
}
#[derive(Debug)]
pub enum AuthError {
Unauthorized(String),
Forbidden(String),
StorageError(StorageError),
}
impl AuthKeys {
fn get_jwt_parser(service_config: &ServiceConfig) -> Option<JwtParser> {
if service_config.jwt_rbac.unwrap_or_default() {
service_config
.api_key
.as_ref()
.map(|secret| JwtParser::new(secret))
} else {
None
}
}
/// Defines the auth scheme given the service config
///
/// Returns None if no scheme is specified.
pub fn try_create(service_config: &ServiceConfig, toc: Arc<TableOfContent>) -> Option<Self> {
match (
service_config.api_key.clone(),
service_config.read_only_api_key.clone(),
) {
(None, None) => None,
(read_write, read_only) => Some(Self {
read_write,
read_only,
jwt_parser: Self::get_jwt_parser(service_config),
toc,
}),
}
}
/// Validate that the specified request is allowed for given keys.
pub async fn validate_request<'a>(
&self,
get_header: impl Fn(&'a str) -> Option<&'a str>,
) -> Result<Access, AuthError> {
let Some(key) = get_header(HTTP_HEADER_API_KEY)
.or_else(|| get_header("authorization").and_then(|v| v.strip_prefix("Bearer ")))
else {
return Err(AuthError::Unauthorized(
"Must provide an API key or an Authorization bearer token".to_string(),
));
};
if self.can_write(key) {
return Ok(Access::full("Read-write access by key"));
}
if self.can_read(key) {
return Ok(Access::full_ro("Read-only access by key"));
}
if let Some(claims) = self.jwt_parser.as_ref().and_then(|p| p.decode(key)) {
let Claims {
exp: _, // already validated on decoding
access,
value_exists,
} = claims?;
if let Some(value_exists) = value_exists {
self.validate_value_exists(&value_exists).await?;
}
return Ok(access);
}
Err(AuthError::Unauthorized(
"Invalid API key or JWT".to_string(),
))
}
async fn validate_value_exists(&self, value_exists: &ValueExists) -> Result<(), AuthError> {
let scroll_req = ScrollRequestInternal {
offset: None,
limit: Some(1),
filter: Some(value_exists.to_filter()),
with_payload: Some(WithPayloadInterface::Bool(false)),
with_vector: WithVector::Bool(false),
order_by: None,
};
let res = self
.toc
.scroll(
value_exists.get_collection(),
scroll_req,
None,
None, // no timeout
ShardSelectorInternal::All,
Access::full("JWT stateful validation"),
)
.await
.map_err(|e| match e {
StorageError::NotFound { .. } => {
AuthError::Forbidden("Invalid JWT, stateful validation failed".to_string())
}
_ => AuthError::StorageError(e),
})?;
if res.points.is_empty() {
return Err(AuthError::Unauthorized(
"Invalid JWT, stateful validation failed".to_string(),
));
};
Ok(())
}
/// Check if a key is allowed to read
#[inline]
fn can_read(&self, key: &str) -> bool {
self.read_only
.as_ref()
.is_some_and(|ro_key| ct_eq(ro_key, key))
}
/// Check if a key is allowed to write
#[inline]
fn can_write(&self, key: &str) -> bool {
self.read_write
.as_ref()
.is_some_and(|rw_key| ct_eq(rw_key, key))
}
}
|