Spaces:
Build error
Build error
import asyncio | |
import calendar | |
import contextlib | |
import datetime | |
import heapq | |
import itertools | |
import os # noqa | |
import pathlib | |
import pickle | |
import re | |
import time | |
from collections import defaultdict | |
from http.cookies import BaseCookie, Morsel, SimpleCookie | |
from typing import ( | |
DefaultDict, | |
Dict, | |
Iterable, | |
Iterator, | |
List, | |
Mapping, | |
Optional, | |
Set, | |
Tuple, | |
Union, | |
cast, | |
) | |
from yarl import URL | |
from .abc import AbstractCookieJar, ClearCookiePredicate | |
from .helpers import is_ip_address | |
from .typedefs import LooseCookies, PathLike, StrOrURL | |
__all__ = ("CookieJar", "DummyCookieJar") | |
CookieItem = Union[str, "Morsel[str]"] | |
# We cache these string methods here as their use is in performance critical code. | |
_FORMAT_PATH = "{}/{}".format | |
_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format | |
# The minimum number of scheduled cookie expirations before we start cleaning up | |
# the expiration heap. This is a performance optimization to avoid cleaning up the | |
# heap too often when there are only a few scheduled expirations. | |
_MIN_SCHEDULED_COOKIE_EXPIRATION = 100 | |
class CookieJar(AbstractCookieJar): | |
"""Implements cookie storage adhering to RFC 6265.""" | |
DATE_TOKENS_RE = re.compile( | |
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" | |
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" | |
) | |
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") | |
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") | |
DATE_MONTH_RE = re.compile( | |
"(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)", | |
re.I, | |
) | |
DATE_YEAR_RE = re.compile(r"(\d{2,4})") | |
# calendar.timegm() fails for timestamps after datetime.datetime.max | |
# Minus one as a loss of precision occurs when timestamp() is called. | |
MAX_TIME = ( | |
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1 | |
) | |
try: | |
calendar.timegm(time.gmtime(MAX_TIME)) | |
except (OSError, ValueError): | |
# Hit the maximum representable time on Windows | |
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64 | |
# Throws ValueError on PyPy 3.8 and 3.9, OSError elsewhere | |
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1)) | |
except OverflowError: | |
# #4515: datetime.max may not be representable on 32-bit platforms | |
MAX_TIME = 2**31 - 1 | |
# Avoid minuses in the future, 3x faster | |
SUB_MAX_TIME = MAX_TIME - 1 | |
def __init__( | |
self, | |
*, | |
unsafe: bool = False, | |
quote_cookie: bool = True, | |
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None, | |
loop: Optional[asyncio.AbstractEventLoop] = None, | |
) -> None: | |
super().__init__(loop=loop) | |
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict( | |
SimpleCookie | |
) | |
self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = ( | |
defaultdict(dict) | |
) | |
self._host_only_cookies: Set[Tuple[str, str]] = set() | |
self._unsafe = unsafe | |
self._quote_cookie = quote_cookie | |
if treat_as_secure_origin is None: | |
treat_as_secure_origin = [] | |
elif isinstance(treat_as_secure_origin, URL): | |
treat_as_secure_origin = [treat_as_secure_origin.origin()] | |
elif isinstance(treat_as_secure_origin, str): | |
treat_as_secure_origin = [URL(treat_as_secure_origin).origin()] | |
else: | |
treat_as_secure_origin = [ | |
URL(url).origin() if isinstance(url, str) else url.origin() | |
for url in treat_as_secure_origin | |
] | |
self._treat_as_secure_origin = treat_as_secure_origin | |
self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = [] | |
self._expirations: Dict[Tuple[str, str, str], float] = {} | |
def save(self, file_path: PathLike) -> None: | |
file_path = pathlib.Path(file_path) | |
with file_path.open(mode="wb") as f: | |
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) | |
def load(self, file_path: PathLike) -> None: | |
file_path = pathlib.Path(file_path) | |
with file_path.open(mode="rb") as f: | |
self._cookies = pickle.load(f) | |
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: | |
if predicate is None: | |
self._expire_heap.clear() | |
self._cookies.clear() | |
self._morsel_cache.clear() | |
self._host_only_cookies.clear() | |
self._expirations.clear() | |
return | |
now = time.time() | |
to_del = [ | |
key | |
for (domain, path), cookie in self._cookies.items() | |
for name, morsel in cookie.items() | |
if ( | |
(key := (domain, path, name)) in self._expirations | |
and self._expirations[key] <= now | |
) | |
or predicate(morsel) | |
] | |
if to_del: | |
self._delete_cookies(to_del) | |
def clear_domain(self, domain: str) -> None: | |
self.clear(lambda x: self._is_domain_match(domain, x["domain"])) | |
def __iter__(self) -> "Iterator[Morsel[str]]": | |
self._do_expiration() | |
for val in self._cookies.values(): | |
yield from val.values() | |
def __len__(self) -> int: | |
"""Return number of cookies. | |
This function does not iterate self to avoid unnecessary expiration | |
checks. | |
""" | |
return sum(len(cookie.values()) for cookie in self._cookies.values()) | |
def _do_expiration(self) -> None: | |
"""Remove expired cookies.""" | |
if not (expire_heap_len := len(self._expire_heap)): | |
return | |
# If the expiration heap grows larger than the number expirations | |
# times two, we clean it up to avoid keeping expired entries in | |
# the heap and consuming memory. We guard this with a minimum | |
# threshold to avoid cleaning up the heap too often when there are | |
# only a few scheduled expirations. | |
if ( | |
expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION | |
and expire_heap_len > len(self._expirations) * 2 | |
): | |
# Remove any expired entries from the expiration heap | |
# that do not match the expiration time in the expirations | |
# as it means the cookie has been re-added to the heap | |
# with a different expiration time. | |
self._expire_heap = [ | |
entry | |
for entry in self._expire_heap | |
if self._expirations.get(entry[1]) == entry[0] | |
] | |
heapq.heapify(self._expire_heap) | |
now = time.time() | |
to_del: List[Tuple[str, str, str]] = [] | |
# Find any expired cookies and add them to the to-delete list | |
while self._expire_heap: | |
when, cookie_key = self._expire_heap[0] | |
if when > now: | |
break | |
heapq.heappop(self._expire_heap) | |
# Check if the cookie hasn't been re-added to the heap | |
# with a different expiration time as it will be removed | |
# later when it reaches the top of the heap and its | |
# expiration time is met. | |
if self._expirations.get(cookie_key) == when: | |
to_del.append(cookie_key) | |
if to_del: | |
self._delete_cookies(to_del) | |
def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None: | |
for domain, path, name in to_del: | |
self._host_only_cookies.discard((domain, name)) | |
self._cookies[(domain, path)].pop(name, None) | |
self._morsel_cache[(domain, path)].pop(name, None) | |
self._expirations.pop((domain, path, name), None) | |
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None: | |
cookie_key = (domain, path, name) | |
if self._expirations.get(cookie_key) == when: | |
# Avoid adding duplicates to the heap | |
return | |
heapq.heappush(self._expire_heap, (when, cookie_key)) | |
self._expirations[cookie_key] = when | |
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: | |
"""Update cookies.""" | |
hostname = response_url.raw_host | |
if not self._unsafe and is_ip_address(hostname): | |
# Don't accept cookies from IPs | |
return | |
if isinstance(cookies, Mapping): | |
cookies = cookies.items() | |
for name, cookie in cookies: | |
if not isinstance(cookie, Morsel): | |
tmp = SimpleCookie() | |
tmp[name] = cookie # type: ignore[assignment] | |
cookie = tmp[name] | |
domain = cookie["domain"] | |
# ignore domains with trailing dots | |
if domain and domain[-1] == ".": | |
domain = "" | |
del cookie["domain"] | |
if not domain and hostname is not None: | |
# Set the cookie's domain to the response hostname | |
# and set its host-only-flag | |
self._host_only_cookies.add((hostname, name)) | |
domain = cookie["domain"] = hostname | |
if domain and domain[0] == ".": | |
# Remove leading dot | |
domain = domain[1:] | |
cookie["domain"] = domain | |
if hostname and not self._is_domain_match(domain, hostname): | |
# Setting cookies for different domains is not allowed | |
continue | |
path = cookie["path"] | |
if not path or path[0] != "/": | |
# Set the cookie's path to the response path | |
path = response_url.path | |
if not path.startswith("/"): | |
path = "/" | |
else: | |
# Cut everything from the last slash to the end | |
path = "/" + path[1 : path.rfind("/")] | |
cookie["path"] = path | |
path = path.rstrip("/") | |
if max_age := cookie["max-age"]: | |
try: | |
delta_seconds = int(max_age) | |
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME) | |
self._expire_cookie(max_age_expiration, domain, path, name) | |
except ValueError: | |
cookie["max-age"] = "" | |
elif expires := cookie["expires"]: | |
if expire_time := self._parse_date(expires): | |
self._expire_cookie(expire_time, domain, path, name) | |
else: | |
cookie["expires"] = "" | |
key = (domain, path) | |
if self._cookies[key].get(name) != cookie: | |
# Don't blow away the cache if the same | |
# cookie gets set again | |
self._cookies[key][name] = cookie | |
self._morsel_cache[key].pop(name, None) | |
self._do_expiration() | |
def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]": | |
"""Returns this jar's cookies filtered by their attributes.""" | |
filtered: Union[SimpleCookie, "BaseCookie[str]"] = ( | |
SimpleCookie() if self._quote_cookie else BaseCookie() | |
) | |
if not self._cookies: | |
# Skip do_expiration() if there are no cookies. | |
return filtered | |
self._do_expiration() | |
if not self._cookies: | |
# Skip rest of function if no non-expired cookies. | |
return filtered | |
request_url = URL(request_url) | |
hostname = request_url.raw_host or "" | |
is_not_secure = request_url.scheme not in ("https", "wss") | |
if is_not_secure and self._treat_as_secure_origin: | |
request_origin = URL() | |
with contextlib.suppress(ValueError): | |
request_origin = request_url.origin() | |
is_not_secure = request_origin not in self._treat_as_secure_origin | |
# Send shared cookie | |
for c in self._cookies[("", "")].values(): | |
filtered[c.key] = c.value | |
if is_ip_address(hostname): | |
if not self._unsafe: | |
return filtered | |
domains: Iterable[str] = (hostname,) | |
else: | |
# Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com") | |
domains = itertools.accumulate( | |
reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED | |
) | |
# Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar") | |
paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH) | |
# Create every combination of (domain, path) pairs. | |
pairs = itertools.product(domains, paths) | |
path_len = len(request_url.path) | |
# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 | |
for p in pairs: | |
for name, cookie in self._cookies[p].items(): | |
domain = cookie["domain"] | |
if (domain, name) in self._host_only_cookies and domain != hostname: | |
continue | |
# Skip edge case when the cookie has a trailing slash but request doesn't. | |
if len(cookie["path"]) > path_len: | |
continue | |
if is_not_secure and cookie["secure"]: | |
continue | |
# We already built the Morsel so reuse it here | |
if name in self._morsel_cache[p]: | |
filtered[name] = self._morsel_cache[p][name] | |
continue | |
# It's critical we use the Morsel so the coded_value | |
# (based on cookie version) is preserved | |
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) | |
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) | |
self._morsel_cache[p][name] = mrsl_val | |
filtered[name] = mrsl_val | |
return filtered | |
def _is_domain_match(domain: str, hostname: str) -> bool: | |
"""Implements domain matching adhering to RFC 6265.""" | |
if hostname == domain: | |
return True | |
if not hostname.endswith(domain): | |
return False | |
non_matching = hostname[: -len(domain)] | |
if not non_matching.endswith("."): | |
return False | |
return not is_ip_address(hostname) | |
def _parse_date(cls, date_str: str) -> Optional[int]: | |
"""Implements date string parsing adhering to RFC 6265.""" | |
if not date_str: | |
return None | |
found_time = False | |
found_day = False | |
found_month = False | |
found_year = False | |
hour = minute = second = 0 | |
day = 0 | |
month = 0 | |
year = 0 | |
for token_match in cls.DATE_TOKENS_RE.finditer(date_str): | |
token = token_match.group("token") | |
if not found_time: | |
time_match = cls.DATE_HMS_TIME_RE.match(token) | |
if time_match: | |
found_time = True | |
hour, minute, second = (int(s) for s in time_match.groups()) | |
continue | |
if not found_day: | |
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) | |
if day_match: | |
found_day = True | |
day = int(day_match.group()) | |
continue | |
if not found_month: | |
month_match = cls.DATE_MONTH_RE.match(token) | |
if month_match: | |
found_month = True | |
assert month_match.lastindex is not None | |
month = month_match.lastindex | |
continue | |
if not found_year: | |
year_match = cls.DATE_YEAR_RE.match(token) | |
if year_match: | |
found_year = True | |
year = int(year_match.group()) | |
if 70 <= year <= 99: | |
year += 1900 | |
elif 0 <= year <= 69: | |
year += 2000 | |
if False in (found_day, found_month, found_year, found_time): | |
return None | |
if not 1 <= day <= 31: | |
return None | |
if year < 1601 or hour > 23 or minute > 59 or second > 59: | |
return None | |
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1)) | |
class DummyCookieJar(AbstractCookieJar): | |
"""Implements a dummy cookie storage. | |
It can be used with the ClientSession when no cookie processing is needed. | |
""" | |
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: | |
super().__init__(loop=loop) | |
def __iter__(self) -> "Iterator[Morsel[str]]": | |
while False: | |
yield None | |
def __len__(self) -> int: | |
return 0 | |
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: | |
pass | |
def clear_domain(self, domain: str) -> None: | |
pass | |
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: | |
pass | |
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": | |
return SimpleCookie() | |