|
from __future__ import annotations |
|
|
|
import json |
|
import pathlib |
|
from typing import Any |
|
|
|
import anywidget |
|
import traitlets |
|
|
|
import altair as alt |
|
from altair import TopLevelSpec |
|
from altair.utils._vegafusion_data import ( |
|
compile_to_vegafusion_chart_state, |
|
using_vegafusion, |
|
) |
|
from altair.utils.selection import IndexSelection, IntervalSelection, PointSelection |
|
|
|
_here = pathlib.Path(__file__).parent |
|
|
|
|
|
class Params(traitlets.HasTraits): |
|
"""Traitlet class storing a JupyterChart's params.""" |
|
|
|
def __init__(self, trait_values): |
|
super().__init__() |
|
|
|
for key, value in trait_values.items(): |
|
if isinstance(value, (int, float)): |
|
traitlet_type = traitlets.Float() |
|
elif isinstance(value, str): |
|
traitlet_type = traitlets.Unicode() |
|
elif isinstance(value, list): |
|
traitlet_type = traitlets.List() |
|
elif isinstance(value, dict): |
|
traitlet_type = traitlets.Dict() |
|
else: |
|
traitlet_type = traitlets.Any() |
|
|
|
|
|
self.add_traits(**{key: traitlet_type}) |
|
|
|
|
|
setattr(self, key, value) |
|
|
|
def __repr__(self): |
|
return f"Params({self.trait_values()})" |
|
|
|
|
|
class Selections(traitlets.HasTraits): |
|
"""Traitlet class storing a JupyterChart's selections.""" |
|
|
|
def __init__(self, trait_values): |
|
super().__init__() |
|
|
|
for key, value in trait_values.items(): |
|
if isinstance(value, IndexSelection): |
|
traitlet_type = traitlets.Instance(IndexSelection) |
|
elif isinstance(value, PointSelection): |
|
traitlet_type = traitlets.Instance(PointSelection) |
|
elif isinstance(value, IntervalSelection): |
|
traitlet_type = traitlets.Instance(IntervalSelection) |
|
else: |
|
msg = f"Unexpected selection type: {type(value)}" |
|
raise ValueError(msg) |
|
|
|
|
|
self.add_traits(**{key: traitlet_type}) |
|
|
|
|
|
setattr(self, key, value) |
|
|
|
|
|
self.observe(self._make_read_only, names=key) |
|
|
|
def __repr__(self): |
|
return f"Selections({self.trait_values()})" |
|
|
|
def _make_read_only(self, change): |
|
"""Work around to make traits read-only, but still allow us to change them internally.""" |
|
if change["name"] in self.traits() and change["old"] != change["new"]: |
|
self._set_value(change["name"], change["old"]) |
|
msg = ( |
|
"Selections may not be set from Python.\n" |
|
f"Attempted to set select: {change['name']}" |
|
) |
|
raise ValueError(msg) |
|
|
|
def _set_value(self, key, value): |
|
self.unobserve(self._make_read_only, names=key) |
|
setattr(self, key, value) |
|
self.observe(self._make_read_only, names=key) |
|
|
|
|
|
def load_js_src() -> str: |
|
return (_here / "js" / "index.js").read_text() |
|
|
|
|
|
class JupyterChart(anywidget.AnyWidget): |
|
_esm = load_js_src() |
|
_css = r""" |
|
.vega-embed { |
|
/* Make sure action menu isn't cut off */ |
|
overflow: visible; |
|
} |
|
""" |
|
|
|
|
|
chart = traitlets.Instance(TopLevelSpec, allow_none=True) |
|
spec = traitlets.Dict(allow_none=True).tag(sync=True) |
|
debounce_wait = traitlets.Float(default_value=10).tag(sync=True) |
|
max_wait = traitlets.Bool(default_value=True).tag(sync=True) |
|
local_tz = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True) |
|
debug = traitlets.Bool(default_value=False) |
|
embed_options = traitlets.Dict(default_value=None, allow_none=True).tag(sync=True) |
|
|
|
|
|
_selection_types = traitlets.Dict() |
|
_vl_selections = traitlets.Dict().tag(sync=True) |
|
|
|
|
|
_params = traitlets.Dict().tag(sync=True) |
|
|
|
|
|
_chart_state = traitlets.Any(allow_none=True) |
|
_js_watch_plan = traitlets.Any(allow_none=True).tag(sync=True) |
|
_js_to_py_updates = traitlets.Any(allow_none=True).tag(sync=True) |
|
_py_to_js_updates = traitlets.Any(allow_none=True).tag(sync=True) |
|
|
|
|
|
_is_offline = False |
|
|
|
@classmethod |
|
def enable_offline(cls, offline: bool = True): |
|
""" |
|
Configure JupyterChart's offline behavior. |
|
|
|
Parameters |
|
---------- |
|
offline: bool |
|
If True, configure JupyterChart to operate in offline mode where JavaScript |
|
dependencies are loaded from vl-convert. |
|
If False, configure it to operate in online mode where JavaScript dependencies |
|
are loaded from CDN dynamically. This is the default behavior. |
|
""" |
|
from altair.utils._importers import import_vl_convert, vl_version_for_vl_convert |
|
|
|
if offline: |
|
if cls._is_offline: |
|
|
|
return |
|
|
|
vlc = import_vl_convert() |
|
|
|
src_lines = load_js_src().split("\n") |
|
|
|
|
|
while src_lines and ( |
|
len(src_lines[0].strip()) == 0 |
|
or src_lines[0].startswith("import") |
|
or src_lines[0].startswith("//") |
|
): |
|
src_lines.pop(0) |
|
|
|
src = "\n".join(src_lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
bundled_src = vlc.javascript_bundle( |
|
src, vl_version=vl_version_for_vl_convert() |
|
) |
|
cls._esm = bundled_src |
|
cls._is_offline = True |
|
else: |
|
cls._esm = load_js_src() |
|
cls._is_offline = False |
|
|
|
def __init__( |
|
self, |
|
chart: TopLevelSpec, |
|
debounce_wait: int = 10, |
|
max_wait: bool = True, |
|
debug: bool = False, |
|
embed_options: dict | None = None, |
|
**kwargs: Any, |
|
): |
|
""" |
|
Jupyter Widget for displaying and updating Altair Charts, and retrieving selection and parameter values. |
|
|
|
Parameters |
|
---------- |
|
chart: Chart |
|
Altair Chart instance |
|
debounce_wait: int |
|
Debouncing wait time in milliseconds. Updates will be sent from the client to the kernel |
|
after debounce_wait milliseconds of no chart interactions. |
|
max_wait: bool |
|
If True (default), updates will be sent from the client to the kernel every debounce_wait |
|
milliseconds even if there are ongoing chart interactions. If False, updates will not be |
|
sent until chart interactions have completed. |
|
debug: bool |
|
If True, debug messages will be printed |
|
embed_options: dict |
|
Options to pass to vega-embed. |
|
See https://github.com/vega/vega-embed?tab=readme-ov-file#options |
|
""" |
|
self.params = Params({}) |
|
self.selections = Selections({}) |
|
super().__init__( |
|
chart=chart, |
|
debounce_wait=debounce_wait, |
|
max_wait=max_wait, |
|
debug=debug, |
|
embed_options=embed_options, |
|
**kwargs, |
|
) |
|
|
|
@traitlets.observe("chart") |
|
def _on_change_chart(self, change): |
|
"""Updates the JupyterChart's internal state when the wrapped Chart instance changes.""" |
|
new_chart = change.new |
|
selection_watches = [] |
|
selection_types = {} |
|
initial_params = {} |
|
initial_vl_selections = {} |
|
empty_selections = {} |
|
|
|
if new_chart is None: |
|
with self.hold_sync(): |
|
self.spec = None |
|
self._selection_types = selection_types |
|
self._vl_selections = initial_vl_selections |
|
self._params = initial_params |
|
return |
|
|
|
params = getattr(new_chart, "params", []) |
|
|
|
if params is not alt.Undefined: |
|
for param in new_chart.params: |
|
if isinstance(param.name, alt.ParameterName): |
|
clean_name = param.name.to_json().strip('"') |
|
else: |
|
clean_name = param.name |
|
|
|
select = getattr(param, "select", alt.Undefined) |
|
|
|
if select != alt.Undefined: |
|
if not isinstance(select, dict): |
|
select = select.to_dict() |
|
|
|
select_type = select["type"] |
|
if select_type == "point": |
|
if not ( |
|
select.get("fields", None) or select.get("encodings", None) |
|
): |
|
|
|
|
|
selection_types[clean_name] = "index" |
|
empty_selections[clean_name] = IndexSelection( |
|
name=clean_name, value=[], store=[] |
|
) |
|
else: |
|
selection_types[clean_name] = "point" |
|
empty_selections[clean_name] = PointSelection( |
|
name=clean_name, value=[], store=[] |
|
) |
|
elif select_type == "interval": |
|
selection_types[clean_name] = "interval" |
|
empty_selections[clean_name] = IntervalSelection( |
|
name=clean_name, value={}, store=[] |
|
) |
|
else: |
|
msg = f"Unexpected selection type {select.type}" |
|
raise ValueError(msg) |
|
selection_watches.append(clean_name) |
|
initial_vl_selections[clean_name] = {"value": None, "store": []} |
|
else: |
|
clean_value = param.value if param.value != alt.Undefined else None |
|
initial_params[clean_name] = clean_value |
|
|
|
|
|
for param_name in collect_transform_params(new_chart): |
|
initial_params[param_name] = None |
|
|
|
|
|
self.params = Params(initial_params) |
|
|
|
def on_param_traitlet_changed(param_change): |
|
new_params = dict(self._params) |
|
new_params[param_change["name"]] = param_change["new"] |
|
self._params = new_params |
|
|
|
self.params.observe(on_param_traitlet_changed) |
|
|
|
|
|
self.selections = Selections(empty_selections) |
|
|
|
|
|
with self.hold_sync(): |
|
if using_vegafusion(): |
|
if self.local_tz is None: |
|
self.spec = None |
|
|
|
def on_local_tz_change(change): |
|
self._init_with_vegafusion(change["new"]) |
|
|
|
self.observe(on_local_tz_change, ["local_tz"]) |
|
else: |
|
self._init_with_vegafusion(self.local_tz) |
|
else: |
|
self.spec = new_chart.to_dict() |
|
self._selection_types = selection_types |
|
self._vl_selections = initial_vl_selections |
|
self._params = initial_params |
|
|
|
def _init_with_vegafusion(self, local_tz: str): |
|
if self.chart is not None: |
|
vegalite_spec = self.chart.to_dict(context={"pre_transform": False}) |
|
with self.hold_sync(): |
|
self._chart_state = compile_to_vegafusion_chart_state( |
|
vegalite_spec, local_tz |
|
) |
|
self._js_watch_plan = self._chart_state.get_watch_plan()[ |
|
"client_to_server" |
|
] |
|
self.spec = self._chart_state.get_transformed_spec() |
|
|
|
|
|
def on_js_to_py_updates(change): |
|
if self.debug: |
|
updates_str = json.dumps(change["new"], indent=2) |
|
print( |
|
f"JavaScript to Python VegaFusion updates:\n {updates_str}" |
|
) |
|
updates = self._chart_state.update(change["new"]) |
|
if self.debug: |
|
updates_str = json.dumps(updates, indent=2) |
|
print( |
|
f"Python to JavaScript VegaFusion updates:\n {updates_str}" |
|
) |
|
self._py_to_js_updates = updates |
|
|
|
self.observe(on_js_to_py_updates, ["_js_to_py_updates"]) |
|
|
|
@traitlets.observe("_params") |
|
def _on_change_params(self, change): |
|
for param_name, value in change.new.items(): |
|
setattr(self.params, param_name, value) |
|
|
|
@traitlets.observe("_vl_selections") |
|
def _on_change_selections(self, change): |
|
"""Updates the JupyterChart's public selections traitlet in response to changes that the JavaScript logic makes to the internal _selections traitlet.""" |
|
for selection_name, selection_dict in change.new.items(): |
|
value = selection_dict["value"] |
|
store = selection_dict["store"] |
|
selection_type = self._selection_types[selection_name] |
|
if selection_type == "index": |
|
self.selections._set_value( |
|
selection_name, |
|
IndexSelection.from_vega(selection_name, signal=value, store=store), |
|
) |
|
elif selection_type == "point": |
|
self.selections._set_value( |
|
selection_name, |
|
PointSelection.from_vega(selection_name, signal=value, store=store), |
|
) |
|
elif selection_type == "interval": |
|
self.selections._set_value( |
|
selection_name, |
|
IntervalSelection.from_vega( |
|
selection_name, signal=value, store=store |
|
), |
|
) |
|
|
|
|
|
def collect_transform_params(chart: TopLevelSpec) -> set[str]: |
|
""" |
|
Collect the names of params that are defined by transforms. |
|
|
|
Parameters |
|
---------- |
|
chart: Chart from which to extract transform params |
|
|
|
Returns |
|
------- |
|
set of param names |
|
""" |
|
transform_params = set() |
|
|
|
|
|
for prop in ("layer", "concat", "hconcat", "vconcat"): |
|
for child in getattr(chart, prop, []): |
|
transform_params.update(collect_transform_params(child)) |
|
|
|
|
|
transforms = getattr(chart, "transform", []) |
|
transforms = transforms if transforms != alt.Undefined else [] |
|
for tx in transforms: |
|
if hasattr(tx, "param"): |
|
transform_params.add(tx.param) |
|
|
|
return transform_params |
|
|