team-10/env/Lib/site-packages/narwhals/_expression_parsing.py

631 lines
23 KiB
Python
Raw Permalink Normal View History

2025-08-02 07:34:44 +02:00
# Utilities for expression parsing
# Useful for backends which don't have any concept of expressions, such
# and pandas or PyArrow.
from __future__ import annotations
from enum import Enum, auto
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
from narwhals._utils import is_compliant_expr
from narwhals.dependencies import is_narwhals_series, is_numpy_array
from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Never, TypeIs
from narwhals._compliant import CompliantExpr, CompliantFrameT
from narwhals._compliant.typing import (
AliasNames,
CompliantExprAny,
CompliantFrameAny,
CompliantNamespaceAny,
EagerNamespaceAny,
EvalNames,
)
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray
T = TypeVar("T")
def is_expr(obj: Any) -> TypeIs[Expr]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.expr import Expr
return isinstance(obj, Expr)
def is_series(obj: Any) -> TypeIs[Series[Any]]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.series import Series
return isinstance(obj, Series)
def combine_evaluate_output_names(
*exprs: CompliantExpr[CompliantFrameT, Any],
) -> EvalNames[CompliantFrameT]:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
# first name of `expr1`.
if not is_compliant_expr(exprs[0]): # pragma: no cover
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
raise AssertionError(msg)
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
return exprs[0]._evaluate_output_names(df)[:1]
return evaluate_output_names
def combine_alias_output_names(*exprs: CompliantExprAny) -> AliasNames | None:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the
# aliasing function of `expr1` and apply it to the first output name of `expr1`.
if exprs[0]._alias_output_names is None:
return None
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
return exprs[0]._alias_output_names(names)[:1] # type: ignore[misc]
return alias_output_names
def extract_compliant(
plx: CompliantNamespaceAny,
other: IntoExpr | NonNestedLiteral | _1DArray,
*,
str_as_lit: bool,
) -> CompliantExprAny | NonNestedLiteral:
if is_expr(other):
return other._to_compliant_expr(plx)
if isinstance(other, str) and not str_as_lit:
return plx.col(other)
if is_narwhals_series(other):
return other._compliant_series._to_expr()
if is_numpy_array(other):
ns = cast("EagerNamespaceAny", plx)
return ns._series.from_numpy(other, context=ns)._to_expr()
return other
def evaluate_output_names_and_aliases(
expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str]
) -> tuple[Sequence[str], Sequence[str]]:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
if exclude:
assert expr._metadata is not None # noqa: S101
if expr._metadata.expansion_kind.is_multi_unnamed():
output_names, aliases = zip(
*[
(x, alias)
for x, alias in zip(output_names, aliases)
if x not in exclude
]
)
return output_names, aliases
class ExprKind(Enum):
"""Describe which kind of expression we are dealing with."""
LITERAL = auto()
"""e.g. `nw.lit(1)`"""
AGGREGATION = auto()
"""Reduces to a single value, not affected by row order, e.g. `nw.col('a').mean()`"""
ORDERABLE_AGGREGATION = auto()
"""Reduces to a single value, affected by row order, e.g. `nw.col('a').arg_max()`"""
ELEMENTWISE = auto()
"""Preserves length, can operate without context for surrounding rows, e.g. `nw.col('a').abs()`."""
ORDERABLE_WINDOW = auto()
"""Depends on the rows around it and on their order, e.g. `diff`."""
WINDOW = auto()
"""Depends on the rows around it and possibly their order, e.g. `rank`."""
FILTRATION = auto()
"""Changes length, not affected by row order, e.g. `drop_nulls`."""
ORDERABLE_FILTRATION = auto()
"""Changes length, affected by row order, e.g. `tail`."""
NARY = auto()
"""Results from the combination of multiple expressions."""
OVER = auto()
"""Results from calling `.over` on expression."""
UNKNOWN = auto()
"""Based on the information we have, we can't determine the ExprKind."""
@property
def is_scalar_like(self) -> bool:
return self in {ExprKind.LITERAL, ExprKind.AGGREGATION}
@property
def is_orderable_window(self) -> bool:
return self in {ExprKind.ORDERABLE_WINDOW, ExprKind.ORDERABLE_AGGREGATION}
@classmethod
def from_expr(cls, obj: Expr) -> ExprKind:
meta = obj._metadata
if meta.is_literal:
return ExprKind.LITERAL
if meta.is_scalar_like:
return ExprKind.AGGREGATION
if meta.is_elementwise:
return ExprKind.ELEMENTWISE
return ExprKind.UNKNOWN
@classmethod
def from_into_expr(
cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool
) -> ExprKind:
if is_expr(obj):
return cls.from_expr(obj)
if (
is_narwhals_series(obj)
or is_numpy_array(obj)
or (isinstance(obj, str) and not str_as_lit)
):
return ExprKind.ELEMENTWISE
return ExprKind.LITERAL
def is_scalar_like(
obj: ExprKind,
) -> TypeIs[Literal[ExprKind.LITERAL, ExprKind.AGGREGATION]]:
return obj.is_scalar_like
class ExpansionKind(Enum):
"""Describe what kind of expansion the expression performs."""
SINGLE = auto()
"""e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`"""
MULTI_NAMED = auto()
"""e.g. `nw.col('a', 'b')`"""
MULTI_UNNAMED = auto()
"""e.g. `nw.all()`, nw.nth(0, 1)"""
def is_multi_unnamed(self) -> bool:
return self is ExpansionKind.MULTI_UNNAMED
def is_multi_output(self) -> bool:
return self in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED}
def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]:
if self is ExpansionKind.MULTI_UNNAMED and other is ExpansionKind.MULTI_UNNAMED:
# e.g. nw.selectors.all() - nw.selectors.numeric().
return ExpansionKind.MULTI_UNNAMED
# Don't attempt anything more complex, keep it simple and raise in the face of ambiguity.
msg = f"Unsupported ExpansionKind combination, got {self} and {other}, please report a bug." # pragma: no cover
raise AssertionError(msg) # pragma: no cover
class ExprMetadata:
"""Expression metadata.
Parameters:
expansion_kind: What kind of expansion the expression performs.
has_windows: Whether it already contains window functions.
is_elementwise: Whether it can operate row-by-row without context
of the other rows around it.
is_literal: Whether it is just a literal wrapped in an expression.
is_scalar_like: Whether it is a literal or an aggregation.
last_node: The ExprKind of the last node.
n_orderable_ops: The number of order-dependent operations. In the
lazy case, this number must be `0` by the time the expression
is evaluated.
preserves_length: Whether the expression preserves the input length.
"""
__slots__ = (
"expansion_kind",
"has_windows",
"is_elementwise",
"is_literal",
"is_scalar_like",
"last_node",
"n_orderable_ops",
"preserves_length",
)
def __init__(
self,
expansion_kind: ExpansionKind,
last_node: ExprKind,
*,
has_windows: bool = False,
n_orderable_ops: int = 0,
preserves_length: bool = True,
is_elementwise: bool = True,
is_scalar_like: bool = False,
is_literal: bool = False,
) -> None:
if is_literal:
assert is_scalar_like # noqa: S101 # debug assertion
if is_elementwise:
assert preserves_length # noqa: S101 # debug assertion
self.expansion_kind: ExpansionKind = expansion_kind
self.last_node: ExprKind = last_node
self.has_windows: bool = has_windows
self.n_orderable_ops: int = n_orderable_ops
self.is_elementwise: bool = is_elementwise
self.preserves_length: bool = preserves_length
self.is_scalar_like: bool = is_scalar_like
self.is_literal: bool = is_literal
def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover
msg = f"Cannot subclass {cls.__name__!r}"
raise TypeError(msg)
def __repr__(self) -> str: # pragma: no cover
return (
f"ExprMetadata(\n"
f" expansion_kind: {self.expansion_kind},\n"
f" last_node: {self.last_node},\n"
f" has_windows: {self.has_windows},\n"
f" n_orderable_ops: {self.n_orderable_ops},\n"
f" is_elementwise: {self.is_elementwise},\n"
f" preserves_length: {self.preserves_length},\n"
f" is_scalar_like: {self.is_scalar_like},\n"
f" is_literal: {self.is_literal},\n"
")"
)
@property
def is_filtration(self) -> bool:
return not self.preserves_length and not self.is_scalar_like
def with_aggregation(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.AGGREGATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
)
def with_orderable_aggregation(self) -> ExprMetadata:
# Deprecated, used only in stable.v1.
if self.is_scalar_like: # pragma: no cover
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_AGGREGATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
)
def with_elementwise_op(self) -> ExprMetadata:
return ExprMetadata(
self.expansion_kind,
ExprKind.ELEMENTWISE,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=self.preserves_length,
is_elementwise=self.is_elementwise,
is_scalar_like=self.is_scalar_like,
is_literal=self.is_literal,
)
def with_window(self) -> ExprMetadata:
# Window function which may (but doesn't have to) be used with `over(order_by=...)`.
if self.is_scalar_like:
msg = "Can't apply window (e.g. `rank`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.WINDOW,
has_windows=self.has_windows,
# The function isn't order-dependent (but, users can still use `order_by` if they wish!),
# so we don't increment `n_orderable_ops`.
n_orderable_ops=self.n_orderable_ops,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_orderable_window(self) -> ExprMetadata:
# Window function which must be used with `over(order_by=...)`.
if self.is_scalar_like:
msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_WINDOW,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_ordered_over(self) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
n_orderable_ops = self.n_orderable_ops
if not n_orderable_ops and self.last_node is not ExprKind.WINDOW:
msg = (
"Cannot use `order_by` in `over` on expression which isn't orderable.\n"
"If your expression is orderable, then make sure that `over(order_by=...)`\n"
"comes immediately after the order-dependent expression.\n\n"
"Hint: instead of\n"
" - `(nw.col('price').diff() + 1).over(order_by='date')`\n"
"write:\n"
" + `nw.col('price').diff().over(order_by='date') + 1`\n"
)
raise InvalidOperationError(msg)
if self.last_node.is_orderable_window:
n_orderable_ops -= 1
return ExprMetadata(
self.expansion_kind,
ExprKind.OVER,
has_windows=True,
n_orderable_ops=n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_partitioned_over(self) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.OVER,
has_windows=True,
n_orderable_ops=self.n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_filtration(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.FILTRATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_orderable_filtration(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_FILTRATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
@staticmethod
def aggregation() -> ExprMetadata:
return ExprMetadata(
ExpansionKind.SINGLE,
ExprKind.AGGREGATION,
is_elementwise=False,
preserves_length=False,
is_scalar_like=True,
)
@staticmethod
def literal() -> ExprMetadata:
return ExprMetadata(
ExpansionKind.SINGLE,
ExprKind.LITERAL,
is_elementwise=False,
preserves_length=False,
is_literal=True,
is_scalar_like=True,
)
@staticmethod
def selector_single() -> ExprMetadata:
# e.g. `nw.col('a')`, `nw.nth(0)`
return ExprMetadata(ExpansionKind.SINGLE, ExprKind.ELEMENTWISE)
@staticmethod
def selector_multi_named() -> ExprMetadata:
# e.g. `nw.col('a', 'b')`
return ExprMetadata(ExpansionKind.MULTI_NAMED, ExprKind.ELEMENTWISE)
@staticmethod
def selector_multi_unnamed() -> ExprMetadata:
# e.g. `nw.all()`
return ExprMetadata(ExpansionKind.MULTI_UNNAMED, ExprKind.ELEMENTWISE)
@classmethod
def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata:
# We may be able to allow multi-output rhs in the future:
# https://github.com/narwhals-dev/narwhals/issues/2244.
return combine_metadata(
lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False
)
@classmethod
def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata:
return combine_metadata(
*exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True
)
def combine_metadata(
*args: IntoExpr | object | None,
str_as_lit: bool,
allow_multi_output: bool,
to_single_output: bool,
) -> ExprMetadata:
"""Combine metadata from `args`.
Arguments:
args: Arguments, maybe expressions, literals, or Series.
str_as_lit: Whether to interpret strings as literals or as column names.
allow_multi_output: Whether to allow multi-output inputs.
to_single_output: Whether the result is always single-output, regardless
of the inputs (e.g. `nw.sum_horizontal`).
"""
n_filtrations = 0
result_expansion_kind = ExpansionKind.SINGLE
result_has_windows = False
result_n_orderable_ops = 0
# result preserves length if at least one input does
result_preserves_length = False
# result is elementwise if all inputs are elementwise
result_is_elementwise = True
# result is scalar-like if all inputs are scalar-like
result_is_scalar_like = True
# result is literal if all inputs are literal
result_is_literal = True
for i, arg in enumerate(args):
if (isinstance(arg, str) and not str_as_lit) or is_series(arg):
result_preserves_length = True
result_is_scalar_like = False
result_is_literal = False
elif is_expr(arg):
metadata = arg._metadata
if metadata.expansion_kind.is_multi_output():
expansion_kind = metadata.expansion_kind
if i > 0 and not allow_multi_output:
# Left-most argument is always allowed to be multi-output.
msg = (
"Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) "
"are not supported in this context."
)
raise MultiOutputExpressionError(msg)
if not to_single_output:
result_expansion_kind = (
result_expansion_kind & expansion_kind
if i > 0
else expansion_kind
)
result_has_windows |= metadata.has_windows
result_n_orderable_ops += metadata.n_orderable_ops
result_preserves_length |= metadata.preserves_length
result_is_elementwise &= metadata.is_elementwise
result_is_scalar_like &= metadata.is_scalar_like
result_is_literal &= metadata.is_literal
n_filtrations += int(metadata.is_filtration)
if n_filtrations > 1:
msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
raise InvalidOperationError(msg)
if result_preserves_length and n_filtrations:
msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
raise InvalidOperationError(msg)
return ExprMetadata(
result_expansion_kind,
ExprKind.NARY,
has_windows=result_has_windows,
n_orderable_ops=result_n_orderable_ops,
preserves_length=result_preserves_length,
is_elementwise=result_is_elementwise,
is_scalar_like=result_is_scalar_like,
is_literal=result_is_literal,
)
def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None:
# Raise if any argument in `args` isn't length-preserving.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
from narwhals.series import Series
if not all(
(is_expr(x) and x._metadata.preserves_length) or isinstance(x, (str, Series))
for x in args
):
msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'."
raise InvalidOperationError(msg)
def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
# Raise if any argument in `args` isn't an aggregation or literal.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
exprs = chain(args, kwargs.values())
return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs)
def apply_n_ary_operation(
plx: CompliantNamespaceAny,
function: Any,
*comparands: IntoExpr | NonNestedLiteral | _1DArray,
str_as_lit: bool,
) -> CompliantExprAny:
compliant_exprs = (
extract_compliant(plx, comparand, str_as_lit=str_as_lit)
for comparand in comparands
)
kinds = [
ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit)
for comparand in comparands
]
broadcast = any(not kind.is_scalar_like for kind in kinds)
compliant_exprs = (
compliant_expr.broadcast(kind)
if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind)
else compliant_expr
for compliant_expr, kind in zip(compliant_exprs, kinds)
)
return function(*compliant_exprs)