File size: 4,889 Bytes
d8435ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use validator::Validate;

use super::claims::Claims;
use super::AuthError;

#[derive(Clone)]
pub struct JwtParser {
    key: DecodingKey,
    validation: Validation,
}

impl JwtParser {
    const ALGORITHM: Algorithm = Algorithm::HS256;

    pub fn new(secret: &str) -> Self {
        let key = DecodingKey::from_secret(secret.as_bytes());
        let mut validation = Validation::new(Self::ALGORITHM);

        // Qdrant server is the only audience
        validation.validate_aud = false;

        // Expiration time leeway to account for clock skew
        validation.leeway = 30;

        // All claims are optional
        validation.required_spec_claims = Default::default();

        JwtParser { key, validation }
    }

    /// Decode the token and return the claims, this already validates the `exp` claim with some leeway.
    /// Returns None when the token doesn't look like a JWT.
    pub fn decode(&self, token: &str) -> Option<Result<Claims, AuthError>> {
        let claims = match decode::<Claims>(token, &self.key, &self.validation) {
            Ok(token_data) => token_data.claims,
            Err(e) => {
                return match e.kind() {
                    ErrorKind::ExpiredSignature | ErrorKind::InvalidSignature => {
                        Some(Err(AuthError::Forbidden(e.to_string())))
                    }
                    _ => None,
                }
            }
        };
        if let Err(e) = claims.validate() {
            return Some(Err(AuthError::Unauthorized(e.to_string())));
        }
        Some(Ok(claims))
    }
}

#[cfg(test)]
mod tests {
    use segment::types::ValueVariants;
    use storage::rbac::{
        Access, CollectionAccess, CollectionAccessList, CollectionAccessMode, GlobalAccessMode,
        PayloadConstraint,
    };

    use super::*;

    pub fn create_token(claims: &Claims) -> String {
        use jsonwebtoken::{encode, EncodingKey, Header};

        let key = EncodingKey::from_secret("secret".as_ref());
        let header = Header::new(JwtParser::ALGORITHM);
        encode(&header, claims, &key).unwrap()
    }

    #[test]
    fn test_jwt_parser() {
        let exp = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .expect("Time went backwards")
            .as_secs();
        let claims = Claims {
            exp: Some(exp),
            access: Access::Collection(CollectionAccessList(vec![CollectionAccess {
                collection: "collection".to_string(),
                access: CollectionAccessMode::ReadWrite,
                payload: Some(PayloadConstraint(
                    vec![
                        (
                            "field1".parse().unwrap(),
                            ValueVariants::String("value".to_string()),
                        ),
                        ("field2".parse().unwrap(), ValueVariants::Integer(42)),
                        ("field3".parse().unwrap(), ValueVariants::Bool(true)),
                    ]
                    .into_iter()
                    .collect(),
                )),
            }])),
            value_exists: None,
        };
        let token = create_token(&claims);

        let secret = "secret";
        let parser = JwtParser::new(secret);
        let decoded_claims = parser.decode(&token).unwrap().unwrap();

        assert_eq!(claims, decoded_claims);
    }

    #[test]
    fn test_exp_validation() {
        let exp = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .expect("Time went backwards")
            .as_secs()
            - 31; // 31 seconds in the past, bigger than the 30 seconds leeway

        let mut claims = Claims {
            exp: Some(exp),
            access: Access::Global(GlobalAccessMode::Read),
            value_exists: None,
        };

        let token = create_token(&claims);

        let secret = "secret";
        let parser = JwtParser::new(secret);
        assert!(matches!(
            parser.decode(&token),
            Some(Err(AuthError::Forbidden(_)))
        ));

        // Remove the exp claim and it should work
        claims.exp = None;
        let token = create_token(&claims);

        let decoded_claims = parser.decode(&token).unwrap().unwrap();

        assert_eq!(claims, decoded_claims);
    }

    #[test]
    fn test_invalid_token() {
        let claims = Claims {
            exp: None,
            access: Access::Global(GlobalAccessMode::Read),
            value_exists: None,
        };
        let token = create_token(&claims);

        assert!(matches!(
            JwtParser::new("wrong-secret").decode(&token),
            Some(Err(AuthError::Forbidden(_)))
        ));

        assert!(JwtParser::new("secret").decode("foo.bar.baz").is_none());
    }
}