|
import json |
|
import os |
|
import random |
|
import hashlib |
|
import warnings |
|
|
|
import pandas as pd |
|
from toolz import curried |
|
from typing import Callable |
|
|
|
from .core import sanitize_dataframe |
|
from .core import sanitize_geo_interface |
|
from .deprecation import AltairDeprecationWarning |
|
from .plugin_registry import PluginRegistry |
|
|
|
|
|
|
|
|
|
|
|
DataTransformerType = Callable |
|
|
|
|
|
class DataTransformerRegistry(PluginRegistry[DataTransformerType]): |
|
_global_settings = {"consolidate_datasets": True} |
|
|
|
@property |
|
def consolidate_datasets(self): |
|
return self._global_settings["consolidate_datasets"] |
|
|
|
@consolidate_datasets.setter |
|
def consolidate_datasets(self, value): |
|
self._global_settings["consolidate_datasets"] = value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaxRowsError(Exception): |
|
"""Raised when a data model has too many rows.""" |
|
|
|
pass |
|
|
|
|
|
@curried.curry |
|
def limit_rows(data, max_rows=5000): |
|
"""Raise MaxRowsError if the data model has more than max_rows. |
|
|
|
If max_rows is None, then do not perform any check. |
|
""" |
|
check_data_type(data) |
|
if hasattr(data, "__geo_interface__"): |
|
if data.__geo_interface__["type"] == "FeatureCollection": |
|
values = data.__geo_interface__["features"] |
|
else: |
|
values = data.__geo_interface__ |
|
elif isinstance(data, pd.DataFrame): |
|
values = data |
|
elif isinstance(data, dict): |
|
if "values" in data: |
|
values = data["values"] |
|
else: |
|
return data |
|
if max_rows is not None and len(values) > max_rows: |
|
raise MaxRowsError( |
|
"The number of rows in your dataset is greater " |
|
"than the maximum allowed ({}). " |
|
"For information on how to plot larger datasets " |
|
"in Altair, see the documentation".format(max_rows) |
|
) |
|
return data |
|
|
|
|
|
@curried.curry |
|
def sample(data, n=None, frac=None): |
|
"""Reduce the size of the data model by sampling without replacement.""" |
|
check_data_type(data) |
|
if isinstance(data, pd.DataFrame): |
|
return data.sample(n=n, frac=frac) |
|
elif isinstance(data, dict): |
|
if "values" in data: |
|
values = data["values"] |
|
n = n if n else int(frac * len(values)) |
|
values = random.sample(values, n) |
|
return {"values": values} |
|
|
|
|
|
@curried.curry |
|
def to_json( |
|
data, |
|
prefix="altair-data", |
|
extension="json", |
|
filename="{prefix}-{hash}.{extension}", |
|
urlpath="", |
|
): |
|
""" |
|
Write the data model to a .json file and return a url based data model. |
|
""" |
|
data_json = _data_to_json_string(data) |
|
data_hash = _compute_data_hash(data_json) |
|
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) |
|
with open(filename, "w") as f: |
|
f.write(data_json) |
|
return {"url": os.path.join(urlpath, filename), "format": {"type": "json"}} |
|
|
|
|
|
@curried.curry |
|
def to_csv( |
|
data, |
|
prefix="altair-data", |
|
extension="csv", |
|
filename="{prefix}-{hash}.{extension}", |
|
urlpath="", |
|
): |
|
"""Write the data model to a .csv file and return a url based data model.""" |
|
data_csv = _data_to_csv_string(data) |
|
data_hash = _compute_data_hash(data_csv) |
|
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) |
|
with open(filename, "w") as f: |
|
f.write(data_csv) |
|
return {"url": os.path.join(urlpath, filename), "format": {"type": "csv"}} |
|
|
|
|
|
@curried.curry |
|
def to_values(data): |
|
"""Replace a DataFrame by a data model with values.""" |
|
check_data_type(data) |
|
if hasattr(data, "__geo_interface__"): |
|
if isinstance(data, pd.DataFrame): |
|
data = sanitize_dataframe(data) |
|
data = sanitize_geo_interface(data.__geo_interface__) |
|
return {"values": data} |
|
elif isinstance(data, pd.DataFrame): |
|
data = sanitize_dataframe(data) |
|
return {"values": data.to_dict(orient="records")} |
|
elif isinstance(data, dict): |
|
if "values" not in data: |
|
raise KeyError("values expected in data dict, but not present.") |
|
return data |
|
|
|
|
|
def check_data_type(data): |
|
"""Raise if the data is not a dict or DataFrame.""" |
|
if not isinstance(data, (dict, pd.DataFrame)) and not hasattr( |
|
data, "__geo_interface__" |
|
): |
|
raise TypeError( |
|
"Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format( |
|
type(data) |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _compute_data_hash(data_str): |
|
return hashlib.md5(data_str.encode()).hexdigest() |
|
|
|
|
|
def _data_to_json_string(data): |
|
"""Return a JSON string representation of the input data""" |
|
check_data_type(data) |
|
if hasattr(data, "__geo_interface__"): |
|
if isinstance(data, pd.DataFrame): |
|
data = sanitize_dataframe(data) |
|
data = sanitize_geo_interface(data.__geo_interface__) |
|
return json.dumps(data) |
|
elif isinstance(data, pd.DataFrame): |
|
data = sanitize_dataframe(data) |
|
return data.to_json(orient="records", double_precision=15) |
|
elif isinstance(data, dict): |
|
if "values" not in data: |
|
raise KeyError("values expected in data dict, but not present.") |
|
return json.dumps(data["values"], sort_keys=True) |
|
else: |
|
raise NotImplementedError( |
|
"to_json only works with data expressed as " "a DataFrame or as a dict" |
|
) |
|
|
|
|
|
def _data_to_csv_string(data): |
|
"""return a CSV string representation of the input data""" |
|
check_data_type(data) |
|
if hasattr(data, "__geo_interface__"): |
|
raise NotImplementedError( |
|
"to_csv does not work with data that " |
|
"contains the __geo_interface__ attribute" |
|
) |
|
elif isinstance(data, pd.DataFrame): |
|
data = sanitize_dataframe(data) |
|
return data.to_csv(index=False) |
|
elif isinstance(data, dict): |
|
if "values" not in data: |
|
raise KeyError("values expected in data dict, but not present") |
|
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) |
|
else: |
|
raise NotImplementedError( |
|
"to_csv only works with data expressed as " "a DataFrame or as a dict" |
|
) |
|
|
|
|
|
def pipe(data, *funcs): |
|
""" |
|
Pipe a value through a sequence of functions |
|
|
|
Deprecated: use toolz.curried.pipe() instead. |
|
""" |
|
warnings.warn( |
|
"alt.pipe() is deprecated, and will be removed in a future release. " |
|
"Use toolz.curried.pipe() instead.", |
|
AltairDeprecationWarning, |
|
) |
|
return curried.pipe(data, *funcs) |
|
|
|
|
|
def curry(*args, **kwargs): |
|
"""Curry a callable function |
|
|
|
Deprecated: use toolz.curried.curry() instead. |
|
""" |
|
warnings.warn( |
|
"alt.curry() is deprecated, and will be removed in a future release. " |
|
"Use toolz.curried.curry() instead.", |
|
AltairDeprecationWarning, |
|
) |
|
return curried.curry(*args, **kwargs) |
|
|