team-10/env/Lib/site-packages/narwhals/_spark_like/group_by.py

38 lines
1.2 KiB
Python
Raw Normal View History

2025-08-02 07:34:44 +02:00
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._sql.group_by import SQLGroupBy
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlframe.base.column import Column # noqa: F401
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeLazyGroupBy(SQLGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]):
def __init__(
self,
df: SparkLikeLazyFrame,
keys: Sequence[SparkLikeExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
result = (
self.compliant.native.groupBy(*self._keys).agg(*agg_columns)
if (agg_columns := list(self._evaluate_exprs(exprs)))
else self.compliant.native.select(*self._keys).dropDuplicates()
)
return self.compliant._with_native(result).rename(
dict(zip(self._keys, self._output_key_names))
)