from __future__ import annotations import warnings from functools import lru_cache from itertools import chain from operator import methodcaller from typing import TYPE_CHECKING, Any, ClassVar, Literal from narwhals._compliant import EagerGroupBy from narwhals._exceptions import issue_warning from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals.dependencies import is_pandas_like_dataframe if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence import pandas as pd from pandas.api.typing import DataFrameGroupBy as _NativeGroupBy from typing_extensions import TypeAlias, Unpack from narwhals._compliant.typing import NarwhalsAggregation, ScalarKwargs from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr NativeGroupBy: TypeAlias = "_NativeGroupBy[tuple[str, ...], Literal[True]]" NativeApply: TypeAlias = "Callable[[pd.DataFrame], pd.Series[Any]]" InefficientNativeAggregation: TypeAlias = Literal["cov", "skew"] NativeAggregation: TypeAlias = Literal[ "any", "all", "count", "first", "idxmax", "idxmin", "last", "max", "mean", "median", "min", "nunique", "prod", "quantile", "sem", "size", "std", "sum", "var", InefficientNativeAggregation, ] """https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#built-in-aggregation-methods""" _NativeAgg: TypeAlias = "Callable[[Any], pd.DataFrame | pd.Series[Any]]" """Equivalent to a partial method call on `DataFrameGroupBy`.""" NonStrHashable: TypeAlias = Any """Because `pandas` allows *"names"* like that 😭""" @lru_cache(maxsize=32) def _native_agg(name: NativeAggregation, /, **kwds: Unpack[ScalarKwargs]) -> _NativeAgg: if name == "nunique": return methodcaller(name, dropna=False) if not kwds or kwds.get("ddof") == 1: return methodcaller(name) return methodcaller(name, **kwds) class AggExpr: """Wrapper storing the intermediate state per-`PandasLikeExpr`. There's a lot of edge cases to handle, so aim to evaluate as little as possible - and store anything that's needed twice. Warning: While a `PandasLikeExpr` can be reused - this wrapper is valid **only** in a single `.agg(...)` operation. """ expr: PandasLikeExpr output_names: Sequence[str] aliases: Sequence[str] def __init__(self, expr: PandasLikeExpr) -> None: self.expr = expr self.output_names = () self.aliases = () self._leaf_name: NarwhalsAggregation | Any = "" def with_expand_names(self, group_by: PandasLikeGroupBy, /) -> AggExpr: """**Mutating operation**. Stores the results of `evaluate_output_names_and_aliases`. """ df = group_by.compliant exclude = group_by.exclude self.output_names, self.aliases = evaluate_output_names_and_aliases( self.expr, df, exclude ) return self def _getitem_aggs( self, group_by: PandasLikeGroupBy, / ) -> pd.DataFrame | pd.Series[Any]: """Evaluate the wrapped expression as a group_by operation.""" result: pd.DataFrame | pd.Series[Any] names = self.output_names if self.is_len() and self.is_anonymous(): result = group_by._grouped.size() else: select = names[0] if len(names) == 1 else list(names) result = self.native_agg()(group_by._grouped[select]) if is_pandas_like_dataframe(result): result.columns = list(self.aliases) else: result.name = self.aliases[0] return result def is_len(self) -> bool: return self.leaf_name == "len" def is_anonymous(self) -> bool: return self.expr._depth == 0 @property def kwargs(self) -> ScalarKwargs: return self.expr._scalar_kwargs @property def leaf_name(self) -> NarwhalsAggregation | Any: if name := self._leaf_name: return name self._leaf_name = PandasLikeGroupBy._leaf_name(self.expr) return self._leaf_name def native_agg(self) -> _NativeAgg: """Return a partial `DataFrameGroupBy` method, missing only `self`.""" return _native_agg( PandasLikeGroupBy._remap_expr_name(self.leaf_name), **self.kwargs ) class PandasLikeGroupBy( EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", NativeAggregation] ): _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, NativeAggregation]] = { "sum": "sum", "mean": "mean", "median": "median", "max": "max", "min": "min", "std": "std", "var": "var", "len": "size", "n_unique": "nunique", "count": "count", "quantile": "quantile", "all": "all", "any": "any", } _original_columns: tuple[str, ...] """Column names *prior* to any aliasing in `ParseKeysGroupBy`.""" _keys: list[str] """Stores the **aliased** version of group keys from `ParseKeysGroupBy`.""" _output_key_names: list[str] """Stores the **original** version of group keys.""" @property def exclude(self) -> tuple[str, ...]: """Group keys to ignore when expanding multi-output aggregations.""" return self._exclude def __init__( self, df: PandasLikeDataFrame, keys: Sequence[PandasLikeExpr] | Sequence[str], /, *, drop_null_keys: bool, ) -> None: self._original_columns = tuple(df.columns) self._drop_null_keys = drop_null_keys self._compliant_frame, self._keys, self._output_key_names = self._parse_keys( df, keys ) self._exclude: tuple[str, ...] = (*self._keys, *self._output_key_names) # Drop index to avoid potential collisions: # https://github.com/narwhals-dev/narwhals/issues/1907. native = self.compliant.native if set(native.index.names).intersection(self.compliant.columns): native = native.reset_index(drop=True) self._grouped: NativeGroupBy = native.groupby( self._keys.copy(), sort=False, as_index=True, dropna=drop_null_keys, observed=True, ) def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: all_aggs_are_simple = True agg_exprs: list[AggExpr] = [] for expr in exprs: agg_exprs.append(AggExpr(expr).with_expand_names(self)) if not self._is_simple(expr): all_aggs_are_simple = False if all_aggs_are_simple: result: pd.DataFrame if agg_exprs: ns = self.compliant.__narwhals_namespace__() result = ns._concat_horizontal(self._getitem_aggs(agg_exprs)) else: result = self.compliant.__native_namespace__().DataFrame( list(self._grouped.groups), columns=self._keys ) elif self.compliant.native.empty: raise empty_results_error() else: result = self._apply_aggs(exprs) # NOTE: Keep `inplace=True` to avoid making a redundant copy. # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files result.reset_index(inplace=True) # noqa: PD002 return self._select_results(result, agg_exprs) def _select_results( self, df: pd.DataFrame, /, agg_exprs: Sequence[AggExpr] ) -> PandasLikeDataFrame: """Responsible for remapping temp column names back to original. See `ParseKeysGroupBy`. """ new_names = chain.from_iterable(e.aliases for e in agg_exprs) return ( self.compliant._with_native(df, validate_column_names=False) .simple_select(*self._keys, *new_names) .rename(dict(zip(self._keys, self._output_key_names))) ) def _getitem_aggs( self, exprs: Iterable[AggExpr], / ) -> list[pd.DataFrame | pd.Series[Any]]: return [e._getitem_aggs(self) for e in exprs] def _apply_aggs(self, exprs: Iterable[PandasLikeExpr]) -> pd.DataFrame: """Stub issue for `include_groups` [pandas-dev/pandas-stubs#1270]. - [User guide] mentions `include_groups` 4 times without deprecation. - [`DataFrameGroupBy.apply`] doc says the default value of `True` is deprecated since `2.2.0`. - `False` is explicitly the only *non-deprecated* option, but entirely omitted since [pandas-dev/pandas-stubs#1268]. [pandas-dev/pandas-stubs#1270]: https://github.com/pandas-dev/pandas-stubs/issues/1270 [User guide]: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html [`DataFrameGroupBy.apply`]: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.core.groupby.DataFrameGroupBy.apply.html [pandas-dev/pandas-stubs#1268]: https://github.com/pandas-dev/pandas-stubs/pull/1268 """ warn_complex_group_by() impl = self.compliant._implementation func = self._apply_exprs_function(exprs) apply = self._grouped.apply if impl.is_pandas() and impl._backend_version() >= (2, 2): return apply(func, include_groups=False) # type: ignore[call-overload] else: # pragma: no cover return apply(func) def _apply_exprs_function(self, exprs: Iterable[PandasLikeExpr]) -> NativeApply: ns = self.compliant.__narwhals_namespace__() into_series = ns._series.from_iterable def fn(df: pd.DataFrame) -> pd.Series[Any]: compliant = self.compliant._with_native(df) results = ( (keys.native.iloc[0], keys.name) for expr in exprs for keys in expr(compliant) ) out_group, out_names = zip(*results) if results else ([], []) return into_series(out_group, index=out_names, context=ns).native return fn def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=".*a length 1 tuple will be returned", category=FutureWarning, ) with_native = self.compliant._with_native for key, group in self._grouped: yield (key, with_native(group).simple_select(*self._original_columns)) def empty_results_error() -> ValueError: """Don't even attempt this, it's way too inconsistent across pandas versions.""" msg = ( "No results for group-by aggregation.\n\n" "Hint: you were probably trying to apply a non-elementary aggregation with a " "pandas-like API.\n" "Please rewrite your query such that group-by aggregations " "are elementary. For example, instead of:\n\n" " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" "use:\n\n" " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" ) return ValueError(msg) def warn_complex_group_by() -> None: issue_warning( "Found complex group-by expression, which can't be expressed efficiently with the " "pandas API. If you can, please rewrite your query such that group-by aggregations " "are simple (e.g. mean, std, min, max, ...). \n\n" "Please see: " "https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/", UserWarning, )