from __future__ import annotations from functools import lru_cache from typing import TYPE_CHECKING, Any, cast import pyarrow as pa import pyarrow.compute as pc from narwhals._compliant import EagerSeriesNamespace from narwhals._utils import isinstance_or_issubclass if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping from typing_extensions import TypeAlias, TypeIs from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import ( ArrayAny, ArrayOrScalar, ArrayOrScalarT1, ArrayOrScalarT2, ChunkedArrayAny, NativeIntervalUnit, ScalarAny, ) from narwhals._duration import IntervalUnit from narwhals._utils import Version from narwhals.dtypes import DType from narwhals.typing import IntoDType, PythonLiteral # NOTE: stubs don't allow for `ChunkedArray[StructArray]` # Intended to represent the `.chunks` property storing `list[pa.StructArray]` ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ... def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ... def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ... def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ... def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ... def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ... def extract_regex( strings: ChunkedArrayAny, /, pattern: str, *, options: Any = None, memory_pool: Any = None, ) -> ChunkedArrayStructArray: ... else: from pyarrow.compute import extract_regex from pyarrow.types import ( is_dictionary, # noqa: F401 is_duration, is_fixed_size_list, is_large_list, is_list, is_timestamp, ) UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = { "y": "year", "q": "quarter", "mo": "month", "d": "day", "h": "hour", "m": "minute", "s": "second", "ms": "millisecond", "us": "microsecond", "ns": "nanosecond", } lit = pa.scalar """Alias for `pyarrow.scalar`.""" def extract_py_scalar(value: Any, /) -> Any: from narwhals._arrow.series import maybe_extract_py_scalar return maybe_extract_py_scalar(value, return_py_scalar=True) def chunked_array( arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, / ) -> ChunkedArrayAny: if isinstance(arr, pa.ChunkedArray): return arr if isinstance(arr, list): return pa.chunked_array(arr, dtype) else: return pa.chunked_array([arr], arr.type) def nulls_like(n: int, series: ArrowSeries) -> ArrayAny: """Create a strongly-typed Array instance with all elements null. Uses the type of `series`, without upseting `mypy`. """ return pa.nulls(n, series.native.type) def zeros(n: int, /) -> pa.Int64Array: return pa.repeat(0, n) @lru_cache(maxsize=16) def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912 dtypes = version.dtypes if pa.types.is_int64(dtype): return dtypes.Int64() if pa.types.is_int32(dtype): return dtypes.Int32() if pa.types.is_int16(dtype): return dtypes.Int16() if pa.types.is_int8(dtype): return dtypes.Int8() if pa.types.is_uint64(dtype): return dtypes.UInt64() if pa.types.is_uint32(dtype): return dtypes.UInt32() if pa.types.is_uint16(dtype): return dtypes.UInt16() if pa.types.is_uint8(dtype): return dtypes.UInt8() if pa.types.is_boolean(dtype): return dtypes.Boolean() if pa.types.is_float64(dtype): return dtypes.Float64() if pa.types.is_float32(dtype): return dtypes.Float32() # bug in coverage? it shows `31->exit` (where `31` is currently the line number of # the next line), even though both when the if condition is true and false are covered if ( # pragma: no cover pa.types.is_string(dtype) or pa.types.is_large_string(dtype) or getattr(pa.types, "is_string_view", lambda _: False)(dtype) ): return dtypes.String() if pa.types.is_date32(dtype): return dtypes.Date() if is_timestamp(dtype): return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz) if is_duration(dtype): return dtypes.Duration(time_unit=dtype.unit) if pa.types.is_dictionary(dtype): return dtypes.Categorical() if pa.types.is_struct(dtype): return dtypes.Struct( [ dtypes.Field( dtype.field(i).name, native_to_narwhals_dtype(dtype.field(i).type, version), ) for i in range(dtype.num_fields) ] ) if is_list(dtype) or is_large_list(dtype): return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version)) if is_fixed_size_list(dtype): return dtypes.Array( native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size ) if pa.types.is_decimal(dtype): return dtypes.Decimal() if pa.types.is_time32(dtype) or pa.types.is_time64(dtype): return dtypes.Time() if pa.types.is_binary(dtype): return dtypes.Binary() return dtypes.Unknown() # pragma: no cover def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: # noqa: C901, PLR0912 dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Decimal): msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Float64): return pa.float64() if isinstance_or_issubclass(dtype, dtypes.Float32): return pa.float32() if isinstance_or_issubclass(dtype, dtypes.Int64): return pa.int64() if isinstance_or_issubclass(dtype, dtypes.Int32): return pa.int32() if isinstance_or_issubclass(dtype, dtypes.Int16): return pa.int16() if isinstance_or_issubclass(dtype, dtypes.Int8): return pa.int8() if isinstance_or_issubclass(dtype, dtypes.UInt64): return pa.uint64() if isinstance_or_issubclass(dtype, dtypes.UInt32): return pa.uint32() if isinstance_or_issubclass(dtype, dtypes.UInt16): return pa.uint16() if isinstance_or_issubclass(dtype, dtypes.UInt8): return pa.uint8() if isinstance_or_issubclass(dtype, dtypes.String): return pa.string() if isinstance_or_issubclass(dtype, dtypes.Boolean): return pa.bool_() if isinstance_or_issubclass(dtype, dtypes.Categorical): return pa.dictionary(pa.uint32(), pa.string()) if isinstance_or_issubclass(dtype, dtypes.Datetime): unit = dtype.time_unit return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit) if isinstance_or_issubclass(dtype, dtypes.Duration): return pa.duration(dtype.time_unit) if isinstance_or_issubclass(dtype, dtypes.Date): return pa.date32() if isinstance_or_issubclass(dtype, dtypes.List): return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version)) if isinstance_or_issubclass(dtype, dtypes.Struct): return pa.struct( [ (field.name, narwhals_to_native_dtype(field.dtype, version=version)) for field in dtype.fields ] ) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover inner = narwhals_to_native_dtype(dtype.inner, version=version) list_size = dtype.size return pa.list_(inner, list_size=list_size) if isinstance_or_issubclass(dtype, dtypes.Time): return pa.time64("ns") if isinstance_or_issubclass(dtype, dtypes.Binary): return pa.binary() msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) def extract_native( lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny ) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]: """Extract native objects in binary operation. If the comparison isn't supported, return `NotImplemented` so that the "right-hand-side" operation (e.g. `__radd__`) can be tried. If one of the two sides has a `_broadcast` flag, then extract the scalar underneath it so that PyArrow can do its own broadcasting. """ from narwhals._arrow.series import ArrowSeries if rhs is None: # pragma: no cover return lhs.native, lit(None, type=lhs._type) if isinstance(rhs, ArrowSeries): if lhs._broadcast and not rhs._broadcast: return lhs.native[0], rhs.native if rhs._broadcast: return lhs.native, rhs.native[0] return lhs.native, rhs.native if isinstance(rhs, list): msg = "Expected Series or scalar, got list." raise TypeError(msg) return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs) def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 if pa.types.is_integer(left.type) and pa.types.is_integer(right.type): divided = pc.divide_checked(left, right) # TODO @dangotbanned: Use a `TypeVar` in guards # Narrowing to a `Union` isn't interacting well with the rest of the stubs # https://github.com/zen-xu/pyarrow-stubs/pull/215 if pa.types.is_signed_integer(divided.type): div_type = cast("pa._lib.Int64Type", divided.type) has_remainder = pc.not_equal(pc.multiply(divided, right), left) has_one_negative_operand = pc.less( pc.bit_wise_xor(left, right), lit(0, div_type) ) result = pc.if_else( pc.and_(has_remainder, has_one_negative_operand), pc.subtract(divided, lit(1, div_type)), divided, ) else: result = divided # pragma: no cover result = result.cast(left.type) else: divided = pc.divide(left, right) result = pc.floor(divided) return result def cast_for_truediv( arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2 ) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 # Ensure int / int -> float mirroring Python/Numpy behavior # as pc.divide_checked(int, int) -> int if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type): # GH: 56645. # noqa: ERA001 # https://github.com/apache/arrow/issues/35563 return arrow_array.cast(pa.float64(), safe=False), pa_object.cast( pa.float64(), safe=False ) return arrow_array, pa_object # Regex for date, time, separator and timezone components DATE_RE = r"(?P\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})" SEP_RE = r"(?P\s|T)" TIME_RE = r"(?P