from __future__ import annotations from typing import TYPE_CHECKING, Any, overload from altair import ( Chart, ConcatChart, ConcatSpecGenericSpec, FacetChart, FacetedUnitSpec, FacetSpec, HConcatChart, HConcatSpecGenericSpec, LayerChart, LayerSpec, NonNormalizedSpec, TopLevelConcatSpec, TopLevelFacetSpec, TopLevelHConcatSpec, TopLevelLayerSpec, TopLevelUnitSpec, TopLevelVConcatSpec, UnitSpec, UnitSpecWithFrame, VConcatChart, VConcatSpecGenericSpec, data_transformers, ) from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion from altair.utils.schemapi import Undefined if TYPE_CHECKING: import sys from collections.abc import Iterable if sys.version_info >= (3, 10): from typing import TypeAlias else: from typing_extensions import TypeAlias from altair.typing import ChartType from altair.utils.core import DataFrameLike Scope: TypeAlias = tuple[int, ...] FacetMapping: TypeAlias = dict[tuple[str, Scope], tuple[str, Scope]] # For the transformed_data functionality, the chart classes in the values # can be considered equivalent to the chart class in the key. _chart_class_mapping = { Chart: ( Chart, TopLevelUnitSpec, FacetedUnitSpec, UnitSpec, UnitSpecWithFrame, NonNormalizedSpec, ), LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec), ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec), HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec), VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec), FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec), } @overload def transformed_data( chart: Chart | FacetChart, row_limit: int | None = None, exclude: Iterable[str] | None = None, ) -> DataFrameLike | None: ... @overload def transformed_data( chart: LayerChart | HConcatChart | VConcatChart | ConcatChart, row_limit: int | None = None, exclude: Iterable[str] | None = None, ) -> list[DataFrameLike]: ... def transformed_data(chart, row_limit=None, exclude=None): """ Evaluate a Chart's transforms. Evaluate the data transforms associated with a Chart and return the transformed data as one or more DataFrames Parameters ---------- chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart Altair chart to evaluate transforms on row_limit : int (optional) Maximum number of rows to return for each DataFrame. None (default) for unlimited exclude : iterable of str Set of the names of charts to exclude Returns ------- DataFrame or list of DataFrames or None If input chart is a Chart or Facet Chart, returns a DataFrame of the transformed data. Otherwise, returns a list of DataFrames of the transformed data """ vf = import_vegafusion() # Add mark if none is specified to satisfy Vega-Lite if isinstance(chart, Chart) and chart.mark == Undefined: chart = chart.mark_point() # Deep copy chart so that we can rename marks without affecting caller chart = chart.copy(deep=True) # Ensure that all views are named so that we can look them up in the # resulting Vega specification chart_names = name_views(chart, 0, exclude=exclude) # Compile to Vega and extract inline DataFrames with data_transformers.enable("vegafusion"): vega_spec = chart.to_dict(format="vega", context={"pre_transform": False}) inline_datasets = get_inline_tables(vega_spec) # Build mapping from mark names to vega datasets facet_mapping = get_facet_mapping(vega_spec) dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping) # Build a list of vega dataset names that corresponds to the order # of the chart components dataset_names = [] for chart_name in chart_names: if chart_name in dataset_mapping: dataset_names.append(dataset_mapping[chart_name]) else: msg = "Failed to locate all datasets" raise ValueError(msg) # Extract transformed datasets with VegaFusion datasets, _ = vf.runtime.pre_transform_datasets( vega_spec, dataset_names, row_limit=row_limit, inline_datasets=inline_datasets, ) if isinstance(chart, (Chart, FacetChart)): # Return DataFrame (or None if it was excluded) if input was a simple Chart if not datasets: return None else: return datasets[0] else: # Otherwise return the list of DataFrames return datasets # The equivalent classes from _chart_class_mapping should also be added # to the type hints below for `chart` as the function would also work for them. # However, this was not possible so far as mypy then complains about # "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]" # This might be due to the complex type hierarchy of the chart classes. # See also https://github.com/python/mypy/issues/5119 # and https://github.com/python/mypy/issues/4020 which show that mypy might not have # a very consistent behavior for overloaded functions. # The same error appeared when trying it with Protocols for the concat and layer charts. # This function is only used internally and so we accept this inconsistency for now. def name_views( chart: ChartType, i: int = 0, exclude: Iterable[str] | None = None ) -> list[str]: """ Name unnamed chart views. Name unnamed charts views so that we can look them up later in the compiled Vega spec. Note: This function mutates the input chart by applying names to unnamed views. Parameters ---------- chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart Altair chart to apply names to i : int (default 0) Starting chart index exclude : iterable of str Names of charts to exclude Returns ------- list of str List of the names of the charts and subcharts """ exclude = set(exclude) if exclude is not None else set() if isinstance( chart, (_chart_class_mapping[Chart], _chart_class_mapping[FacetChart]) ): if chart.name not in exclude: if chart.name in {None, Undefined}: # Add name since none is specified chart.name = Chart._get_name() return [chart.name] else: return [] else: subcharts: list[Any] if isinstance(chart, _chart_class_mapping[LayerChart]): subcharts = chart.layer elif isinstance(chart, _chart_class_mapping[HConcatChart]): subcharts = chart.hconcat elif isinstance(chart, _chart_class_mapping[VConcatChart]): subcharts = chart.vconcat elif isinstance(chart, _chart_class_mapping[ConcatChart]): subcharts = chart.concat else: msg = ( "transformed_data accepts an instance of " "Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n" f"Received value of type: {type(chart)}" ) raise ValueError(msg) chart_names: list[str] = [] for subchart in subcharts: for name in name_views(subchart, i=i + len(chart_names), exclude=exclude): chart_names.append(name) return chart_names def get_group_mark_for_scope( vega_spec: dict[str, Any], scope: Scope ) -> dict[str, Any] | None: """ Get the group mark at a particular scope. Parameters ---------- vega_spec : dict Top-level Vega specification dictionary scope : tuple of int Scope tuple. If empty, the original Vega specification is returned. Otherwise, the nested group mark at the scope specified is returned. Returns ------- dict or None Top-level Vega spec (if scope is empty) or group mark (if scope is non-empty) or None (if group mark at scope does not exist) Examples -------- >>> spec = { ... "marks": [ ... {"type": "group", "marks": [{"type": "symbol"}]}, ... {"type": "group", "marks": [{"type": "rect"}]}, ... ] ... } >>> get_group_mark_for_scope(spec, (1,)) {'type': 'group', 'marks': [{'type': 'rect'}]} """ group = vega_spec # Find group at scope for scope_value in scope: group_index = 0 child_group = None for mark in group.get("marks", []): if mark.get("type") == "group": if group_index == scope_value: child_group = mark break group_index += 1 if child_group is None: return None group = child_group return group def get_datasets_for_scope(vega_spec: dict[str, Any], scope: Scope) -> list[str]: """ Get the names of the datasets that are defined at a given scope. Parameters ---------- vega_spec : dict Top-leve Vega specification scope : tuple of int Scope tuple. If empty, the names of top-level datasets are returned Otherwise, the names of the datasets defined in the nested group mark at the specified scope are returned. Returns ------- list of str List of the names of the datasets defined at the specified scope Examples -------- >>> spec = { ... "data": [{"name": "data1"}], ... "marks": [ ... { ... "type": "group", ... "data": [{"name": "data2"}], ... "marks": [{"type": "symbol"}], ... }, ... { ... "type": "group", ... "data": [ ... {"name": "data3"}, ... {"name": "data4"}, ... ], ... "marks": [{"type": "rect"}], ... }, ... ], ... } >>> get_datasets_for_scope(spec, ()) ['data1'] >>> get_datasets_for_scope(spec, (0,)) ['data2'] >>> get_datasets_for_scope(spec, (1,)) ['data3', 'data4'] Returns empty when no group mark exists at scope >>> get_datasets_for_scope(spec, (1, 3)) [] """ group = get_group_mark_for_scope(vega_spec, scope) or {} # get datasets from group datasets = [] for dataset in group.get("data", []): datasets.append(dataset["name"]) # Add facet dataset facet_dataset = group.get("from", {}).get("facet", {}).get("name", None) if facet_dataset: datasets.append(facet_dataset) return datasets def get_definition_scope_for_data_reference( vega_spec: dict[str, Any], data_name: str, usage_scope: Scope ) -> Scope | None: """ Return the scope that a dataset is defined at, for a given usage scope. Parameters ---------- vega_spec: dict Top-level Vega specification data_name: str The name of a dataset reference usage_scope: tuple of int The scope that the dataset is referenced in Returns ------- tuple of int The scope where the referenced dataset is defined, or None if no such dataset is found Examples -------- >>> spec = { ... "data": [{"name": "data1"}], ... "marks": [ ... { ... "type": "group", ... "data": [{"name": "data2"}], ... "marks": [ ... { ... "type": "symbol", ... "encode": { ... "update": { ... "x": {"field": "x", "data": "data1"}, ... "y": {"field": "y", "data": "data2"}, ... } ... }, ... } ... ], ... } ... ], ... } data1 is referenced at scope [0] and defined at scope [] >>> get_definition_scope_for_data_reference(spec, "data1", (0,)) () data2 is referenced at scope [0] and defined at scope [0] >>> get_definition_scope_for_data_reference(spec, "data2", (0,)) (0,) If data2 is not visible at scope [] (the top level), because it's defined in scope [0] >>> repr(get_definition_scope_for_data_reference(spec, "data2", ())) 'None' """ for i in reversed(range(len(usage_scope) + 1)): scope = usage_scope[:i] datasets = get_datasets_for_scope(vega_spec, scope) if data_name in datasets: return scope return None def get_facet_mapping(group: dict[str, Any], scope: Scope = ()) -> FacetMapping: """ Create mapping from facet definitions to source datasets. Parameters ---------- group : dict Top-level Vega spec or nested group mark scope : tuple of int Scope of the group dictionary within a top-level Vega spec Returns ------- dict Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope) Examples -------- >>> spec = { ... "data": [{"name": "data1"}], ... "marks": [ ... { ... "type": "group", ... "from": { ... "facet": { ... "name": "facet1", ... "data": "data1", ... "groupby": ["colA"], ... } ... }, ... } ... ], ... } >>> get_facet_mapping(spec) {('facet1', (0,)): ('data1', ())} """ facet_mapping = {} group_index = 0 mark_group = get_group_mark_for_scope(group, scope) or {} for mark in mark_group.get("marks", []): if mark.get("type", None) == "group": # Get facet for this group group_scope = (*scope, group_index) facet = mark.get("from", {}).get("facet", None) if facet is not None: facet_name = facet.get("name", None) facet_data = facet.get("data", None) if facet_name is not None and facet_data is not None: definition_scope = get_definition_scope_for_data_reference( group, facet_data, scope ) if definition_scope is not None: facet_mapping[facet_name, group_scope] = ( facet_data, definition_scope, ) # Handle children recursively child_mapping = get_facet_mapping(group, scope=group_scope) facet_mapping.update(child_mapping) group_index += 1 return facet_mapping def get_from_facet_mapping( scoped_dataset: tuple[str, Scope], facet_mapping: FacetMapping ) -> tuple[str, Scope]: """ Apply facet mapping to a scoped dataset. Parameters ---------- scoped_dataset : (str, tuple of int) A dataset name and scope tuple facet_mapping : dict from (str, tuple of int) to (str, tuple of int) The facet mapping produced by get_facet_mapping Returns ------- (str, tuple of int) Dataset name and scope tuple that has been mapped as many times as possible Examples -------- Facet mapping as produced by get_facet_mapping >>> facet_mapping = { ... ("facet1", (0,)): ("data1", ()), ... ("facet2", (0, 1)): ("facet1", (0,)), ... } >>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping) ('data1', ()) """ while scoped_dataset in facet_mapping: scoped_dataset = facet_mapping[scoped_dataset] return scoped_dataset def get_datasets_for_view_names( group: dict[str, Any], vl_chart_names: list[str], facet_mapping: FacetMapping, scope: Scope = (), ) -> dict[str, tuple[str, Scope]]: """ Get the Vega datasets that correspond to the provided Altair view names. Parameters ---------- group : dict Top-level Vega spec or nested group mark vl_chart_names : list of str List of the Vega-Lite facet_mapping : dict from (str, tuple of int) to (str, tuple of int) The facet mapping produced by get_facet_mapping scope : tuple of int Scope of the group dictionary within a top-level Vega spec Returns ------- dict from str to (str, tuple of int) Dict from Altair view names to scoped datasets """ datasets = {} group_index = 0 mark_group = get_group_mark_for_scope(group, scope) or {} for mark in mark_group.get("marks", []): for vl_chart_name in vl_chart_names: if mark.get("name", "") == f"{vl_chart_name}_cell": data_name = mark.get("from", {}).get("facet", None).get("data", None) scoped_data_name = (data_name, scope) datasets[vl_chart_name] = get_from_facet_mapping( scoped_data_name, facet_mapping ) break name = mark.get("name", "") if mark.get("type", "") == "group": group_data_names = get_datasets_for_view_names( group, vl_chart_names, facet_mapping, scope=(*scope, group_index) ) for k, v in group_data_names.items(): datasets.setdefault(k, v) group_index += 1 else: for vl_chart_name in vl_chart_names: if name.startswith(vl_chart_name) and name.endswith("_marks"): data_name = mark.get("from", {}).get("data", None) scoped_data = get_definition_scope_for_data_reference( group, data_name, scope ) if scoped_data is not None: datasets[vl_chart_name] = get_from_facet_mapping( (data_name, scoped_data), facet_mapping ) break return datasets