team-10/venv/Lib/site-packages/narwhals/_duckdb/expr.py

327 lines
11 KiB
Python
Raw Normal View History

2025-08-02 02:00:33 +02:00
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
from duckdb import CoalesceOperator, StarExpression
from duckdb.typing import DuckDBPyType
from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
from narwhals._duckdb.utils import (
DeferredTimeZone,
F,
col,
lit,
narwhals_to_native_dtype,
when,
window_expression,
)
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version
if TYPE_CHECKING:
from collections.abc import Sequence
from duckdb import Expression
from typing_extensions import Self
from narwhals._compliant import WindowInputs
from narwhals._compliant.typing import (
AliasNames,
EvalNames,
EvalSeries,
WindowFunction,
)
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._utils import _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
NonNestedLiteral,
RollingInterpolationMethod,
)
DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression]
DuckDBWindowInputs = WindowInputs[Expression]
class DuckDBExpr(SQLExpr["DuckDBLazyFrame", "Expression"]):
_implementation = Implementation.DUCKDB
def __init__(
self,
call: EvalSeries[DuckDBLazyFrame, Expression],
window_function: DuckDBWindowFunction | None = None,
*,
evaluate_output_names: EvalNames[DuckDBLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
implementation: Implementation = Implementation.DUCKDB,
) -> None:
self._call = call
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._metadata: ExprMetadata | None = None
self._window_function: DuckDBWindowFunction | None = window_function
def _count_star(self) -> Expression:
return F("count", StarExpression())
def _window_expression(
self,
expr: Expression,
partition_by: Sequence[str | Expression] = (),
order_by: Sequence[str | Expression] = (),
rows_start: int | None = None,
rows_end: int | None = None,
*,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> Expression:
return window_expression(
expr,
partition_by,
order_by,
rows_start,
rows_end,
descending=descending,
nulls_last=nulls_last,
)
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
from narwhals._duckdb.namespace import DuckDBNamespace
return DuckDBNamespace(version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
if kind is ExprKind.LITERAL:
return self
if self._backend_version < (1, 3):
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
raise NotImplementedError(msg)
return self.over([lit(1)], [])
@classmethod
def from_column_names(
cls,
evaluate_column_names: EvalNames[DuckDBLazyFrame],
/,
*,
context: _LimitedContext,
) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
return [col(name) for name in evaluate_column_names(df)]
return cls(
func,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
columns = df.columns
return [col(columns[i]) for i in column_indices]
return cls(
func,
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
@classmethod
def _alias_native(cls, expr: Expression, name: str) -> Expression:
return expr.alias(name)
def __invert__(self) -> Self:
invert = cast("Callable[..., Expression]", operator.invert)
return self._with_elementwise(invert)
def skew(self) -> Self:
def func(expr: Expression) -> Expression:
count = F("count", expr)
# Adjust population skewness by correction factor to get sample skewness
sample_skewness = (
F("skewness", expr)
* (count - lit(2))
/ F("sqrt", count * (count - lit(1)))
)
return when(count == lit(0), lit(None)).otherwise(
when(count == lit(1), lit(float("nan"))).otherwise(
when(count == lit(2), lit(0.0)).otherwise(sample_skewness)
)
)
return self._with_callable(func)
def kurtosis(self) -> Self:
return self._with_callable(lambda expr: F("kurtosis_pop", expr))
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
def func(expr: Expression) -> Expression:
if interpolation == "linear":
return F("quantile_cont", expr, lit(quantile))
msg = "Only linear interpolation methods are supported for DuckDB quantile."
raise NotImplementedError(msg)
return self._with_callable(func)
def n_unique(self) -> Self:
def func(expr: Expression) -> Expression:
# https://stackoverflow.com/a/79338887/4451315
return F("array_unique", F("array_agg", expr)) + F(
"max", when(expr.isnotnull(), lit(0)).otherwise(lit(1))
)
return self._with_callable(func)
def len(self) -> Self:
return self._with_callable(lambda _expr: F("count"))
def std(self, ddof: int) -> Self:
if ddof == 0:
return self._with_callable(lambda expr: F("stddev_pop", expr))
if ddof == 1:
return self._with_callable(lambda expr: F("stddev_samp", expr))
def _std(expr: Expression) -> Expression:
n_samples = F("count", expr)
return (
F("stddev_pop", expr)
* F("sqrt", n_samples)
/ (F("sqrt", (n_samples - lit(ddof))))
)
return self._with_callable(_std)
def var(self, ddof: int) -> Self:
if ddof == 0:
return self._with_callable(lambda expr: F("var_pop", expr))
if ddof == 1:
return self._with_callable(lambda expr: F("var_samp", expr))
def _var(expr: Expression) -> Expression:
n_samples = F("count", expr)
return F("var_pop", expr) * n_samples / (n_samples - lit(ddof))
return self._with_callable(_var)
def null_count(self) -> Self:
return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int")))
def is_nan(self) -> Self:
return self._with_elementwise(lambda expr: F("isnan", expr))
def is_finite(self) -> Self:
return self._with_elementwise(lambda expr: F("isfinite", expr))
def is_in(self, other: Sequence[Any]) -> Self:
return self._with_elementwise(lambda expr: F("contains", lit(other), expr))
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
if strategy is not None:
if self._backend_version < (1, 3): # pragma: no cover
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
raise NotImplementedError(msg)
def _fill_with_strategy(
df: DuckDBLazyFrame, inputs: DuckDBWindowInputs
) -> Sequence[Expression]:
fill_func = "last_value" if strategy == "forward" else "first_value"
rows_start, rows_end = (
(-limit if limit is not None else None, 0)
if strategy == "forward"
else (0, limit)
)
return [
window_expression(
F(fill_func, expr),
inputs.partition_by,
inputs.order_by,
rows_start=rows_start,
rows_end=rows_end,
ignore_nulls=True,
)
for expr in self(df)
]
return self._with_window_function(_fill_with_strategy)
def _fill_constant(expr: Expression, value: Any) -> Expression:
return CoalesceOperator(expr, value)
return self._with_elementwise(_fill_constant, value=value)
def cast(self, dtype: IntoDType) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return [expr.cast(DuckDBPyType(native_dtype)) for expr in self(df)]
def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return [
expr.cast(DuckDBPyType(native_dtype))
for expr in self.window_function(df, inputs)
]
return self.__class__(
func,
window_f,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
def log(self, base: float) -> Self:
def _log(expr: Expression) -> Expression:
log = F("log", expr)
return (
when(expr < lit(0), lit(float("nan")))
.when(expr == lit(0), lit(float("-inf")))
.otherwise(log / F("log", lit(base)))
)
return self._with_elementwise(_log)
def sqrt(self) -> Self:
def _sqrt(expr: Expression) -> Expression:
return when(expr < lit(0), lit(float("nan"))).otherwise(F("sqrt", expr))
return self._with_elementwise(_sqrt)
@property
def str(self) -> DuckDBExprStringNamespace:
return DuckDBExprStringNamespace(self)
@property
def dt(self) -> DuckDBExprDateTimeNamespace:
return DuckDBExprDateTimeNamespace(self)
@property
def list(self) -> DuckDBExprListNamespace:
return DuckDBExprListNamespace(self)
@property
def struct(self) -> DuckDBExprStructNamespace:
return DuckDBExprStructNamespace(self)