team-10/venv/Lib/site-packages/narwhals/_dask/expr.py

693 lines
25 KiB
Python
Raw Normal View History

2025-08-02 02:00:33 +02:00
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal
from narwhals._compliant import DepthTrackingExpr, LazyExpr
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
from narwhals._dask.expr_str import DaskExprStringNamespace
from narwhals._dask.utils import (
add_row_index,
maybe_evaluate_expr,
narwhals_to_native_dtype,
)
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
not_implemented,
)
from narwhals.exceptions import InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Sequence
import dask.dataframe.dask_expr as dx
from typing_extensions import Self
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._utils import Version, _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
NonNestedLiteral,
NumericLiteral,
RollingInterpolationMethod,
TemporalLiteral,
)
class DaskExpr(
LazyExpr["DaskLazyFrame", "dx.Series"],
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
):
_implementation: Implementation = Implementation.DASK
def __init__(
self,
call: EvalSeries[DaskLazyFrame, dx.Series],
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[DaskLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
scalar_kwargs: ScalarKwargs | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._scalar_kwargs = scalar_kwargs or {}
self._metadata: ExprMetadata | None = None
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
return self._call(df)
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
# that raised a KeyError for result[0] during collection.
return [result.loc[0][0] for result in self(df)]
return self.__class__(
func,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[DaskLazyFrame],
/,
*,
context: _LimitedContext,
function_name: str = "",
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
return [
df._native_frame[column_name]
for column_name in evaluate_column_names(df)
]
except KeyError as e:
if error := df._check_columns_exist(evaluate_column_names(df)):
raise error from e
raise
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [df.native.iloc[:, i] for i in column_indices]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
def _with_callable(
self,
# First argument to `call` should be `dx.Series`
call: Callable[..., dx.Series],
/,
expr_name: str = "",
scalar_kwargs: ScalarKwargs | None = None,
**expressifiable_args: Self | Any,
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
native_results: list[dx.Series] = []
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate_expr(df, value)
for key, value in expressifiable_args.items()
}
for native_series in native_series_list:
result_native = call(native_series, **other_native_series)
native_results.append(result_native)
return native_results
return self.__class__(
func,
depth=self._depth + 1,
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=scalar_kwargs,
)
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
current_alias_output_names = self._alias_output_names
alias_output_names = (
None
if func is None
else func
if current_alias_output_names is None
else lambda output_names: func(current_alias_output_names(output_names))
)
return type(self)(
call=self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
def __add__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__add__(other), "__add__", other=other
)
def __sub__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__sub__(other), "__sub__", other=other
)
def __rsub__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other - expr, "__rsub__", other=other
).alias("literal")
def __mul__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__mul__(other), "__mul__", other=other
)
def __truediv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__truediv__(other), "__truediv__", other=other
)
def __rtruediv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other / expr, "__rtruediv__", other=other
).alias("literal")
def __floordiv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__floordiv__(other), "__floordiv__", other=other
)
def __rfloordiv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other // expr, "__rfloordiv__", other=other
).alias("literal")
def __pow__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__pow__(other), "__pow__", other=other
)
def __rpow__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other**expr, "__rpow__", other=other
).alias("literal")
def __mod__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__mod__(other), "__mod__", other=other
)
def __rmod__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other % expr, "__rmod__", other=other
).alias("literal")
def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override]
return self._with_callable(
lambda expr, other: expr.__eq__(other), "__eq__", other=other
)
def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override]
return self._with_callable(
lambda expr, other: expr.__ne__(other), "__ne__", other=other
)
def __ge__(self, other: DaskExpr | Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__ge__(other), "__ge__", other=other
)
def __gt__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__gt__(other), "__gt__", other=other
)
def __le__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__le__(other), "__le__", other=other
)
def __lt__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__lt__(other), "__lt__", other=other
)
def __and__(self, other: DaskExpr | Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__and__(other), "__and__", other=other
)
def __or__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__or__(other), "__or__", other=other
)
def __invert__(self) -> Self:
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
def mean(self) -> Self:
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
def median(self) -> Self:
from narwhals.exceptions import InvalidOperationError
def func(s: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
if not dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)
return s.median_approximate().to_series()
return self._with_callable(func, "median")
def min(self) -> Self:
return self._with_callable(lambda expr: expr.min().to_series(), "min")
def max(self) -> Self:
return self._with_callable(lambda expr: expr.max().to_series(), "max")
def std(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.std(ddof=ddof).to_series(),
"std",
scalar_kwargs={"ddof": ddof},
)
def var(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.var(ddof=ddof).to_series(),
"var",
scalar_kwargs={"ddof": ddof},
)
def skew(self) -> Self:
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
def kurtosis(self) -> Self:
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
def shift(self, n: int) -> Self:
return self._with_callable(lambda expr: expr.shift(n), "shift")
def cum_sum(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
# https://github.com/dask/dask/issues/11802
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
def cum_count(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
)
def cum_min(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
def cum_max(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
def cum_prod(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).sum(),
"rolling_sum",
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).mean(),
"rolling_mean",
)
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).var(),
"rolling_var",
)
else:
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
raise NotImplementedError(msg)
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).std(),
"rolling_std",
)
else:
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
raise NotImplementedError(msg)
def sum(self) -> Self:
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
def count(self) -> Self:
return self._with_callable(lambda expr: expr.count().to_series(), "count")
def round(self, decimals: int) -> Self:
return self._with_callable(lambda expr: expr.round(decimals), "round")
def unique(self) -> Self:
return self._with_callable(lambda expr: expr.unique(), "unique")
def drop_nulls(self) -> Self:
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
def abs(self) -> Self:
return self._with_callable(lambda expr: expr.abs(), "abs")
def all(self) -> Self:
return self._with_callable(
lambda expr: expr.all(
axis=None, skipna=True, split_every=False, out=None
).to_series(),
"all",
)
def any(self) -> Self:
return self._with_callable(
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
"any",
)
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
def func(expr: dx.Series) -> dx.Series:
if value is not None:
res_ser = expr.fillna(value)
else:
res_ser = (
expr.ffill(limit=limit)
if strategy == "forward"
else expr.bfill(limit=limit)
)
return res_ser
return self._with_callable(func, "fillna")
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self:
return self._with_callable(
lambda expr, lower_bound, upper_bound: expr.clip(
lower=lower_bound, upper=upper_bound
),
"clip",
lower_bound=lower_bound,
upper_bound=upper_bound,
)
def diff(self) -> Self:
return self._with_callable(lambda expr: expr.diff(), "diff")
def n_unique(self) -> Self:
return self._with_callable(
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
)
def is_null(self) -> Self:
return self._with_callable(lambda expr: expr.isna(), "is_null")
def is_nan(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(
expr.dtype, self._version, self._implementation
)
if dtype.is_numeric():
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
raise InvalidOperationError(msg)
return self._with_callable(func, "is_null")
def len(self) -> Self:
return self._with_callable(lambda expr: expr.size.to_series(), "len")
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
if interpolation == "linear":
def func(expr: dx.Series, quantile: float) -> dx.Series:
if expr.npartitions > 1:
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
return expr.quantile(
q=quantile, method="dask"
).to_series() # pragma: no cover
return self._with_callable(func, "quantile", quantile=quantile)
else:
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
raise NotImplementedError(msg)
def is_first_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
return frame[col_token].isin(first_distinct_index)
return self._with_callable(func, "is_first_distinct")
def is_last_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
return frame[col_token].isin(last_distinct_index)
return self._with_callable(func, "is_last_distinct")
def is_unique(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
return (
expr.to_frame()
.groupby(_name, dropna=False)
.transform("size", meta=(_name, int))
== 1
)
return self._with_callable(func, "is_unique")
def is_in(self, other: Any) -> Self:
return self._with_callable(lambda expr: expr.isin(other), "is_in")
def null_count(self) -> Self:
return self._with_callable(
lambda expr: expr.isna().sum().to_series(), "null_count"
)
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
# pandas is a required dependency of dask so it's safe to import this
from narwhals._pandas_like.group_by import PandasLikeGroupBy
if not partition_by:
assert order_by # noqa: S101
# This is something like `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
return self(df.sort(*order_by, descending=False, nulls_last=False))
elif not self._is_elementary(): # pragma: no cover
msg = (
"Only elementary expressions are supported for `.over` in dask.\n\n"
"Please see: "
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
)
raise NotImplementedError(msg)
elif order_by:
# Wrong results https://github.com/dask/dask/issues/11806.
msg = "`over` with `order_by` is not yet supported in Dask."
raise NotImplementedError(msg)
else:
function_name = PandasLikeGroupBy._leaf_name(self)
try:
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
except KeyError:
# window functions are unsupported: https://github.com/dask/dask/issues/11806
msg = (
f"Unsupported function: {function_name} in `over` context.\n\n"
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
)
raise NotImplementedError(msg) from None
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
with warnings.catch_warnings():
# https://github.com/dask/dask/issues/11804
warnings.filterwarnings(
"ignore",
message=".*`meta` is not specified",
category=UserWarning,
)
grouped = df.native.groupby(partition_by)
if dask_function_name == "size":
if len(output_names) != 1: # pragma: no cover
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform(
dask_function_name, **self._scalar_kwargs
).to_frame(output_names[0])
else:
res_native = grouped[list(output_names)].transform(
dask_function_name, **self._scalar_kwargs
)
result_frame = df._with_native(
res_native.rename(columns=dict(zip(output_names, aliases)))
).native
return [result_frame[name] for name in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
def cast(self, dtype: IntoDType) -> Self:
def func(expr: dx.Series) -> dx.Series:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
return expr.astype(native_dtype)
return self._with_callable(func, "cast")
def is_finite(self) -> Self:
import dask.array as da
return self._with_callable(da.isfinite, "is_finite")
def log(self, base: float) -> Self:
import dask.array as da
def _log(expr: dx.Series) -> dx.Series:
return da.log(expr) / da.log(base)
return self._with_callable(_log, "log")
def exp(self) -> Self:
import dask.array as da
return self._with_callable(da.exp, "exp")
def sqrt(self) -> Self:
import dask.array as da
return self._with_callable(da.sqrt, "sqrt")
@property
def str(self) -> DaskExprStringNamespace:
return DaskExprStringNamespace(self)
@property
def dt(self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)
arg_max: not_implemented = not_implemented()
arg_min: not_implemented = not_implemented()
arg_true: not_implemented = not_implemented()
ewm_mean: not_implemented = not_implemented()
gather_every: not_implemented = not_implemented()
head: not_implemented = not_implemented()
map_batches: not_implemented = not_implemented()
mode: not_implemented = not_implemented()
sample: not_implemented = not_implemented()
rank: not_implemented = not_implemented()
replace_strict: not_implemented = not_implemented()
sort: not_implemented = not_implemented()
tail: not_implemented = not_implemented()
# namespaces
list: not_implemented = not_implemented() # type: ignore[assignment]
cat: not_implemented = not_implemented() # type: ignore[assignment]
struct: not_implemented = not_implemented() # type: ignore[assignment]