146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
|
|
|
import dask.dataframe as dd
|
|
|
|
from narwhals._compliant import DepthTrackingGroupBy
|
|
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Mapping, Sequence
|
|
|
|
import pandas as pd
|
|
from dask.dataframe.api import GroupBy as _DaskGroupBy
|
|
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
|
|
from typing_extensions import TypeAlias
|
|
|
|
from narwhals._compliant.typing import NarwhalsAggregation
|
|
from narwhals._dask.dataframe import DaskLazyFrame
|
|
from narwhals._dask.expr import DaskExpr
|
|
|
|
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
|
|
_AggFn: TypeAlias = Callable[..., Any]
|
|
|
|
else:
|
|
try:
|
|
import dask.dataframe.dask_expr as dx
|
|
except ModuleNotFoundError: # pragma: no cover
|
|
import dask_expr as dx
|
|
_DaskGroupBy = dx._groupby.GroupBy
|
|
|
|
Aggregation: TypeAlias = "str | _AggFn"
|
|
"""The name of an aggregation function, or the function itself."""
|
|
|
|
|
|
def n_unique() -> dd.Aggregation:
|
|
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
|
return s.nunique(dropna=False)
|
|
|
|
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
|
return s0.sum()
|
|
|
|
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)
|
|
|
|
|
|
def _all() -> dd.Aggregation:
|
|
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
|
return s.all(skipna=True)
|
|
|
|
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
|
return s0.all(skipna=True)
|
|
|
|
return dd.Aggregation(name="all", chunk=chunk, agg=agg)
|
|
|
|
|
|
def _any() -> dd.Aggregation:
|
|
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
|
return s.any(skipna=True)
|
|
|
|
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
|
return s0.any(skipna=True)
|
|
|
|
return dd.Aggregation(name="any", chunk=chunk, agg=agg)
|
|
|
|
|
|
def var(ddof: int) -> _AggFn:
|
|
return partial(_DaskGroupBy.var, ddof=ddof)
|
|
|
|
|
|
def std(ddof: int) -> _AggFn:
|
|
return partial(_DaskGroupBy.std, ddof=ddof)
|
|
|
|
|
|
class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]):
|
|
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
|
"sum": "sum",
|
|
"mean": "mean",
|
|
"median": "median",
|
|
"max": "max",
|
|
"min": "min",
|
|
"std": std,
|
|
"var": var,
|
|
"len": "size",
|
|
"n_unique": n_unique,
|
|
"count": "count",
|
|
"quantile": "quantile",
|
|
"all": _all,
|
|
"any": _any,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
df: DaskLazyFrame,
|
|
keys: Sequence[DaskExpr] | Sequence[str],
|
|
/,
|
|
*,
|
|
drop_null_keys: bool,
|
|
) -> None:
|
|
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
|
|
df, keys=keys
|
|
)
|
|
self._grouped = self.compliant.native.groupby(
|
|
self._keys, dropna=drop_null_keys, observed=True
|
|
)
|
|
|
|
def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
|
|
from narwhals._dask.dataframe import DaskLazyFrame
|
|
|
|
if not exprs:
|
|
# No aggregation provided
|
|
return (
|
|
self.compliant.simple_select(*self._keys)
|
|
.unique(self._keys, keep="any")
|
|
.rename(dict(zip(self._keys, self._output_key_names)))
|
|
)
|
|
|
|
self._ensure_all_simple(exprs)
|
|
# This should be the fastpath, but cuDF is too far behind to use it.
|
|
# - https://github.com/rapidsai/cudf/issues/15118
|
|
# - https://github.com/rapidsai/cudf/issues/15084
|
|
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
|
|
exclude = (*self._keys, *self._output_key_names)
|
|
for expr in exprs:
|
|
output_names, aliases = evaluate_output_names_and_aliases(
|
|
expr, self.compliant, exclude
|
|
)
|
|
if expr._depth == 0:
|
|
# e.g. `agg(nw.len())`
|
|
column = self._keys[0]
|
|
agg_fn = self._remap_expr_name(expr._function_name)
|
|
simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn)))
|
|
continue
|
|
|
|
# e.g. `agg(nw.mean('a'))`
|
|
agg_fn = self._remap_expr_name(self._leaf_name(expr))
|
|
# deal with n_unique case in a "lazy" mode to not depend on dask globally
|
|
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
|
|
simple_aggregations.update(
|
|
(alias, (output_name, agg_fn))
|
|
for alias, output_name in zip(aliases, output_names)
|
|
)
|
|
return DaskLazyFrame(
|
|
self._grouped.agg(**simple_aggregations).reset_index(),
|
|
version=self.compliant._version,
|
|
).rename(dict(zip(self._keys, self._output_key_names)))
|