team-10/venv/Lib/site-packages/narwhals/_polars/dataframe.py
2025-08-02 02:00:33 +02:00

645 lines
22 KiB
Python

from __future__ import annotations
from collections.abc import Iterator, Mapping, Sequence, Sized
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload
import polars as pl
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.series import PolarsSeries
from narwhals._polars.utils import (
catch_polars_exception,
extract_args_kwargs,
native_to_narwhals_dtype,
)
from narwhals._utils import (
Implementation,
_into_arrow_table,
check_columns_exist,
convert_str_slice_to_int_slice,
is_compliant_series,
is_index_selector,
is_range,
is_sequence_like,
is_slice_index,
is_slice_none,
parse_columns_to_drop,
requires,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ColumnNotFoundError
if TYPE_CHECKING:
from types import ModuleType
from typing import Callable
import pandas as pd
import pyarrow as pa
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy
from narwhals._translate import IntoArrowTable
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import DataFrame, LazyFrame
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import (
JoinStrategy,
MultiColSelector,
MultiIndexSelector,
PivotAgg,
SingleIndexSelector,
_2DArray,
)
T = TypeVar("T")
R = TypeVar("R")
Method: TypeAlias = "Callable[..., R]"
"""Generic alias representing all methods implemented via `__getattr__`.
Where `R` is the return type.
"""
# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly.
INHERITED_METHODS = frozenset(
[
"clone",
"drop_nulls",
"estimated_size",
"explode",
"filter",
"gather_every",
"head",
"is_unique",
"item",
"iter_rows",
"join_asof",
"rename",
"row",
"rows",
"sample",
"select",
"sink_parquet",
"sort",
"tail",
"to_arrow",
"to_pandas",
"unique",
"with_columns",
"write_csv",
"write_parquet",
]
)
NativePolarsFrame = TypeVar("NativePolarsFrame", pl.DataFrame, pl.LazyFrame)
class PolarsBaseFrame(Generic[NativePolarsFrame]):
drop_nulls: Method[Self]
explode: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
head: Method[Self]
join_asof: Method[Self]
rename: Method[Self]
select: Method[Self]
sort: Method[Self]
tail: Method[Self]
unique: Method[Self]
with_columns: Method[Self]
_implementation = Implementation.POLARS
def __init__(
self,
df: NativePolarsFrame,
*,
version: Version,
validate_backend_version: bool = False,
) -> None:
self._native_frame: NativePolarsFrame = df
self._version = version
if validate_backend_version:
self._validate_backend_version()
def _validate_backend_version(self) -> None:
"""Raise if installed version below `nw._utils.MIN_VERSIONS`.
**Only use this when moving between backends.**
Otherwise, the validation will have taken place already.
"""
_ = self._implementation._backend_version()
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def native(self) -> NativePolarsFrame:
return self._native_frame
@property
def columns(self) -> list[str]:
return self.native.columns
def __narwhals_namespace__(self) -> PolarsNamespace:
return PolarsNamespace(version=self._version)
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _with_native(self, df: NativePolarsFrame) -> Self:
return self.__class__(df, version=self._version)
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
@classmethod
def from_native(cls, data: NativePolarsFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist( # pragma: no cover
subset, available=self.columns
)
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: Any) -> Self:
return self.select(*exprs)
@property
def schema(self) -> dict[str, DType]:
return self.collect_schema()
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_native = (
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
)
return self._with_native(
self.native.join(
other=other.native,
how=how_native, # type: ignore[arg-type]
left_on=left_on,
right_on=right_on,
suffix=suffix,
)
)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
variable_name=variable_name,
value_name=value_name,
)
)
return self._with_native(
self.native.unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)
)
def collect_schema(self) -> dict[str, DType]:
df = self.native
schema = df.schema if self._backend_version < (1,) else df.collect_schema()
return {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in schema.items()
}
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
frame = self.native
if order_by is None:
result = frame.with_row_index(name)
else:
end = pl.count() if self._backend_version < (0, 20, 5) else pl.len()
result = frame.select(
pl.int_range(start=0, end=end).sort_by(order_by).alias(name), pl.all()
)
return self._with_native(result)
class PolarsDataFrame(PolarsBaseFrame[pl.DataFrame]):
clone: Method[Self]
collect: Method[CompliantDataFrameAny]
estimated_size: Method[int | float]
gather_every: Method[Self]
item: Method[Any]
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
is_unique: Method[PolarsSeries]
row: Method[tuple[Any, ...]]
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
sample: Method[Self]
to_arrow: Method[pa.Table]
to_pandas: Method[pd.DataFrame]
# NOTE: `write_csv` requires an `@overload` for `str | None`
# Can't do that here 😟
write_csv: Method[Any]
write_parquet: Method[None]
# CompliantDataFrame
_evaluate_aliases: Any
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
if context._implementation._backend_version() >= (1, 3):
native = pl.DataFrame(data)
else: # pragma: no cover
native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
return cls.from_native(native, context=context)
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = Schema(schema).to_polars() if schema is not None else schema
return cls.from_native(pl.from_dict(data, pl_schema), context=context)
@staticmethod
def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]:
return isinstance(obj, pl.DataFrame)
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext, # NOTE: Maybe only `Implementation`?
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = (
Schema(schema).to_polars()
if isinstance(schema, (Mapping, Schema))
else schema
)
return cls.from_native(pl.from_numpy(data, pl_schema), context=context)
def to_narwhals(self) -> DataFrame[pl.DataFrame]:
return self._version.dataframe(self, level="full")
def __repr__(self) -> str: # pragma: no cover
return "PolarsDataFrame"
def __narwhals_dataframe__(self) -> Self:
return self
@overload
def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ...
@overload
def _from_native_object(self, obj: pl.DataFrame) -> Self: ...
@overload
def _from_native_object(self, obj: T) -> T: ...
def _from_native_object(
self, obj: pl.Series | pl.DataFrame | T
) -> Self | PolarsSeries | T:
if isinstance(obj, pl.Series):
return PolarsSeries.from_native(obj, context=self)
if self._is_native(obj):
return self._with_native(obj)
# scalar
return obj
def __len__(self) -> int:
return len(self.native)
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS: # pragma: no cover
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
try:
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
raise ColumnNotFoundError(msg) from e
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
return func
def __array__(
self, dtype: Any | None = None, *, copy: bool | None = None
) -> _2DArray:
if self._backend_version < (0, 20, 28) and copy is not None:
msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28):
return self.native.__array__(dtype)
return self.native.__array__(dtype)
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
return self.native.to_numpy()
@property
def shape(self) -> tuple[int, int]:
return self.native.shape
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[PolarsSeries],
MultiColSelector[PolarsSeries],
],
) -> Any:
rows, columns = item
if self._backend_version > (0, 20, 30):
rows_native = rows.native if is_compliant_series(rows) else rows
columns_native = columns.native if is_compliant_series(columns) else columns
selector = rows_native, columns_native
selected = self.native.__getitem__(selector) # type: ignore[index]
return self._from_native_object(selected)
else: # pragma: no cover
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
# This mostly mirrors the logic in `EagerDataFrame.__getitem__`.
rows = list(rows) if isinstance(rows, tuple) else rows
columns = list(columns) if isinstance(columns, tuple) else columns
if is_numpy_array_1d(columns):
columns = columns.tolist()
native = self.native
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return self.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
native = native.select(
self.columns[slice(columns.start, columns.stop, columns.step)]
)
elif is_compliant_series(columns):
native = native[:, columns.native.to_list()]
else:
native = native[:, columns]
elif isinstance(columns, slice):
native = native.select(
self.columns[
slice(*convert_str_slice_to_int_slice(columns, self.columns))
]
)
elif is_compliant_series(columns):
native = native.select(columns.native.to_list())
elif is_sequence_like(columns):
native = native.select(columns)
else:
msg = f"Unreachable code, got unexpected type: {type(columns)}"
raise AssertionError(msg)
if not is_slice_none(rows):
if isinstance(rows, int):
native = native[[rows], :]
elif isinstance(rows, (slice, range)):
native = native[rows, :]
elif is_compliant_series(rows):
native = native[rows.native, :]
elif is_sequence_like(rows):
native = native[rows, :]
else:
msg = f"Unreachable code, got unexpected type: {type(rows)}"
raise AssertionError(msg)
return self._with_native(native)
def get_column(self, name: str) -> PolarsSeries:
return PolarsSeries.from_native(self.native.get_column(name), context=self)
def iter_columns(self) -> Iterator[PolarsSeries]:
for series in self.native.iter_columns():
yield PolarsSeries.from_native(series, context=self)
def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny:
if backend is None or backend is Implementation.POLARS:
return PolarsLazyFrame.from_native(self.native.lazy(), context=self)
elif backend is Implementation.DUCKDB:
import duckdb # ignore-banned-import
from narwhals._duckdb.dataframe import DuckDBLazyFrame
# NOTE: (F841) is a false positive
df = self.native # noqa: F841
return DuckDBLazyFrame(
duckdb.table("df"), validate_backend_version=True, version=self._version
)
elif backend is Implementation.DASK:
import dask.dataframe as dd # ignore-banned-import
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
dd.from_pandas(self.native.to_pandas()),
validate_backend_version=True,
version=self._version,
)
elif backend.is_ibis():
import ibis # ignore-banned-import
from narwhals._ibis.dataframe import IbisLazyFrame
return IbisLazyFrame(
ibis.memtable(self.native, columns=self.columns),
validate_backend_version=True,
version=self._version,
)
raise AssertionError # pragma: no cover
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
if as_series:
return {
name: PolarsSeries.from_native(col, context=self)
for name, col in self.native.to_dict().items()
}
else:
return self.native.to_dict(as_series=False)
def group_by(
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
) -> PolarsGroupBy:
from narwhals._polars.group_by import PolarsGroupBy
return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(to_drop))
@requires.backend_version((1,))
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self:
try:
result = self.native.pivot(
on,
index=index,
values=values,
aggregate_function=aggregate_function,
sort_columns=sort_columns,
separator=separator,
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
return self._from_native_object(result)
def to_polars(self) -> pl.DataFrame:
return self.native
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
try:
return super().join(
other=other, how=how, left_on=left_on, right_on=right_on, suffix=suffix
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]):
# CompliantLazyFrame
sink_parquet: Method[None]
_evaluate_expr: Any
_evaluate_aliases: Any
@staticmethod
def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]:
return isinstance(obj, pl.LazyFrame)
def to_narwhals(self) -> LazyFrame[pl.LazyFrame]:
return self._version.lazyframe(self, level="lazy")
def __repr__(self) -> str: # pragma: no cover
return "PolarsLazyFrame"
def __narwhals_lazyframe__(self) -> Self:
return self
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS: # pragma: no cover
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
try:
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
raise ColumnNotFoundError(str(e)) from e
return func
def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover
yield from self.collect(self._implementation).iter_columns()
def collect_schema(self) -> dict[str, DType]:
try:
return super().collect_schema()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny:
try:
result = self.native.collect(**kwargs)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
if backend is None or backend is Implementation.POLARS:
return PolarsDataFrame.from_native(result, context=self)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
result.to_pandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
if backend is Implementation.PYARROW:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
result.to_arrow(),
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def group_by(
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
) -> PolarsLazyGroupBy:
from narwhals._polars.group_by import PolarsLazyGroupBy
return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(self.native.drop(columns))
return self._with_native(self.native.drop(columns, strict=strict))