from __future__ import annotations import warnings from typing import TYPE_CHECKING, Any, Callable, Literal from narwhals._compliant import DepthTrackingExpr, LazyExpr from narwhals._dask.expr_dt import DaskExprDateTimeNamespace from narwhals._dask.expr_str import DaskExprStringNamespace from narwhals._dask.utils import ( add_row_index, maybe_evaluate_expr, narwhals_to_native_dtype, ) from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._utils import ( Implementation, generate_temporary_column_name, not_implemented, ) from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from collections.abc import Sequence import dask.dataframe.dask_expr as dx from typing_extensions import Self from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._utils import Version, _LimitedContext from narwhals.typing import ( FillNullStrategy, IntoDType, NonNestedLiteral, NumericLiteral, RollingInterpolationMethod, TemporalLiteral, ) class DaskExpr( LazyExpr["DaskLazyFrame", "dx.Series"], DepthTrackingExpr["DaskLazyFrame", "dx.Series"], ): _implementation: Implementation = Implementation.DASK def __init__( self, call: EvalSeries[DaskLazyFrame, dx.Series], *, depth: int, function_name: str, evaluate_output_names: EvalNames[DaskLazyFrame], alias_output_names: AliasNames | None, version: Version, scalar_kwargs: ScalarKwargs | None = None, ) -> None: self._call = call self._depth = depth self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version self._scalar_kwargs = scalar_kwargs or {} self._metadata: ExprMetadata | None = None def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: return self._call(df) def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover from narwhals._dask.namespace import DaskNamespace return DaskNamespace(version=self._version) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: # result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16 # that raised a KeyError for result[0] during collection. return [result.loc[0][0] for result in self(df)] return self.__class__( func, depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, scalar_kwargs=self._scalar_kwargs, ) @classmethod def from_column_names( cls: type[Self], evaluate_column_names: EvalNames[DaskLazyFrame], /, *, context: _LimitedContext, function_name: str = "", ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: return [ df._native_frame[column_name] for column_name in evaluate_column_names(df) ] except KeyError as e: if error := df._check_columns_exist(evaluate_column_names(df)): raise error from e raise return cls( func, depth=0, function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, ) @classmethod def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: return [df.native.iloc[:, i] for i in column_indices] return cls( func, depth=0, function_name="nth", evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, ) def _with_callable( self, # First argument to `call` should be `dx.Series` call: Callable[..., dx.Series], /, expr_name: str = "", scalar_kwargs: ScalarKwargs | None = None, **expressifiable_args: Self | Any, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: native_results: list[dx.Series] = [] native_series_list = self._call(df) other_native_series = { key: maybe_evaluate_expr(df, value) for key, value in expressifiable_args.items() } for native_series in native_series_list: result_native = call(native_series, **other_native_series) native_results.append(result_native) return native_results return self.__class__( func, depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, scalar_kwargs=scalar_kwargs, ) def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: current_alias_output_names = self._alias_output_names alias_output_names = ( None if func is None else func if current_alias_output_names is None else lambda output_names: func(current_alias_output_names(output_names)) ) return type(self)( call=self._call, depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, version=self._version, scalar_kwargs=self._scalar_kwargs, ) def __add__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: expr.__add__(other), "__add__", other=other ) def __sub__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: expr.__sub__(other), "__sub__", other=other ) def __rsub__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: other - expr, "__rsub__", other=other ).alias("literal") def __mul__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: expr.__mul__(other), "__mul__", other=other ) def __truediv__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: expr.__truediv__(other), "__truediv__", other=other ) def __rtruediv__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: other / expr, "__rtruediv__", other=other ).alias("literal") def __floordiv__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: expr.__floordiv__(other), "__floordiv__", other=other ) def __rfloordiv__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: other // expr, "__rfloordiv__", other=other ).alias("literal") def __pow__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: expr.__pow__(other), "__pow__", other=other ) def __rpow__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: other**expr, "__rpow__", other=other ).alias("literal") def __mod__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: expr.__mod__(other), "__mod__", other=other ) def __rmod__(self, other: Any) -> Self: return self._with_callable( lambda expr, other: other % expr, "__rmod__", other=other ).alias("literal") def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override] return self._with_callable( lambda expr, other: expr.__eq__(other), "__eq__", other=other ) def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override] return self._with_callable( lambda expr, other: expr.__ne__(other), "__ne__", other=other ) def __ge__(self, other: DaskExpr | Any) -> Self: return self._with_callable( lambda expr, other: expr.__ge__(other), "__ge__", other=other ) def __gt__(self, other: DaskExpr) -> Self: return self._with_callable( lambda expr, other: expr.__gt__(other), "__gt__", other=other ) def __le__(self, other: DaskExpr) -> Self: return self._with_callable( lambda expr, other: expr.__le__(other), "__le__", other=other ) def __lt__(self, other: DaskExpr) -> Self: return self._with_callable( lambda expr, other: expr.__lt__(other), "__lt__", other=other ) def __and__(self, other: DaskExpr | Any) -> Self: return self._with_callable( lambda expr, other: expr.__and__(other), "__and__", other=other ) def __or__(self, other: DaskExpr) -> Self: return self._with_callable( lambda expr, other: expr.__or__(other), "__or__", other=other ) def __invert__(self) -> Self: return self._with_callable(lambda expr: expr.__invert__(), "__invert__") def mean(self) -> Self: return self._with_callable(lambda expr: expr.mean().to_series(), "mean") def median(self) -> Self: from narwhals.exceptions import InvalidOperationError def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK) if not dtype.is_numeric(): msg = "`median` operation not supported for non-numeric input type." raise InvalidOperationError(msg) return s.median_approximate().to_series() return self._with_callable(func, "median") def min(self) -> Self: return self._with_callable(lambda expr: expr.min().to_series(), "min") def max(self) -> Self: return self._with_callable(lambda expr: expr.max().to_series(), "max") def std(self, ddof: int) -> Self: return self._with_callable( lambda expr: expr.std(ddof=ddof).to_series(), "std", scalar_kwargs={"ddof": ddof}, ) def var(self, ddof: int) -> Self: return self._with_callable( lambda expr: expr.var(ddof=ddof).to_series(), "var", scalar_kwargs={"ddof": ddof}, ) def skew(self) -> Self: return self._with_callable(lambda expr: expr.skew().to_series(), "skew") def kurtosis(self) -> Self: return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis") def shift(self, n: int) -> Self: return self._with_callable(lambda expr: expr.shift(n), "shift") def cum_sum(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover # https://github.com/dask/dask/issues/11802 msg = "`cum_sum(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda expr: expr.cumsum(), "cum_sum") def cum_count(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_count(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable( lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count" ) def cum_min(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_min(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda expr: expr.cummin(), "cum_min") def cum_max(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_max(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda expr: expr.cummax(), "cum_max") def cum_prod(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_prod(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda expr: expr.cumprod(), "cum_prod") def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center ).sum(), "rolling_sum", ) def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center ).mean(), "rolling_mean", ) def rolling_var( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if ddof == 1: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center ).var(), "rolling_var", ) else: msg = "Dask backend only supports `ddof=1` for `rolling_var`" raise NotImplementedError(msg) def rolling_std( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if ddof == 1: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center ).std(), "rolling_std", ) else: msg = "Dask backend only supports `ddof=1` for `rolling_std`" raise NotImplementedError(msg) def sum(self) -> Self: return self._with_callable(lambda expr: expr.sum().to_series(), "sum") def count(self) -> Self: return self._with_callable(lambda expr: expr.count().to_series(), "count") def round(self, decimals: int) -> Self: return self._with_callable(lambda expr: expr.round(decimals), "round") def unique(self) -> Self: return self._with_callable(lambda expr: expr.unique(), "unique") def drop_nulls(self) -> Self: return self._with_callable(lambda expr: expr.dropna(), "drop_nulls") def abs(self) -> Self: return self._with_callable(lambda expr: expr.abs(), "abs") def all(self) -> Self: return self._with_callable( lambda expr: expr.all( axis=None, skipna=True, split_every=False, out=None ).to_series(), "all", ) def any(self) -> Self: return self._with_callable( lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(), "any", ) def fill_null( self, value: Self | NonNestedLiteral, strategy: FillNullStrategy | None, limit: int | None, ) -> Self: def func(expr: dx.Series) -> dx.Series: if value is not None: res_ser = expr.fillna(value) else: res_ser = ( expr.ffill(limit=limit) if strategy == "forward" else expr.bfill(limit=limit) ) return res_ser return self._with_callable(func, "fillna") def clip( self, lower_bound: Self | NumericLiteral | TemporalLiteral | None, upper_bound: Self | NumericLiteral | TemporalLiteral | None, ) -> Self: return self._with_callable( lambda expr, lower_bound, upper_bound: expr.clip( lower=lower_bound, upper=upper_bound ), "clip", lower_bound=lower_bound, upper_bound=upper_bound, ) def diff(self) -> Self: return self._with_callable(lambda expr: expr.diff(), "diff") def n_unique(self) -> Self: return self._with_callable( lambda expr: expr.nunique(dropna=False).to_series(), "n_unique" ) def is_null(self) -> Self: return self._with_callable(lambda expr: expr.isna(), "is_null") def is_nan(self) -> Self: def func(expr: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype( expr.dtype, self._version, self._implementation ) if dtype.is_numeric(): return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124 msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?" raise InvalidOperationError(msg) return self._with_callable(func, "is_null") def len(self) -> Self: return self._with_callable(lambda expr: expr.size.to_series(), "len") def quantile( self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: if interpolation == "linear": def func(expr: dx.Series, quantile: float) -> dx.Series: if expr.npartitions > 1: msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." raise NotImplementedError(msg) return expr.quantile( q=quantile, method="dask" ).to_series() # pragma: no cover return self._with_callable(func, "quantile", quantile=quantile) else: msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." raise NotImplementedError(msg) def is_first_distinct(self) -> Self: def func(expr: dx.Series) -> dx.Series: _name = expr.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) frame = add_row_index(expr.to_frame(), col_token) first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token] return frame[col_token].isin(first_distinct_index) return self._with_callable(func, "is_first_distinct") def is_last_distinct(self) -> Self: def func(expr: dx.Series) -> dx.Series: _name = expr.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) frame = add_row_index(expr.to_frame(), col_token) last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token] return frame[col_token].isin(last_distinct_index) return self._with_callable(func, "is_last_distinct") def is_unique(self) -> Self: def func(expr: dx.Series) -> dx.Series: _name = expr.name return ( expr.to_frame() .groupby(_name, dropna=False) .transform("size", meta=(_name, int)) == 1 ) return self._with_callable(func, "is_unique") def is_in(self, other: Any) -> Self: return self._with_callable(lambda expr: expr.isin(other), "is_in") def null_count(self) -> Self: return self._with_callable( lambda expr: expr.isna().sum().to_series(), "null_count" ) def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: # pandas is a required dependency of dask so it's safe to import this from narwhals._pandas_like.group_by import PandasLikeGroupBy if not partition_by: assert order_by # noqa: S101 # This is something like `nw.col('a').cum_sum().order_by(key)` # which we can always easily support, as it doesn't require grouping. def func(df: DaskLazyFrame) -> Sequence[dx.Series]: return self(df.sort(*order_by, descending=False, nulls_last=False)) elif not self._is_elementary(): # pragma: no cover msg = ( "Only elementary expressions are supported for `.over` in dask.\n\n" "Please see: " "https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/" ) raise NotImplementedError(msg) elif order_by: # Wrong results https://github.com/dask/dask/issues/11806. msg = "`over` with `order_by` is not yet supported in Dask." raise NotImplementedError(msg) else: function_name = PandasLikeGroupBy._leaf_name(self) try: dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] except KeyError: # window functions are unsupported: https://github.com/dask/dask/issues/11806 msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n" ) raise NotImplementedError(msg) from None def func(df: DaskLazyFrame) -> Sequence[dx.Series]: output_names, aliases = evaluate_output_names_and_aliases(self, df, []) with warnings.catch_warnings(): # https://github.com/dask/dask/issues/11804 warnings.filterwarnings( "ignore", message=".*`meta` is not specified", category=UserWarning, ) grouped = df.native.groupby(partition_by) if dask_function_name == "size": if len(output_names) != 1: # pragma: no cover msg = "Safety check failed, please report a bug." raise AssertionError(msg) res_native = grouped.transform( dask_function_name, **self._scalar_kwargs ).to_frame(output_names[0]) else: res_native = grouped[list(output_names)].transform( dask_function_name, **self._scalar_kwargs ) result_frame = df._with_native( res_native.rename(columns=dict(zip(output_names, aliases))) ).native return [result_frame[name] for name in aliases] return self.__class__( func, depth=self._depth + 1, function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, ) def cast(self, dtype: IntoDType) -> Self: def func(expr: dx.Series) -> dx.Series: native_dtype = narwhals_to_native_dtype(dtype, self._version) return expr.astype(native_dtype) return self._with_callable(func, "cast") def is_finite(self) -> Self: import dask.array as da return self._with_callable(da.isfinite, "is_finite") def log(self, base: float) -> Self: import dask.array as da def _log(expr: dx.Series) -> dx.Series: return da.log(expr) / da.log(base) return self._with_callable(_log, "log") def exp(self) -> Self: import dask.array as da return self._with_callable(da.exp, "exp") def sqrt(self) -> Self: import dask.array as da return self._with_callable(da.sqrt, "sqrt") @property def str(self) -> DaskExprStringNamespace: return DaskExprStringNamespace(self) @property def dt(self) -> DaskExprDateTimeNamespace: return DaskExprDateTimeNamespace(self) arg_max: not_implemented = not_implemented() arg_min: not_implemented = not_implemented() arg_true: not_implemented = not_implemented() ewm_mean: not_implemented = not_implemented() gather_every: not_implemented = not_implemented() head: not_implemented = not_implemented() map_batches: not_implemented = not_implemented() mode: not_implemented = not_implemented() sample: not_implemented = not_implemented() rank: not_implemented = not_implemented() replace_strict: not_implemented = not_implemented() sort: not_implemented = not_implemented() tail: not_implemented = not_implemented() # namespaces list: not_implemented = not_implemented() # type: ignore[assignment] cat: not_implemented = not_implemented() # type: ignore[assignment] struct: not_implemented = not_implemented() # type: ignore[assignment]