483 lines
17 KiB
Python
483 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import dask.dataframe as dd
|
|
|
|
from narwhals._dask.utils import add_row_index, evaluate_exprs
|
|
from narwhals._expression_parsing import ExprKind
|
|
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
|
|
from narwhals._typing_compat import assert_never
|
|
from narwhals._utils import (
|
|
Implementation,
|
|
ValidateBackendVersion,
|
|
_remap_full_join_keys,
|
|
check_column_names_are_unique,
|
|
generate_temporary_column_name,
|
|
not_implemented,
|
|
parse_columns_to_drop,
|
|
)
|
|
from narwhals.typing import CompliantLazyFrame
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterator, Mapping, Sequence
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
|
|
import dask.dataframe.dask_expr as dx
|
|
from typing_extensions import Self, TypeAlias, TypeIs
|
|
|
|
from narwhals._compliant.typing import CompliantDataFrameAny
|
|
from narwhals._dask.expr import DaskExpr
|
|
from narwhals._dask.group_by import DaskLazyGroupBy
|
|
from narwhals._dask.namespace import DaskNamespace
|
|
from narwhals._utils import Version, _LimitedContext
|
|
from narwhals.dataframe import LazyFrame
|
|
from narwhals.dtypes import DType
|
|
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
|
|
|
Incomplete: TypeAlias = "Any"
|
|
"""Using `_pandas_like` utils with `_dask`.
|
|
|
|
Typing this correctly will complicate the `_pandas_like`-side.
|
|
Very low priority until `dask` adds typing.
|
|
"""
|
|
|
|
|
|
class DaskLazyFrame(
|
|
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
|
|
ValidateBackendVersion,
|
|
):
|
|
_implementation = Implementation.DASK
|
|
|
|
def __init__(
|
|
self,
|
|
native_dataframe: dd.DataFrame,
|
|
*,
|
|
version: Version,
|
|
validate_backend_version: bool = False,
|
|
) -> None:
|
|
self._native_frame: dd.DataFrame = native_dataframe
|
|
self._version = version
|
|
self._cached_schema: dict[str, DType] | None = None
|
|
self._cached_columns: list[str] | None = None
|
|
if validate_backend_version:
|
|
self._validate_backend_version()
|
|
|
|
@staticmethod
|
|
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
|
|
return isinstance(obj, dd.DataFrame)
|
|
|
|
@classmethod
|
|
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
|
|
return cls(data, version=context._version)
|
|
|
|
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
|
|
return self._version.lazyframe(self, level="lazy")
|
|
|
|
def __native_namespace__(self) -> ModuleType:
|
|
if self._implementation is Implementation.DASK:
|
|
return self._implementation.to_native_namespace()
|
|
|
|
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
|
|
raise AssertionError(msg)
|
|
|
|
def __narwhals_namespace__(self) -> DaskNamespace:
|
|
from narwhals._dask.namespace import DaskNamespace
|
|
|
|
return DaskNamespace(version=self._version)
|
|
|
|
def __narwhals_lazyframe__(self) -> Self:
|
|
return self
|
|
|
|
def _with_version(self, version: Version) -> Self:
|
|
return self.__class__(self.native, version=version)
|
|
|
|
def _with_native(self, df: Any) -> Self:
|
|
return self.__class__(df, version=self._version)
|
|
|
|
def _iter_columns(self) -> Iterator[dx.Series]:
|
|
for _col, ser in self.native.items(): # noqa: PERF102
|
|
yield ser
|
|
|
|
def with_columns(self, *exprs: DaskExpr) -> Self:
|
|
new_series = evaluate_exprs(self, *exprs)
|
|
return self._with_native(self.native.assign(**dict(new_series)))
|
|
|
|
def collect(
|
|
self, backend: Implementation | None, **kwargs: Any
|
|
) -> CompliantDataFrameAny:
|
|
result = self.native.compute(**kwargs)
|
|
|
|
if backend is None or backend is Implementation.PANDAS:
|
|
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
|
|
|
return PandasLikeDataFrame(
|
|
result,
|
|
implementation=Implementation.PANDAS,
|
|
validate_backend_version=True,
|
|
version=self._version,
|
|
validate_column_names=True,
|
|
)
|
|
|
|
if backend is Implementation.POLARS:
|
|
import polars as pl # ignore-banned-import
|
|
|
|
from narwhals._polars.dataframe import PolarsDataFrame
|
|
|
|
return PolarsDataFrame(
|
|
pl.from_pandas(result),
|
|
validate_backend_version=True,
|
|
version=self._version,
|
|
)
|
|
|
|
if backend is Implementation.PYARROW:
|
|
import pyarrow as pa # ignore-banned-import
|
|
|
|
from narwhals._arrow.dataframe import ArrowDataFrame
|
|
|
|
return ArrowDataFrame(
|
|
pa.Table.from_pandas(result),
|
|
validate_backend_version=True,
|
|
version=self._version,
|
|
validate_column_names=True,
|
|
)
|
|
|
|
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
|
raise ValueError(msg) # pragma: no cover
|
|
|
|
@property
|
|
def columns(self) -> list[str]:
|
|
if self._cached_columns is None:
|
|
self._cached_columns = (
|
|
list(self.schema)
|
|
if self._cached_schema is not None
|
|
else self.native.columns.tolist()
|
|
)
|
|
return self._cached_columns
|
|
|
|
def filter(self, predicate: DaskExpr) -> Self:
|
|
# `[0]` is safe as the predicate's expression only returns a single column
|
|
mask = predicate(self)[0]
|
|
return self._with_native(self.native.loc[mask])
|
|
|
|
def simple_select(self, *column_names: str) -> Self:
|
|
df: Incomplete = self.native
|
|
native = select_columns_by_name(df, list(column_names), self._implementation)
|
|
return self._with_native(native)
|
|
|
|
def aggregate(self, *exprs: DaskExpr) -> Self:
|
|
new_series = evaluate_exprs(self, *exprs)
|
|
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
|
|
return self._with_native(df)
|
|
|
|
def select(self, *exprs: DaskExpr) -> Self:
|
|
new_series = evaluate_exprs(self, *exprs)
|
|
df: Incomplete = self.native
|
|
df = select_columns_by_name(
|
|
df.assign(**dict(new_series)),
|
|
[s[0] for s in new_series],
|
|
self._implementation,
|
|
)
|
|
return self._with_native(df)
|
|
|
|
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
|
if subset is None:
|
|
return self._with_native(self.native.dropna())
|
|
plx = self.__narwhals_namespace__()
|
|
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
|
return self.filter(mask)
|
|
|
|
@property
|
|
def schema(self) -> dict[str, DType]:
|
|
if self._cached_schema is None:
|
|
native_dtypes = self.native.dtypes
|
|
self._cached_schema = {
|
|
col: native_to_narwhals_dtype(
|
|
native_dtypes[col], self._version, self._implementation
|
|
)
|
|
for col in self.native.columns
|
|
}
|
|
return self._cached_schema
|
|
|
|
def collect_schema(self) -> dict[str, DType]:
|
|
return self.schema
|
|
|
|
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
|
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
|
|
|
return self._with_native(self.native.drop(columns=to_drop))
|
|
|
|
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
|
# Implementation is based on the following StackOverflow reply:
|
|
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
|
|
if order_by is None:
|
|
return self._with_native(add_row_index(self.native, name))
|
|
else:
|
|
plx = self.__narwhals_namespace__()
|
|
columns = self.columns
|
|
const_expr = (
|
|
plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
|
|
)
|
|
row_index_expr = (
|
|
plx.col(name)
|
|
.cum_sum(reverse=False)
|
|
.over(partition_by=[], order_by=order_by)
|
|
- 1
|
|
)
|
|
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
|
|
|
|
def rename(self, mapping: Mapping[str, str]) -> Self:
|
|
return self._with_native(self.native.rename(columns=mapping))
|
|
|
|
def head(self, n: int) -> Self:
|
|
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
|
|
|
|
def unique(
|
|
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
|
) -> Self:
|
|
if subset and (error := self._check_columns_exist(subset)):
|
|
raise error
|
|
if keep == "none":
|
|
subset = subset or self.columns
|
|
token = generate_temporary_column_name(n_bytes=8, columns=subset)
|
|
ser = self.native.groupby(subset).size().rename(token)
|
|
ser = ser[ser == 1]
|
|
unique = ser.reset_index().drop(columns=token)
|
|
result = self.native.merge(unique, on=subset, how="inner")
|
|
else:
|
|
mapped_keep = {"any": "first"}.get(keep, keep)
|
|
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
|
|
return self._with_native(result)
|
|
|
|
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
|
if isinstance(descending, bool):
|
|
ascending: bool | list[bool] = not descending
|
|
else:
|
|
ascending = [not d for d in descending]
|
|
position = "last" if nulls_last else "first"
|
|
return self._with_native(
|
|
self.native.sort_values(list(by), ascending=ascending, na_position=position)
|
|
)
|
|
|
|
def _join_inner(
|
|
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
|
) -> dd.DataFrame:
|
|
return self.native.merge(
|
|
other.native,
|
|
left_on=left_on,
|
|
right_on=right_on,
|
|
how="inner",
|
|
suffixes=("", suffix),
|
|
)
|
|
|
|
def _join_left(
|
|
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
|
) -> dd.DataFrame:
|
|
result_native = self.native.merge(
|
|
other.native,
|
|
how="left",
|
|
left_on=left_on,
|
|
right_on=right_on,
|
|
suffixes=("", suffix),
|
|
)
|
|
extra = [
|
|
right_key if right_key not in self.columns else f"{right_key}{suffix}"
|
|
for left_key, right_key in zip(left_on, right_on)
|
|
if right_key != left_key
|
|
]
|
|
return result_native.drop(columns=extra)
|
|
|
|
def _join_full(
|
|
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
|
) -> dd.DataFrame:
|
|
# dask does not retain keys post-join
|
|
# we must append the suffix to each key before-hand
|
|
|
|
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
|
|
other_native = other.native.rename(columns=right_on_mapper)
|
|
check_column_names_are_unique(other_native.columns)
|
|
right_suffixed = list(right_on_mapper.values())
|
|
return self.native.merge(
|
|
other_native,
|
|
left_on=left_on,
|
|
right_on=right_suffixed,
|
|
how="outer",
|
|
suffixes=("", suffix),
|
|
)
|
|
|
|
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
|
|
key_token = generate_temporary_column_name(
|
|
n_bytes=8, columns=(*self.columns, *other.columns)
|
|
)
|
|
return (
|
|
self.native.assign(**{key_token: 0})
|
|
.merge(
|
|
other.native.assign(**{key_token: 0}),
|
|
how="inner",
|
|
left_on=key_token,
|
|
right_on=key_token,
|
|
suffixes=("", suffix),
|
|
)
|
|
.drop(columns=key_token)
|
|
)
|
|
|
|
def _join_semi(
|
|
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
|
) -> dd.DataFrame:
|
|
other_native = self._join_filter_rename(
|
|
other=other,
|
|
columns_to_select=list(right_on),
|
|
columns_mapping=dict(zip(right_on, left_on)),
|
|
)
|
|
return self.native.merge(
|
|
other_native, how="inner", left_on=left_on, right_on=left_on
|
|
)
|
|
|
|
def _join_anti(
|
|
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
|
) -> dd.DataFrame:
|
|
indicator_token = generate_temporary_column_name(
|
|
n_bytes=8, columns=(*self.columns, *other.columns)
|
|
)
|
|
other_native = self._join_filter_rename(
|
|
other=other,
|
|
columns_to_select=list(right_on),
|
|
columns_mapping=dict(zip(right_on, left_on)),
|
|
)
|
|
df = self.native.merge(
|
|
other_native,
|
|
how="left",
|
|
indicator=indicator_token, # pyright: ignore[reportArgumentType]
|
|
left_on=left_on,
|
|
right_on=left_on,
|
|
)
|
|
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
|
|
|
|
def _join_filter_rename(
|
|
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
|
|
) -> dd.DataFrame:
|
|
"""Helper function to avoid creating extra columns and row duplication.
|
|
|
|
Used in `"anti"` and `"semi`" join's.
|
|
|
|
Notice that a native object is returned.
|
|
"""
|
|
other_native: Incomplete = other.native
|
|
# rename to avoid creating extra columns in join
|
|
return (
|
|
select_columns_by_name(other_native, columns_to_select, self._implementation)
|
|
.rename(columns=columns_mapping)
|
|
.drop_duplicates()
|
|
)
|
|
|
|
def join(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
how: JoinStrategy,
|
|
left_on: Sequence[str] | None,
|
|
right_on: Sequence[str] | None,
|
|
suffix: str,
|
|
) -> Self:
|
|
if how == "cross":
|
|
result = self._join_cross(other=other, suffix=suffix)
|
|
|
|
elif left_on is None or right_on is None: # pragma: no cover
|
|
raise ValueError(left_on, right_on)
|
|
|
|
elif how == "inner":
|
|
result = self._join_inner(
|
|
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
|
)
|
|
elif how == "anti":
|
|
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
|
|
elif how == "semi":
|
|
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
|
|
elif how == "left":
|
|
result = self._join_left(
|
|
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
|
)
|
|
elif how == "full":
|
|
result = self._join_full(
|
|
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
|
)
|
|
else:
|
|
assert_never(how)
|
|
return self._with_native(result)
|
|
|
|
def join_asof(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
left_on: str,
|
|
right_on: str,
|
|
by_left: Sequence[str] | None,
|
|
by_right: Sequence[str] | None,
|
|
strategy: AsofJoinStrategy,
|
|
suffix: str,
|
|
) -> Self:
|
|
plx = self.__native_namespace__()
|
|
return self._with_native(
|
|
plx.merge_asof(
|
|
self.native,
|
|
other.native,
|
|
left_on=left_on,
|
|
right_on=right_on,
|
|
left_by=by_left,
|
|
right_by=by_right,
|
|
direction=strategy,
|
|
suffixes=("", suffix),
|
|
)
|
|
)
|
|
|
|
def group_by(
|
|
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
|
|
) -> DaskLazyGroupBy:
|
|
from narwhals._dask.group_by import DaskLazyGroupBy
|
|
|
|
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
|
|
|
def tail(self, n: int) -> Self: # pragma: no cover
|
|
native_frame = self.native
|
|
n_partitions = native_frame.npartitions
|
|
|
|
if n_partitions == 1:
|
|
return self._with_native(self.native.tail(n=n, compute=False))
|
|
else:
|
|
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
|
|
raise NotImplementedError(msg)
|
|
|
|
def gather_every(self, n: int, offset: int) -> Self:
|
|
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
|
plx = self.__narwhals_namespace__()
|
|
return (
|
|
self.with_row_index(row_index_token, order_by=None)
|
|
.filter(
|
|
(plx.col(row_index_token) >= offset)
|
|
& ((plx.col(row_index_token) - offset) % n == 0)
|
|
)
|
|
.drop([row_index_token], strict=False)
|
|
)
|
|
|
|
def unpivot(
|
|
self,
|
|
on: Sequence[str] | None,
|
|
index: Sequence[str] | None,
|
|
variable_name: str,
|
|
value_name: str,
|
|
) -> Self:
|
|
return self._with_native(
|
|
self.native.melt(
|
|
id_vars=index,
|
|
value_vars=on,
|
|
var_name=variable_name,
|
|
value_name=value_name,
|
|
)
|
|
)
|
|
|
|
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
|
self.native.to_parquet(file)
|
|
|
|
explode = not_implemented()
|