from __future__ import annotations from typing import TYPE_CHECKING, Any import dask.dataframe as dd from narwhals._dask.utils import add_row_index, evaluate_exprs from narwhals._expression_parsing import ExprKind from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name from narwhals._typing_compat import assert_never from narwhals._utils import ( Implementation, ValidateBackendVersion, _remap_full_join_keys, check_column_names_are_unique, generate_temporary_column_name, not_implemented, parse_columns_to_drop, ) from narwhals.typing import CompliantLazyFrame if TYPE_CHECKING: from collections.abc import Iterator, Mapping, Sequence from io import BytesIO from pathlib import Path from types import ModuleType import dask.dataframe.dask_expr as dx from typing_extensions import Self, TypeAlias, TypeIs from narwhals._compliant.typing import CompliantDataFrameAny from narwhals._dask.expr import DaskExpr from narwhals._dask.group_by import DaskLazyGroupBy from narwhals._dask.namespace import DaskNamespace from narwhals._utils import Version, _LimitedContext from narwhals.dataframe import LazyFrame from narwhals.dtypes import DType from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy Incomplete: TypeAlias = "Any" """Using `_pandas_like` utils with `_dask`. Typing this correctly will complicate the `_pandas_like`-side. Very low priority until `dask` adds typing. """ class DaskLazyFrame( CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"], ValidateBackendVersion, ): _implementation = Implementation.DASK def __init__( self, native_dataframe: dd.DataFrame, *, version: Version, validate_backend_version: bool = False, ) -> None: self._native_frame: dd.DataFrame = native_dataframe self._version = version self._cached_schema: dict[str, DType] | None = None self._cached_columns: list[str] | None = None if validate_backend_version: self._validate_backend_version() @staticmethod def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]: return isinstance(obj, dd.DataFrame) @classmethod def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self: return cls(data, version=context._version) def to_narwhals(self) -> LazyFrame[dd.DataFrame]: return self._version.lazyframe(self, level="lazy") def __native_namespace__(self) -> ModuleType: if self._implementation is Implementation.DASK: return self._implementation.to_native_namespace() msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def __narwhals_namespace__(self) -> DaskNamespace: from narwhals._dask.namespace import DaskNamespace return DaskNamespace(version=self._version) def __narwhals_lazyframe__(self) -> Self: return self def _with_version(self, version: Version) -> Self: return self.__class__(self.native, version=version) def _with_native(self, df: Any) -> Self: return self.__class__(df, version=self._version) def _iter_columns(self) -> Iterator[dx.Series]: for _col, ser in self.native.items(): # noqa: PERF102 yield ser def with_columns(self, *exprs: DaskExpr) -> Self: new_series = evaluate_exprs(self, *exprs) return self._with_native(self.native.assign(**dict(new_series))) def collect( self, backend: Implementation | None, **kwargs: Any ) -> CompliantDataFrameAny: result = self.native.compute(**kwargs) if backend is None or backend is Implementation.PANDAS: from narwhals._pandas_like.dataframe import PandasLikeDataFrame return PandasLikeDataFrame( result, implementation=Implementation.PANDAS, validate_backend_version=True, version=self._version, validate_column_names=True, ) if backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame return PolarsDataFrame( pl.from_pandas(result), validate_backend_version=True, version=self._version, ) if backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame return ArrowDataFrame( pa.Table.from_pandas(result), validate_backend_version=True, version=self._version, validate_column_names=True, ) msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover @property def columns(self) -> list[str]: if self._cached_columns is None: self._cached_columns = ( list(self.schema) if self._cached_schema is not None else self.native.columns.tolist() ) return self._cached_columns def filter(self, predicate: DaskExpr) -> Self: # `[0]` is safe as the predicate's expression only returns a single column mask = predicate(self)[0] return self._with_native(self.native.loc[mask]) def simple_select(self, *column_names: str) -> Self: df: Incomplete = self.native native = select_columns_by_name(df, list(column_names), self._implementation) return self._with_native(native) def aggregate(self, *exprs: DaskExpr) -> Self: new_series = evaluate_exprs(self, *exprs) df = dd.concat([val.rename(name) for name, val in new_series], axis=1) return self._with_native(df) def select(self, *exprs: DaskExpr) -> Self: new_series = evaluate_exprs(self, *exprs) df: Incomplete = self.native df = select_columns_by_name( df.assign(**dict(new_series)), [s[0] for s in new_series], self._implementation, ) return self._with_native(df) def drop_nulls(self, subset: Sequence[str] | None) -> Self: if subset is None: return self._with_native(self.native.dropna()) plx = self.__narwhals_namespace__() mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) return self.filter(mask) @property def schema(self) -> dict[str, DType]: if self._cached_schema is None: native_dtypes = self.native.dtypes self._cached_schema = { col: native_to_narwhals_dtype( native_dtypes[col], self._version, self._implementation ) for col in self.native.columns } return self._cached_schema def collect_schema(self) -> dict[str, DType]: return self.schema def drop(self, columns: Sequence[str], *, strict: bool) -> Self: to_drop = parse_columns_to_drop(self, columns, strict=strict) return self._with_native(self.native.drop(columns=to_drop)) def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: # Implementation is based on the following StackOverflow reply: # https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409 if order_by is None: return self._with_native(add_row_index(self.native, name)) else: plx = self.__narwhals_namespace__() columns = self.columns const_expr = ( plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL) ) row_index_expr = ( plx.col(name) .cum_sum(reverse=False) .over(partition_by=[], order_by=order_by) - 1 ) return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns)) def rename(self, mapping: Mapping[str, str]) -> Self: return self._with_native(self.native.rename(columns=mapping)) def head(self, n: int) -> Self: return self._with_native(self.native.head(n=n, compute=False, npartitions=-1)) def unique( self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: if subset and (error := self._check_columns_exist(subset)): raise error if keep == "none": subset = subset or self.columns token = generate_temporary_column_name(n_bytes=8, columns=subset) ser = self.native.groupby(subset).size().rename(token) ser = ser[ser == 1] unique = ser.reset_index().drop(columns=token) result = self.native.merge(unique, on=subset, how="inner") else: mapped_keep = {"any": "first"}.get(keep, keep) result = self.native.drop_duplicates(subset=subset, keep=mapped_keep) return self._with_native(result) def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: if isinstance(descending, bool): ascending: bool | list[bool] = not descending else: ascending = [not d for d in descending] position = "last" if nulls_last else "first" return self._with_native( self.native.sort_values(list(by), ascending=ascending, na_position=position) ) def _join_inner( self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str ) -> dd.DataFrame: return self.native.merge( other.native, left_on=left_on, right_on=right_on, how="inner", suffixes=("", suffix), ) def _join_left( self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str ) -> dd.DataFrame: result_native = self.native.merge( other.native, how="left", left_on=left_on, right_on=right_on, suffixes=("", suffix), ) extra = [ right_key if right_key not in self.columns else f"{right_key}{suffix}" for left_key, right_key in zip(left_on, right_on) if right_key != left_key ] return result_native.drop(columns=extra) def _join_full( self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str ) -> dd.DataFrame: # dask does not retain keys post-join # we must append the suffix to each key before-hand right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix) other_native = other.native.rename(columns=right_on_mapper) check_column_names_are_unique(other_native.columns) right_suffixed = list(right_on_mapper.values()) return self.native.merge( other_native, left_on=left_on, right_on=right_suffixed, how="outer", suffixes=("", suffix), ) def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame: key_token = generate_temporary_column_name( n_bytes=8, columns=(*self.columns, *other.columns) ) return ( self.native.assign(**{key_token: 0}) .merge( other.native.assign(**{key_token: 0}), how="inner", left_on=key_token, right_on=key_token, suffixes=("", suffix), ) .drop(columns=key_token) ) def _join_semi( self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] ) -> dd.DataFrame: other_native = self._join_filter_rename( other=other, columns_to_select=list(right_on), columns_mapping=dict(zip(right_on, left_on)), ) return self.native.merge( other_native, how="inner", left_on=left_on, right_on=left_on ) def _join_anti( self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] ) -> dd.DataFrame: indicator_token = generate_temporary_column_name( n_bytes=8, columns=(*self.columns, *other.columns) ) other_native = self._join_filter_rename( other=other, columns_to_select=list(right_on), columns_mapping=dict(zip(right_on, left_on)), ) df = self.native.merge( other_native, how="left", indicator=indicator_token, # pyright: ignore[reportArgumentType] left_on=left_on, right_on=left_on, ) return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token]) def _join_filter_rename( self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str] ) -> dd.DataFrame: """Helper function to avoid creating extra columns and row duplication. Used in `"anti"` and `"semi`" join's. Notice that a native object is returned. """ other_native: Incomplete = other.native # rename to avoid creating extra columns in join return ( select_columns_by_name(other_native, columns_to_select, self._implementation) .rename(columns=columns_mapping) .drop_duplicates() ) def join( self, other: Self, *, how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: if how == "cross": result = self._join_cross(other=other, suffix=suffix) elif left_on is None or right_on is None: # pragma: no cover raise ValueError(left_on, right_on) elif how == "inner": result = self._join_inner( other=other, left_on=left_on, right_on=right_on, suffix=suffix ) elif how == "anti": result = self._join_anti(other=other, left_on=left_on, right_on=right_on) elif how == "semi": result = self._join_semi(other=other, left_on=left_on, right_on=right_on) elif how == "left": result = self._join_left( other=other, left_on=left_on, right_on=right_on, suffix=suffix ) elif how == "full": result = self._join_full( other=other, left_on=left_on, right_on=right_on, suffix=suffix ) else: assert_never(how) return self._with_native(result) def join_asof( self, other: Self, *, left_on: str, right_on: str, by_left: Sequence[str] | None, by_right: Sequence[str] | None, strategy: AsofJoinStrategy, suffix: str, ) -> Self: plx = self.__native_namespace__() return self._with_native( plx.merge_asof( self.native, other.native, left_on=left_on, right_on=right_on, left_by=by_left, right_by=by_right, direction=strategy, suffixes=("", suffix), ) ) def group_by( self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool ) -> DaskLazyGroupBy: from narwhals._dask.group_by import DaskLazyGroupBy return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def tail(self, n: int) -> Self: # pragma: no cover native_frame = self.native n_partitions = native_frame.npartitions if n_partitions == 1: return self._with_native(self.native.tail(n=n, compute=False)) else: msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions." raise NotImplementedError(msg) def gather_every(self, n: int, offset: int) -> Self: row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) plx = self.__narwhals_namespace__() return ( self.with_row_index(row_index_token, order_by=None) .filter( (plx.col(row_index_token) >= offset) & ((plx.col(row_index_token) - offset) % n == 0) ) .drop([row_index_token], strict=False) ) def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: return self._with_native( self.native.melt( id_vars=index, value_vars=on, var_name=variable_name, value_name=value_name, ) ) def sink_parquet(self, file: str | Path | BytesIO) -> None: self.native.to_parquet(file) explode = not_implemented()