311 lines
11 KiB
Python
311 lines
11 KiB
Python
"""Almost entirely complete, generic `selectors` implementation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from functools import partial
|
|
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
|
|
|
|
from narwhals._compliant.expr import CompliantExpr
|
|
from narwhals._utils import (
|
|
_parse_time_unit_and_time_zone,
|
|
dtype_matches_time_unit_and_time_zone,
|
|
get_column_names,
|
|
is_compliant_dataframe,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Collection, Iterable, Iterator, Sequence
|
|
from datetime import timezone
|
|
|
|
from typing_extensions import Self, TypeAlias, TypeIs
|
|
|
|
from narwhals._compliant.expr import NativeExpr
|
|
from narwhals._compliant.typing import (
|
|
CompliantDataFrameAny,
|
|
CompliantExprAny,
|
|
CompliantFrameAny,
|
|
CompliantLazyFrameAny,
|
|
CompliantSeriesAny,
|
|
CompliantSeriesOrNativeExprAny,
|
|
EvalNames,
|
|
EvalSeries,
|
|
)
|
|
from narwhals._utils import Implementation, Version, _LimitedContext
|
|
from narwhals.dtypes import DType
|
|
from narwhals.typing import TimeUnit
|
|
|
|
__all__ = [
|
|
"CompliantSelector",
|
|
"CompliantSelectorNamespace",
|
|
"EagerSelectorNamespace",
|
|
"LazySelectorNamespace",
|
|
]
|
|
|
|
|
|
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
|
|
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
|
|
ExprT = TypeVar("ExprT", bound="NativeExpr")
|
|
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
|
|
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
|
|
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
|
|
SelectorOrExpr: TypeAlias = (
|
|
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
|
|
)
|
|
|
|
|
|
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
|
|
_implementation: Implementation
|
|
_version: Version
|
|
|
|
@classmethod
|
|
def from_namespace(cls, context: _LimitedContext, /) -> Self:
|
|
obj = cls.__new__(cls)
|
|
obj._implementation = context._implementation
|
|
obj._version = context._version
|
|
return obj
|
|
|
|
@property
|
|
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
|
|
|
|
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
|
|
|
|
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
|
|
|
|
def _iter_columns_dtypes(
|
|
self, df: FrameT, /
|
|
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
|
|
|
|
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
|
|
yield from zip(self._iter_columns(df), df.columns)
|
|
|
|
def _is_dtype(
|
|
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
|
|
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
return [
|
|
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
|
|
]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
|
|
|
|
return self._selector.from_callables(series, names, context=self)
|
|
|
|
def by_dtype(
|
|
self, dtypes: Collection[DType | type[DType]]
|
|
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
|
|
|
|
return self._selector.from_callables(series, names, context=self)
|
|
|
|
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
p = re.compile(pattern)
|
|
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
if (
|
|
is_compliant_dataframe(df)
|
|
and not self._implementation.is_duckdb()
|
|
and not self._implementation.is_ibis()
|
|
):
|
|
return [df.get_column(col) for col in df.columns if p.search(col)]
|
|
|
|
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
return [col for col in df.columns if p.search(col)]
|
|
|
|
return self._selector.from_callables(series, names, context=self)
|
|
|
|
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
|
|
|
|
return self._selector.from_callables(series, names, context=self)
|
|
|
|
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
return self._is_dtype(self._version.dtypes.Categorical)
|
|
|
|
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
return self._is_dtype(self._version.dtypes.String)
|
|
|
|
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
return self._is_dtype(self._version.dtypes.Boolean)
|
|
|
|
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
return list(self._iter_columns(df))
|
|
|
|
return self._selector.from_callables(series, get_column_names, context=self)
|
|
|
|
def datetime(
|
|
self,
|
|
time_unit: TimeUnit | Iterable[TimeUnit] | None,
|
|
time_zone: str | timezone | Iterable[str | timezone | None] | None,
|
|
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
|
|
matches = partial(
|
|
dtype_matches_time_unit_and_time_zone,
|
|
dtypes=self._version.dtypes,
|
|
time_units=time_units,
|
|
time_zones=time_zones,
|
|
)
|
|
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
return [name for name, tp in self._iter_schema(df) if matches(tp)]
|
|
|
|
return self._selector.from_callables(series, names, context=self)
|
|
|
|
|
|
class EagerSelectorNamespace(
|
|
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
|
|
):
|
|
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
|
|
for ser in self._iter_columns(df):
|
|
yield ser.name, ser.dtype
|
|
|
|
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
|
|
yield from df.iter_columns()
|
|
|
|
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
|
|
for ser in self._iter_columns(df):
|
|
yield ser, ser.dtype
|
|
|
|
|
|
class LazySelectorNamespace(
|
|
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
|
|
):
|
|
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
|
|
yield from df.schema.items()
|
|
|
|
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
|
|
yield from df._iter_columns()
|
|
|
|
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
|
|
yield from zip(self._iter_columns(df), df.schema.values())
|
|
|
|
|
|
class CompliantSelector(
|
|
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
|
|
):
|
|
_call: EvalSeries[FrameT, SeriesOrExprT]
|
|
_function_name: str
|
|
_implementation: Implementation
|
|
_version: Version
|
|
|
|
@classmethod
|
|
def from_callables(
|
|
cls,
|
|
call: EvalSeries[FrameT, SeriesOrExprT],
|
|
evaluate_output_names: EvalNames[FrameT],
|
|
*,
|
|
context: _LimitedContext,
|
|
) -> Self:
|
|
obj = cls.__new__(cls)
|
|
obj._call = call
|
|
obj._evaluate_output_names = evaluate_output_names
|
|
obj._alias_output_names = None
|
|
obj._implementation = context._implementation
|
|
obj._version = context._version
|
|
return obj
|
|
|
|
@property
|
|
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
|
|
return self.__narwhals_namespace__().selectors
|
|
|
|
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
|
|
|
def _is_selector(
|
|
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
|
|
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
|
|
return isinstance(other, type(self))
|
|
|
|
@overload
|
|
def __sub__(self, other: Self) -> Self: ...
|
|
@overload
|
|
def __sub__(
|
|
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
|
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
|
def __sub__(
|
|
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
|
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
|
if self._is_selector(other):
|
|
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
|
return [
|
|
x for x, name in zip(self(df), lhs_names) if name not in rhs_names
|
|
]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
|
return [x for x in lhs_names if x not in rhs_names]
|
|
|
|
return self.selectors._selector.from_callables(series, names, context=self)
|
|
return self._to_expr() - other
|
|
|
|
@overload
|
|
def __or__(self, other: Self) -> Self: ...
|
|
@overload
|
|
def __or__(
|
|
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
|
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
|
def __or__(
|
|
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
|
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
|
if self._is_selector(other):
|
|
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
|
return [
|
|
*(x for x, name in zip(self(df), lhs_names) if name not in rhs_names),
|
|
*other(df),
|
|
]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
|
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
|
|
|
|
return self.selectors._selector.from_callables(series, names, context=self)
|
|
return self._to_expr() | other
|
|
|
|
@overload
|
|
def __and__(self, other: Self) -> Self: ...
|
|
@overload
|
|
def __and__(
|
|
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
|
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
|
def __and__(
|
|
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
|
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
|
if self._is_selector(other):
|
|
|
|
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
|
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
|
return [x for x, name in zip(self(df), lhs_names) if name in rhs_names]
|
|
|
|
def names(df: FrameT) -> Sequence[str]:
|
|
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
|
return [x for x in lhs_names if x in rhs_names]
|
|
|
|
return self.selectors._selector.from_callables(series, names, context=self)
|
|
return self._to_expr() & other
|
|
|
|
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
|
return self.selectors.all() - self
|
|
|
|
|
|
def _eval_lhs_rhs(
|
|
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
|
|
) -> tuple[Sequence[str], Sequence[str]]:
|
|
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)
|