from __future__ import annotations from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, cast import ibis import ibis.expr.datatypes as ibis_dtypes from narwhals._utils import isinstance_or_issubclass if TYPE_CHECKING: from collections.abc import Mapping from datetime import timedelta import ibis.expr.types as ir from ibis.common.temporal import TimestampUnit from ibis.expr.datatypes import DataType as IbisDataType from typing_extensions import TypeAlias, TypeIs from narwhals._duration import IntervalUnit from narwhals._ibis.dataframe import IbisLazyFrame from narwhals._ibis.expr import IbisExpr from narwhals._utils import Version from narwhals.dtypes import DType from narwhals.typing import IntoDType, PythonLiteral lit = ibis.literal """Alias for `ibis.literal`.""" BucketUnit: TypeAlias = Literal[ "years", "quarters", "months", "days", "hours", "minutes", "seconds", "milliseconds", "microseconds", "nanoseconds", ] TruncateUnit: TypeAlias = Literal[ "Y", "Q", "M", "W", "D", "h", "m", "s", "ms", "us", "ns" ] UNITS_DICT_BUCKET: Mapping[IntervalUnit, BucketUnit] = { "y": "years", "q": "quarters", "mo": "months", "d": "days", "h": "hours", "m": "minutes", "s": "seconds", "ms": "milliseconds", "us": "microseconds", "ns": "nanoseconds", } UNITS_DICT_TRUNCATE: Mapping[IntervalUnit, TruncateUnit] = { "y": "Y", "q": "Q", "mo": "M", "d": "D", "h": "h", "m": "m", "s": "s", "ms": "ms", "us": "us", "ns": "ns", } def evaluate_exprs(df: IbisLazyFrame, /, *exprs: IbisExpr) -> list[tuple[str, ir.Value]]: native_results: list[tuple[str, ir.Value]] = [] for expr in exprs: native_series_list = expr(df) output_names = expr._evaluate_output_names(df) if expr._alias_output_names is not None: output_names = expr._alias_output_names(output_names) if len(output_names) != len(native_series_list): # pragma: no cover msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" raise AssertionError(msg) native_results.extend(zip(output_names, native_series_list)) return native_results @lru_cache(maxsize=16) def native_to_narwhals_dtype(ibis_dtype: IbisDataType, version: Version) -> DType: # noqa: C901, PLR0912 dtypes = version.dtypes if ibis_dtype.is_int64(): return dtypes.Int64() if ibis_dtype.is_int32(): return dtypes.Int32() if ibis_dtype.is_int16(): return dtypes.Int16() if ibis_dtype.is_int8(): return dtypes.Int8() if ibis_dtype.is_uint64(): return dtypes.UInt64() if ibis_dtype.is_uint32(): return dtypes.UInt32() if ibis_dtype.is_uint16(): return dtypes.UInt16() if ibis_dtype.is_uint8(): return dtypes.UInt8() if ibis_dtype.is_boolean(): return dtypes.Boolean() if ibis_dtype.is_float64(): return dtypes.Float64() if ibis_dtype.is_float32(): return dtypes.Float32() if ibis_dtype.is_string(): return dtypes.String() if ibis_dtype.is_date(): return dtypes.Date() if is_timestamp(ibis_dtype): _unit = cast("TimestampUnit", ibis_dtype.unit) return dtypes.Datetime(time_unit=_unit.value, time_zone=ibis_dtype.timezone) if is_interval(ibis_dtype): _time_unit = ibis_dtype.unit.value if _time_unit not in {"ns", "us", "ms", "s"}: # pragma: no cover msg = f"Unsupported interval unit: {_time_unit}" raise NotImplementedError(msg) return dtypes.Duration(_time_unit) if is_array(ibis_dtype): if ibis_dtype.length: return dtypes.Array( native_to_narwhals_dtype(ibis_dtype.value_type, version), ibis_dtype.length, ) else: return dtypes.List(native_to_narwhals_dtype(ibis_dtype.value_type, version)) if is_struct(ibis_dtype): return dtypes.Struct( [ dtypes.Field(name, native_to_narwhals_dtype(dtype, version)) for name, dtype in ibis_dtype.items() ] ) if ibis_dtype.is_decimal(): # pragma: no cover return dtypes.Decimal() if ibis_dtype.is_time(): return dtypes.Time() if ibis_dtype.is_binary(): return dtypes.Binary() return dtypes.Unknown() # pragma: no cover def is_timestamp(obj: IbisDataType) -> TypeIs[ibis_dtypes.Timestamp]: return obj.is_timestamp() def is_interval(obj: IbisDataType) -> TypeIs[ibis_dtypes.Interval]: return obj.is_interval() def is_array(obj: IbisDataType) -> TypeIs[ibis_dtypes.Array[Any]]: return obj.is_array() def is_struct(obj: IbisDataType) -> TypeIs[ibis_dtypes.Struct]: return obj.is_struct() def is_floating(obj: IbisDataType) -> TypeIs[ibis_dtypes.Floating]: return obj.is_floating() def narwhals_to_native_dtype( # noqa: C901, PLR0912 dtype: IntoDType, version: Version ) -> IbisDataType: dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Decimal): # pragma: no cover return ibis_dtypes.Decimal() if isinstance_or_issubclass(dtype, dtypes.Float64): return ibis_dtypes.Float64() if isinstance_or_issubclass(dtype, dtypes.Float32): return ibis_dtypes.Float32() if isinstance_or_issubclass(dtype, dtypes.Int128): # pragma: no cover msg = "Int128 not supported by Ibis" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Int64): return ibis_dtypes.Int64() if isinstance_or_issubclass(dtype, dtypes.Int32): return ibis_dtypes.Int32() if isinstance_or_issubclass(dtype, dtypes.Int16): return ibis_dtypes.Int16() if isinstance_or_issubclass(dtype, dtypes.Int8): return ibis_dtypes.Int8() if isinstance_or_issubclass(dtype, dtypes.UInt128): # pragma: no cover msg = "UInt128 not supported by Ibis" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.UInt64): return ibis_dtypes.UInt64() if isinstance_or_issubclass(dtype, dtypes.UInt32): return ibis_dtypes.UInt32() if isinstance_or_issubclass(dtype, dtypes.UInt16): return ibis_dtypes.UInt16() if isinstance_or_issubclass(dtype, dtypes.UInt8): return ibis_dtypes.UInt8() if isinstance_or_issubclass(dtype, dtypes.String): return ibis_dtypes.String() if isinstance_or_issubclass(dtype, dtypes.Boolean): return ibis_dtypes.Boolean() if isinstance_or_issubclass(dtype, dtypes.Categorical): msg = "Categorical not supported by Ibis" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Datetime): return ibis_dtypes.Timestamp.from_unit(dtype.time_unit, timezone=dtype.time_zone) if isinstance_or_issubclass(dtype, dtypes.Duration): return ibis_dtypes.Interval(unit=dtype.time_unit) # pyright: ignore[reportArgumentType] if isinstance_or_issubclass(dtype, dtypes.Date): return ibis_dtypes.Date() if isinstance_or_issubclass(dtype, dtypes.Time): return ibis_dtypes.Time() if isinstance_or_issubclass(dtype, dtypes.List): inner = narwhals_to_native_dtype(dtype.inner, version) return ibis_dtypes.Array(value_type=inner) if isinstance_or_issubclass(dtype, dtypes.Struct): fields = [ (field.name, narwhals_to_native_dtype(field.dtype, version)) for field in dtype.fields ] return ibis_dtypes.Struct.from_tuples(fields) if isinstance_or_issubclass(dtype, dtypes.Array): inner = narwhals_to_native_dtype(dtype.inner, version) return ibis_dtypes.Array(value_type=inner, length=dtype.size) if isinstance_or_issubclass(dtype, dtypes.Binary): return ibis_dtypes.Binary() if isinstance_or_issubclass(dtype, dtypes.Enum): # Ibis does not support: https://github.com/ibis-project/ibis/issues/10991 msg = "Enum not supported by Ibis" raise NotImplementedError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) def timedelta_to_ibis_interval(td: timedelta) -> ibis.expr.types.temporal.IntervalScalar: return ibis.interval(days=td.days, seconds=td.seconds, microseconds=td.microseconds) def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value: if name == "row_number": return ibis.row_number() + 1 # pyright: ignore[reportOperatorIssue] if name == "least": return ibis.least(*args) # pyright: ignore[reportOperatorIssue] if name == "greatest": return ibis.greatest(*args) # pyright: ignore[reportOperatorIssue] expr = args[0] if name == "var_pop": return cast("ir.NumericColumn", expr).var(how="pop") if name == "var_samp": return cast("ir.NumericColumn", expr).var(how="sample") if name == "stddev_pop": return cast("ir.NumericColumn", expr).std(how="pop") if name == "stddev_samp": return cast("ir.NumericColumn", expr).std(how="sample") return getattr(expr, name)(*args[1:])