team-10/venv/Lib/site-packages/narwhals/_compliant/when_then.py
2025-08-02 02:00:33 +02:00

130 lines
4.2 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.typing import (
CompliantExprAny,
CompliantFrameAny,
CompliantSeriesOrNativeExprAny,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprAny,
NativeSeriesT,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self, TypeAlias
from narwhals._compliant.typing import EvalSeries
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.typing import NonNestedLiteral
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"]
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
Scalar: TypeAlias = Any
"""A native literal value."""
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
"""Anything that is convertible into a `CompliantExpr`."""
class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]):
_condition: ExprT
_then_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
_implementation: Implementation
_version: Version
@property
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ...
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
def then(
self, value: IntoExpr[SeriesT, ExprT], /
) -> CompliantThen[FrameT, SeriesT, ExprT, Self]:
return self._then.from_when(self, value)
@classmethod
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
obj = cls.__new__(cls)
obj._condition = condition
obj._then_value = None
obj._otherwise_value = None
obj._implementation = context._implementation
obj._version = context._version
return obj
WhenT_contra = TypeVar(
"WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True
)
class CompliantThen(
CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra]
):
_call: EvalSeries[FrameT, SeriesT]
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
_implementation: Implementation
_version: Version
@classmethod
def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self:
when._then_value = then
obj = cls.__new__(cls)
obj._call = when
obj._when_value = when
obj._evaluate_output_names = getattr(
then, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then, "_alias_output_names", None)
obj._implementation = when._implementation
obj._version = when._version
return obj
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
self._when_value._otherwise_value = otherwise
return cast("ExprT", self)
class EagerWhen(
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
):
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
/,
) -> NativeSeriesT: ...
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
is_expr = self._condition._is_expr
when: EagerSeriesT = self._condition(df)[0]
then: EagerSeriesT
align = when._align_full_broadcast
if is_expr(self._then_value):
then = self._then_value(df)[0]
else:
then = when.alias("literal")._from_scalar(self._then_value)
then._broadcast = True
if is_expr(self._otherwise_value):
otherwise = self._otherwise_value(df)[0]
when, then, otherwise = align(when, then, otherwise)
result = self._if_then_else(when.native, then.native, otherwise.native)
else:
when, then = align(when, then)
result = self._if_then_else(when.native, then.native, self._otherwise_value)
return [then._with_native(result)]