Omkar008 commited on
Commit
462db34
·
verified ·
1 Parent(s): 3bf8504

Create fastapi_globals.py

Browse files
Files changed (1) hide show
  1. models/fastapi_globals.py +91 -0
models/fastapi_globals.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Awaitable, Callable
2
+ from contextvars import ContextVar, copy_context
3
+ from typing import Any
4
+
5
+ from fastapi import Request, Response
6
+ from starlette.middleware.base import BaseHTTPMiddleware
7
+ from starlette.types import ASGIApp
8
+
9
+
10
+ class Globals:
11
+ __slots__ = ("_vars", "_defaults")
12
+
13
+ _vars: dict[str, ContextVar]
14
+ _defaults: dict[str, Any]
15
+
16
+ def __init__(self) -> None:
17
+ object.__setattr__(self, "_vars", {})
18
+ object.__setattr__(self, "_defaults", {})
19
+
20
+ def cleanup(self):
21
+ """Clear all variables and free memory."""
22
+
23
+ self._vars.clear()
24
+ self._defaults.clear()
25
+ del self._vars
26
+ del self._defaults
27
+
28
+ def set_default(self, name: str, default: Any) -> None:
29
+ """Set a default value for a variable."""
30
+
31
+ # Ignore if default is already set and is the same value
32
+ if name in self._defaults and default is self._defaults[name]:
33
+ return
34
+
35
+ # Ensure we don't have a value set already - the default will have
36
+ # no effect then
37
+ if name in self._vars:
38
+ raise RuntimeError(
39
+ f"Cannot set default as variable {name} was already set",
40
+ )
41
+
42
+ # Set the default already!
43
+ self._defaults[name] = default
44
+
45
+ def _get_default_value(self, name: str) -> Any:
46
+ """Get the default value for a variable."""
47
+
48
+ default = self._defaults.get(name, None)
49
+
50
+ # return default() if callable(default) else default
51
+ return default
52
+
53
+ def _ensure_var(self, name: str) -> None:
54
+ """Ensure a ContextVar exists for a variable."""
55
+ if name not in self._vars:
56
+ default = self._get_default_value(name)
57
+ self._vars[name] = ContextVar(f"globals:{name}", default=default)
58
+
59
+ def __getattr__(self, name: str) -> Any:
60
+ """Get the value of a variable."""
61
+
62
+ self._ensure_var(name)
63
+ return self._vars[name].get()
64
+
65
+ def __setattr__(self, name: str, value: Any) -> None:
66
+ """Set the value of a variable."""
67
+
68
+ self._ensure_var(name)
69
+ self._vars[name].set(value)
70
+
71
+
72
+ async def globals_middleware_dispatch(
73
+ request: Request,
74
+ call_next: Callable,
75
+ ) -> Response:
76
+ """Dispatch the request in a new context to allow globals to be used."""
77
+
78
+ ctx = copy_context()
79
+
80
+ def _call_next() -> Awaitable[Response]:
81
+ return call_next(request)
82
+
83
+ return await ctx.run(_call_next)
84
+
85
+
86
+ class GlobalsMiddleware(BaseHTTPMiddleware): # noqa
87
+ """Middleware to setup the globals context using globals_middleware_dispatch()."""
88
+
89
+ def __init__(self, app: ASGIApp) -> None:
90
+ super().__init__(app, globals_middleware_dispatch)
91
+ g = Globals()