414 lines
14 KiB
Python
414 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import operator
|
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast
|
|
|
|
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
|
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
|
|
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
|
|
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
|
|
from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace
|
|
from narwhals._spark_like.utils import (
|
|
import_functions,
|
|
import_native_dtypes,
|
|
import_window,
|
|
narwhals_to_native_dtype,
|
|
true_divide,
|
|
)
|
|
from narwhals._sql.expr import SQLExpr
|
|
from narwhals._utils import Implementation, Version, not_implemented
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterator, Mapping, Sequence
|
|
|
|
from sqlframe.base.column import Column
|
|
from sqlframe.base.window import Window, WindowSpec
|
|
from typing_extensions import Self, TypeAlias
|
|
|
|
from narwhals._compliant import WindowInputs
|
|
from narwhals._compliant.typing import (
|
|
AliasNames,
|
|
EvalNames,
|
|
EvalSeries,
|
|
WindowFunction,
|
|
)
|
|
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
|
from narwhals._spark_like.namespace import SparkLikeNamespace
|
|
from narwhals._utils import _LimitedContext
|
|
from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod
|
|
|
|
NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"]
|
|
SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column]
|
|
SparkWindowInputs = WindowInputs[Column]
|
|
|
|
|
|
class SparkLikeExpr(SQLExpr["SparkLikeLazyFrame", "Column"]):
|
|
def __init__(
|
|
self,
|
|
call: EvalSeries[SparkLikeLazyFrame, Column],
|
|
window_function: SparkWindowFunction | None = None,
|
|
*,
|
|
evaluate_output_names: EvalNames[SparkLikeLazyFrame],
|
|
alias_output_names: AliasNames | None,
|
|
version: Version,
|
|
implementation: Implementation,
|
|
) -> None:
|
|
self._call = call
|
|
self._evaluate_output_names = evaluate_output_names
|
|
self._alias_output_names = alias_output_names
|
|
self._version = version
|
|
self._implementation = implementation
|
|
self._metadata: ExprMetadata | None = None
|
|
self._window_function: SparkWindowFunction | None = window_function
|
|
|
|
_REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = {
|
|
"min": "rank",
|
|
"max": "rank",
|
|
"average": "rank",
|
|
"dense": "dense_rank",
|
|
"ordinal": "row_number",
|
|
}
|
|
|
|
def _count_star(self) -> Column:
|
|
return self._F.count("*")
|
|
|
|
def _window_expression(
|
|
self,
|
|
expr: Column,
|
|
partition_by: Sequence[str | Column] = (),
|
|
order_by: Sequence[str | Column] = (),
|
|
rows_start: int | None = None,
|
|
rows_end: int | None = None,
|
|
*,
|
|
descending: Sequence[bool] | None = None,
|
|
nulls_last: Sequence[bool] | None = None,
|
|
) -> Column:
|
|
window = self.partition_by(*partition_by)
|
|
if order_by:
|
|
window = window.orderBy(
|
|
*self._sort(*order_by, descending=descending, nulls_last=nulls_last)
|
|
)
|
|
if rows_start is not None and rows_end is not None:
|
|
window = window.rowsBetween(rows_start, rows_end)
|
|
elif rows_end is not None:
|
|
window = window.rowsBetween(self._Window.unboundedPreceding, rows_end)
|
|
elif rows_start is not None: # pragma: no cover
|
|
window = window.rowsBetween(rows_start, self._Window.unboundedFollowing)
|
|
return expr.over(window)
|
|
|
|
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
|
if kind is ExprKind.LITERAL:
|
|
return self
|
|
return self.over([self._F.lit(1)], [])
|
|
|
|
@property
|
|
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202, N802
|
|
if TYPE_CHECKING:
|
|
from sqlframe.base import functions
|
|
|
|
return functions
|
|
else:
|
|
return import_functions(self._implementation)
|
|
|
|
@property
|
|
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
|
if TYPE_CHECKING:
|
|
from sqlframe.base import types
|
|
|
|
return types
|
|
else:
|
|
return import_native_dtypes(self._implementation)
|
|
|
|
@property
|
|
def _Window(self) -> type[Window]: # noqa: N802
|
|
if TYPE_CHECKING:
|
|
from sqlframe.base.window import Window
|
|
|
|
return Window
|
|
else:
|
|
return import_window(self._implementation)
|
|
|
|
def _sort(
|
|
self,
|
|
*cols: Column | str,
|
|
descending: Sequence[bool] | None = None,
|
|
nulls_last: Sequence[bool] | None = None,
|
|
) -> Iterator[Column]:
|
|
F = self._F # noqa: N806
|
|
descending = descending or [False] * len(cols)
|
|
nulls_last = nulls_last or [False] * len(cols)
|
|
mapping = {
|
|
(False, False): F.asc_nulls_first,
|
|
(False, True): F.asc_nulls_last,
|
|
(True, False): F.desc_nulls_first,
|
|
(True, True): F.desc_nulls_last,
|
|
}
|
|
yield from (
|
|
mapping[(_desc, _nulls_last)](col)
|
|
for col, _desc, _nulls_last in zip(cols, descending, nulls_last)
|
|
)
|
|
|
|
def partition_by(self, *cols: Column | str) -> WindowSpec:
|
|
"""Wraps `Window().partitionBy`, with default and `WindowInputs` handling."""
|
|
return self._Window.partitionBy(*cols or [self._F.lit(1)])
|
|
|
|
def __narwhals_expr__(self) -> None: ...
|
|
|
|
def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover
|
|
from narwhals._spark_like.namespace import SparkLikeNamespace
|
|
|
|
return SparkLikeNamespace(
|
|
version=self._version, implementation=self._implementation
|
|
)
|
|
|
|
@classmethod
|
|
def _alias_native(cls, expr: Column, name: str) -> Column:
|
|
return expr.alias(name)
|
|
|
|
@classmethod
|
|
def from_column_names(
|
|
cls: type[Self],
|
|
evaluate_column_names: EvalNames[SparkLikeLazyFrame],
|
|
/,
|
|
*,
|
|
context: _LimitedContext,
|
|
) -> Self:
|
|
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
|
return [df._F.col(col_name) for col_name in evaluate_column_names(df)]
|
|
|
|
return cls(
|
|
func,
|
|
evaluate_output_names=evaluate_column_names,
|
|
alias_output_names=None,
|
|
version=context._version,
|
|
implementation=context._implementation,
|
|
)
|
|
|
|
@classmethod
|
|
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
|
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
|
columns = df.columns
|
|
return [df._F.col(columns[i]) for i in column_indices]
|
|
|
|
return cls(
|
|
func,
|
|
evaluate_output_names=cls._eval_names_indices(column_indices),
|
|
alias_output_names=None,
|
|
version=context._version,
|
|
implementation=context._implementation,
|
|
)
|
|
|
|
def __truediv__(self, other: SparkLikeExpr) -> Self:
|
|
def _truediv(expr: Column, other: Column) -> Column:
|
|
return true_divide(self._F, expr, other)
|
|
|
|
return self._with_binary(_truediv, other)
|
|
|
|
def __rtruediv__(self, other: SparkLikeExpr) -> Self:
|
|
def _rtruediv(expr: Column, other: Column) -> Column:
|
|
return true_divide(self._F, other, expr)
|
|
|
|
return self._with_binary(_rtruediv, other).alias("literal")
|
|
|
|
def __floordiv__(self, other: SparkLikeExpr) -> Self:
|
|
def _floordiv(expr: Column, other: Column) -> Column:
|
|
return self._F.floor(true_divide(self._F, expr, other))
|
|
|
|
return self._with_binary(_floordiv, other)
|
|
|
|
def __rfloordiv__(self, other: SparkLikeExpr) -> Self:
|
|
def _rfloordiv(expr: Column, other: Column) -> Column:
|
|
return self._F.floor(true_divide(self._F, other, expr))
|
|
|
|
return self._with_binary(_rfloordiv, other).alias("literal")
|
|
|
|
def __invert__(self) -> Self:
|
|
invert = cast("Callable[..., Column]", operator.invert)
|
|
return self._with_elementwise(invert)
|
|
|
|
def cast(self, dtype: IntoDType) -> Self:
|
|
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
|
|
spark_dtype = narwhals_to_native_dtype(
|
|
dtype, self._version, self._native_dtypes, df.native.sparkSession
|
|
)
|
|
return [expr.cast(spark_dtype) for expr in self(df)]
|
|
|
|
def window_f(
|
|
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
|
|
) -> Sequence[Column]:
|
|
spark_dtype = narwhals_to_native_dtype(
|
|
dtype, self._version, self._native_dtypes, df.native.sparkSession
|
|
)
|
|
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]
|
|
|
|
return self.__class__(
|
|
func,
|
|
window_f,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
implementation=self._implementation,
|
|
)
|
|
|
|
def median(self) -> Self:
|
|
def _median(expr: Column) -> Column:
|
|
if self._implementation in {
|
|
Implementation.PYSPARK,
|
|
Implementation.PYSPARK_CONNECT,
|
|
} and Implementation.PYSPARK._backend_version() < (3, 4): # pragma: no cover
|
|
# Use percentile_approx with default accuracy parameter (10000)
|
|
return self._F.percentile_approx(expr.cast("double"), 0.5)
|
|
|
|
return self._F.median(expr)
|
|
|
|
return self._with_callable(_median)
|
|
|
|
def null_count(self) -> Self:
|
|
def _null_count(expr: Column) -> Column:
|
|
return self._F.count_if(self._F.isnull(expr))
|
|
|
|
return self._with_callable(_null_count)
|
|
|
|
def std(self, ddof: int) -> Self:
|
|
F = self._F # noqa: N806
|
|
if ddof == 0:
|
|
return self._with_callable(F.stddev_pop)
|
|
if ddof == 1:
|
|
return self._with_callable(F.stddev_samp)
|
|
|
|
def func(expr: Column) -> Column:
|
|
n_rows = F.count(expr)
|
|
return F.stddev_samp(expr) * F.sqrt((n_rows - 1) / (n_rows - ddof))
|
|
|
|
return self._with_callable(func)
|
|
|
|
def var(self, ddof: int) -> Self:
|
|
F = self._F # noqa: N806
|
|
if ddof == 0:
|
|
return self._with_callable(F.var_pop)
|
|
if ddof == 1:
|
|
return self._with_callable(F.var_samp)
|
|
|
|
def func(expr: Column) -> Column:
|
|
n_rows = F.count(expr)
|
|
return F.var_samp(expr) * (n_rows - 1) / (n_rows - ddof)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def is_finite(self) -> Self:
|
|
def _is_finite(expr: Column) -> Column:
|
|
# A value is finite if it's not NaN, and not infinite, while NULLs should be
|
|
# preserved
|
|
is_finite_condition = (
|
|
~self._F.isnan(expr)
|
|
& (expr != self._F.lit(float("inf")))
|
|
& (expr != self._F.lit(float("-inf")))
|
|
)
|
|
return self._F.when(~self._F.isnull(expr), is_finite_condition).otherwise(
|
|
None
|
|
)
|
|
|
|
return self._with_elementwise(_is_finite)
|
|
|
|
def is_in(self, values: Sequence[Any]) -> Self:
|
|
def _is_in(expr: Column) -> Column:
|
|
return expr.isin(values) if values else self._F.lit(False) # noqa: FBT003
|
|
|
|
return self._with_elementwise(_is_in)
|
|
|
|
def len(self) -> Self:
|
|
def _len(_expr: Column) -> Column:
|
|
# Use count(*) to count all rows including nulls
|
|
return self._F.count("*")
|
|
|
|
return self._with_callable(_len)
|
|
|
|
def skew(self) -> Self:
|
|
return self._with_callable(self._F.skewness)
|
|
|
|
def kurtosis(self) -> Self:
|
|
return self._with_callable(self._F.kurtosis)
|
|
|
|
def n_unique(self) -> Self:
|
|
def _n_unique(expr: Column) -> Column:
|
|
return self._F.count_distinct(expr) + self._F.max(
|
|
self._F.isnull(expr).cast(self._native_dtypes.IntegerType())
|
|
)
|
|
|
|
return self._with_callable(_n_unique)
|
|
|
|
def is_nan(self) -> Self:
|
|
def _is_nan(expr: Column) -> Column:
|
|
return self._F.when(self._F.isnull(expr), None).otherwise(self._F.isnan(expr))
|
|
|
|
return self._with_elementwise(_is_nan)
|
|
|
|
def fill_null(
|
|
self,
|
|
value: Self | NonNestedLiteral,
|
|
strategy: FillNullStrategy | None,
|
|
limit: int | None,
|
|
) -> Self:
|
|
if strategy is not None:
|
|
|
|
def _fill_with_strategy(
|
|
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
|
|
) -> Sequence[Column]:
|
|
fn = self._F.last_value if strategy == "forward" else self._F.first_value
|
|
if strategy == "forward":
|
|
start = self._Window.unboundedPreceding if limit is None else -limit
|
|
end = self._Window.currentRow
|
|
else:
|
|
start = self._Window.currentRow
|
|
end = self._Window.unboundedFollowing if limit is None else limit
|
|
return [
|
|
fn(expr, ignoreNulls=True).over(
|
|
self.partition_by(*inputs.partition_by)
|
|
.orderBy(*self._sort(*inputs.order_by))
|
|
.rowsBetween(start, end)
|
|
)
|
|
for expr in self(df)
|
|
]
|
|
|
|
return self._with_window_function(_fill_with_strategy)
|
|
|
|
def _fill_constant(expr: Column, value: Column) -> Column:
|
|
return self._F.ifnull(expr, value)
|
|
|
|
return self._with_elementwise(_fill_constant, value=value)
|
|
|
|
def log(self, base: float) -> Self:
|
|
def _log(expr: Column) -> Column:
|
|
return (
|
|
self._F.when(expr < 0, self._F.lit(float("nan")))
|
|
.when(expr == 0, self._F.lit(float("-inf")))
|
|
.otherwise(self._F.log(float(base), expr))
|
|
)
|
|
|
|
return self._with_elementwise(_log)
|
|
|
|
def sqrt(self) -> Self:
|
|
def _sqrt(expr: Column) -> Column:
|
|
return self._F.when(expr < 0, self._F.lit(float("nan"))).otherwise(
|
|
self._F.sqrt(expr)
|
|
)
|
|
|
|
return self._with_elementwise(_sqrt)
|
|
|
|
@property
|
|
def str(self) -> SparkLikeExprStringNamespace:
|
|
return SparkLikeExprStringNamespace(self)
|
|
|
|
@property
|
|
def dt(self) -> SparkLikeExprDateTimeNamespace:
|
|
return SparkLikeExprDateTimeNamespace(self)
|
|
|
|
@property
|
|
def list(self) -> SparkLikeExprListNamespace:
|
|
return SparkLikeExprListNamespace(self)
|
|
|
|
@property
|
|
def struct(self) -> SparkLikeExprStructNamespace:
|
|
return SparkLikeExprStructNamespace(self)
|
|
|
|
quantile = not_implemented()
|