from __future__ import annotations import operator from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace from narwhals._spark_like.expr_list import SparkLikeExprListNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace from narwhals._spark_like.utils import ( import_functions, import_native_dtypes, import_window, narwhals_to_native_dtype, true_divide, ) from narwhals._sql.expr import SQLExpr from narwhals._utils import Implementation, Version, not_implemented if TYPE_CHECKING: from collections.abc import Iterator, Mapping, Sequence from sqlframe.base.column import Column from sqlframe.base.window import Window, WindowSpec from typing_extensions import Self, TypeAlias from narwhals._compliant import WindowInputs from narwhals._compliant.typing import ( AliasNames, EvalNames, EvalSeries, WindowFunction, ) from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._utils import _LimitedContext from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"] SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column] SparkWindowInputs = WindowInputs[Column] class SparkLikeExpr(SQLExpr["SparkLikeLazyFrame", "Column"]): def __init__( self, call: EvalSeries[SparkLikeLazyFrame, Column], window_function: SparkWindowFunction | None = None, *, evaluate_output_names: EvalNames[SparkLikeLazyFrame], alias_output_names: AliasNames | None, version: Version, implementation: Implementation, ) -> None: self._call = call self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version self._implementation = implementation self._metadata: ExprMetadata | None = None self._window_function: SparkWindowFunction | None = window_function _REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = { "min": "rank", "max": "rank", "average": "rank", "dense": "dense_rank", "ordinal": "row_number", } def _count_star(self) -> Column: return self._F.count("*") def _window_expression( self, expr: Column, partition_by: Sequence[str | Column] = (), order_by: Sequence[str | Column] = (), rows_start: int | None = None, rows_end: int | None = None, *, descending: Sequence[bool] | None = None, nulls_last: Sequence[bool] | None = None, ) -> Column: window = self.partition_by(*partition_by) if order_by: window = window.orderBy( *self._sort(*order_by, descending=descending, nulls_last=nulls_last) ) if rows_start is not None and rows_end is not None: window = window.rowsBetween(rows_start, rows_end) elif rows_end is not None: window = window.rowsBetween(self._Window.unboundedPreceding, rows_end) elif rows_start is not None: # pragma: no cover window = window.rowsBetween(rows_start, self._Window.unboundedFollowing) return expr.over(window) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: if kind is ExprKind.LITERAL: return self return self.over([self._F.lit(1)], []) @property def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 if TYPE_CHECKING: from sqlframe.base import functions return functions else: return import_functions(self._implementation) @property def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202 if TYPE_CHECKING: from sqlframe.base import types return types else: return import_native_dtypes(self._implementation) @property def _Window(self) -> type[Window]: # noqa: N802 if TYPE_CHECKING: from sqlframe.base.window import Window return Window else: return import_window(self._implementation) def _sort( self, *cols: Column | str, descending: Sequence[bool] | None = None, nulls_last: Sequence[bool] | None = None, ) -> Iterator[Column]: F = self._F # noqa: N806 descending = descending or [False] * len(cols) nulls_last = nulls_last or [False] * len(cols) mapping = { (False, False): F.asc_nulls_first, (False, True): F.asc_nulls_last, (True, False): F.desc_nulls_first, (True, True): F.desc_nulls_last, } yield from ( mapping[(_desc, _nulls_last)](col) for col, _desc, _nulls_last in zip(cols, descending, nulls_last) ) def partition_by(self, *cols: Column | str) -> WindowSpec: """Wraps `Window().partitionBy`, with default and `WindowInputs` handling.""" return self._Window.partitionBy(*cols or [self._F.lit(1)]) def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover from narwhals._spark_like.namespace import SparkLikeNamespace return SparkLikeNamespace( version=self._version, implementation=self._implementation ) @classmethod def _alias_native(cls, expr: Column, name: str) -> Column: return expr.alias(name) @classmethod def from_column_names( cls: type[Self], evaluate_column_names: EvalNames[SparkLikeLazyFrame], /, *, context: _LimitedContext, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col_name) for col_name in evaluate_column_names(df)] return cls( func, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, implementation=context._implementation, ) @classmethod def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: columns = df.columns return [df._F.col(columns[i]) for i in column_indices] return cls( func, evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, implementation=context._implementation, ) def __truediv__(self, other: SparkLikeExpr) -> Self: def _truediv(expr: Column, other: Column) -> Column: return true_divide(self._F, expr, other) return self._with_binary(_truediv, other) def __rtruediv__(self, other: SparkLikeExpr) -> Self: def _rtruediv(expr: Column, other: Column) -> Column: return true_divide(self._F, other, expr) return self._with_binary(_rtruediv, other).alias("literal") def __floordiv__(self, other: SparkLikeExpr) -> Self: def _floordiv(expr: Column, other: Column) -> Column: return self._F.floor(true_divide(self._F, expr, other)) return self._with_binary(_floordiv, other) def __rfloordiv__(self, other: SparkLikeExpr) -> Self: def _rfloordiv(expr: Column, other: Column) -> Column: return self._F.floor(true_divide(self._F, other, expr)) return self._with_binary(_rfloordiv, other).alias("literal") def __invert__(self) -> Self: invert = cast("Callable[..., Column]", operator.invert) return self._with_elementwise(invert) def cast(self, dtype: IntoDType) -> Self: def func(df: SparkLikeLazyFrame) -> Sequence[Column]: spark_dtype = narwhals_to_native_dtype( dtype, self._version, self._native_dtypes, df.native.sparkSession ) return [expr.cast(spark_dtype) for expr in self(df)] def window_f( df: SparkLikeLazyFrame, inputs: SparkWindowInputs ) -> Sequence[Column]: spark_dtype = narwhals_to_native_dtype( dtype, self._version, self._native_dtypes, df.native.sparkSession ) return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)] return self.__class__( func, window_f, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, implementation=self._implementation, ) def median(self) -> Self: def _median(expr: Column) -> Column: if self._implementation in { Implementation.PYSPARK, Implementation.PYSPARK_CONNECT, } and Implementation.PYSPARK._backend_version() < (3, 4): # pragma: no cover # Use percentile_approx with default accuracy parameter (10000) return self._F.percentile_approx(expr.cast("double"), 0.5) return self._F.median(expr) return self._with_callable(_median) def null_count(self) -> Self: def _null_count(expr: Column) -> Column: return self._F.count_if(self._F.isnull(expr)) return self._with_callable(_null_count) def std(self, ddof: int) -> Self: F = self._F # noqa: N806 if ddof == 0: return self._with_callable(F.stddev_pop) if ddof == 1: return self._with_callable(F.stddev_samp) def func(expr: Column) -> Column: n_rows = F.count(expr) return F.stddev_samp(expr) * F.sqrt((n_rows - 1) / (n_rows - ddof)) return self._with_callable(func) def var(self, ddof: int) -> Self: F = self._F # noqa: N806 if ddof == 0: return self._with_callable(F.var_pop) if ddof == 1: return self._with_callable(F.var_samp) def func(expr: Column) -> Column: n_rows = F.count(expr) return F.var_samp(expr) * (n_rows - 1) / (n_rows - ddof) return self._with_callable(func) def is_finite(self) -> Self: def _is_finite(expr: Column) -> Column: # A value is finite if it's not NaN, and not infinite, while NULLs should be # preserved is_finite_condition = ( ~self._F.isnan(expr) & (expr != self._F.lit(float("inf"))) & (expr != self._F.lit(float("-inf"))) ) return self._F.when(~self._F.isnull(expr), is_finite_condition).otherwise( None ) return self._with_elementwise(_is_finite) def is_in(self, values: Sequence[Any]) -> Self: def _is_in(expr: Column) -> Column: return expr.isin(values) if values else self._F.lit(False) # noqa: FBT003 return self._with_elementwise(_is_in) def len(self) -> Self: def _len(_expr: Column) -> Column: # Use count(*) to count all rows including nulls return self._F.count("*") return self._with_callable(_len) def skew(self) -> Self: return self._with_callable(self._F.skewness) def kurtosis(self) -> Self: return self._with_callable(self._F.kurtosis) def n_unique(self) -> Self: def _n_unique(expr: Column) -> Column: return self._F.count_distinct(expr) + self._F.max( self._F.isnull(expr).cast(self._native_dtypes.IntegerType()) ) return self._with_callable(_n_unique) def is_nan(self) -> Self: def _is_nan(expr: Column) -> Column: return self._F.when(self._F.isnull(expr), None).otherwise(self._F.isnan(expr)) return self._with_elementwise(_is_nan) def fill_null( self, value: Self | NonNestedLiteral, strategy: FillNullStrategy | None, limit: int | None, ) -> Self: if strategy is not None: def _fill_with_strategy( df: SparkLikeLazyFrame, inputs: SparkWindowInputs ) -> Sequence[Column]: fn = self._F.last_value if strategy == "forward" else self._F.first_value if strategy == "forward": start = self._Window.unboundedPreceding if limit is None else -limit end = self._Window.currentRow else: start = self._Window.currentRow end = self._Window.unboundedFollowing if limit is None else limit return [ fn(expr, ignoreNulls=True).over( self.partition_by(*inputs.partition_by) .orderBy(*self._sort(*inputs.order_by)) .rowsBetween(start, end) ) for expr in self(df) ] return self._with_window_function(_fill_with_strategy) def _fill_constant(expr: Column, value: Column) -> Column: return self._F.ifnull(expr, value) return self._with_elementwise(_fill_constant, value=value) def log(self, base: float) -> Self: def _log(expr: Column) -> Column: return ( self._F.when(expr < 0, self._F.lit(float("nan"))) .when(expr == 0, self._F.lit(float("-inf"))) .otherwise(self._F.log(float(base), expr)) ) return self._with_elementwise(_log) def sqrt(self) -> Self: def _sqrt(expr: Column) -> Column: return self._F.when(expr < 0, self._F.lit(float("nan"))).otherwise( self._F.sqrt(expr) ) return self._with_elementwise(_sqrt) @property def str(self) -> SparkLikeExprStringNamespace: return SparkLikeExprStringNamespace(self) @property def dt(self) -> SparkLikeExprDateTimeNamespace: return SparkLikeExprDateTimeNamespace(self) @property def list(self) -> SparkLikeExprListNamespace: return SparkLikeExprListNamespace(self) @property def struct(self) -> SparkLikeExprStructNamespace: return SparkLikeExprStructNamespace(self) quantile = not_implemented()