from __future__ import annotations import operator from typing import TYPE_CHECKING, Any, Callable, Literal, cast from duckdb import CoalesceOperator, StarExpression from duckdb.typing import DuckDBPyType from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace from narwhals._duckdb.expr_list import DuckDBExprListNamespace from narwhals._duckdb.expr_str import DuckDBExprStringNamespace from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace from narwhals._duckdb.utils import ( DeferredTimeZone, F, col, lit, narwhals_to_native_dtype, when, window_expression, ) from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._sql.expr import SQLExpr from narwhals._utils import Implementation, Version if TYPE_CHECKING: from collections.abc import Sequence from duckdb import Expression from typing_extensions import Self from narwhals._compliant import WindowInputs from narwhals._compliant.typing import ( AliasNames, EvalNames, EvalSeries, WindowFunction, ) from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace from narwhals._utils import _LimitedContext from narwhals.typing import ( FillNullStrategy, IntoDType, NonNestedLiteral, RollingInterpolationMethod, ) DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression] DuckDBWindowInputs = WindowInputs[Expression] class DuckDBExpr(SQLExpr["DuckDBLazyFrame", "Expression"]): _implementation = Implementation.DUCKDB def __init__( self, call: EvalSeries[DuckDBLazyFrame, Expression], window_function: DuckDBWindowFunction | None = None, *, evaluate_output_names: EvalNames[DuckDBLazyFrame], alias_output_names: AliasNames | None, version: Version, implementation: Implementation = Implementation.DUCKDB, ) -> None: self._call = call self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version self._metadata: ExprMetadata | None = None self._window_function: DuckDBWindowFunction | None = window_function def _count_star(self) -> Expression: return F("count", StarExpression()) def _window_expression( self, expr: Expression, partition_by: Sequence[str | Expression] = (), order_by: Sequence[str | Expression] = (), rows_start: int | None = None, rows_end: int | None = None, *, descending: Sequence[bool] | None = None, nulls_last: Sequence[bool] | None = None, ) -> Expression: return window_expression( expr, partition_by, order_by, rows_start, rows_end, descending=descending, nulls_last=nulls_last, ) def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover from narwhals._duckdb.namespace import DuckDBNamespace return DuckDBNamespace(version=self._version) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: if kind is ExprKind.LITERAL: return self if self._backend_version < (1, 3): msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." raise NotImplementedError(msg) return self.over([lit(1)], []) @classmethod def from_column_names( cls, evaluate_column_names: EvalNames[DuckDBLazyFrame], /, *, context: _LimitedContext, ) -> Self: def func(df: DuckDBLazyFrame) -> list[Expression]: return [col(name) for name in evaluate_column_names(df)] return cls( func, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, ) @classmethod def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self: def func(df: DuckDBLazyFrame) -> list[Expression]: columns = df.columns return [col(columns[i]) for i in column_indices] return cls( func, evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, ) @classmethod def _alias_native(cls, expr: Expression, name: str) -> Expression: return expr.alias(name) def __invert__(self) -> Self: invert = cast("Callable[..., Expression]", operator.invert) return self._with_elementwise(invert) def skew(self) -> Self: def func(expr: Expression) -> Expression: count = F("count", expr) # Adjust population skewness by correction factor to get sample skewness sample_skewness = ( F("skewness", expr) * (count - lit(2)) / F("sqrt", count * (count - lit(1))) ) return when(count == lit(0), lit(None)).otherwise( when(count == lit(1), lit(float("nan"))).otherwise( when(count == lit(2), lit(0.0)).otherwise(sample_skewness) ) ) return self._with_callable(func) def kurtosis(self) -> Self: return self._with_callable(lambda expr: F("kurtosis_pop", expr)) def quantile( self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: def func(expr: Expression) -> Expression: if interpolation == "linear": return F("quantile_cont", expr, lit(quantile)) msg = "Only linear interpolation methods are supported for DuckDB quantile." raise NotImplementedError(msg) return self._with_callable(func) def n_unique(self) -> Self: def func(expr: Expression) -> Expression: # https://stackoverflow.com/a/79338887/4451315 return F("array_unique", F("array_agg", expr)) + F( "max", when(expr.isnotnull(), lit(0)).otherwise(lit(1)) ) return self._with_callable(func) def len(self) -> Self: return self._with_callable(lambda _expr: F("count")) def std(self, ddof: int) -> Self: if ddof == 0: return self._with_callable(lambda expr: F("stddev_pop", expr)) if ddof == 1: return self._with_callable(lambda expr: F("stddev_samp", expr)) def _std(expr: Expression) -> Expression: n_samples = F("count", expr) return ( F("stddev_pop", expr) * F("sqrt", n_samples) / (F("sqrt", (n_samples - lit(ddof)))) ) return self._with_callable(_std) def var(self, ddof: int) -> Self: if ddof == 0: return self._with_callable(lambda expr: F("var_pop", expr)) if ddof == 1: return self._with_callable(lambda expr: F("var_samp", expr)) def _var(expr: Expression) -> Expression: n_samples = F("count", expr) return F("var_pop", expr) * n_samples / (n_samples - lit(ddof)) return self._with_callable(_var) def null_count(self) -> Self: return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int"))) def is_nan(self) -> Self: return self._with_elementwise(lambda expr: F("isnan", expr)) def is_finite(self) -> Self: return self._with_elementwise(lambda expr: F("isfinite", expr)) def is_in(self, other: Sequence[Any]) -> Self: return self._with_elementwise(lambda expr: F("contains", lit(other), expr)) def fill_null( self, value: Self | NonNestedLiteral, strategy: FillNullStrategy | None, limit: int | None, ) -> Self: if strategy is not None: if self._backend_version < (1, 3): # pragma: no cover msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'." raise NotImplementedError(msg) def _fill_with_strategy( df: DuckDBLazyFrame, inputs: DuckDBWindowInputs ) -> Sequence[Expression]: fill_func = "last_value" if strategy == "forward" else "first_value" rows_start, rows_end = ( (-limit if limit is not None else None, 0) if strategy == "forward" else (0, limit) ) return [ window_expression( F(fill_func, expr), inputs.partition_by, inputs.order_by, rows_start=rows_start, rows_end=rows_end, ignore_nulls=True, ) for expr in self(df) ] return self._with_window_function(_fill_with_strategy) def _fill_constant(expr: Expression, value: Any) -> Expression: return CoalesceOperator(expr, value) return self._with_elementwise(_fill_constant, value=value) def cast(self, dtype: IntoDType) -> Self: def func(df: DuckDBLazyFrame) -> list[Expression]: tz = DeferredTimeZone(df.native) native_dtype = narwhals_to_native_dtype(dtype, self._version, tz) return [expr.cast(DuckDBPyType(native_dtype)) for expr in self(df)] def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]: tz = DeferredTimeZone(df.native) native_dtype = narwhals_to_native_dtype(dtype, self._version, tz) return [ expr.cast(DuckDBPyType(native_dtype)) for expr in self.window_function(df, inputs) ] return self.__class__( func, window_f, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, ) def log(self, base: float) -> Self: def _log(expr: Expression) -> Expression: log = F("log", expr) return ( when(expr < lit(0), lit(float("nan"))) .when(expr == lit(0), lit(float("-inf"))) .otherwise(log / F("log", lit(base))) ) return self._with_elementwise(_log) def sqrt(self) -> Self: def _sqrt(expr: Expression) -> Expression: return when(expr < lit(0), lit(float("nan"))).otherwise(F("sqrt", expr)) return self._with_elementwise(_sqrt) @property def str(self) -> DuckDBExprStringNamespace: return DuckDBExprStringNamespace(self) @property def dt(self) -> DuckDBExprDateTimeNamespace: return DuckDBExprDateTimeNamespace(self) @property def list(self) -> DuckDBExprListNamespace: return DuckDBExprListNamespace(self) @property def struct(self) -> DuckDBExprStructNamespace: return DuckDBExprStructNamespace(self)