team-10/venv/Lib/site-packages/narwhals/_sql/expr.py
2025-08-02 02:00:33 +02:00

730 lines
27 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, cast
from narwhals._compliant.expr import LazyExpr
from narwhals._compliant.typing import (
AliasNames,
EvalNames,
EvalSeries,
NativeExprT,
WindowFunction,
)
from narwhals._compliant.window import WindowInputs
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._sql.typing import SQLLazyFrameT
from narwhals._utils import Implementation, Version, not_implemented
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from typing_extensions import Self, TypeIs
from narwhals._compliant.typing import AliasNames, WindowFunction
from narwhals._expression_parsing import ExprMetadata
from narwhals._sql.namespace import SQLNamespace
from narwhals.typing import NumericLiteral, PythonLiteral, RankMethod, TemporalLiteral
class SQLExpr(LazyExpr[SQLLazyFrameT, NativeExprT], Protocol[SQLLazyFrameT, NativeExprT]):
_call: EvalSeries[SQLLazyFrameT, NativeExprT]
_evaluate_output_names: EvalNames[SQLLazyFrameT]
_alias_output_names: AliasNames | None
_version: Version
_implementation: Implementation
_metadata: ExprMetadata | None
_window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None
def __init__(
self,
call: EvalSeries[SQLLazyFrameT, NativeExprT],
window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None = None,
*,
evaluate_output_names: EvalNames[SQLLazyFrameT],
alias_output_names: AliasNames | None,
version: Version,
implementation: Implementation = Implementation.DUCKDB,
) -> None: ...
def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]:
return self._call(df)
def __narwhals_namespace__(
self,
) -> SQLNamespace[SQLLazyFrameT, Self, Any, NativeExprT]: ...
def _callable_to_eval_series(
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
) -> EvalSeries[SQLLazyFrameT, NativeExprT]:
def func(df: SQLLazyFrameT) -> list[NativeExprT]:
native_series_list = self(df)
other_native_series = {
key: df._evaluate_expr(value)
if self._is_expr(value)
else self._lit(value)
for key, value in expressifiable_args.items()
}
return [
call(native_series, **other_native_series)
for native_series in native_series_list
]
return func
def _push_down_window_function(
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
def window_f(
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
# If a function `f` is elementwise, and `g` is another function, then
# - `f(g) over (window)`
# - `f(g over (window))
# are equivalent.
# Make sure to only use with if `call` is elementwise!
native_series_list = self.window_function(df, window_inputs)
other_native_series = {
key: df._evaluate_window_expr(value, window_inputs)
if self._is_expr(value)
else self._lit(value)
for key, value in expressifiable_args.items()
}
return [
call(native_series, **other_native_series)
for native_series in native_series_list
]
return window_f
def _with_window_function(
self, window_function: WindowFunction[SQLLazyFrameT, NativeExprT]
) -> Self:
return self.__class__(
self._call,
window_function,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)
def _with_callable(
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
) -> Self:
return self.__class__(
self._callable_to_eval_series(call, **expressifiable_args),
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)
def _with_elementwise(
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
) -> Self:
return self.__class__(
self._callable_to_eval_series(call, **expressifiable_args),
self._push_down_window_function(call, **expressifiable_args),
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)
def _with_binary(self, op: Callable[..., NativeExprT], other: Self | Any) -> Self:
return self.__class__(
self._callable_to_eval_series(op, other=other),
self._push_down_window_function(op, other=other),
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
current_alias_output_names = self._alias_output_names
alias_output_names = (
None
if func is None
else func
if current_alias_output_names is None
else lambda output_names: func(current_alias_output_names(output_names))
)
return type(self)(
self._call,
self._window_function,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=alias_output_names,
version=self._version,
implementation=self._implementation,
)
@property
def window_function(self) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
def default_window_func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
assert not inputs.order_by # noqa: S101
return [
self._window_expression(expr, inputs.partition_by, inputs.order_by)
for expr in self(df)
]
return self._window_function or default_window_func
def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT:
return self.__narwhals_namespace__()._function(name, *args)
def _lit(self, value: Any) -> NativeExprT:
return self.__narwhals_namespace__()._lit(value)
def _when(self, condition: NativeExprT, value: NativeExprT) -> NativeExprT:
return self.__narwhals_namespace__()._when(condition, value)
def _coalesce(self, *expr: NativeExprT) -> NativeExprT:
return self.__narwhals_namespace__()._coalesce(*expr)
def _count_star(self) -> NativeExprT: ...
def _window_expression(
self,
expr: NativeExprT,
partition_by: Sequence[str | NativeExprT] = (),
order_by: Sequence[str | NativeExprT] = (),
rows_start: int | None = None,
rows_end: int | None = None,
*,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> NativeExprT: ...
def _cum_window_func(
self,
func_name: Literal["sum", "max", "min", "count", "product"],
*,
reverse: bool,
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
self._window_expression(
self._function(func_name, expr),
inputs.partition_by,
inputs.order_by,
descending=[reverse] * len(inputs.order_by),
nulls_last=[reverse] * len(inputs.order_by),
rows_end=0,
)
for expr in self(df)
]
return func
def _rolling_window_func(
self,
func_name: Literal["sum", "mean", "std", "var"],
window_size: int,
min_samples: int,
ddof: int | None = None,
*,
center: bool,
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
supported_funcs = ["sum", "mean", "std", "var"]
if center:
half = (window_size - 1) // 2
remainder = (window_size - 1) % 2
start = -(half + remainder)
end = half
else:
start = -(window_size - 1)
end = 0
def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
if func_name in {"sum", "mean"}:
func_: str = func_name
elif func_name == "var" and ddof == 0:
func_ = "var_pop"
elif func_name in "var" and ddof == 1:
func_ = "var_samp"
elif func_name == "std" and ddof == 0:
func_ = "stddev_pop"
elif func_name == "std" and ddof == 1:
func_ = "stddev_samp"
elif func_name in {"var", "std"}: # pragma: no cover
msg = f"Only ddof=0 and ddof=1 are currently supported for rolling_{func_name}."
raise ValueError(msg)
else: # pragma: no cover
msg = f"Only the following functions are supported: {supported_funcs}.\nGot: {func_name}."
raise ValueError(msg)
window_kwargs: Any = {
"partition_by": inputs.partition_by,
"order_by": inputs.order_by,
"rows_start": start,
"rows_end": end,
}
return [
self._when(
self._window_expression( # type: ignore[operator]
self._function("count", expr), **window_kwargs
)
>= self._lit(min_samples),
self._window_expression(self._function(func_, expr), **window_kwargs),
)
for expr in self(df)
]
return func
@classmethod
def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]:
return hasattr(obj, "__narwhals_expr__")
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@classmethod
def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ...
@classmethod
def _from_elementwise_horizontal_op(
cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self
) -> Self:
def call(df: SQLLazyFrameT) -> Sequence[NativeExprT]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]
def window_function(
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
cols = (
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
)
return [func(cols)]
context = exprs[0]
return cls(
call,
window_function=window_function,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=context._version,
implementation=context._implementation,
)
# Binary
def __eq__(self, other: Self) -> Self: # type: ignore[override]
return self._with_binary(lambda expr, other: expr.__eq__(other), other)
def __ne__(self, other: Self) -> Self: # type: ignore[override]
return self._with_binary(lambda expr, other: expr.__ne__(other), other)
def __add__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__add__(other), other)
def __sub__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__sub__(other), other)
def __rsub__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: other - expr, other).alias("literal")
def __mul__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__mul__(other), other)
def __truediv__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__truediv__(other), other)
def __rtruediv__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: other / expr, other).alias("literal")
def __floordiv__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__floordiv__(other), other)
def __rfloordiv__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: other // expr, other).alias(
"literal"
)
def __pow__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__pow__(other), other)
def __rpow__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: other**expr, other).alias("literal")
def __mod__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__mod__(other), other)
def __rmod__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: other % expr, other).alias("literal")
def __ge__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__ge__(other), other)
def __gt__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__gt__(other), other)
def __le__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__le__(other), other)
def __lt__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__lt__(other), other)
def __and__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__and__(other), other)
def __or__(self, other: Self) -> Self:
return self._with_binary(lambda expr, other: expr.__or__(other), other)
# Aggregations
def all(self) -> Self:
def f(expr: NativeExprT) -> NativeExprT:
return self._coalesce(self._function("bool_and", expr), self._lit(True)) # noqa: FBT003
def window_f(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
self._coalesce(
self._window_expression(
self._function("bool_and", expr), inputs.partition_by
),
self._lit(True), # noqa: FBT003
)
for expr in self(df)
]
return self._with_callable(f)._with_window_function(window_f)
def any(self) -> Self:
def f(expr: NativeExprT) -> NativeExprT:
return self._coalesce(self._function("bool_or", expr), self._lit(False)) # noqa: FBT003
def window_f(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
self._coalesce(
self._window_expression(
self._function("bool_or", expr), inputs.partition_by
),
self._lit(False), # noqa: FBT003
)
for expr in self(df)
]
return self._with_callable(f)._with_window_function(window_f)
def max(self) -> Self:
return self._with_callable(lambda expr: self._function("max", expr))
def mean(self) -> Self:
return self._with_callable(lambda expr: self._function("mean", expr))
def median(self) -> Self:
return self._with_callable(lambda expr: self._function("median", expr))
def min(self) -> Self:
return self._with_callable(lambda expr: self._function("min", expr))
def count(self) -> Self:
return self._with_callable(lambda expr: self._function("count", expr))
def sum(self) -> Self:
def f(expr: NativeExprT) -> NativeExprT:
return self._coalesce(self._function("sum", expr), self._lit(0))
def window_f(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
self._coalesce(
self._window_expression(
self._function("sum", expr), inputs.partition_by
),
self._lit(0),
)
for expr in self(df)
]
return self._with_callable(f)._with_window_function(window_f)
# Elementwise
def abs(self) -> Self:
return self._with_elementwise(lambda expr: self._function("abs", expr))
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self:
def _clip_lower(expr: NativeExprT, lower_bound: Any) -> NativeExprT:
return self._function("greatest", expr, lower_bound)
def _clip_upper(expr: NativeExprT, upper_bound: Any) -> NativeExprT:
return self._function("least", expr, upper_bound)
def _clip_both(
expr: NativeExprT, lower_bound: Any, upper_bound: Any
) -> NativeExprT:
return self._function(
"greatest", self._function("least", expr, upper_bound), lower_bound
)
if lower_bound is None:
return self._with_elementwise(_clip_upper, upper_bound=upper_bound)
if upper_bound is None:
return self._with_elementwise(_clip_lower, lower_bound=lower_bound)
return self._with_elementwise(
_clip_both, lower_bound=lower_bound, upper_bound=upper_bound
)
def is_null(self) -> Self:
return self._with_elementwise(lambda expr: self._function("isnull", expr))
def round(self, decimals: int) -> Self:
return self._with_elementwise(
lambda expr: self._function("round", expr, self._lit(decimals))
)
def exp(self) -> Self:
return self._with_elementwise(lambda expr: self._function("exp", expr))
# Cumulative
def cum_sum(self, *, reverse: bool) -> Self:
return self._with_window_function(self._cum_window_func("sum", reverse=reverse))
def cum_max(self, *, reverse: bool) -> Self:
return self._with_window_function(self._cum_window_func("max", reverse=reverse))
def cum_min(self, *, reverse: bool) -> Self:
return self._with_window_function(self._cum_window_func("min", reverse=reverse))
def cum_count(self, *, reverse: bool) -> Self:
return self._with_window_function(self._cum_window_func("count", reverse=reverse))
def cum_prod(self, *, reverse: bool) -> Self:
return self._with_window_function(
self._cum_window_func("product", reverse=reverse)
)
# Rolling
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_window_function(
self._rolling_window_func("sum", window_size, min_samples, center=center)
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_window_function(
self._rolling_window_func("mean", window_size, min_samples, center=center)
)
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
return self._with_window_function(
self._rolling_window_func(
"var", window_size, min_samples, ddof=ddof, center=center
)
)
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
return self._with_window_function(
self._rolling_window_func(
"std", window_size, min_samples, ddof=ddof, center=center
)
)
# Other window functions
def diff(self) -> Self:
def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
expr # type: ignore[operator]
- self._window_expression(
self._function("lag", expr), inputs.partition_by, inputs.order_by
)
for expr in self(df)
]
return self._with_window_function(func)
def shift(self, n: int) -> Self:
def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
self._window_expression(
self._function("lag", expr, n), inputs.partition_by, inputs.order_by
)
for expr in self(df)
]
return self._with_window_function(func)
def is_first_distinct(self) -> Self:
def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
# pyright checkers think the return type is `list[bool]` because of `==`
return [
cast(
"NativeExprT",
self._window_expression(
self._function("row_number"),
(*inputs.partition_by, expr),
inputs.order_by,
)
== self._lit(1),
)
for expr in self(df)
]
return self._with_window_function(func)
def is_last_distinct(self) -> Self:
def func(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
return [
cast(
"NativeExprT",
self._window_expression(
self._function("row_number"),
(*inputs.partition_by, expr),
inputs.order_by,
descending=[True] * len(inputs.order_by),
nulls_last=[True] * len(inputs.order_by),
)
== self._lit(1),
)
for expr in self(df)
]
return self._with_window_function(func)
def rank(self, method: RankMethod, *, descending: bool) -> Self:
if method in {"min", "max", "average"}:
func = self._function("rank")
elif method == "dense":
func = self._function("dense_rank")
else: # method == "ordinal"
func = self._function("row_number")
def _rank(
expr: NativeExprT,
partition_by: Sequence[str | NativeExprT] = (),
order_by: Sequence[str | NativeExprT] = (),
*,
descending: Sequence[bool],
nulls_last: Sequence[bool],
) -> NativeExprT:
count_expr = self._count_star()
window_kwargs: dict[str, Any] = {
"partition_by": partition_by,
"order_by": (expr, *order_by),
"descending": descending,
"nulls_last": nulls_last,
}
count_window_kwargs: dict[str, Any] = {"partition_by": (*partition_by, expr)}
if method == "max":
rank_expr = (
self._window_expression(func, **window_kwargs) # type: ignore[operator]
+ self._window_expression(count_expr, **count_window_kwargs)
- self._lit(1)
)
elif method == "average":
rank_expr = self._window_expression(func, **window_kwargs) + (
self._window_expression(count_expr, **count_window_kwargs) # type: ignore[operator]
- self._lit(1)
) / self._lit(2.0)
else:
rank_expr = self._window_expression(func, **window_kwargs)
return self._when(~self._function("isnull", expr), rank_expr) # type: ignore[operator]
def _unpartitioned_rank(expr: NativeExprT) -> NativeExprT:
return _rank(expr, descending=[descending], nulls_last=[True])
def _partitioned_rank(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
# node: when `descending` / `nulls_last` are supported in `.over`, they should be respected here
# https://github.com/narwhals-dev/narwhals/issues/2790
return [
_rank(
expr,
inputs.partition_by,
inputs.order_by,
descending=[descending] + [False] * len(inputs.order_by),
nulls_last=[True] + [False] * len(inputs.order_by),
)
for expr in self(df)
]
return self._with_callable(_unpartitioned_rank)._with_window_function(
_partitioned_rank
)
def is_unique(self) -> Self:
def _is_unique(
expr: NativeExprT, *partition_by: str | NativeExprT
) -> NativeExprT:
return cast(
"NativeExprT",
self._window_expression(self._count_star(), (expr, *partition_by))
== self._lit(1),
)
def _unpartitioned_is_unique(expr: NativeExprT) -> NativeExprT:
return _is_unique(expr)
def _partitioned_is_unique(
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
) -> Sequence[NativeExprT]:
assert not inputs.order_by # noqa: S101
return [_is_unique(expr, *inputs.partition_by) for expr in self(df)]
return self._with_callable(_unpartitioned_is_unique)._with_window_function(
_partitioned_is_unique
)
# Other
def over(
self, partition_by: Sequence[str | NativeExprT], order_by: Sequence[str]
) -> Self:
def func(df: SQLLazyFrameT) -> Sequence[NativeExprT]:
return self.window_function(df, WindowInputs(partition_by, order_by))
return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)
arg_max: not_implemented = not_implemented()
arg_min: not_implemented = not_implemented()
arg_true: not_implemented = not_implemented()
drop_nulls: not_implemented = not_implemented()
ewm_mean: not_implemented = not_implemented()
gather_every: not_implemented = not_implemented()
head: not_implemented = not_implemented()
map_batches: not_implemented = not_implemented()
mode: not_implemented = not_implemented()
replace_strict: not_implemented = not_implemented()
sort: not_implemented = not_implemented()
tail: not_implemented = not_implemented()
sample: not_implemented = not_implemented()
unique: not_implemented = not_implemented()
# namespaces
cat: not_implemented = not_implemented() # type: ignore[assignment]