106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
from narwhals._compliant.typing import NativeExprT
|
|
from narwhals._compliant.when_then import CompliantThen, CompliantWhen
|
|
from narwhals._sql.typing import SQLExprT, SQLLazyFrameT
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
from typing_extensions import Self
|
|
|
|
from narwhals._compliant.typing import WindowFunction
|
|
from narwhals._compliant.when_then import IntoExpr
|
|
from narwhals._compliant.window import WindowInputs
|
|
from narwhals._utils import _LimitedContext
|
|
|
|
|
|
class SQLWhen(
|
|
CompliantWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
|
Protocol[SQLLazyFrameT, NativeExprT, SQLExprT],
|
|
):
|
|
@property
|
|
def _then(self) -> type[SQLThen[SQLLazyFrameT, NativeExprT, SQLExprT]]: ...
|
|
|
|
def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]:
|
|
is_expr = self._condition._is_expr
|
|
when = df.__narwhals_namespace__()._when
|
|
lit = df.__narwhals_namespace__()._lit
|
|
condition = df._evaluate_expr(self._condition)
|
|
then_ = self._then_value
|
|
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
|
|
other_ = self._otherwise_value
|
|
if other_ is None:
|
|
result = when(condition, then)
|
|
else:
|
|
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
|
|
result = when(condition, then).otherwise(otherwise)
|
|
return [result]
|
|
|
|
@classmethod
|
|
def from_expr(cls, condition: SQLExprT, /, *, context: _LimitedContext) -> Self:
|
|
obj = cls.__new__(cls)
|
|
obj._condition = condition
|
|
obj._then_value = None
|
|
obj._otherwise_value = None
|
|
obj._implementation = context._implementation
|
|
obj._version = context._version
|
|
return obj
|
|
|
|
def _window_function(
|
|
self, df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
|
|
) -> Sequence[NativeExprT]:
|
|
when = df.__narwhals_namespace__()._when
|
|
lit = df.__narwhals_namespace__()._lit
|
|
is_expr = self._condition._is_expr
|
|
condition = self._condition.window_function(df, window_inputs)[0]
|
|
then_ = self._then_value
|
|
then = (
|
|
then_.window_function(df, window_inputs)[0] if is_expr(then_) else lit(then_)
|
|
)
|
|
|
|
other_ = self._otherwise_value
|
|
if other_ is None:
|
|
result = when(condition, then)
|
|
else:
|
|
other = (
|
|
other_.window_function(df, window_inputs)[0]
|
|
if is_expr(other_)
|
|
else lit(other_)
|
|
)
|
|
result = when(condition, then).otherwise(other)
|
|
return [result]
|
|
|
|
|
|
class SQLThen(
|
|
CompliantThen[
|
|
SQLLazyFrameT,
|
|
NativeExprT,
|
|
SQLExprT,
|
|
SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
|
],
|
|
Protocol[SQLLazyFrameT, NativeExprT, SQLExprT],
|
|
):
|
|
_window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None
|
|
|
|
@classmethod
|
|
def from_when(
|
|
cls,
|
|
when: SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
|
then: IntoExpr[NativeExprT, SQLExprT],
|
|
/,
|
|
) -> Self:
|
|
when._then_value = then
|
|
obj = cls.__new__(cls)
|
|
obj._call = when
|
|
obj._window_function = when._window_function
|
|
obj._when_value = when
|
|
obj._evaluate_output_names = getattr(
|
|
then, "_evaluate_output_names", lambda _df: ["literal"]
|
|
)
|
|
obj._alias_output_names = getattr(then, "_alias_output_names", None)
|
|
obj._implementation = when._implementation
|
|
obj._version = when._version
|
|
return obj
|