|
from __future__ import annotations |
|
|
|
import hashlib |
|
import json |
|
import random |
|
import sys |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Dict, |
|
List, |
|
Literal, |
|
MutableMapping, |
|
Protocol, |
|
Sequence, |
|
TypedDict, |
|
TypeVar, |
|
Union, |
|
overload, |
|
runtime_checkable, |
|
) |
|
from typing_extensions import Concatenate, ParamSpec, TypeAlias |
|
|
|
import narwhals.stable.v1 as nw |
|
from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe |
|
from narwhals.typing import IntoDataFrame |
|
|
|
from ._importers import import_pyarrow_interchange |
|
from .core import ( |
|
DataFrameLike, |
|
sanitize_geo_interface, |
|
sanitize_narwhals_dataframe, |
|
sanitize_pandas_dataframe, |
|
to_eager_narwhals_dataframe, |
|
) |
|
from .plugin_registry import PluginRegistry |
|
|
|
if sys.version_info >= (3, 13): |
|
from typing import TypeIs |
|
else: |
|
from typing_extensions import TypeIs |
|
|
|
if TYPE_CHECKING: |
|
import pandas as pd |
|
import pyarrow as pa |
|
|
|
|
|
@runtime_checkable |
|
class SupportsGeoInterface(Protocol): |
|
__geo_interface__: MutableMapping |
|
|
|
|
|
DataType: TypeAlias = Union[ |
|
Dict[Any, Any], IntoDataFrame, SupportsGeoInterface, DataFrameLike |
|
] |
|
|
|
TDataType = TypeVar("TDataType", bound=DataType) |
|
TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame) |
|
|
|
VegaLiteDataDict: TypeAlias = Dict[ |
|
str, Union[str, Dict[Any, Any], List[Dict[Any, Any]]] |
|
] |
|
ToValuesReturnType: TypeAlias = Dict[str, Union[Dict[Any, Any], List[Dict[Any, Any]]]] |
|
SampleReturnType = Union[IntoDataFrame, Dict[str, Sequence], None] |
|
|
|
|
|
def is_data_type(obj: Any) -> TypeIs[DataType]: |
|
return _is_pandas_dataframe(obj) or isinstance( |
|
obj, (dict, DataFrameLike, SupportsGeoInterface, nw.DataFrame) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
P = ParamSpec("P") |
|
|
|
R = TypeVar("R", VegaLiteDataDict, Any) |
|
DataTransformerType = Callable[Concatenate[DataType, P], R] |
|
|
|
|
|
class DataTransformerRegistry(PluginRegistry[DataTransformerType, R]): |
|
_global_settings = {"consolidate_datasets": True} |
|
|
|
@property |
|
def consolidate_datasets(self) -> bool: |
|
return self._global_settings["consolidate_datasets"] |
|
|
|
@consolidate_datasets.setter |
|
def consolidate_datasets(self, value: bool) -> None: |
|
self._global_settings["consolidate_datasets"] = value |
|
|
|
|
|
|
|
class MaxRowsError(Exception): |
|
"""Raised when a data model has too many rows.""" |
|
|
|
|
|
@overload |
|
def limit_rows(data: None = ..., max_rows: int | None = ...) -> partial: ... |
|
@overload |
|
def limit_rows(data: DataType, max_rows: int | None = ...) -> DataType: ... |
|
def limit_rows( |
|
data: DataType | None = None, max_rows: int | None = 5000 |
|
) -> partial | DataType: |
|
""" |
|
Raise MaxRowsError if the data model has more than max_rows. |
|
|
|
If max_rows is None, then do not perform any check. |
|
""" |
|
if data is None: |
|
return partial(limit_rows, max_rows=max_rows) |
|
check_data_type(data) |
|
|
|
def raise_max_rows_error(): |
|
msg = ( |
|
"The number of rows in your dataset is greater " |
|
f"than the maximum allowed ({max_rows}).\n\n" |
|
"Try enabling the VegaFusion data transformer which " |
|
"raises this limit by pre-evaluating data\n" |
|
"transformations in Python.\n" |
|
" >> import altair as alt\n" |
|
' >> alt.data_transformers.enable("vegafusion")\n\n' |
|
"Or, see https://altair-viz.github.io/user_guide/large_datasets.html " |
|
"for additional information\n" |
|
"on how to plot large datasets." |
|
) |
|
raise MaxRowsError(msg) |
|
|
|
if isinstance(data, SupportsGeoInterface): |
|
if data.__geo_interface__["type"] == "FeatureCollection": |
|
values = data.__geo_interface__["features"] |
|
else: |
|
values = data.__geo_interface__ |
|
elif isinstance(data, dict): |
|
if "values" in data: |
|
values = data["values"] |
|
else: |
|
return data |
|
else: |
|
data = to_eager_narwhals_dataframe(data) |
|
values = data |
|
|
|
if max_rows is not None and len(values) > max_rows: |
|
raise_max_rows_error() |
|
|
|
return data |
|
|
|
|
|
@overload |
|
def sample( |
|
data: None = ..., n: int | None = ..., frac: float | None = ... |
|
) -> partial: ... |
|
@overload |
|
def sample( |
|
data: TIntoDataFrame, n: int | None = ..., frac: float | None = ... |
|
) -> TIntoDataFrame: ... |
|
@overload |
|
def sample( |
|
data: DataType, n: int | None = ..., frac: float | None = ... |
|
) -> SampleReturnType: ... |
|
def sample( |
|
data: DataType | None = None, |
|
n: int | None = None, |
|
frac: float | None = None, |
|
) -> partial | SampleReturnType: |
|
"""Reduce the size of the data model by sampling without replacement.""" |
|
if data is None: |
|
return partial(sample, n=n, frac=frac) |
|
check_data_type(data) |
|
if _is_pandas_dataframe(data): |
|
return data.sample(n=n, frac=frac) |
|
elif isinstance(data, dict): |
|
if "values" in data: |
|
values = data["values"] |
|
if not n: |
|
if frac is None: |
|
msg = "frac cannot be None if n is None and data is a dictionary" |
|
raise ValueError(msg) |
|
n = int(frac * len(values)) |
|
values = random.sample(values, n) |
|
return {"values": values} |
|
else: |
|
|
|
return None |
|
data = nw.from_native(data, eager_only=True) |
|
if not n: |
|
if frac is None: |
|
msg = "frac cannot be None if n is None with this data input type" |
|
raise ValueError(msg) |
|
n = int(frac * len(data)) |
|
indices = random.sample(range(len(data)), n) |
|
return nw.to_native(data[indices]) |
|
|
|
|
|
_FormatType = Literal["csv", "json"] |
|
|
|
|
|
class _FormatDict(TypedDict): |
|
type: _FormatType |
|
|
|
|
|
class _ToFormatReturnUrlDict(TypedDict): |
|
url: str |
|
format: _FormatDict |
|
|
|
|
|
@overload |
|
def to_json( |
|
data: None = ..., |
|
prefix: str = ..., |
|
extension: str = ..., |
|
filename: str = ..., |
|
urlpath: str = ..., |
|
) -> partial: ... |
|
|
|
|
|
@overload |
|
def to_json( |
|
data: DataType, |
|
prefix: str = ..., |
|
extension: str = ..., |
|
filename: str = ..., |
|
urlpath: str = ..., |
|
) -> _ToFormatReturnUrlDict: ... |
|
|
|
|
|
def to_json( |
|
data: DataType | None = None, |
|
prefix: str = "altair-data", |
|
extension: str = "json", |
|
filename: str = "{prefix}-{hash}.{extension}", |
|
urlpath: str = "", |
|
) -> partial | _ToFormatReturnUrlDict: |
|
"""Write the data model to a .json file and return a url based data model.""" |
|
kwds = _to_text_kwds(prefix, extension, filename, urlpath) |
|
if data is None: |
|
return partial(to_json, **kwds) |
|
else: |
|
data_str = _data_to_json_string(data) |
|
return _to_text(data_str, **kwds, format=_FormatDict(type="json")) |
|
|
|
|
|
@overload |
|
def to_csv( |
|
data: None = ..., |
|
prefix: str = ..., |
|
extension: str = ..., |
|
filename: str = ..., |
|
urlpath: str = ..., |
|
) -> partial: ... |
|
|
|
|
|
@overload |
|
def to_csv( |
|
data: dict | pd.DataFrame | DataFrameLike, |
|
prefix: str = ..., |
|
extension: str = ..., |
|
filename: str = ..., |
|
urlpath: str = ..., |
|
) -> _ToFormatReturnUrlDict: ... |
|
|
|
|
|
def to_csv( |
|
data: dict | pd.DataFrame | DataFrameLike | None = None, |
|
prefix: str = "altair-data", |
|
extension: str = "csv", |
|
filename: str = "{prefix}-{hash}.{extension}", |
|
urlpath: str = "", |
|
) -> partial | _ToFormatReturnUrlDict: |
|
"""Write the data model to a .csv file and return a url based data model.""" |
|
kwds = _to_text_kwds(prefix, extension, filename, urlpath) |
|
if data is None: |
|
return partial(to_csv, **kwds) |
|
else: |
|
data_str = _data_to_csv_string(data) |
|
return _to_text(data_str, **kwds, format=_FormatDict(type="csv")) |
|
|
|
|
|
def _to_text( |
|
data: str, |
|
prefix: str, |
|
extension: str, |
|
filename: str, |
|
urlpath: str, |
|
format: _FormatDict, |
|
) -> _ToFormatReturnUrlDict: |
|
data_hash = _compute_data_hash(data) |
|
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) |
|
Path(filename).write_text(data, encoding="utf-8") |
|
url = str(Path(urlpath, filename)) |
|
return _ToFormatReturnUrlDict({"url": url, "format": format}) |
|
|
|
|
|
def _to_text_kwds(prefix: str, extension: str, filename: str, urlpath: str, /) -> dict[str, str]: |
|
return {"prefix": prefix, "extension": extension, "filename": filename, "urlpath": urlpath} |
|
|
|
|
|
def to_values(data: DataType) -> ToValuesReturnType: |
|
"""Replace a DataFrame by a data model with values.""" |
|
check_data_type(data) |
|
|
|
data_native = nw.to_native(data, strict=False) |
|
if isinstance(data_native, SupportsGeoInterface): |
|
return {"values": _from_geo_interface(data_native)} |
|
elif _is_pandas_dataframe(data_native): |
|
data_native = sanitize_pandas_dataframe(data_native) |
|
return {"values": data_native.to_dict(orient="records")} |
|
elif isinstance(data_native, dict): |
|
if "values" not in data_native: |
|
msg = "values expected in data dict, but not present." |
|
raise KeyError(msg) |
|
return data_native |
|
elif isinstance(data, nw.DataFrame): |
|
data = sanitize_narwhals_dataframe(data) |
|
return {"values": data.rows(named=True)} |
|
else: |
|
|
|
msg = f"Unrecognized data type: {type(data)}" |
|
raise ValueError(msg) |
|
|
|
|
|
def check_data_type(data: DataType) -> None: |
|
if not is_data_type(data): |
|
msg = f"Expected dict, DataFrame or a __geo_interface__ attribute, got: {type(data)}" |
|
raise TypeError(msg) |
|
|
|
|
|
|
|
|
|
|
|
def _compute_data_hash(data_str: str) -> str: |
|
return hashlib.sha256(data_str.encode()).hexdigest()[:32] |
|
|
|
|
|
def _from_geo_interface(data: SupportsGeoInterface | Any) -> dict[str, Any]: |
|
""" |
|
Santize a ``__geo_interface__`` w/ pre-santize step for ``pandas`` if needed. |
|
|
|
Notes |
|
----- |
|
Split out to resolve typing issues related to: |
|
- Intersection types |
|
- ``typing.TypeGuard`` |
|
- ``pd.DataFrame.__getattr__`` |
|
""" |
|
if _is_pandas_dataframe(data): |
|
data = sanitize_pandas_dataframe(data) |
|
return sanitize_geo_interface(data.__geo_interface__) |
|
|
|
|
|
def _data_to_json_string(data: DataType) -> str: |
|
"""Return a JSON string representation of the input data.""" |
|
check_data_type(data) |
|
if isinstance(data, SupportsGeoInterface): |
|
return json.dumps(_from_geo_interface(data)) |
|
elif _is_pandas_dataframe(data): |
|
data = sanitize_pandas_dataframe(data) |
|
return data.to_json(orient="records", double_precision=15) |
|
elif isinstance(data, dict): |
|
if "values" not in data: |
|
msg = "values expected in data dict, but not present." |
|
raise KeyError(msg) |
|
return json.dumps(data["values"], sort_keys=True) |
|
try: |
|
data_nw = nw.from_native(data, eager_only=True) |
|
except TypeError as exc: |
|
msg = "to_json only works with data expressed as a DataFrame or as a dict" |
|
raise NotImplementedError(msg) from exc |
|
data_nw = sanitize_narwhals_dataframe(data_nw) |
|
return json.dumps(data_nw.rows(named=True)) |
|
|
|
|
|
def _data_to_csv_string(data: DataType) -> str: |
|
"""Return a CSV string representation of the input data.""" |
|
check_data_type(data) |
|
if isinstance(data, SupportsGeoInterface): |
|
msg = ( |
|
f"to_csv does not yet work with data that " |
|
f"is of type {type(SupportsGeoInterface).__name__!r}.\n" |
|
f"See https://github.com/vega/altair/issues/3441" |
|
) |
|
raise NotImplementedError(msg) |
|
elif _is_pandas_dataframe(data): |
|
data = sanitize_pandas_dataframe(data) |
|
return data.to_csv(index=False) |
|
elif isinstance(data, dict): |
|
if "values" not in data: |
|
msg = "values expected in data dict, but not present" |
|
raise KeyError(msg) |
|
try: |
|
import pandas as pd |
|
except ImportError as exc: |
|
msg = "pandas is required to convert a dict to a CSV string" |
|
raise ImportError(msg) from exc |
|
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) |
|
try: |
|
data_nw = nw.from_native(data, eager_only=True) |
|
except TypeError as exc: |
|
msg = "to_csv only works with data expressed as a DataFrame or as a dict" |
|
raise NotImplementedError(msg) from exc |
|
return data_nw.write_csv() |
|
|
|
|
|
def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table: |
|
"""Convert a DataFrame Interchange Protocol compatible object to an Arrow Table.""" |
|
import pyarrow as pa |
|
|
|
|
|
|
|
|
|
|
|
|
|
for convert_method_name in ("arrow", "to_arrow", "to_arrow_table", "to_pyarrow"): |
|
convert_method = getattr(dfi_df, convert_method_name, None) |
|
if callable(convert_method): |
|
result = convert_method() |
|
if isinstance(result, pa.Table): |
|
return result |
|
|
|
pi = import_pyarrow_interchange() |
|
return pi.from_dataframe(dfi_df) |
|
|