80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, cast
|
|
|
|
from narwhals._utils import is_sequence_of
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterator, Sequence
|
|
|
|
from polars.dataframe.group_by import GroupBy as NativeGroupBy
|
|
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
|
|
|
|
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
|
|
class PolarsGroupBy:
|
|
_compliant_frame: PolarsDataFrame
|
|
_grouped: NativeGroupBy
|
|
_drop_null_keys: bool
|
|
_output_names: Sequence[str]
|
|
|
|
@property
|
|
def compliant(self) -> PolarsDataFrame:
|
|
return self._compliant_frame
|
|
|
|
def __init__(
|
|
self,
|
|
df: PolarsDataFrame,
|
|
keys: Sequence[PolarsExpr] | Sequence[str],
|
|
/,
|
|
*,
|
|
drop_null_keys: bool,
|
|
) -> None:
|
|
self._keys = list(keys)
|
|
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
|
self._grouped = (
|
|
self.compliant.native.group_by(keys)
|
|
if is_sequence_of(keys, str)
|
|
else self.compliant.native.group_by(arg.native for arg in keys)
|
|
)
|
|
|
|
def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame:
|
|
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
|
return self.compliant._with_native(agg_result)
|
|
|
|
def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
|
|
for key, df in self._grouped:
|
|
yield tuple(cast("str", key)), self.compliant._with_native(df)
|
|
|
|
|
|
class PolarsLazyGroupBy:
|
|
_compliant_frame: PolarsLazyFrame
|
|
_grouped: NativeLazyGroupBy
|
|
_drop_null_keys: bool
|
|
_output_names: Sequence[str]
|
|
|
|
@property
|
|
def compliant(self) -> PolarsLazyFrame:
|
|
return self._compliant_frame
|
|
|
|
def __init__(
|
|
self,
|
|
df: PolarsLazyFrame,
|
|
keys: Sequence[PolarsExpr] | Sequence[str],
|
|
/,
|
|
*,
|
|
drop_null_keys: bool,
|
|
) -> None:
|
|
self._keys = list(keys)
|
|
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
|
self._grouped = (
|
|
self.compliant.native.group_by(keys)
|
|
if is_sequence_of(keys, str)
|
|
else self.compliant.native.group_by(arg.native for arg in keys)
|
|
)
|
|
|
|
def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame:
|
|
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
|
return self.compliant._with_native(agg_result)
|