use std::sync::Arc; use std::task::{Context, Poll}; use futures::future::BoxFuture; use storage::rbac::Access; use tonic::body::BoxBody; use tonic::Status; use tower::{Layer, Service}; use crate::common::auth::{AuthError, AuthKeys}; type Request = tonic::codegen::http::Request; type Response = tonic::codegen::http::Response; #[derive(Clone)] pub struct AuthMiddleware { auth_keys: Arc, service: S, } async fn check(auth_keys: Arc, mut req: Request) -> Result { let access = auth_keys .validate_request(|key| req.headers().get(key).and_then(|val| val.to_str().ok())) .await .map_err(|e| match e { AuthError::Unauthorized(e) => Status::unauthenticated(e), AuthError::Forbidden(e) => Status::permission_denied(e), AuthError::StorageError(e) => Status::from(e), })?; let previous = req.extensions_mut().insert::(access); debug_assert!( previous.is_none(), "Previous access object should not exist in the request" ); Ok(req) } impl Service for AuthMiddleware where S: Service + Clone + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { let auth_keys = self.auth_keys.clone(); let mut service = self.service.clone(); Box::pin(async move { match check(auth_keys, request).await { Ok(req) => service.call(req).await, Err(e) => Ok(e.to_http()), } }) } } #[derive(Clone)] pub struct AuthLayer { auth_keys: Arc, } impl AuthLayer { pub fn new(auth_keys: AuthKeys) -> Self { Self { auth_keys: Arc::new(auth_keys), } } } impl Layer for AuthLayer { type Service = AuthMiddleware; fn layer(&self, service: S) -> Self::Service { Self::Service { auth_keys: self.auth_keys.clone(), service, } } } pub fn extract_access(req: &mut tonic::Request) -> Access { req.extensions_mut().remove::().unwrap_or_else(|| { Access::full("All requests have full by default access when API key is not configured") }) }