from __future__ import annotations import operator from functools import reduce from itertools import chain from typing import TYPE_CHECKING, Literal import pyarrow as pa import pyarrow.compute as pc from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.selectors import ArrowSelectorNamespace from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, ) from narwhals._utils import Implementation if TYPE_CHECKING: from collections.abc import Sequence from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete from narwhals._compliant.typing import ScalarKwargs from narwhals._utils import Version from narwhals.typing import IntoDType, NonNestedLiteral class ArrowNamespace( EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table, "ChunkedArrayAny"] ): _implementation = Implementation.PYARROW @property def _dataframe(self) -> type[ArrowDataFrame]: return ArrowDataFrame @property def _expr(self) -> type[ArrowExpr]: return ArrowExpr @property def _series(self) -> type[ArrowSeries]: return ArrowSeries def __init__(self, *, version: Version) -> None: self._version = version def len(self) -> ArrowExpr: # coverage bug? this is definitely hit return self._expr( # pragma: no cover lambda df: [ ArrowSeries.from_iterable([len(df.native)], name="len", context=self) ], depth=0, function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, version=self._version, ) def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr: def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: arrow_series = ArrowSeries.from_iterable( data=[value], name="literal", context=self ) if dtype: return arrow_series.cast(dtype) return arrow_series return self._expr( lambda df: [_lit_arrow_series(df)], depth=0, function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, ) def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: series = chain.from_iterable(expr(df) for expr in exprs) align = self._series._align_full_broadcast it = ( (s.fill_null(True, None, None) for s in series) # noqa: FBT003 if ignore_nulls else series ) return [reduce(operator.and_, align(*it))] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: series = chain.from_iterable(expr(df) for expr in exprs) align = self._series._align_full_broadcast it = ( (s.fill_null(False, None, None) for s in series) # noqa: FBT003 if ignore_nulls else series ) return [reduce(operator.or_, align(*it))] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: it = chain.from_iterable(expr(df) for expr in exprs) series = (s.fill_null(0, strategy=None, limit=None) for s in it) align = self._series._align_full_broadcast return [reduce(operator.add, align(*series))] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: int_64 = self._version.dtypes.Int64() def func(df: ArrowDataFrame) -> list[ArrowSeries]: expr_results = list(chain.from_iterable(expr(df) for expr in exprs)) align = self._series._align_full_broadcast series = align( *(s.fill_null(0, strategy=None, limit=None) for s in expr_results) ) non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results)) return [reduce(operator.add, series) / reduce(operator.add, non_na)] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: align = self._series._align_full_broadcast init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs)) init_series, *series = align(init_series, *series) native_series = reduce( pc.min_element_wise, [s.native for s in series], init_series.native ) return [ ArrowSeries(native_series, name=init_series.name, version=self._version) ] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: align = self._series._align_full_broadcast init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs)) init_series, *series = align(init_series, *series) native_series = reduce( pc.max_element_wise, [s.native for s in series], init_series.native ) return [ ArrowSeries(native_series, name=init_series.name, version=self._version) ] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: if self._backend_version >= (14,): return pa.concat_tables(dfs, promote_options="default") return pa.concat_tables(dfs, promote=True) # pragma: no cover def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: names = list(chain.from_iterable(df.column_names for df in dfs)) arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) return pa.Table.from_arrays(arrays, names=names) def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: cols_0 = dfs[0].column_names for i, df in enumerate(dfs[1:], start=1): cols_current = df.column_names if cols_current != cols_0: msg = ( "unable to vstack, column names don't match:\n" f" - dataframe 0: {cols_0}\n" f" - dataframe {i}: {cols_current}\n" ) raise TypeError(msg) return pa.concat_tables(dfs) @property def selectors(self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace.from_namespace(self) def when(self, predicate: ArrowExpr) -> ArrowWhen: return ArrowWhen.from_expr(predicate, context=self) def concat_str( self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool ) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: align = self._series._align_full_broadcast compliant_series_list = align( *(chain.from_iterable(expr(df) for expr in exprs)) ) name = compliant_series_list[0].name null_handling: Literal["skip", "emit_null"] = ( "skip" if ignore_nulls else "emit_null" ) it, separator_scalar = cast_to_comparable_string_types( *(s.native for s in compliant_series_list), separator=separator ) # NOTE: stubs indicate `separator` must also be a `ChunkedArray` # Reality: `str` is fine concat_str: Incomplete = pc.binary_join_element_wise compliant = self._series( concat_str(*it, separator_scalar, null_handling=null_handling), name=name, version=self._version, ) return [compliant] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) def coalesce(self, *exprs: ArrowExpr) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: align = self._series._align_full_broadcast init_series, *series = align(*chain.from_iterable(expr(df) for expr in exprs)) return [ ArrowSeries( pc.coalesce(init_series.native, *(s.native for s in series)), name=init_series.name, version=self._version, ) ] return self._expr._from_callable( func=func, depth=max(x._depth for x in exprs) + 1, function_name="coalesce", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]): @property def _then(self) -> type[ArrowThen]: return ArrowThen def _if_then_else( self, when: ChunkedArrayAny, then: ChunkedArrayAny, otherwise: ArrayOrScalar | NonNestedLiteral, /, ) -> ChunkedArrayAny: otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise return pc.if_else(when, then, otherwise) class ArrowThen( CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr ): _depth: int = 0 _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 _function_name: str = "whenthen"