|
import asyncio |
|
import contextlib |
|
import datetime |
|
import os |
|
import pathlib |
|
import pickle |
|
import re |
|
from collections import defaultdict |
|
from http.cookies import BaseCookie, Morsel, SimpleCookie |
|
from typing import ( |
|
DefaultDict, |
|
Dict, |
|
Iterable, |
|
Iterator, |
|
List, |
|
Mapping, |
|
Optional, |
|
Set, |
|
Tuple, |
|
Union, |
|
cast, |
|
) |
|
|
|
from yarl import URL |
|
|
|
from .abc import AbstractCookieJar, ClearCookiePredicate |
|
from .helpers import is_ip_address, next_whole_second |
|
from .typedefs import LooseCookies, PathLike, StrOrURL |
|
|
|
__all__ = ("CookieJar", "DummyCookieJar") |
|
|
|
|
|
CookieItem = Union[str, "Morsel[str]"] |
|
|
|
|
|
class CookieJar(AbstractCookieJar): |
|
"""Implements cookie storage adhering to RFC 6265.""" |
|
|
|
DATE_TOKENS_RE = re.compile( |
|
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" |
|
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" |
|
) |
|
|
|
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") |
|
|
|
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") |
|
|
|
DATE_MONTH_RE = re.compile( |
|
"(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)", |
|
re.I, |
|
) |
|
|
|
DATE_YEAR_RE = re.compile(r"(\d{2,4})") |
|
|
|
MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc) |
|
|
|
MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2**31 - 1) |
|
|
|
def __init__( |
|
self, |
|
*, |
|
unsafe: bool = False, |
|
quote_cookie: bool = True, |
|
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None, |
|
loop: Optional[asyncio.AbstractEventLoop] = None, |
|
) -> None: |
|
super().__init__(loop=loop) |
|
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie[str]] = defaultdict( |
|
SimpleCookie |
|
) |
|
self._host_only_cookies: Set[Tuple[str, str]] = set() |
|
self._unsafe = unsafe |
|
self._quote_cookie = quote_cookie |
|
if treat_as_secure_origin is None: |
|
treat_as_secure_origin = [] |
|
elif isinstance(treat_as_secure_origin, URL): |
|
treat_as_secure_origin = [treat_as_secure_origin.origin()] |
|
elif isinstance(treat_as_secure_origin, str): |
|
treat_as_secure_origin = [URL(treat_as_secure_origin).origin()] |
|
else: |
|
treat_as_secure_origin = [ |
|
URL(url).origin() if isinstance(url, str) else url.origin() |
|
for url in treat_as_secure_origin |
|
] |
|
self._treat_as_secure_origin = treat_as_secure_origin |
|
self._next_expiration = next_whole_second() |
|
self._expirations: Dict[Tuple[str, str, str], datetime.datetime] = {} |
|
|
|
self._max_time = self.MAX_TIME |
|
try: |
|
self._max_time.timestamp() |
|
except OverflowError: |
|
self._max_time = self.MAX_32BIT_TIME |
|
|
|
def save(self, file_path: PathLike) -> None: |
|
file_path = pathlib.Path(file_path) |
|
with file_path.open(mode="wb") as f: |
|
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) |
|
|
|
def load(self, file_path: PathLike) -> None: |
|
file_path = pathlib.Path(file_path) |
|
with file_path.open(mode="rb") as f: |
|
self._cookies = pickle.load(f) |
|
|
|
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: |
|
if predicate is None: |
|
self._next_expiration = next_whole_second() |
|
self._cookies.clear() |
|
self._host_only_cookies.clear() |
|
self._expirations.clear() |
|
return |
|
|
|
to_del = [] |
|
now = datetime.datetime.now(datetime.timezone.utc) |
|
for (domain, path), cookie in self._cookies.items(): |
|
for name, morsel in cookie.items(): |
|
key = (domain, path, name) |
|
if ( |
|
key in self._expirations and self._expirations[key] <= now |
|
) or predicate(morsel): |
|
to_del.append(key) |
|
|
|
for domain, path, name in to_del: |
|
self._host_only_cookies.discard((domain, name)) |
|
key = (domain, path, name) |
|
if key in self._expirations: |
|
del self._expirations[(domain, path, name)] |
|
self._cookies[(domain, path)].pop(name, None) |
|
|
|
next_expiration = min(self._expirations.values(), default=self._max_time) |
|
try: |
|
self._next_expiration = next_expiration.replace( |
|
microsecond=0 |
|
) + datetime.timedelta(seconds=1) |
|
except OverflowError: |
|
self._next_expiration = self._max_time |
|
|
|
def clear_domain(self, domain: str) -> None: |
|
self.clear(lambda x: self._is_domain_match(domain, x["domain"])) |
|
|
|
def __iter__(self) -> "Iterator[Morsel[str]]": |
|
self._do_expiration() |
|
for val in self._cookies.values(): |
|
yield from val.values() |
|
|
|
def __len__(self) -> int: |
|
return sum(1 for i in self) |
|
|
|
def _do_expiration(self) -> None: |
|
self.clear(lambda x: False) |
|
|
|
def _expire_cookie( |
|
self, when: datetime.datetime, domain: str, path: str, name: str |
|
) -> None: |
|
self._next_expiration = min(self._next_expiration, when) |
|
self._expirations[(domain, path, name)] = when |
|
|
|
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: |
|
"""Update cookies.""" |
|
hostname = response_url.raw_host |
|
|
|
if not self._unsafe and is_ip_address(hostname): |
|
|
|
return |
|
|
|
if isinstance(cookies, Mapping): |
|
cookies = cookies.items() |
|
|
|
for name, cookie in cookies: |
|
if not isinstance(cookie, Morsel): |
|
tmp: SimpleCookie[str] = SimpleCookie() |
|
tmp[name] = cookie |
|
cookie = tmp[name] |
|
|
|
domain = cookie["domain"] |
|
|
|
|
|
if domain.endswith("."): |
|
domain = "" |
|
del cookie["domain"] |
|
|
|
if not domain and hostname is not None: |
|
|
|
|
|
self._host_only_cookies.add((hostname, name)) |
|
domain = cookie["domain"] = hostname |
|
|
|
if domain.startswith("."): |
|
|
|
domain = domain[1:] |
|
cookie["domain"] = domain |
|
|
|
if hostname and not self._is_domain_match(domain, hostname): |
|
|
|
continue |
|
|
|
path = cookie["path"] |
|
if not path or not path.startswith("/"): |
|
|
|
path = response_url.path |
|
if not path.startswith("/"): |
|
path = "/" |
|
else: |
|
|
|
path = "/" + path[1 : path.rfind("/")] |
|
cookie["path"] = path |
|
|
|
max_age = cookie["max-age"] |
|
if max_age: |
|
try: |
|
delta_seconds = int(max_age) |
|
try: |
|
max_age_expiration = datetime.datetime.now( |
|
datetime.timezone.utc |
|
) + datetime.timedelta(seconds=delta_seconds) |
|
except OverflowError: |
|
max_age_expiration = self._max_time |
|
self._expire_cookie(max_age_expiration, domain, path, name) |
|
except ValueError: |
|
cookie["max-age"] = "" |
|
|
|
else: |
|
expires = cookie["expires"] |
|
if expires: |
|
expire_time = self._parse_date(expires) |
|
if expire_time: |
|
self._expire_cookie(expire_time, domain, path, name) |
|
else: |
|
cookie["expires"] = "" |
|
|
|
self._cookies[(domain, path)][name] = cookie |
|
|
|
self._do_expiration() |
|
|
|
def filter_cookies( |
|
self, request_url: URL = URL() |
|
) -> Union["BaseCookie[str]", "SimpleCookie[str]"]: |
|
"""Returns this jar's cookies filtered by their attributes.""" |
|
self._do_expiration() |
|
request_url = URL(request_url) |
|
filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = ( |
|
SimpleCookie() if self._quote_cookie else BaseCookie() |
|
) |
|
hostname = request_url.raw_host or "" |
|
request_origin = URL() |
|
with contextlib.suppress(ValueError): |
|
request_origin = request_url.origin() |
|
|
|
is_not_secure = ( |
|
request_url.scheme not in ("https", "wss") |
|
and request_origin not in self._treat_as_secure_origin |
|
) |
|
|
|
|
|
for cookie in sorted(self, key=lambda c: len(c["path"])): |
|
name = cookie.key |
|
domain = cookie["domain"] |
|
|
|
|
|
if not domain: |
|
filtered[name] = cookie.value |
|
continue |
|
|
|
if not self._unsafe and is_ip_address(hostname): |
|
continue |
|
|
|
if (domain, name) in self._host_only_cookies: |
|
if domain != hostname: |
|
continue |
|
elif not self._is_domain_match(domain, hostname): |
|
continue |
|
|
|
if not self._is_path_match(request_url.path, cookie["path"]): |
|
continue |
|
|
|
if is_not_secure and cookie["secure"]: |
|
continue |
|
|
|
|
|
|
|
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) |
|
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) |
|
filtered[name] = mrsl_val |
|
|
|
return filtered |
|
|
|
@staticmethod |
|
def _is_domain_match(domain: str, hostname: str) -> bool: |
|
"""Implements domain matching adhering to RFC 6265.""" |
|
if hostname == domain: |
|
return True |
|
|
|
if not hostname.endswith(domain): |
|
return False |
|
|
|
non_matching = hostname[: -len(domain)] |
|
|
|
if not non_matching.endswith("."): |
|
return False |
|
|
|
return not is_ip_address(hostname) |
|
|
|
@staticmethod |
|
def _is_path_match(req_path: str, cookie_path: str) -> bool: |
|
"""Implements path matching adhering to RFC 6265.""" |
|
if not req_path.startswith("/"): |
|
req_path = "/" |
|
|
|
if req_path == cookie_path: |
|
return True |
|
|
|
if not req_path.startswith(cookie_path): |
|
return False |
|
|
|
if cookie_path.endswith("/"): |
|
return True |
|
|
|
non_matching = req_path[len(cookie_path) :] |
|
|
|
return non_matching.startswith("/") |
|
|
|
@classmethod |
|
def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]: |
|
"""Implements date string parsing adhering to RFC 6265.""" |
|
if not date_str: |
|
return None |
|
|
|
found_time = False |
|
found_day = False |
|
found_month = False |
|
found_year = False |
|
|
|
hour = minute = second = 0 |
|
day = 0 |
|
month = 0 |
|
year = 0 |
|
|
|
for token_match in cls.DATE_TOKENS_RE.finditer(date_str): |
|
|
|
token = token_match.group("token") |
|
|
|
if not found_time: |
|
time_match = cls.DATE_HMS_TIME_RE.match(token) |
|
if time_match: |
|
found_time = True |
|
hour, minute, second = (int(s) for s in time_match.groups()) |
|
continue |
|
|
|
if not found_day: |
|
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) |
|
if day_match: |
|
found_day = True |
|
day = int(day_match.group()) |
|
continue |
|
|
|
if not found_month: |
|
month_match = cls.DATE_MONTH_RE.match(token) |
|
if month_match: |
|
found_month = True |
|
assert month_match.lastindex is not None |
|
month = month_match.lastindex |
|
continue |
|
|
|
if not found_year: |
|
year_match = cls.DATE_YEAR_RE.match(token) |
|
if year_match: |
|
found_year = True |
|
year = int(year_match.group()) |
|
|
|
if 70 <= year <= 99: |
|
year += 1900 |
|
elif 0 <= year <= 69: |
|
year += 2000 |
|
|
|
if False in (found_day, found_month, found_year, found_time): |
|
return None |
|
|
|
if not 1 <= day <= 31: |
|
return None |
|
|
|
if year < 1601 or hour > 23 or minute > 59 or second > 59: |
|
return None |
|
|
|
return datetime.datetime( |
|
year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc |
|
) |
|
|
|
|
|
class DummyCookieJar(AbstractCookieJar): |
|
"""Implements a dummy cookie storage. |
|
|
|
It can be used with the ClientSession when no cookie processing is needed. |
|
|
|
""" |
|
|
|
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: |
|
super().__init__(loop=loop) |
|
|
|
def __iter__(self) -> "Iterator[Morsel[str]]": |
|
while False: |
|
yield None |
|
|
|
def __len__(self) -> int: |
|
return 0 |
|
|
|
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: |
|
pass |
|
|
|
def clear_domain(self, domain: str) -> None: |
|
pass |
|
|
|
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: |
|
pass |
|
|
|
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": |
|
return SimpleCookie() |
|
|