Spaces:
Build error
Build error
use jsonwebtoken::errors::ErrorKind; | |
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; | |
use validator::Validate; | |
use super::claims::Claims; | |
use super::AuthError; | |
pub struct JwtParser { | |
key: DecodingKey, | |
validation: Validation, | |
} | |
impl JwtParser { | |
const ALGORITHM: Algorithm = Algorithm::HS256; | |
pub fn new(secret: &str) -> Self { | |
let key = DecodingKey::from_secret(secret.as_bytes()); | |
let mut validation = Validation::new(Self::ALGORITHM); | |
// Qdrant server is the only audience | |
validation.validate_aud = false; | |
// Expiration time leeway to account for clock skew | |
validation.leeway = 30; | |
// All claims are optional | |
validation.required_spec_claims = Default::default(); | |
JwtParser { key, validation } | |
} | |
/// Decode the token and return the claims, this already validates the `exp` claim with some leeway. | |
/// Returns None when the token doesn't look like a JWT. | |
pub fn decode(&self, token: &str) -> Option<Result<Claims, AuthError>> { | |
let claims = match decode::<Claims>(token, &self.key, &self.validation) { | |
Ok(token_data) => token_data.claims, | |
Err(e) => { | |
return match e.kind() { | |
ErrorKind::ExpiredSignature | ErrorKind::InvalidSignature => { | |
Some(Err(AuthError::Forbidden(e.to_string()))) | |
} | |
_ => None, | |
} | |
} | |
}; | |
if let Err(e) = claims.validate() { | |
return Some(Err(AuthError::Unauthorized(e.to_string()))); | |
} | |
Some(Ok(claims)) | |
} | |
} | |
mod tests { | |
use segment::types::ValueVariants; | |
use storage::rbac::{ | |
Access, CollectionAccess, CollectionAccessList, CollectionAccessMode, GlobalAccessMode, | |
PayloadConstraint, | |
}; | |
use super::*; | |
pub fn create_token(claims: &Claims) -> String { | |
use jsonwebtoken::{encode, EncodingKey, Header}; | |
let key = EncodingKey::from_secret("secret".as_ref()); | |
let header = Header::new(JwtParser::ALGORITHM); | |
encode(&header, claims, &key).unwrap() | |
} | |
fn test_jwt_parser() { | |
let exp = std::time::SystemTime::now() | |
.duration_since(std::time::UNIX_EPOCH) | |
.expect("Time went backwards") | |
.as_secs(); | |
let claims = Claims { | |
exp: Some(exp), | |
access: Access::Collection(CollectionAccessList(vec![CollectionAccess { | |
collection: "collection".to_string(), | |
access: CollectionAccessMode::ReadWrite, | |
payload: Some(PayloadConstraint( | |
vec![ | |
( | |
"field1".parse().unwrap(), | |
ValueVariants::String("value".to_string()), | |
), | |
("field2".parse().unwrap(), ValueVariants::Integer(42)), | |
("field3".parse().unwrap(), ValueVariants::Bool(true)), | |
] | |
.into_iter() | |
.collect(), | |
)), | |
}])), | |
value_exists: None, | |
}; | |
let token = create_token(&claims); | |
let secret = "secret"; | |
let parser = JwtParser::new(secret); | |
let decoded_claims = parser.decode(&token).unwrap().unwrap(); | |
assert_eq!(claims, decoded_claims); | |
} | |
fn test_exp_validation() { | |
let exp = std::time::SystemTime::now() | |
.duration_since(std::time::UNIX_EPOCH) | |
.expect("Time went backwards") | |
.as_secs() | |
- 31; // 31 seconds in the past, bigger than the 30 seconds leeway | |
let mut claims = Claims { | |
exp: Some(exp), | |
access: Access::Global(GlobalAccessMode::Read), | |
value_exists: None, | |
}; | |
let token = create_token(&claims); | |
let secret = "secret"; | |
let parser = JwtParser::new(secret); | |
assert!(matches!( | |
parser.decode(&token), | |
Some(Err(AuthError::Forbidden(_))) | |
)); | |
// Remove the exp claim and it should work | |
claims.exp = None; | |
let token = create_token(&claims); | |
let decoded_claims = parser.decode(&token).unwrap().unwrap(); | |
assert_eq!(claims, decoded_claims); | |
} | |
fn test_invalid_token() { | |
let claims = Claims { | |
exp: None, | |
access: Access::Global(GlobalAccessMode::Read), | |
value_exists: None, | |
}; | |
let token = create_token(&claims); | |
assert!(matches!( | |
JwtParser::new("wrong-secret").decode(&token), | |
Some(Err(AuthError::Forbidden(_))) | |
)); | |
assert!(JwtParser::new("secret").decode("foo.bar.baz").is_none()); | |
} | |
} | |