team-10/venv/Lib/site-packages/narwhals/_compliant/namespace.py
2025-08-02 02:00:33 +02:00

211 lines
7.2 KiB
Python

from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Protocol, overload
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantLazyFrameT,
DepthTrackingExprT,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprT,
NativeFrameT,
NativeFrameT_co,
NativeSeriesT,
)
from narwhals._utils import (
exclude_column_names,
get_column_names,
passthrough_column_names,
)
from narwhals.dependencies import is_numpy_array_2d
if TYPE_CHECKING:
from collections.abc import Container, Iterable, Mapping, Sequence
from typing_extensions import TypeAlias
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
from narwhals._utils import Implementation, Version
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import (
ConcatMethod,
Into1DArray,
IntoDType,
NonNestedLiteral,
_2DArray,
)
Incomplete: TypeAlias = Any
__all__ = [
"CompliantNamespace",
"DepthTrackingNamespace",
"EagerNamespace",
"LazyNamespace",
]
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
_implementation: Implementation
_version: Version
def all(self) -> CompliantExprT:
return self._expr.from_column_names(get_column_names, context=self)
def col(self, *column_names: str) -> CompliantExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), context=self
)
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names), context=self
)
def nth(self, *column_indices: int) -> CompliantExprT:
return self._expr.from_column_indices(*column_indices, context=self)
def len(self) -> CompliantExprT: ...
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
def all_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def any_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def concat(
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
) -> CompliantFrameT: ...
def when(
self, predicate: CompliantExprT
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
def concat_str(
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
) -> CompliantExprT: ...
@property
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
@property
def _expr(self) -> type[CompliantExprT]: ...
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
):
def all(self) -> DepthTrackingExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)
def col(self, *column_names: str) -> DepthTrackingExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)
class LazyNamespace(
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
if self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
else: # pragma: no cover
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
class EagerNamespace(
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _dataframe(self) -> type[EagerDataFrameT]: ...
@property
def _series(self) -> type[EagerSeriesT]: ...
def when(
self, predicate: EagerExprT
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
@overload
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
@overload
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
def from_native(
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
elif self._series._is_native(data):
return self._series.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
@overload
def from_numpy(
self,
data: _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> EagerDataFrameT: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
) -> EagerDataFrameT | EagerSeriesT:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self)
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def _concat_horizontal(
self, dfs: Sequence[NativeFrameT | Any], /
) -> NativeFrameT: ...
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def concat(
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
) -> EagerDataFrameT:
dfs = [item.native for item in items]
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
else: # pragma: no cover
raise NotImplementedError
return self._dataframe.from_native(native, context=self)