197 lines
6 KiB
Python
197 lines
6 KiB
Python
"""Schema.
|
|
|
|
Adapted from Polars implementation at:
|
|
https://github.com/pola-rs/polars/blob/main/py-polars/polars/schema.py.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import OrderedDict
|
|
from functools import partial
|
|
from typing import TYPE_CHECKING, cast
|
|
|
|
from narwhals._utils import Implementation, Version
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable, Mapping
|
|
from typing import Any, ClassVar
|
|
|
|
import polars as pl
|
|
import pyarrow as pa
|
|
|
|
from narwhals.dtypes import DType
|
|
from narwhals.typing import DTypeBackend
|
|
|
|
|
|
__all__ = ["Schema"]
|
|
|
|
|
|
class Schema(OrderedDict[str, "DType"]):
|
|
"""Ordered mapping of column names to their data type.
|
|
|
|
Arguments:
|
|
schema: The schema definition given by column names and their associated
|
|
*instantiated* Narwhals data type. Accepts a mapping or an iterable of tuples.
|
|
|
|
Examples:
|
|
Define a schema by passing *instantiated* data types.
|
|
|
|
>>> import narwhals as nw
|
|
>>> schema = nw.Schema({"foo": nw.Int8(), "bar": nw.String()})
|
|
>>> schema
|
|
Schema({'foo': Int8, 'bar': String})
|
|
|
|
Access the data type associated with a specific column name.
|
|
|
|
>>> schema["foo"]
|
|
Int8
|
|
|
|
Access various schema properties using the `names`, `dtypes`, and `len` methods.
|
|
|
|
>>> schema.names()
|
|
['foo', 'bar']
|
|
>>> schema.dtypes()
|
|
[Int8, String]
|
|
>>> schema.len()
|
|
2
|
|
"""
|
|
|
|
_version: ClassVar[Version] = Version.MAIN
|
|
|
|
def __init__(
|
|
self, schema: Mapping[str, DType] | Iterable[tuple[str, DType]] | None = None
|
|
) -> None:
|
|
schema = schema or {}
|
|
super().__init__(schema)
|
|
|
|
def names(self) -> list[str]:
|
|
"""Get the column names of the schema.
|
|
|
|
Returns:
|
|
Column names.
|
|
"""
|
|
return list(self.keys())
|
|
|
|
def dtypes(self) -> list[DType]:
|
|
"""Get the data types of the schema.
|
|
|
|
Returns:
|
|
Data types of schema.
|
|
"""
|
|
return list(self.values())
|
|
|
|
def len(self) -> int:
|
|
"""Get the number of columns in the schema.
|
|
|
|
Returns:
|
|
Number of columns.
|
|
"""
|
|
return len(self)
|
|
|
|
def to_arrow(self) -> pa.Schema:
|
|
"""Convert Schema to a pyarrow Schema.
|
|
|
|
Returns:
|
|
A pyarrow Schema.
|
|
|
|
Examples:
|
|
>>> import narwhals as nw
|
|
>>> schema = nw.Schema({"a": nw.Int64(), "b": nw.Datetime("ns")})
|
|
>>> schema.to_arrow()
|
|
a: int64
|
|
b: timestamp[ns]
|
|
"""
|
|
import pyarrow as pa # ignore-banned-import
|
|
|
|
from narwhals._arrow.utils import narwhals_to_native_dtype
|
|
|
|
return pa.schema(
|
|
(name, narwhals_to_native_dtype(dtype, self._version))
|
|
for name, dtype in self.items()
|
|
)
|
|
|
|
def to_pandas(
|
|
self, dtype_backend: DTypeBackend | Iterable[DTypeBackend] = None
|
|
) -> dict[str, Any]:
|
|
"""Convert Schema to an ordered mapping of column names to their pandas data type.
|
|
|
|
Arguments:
|
|
dtype_backend: Backend(s) used for the native types. When providing more than
|
|
one, the length of the iterable must be equal to the length of the schema.
|
|
|
|
Returns:
|
|
An ordered mapping of column names to their pandas data type.
|
|
|
|
Examples:
|
|
>>> import narwhals as nw
|
|
>>> schema = nw.Schema({"a": nw.Int64(), "b": nw.Datetime("ns")})
|
|
>>> schema.to_pandas()
|
|
{'a': 'int64', 'b': 'datetime64[ns]'}
|
|
|
|
>>> schema.to_pandas("pyarrow")
|
|
{'a': 'Int64[pyarrow]', 'b': 'timestamp[ns][pyarrow]'}
|
|
"""
|
|
from narwhals._pandas_like.utils import narwhals_to_native_dtype
|
|
|
|
to_native_dtype = partial(
|
|
narwhals_to_native_dtype,
|
|
implementation=Implementation.PANDAS,
|
|
version=self._version,
|
|
)
|
|
if dtype_backend is None or isinstance(dtype_backend, str):
|
|
return {
|
|
name: to_native_dtype(dtype=dtype, dtype_backend=dtype_backend)
|
|
for name, dtype in self.items()
|
|
}
|
|
else:
|
|
backends = tuple(dtype_backend)
|
|
if len(backends) != len(self):
|
|
from itertools import chain, islice, repeat
|
|
|
|
n_user, n_actual = len(backends), len(self)
|
|
suggestion = tuple(
|
|
islice(
|
|
chain.from_iterable(islice(repeat(backends), n_actual)), n_actual
|
|
)
|
|
)
|
|
msg = (
|
|
f"Provided {n_user!r} `dtype_backend`(s), but schema contains {n_actual!r} field(s).\n"
|
|
"Hint: instead of\n"
|
|
f" schema.to_pandas({backends})\n"
|
|
"you may want to use\n"
|
|
f" schema.to_pandas({backends[0]})\n"
|
|
f"or\n"
|
|
f" schema.to_pandas({suggestion})"
|
|
)
|
|
raise ValueError(msg)
|
|
return {
|
|
name: to_native_dtype(dtype=dtype, dtype_backend=backend)
|
|
for name, dtype, backend in zip(self.keys(), self.values(), backends)
|
|
}
|
|
|
|
def to_polars(self) -> pl.Schema:
|
|
"""Convert Schema to a polars Schema.
|
|
|
|
Returns:
|
|
A polars Schema or plain dict (prior to polars 1.0).
|
|
|
|
Examples:
|
|
>>> import narwhals as nw
|
|
>>> schema = nw.Schema({"a": nw.Int64(), "b": nw.Datetime("ns")})
|
|
>>> schema.to_polars()
|
|
Schema({'a': Int64, 'b': Datetime(time_unit='ns', time_zone=None)})
|
|
"""
|
|
import polars as pl # ignore-banned-import
|
|
|
|
from narwhals._polars.utils import narwhals_to_native_dtype
|
|
|
|
pl_version = Implementation.POLARS._backend_version()
|
|
schema = (
|
|
(name, narwhals_to_native_dtype(dtype, self._version))
|
|
for name, dtype in self.items()
|
|
)
|
|
return (
|
|
pl.Schema(schema)
|
|
if pl_version >= (1, 0, 0)
|
|
else cast("pl.Schema", dict(schema))
|
|
)
|