File size: 5,971 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from __future__ import annotations
import datetime as dt
from typing import (
TYPE_CHECKING,
Any,
cast,
)
import numpy as np
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.api.extensions import (
ExtensionArray,
ExtensionDtype,
)
from pandas.api.types import pandas_dtype
if TYPE_CHECKING:
from collections.abc import Sequence
from pandas._typing import (
Dtype,
PositionalIndexer,
)
@register_extension_dtype
class DateDtype(ExtensionDtype):
@property
def type(self):
return dt.date
@property
def name(self):
return "DateDtype"
@classmethod
def construct_from_string(cls, string: str):
if not isinstance(string, str):
raise TypeError(
f"'construct_from_string' expects a string, got {type(string)}"
)
if string == cls.__name__:
return cls()
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
@classmethod
def construct_array_type(cls):
return DateArray
@property
def na_value(self):
return dt.date.min
def __repr__(self) -> str:
return self.name
class DateArray(ExtensionArray):
def __init__(
self,
dates: (
dt.date
| Sequence[dt.date]
| tuple[np.ndarray, np.ndarray, np.ndarray]
| np.ndarray
),
) -> None:
if isinstance(dates, dt.date):
self._year = np.array([dates.year])
self._month = np.array([dates.month])
self._day = np.array([dates.year])
return
ldates = len(dates)
if isinstance(dates, list):
# pre-allocate the arrays since we know the size before hand
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
# populate them
for i, (y, m, d) in enumerate(
(date.year, date.month, date.day) for date in dates
):
self._year[i] = y
self._month[i] = m
self._day[i] = d
elif isinstance(dates, tuple):
# only support triples
if ldates != 3:
raise ValueError("only triples are valid")
# check if all elements have the same type
if any(not isinstance(x, np.ndarray) for x in dates):
raise TypeError("invalid type")
ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates)
if not ly == lm == ld:
raise ValueError(
f"tuple members must have the same length: {(ly, lm, ld)}"
)
self._year = dates[0].astype(np.uint16)
self._month = dates[1].astype(np.uint8)
self._day = dates[2].astype(np.uint8)
elif isinstance(dates, np.ndarray) and dates.dtype == "U10":
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
# error: "object_" object is not iterable
obj = np.char.split(dates, sep="-")
for (i,), (y, m, d) in np.ndenumerate(obj): # type: ignore[misc]
self._year[i] = int(y)
self._month[i] = int(m)
self._day[i] = int(d)
else:
raise TypeError(f"{type(dates)} is not supported")
@property
def dtype(self) -> ExtensionDtype:
return DateDtype()
def astype(self, dtype, copy=True):
dtype = pandas_dtype(dtype)
if isinstance(dtype, DateDtype):
data = self.copy() if copy else self
else:
data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min)
return data
@property
def nbytes(self) -> int:
return self._year.nbytes + self._month.nbytes + self._day.nbytes
def __len__(self) -> int:
return len(self._year) # all 3 arrays are enforced to have the same length
def __getitem__(self, item: PositionalIndexer):
if isinstance(item, int):
return dt.date(self._year[item], self._month[item], self._day[item])
else:
raise NotImplementedError("only ints are supported as indexes")
def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
if not isinstance(key, int):
raise NotImplementedError("only ints are supported as indexes")
if not isinstance(value, dt.date):
raise TypeError("you can only set datetime.date types")
self._year[key] = value.year
self._month[key] = value.month
self._day[key] = value.day
def __repr__(self) -> str:
return f"DateArray{list(zip(self._year, self._month, self._day))}"
def copy(self) -> DateArray:
return DateArray((self._year.copy(), self._month.copy(), self._day.copy()))
def isna(self) -> np.ndarray:
return np.logical_and(
np.logical_and(
self._year == dt.date.min.year, self._month == dt.date.min.month
),
self._day == dt.date.min.day,
)
@classmethod
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
if isinstance(scalars, dt.date):
raise TypeError
elif isinstance(scalars, DateArray):
if dtype is not None:
return scalars.astype(dtype, copy=copy)
if copy:
return scalars.copy()
return scalars[:]
elif isinstance(scalars, np.ndarray):
scalars = scalars.astype("U10") # 10 chars for yyyy-mm-dd
return DateArray(scalars)
|