from __future__ import annotations from functools import lru_cache from typing import TYPE_CHECKING, Any import duckdb from narwhals._utils import Version, isinstance_or_issubclass from narwhals.exceptions import ColumnNotFoundError if TYPE_CHECKING: from collections.abc import Sequence from duckdb import DuckDBPyRelation, Expression from duckdb.typing import DuckDBPyType from narwhals._compliant.typing import CompliantLazyFrameAny from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr from narwhals.dtypes import DType from narwhals.typing import IntoDType UNITS_DICT = { "y": "year", "q": "quarter", "mo": "month", "d": "day", "h": "hour", "m": "minute", "s": "second", "ms": "millisecond", "us": "microsecond", "ns": "nanosecond", } UNIT_TO_TIMESTAMPS = { "s": "TIMESTAMP_S", "ms": "TIMESTAMP_MS", "us": "TIMESTAMP", "ns": "TIMESTAMP_NS", } DESCENDING_TO_ORDER = {True: "desc", False: "asc"} NULLS_LAST_TO_NULLS_POS = {True: "nulls last", False: "nulls first"} col = duckdb.ColumnExpression """Alias for `duckdb.ColumnExpression`.""" lit = duckdb.ConstantExpression """Alias for `duckdb.ConstantExpression`.""" when = duckdb.CaseExpression """Alias for `duckdb.CaseExpression`.""" F = duckdb.FunctionExpression """Alias for `duckdb.FunctionExpression`.""" def concat_str(*exprs: Expression, separator: str = "") -> Expression: """Concatenate many strings, NULL inputs are skipped. Wraps [concat] and [concat_ws] `FunctionExpression`(s). Arguments: exprs: Native columns. separator: String that will be used to separate the values of each column. Returns: A new native expression. [concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring- [concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string- """ return F("concat_ws", lit(separator), *exprs) if separator else F("concat", *exprs) def evaluate_exprs( df: DuckDBLazyFrame, /, *exprs: DuckDBExpr ) -> list[tuple[str, Expression]]: native_results: list[tuple[str, Expression]] = [] for expr in exprs: native_series_list = expr._call(df) output_names = expr._evaluate_output_names(df) if expr._alias_output_names is not None: output_names = expr._alias_output_names(output_names) if len(output_names) != len(native_series_list): # pragma: no cover msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" raise AssertionError(msg) native_results.extend(zip(output_names, native_series_list)) return native_results class DeferredTimeZone: """Object which gets passed between `native_to_narwhals_dtype` calls. DuckDB stores the time zone in the connection, rather than in the dtypes, so this ensures that when calculating the schema of a dataframe with multiple timezone-aware columns, that the connection's time zone is only fetched once. Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because the time zone can be modified after `DuckDBLazyFrame` creation: ```python df = nw.from_native(rel) print(df.collect_schema()) rel.query("set timezone = 'Asia/Kolkata'") print(df.collect_schema()) # should change to reflect new time zone ``` """ _cached_time_zone: str | None = None def __init__(self, rel: DuckDBPyRelation) -> None: self._rel = rel @property def time_zone(self) -> str: """Fetch relation time zone (if it wasn't calculated already).""" if self._cached_time_zone is None: self._cached_time_zone = fetch_rel_time_zone(self._rel) return self._cached_time_zone def native_to_narwhals_dtype( duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone ) -> DType: duckdb_dtype_id = duckdb_dtype.id dtypes = version.dtypes # Handle nested data types first if duckdb_dtype_id == "list": return dtypes.List( native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone) ) if duckdb_dtype_id == "struct": children = duckdb_dtype.children return dtypes.Struct( [ dtypes.Field( name=child[0], dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone), ) for child in children ] ) if duckdb_dtype_id == "array": child, size = duckdb_dtype.children shape: list[int] = [size[1]] while child[1].id == "array": child, size = child[1].children shape.insert(0, size[1]) inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone) return dtypes.Array(inner=inner, shape=tuple(shape)) if duckdb_dtype_id == "enum": if version is Version.V1: return dtypes.Enum() # type: ignore[call-arg] categories = duckdb_dtype.children[0][1] return dtypes.Enum(categories=categories) if duckdb_dtype_id == "timestamp with time zone": return dtypes.Datetime(time_zone=deferred_time_zone.time_zone) return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version) def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str: result = rel.query( "duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'" ).fetchone() assert result is not None # noqa: S101 return result[0] # type: ignore[no-any-return] @lru_cache(maxsize=16) def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType: dtypes = version.dtypes return { "hugeint": dtypes.Int128(), "bigint": dtypes.Int64(), "integer": dtypes.Int32(), "smallint": dtypes.Int16(), "tinyint": dtypes.Int8(), "uhugeint": dtypes.UInt128(), "ubigint": dtypes.UInt64(), "uinteger": dtypes.UInt32(), "usmallint": dtypes.UInt16(), "utinyint": dtypes.UInt8(), "double": dtypes.Float64(), "float": dtypes.Float32(), "varchar": dtypes.String(), "date": dtypes.Date(), "timestamp_s": dtypes.Datetime("s"), "timestamp_ms": dtypes.Datetime("ms"), "timestamp": dtypes.Datetime(), "timestamp_ns": dtypes.Datetime("ns"), "boolean": dtypes.Boolean(), "interval": dtypes.Duration(), "decimal": dtypes.Decimal(), "time": dtypes.Time(), "blob": dtypes.Binary(), }.get(duckdb_dtype_id, dtypes.Unknown()) def narwhals_to_native_dtype( # noqa: PLR0912,PLR0915,C901 dtype: IntoDType, version: Version, deferred_time_zone: DeferredTimeZone ) -> str: dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Decimal): msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Float64): return "DOUBLE" if isinstance_or_issubclass(dtype, dtypes.Float32): return "FLOAT" if isinstance_or_issubclass(dtype, dtypes.Int128): return "INT128" if isinstance_or_issubclass(dtype, dtypes.Int64): return "BIGINT" if isinstance_or_issubclass(dtype, dtypes.Int32): return "INTEGER" if isinstance_or_issubclass(dtype, dtypes.Int16): return "SMALLINT" if isinstance_or_issubclass(dtype, dtypes.Int8): return "TINYINT" if isinstance_or_issubclass(dtype, dtypes.UInt128): return "UINT128" if isinstance_or_issubclass(dtype, dtypes.UInt64): return "UBIGINT" if isinstance_or_issubclass(dtype, dtypes.UInt32): return "UINTEGER" if isinstance_or_issubclass(dtype, dtypes.UInt16): # pragma: no cover return "USMALLINT" if isinstance_or_issubclass(dtype, dtypes.UInt8): # pragma: no cover return "UTINYINT" if isinstance_or_issubclass(dtype, dtypes.String): return "VARCHAR" if isinstance_or_issubclass(dtype, dtypes.Boolean): # pragma: no cover return "BOOLEAN" if isinstance_or_issubclass(dtype, dtypes.Time): return "TIME" if isinstance_or_issubclass(dtype, dtypes.Binary): return "BLOB" if isinstance_or_issubclass(dtype, dtypes.Categorical): msg = "Categorical not supported by DuckDB" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Enum): if version is Version.V1: msg = "Converting to Enum is not supported in narwhals.stable.v1" raise NotImplementedError(msg) if isinstance(dtype, dtypes.Enum): categories = "'" + "', '".join(dtype.categories) + "'" return f"ENUM ({categories})" msg = "Can not cast / initialize Enum without categories present" raise ValueError(msg) if isinstance_or_issubclass(dtype, dtypes.Datetime): tu = dtype.time_unit tz = dtype.time_zone if not tz: return UNIT_TO_TIMESTAMPS[tu] if tu != "us": msg = f"Only microsecond precision is supported for timezone-aware `Datetime` in DuckDB, got {tu} precision" raise ValueError(msg) if tz != (rel_tz := deferred_time_zone.time_zone): # pragma: no cover msg = f"Only the connection time zone {rel_tz} is supported, got: {tz}." raise ValueError(msg) # TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed return "TIMESTAMPTZ" # pragma: no cover if isinstance_or_issubclass(dtype, dtypes.Duration): if (tu := dtype.time_unit) != "us": # pragma: no cover msg = f"Only microsecond-precision Duration is supported, got {tu} precision" return "INTERVAL" if isinstance_or_issubclass(dtype, dtypes.Date): return "DATE" if isinstance_or_issubclass(dtype, dtypes.List): inner = narwhals_to_native_dtype(dtype.inner, version, deferred_time_zone) return f"{inner}[]" if isinstance_or_issubclass(dtype, dtypes.Struct): inner = ", ".join( f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version, deferred_time_zone)}' for field in dtype.fields ) return f"STRUCT({inner})" if isinstance_or_issubclass(dtype, dtypes.Array): shape = dtype.shape duckdb_shape_fmt = "".join(f"[{item}]" for item in shape) inner_dtype: Any = dtype for _ in shape: inner_dtype = inner_dtype.inner duckdb_inner = narwhals_to_native_dtype(inner_dtype, version, deferred_time_zone) return f"{duckdb_inner}{duckdb_shape_fmt}" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) def parse_into_expression(into_expression: str | Expression) -> Expression: return col(into_expression) if isinstance(into_expression, str) else into_expression def generate_partition_by_sql(*partition_by: str | Expression) -> str: if not partition_by: return "" by_sql = ", ".join([f"{parse_into_expression(x)}" for x in partition_by]) return f"partition by {by_sql}" def generate_order_by_sql( *order_by: str | Expression, descending: Sequence[bool], nulls_last: Sequence[bool] ) -> str: if not order_by: return "" by_sql = ",".join( f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}" for x, _descending, _nulls_last in zip(order_by, descending, nulls_last) ) return f"order by {by_sql}" def window_expression( expr: Expression, partition_by: Sequence[str | Expression] = (), order_by: Sequence[str | Expression] = (), rows_start: int | None = None, rows_end: int | None = None, *, descending: Sequence[bool] | None = None, nulls_last: Sequence[bool] | None = None, ignore_nulls: bool = False, ) -> Expression: # TODO(unassigned): Replace with `duckdb.WindowExpression` when they release it. # https://github.com/duckdb/duckdb/discussions/14725#discussioncomment-11200348 try: from duckdb import SQLExpression except ModuleNotFoundError as exc: # pragma: no cover msg = f"DuckDB>=1.3.0 is required for this operation. Found: DuckDB {duckdb.__version__}" raise NotImplementedError(msg) from exc pb = generate_partition_by_sql(*partition_by) descending = descending or [False] * len(order_by) nulls_last = nulls_last or [False] * len(order_by) ob = generate_order_by_sql(*order_by, descending=descending, nulls_last=nulls_last) if rows_start is not None and rows_end is not None: rows = f"rows between {-rows_start} preceding and {rows_end} following" elif rows_end is not None: rows = f"rows between unbounded preceding and {rows_end} following" elif rows_start is not None: rows = f"rows between {-rows_start} preceding and unbounded following" else: rows = "" func = f"{str(expr).removesuffix(')')} ignore nulls)" if ignore_nulls else str(expr) return SQLExpression(f"{func} over ({pb} {ob} {rows})") def catch_duckdb_exception( exception: Exception, frame: CompliantLazyFrameAny, / ) -> ColumnNotFoundError | Exception: if isinstance(exception, duckdb.BinderException) and any( msg in str(exception) for msg in ( "not found in FROM clause", "this column cannot be referenced before it is defined", ) ): return ColumnNotFoundError.from_available_column_names( available_columns=frame.columns ) # Just return exception as-is. return exception def function(name: str, *args: Expression) -> Expression: if name == "isnull": return args[0].isnull() return F(name, *args)