486 lines
17 KiB
Python
486 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterator, Mapping, Sequence, Sized
|
|
from itertools import chain
|
|
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
|
|
|
|
from narwhals._compliant.typing import (
|
|
CompliantDataFrameAny,
|
|
CompliantExprT_contra,
|
|
CompliantLazyFrameAny,
|
|
CompliantSeriesT,
|
|
EagerExprT,
|
|
EagerSeriesT,
|
|
NativeFrameT,
|
|
NativeSeriesT,
|
|
)
|
|
from narwhals._translate import (
|
|
ArrowConvertible,
|
|
DictConvertible,
|
|
FromNative,
|
|
NumpyConvertible,
|
|
ToNarwhals,
|
|
ToNarwhalsT_co,
|
|
)
|
|
from narwhals._typing_compat import assert_never
|
|
from narwhals._utils import (
|
|
ValidateBackendVersion,
|
|
Version,
|
|
_StoresNative,
|
|
check_columns_exist,
|
|
is_compliant_series,
|
|
is_index_selector,
|
|
is_range,
|
|
is_sequence_like,
|
|
is_sized_multi_index_selector,
|
|
is_slice_index,
|
|
is_slice_none,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
import polars as pl
|
|
import pyarrow as pa
|
|
from typing_extensions import Self, TypeAlias
|
|
|
|
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
|
|
from narwhals._compliant.namespace import EagerNamespace
|
|
from narwhals._translate import IntoArrowTable
|
|
from narwhals._utils import Implementation, _LimitedContext
|
|
from narwhals.dataframe import DataFrame
|
|
from narwhals.dtypes import DType
|
|
from narwhals.exceptions import ColumnNotFoundError
|
|
from narwhals.schema import Schema
|
|
from narwhals.typing import (
|
|
AsofJoinStrategy,
|
|
JoinStrategy,
|
|
LazyUniqueKeepStrategy,
|
|
MultiColSelector,
|
|
MultiIndexSelector,
|
|
PivotAgg,
|
|
SingleIndexSelector,
|
|
SizedMultiIndexSelector,
|
|
SizedMultiNameSelector,
|
|
SizeUnit,
|
|
UniqueKeepStrategy,
|
|
_2DArray,
|
|
_SliceIndex,
|
|
_SliceName,
|
|
)
|
|
|
|
Incomplete: TypeAlias = Any
|
|
|
|
__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"]
|
|
|
|
T = TypeVar("T")
|
|
|
|
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
|
|
|
|
|
|
class CompliantDataFrame(
|
|
NumpyConvertible["_2DArray", "_2DArray"],
|
|
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
|
|
ArrowConvertible["pa.Table", "IntoArrowTable"],
|
|
_StoresNative[NativeFrameT],
|
|
FromNative[NativeFrameT],
|
|
ToNarwhals[ToNarwhalsT_co],
|
|
Sized,
|
|
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
|
|
):
|
|
_native_frame: NativeFrameT
|
|
_implementation: Implementation
|
|
_version: Version
|
|
|
|
def __narwhals_dataframe__(self) -> Self: ...
|
|
def __narwhals_namespace__(self) -> Any: ...
|
|
@classmethod
|
|
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
data: Mapping[str, Any],
|
|
/,
|
|
*,
|
|
context: _LimitedContext,
|
|
schema: Mapping[str, DType] | Schema | None,
|
|
) -> Self: ...
|
|
@classmethod
|
|
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
|
|
@classmethod
|
|
def from_numpy(
|
|
cls,
|
|
data: _2DArray,
|
|
/,
|
|
*,
|
|
context: _LimitedContext,
|
|
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
|
|
) -> Self: ...
|
|
|
|
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
|
|
def __getitem__(
|
|
self,
|
|
item: tuple[
|
|
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
|
|
MultiColSelector[CompliantSeriesT],
|
|
],
|
|
) -> Self: ...
|
|
def simple_select(self, *column_names: str) -> Self:
|
|
"""`select` where all args are column names."""
|
|
...
|
|
|
|
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
|
|
"""`select` where all args are aggregations or literals.
|
|
|
|
(so, no broadcasting is necessary).
|
|
"""
|
|
# NOTE: Ignore is to avoid an intermittent false positive
|
|
return self.select(*exprs) # pyright: ignore[reportArgumentType]
|
|
|
|
def _with_version(self, version: Version) -> Self: ...
|
|
|
|
@property
|
|
def native(self) -> NativeFrameT:
|
|
return self._native_frame
|
|
|
|
@property
|
|
def columns(self) -> Sequence[str]: ...
|
|
@property
|
|
def schema(self) -> Mapping[str, DType]: ...
|
|
@property
|
|
def shape(self) -> tuple[int, int]: ...
|
|
def clone(self) -> Self: ...
|
|
def collect(
|
|
self, backend: Implementation | None, **kwargs: Any
|
|
) -> CompliantDataFrameAny: ...
|
|
def collect_schema(self) -> Mapping[str, DType]: ...
|
|
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
|
|
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
|
|
def estimated_size(self, unit: SizeUnit) -> int | float: ...
|
|
def explode(self, columns: Sequence[str]) -> Self: ...
|
|
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
|
|
def gather_every(self, n: int, offset: int) -> Self: ...
|
|
def get_column(self, name: str) -> CompliantSeriesT: ...
|
|
def group_by(
|
|
self,
|
|
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
|
*,
|
|
drop_null_keys: bool,
|
|
) -> DataFrameGroupBy[Self, Any]: ...
|
|
def head(self, n: int) -> Self: ...
|
|
def item(self, row: int | None, column: int | str | None) -> Any: ...
|
|
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
|
|
def iter_rows(
|
|
self, *, named: bool, buffer_size: int
|
|
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
|
|
def is_unique(self) -> CompliantSeriesT: ...
|
|
def join(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
how: JoinStrategy,
|
|
left_on: Sequence[str] | None,
|
|
right_on: Sequence[str] | None,
|
|
suffix: str,
|
|
) -> Self: ...
|
|
def join_asof(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
left_on: str,
|
|
right_on: str,
|
|
by_left: Sequence[str] | None,
|
|
by_right: Sequence[str] | None,
|
|
strategy: AsofJoinStrategy,
|
|
suffix: str,
|
|
) -> Self: ...
|
|
def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ...
|
|
def pivot(
|
|
self,
|
|
on: Sequence[str],
|
|
*,
|
|
index: Sequence[str] | None,
|
|
values: Sequence[str] | None,
|
|
aggregate_function: PivotAgg | None,
|
|
sort_columns: bool,
|
|
separator: str,
|
|
) -> Self: ...
|
|
def rename(self, mapping: Mapping[str, str]) -> Self: ...
|
|
def row(self, index: int) -> tuple[Any, ...]: ...
|
|
def rows(
|
|
self, *, named: bool
|
|
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
|
|
def sample(
|
|
self,
|
|
n: int | None,
|
|
*,
|
|
fraction: float | None,
|
|
with_replacement: bool,
|
|
seed: int | None,
|
|
) -> Self: ...
|
|
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
|
|
def sort(
|
|
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
|
|
) -> Self: ...
|
|
def tail(self, n: int) -> Self: ...
|
|
def to_arrow(self) -> pa.Table: ...
|
|
def to_pandas(self) -> pd.DataFrame: ...
|
|
def to_polars(self) -> pl.DataFrame: ...
|
|
@overload
|
|
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
|
|
@overload
|
|
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
|
def to_dict(
|
|
self, *, as_series: bool
|
|
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
|
|
def unique(
|
|
self,
|
|
subset: Sequence[str] | None,
|
|
*,
|
|
keep: UniqueKeepStrategy,
|
|
maintain_order: bool | None = None,
|
|
) -> Self: ...
|
|
def unpivot(
|
|
self,
|
|
on: Sequence[str] | None,
|
|
index: Sequence[str] | None,
|
|
variable_name: str,
|
|
value_name: str,
|
|
) -> Self: ...
|
|
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
|
|
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
|
|
@overload
|
|
def write_csv(self, file: None) -> str: ...
|
|
@overload
|
|
def write_csv(self, file: str | Path | BytesIO) -> None: ...
|
|
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
|
|
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
|
|
|
|
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
|
|
it = (expr._evaluate_aliases(self) for expr in exprs)
|
|
return list(chain.from_iterable(it))
|
|
|
|
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
|
return check_columns_exist(subset, available=self.columns)
|
|
|
|
|
|
class CompliantLazyFrame(
|
|
_StoresNative[NativeFrameT],
|
|
FromNative[NativeFrameT],
|
|
ToNarwhals[ToNarwhalsT_co],
|
|
Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
|
|
):
|
|
_native_frame: NativeFrameT
|
|
_implementation: Implementation
|
|
_version: Version
|
|
|
|
def __narwhals_lazyframe__(self) -> Self: ...
|
|
def __narwhals_namespace__(self) -> Any: ...
|
|
|
|
@classmethod
|
|
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
|
|
|
|
def simple_select(self, *column_names: str) -> Self:
|
|
"""`select` where all args are column names."""
|
|
...
|
|
|
|
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
|
|
"""`select` where all args are aggregations or literals.
|
|
|
|
(so, no broadcasting is necessary).
|
|
"""
|
|
...
|
|
|
|
def _with_version(self, version: Version) -> Self: ...
|
|
|
|
@property
|
|
def native(self) -> NativeFrameT:
|
|
return self._native_frame
|
|
|
|
@property
|
|
def columns(self) -> Sequence[str]: ...
|
|
@property
|
|
def schema(self) -> Mapping[str, DType]: ...
|
|
def _iter_columns(self) -> Iterator[Any]: ...
|
|
def collect(
|
|
self, backend: Implementation | None, **kwargs: Any
|
|
) -> CompliantDataFrameAny: ...
|
|
def collect_schema(self) -> Mapping[str, DType]: ...
|
|
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
|
|
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
|
|
def explode(self, columns: Sequence[str]) -> Self: ...
|
|
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
|
|
def group_by(
|
|
self,
|
|
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
|
*,
|
|
drop_null_keys: bool,
|
|
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
|
|
def head(self, n: int) -> Self: ...
|
|
def join(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
how: JoinStrategy,
|
|
left_on: Sequence[str] | None,
|
|
right_on: Sequence[str] | None,
|
|
suffix: str,
|
|
) -> Self: ...
|
|
def join_asof(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
left_on: str,
|
|
right_on: str,
|
|
by_left: Sequence[str] | None,
|
|
by_right: Sequence[str] | None,
|
|
strategy: AsofJoinStrategy,
|
|
suffix: str,
|
|
) -> Self: ...
|
|
def rename(self, mapping: Mapping[str, str]) -> Self: ...
|
|
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
|
|
def sink_parquet(self, file: str | Path | BytesIO) -> None: ...
|
|
def sort(
|
|
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
|
|
) -> Self: ...
|
|
def unique(
|
|
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
|
) -> Self: ...
|
|
def unpivot(
|
|
self,
|
|
on: Sequence[str] | None,
|
|
index: Sequence[str] | None,
|
|
variable_name: str,
|
|
value_name: str,
|
|
) -> Self: ...
|
|
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
|
|
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
|
|
def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
|
|
result = expr(self)
|
|
assert len(result) == 1 # debug assertion # noqa: S101
|
|
return result[0]
|
|
|
|
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
|
|
it = (expr._evaluate_aliases(self) for expr in exprs)
|
|
return list(chain.from_iterable(it))
|
|
|
|
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
|
return check_columns_exist(subset, available=self.columns)
|
|
|
|
|
|
class EagerDataFrame(
|
|
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
|
|
CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
|
|
ValidateBackendVersion,
|
|
Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
|
|
):
|
|
@property
|
|
def _backend_version(self) -> tuple[int, ...]:
|
|
return self._implementation._backend_version()
|
|
|
|
def __narwhals_namespace__(
|
|
self,
|
|
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ...
|
|
|
|
def to_narwhals(self) -> DataFrame[NativeFrameT]:
|
|
return self._version.dataframe(self, level="full")
|
|
|
|
def _with_native(
|
|
self, df: NativeFrameT, *, validate_column_names: bool = True
|
|
) -> Self: ...
|
|
|
|
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
|
|
"""Evaluate `expr` and ensure it has a **single** output."""
|
|
result: Sequence[EagerSeriesT] = expr(self)
|
|
assert len(result) == 1 # debug assertion # noqa: S101
|
|
return result[0]
|
|
|
|
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
|
|
# NOTE: Ignore is to avoid an intermittent false positive
|
|
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
|
|
|
|
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
|
|
"""Return list of raw columns.
|
|
|
|
For eager backends we alias operations at each step.
|
|
|
|
As a safety precaution, here we can check that the expected result names match those
|
|
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
|
|
|
|
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
|
|
"""
|
|
aliases = expr._evaluate_aliases(self)
|
|
result = expr(self)
|
|
if list(aliases) != (
|
|
result_aliases := [s.name for s in result]
|
|
): # pragma: no cover
|
|
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
|
|
raise AssertionError(msg)
|
|
return result
|
|
|
|
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
|
|
"""Extract native Series, broadcasting to `len(self)` if necessary."""
|
|
...
|
|
|
|
@staticmethod
|
|
def _numpy_column_names(
|
|
data: _2DArray, columns: Sequence[str] | None, /
|
|
) -> list[str]:
|
|
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
|
|
|
|
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
|
|
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
|
|
def _select_multi_index(
|
|
self, columns: SizedMultiIndexSelector[NativeSeriesT]
|
|
) -> Self: ...
|
|
def _select_multi_name(
|
|
self, columns: SizedMultiNameSelector[NativeSeriesT]
|
|
) -> Self: ...
|
|
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
|
|
def _select_slice_name(self, columns: _SliceName) -> Self: ...
|
|
def __getitem__( # noqa: C901, PLR0912
|
|
self,
|
|
item: tuple[
|
|
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
|
|
MultiColSelector[EagerSeriesT],
|
|
],
|
|
) -> Self:
|
|
rows, columns = item
|
|
compliant = self
|
|
if not is_slice_none(columns):
|
|
if isinstance(columns, Sized) and len(columns) == 0:
|
|
return compliant.select()
|
|
if is_index_selector(columns):
|
|
if is_slice_index(columns) or is_range(columns):
|
|
compliant = compliant._select_slice_index(columns)
|
|
elif is_compliant_series(columns):
|
|
compliant = self._select_multi_index(columns.native)
|
|
else:
|
|
compliant = compliant._select_multi_index(columns)
|
|
elif isinstance(columns, slice):
|
|
compliant = compliant._select_slice_name(columns)
|
|
elif is_compliant_series(columns):
|
|
compliant = self._select_multi_name(columns.native)
|
|
elif is_sequence_like(columns):
|
|
compliant = self._select_multi_name(columns)
|
|
else:
|
|
assert_never(columns)
|
|
|
|
if not is_slice_none(rows):
|
|
if isinstance(rows, int):
|
|
compliant = compliant._gather([rows])
|
|
elif isinstance(rows, (slice, range)):
|
|
compliant = compliant._gather_slice(rows)
|
|
elif is_compliant_series(rows):
|
|
compliant = compliant._gather(rows.native)
|
|
elif is_sized_multi_index_selector(rows):
|
|
compliant = compliant._gather(rows)
|
|
else:
|
|
assert_never(rows)
|
|
|
|
return compliant
|
|
|
|
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
|
return self.write_parquet(file)
|