Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contextvar util.""" | |
from __future__ import annotations | |
import contextvars | |
import dataclasses | |
import functools | |
import typing | |
from typing import Any, Optional, TypeVar | |
from etils.edc import helpers | |
from typing_extensions import Annotated | |
_T = TypeVar('_T') | |
_Dataclass = Any | |
_DataclassT = TypeVar('_DataclassT') | |
_IS_CONTEXTVAR = object() | |
if typing.TYPE_CHECKING: | |
# TODO(b/254514368): Remove hack | |
class _ContextVarMeta(type): | |
def __getitem__(cls, value): | |
return value | |
class ContextVar(metaclass=_ContextVarMeta): | |
pass | |
else: | |
ContextVar = Annotated[_T, _IS_CONTEXTVAR] # pytype: disable=invalid-typevar | |
def make_contextvar_descriptor( | |
field: dataclasses.Field[Any], hint: helpers.Hint | |
) -> _ContextvarDescriptor: | |
"""Replace `ContextVar[]` annotated fields with contextvar descriptor.""" | |
del hint | |
return _ContextvarDescriptor(field) | |
class _ContextvarDescriptor: | |
"""Descriptor to read-write individual contextvar.""" | |
_field: dataclasses.Field[Any] | |
_objtype: type[Any] = dataclasses.field(init=False) | |
_attribute_name: str = dataclasses.field(init=False) | |
def __set_name__(self, objtype: type[_Dataclass], name: str) -> None: | |
"""Bind the descriptor to the class (PEP 487).""" | |
self._objtype = objtype | |
self._attribute_name = name | |
def _var(self) -> contextvars.ContextVar[Any]: | |
"""Contextvar.""" | |
default_kwargs = {} | |
if self._field.default is not dataclasses.MISSING: | |
default_kwargs['default'] = self._field.default | |
elif self._field.default_factory is not dataclasses.MISSING: | |
default_kwargs['default'] = self._field.default_factory() | |
else: | |
pass | |
return contextvars.ContextVar(self._attribute_name, **default_kwargs) | |
def __get__( | |
self, | |
obj: Optional[_Dataclass], | |
objtype: Optional[type[_Dataclass]] = None, | |
): | |
if obj is None: | |
return self | |
return self._var.get() | |
def __set__(self, obj: _Dataclass, value: Any) -> None: | |
self._var.set(value) | |