team-10/venv/Lib/site-packages/narwhals/_arrow/dataframe.py
2025-08-02 02:00:33 +02:00

761 lines
27 KiB
Python

from __future__ import annotations
from collections.abc import Collection, Iterator, Mapping, Sequence
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, cast, overload
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._compliant import EagerDataFrame
from narwhals._expression_parsing import ExprKind
from narwhals._utils import (
Implementation,
Version,
check_column_names_are_unique,
convert_str_slice_to_int_slice,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
scale_bytes,
supports_arrow_c_stream,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ShapeError
if TYPE_CHECKING:
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pandas as pd
import polars as pl
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.group_by import ArrowGroupBy
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
ChunkedArrayAny,
Mask,
Order,
)
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
from narwhals._translate import IntoArrowTable
from narwhals._utils import Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import (
JoinStrategy,
SizedMultiIndexSelector,
SizedMultiNameSelector,
SizeUnit,
UniqueKeepStrategy,
_1DArray,
_2DArray,
_SliceIndex,
_SliceName,
)
JoinType: TypeAlias = Literal[
"left semi",
"right semi",
"left anti",
"right anti",
"inner",
"left outer",
"right outer",
"full outer",
]
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
class ArrowDataFrame(
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
):
_implementation = Implementation.PYARROW
def __init__(
self,
native_dataframe: pa.Table,
*,
version: Version,
validate_column_names: bool,
validate_backend_version: bool = False,
) -> None:
if validate_column_names:
check_column_names_are_unique(native_dataframe.column_names)
if validate_backend_version:
self._validate_backend_version()
self._native_frame = native_dataframe
self._version = version
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
backend_version = context._implementation._backend_version()
if cls._is_native(data):
native = data
elif backend_version >= (14,) or isinstance(data, Collection):
native = pa.table(data)
elif supports_arrow_c_stream(data): # pragma: no cover
msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}."
raise ModuleNotFoundError(msg)
else: # pragma: no cover
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
raise TypeError(msg)
return cls.from_native(native, context=context)
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | None,
) -> Self:
from narwhals.schema import Schema
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
if pa_schema and not data:
native = pa_schema.empty_table()
else:
native = pa.Table.from_pydict(data, schema=pa_schema)
return cls.from_native(native, context=context)
@staticmethod
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
return isinstance(obj, pa.Table)
@classmethod
def from_native(cls, data: pa.Table, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version, validate_column_names=True)
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> Self:
from narwhals.schema import Schema
arrays = [pa.array(val) for val in data.T]
if isinstance(schema, (Mapping, Schema)):
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
else:
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
return cls.from_native(native, context=context)
def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace
return ArrowNamespace(version=self._version)
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
return self._implementation.to_native_namespace()
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def __narwhals_dataframe__(self) -> Self:
return self
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version, validate_column_names=False)
def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self:
return self.__class__(
df, version=self._version, validate_column_names=validate_column_names
)
@property
def shape(self) -> tuple[int, int]:
return self.native.shape
def __len__(self) -> int:
return len(self.native)
def row(self, index: int) -> tuple[Any, ...]:
return tuple(col[index] for col in self.native.itercolumns())
@overload
def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
@overload
def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ...
@overload
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
if not named:
return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value]
return self.native.to_pylist()
def iter_columns(self) -> Iterator[ArrowSeries]:
for name, series in zip(self.columns, self.native.itercolumns()):
yield ArrowSeries.from_native(series, context=self, name=name)
_iter_columns = iter_columns
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
df = self.native
num_rows = df.num_rows
if not named:
for i in range(0, num_rows, buffer_size):
rows = df[i : i + buffer_size].to_pydict().values()
yield from zip(*rows)
else:
for i in range(0, num_rows, buffer_size):
yield from df[i : i + buffer_size].to_pylist()
def get_column(self, name: str) -> ArrowSeries:
if not isinstance(name, str):
msg = f"Expected str, got: {type(name)}"
raise TypeError(msg)
return ArrowSeries.from_native(self.native[name], context=self, name=name)
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray:
return self.native.__array__(dtype, copy=copy)
def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self:
if len(rows) == 0:
return self._with_native(self.native.slice(0, 0))
if self._backend_version < (18,) and isinstance(rows, tuple):
rows = list(rows)
return self._with_native(self.native.take(rows))
def _gather_slice(self, rows: _SliceIndex | range) -> Self:
start = rows.start or 0
stop = rows.stop if rows.stop is not None else len(self.native)
if start < 0:
start = len(self.native) + start
if stop < 0:
stop = len(self.native) + stop
if rows.step is not None and rows.step != 1:
msg = "Slicing with step is not supported on PyArrow tables"
raise NotImplementedError(msg)
return self._with_native(self.native.slice(start, stop - start))
def _select_slice_name(self, columns: _SliceName) -> Self:
start, stop, step = convert_str_slice_to_int_slice(columns, self.columns)
return self._with_native(self.native.select(self.columns[start:stop:step]))
def _select_slice_index(self, columns: _SliceIndex | range) -> Self:
return self._with_native(
self.native.select(self.columns[columns.start : columns.stop : columns.step])
)
def _select_multi_index(
self, columns: SizedMultiIndexSelector[ChunkedArrayAny]
) -> Self:
selector: Sequence[int]
if isinstance(columns, pa.ChunkedArray):
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
selector = cast("Sequence[int]", columns.to_pylist())
# TODO @dangotbanned: Fix upstream, it is actually much narrower
# **Doesn't accept `ndarray`**
elif is_numpy_array_1d(columns):
selector = columns.tolist()
else:
selector = columns
return self._with_native(self.native.select(selector))
def _select_multi_name(
self, columns: SizedMultiNameSelector[ChunkedArrayAny]
) -> Self:
selector: Sequence[str] | _1DArray
if isinstance(columns, pa.ChunkedArray):
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
selector = cast("Sequence[str]", columns.to_pylist())
else:
selector = columns
# NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221
return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType]
@property
def schema(self) -> dict[str, DType]:
schema = self.native.schema
return {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in zip(schema.names, schema.types)
}
def collect_schema(self) -> dict[str, DType]:
return self.schema
def estimated_size(self, unit: SizeUnit) -> int | float:
sz = self.native.nbytes
return scale_bytes(sz, unit)
explode = not_implemented()
@property
def columns(self) -> list[str]:
return self.native.column_names
def simple_select(self, *column_names: str) -> Self:
return self._with_native(
self.native.select(list(column_names)), validate_column_names=False
)
def select(self, *exprs: ArrowExpr) -> Self:
new_series = self._evaluate_into_exprs(*exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._with_native(
self.native.__class__.from_arrays([]), validate_column_names=False
)
names = [s.name for s in new_series]
align = new_series[0]._align_full_broadcast
reshaped = align(*new_series)
df = pa.Table.from_arrays([s.native for s in reshaped], names=names)
return self._with_native(df, validate_column_names=True)
def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny:
length = len(self)
if not other._broadcast:
if (len_other := len(other)) != length:
msg = f"Expected object of length {length}, got: {len_other}."
raise ShapeError(msg)
return other.native
value = other.native[0]
return pa.chunked_array([pa.repeat(value, length)])
def with_columns(self, *exprs: ArrowExpr) -> Self:
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
# All `pyarrow` data is immutable, so this is fine
native_frame = self.native
new_columns = self._evaluate_into_exprs(*exprs)
columns = self.columns
for col_value in new_columns:
col_name = col_value.name
column = self._extract_comparand(col_value)
native_frame = (
native_frame.set_column(columns.index(col_name), col_name, column=column)
if col_name in columns
else native_frame.append_column(col_name, column=column)
)
return self._with_native(native_frame, validate_column_names=False)
def group_by(
self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool
) -> ArrowGroupBy:
from narwhals._arrow.group_by import ArrowGroupBy
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_to_join_map: dict[str, JoinType] = {
"anti": "left anti",
"semi": "left semi",
"inner": "inner",
"left": "left outer",
"full": "full outer",
}
if how == "cross":
plx = self.__narwhals_namespace__()
key_token = generate_temporary_column_name(
n_bytes=8, columns=[*self.columns, *other.columns]
)
return self._with_native(
self.with_columns(
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
)
.native.join(
other.with_columns(
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
).native,
keys=key_token,
right_keys=key_token,
join_type="inner",
right_suffix=suffix,
)
.drop([key_token])
)
coalesce_keys = how != "full" # polars full join does not coalesce keys
return self._with_native(
self.native.join(
other.native,
keys=left_on or [], # type: ignore[arg-type]
right_keys=right_on, # type: ignore[arg-type]
join_type=how_to_join_map[how],
right_suffix=suffix,
coalesce_keys=coalesce_keys,
)
)
join_asof = not_implemented()
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(to_drop), validate_column_names=False)
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
if subset is None:
return self._with_native(self.native.drop_null(), validate_column_names=False)
plx = self.__narwhals_namespace__()
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
return self.filter(mask)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
order: Order = "descending" if descending else "ascending"
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
else:
sorting = [
(key, "descending" if is_descending else "ascending")
for key, is_descending in zip(by, descending)
]
null_placement = "at_end" if nulls_last else "at_start"
return self._with_native(
self.native.sort_by(sorting, null_placement=null_placement),
validate_column_names=False,
)
def to_pandas(self) -> pd.DataFrame:
return self.native.to_pandas()
def to_polars(self) -> pl.DataFrame:
import polars as pl # ignore-banned-import
return pl.from_arrow(self.native) # type: ignore[return-value]
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
import numpy as np # ignore-banned-import
arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns])
return arr
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
it = self.iter_columns()
if as_series:
return {ser.name: ser for ser in it}
return {ser.name: ser.to_list() for ser in it}
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
plx = self.__narwhals_namespace__()
if order_by is None:
import numpy as np # ignore-banned-import
data = pa.array(np.arange(len(self), dtype=np.int64))
row_index = plx._expr._from_series(
plx._series.from_iterable(data, context=self, name=name)
)
else:
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
return self.select(row_index, plx.all())
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native: Mask | ChunkedArrayAny = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
return self._with_native(
self.native.filter(mask_native), validate_column_names=False
)
def head(self, n: int) -> Self:
df = self.native
if n >= 0:
return self._with_native(df.slice(0, n), validate_column_names=False)
else:
num_rows = df.num_rows
return self._with_native(
df.slice(0, max(0, num_rows + n)), validate_column_names=False
)
def tail(self, n: int) -> Self:
df = self.native
if n >= 0:
num_rows = df.num_rows
return self._with_native(
df.slice(max(0, num_rows - n)), validate_column_names=False
)
else:
return self._with_native(df.slice(abs(n)), validate_column_names=False)
def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny:
if backend is None:
return self
elif backend is Implementation.DUCKDB:
import duckdb # ignore-banned-import
from narwhals._duckdb.dataframe import DuckDBLazyFrame
df = self.native # noqa: F841
return DuckDBLazyFrame(
duckdb.table("df"), validate_backend_version=True, version=self._version
)
elif backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsLazyFrame
return PolarsLazyFrame(
cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(),
validate_backend_version=True,
version=self._version,
)
elif backend is Implementation.DASK:
import dask.dataframe as dd # ignore-banned-import
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
dd.from_pandas(self.native.to_pandas()),
validate_backend_version=True,
version=self._version,
)
elif backend.is_ibis():
import ibis # ignore-banned-import
from narwhals._ibis.dataframe import IbisLazyFrame
return IbisLazyFrame(
ibis.memtable(self.native, columns=self.columns),
validate_backend_version=True,
version=self._version,
)
raise AssertionError # pragma: no cover
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is Implementation.PYARROW or backend is None:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self.native, version=self._version, validate_column_names=False
)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.to_pandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
cast("pl.DataFrame", pl.from_arrow(self.native)),
validate_backend_version=True,
version=self._version,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise AssertionError(msg) # pragma: no cover
def clone(self) -> Self:
return self._with_native(self.native, validate_column_names=False)
def item(self, row: int | None, column: int | str | None) -> Any:
from narwhals._arrow.series import maybe_extract_py_scalar
if row is None and column is None:
if self.shape != (1, 1):
msg = (
"can only call `.item()` if the dataframe is of shape (1, 1),"
" or if explicit row/col values are provided;"
f" frame has shape {self.shape!r}"
)
raise ValueError(msg)
return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True)
elif row is None or column is None:
msg = "cannot call `.item()` with only one of `row` or `column`"
raise ValueError(msg)
_col = self.columns.index(column) if isinstance(column, str) else column
return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True)
def rename(self, mapping: Mapping[str, str]) -> Self:
names: dict[str, str] | list[str]
if self._backend_version >= (17,):
names = cast("dict[str, str]", mapping)
else: # pragma: no cover
names = [mapping.get(c, c) for c in self.columns]
return self._with_native(self.native.rename_columns(names))
def write_parquet(self, file: str | Path | BytesIO) -> None:
import pyarrow.parquet as pp
pp.write_table(self.native, file)
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None:
import pyarrow.csv as pa_csv
if file is None:
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(self.native, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
pa_csv.write_csv(self.native, file)
return None
def is_unique(self) -> ArrowSeries:
import numpy as np # ignore-banned-import
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
row_index = pa.array(np.arange(len(self)))
keep_idx = (
self.native.append_column(col_token, row_index)
.group_by(self.columns)
.aggregate([(col_token, "min"), (col_token, "max")])
)
native = pa.chunked_array(
pc.and_(
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
)
)
return ArrowSeries.from_native(native, context=self)
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self:
# The param `maintain_order` is only here for compatibility with the Polars API
# and has no effect on the output.
import numpy as np # ignore-banned-import
if subset and (error := self._check_columns_exist(subset)):
raise error
subset = list(subset or self.columns)
if keep in {"any", "first", "last"}:
from narwhals._arrow.group_by import ArrowGroupBy
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
keep_idx_native = (
self.native.append_column(col_token, pa.array(np.arange(len(self))))
.group_by(subset)
.aggregate([(col_token, agg_func)])
.column(f"{col_token}_{agg_func}")
)
return self._with_native(
self.native.take(keep_idx_native), validate_column_names=False
)
keep_idx = self.simple_select(*subset).is_unique()
plx = self.__narwhals_namespace__()
return self.filter(plx._expr._from_series(keep_idx))
def gather_every(self, n: int, offset: int) -> Self:
return self._with_native(self.native[offset::n], validate_column_names=False)
def to_arrow(self) -> pa.Table:
return self.native
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self:
import numpy as np # ignore-banned-import
num_rows = len(self)
if n is None and fraction is not None:
n = int(num_rows * fraction)
rng = np.random.default_rng(seed=seed)
idx = np.arange(num_rows)
mask = rng.choice(idx, size=n, replace=with_replacement)
return self._with_native(self.native.take(mask), validate_column_names=False)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
n_rows = len(self)
index_ = [] if index is None else index
on_ = [c for c in self.columns if c not in index_] if on is None else on
concat = (
partial(pa.concat_tables, promote_options="permissive")
if self._backend_version >= (14, 0, 0)
else pa.concat_tables
)
names = [*index_, variable_name, value_name]
return self._with_native(
concat(
[
pa.Table.from_arrays(
[
*(self.native.column(idx_col) for idx_col in index_),
cast(
"ChunkedArrayAny",
pa.array([on_col] * n_rows, pa.string()),
),
self.native.column(on_col),
],
names=names,
)
for on_col in on_
]
)
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
# upcast numeric to non-numeric (e.g. string) datatypes
pivot = not_implemented()