1416 lines
48 KiB
Python
1416 lines
48 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""A bunch of useful utilities for dealing with dataframes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import inspect
|
|
import math
|
|
import re
|
|
from collections import ChainMap, UserDict, UserList, deque
|
|
from collections.abc import ItemsView, Iterable, Mapping, Sequence
|
|
from enum import Enum, EnumMeta, auto
|
|
from types import MappingProxyType
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Final,
|
|
Protocol,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
runtime_checkable,
|
|
)
|
|
|
|
from typing_extensions import TypeAlias, TypeGuard
|
|
|
|
from streamlit import config, errors, logger, string_util
|
|
from streamlit.type_util import (
|
|
CustomDict,
|
|
has_callable_attr,
|
|
is_custom_dict,
|
|
is_dataclass_instance,
|
|
is_list_like,
|
|
is_namedtuple,
|
|
is_pydantic_model,
|
|
is_type,
|
|
is_version_less_than,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
from pandas import DataFrame, Index, Series
|
|
from pandas.core.indexing import _iLocIndexer
|
|
from pandas.io.formats.style import Styler
|
|
|
|
_LOGGER: Final = logger.get_logger(__name__)
|
|
|
|
|
|
# Maximum number of rows to request from an unevaluated (out-of-core) dataframe
|
|
_MAX_UNEVALUATED_DF_ROWS = 10000
|
|
|
|
_PANDAS_DATA_OBJECT_TYPE_RE: Final = re.compile(r"^pandas.*$")
|
|
|
|
_DASK_DATAFRAME: Final = "dask.dataframe.dask_expr._collection.DataFrame"
|
|
_DASK_SERIES: Final = "dask.dataframe.dask_expr._collection.Series"
|
|
_DASK_INDEX: Final = "dask.dataframe.dask_expr._collection.Index"
|
|
# Dask removed the old legacy types, to support older and newer versions
|
|
# we are still supporting the old an new types.
|
|
_DASK_DATAFRAME_LEGACY: Final = "dask.dataframe.core.DataFrame"
|
|
_DASK_SERIES_LEGACY: Final = "dask.dataframe.core.Series"
|
|
_DASK_INDEX_LEGACY: Final = "dask.dataframe.core.Index"
|
|
_DUCKDB_RELATION: Final = "duckdb.duckdb.DuckDBPyRelation"
|
|
_MODIN_DF_TYPE_STR: Final = "modin.pandas.dataframe.DataFrame"
|
|
_MODIN_SERIES_TYPE_STR: Final = "modin.pandas.series.Series"
|
|
_PANDAS_STYLER_TYPE_STR: Final = "pandas.io.formats.style.Styler"
|
|
_POLARS_DATAFRAME: Final = "polars.dataframe.frame.DataFrame"
|
|
_POLARS_LAZYFRAME: Final = "polars.lazyframe.frame.LazyFrame"
|
|
_POLARS_SERIES: Final = "polars.series.series.Series"
|
|
_PYSPARK_DF_TYPE_STR: Final = "pyspark.sql.dataframe.DataFrame"
|
|
_PYSPARK_CONNECT_DF_TYPE_STR: Final = "pyspark.sql.connect.dataframe.DataFrame"
|
|
_RAY_DATASET: Final = "ray.data.dataset.Dataset"
|
|
_RAY_MATERIALIZED_DATASET: Final = "ray.data.dataset.MaterializedDataset"
|
|
_SNOWPANDAS_DF_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.dataframe.DataFrame"
|
|
_SNOWPANDAS_INDEX_TYPE_STR: Final = (
|
|
"snowflake.snowpark.modin.plugin.extensions.index.Index"
|
|
)
|
|
_SNOWPANDAS_SERIES_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.series.Series"
|
|
_SNOWPARK_DF_ROW_TYPE_STR: Final = "snowflake.snowpark.row.Row"
|
|
_SNOWPARK_DF_TYPE_STR: Final = "snowflake.snowpark.dataframe.DataFrame"
|
|
_SNOWPARK_TABLE_TYPE_STR: Final = "snowflake.snowpark.table.Table"
|
|
_XARRAY_DATASET_TYPE_STR: Final = "xarray.core.dataset.Dataset"
|
|
_XARRAY_DATA_ARRAY_TYPE_STR: Final = "xarray.core.dataarray.DataArray"
|
|
|
|
V_co = TypeVar(
|
|
"V_co",
|
|
covariant=True, # https://peps.python.org/pep-0484/#covariance-and-contravariance
|
|
)
|
|
|
|
|
|
@runtime_checkable
|
|
class DBAPICursor(Protocol):
|
|
"""Protocol for DBAPI 2.0 Cursor objects (PEP 249).
|
|
|
|
This is a simplified version of the DBAPI Cursor protocol
|
|
that only contains the methods that are relevant or used for
|
|
our DB API Integration.
|
|
|
|
Specification: https://peps.python.org/pep-0249/
|
|
Inspired by: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/dbapi.pyi
|
|
"""
|
|
|
|
@property
|
|
def description(
|
|
self,
|
|
) -> (
|
|
Sequence[
|
|
tuple[
|
|
str,
|
|
Any | None,
|
|
int | None,
|
|
int | None,
|
|
int | None,
|
|
int | None,
|
|
bool | None,
|
|
]
|
|
]
|
|
| None
|
|
): ...
|
|
def fetchmany(self, size: int = ..., /) -> Sequence[Sequence[Any]]: ...
|
|
def fetchall(self) -> Sequence[Sequence[Any]]: ...
|
|
|
|
|
|
class DataFrameGenericAlias(Protocol[V_co]):
|
|
"""Technically not a GenericAlias, but serves the same purpose in
|
|
OptionSequence below, in that it is a type which admits DataFrame,
|
|
but is generic. This allows OptionSequence to be a fully generic type,
|
|
significantly increasing its usefulness.
|
|
|
|
We can't use types.GenericAlias, as it is only available from python>=3.9,
|
|
and isn't easily back-ported.
|
|
"""
|
|
|
|
@property
|
|
def iloc(self) -> _iLocIndexer: ...
|
|
|
|
|
|
class PandasCompatible(Protocol):
|
|
"""Protocol for Pandas compatible objects that have a `to_pandas` method."""
|
|
|
|
def to_pandas(self) -> DataFrame | Series: ...
|
|
|
|
|
|
class DataframeInterchangeCompatible(Protocol):
|
|
"""Protocol for objects support the dataframe-interchange protocol.
|
|
|
|
https://data-apis.org/dataframe-protocol/latest/index.html
|
|
"""
|
|
|
|
def __dataframe__(self, allow_copy: bool) -> Any: ...
|
|
|
|
|
|
OptionSequence: TypeAlias = Union[
|
|
Iterable[V_co],
|
|
DataFrameGenericAlias[V_co],
|
|
PandasCompatible,
|
|
DataframeInterchangeCompatible,
|
|
]
|
|
|
|
# Various data types supported by our dataframe processing
|
|
# used for commands like `st.dataframe`, `st.table`, `st.map`,
|
|
# st.line_chart`...
|
|
Data: TypeAlias = Union[
|
|
"DataFrame",
|
|
"Series",
|
|
"Styler",
|
|
"Index",
|
|
"pa.Table",
|
|
"pa.Array",
|
|
"np.ndarray[Any, np.dtype[Any]]",
|
|
Iterable[Any],
|
|
"Mapping[Any, Any]",
|
|
DBAPICursor,
|
|
PandasCompatible,
|
|
DataframeInterchangeCompatible,
|
|
CustomDict,
|
|
None,
|
|
]
|
|
|
|
|
|
class DataFormat(Enum):
|
|
"""DataFormat is used to determine the format of the data."""
|
|
|
|
UNKNOWN = auto()
|
|
EMPTY = auto() # None
|
|
|
|
COLUMN_INDEX_MAPPING = auto() # {column: {index: value}}
|
|
COLUMN_SERIES_MAPPING = auto() # {column: Series(values)}
|
|
COLUMN_VALUE_MAPPING = auto() # {column: List[values]}
|
|
DASK_OBJECT = auto() # dask.dataframe.core.DataFrame, Series, Index
|
|
DBAPI_CURSOR = auto() # DBAPI Cursor (PEP 249)
|
|
DUCKDB_RELATION = auto() # DuckDB Relation
|
|
KEY_VALUE_DICT = auto() # {index: value}
|
|
LIST_OF_RECORDS = auto() # List[Dict[str, Scalar]]
|
|
LIST_OF_ROWS = auto() # List[List[Scalar]]
|
|
LIST_OF_VALUES = auto() # List[Scalar]
|
|
MODIN_OBJECT = auto() # Modin DataFrame, Series
|
|
NUMPY_LIST = auto() # np.array[Scalar]
|
|
NUMPY_MATRIX = auto() # np.array[List[Scalar]]
|
|
PANDAS_ARRAY = auto() # pd.array
|
|
PANDAS_DATAFRAME = auto() # pd.DataFrame
|
|
PANDAS_INDEX = auto() # pd.Index
|
|
PANDAS_SERIES = auto() # pd.Series
|
|
PANDAS_STYLER = auto() # pandas Styler
|
|
POLARS_DATAFRAME = auto() # polars.dataframe.frame.DataFrame
|
|
POLARS_LAZYFRAME = auto() # polars.lazyframe.frame.LazyFrame
|
|
POLARS_SERIES = auto() # polars.series.series.Series
|
|
PYARROW_ARRAY = auto() # pyarrow.Array
|
|
PYARROW_TABLE = auto() # pyarrow.Table
|
|
PYSPARK_OBJECT = auto() # pyspark.DataFrame
|
|
RAY_DATASET = auto() # ray.data.dataset.Dataset, MaterializedDataset
|
|
SET_OF_VALUES = auto() # Set[Scalar]
|
|
SNOWPANDAS_OBJECT = auto() # Snowpandas DataFrame, Series
|
|
SNOWPARK_OBJECT = auto() # Snowpark DataFrame, Table, List[Row]
|
|
TUPLE_OF_VALUES = auto() # Tuple[Scalar]
|
|
XARRAY_DATASET = auto() # xarray.Dataset
|
|
XARRAY_DATA_ARRAY = auto() # xarray.DataArray
|
|
|
|
|
|
def is_pyarrow_version_less_than(v: str) -> bool:
|
|
"""Return True if the current Pyarrow version is less than the input version.
|
|
|
|
Parameters
|
|
----------
|
|
v : str
|
|
Version string, e.g. "0.25.0"
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
|
|
|
|
Raises
|
|
------
|
|
InvalidVersion
|
|
If the version strings are not valid.
|
|
|
|
"""
|
|
import pyarrow as pa
|
|
|
|
return is_version_less_than(pa.__version__, v)
|
|
|
|
|
|
def is_pandas_version_less_than(v: str) -> bool:
|
|
"""Return True if the current Pandas version is less than the input version.
|
|
|
|
Parameters
|
|
----------
|
|
v : str
|
|
Version string, e.g. "0.25.0"
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
|
|
|
|
Raises
|
|
------
|
|
InvalidVersion
|
|
If the version strings are not valid.
|
|
"""
|
|
import pandas as pd
|
|
|
|
return is_version_less_than(pd.__version__, v)
|
|
|
|
|
|
def is_dataframe_like(obj: object) -> bool:
|
|
"""True if the object is a dataframe-like object.
|
|
|
|
This does not include basic collection types like list, dict, tuple, etc.
|
|
"""
|
|
|
|
# We exclude list and dict here since there are some cases where a list or dict is
|
|
# considered a dataframe-like object.
|
|
if obj is None or isinstance(obj, (tuple, set, str, bytes, int, float, bool)):
|
|
# Basic types are not considered dataframe-like, so we can
|
|
# return False early to avoid unnecessary checks.
|
|
return False
|
|
|
|
return determine_data_format(obj) in {
|
|
DataFormat.COLUMN_SERIES_MAPPING,
|
|
DataFormat.DASK_OBJECT,
|
|
DataFormat.DBAPI_CURSOR,
|
|
DataFormat.MODIN_OBJECT,
|
|
DataFormat.NUMPY_LIST,
|
|
DataFormat.NUMPY_MATRIX,
|
|
DataFormat.PANDAS_ARRAY,
|
|
DataFormat.PANDAS_DATAFRAME,
|
|
DataFormat.PANDAS_INDEX,
|
|
DataFormat.PANDAS_SERIES,
|
|
DataFormat.PANDAS_STYLER,
|
|
DataFormat.POLARS_DATAFRAME,
|
|
DataFormat.POLARS_LAZYFRAME,
|
|
DataFormat.POLARS_SERIES,
|
|
DataFormat.PYARROW_ARRAY,
|
|
DataFormat.PYARROW_TABLE,
|
|
DataFormat.PYSPARK_OBJECT,
|
|
DataFormat.RAY_DATASET,
|
|
DataFormat.SNOWPANDAS_OBJECT,
|
|
DataFormat.SNOWPARK_OBJECT,
|
|
DataFormat.XARRAY_DATASET,
|
|
DataFormat.XARRAY_DATA_ARRAY,
|
|
}
|
|
|
|
|
|
def is_unevaluated_data_object(obj: object) -> bool:
|
|
"""True if the object is one of the supported unevaluated data objects.
|
|
|
|
Currently supported objects are:
|
|
- Snowpark DataFrame / Table
|
|
- PySpark DataFrame
|
|
- Modin DataFrame / Series
|
|
- Snowpandas DataFrame / Series / Index
|
|
- Dask DataFrame / Series / Index
|
|
- Ray Dataset
|
|
- Polars LazyFrame
|
|
- Generator functions
|
|
- DB API 2.0 Cursor (PEP 249)
|
|
- DuckDB Relation (Relational API)
|
|
|
|
Unevaluated means that the data is not yet in the local memory.
|
|
Unevaluated data objects are treated differently from other data objects by only
|
|
requesting a subset of the data instead of loading all data into th memory
|
|
"""
|
|
return (
|
|
is_snowpark_data_object(obj)
|
|
or is_pyspark_data_object(obj)
|
|
or is_snowpandas_data_object(obj)
|
|
or is_modin_data_object(obj)
|
|
or is_ray_dataset(obj)
|
|
or is_polars_lazyframe(obj)
|
|
or is_dask_object(obj)
|
|
or is_duckdb_relation(obj)
|
|
or is_dbapi_cursor(obj)
|
|
or inspect.isgeneratorfunction(obj)
|
|
)
|
|
|
|
|
|
def is_pandas_data_object(obj: object) -> bool:
|
|
"""True if obj is a Pandas object (e.g. DataFrame, Series, Index, Styler, ...)."""
|
|
return is_type(obj, _PANDAS_DATA_OBJECT_TYPE_RE)
|
|
|
|
|
|
def is_snowpark_data_object(obj: object) -> bool:
|
|
"""True if obj is a Snowpark DataFrame or Table."""
|
|
return is_type(obj, _SNOWPARK_TABLE_TYPE_STR) or is_type(obj, _SNOWPARK_DF_TYPE_STR)
|
|
|
|
|
|
def is_snowpark_row_list(obj: object) -> bool:
|
|
"""True if obj is a list of snowflake.snowpark.row.Row."""
|
|
return (
|
|
isinstance(obj, list)
|
|
and len(obj) > 0
|
|
and is_type(obj[0], _SNOWPARK_DF_ROW_TYPE_STR)
|
|
and has_callable_attr(obj[0], "as_dict")
|
|
)
|
|
|
|
|
|
def is_pyspark_data_object(obj: object) -> bool:
|
|
"""True if obj is a PySpark or PySpark Connect dataframe."""
|
|
return (
|
|
is_type(obj, _PYSPARK_DF_TYPE_STR) or is_type(obj, _PYSPARK_CONNECT_DF_TYPE_STR)
|
|
) and has_callable_attr(obj, "toPandas")
|
|
|
|
|
|
def is_dask_object(obj: object) -> bool:
|
|
"""True if obj is a Dask DataFrame, Series, or Index."""
|
|
return (
|
|
is_type(obj, _DASK_DATAFRAME)
|
|
or is_type(obj, _DASK_DATAFRAME_LEGACY)
|
|
or is_type(obj, _DASK_SERIES)
|
|
or is_type(obj, _DASK_SERIES_LEGACY)
|
|
or is_type(obj, _DASK_INDEX)
|
|
or is_type(obj, _DASK_INDEX_LEGACY)
|
|
)
|
|
|
|
|
|
def is_modin_data_object(obj: object) -> bool:
|
|
"""True if obj is of Modin Dataframe or Series."""
|
|
return is_type(obj, _MODIN_DF_TYPE_STR) or is_type(obj, _MODIN_SERIES_TYPE_STR)
|
|
|
|
|
|
def is_snowpandas_data_object(obj: object) -> bool:
|
|
"""True if obj is a Snowpark Pandas DataFrame or Series."""
|
|
return (
|
|
is_type(obj, _SNOWPANDAS_DF_TYPE_STR)
|
|
or is_type(obj, _SNOWPANDAS_SERIES_TYPE_STR)
|
|
or is_type(obj, _SNOWPANDAS_INDEX_TYPE_STR)
|
|
)
|
|
|
|
|
|
def is_polars_dataframe(obj: object) -> bool:
|
|
"""True if obj is a Polars Dataframe."""
|
|
return is_type(obj, _POLARS_DATAFRAME)
|
|
|
|
|
|
def is_xarray_dataset(obj: object) -> bool:
|
|
"""True if obj is a Xarray Dataset."""
|
|
return is_type(obj, _XARRAY_DATASET_TYPE_STR)
|
|
|
|
|
|
def is_xarray_data_array(obj: object) -> bool:
|
|
"""True if obj is a Xarray DataArray."""
|
|
return is_type(obj, _XARRAY_DATA_ARRAY_TYPE_STR)
|
|
|
|
|
|
def is_polars_series(obj: object) -> bool:
|
|
"""True if obj is a Polars Series."""
|
|
return is_type(obj, _POLARS_SERIES)
|
|
|
|
|
|
def is_polars_lazyframe(obj: object) -> bool:
|
|
"""True if obj is a Polars Lazyframe."""
|
|
return is_type(obj, _POLARS_LAZYFRAME)
|
|
|
|
|
|
def is_ray_dataset(obj: object) -> bool:
|
|
"""True if obj is a Ray Dataset."""
|
|
return is_type(obj, _RAY_DATASET) or is_type(obj, _RAY_MATERIALIZED_DATASET)
|
|
|
|
|
|
def is_pandas_styler(obj: object) -> TypeGuard[Styler]:
|
|
"""True if obj is a pandas Styler."""
|
|
return is_type(obj, _PANDAS_STYLER_TYPE_STR)
|
|
|
|
|
|
def is_dbapi_cursor(obj: object) -> TypeGuard[DBAPICursor]:
|
|
"""True if obj looks like a DB API 2.0 Cursor.
|
|
|
|
https://peps.python.org/pep-0249/
|
|
"""
|
|
return isinstance(obj, DBAPICursor)
|
|
|
|
|
|
def is_duckdb_relation(obj: object) -> bool:
|
|
"""True if obj is a DuckDB relation.
|
|
|
|
https://duckdb.org/docs/api/python/relational_api
|
|
"""
|
|
|
|
return is_type(obj, _DUCKDB_RELATION)
|
|
|
|
|
|
def _is_list_of_scalars(data: Iterable[Any]) -> bool:
|
|
"""Check if the list only contains scalar values."""
|
|
from pandas.api.types import infer_dtype
|
|
|
|
# Overview on all value that are interpreted as scalar:
|
|
# https://pandas.pydata.org/docs/reference/api/pandas.api.types.is_scalar.html
|
|
return infer_dtype(data, skipna=True) not in ["mixed", "unknown-array"]
|
|
|
|
|
|
def _iterable_to_list(
|
|
iterable: Iterable[Any], max_iterations: int | None = None
|
|
) -> list[Any]:
|
|
"""Convert an iterable to a list.
|
|
|
|
Parameters
|
|
----------
|
|
iterable : Iterable
|
|
The iterable to convert to a list.
|
|
|
|
max_iterations : int or None
|
|
The maximum number of iterations to perform. If None, all iterations are performed.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The converted list.
|
|
"""
|
|
if max_iterations is None:
|
|
return list(iterable)
|
|
|
|
result = []
|
|
for i, item in enumerate(iterable):
|
|
if i >= max_iterations:
|
|
break
|
|
result.append(item)
|
|
return result
|
|
|
|
|
|
def _fix_column_naming(data_df: DataFrame) -> DataFrame:
|
|
"""Rename the first column to "value" if it is not named
|
|
and if there is only one column in the dataframe.
|
|
|
|
The default name of the first column is 0 if it is not named
|
|
which is not very descriptive.
|
|
"""
|
|
|
|
if len(data_df.columns) == 1 and data_df.columns[0] == 0:
|
|
# Pandas automatically names the first column with 0 if it is not named.
|
|
# We rename it to "value" to make it more descriptive if there is only
|
|
# one column in the dataframe.
|
|
data_df = data_df.rename(columns={0: "value"})
|
|
return data_df
|
|
|
|
|
|
def _dict_to_pandas_df(data: dict[Any, Any]) -> DataFrame:
|
|
"""Convert a key-value dict to a Pandas DataFrame.
|
|
|
|
Parameters
|
|
----------
|
|
data : dict
|
|
The dict to convert to a Pandas DataFrame.
|
|
|
|
Returns
|
|
-------
|
|
pandas.DataFrame
|
|
The converted Pandas DataFrame.
|
|
"""
|
|
import pandas as pd
|
|
|
|
return _fix_column_naming(pd.DataFrame.from_dict(data, orient="index"))
|
|
|
|
|
|
def convert_anything_to_pandas_df(
|
|
data: Any,
|
|
max_unevaluated_rows: int = _MAX_UNEVALUATED_DF_ROWS,
|
|
ensure_copy: bool = False,
|
|
) -> DataFrame:
|
|
"""Try to convert different formats to a Pandas Dataframe.
|
|
|
|
Parameters
|
|
----------
|
|
data : dataframe-, array-, or collections-like object
|
|
The data to convert to a Pandas DataFrame.
|
|
|
|
max_unevaluated_rows: int
|
|
If unevaluated data is detected this func will evaluate it,
|
|
taking max_unevaluated_rows, defaults to 10k.
|
|
|
|
ensure_copy: bool
|
|
If True, make sure to always return a copy of the data. If False, it depends on
|
|
the type of the data. For example, a Pandas DataFrame will be returned as-is.
|
|
|
|
Returns
|
|
-------
|
|
pandas.DataFrame
|
|
|
|
"""
|
|
import array
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
if isinstance(data, pd.DataFrame):
|
|
return data.copy() if ensure_copy else cast("pd.DataFrame", data)
|
|
|
|
if isinstance(data, (pd.Series, pd.Index, pd.api.extensions.ExtensionArray)):
|
|
return pd.DataFrame(data)
|
|
|
|
if is_pandas_styler(data):
|
|
return cast("pd.DataFrame", data.data.copy() if ensure_copy else data.data)
|
|
|
|
if isinstance(data, np.ndarray):
|
|
return (
|
|
pd.DataFrame([])
|
|
if len(data.shape) == 0
|
|
else _fix_column_naming(pd.DataFrame(data))
|
|
)
|
|
|
|
if is_polars_dataframe(data):
|
|
data = data.clone() if ensure_copy else data
|
|
return data.to_pandas()
|
|
|
|
if is_polars_series(data):
|
|
data = data.clone() if ensure_copy else data
|
|
return data.to_pandas().to_frame()
|
|
|
|
if is_polars_lazyframe(data):
|
|
data = data.limit(max_unevaluated_rows).collect().to_pandas()
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `collect()` on the dataframe to show more."
|
|
)
|
|
return cast("pd.DataFrame", data)
|
|
|
|
if is_xarray_dataset(data):
|
|
if ensure_copy:
|
|
data = data.copy(deep=True)
|
|
return data.to_dataframe()
|
|
|
|
if is_xarray_data_array(data):
|
|
if ensure_copy:
|
|
data = data.copy(deep=True)
|
|
return data.to_series().to_frame()
|
|
|
|
if is_dask_object(data):
|
|
data = data.head(max_unevaluated_rows, compute=True)
|
|
|
|
# Dask returns a Pandas object (DataFrame, Series, Index) when
|
|
# executing operations like `head`.
|
|
if isinstance(data, (pd.Series, pd.Index)):
|
|
data = data.to_frame()
|
|
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `compute()` on the data object to show more."
|
|
)
|
|
return cast("pd.DataFrame", data)
|
|
|
|
if is_ray_dataset(data):
|
|
data = data.limit(max_unevaluated_rows).to_pandas()
|
|
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `to_pandas()` on the dataset to show more."
|
|
)
|
|
return cast("pd.DataFrame", data)
|
|
|
|
if is_modin_data_object(data):
|
|
data = data.head(max_unevaluated_rows)._to_pandas()
|
|
|
|
if isinstance(data, (pd.Series, pd.Index)):
|
|
data = data.to_frame()
|
|
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `_to_pandas()` on the data object to show more."
|
|
)
|
|
return cast("pd.DataFrame", data)
|
|
|
|
if is_pyspark_data_object(data):
|
|
data = data.limit(max_unevaluated_rows).toPandas()
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `toPandas()` on the data object to show more."
|
|
)
|
|
return cast("pd.DataFrame", data)
|
|
|
|
if is_snowpandas_data_object(data):
|
|
data = data[:max_unevaluated_rows].to_pandas()
|
|
|
|
if isinstance(data, (pd.Series, pd.Index)):
|
|
data = data.to_frame()
|
|
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `to_pandas()` on the data object to show more."
|
|
)
|
|
return cast("pd.DataFrame", data)
|
|
|
|
if is_snowpark_data_object(data):
|
|
data = data.limit(max_unevaluated_rows).to_pandas()
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `to_pandas()` on the data object to show more."
|
|
)
|
|
return cast("pd.DataFrame", data)
|
|
|
|
if is_duckdb_relation(data):
|
|
data = data.limit(max_unevaluated_rows).df()
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `df()` on the relation to show more."
|
|
)
|
|
return data
|
|
|
|
if is_dbapi_cursor(data):
|
|
# Based on the specification, the first item in the description is the
|
|
# column name (if available)
|
|
columns = (
|
|
[d[0] if d else "" for d in data.description] if data.description else None
|
|
)
|
|
data = pd.DataFrame(data.fetchmany(max_unevaluated_rows), columns=columns)
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Call `fetchall()` on the Cursor to show more."
|
|
)
|
|
return data
|
|
|
|
if is_snowpark_row_list(data):
|
|
return pd.DataFrame([row.as_dict() for row in data])
|
|
|
|
if has_callable_attr(data, "to_pandas"):
|
|
return pd.DataFrame(data.to_pandas())
|
|
|
|
# Check for dataframe interchange protocol
|
|
# Only available in pandas >= 1.5.0
|
|
# https://pandas.pydata.org/docs/whatsnew/v1.5.0.html#dataframe-interchange-protocol-implementation
|
|
if (
|
|
has_callable_attr(data, "__dataframe__")
|
|
and is_pandas_version_less_than("1.5.0") is False
|
|
):
|
|
data_df = pd.api.interchange.from_dataframe(data)
|
|
return data_df.copy() if ensure_copy else data_df
|
|
|
|
# Support for generator functions
|
|
if inspect.isgeneratorfunction(data):
|
|
data = _fix_column_naming(
|
|
pd.DataFrame(_iterable_to_list(data(), max_iterations=max_unevaluated_rows))
|
|
)
|
|
|
|
if data.shape[0] == max_unevaluated_rows:
|
|
_show_data_information(
|
|
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
|
|
"rows. Convert the data to a list to show more."
|
|
)
|
|
return data
|
|
|
|
if isinstance(data, EnumMeta):
|
|
# Support for enum classes
|
|
return _fix_column_naming(pd.DataFrame([c.value for c in data])) # type: ignore
|
|
|
|
# Support for some list like objects
|
|
if isinstance(data, (deque, map, array.ArrayType, UserList)):
|
|
return _fix_column_naming(pd.DataFrame(list(data)))
|
|
|
|
# Support for Streamlit's custom dict-like objects
|
|
if is_custom_dict(data):
|
|
return _dict_to_pandas_df(data.to_dict())
|
|
|
|
# Support for named tuples
|
|
if is_namedtuple(data):
|
|
return _dict_to_pandas_df(data._asdict())
|
|
|
|
# Support for dataclass instances
|
|
if is_dataclass_instance(data):
|
|
return _dict_to_pandas_df(dataclasses.asdict(data))
|
|
|
|
# Support for dict-like objects
|
|
if isinstance(data, (ChainMap, MappingProxyType, UserDict)) or is_pydantic_model(
|
|
data
|
|
):
|
|
return _dict_to_pandas_df(dict(data))
|
|
|
|
# Try to convert to pandas.DataFrame. This will raise an error is df is not
|
|
# compatible with the pandas.DataFrame constructor.
|
|
try:
|
|
return _fix_column_naming(pd.DataFrame(data))
|
|
except ValueError as ex:
|
|
if isinstance(data, dict):
|
|
with contextlib.suppress(ValueError):
|
|
# Try to use index orient as back-up to support key-value dicts
|
|
return _dict_to_pandas_df(data)
|
|
raise errors.StreamlitAPIException(
|
|
f"""
|
|
Unable to convert object of type `{type(data)}` to `pandas.DataFrame`.
|
|
Offending object:
|
|
```py
|
|
{data}
|
|
```"""
|
|
) from ex
|
|
|
|
|
|
def convert_arrow_table_to_arrow_bytes(table: pa.Table) -> bytes:
|
|
"""Serialize pyarrow.Table to Arrow IPC bytes.
|
|
|
|
Parameters
|
|
----------
|
|
table : pyarrow.Table
|
|
A table to convert.
|
|
|
|
Returns
|
|
-------
|
|
bytes
|
|
The serialized Arrow IPC bytes.
|
|
"""
|
|
try:
|
|
table = _maybe_truncate_table(table)
|
|
except RecursionError as err:
|
|
# This is a very unlikely edge case, but we want to make sure that
|
|
# it doesn't lead to unexpected behavior.
|
|
# If there is a recursion error, we just return the table as-is
|
|
# which will lead to the normal message limit exceed error.
|
|
_LOGGER.warning(
|
|
"Recursion error while truncating Arrow table. This is not "
|
|
"supposed to happen.",
|
|
exc_info=err,
|
|
)
|
|
|
|
import pyarrow as pa
|
|
|
|
# Convert table to bytes
|
|
sink = pa.BufferOutputStream()
|
|
writer = pa.RecordBatchStreamWriter(sink, table.schema)
|
|
writer.write_table(table)
|
|
writer.close()
|
|
return cast("bytes", sink.getvalue().to_pybytes())
|
|
|
|
|
|
def convert_pandas_df_to_arrow_bytes(df: DataFrame) -> bytes:
|
|
"""Serialize pandas.DataFrame to Arrow IPC bytes.
|
|
|
|
Parameters
|
|
----------
|
|
df : pandas.DataFrame
|
|
A dataframe to convert.
|
|
|
|
Returns
|
|
-------
|
|
bytes
|
|
The serialized Arrow IPC bytes.
|
|
"""
|
|
import pyarrow as pa
|
|
|
|
try:
|
|
table = pa.Table.from_pandas(df)
|
|
except (pa.ArrowTypeError, pa.ArrowInvalid, pa.ArrowNotImplementedError) as ex:
|
|
_LOGGER.info(
|
|
"Serialization of dataframe to Arrow table was unsuccessful. "
|
|
"Applying automatic fixes for column types to make the dataframe "
|
|
"Arrow-compatible.",
|
|
exc_info=ex,
|
|
)
|
|
df = fix_arrow_incompatible_column_types(df)
|
|
table = pa.Table.from_pandas(df)
|
|
return convert_arrow_table_to_arrow_bytes(table)
|
|
|
|
|
|
def convert_arrow_bytes_to_pandas_df(source: bytes) -> DataFrame:
|
|
"""Convert Arrow bytes (IPC format) to pandas.DataFrame.
|
|
|
|
Using this function in production needs to make sure that
|
|
the pyarrow version >= 14.0.1, because of a critical
|
|
security vulnerability in pyarrow < 14.0.1.
|
|
|
|
Parameters
|
|
----------
|
|
source : bytes
|
|
A bytes object to convert.
|
|
|
|
Returns
|
|
-------
|
|
pandas.DataFrame
|
|
The converted dataframe.
|
|
"""
|
|
import pyarrow as pa
|
|
|
|
reader = pa.RecordBatchStreamReader(source)
|
|
return reader.read_pandas()
|
|
|
|
|
|
def _show_data_information(msg: str) -> None:
|
|
"""Show a message to the user with important information
|
|
about the processed dataset.
|
|
"""
|
|
from streamlit.delta_generator_singletons import get_dg_singleton_instance
|
|
|
|
get_dg_singleton_instance().main_dg.caption(msg)
|
|
|
|
|
|
def convert_anything_to_arrow_bytes(
|
|
data: Any,
|
|
max_unevaluated_rows: int = _MAX_UNEVALUATED_DF_ROWS,
|
|
) -> bytes:
|
|
"""Try to convert different formats to Arrow IPC format (bytes).
|
|
|
|
This method tries to directly convert the input data to Arrow bytes
|
|
for some supported formats, but falls back to conversion to a Pandas
|
|
DataFrame and then to Arrow bytes.
|
|
|
|
Parameters
|
|
----------
|
|
data : dataframe-, array-, or collections-like object
|
|
The data to convert to Arrow bytes.
|
|
|
|
max_unevaluated_rows: int
|
|
If unevaluated data is detected this func will evaluate it,
|
|
taking max_unevaluated_rows, defaults to 10k.
|
|
|
|
Returns
|
|
-------
|
|
bytes
|
|
The serialized Arrow IPC bytes.
|
|
"""
|
|
|
|
import pyarrow as pa
|
|
|
|
if isinstance(data, pa.Table):
|
|
return convert_arrow_table_to_arrow_bytes(data)
|
|
|
|
# TODO(lukasmasuch): Add direct conversion to Arrow for supported formats here
|
|
|
|
# Fallback: try to convert to pandas DataFrame
|
|
# and then to Arrow bytes.
|
|
df = convert_anything_to_pandas_df(data, max_unevaluated_rows)
|
|
return convert_pandas_df_to_arrow_bytes(df)
|
|
|
|
|
|
def convert_anything_to_list(obj: OptionSequence[V_co]) -> list[V_co]:
|
|
"""Try to convert different formats to a list.
|
|
|
|
If the input is a dataframe-like object, we just select the first
|
|
column to iterate over. Non sequence-like objects and scalar types,
|
|
will just be wrapped into a list.
|
|
|
|
Parameters
|
|
----------
|
|
obj : dataframe-, array-, or collections-like object
|
|
The object to convert to a list.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The converted list.
|
|
"""
|
|
if obj is None:
|
|
return [] # type: ignore
|
|
|
|
if isinstance(obj, (str, int, float, bool)):
|
|
# Wrap basic objects into a list
|
|
return [obj] # type: ignore[list-item]
|
|
|
|
if isinstance(obj, EnumMeta):
|
|
# Support for enum classes. For string enums, we return the string value
|
|
# of the enum members. For other enums, we just return the enum member.
|
|
return [member.value if isinstance(member, str) else member for member in obj] # type: ignore
|
|
|
|
if isinstance(obj, Mapping):
|
|
return list(obj.keys())
|
|
|
|
if is_list_like(obj) and not is_snowpark_row_list(obj):
|
|
# This also ensures that the sequence is copied to prevent
|
|
# potential mutations to the original object.
|
|
return list(obj)
|
|
|
|
# Fallback to our DataFrame conversion logic:
|
|
try:
|
|
# We use ensure_copy here because the return value of this function is
|
|
# saved in a widget serde class instance to be used in later script runs,
|
|
# and we don't want mutations to the options object passed to a
|
|
# widget affect the widget.
|
|
# (See https://github.com/streamlit/streamlit/issues/7534)
|
|
data_df = convert_anything_to_pandas_df(obj, ensure_copy=True)
|
|
# Return first column as a list:
|
|
return (
|
|
[]
|
|
if data_df.empty
|
|
else cast("list[V_co]", list(data_df.iloc[:, 0].to_list()))
|
|
)
|
|
except errors.StreamlitAPIException:
|
|
# Wrap the object into a list
|
|
return [obj] # type: ignore
|
|
|
|
|
|
def _maybe_truncate_table(
|
|
table: pa.Table, truncated_rows: int | None = None
|
|
) -> pa.Table:
|
|
"""Experimental feature to automatically truncate tables that
|
|
are larger than the maximum allowed message size. It needs to be enabled
|
|
via the server.enableArrowTruncation config option.
|
|
|
|
Parameters
|
|
----------
|
|
table : pyarrow.Table
|
|
A table to truncate.
|
|
|
|
truncated_rows : int or None
|
|
The number of rows that have been truncated so far. This is used by
|
|
the recursion logic to keep track of the total number of truncated
|
|
rows.
|
|
|
|
"""
|
|
|
|
if config.get_option("server.enableArrowTruncation"):
|
|
# This is an optimization problem: We don't know at what row
|
|
# the perfect cut-off is to comply with the max size. But we want to figure
|
|
# it out in as few iterations as possible. We almost always will cut out
|
|
# more than required to keep the iterations low.
|
|
|
|
# The maximum size allowed for protobuf messages in bytes:
|
|
max_message_size = int(config.get_option("server.maxMessageSize") * 1e6)
|
|
# We add 1 MB for other overhead related to the protobuf message.
|
|
# This is a very conservative estimate, but it should be good enough.
|
|
table_size = int(table.nbytes + 1 * 1e6)
|
|
table_rows = table.num_rows
|
|
|
|
if table_rows > 1 and table_size > max_message_size:
|
|
# targeted rows == the number of rows the table should be truncated to.
|
|
# Calculate an approximation of how many rows we need to truncate to.
|
|
targeted_rows = math.ceil(table_rows * (max_message_size / table_size))
|
|
# Make sure to cut out at least a couple of rows to avoid running
|
|
# this logic too often since it is quite inefficient and could lead
|
|
# to infinity recursions without these precautions.
|
|
targeted_rows = math.floor(
|
|
max(
|
|
min(
|
|
# Cut out:
|
|
# an additional 5% of the estimated num rows to cut out:
|
|
targeted_rows - math.floor((table_rows - targeted_rows) * 0.05),
|
|
# at least 1% of table size:
|
|
table_rows - (table_rows * 0.01),
|
|
# at least 5 rows:
|
|
table_rows - 5,
|
|
),
|
|
1, # but it should always have at least 1 row
|
|
)
|
|
)
|
|
sliced_table = table.slice(0, targeted_rows)
|
|
return _maybe_truncate_table(
|
|
sliced_table, (truncated_rows or 0) + (table_rows - targeted_rows)
|
|
)
|
|
|
|
if truncated_rows:
|
|
displayed_rows = string_util.simplify_number(table.num_rows)
|
|
total_rows = string_util.simplify_number(table.num_rows + truncated_rows)
|
|
|
|
if displayed_rows == total_rows:
|
|
# If the simplified numbers are the same,
|
|
# we just display the exact numbers.
|
|
displayed_rows = str(table.num_rows)
|
|
total_rows = str(table.num_rows + truncated_rows)
|
|
_show_data_information(
|
|
f"⚠️ Showing {displayed_rows} out of {total_rows} "
|
|
"rows due to data size limitations."
|
|
)
|
|
|
|
return table
|
|
|
|
|
|
def is_colum_type_arrow_incompatible(column: Series[Any] | Index) -> bool:
|
|
"""Return True if the column type is known to cause issues during
|
|
Arrow conversion.
|
|
"""
|
|
from pandas.api.types import infer_dtype, is_dict_like, is_list_like
|
|
|
|
if column.dtype.kind in [
|
|
"c", # complex64, complex128, complex256
|
|
]:
|
|
return True
|
|
|
|
if str(column.dtype) in {
|
|
# These period types are not yet supported by our frontend impl.
|
|
# See comments in Quiver.ts for more details.
|
|
"period[B]",
|
|
"period[N]",
|
|
"period[ns]",
|
|
"period[U]",
|
|
"period[us]",
|
|
"geometry",
|
|
}:
|
|
return True
|
|
|
|
if column.dtype == "object":
|
|
# The dtype of mixed type columns is always object, the actual type of the column
|
|
# values can be determined via the infer_dtype function:
|
|
# https://pandas.pydata.org/docs/reference/api/pandas.api.types.infer_dtype.html
|
|
inferred_type = infer_dtype(column, skipna=True)
|
|
|
|
if inferred_type in [
|
|
"mixed-integer",
|
|
"complex",
|
|
]:
|
|
return True
|
|
if inferred_type == "mixed":
|
|
# This includes most of the more complex/custom types (objects, dicts,
|
|
# lists, ...)
|
|
if len(column) == 0 or not hasattr(column, "iloc"):
|
|
# The column seems to be invalid, so we assume it is incompatible.
|
|
# But this would most likely never happen since empty columns
|
|
# cannot be mixed.
|
|
return True
|
|
|
|
# Get the first value to check if it is a supported list-like type.
|
|
first_value = column.iloc[0]
|
|
|
|
if ( # noqa: SIM103
|
|
not is_list_like(first_value)
|
|
# dicts are list-like, but have issues in Arrow JS (see comments in
|
|
# Quiver.ts)
|
|
or is_dict_like(first_value)
|
|
# Frozensets are list-like, but are not compatible with pyarrow.
|
|
or isinstance(first_value, frozenset)
|
|
):
|
|
# This seems to be an incompatible list-like type
|
|
return True
|
|
return False
|
|
# We did not detect an incompatible type, so we assume it is compatible:
|
|
return False
|
|
|
|
|
|
def fix_arrow_incompatible_column_types(
|
|
df: DataFrame, selected_columns: list[str] | None = None
|
|
) -> DataFrame:
|
|
"""Fix column types that are not supported by Arrow table.
|
|
|
|
This includes mixed types (e.g. mix of integers and strings)
|
|
as well as complex numbers (complex128 type). These types will cause
|
|
errors during conversion of the dataframe to an Arrow table.
|
|
It is fixed by converting all values of the column to strings
|
|
This is sufficient for displaying the data on the frontend.
|
|
|
|
Parameters
|
|
----------
|
|
df : pandas.DataFrame
|
|
A dataframe to fix.
|
|
|
|
selected_columns: List[str] or None
|
|
A list of columns to fix. If None, all columns are evaluated.
|
|
|
|
Returns
|
|
-------
|
|
The fixed dataframe.
|
|
"""
|
|
import pandas as pd
|
|
|
|
# Make a copy, but only initialize if necessary to preserve memory.
|
|
df_copy: DataFrame | None = None
|
|
for col in selected_columns or df.columns:
|
|
if is_colum_type_arrow_incompatible(df[col]):
|
|
if df_copy is None:
|
|
df_copy = df.copy()
|
|
df_copy[col] = df[col].astype("string")
|
|
|
|
# The index can also contain mixed types
|
|
# causing Arrow issues during conversion.
|
|
# Skipping multi-indices since they won't return
|
|
# the correct value from infer_dtype
|
|
if not selected_columns and (
|
|
not isinstance(
|
|
df.index,
|
|
pd.MultiIndex,
|
|
)
|
|
and is_colum_type_arrow_incompatible(df.index)
|
|
):
|
|
if df_copy is None:
|
|
df_copy = df.copy()
|
|
df_copy.index = df.index.astype("string")
|
|
return df_copy if df_copy is not None else df
|
|
|
|
|
|
def determine_data_format(input_data: Any) -> DataFormat:
|
|
"""Determine the data format of the input data.
|
|
|
|
Parameters
|
|
----------
|
|
input_data : Any
|
|
The input data to determine the data format of.
|
|
|
|
Returns
|
|
-------
|
|
DataFormat
|
|
The data format of the input data.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pyarrow as pa
|
|
|
|
if input_data is None:
|
|
return DataFormat.EMPTY
|
|
if isinstance(input_data, pd.DataFrame):
|
|
return DataFormat.PANDAS_DATAFRAME
|
|
if isinstance(input_data, np.ndarray):
|
|
if len(input_data.shape) == 1:
|
|
# For technical reasons, we need to distinguish one
|
|
# one-dimensional numpy array from multidimensional ones.
|
|
return DataFormat.NUMPY_LIST
|
|
return DataFormat.NUMPY_MATRIX
|
|
if isinstance(input_data, pa.Table):
|
|
return DataFormat.PYARROW_TABLE
|
|
if isinstance(input_data, pa.Array):
|
|
return DataFormat.PYARROW_ARRAY
|
|
if isinstance(input_data, pd.Series):
|
|
return DataFormat.PANDAS_SERIES
|
|
if isinstance(input_data, pd.Index):
|
|
return DataFormat.PANDAS_INDEX
|
|
if is_pandas_styler(input_data):
|
|
return DataFormat.PANDAS_STYLER
|
|
if isinstance(input_data, pd.api.extensions.ExtensionArray):
|
|
return DataFormat.PANDAS_ARRAY
|
|
if is_polars_series(input_data):
|
|
return DataFormat.POLARS_SERIES
|
|
if is_polars_dataframe(input_data):
|
|
return DataFormat.POLARS_DATAFRAME
|
|
if is_polars_lazyframe(input_data):
|
|
return DataFormat.POLARS_LAZYFRAME
|
|
if is_modin_data_object(input_data):
|
|
return DataFormat.MODIN_OBJECT
|
|
if is_snowpandas_data_object(input_data):
|
|
return DataFormat.SNOWPANDAS_OBJECT
|
|
if is_pyspark_data_object(input_data):
|
|
return DataFormat.PYSPARK_OBJECT
|
|
if is_xarray_dataset(input_data):
|
|
return DataFormat.XARRAY_DATASET
|
|
if is_xarray_data_array(input_data):
|
|
return DataFormat.XARRAY_DATA_ARRAY
|
|
if is_ray_dataset(input_data):
|
|
return DataFormat.RAY_DATASET
|
|
if is_dask_object(input_data):
|
|
return DataFormat.DASK_OBJECT
|
|
if is_snowpark_data_object(input_data) or is_snowpark_row_list(input_data):
|
|
return DataFormat.SNOWPARK_OBJECT
|
|
if is_duckdb_relation(input_data):
|
|
return DataFormat.DUCKDB_RELATION
|
|
if is_dbapi_cursor(input_data):
|
|
return DataFormat.DBAPI_CURSOR
|
|
if (
|
|
isinstance(
|
|
input_data,
|
|
(ChainMap, UserDict, MappingProxyType),
|
|
)
|
|
or is_dataclass_instance(input_data)
|
|
or is_namedtuple(input_data)
|
|
or is_custom_dict(input_data)
|
|
or is_pydantic_model(input_data)
|
|
):
|
|
return DataFormat.KEY_VALUE_DICT
|
|
if isinstance(input_data, (ItemsView, enumerate)):
|
|
return DataFormat.LIST_OF_ROWS
|
|
if isinstance(input_data, (list, tuple, set, frozenset)):
|
|
if _is_list_of_scalars(input_data):
|
|
# -> one-dimensional data structure
|
|
if isinstance(input_data, tuple):
|
|
return DataFormat.TUPLE_OF_VALUES
|
|
if isinstance(input_data, (set, frozenset)):
|
|
return DataFormat.SET_OF_VALUES
|
|
return DataFormat.LIST_OF_VALUES
|
|
# -> Multi-dimensional data structure
|
|
# This should always contain at least one element,
|
|
# otherwise the values type from infer_dtype would have been empty
|
|
first_element = next(iter(input_data))
|
|
if isinstance(first_element, dict):
|
|
return DataFormat.LIST_OF_RECORDS
|
|
if isinstance(first_element, (list, tuple, set, frozenset)):
|
|
return DataFormat.LIST_OF_ROWS
|
|
elif isinstance(input_data, (dict, Mapping)):
|
|
if not input_data:
|
|
return DataFormat.KEY_VALUE_DICT
|
|
if len(input_data) > 0:
|
|
first_value = next(iter(input_data.values()))
|
|
# In the future, we could potentially also support tight & split formats
|
|
if isinstance(first_value, dict):
|
|
return DataFormat.COLUMN_INDEX_MAPPING
|
|
if isinstance(first_value, (list, tuple)):
|
|
return DataFormat.COLUMN_VALUE_MAPPING
|
|
if isinstance(first_value, pd.Series):
|
|
return DataFormat.COLUMN_SERIES_MAPPING
|
|
# Use key-value dict as fallback. However, if the values of the dict
|
|
# contains mixed types, it will become non-editable in the frontend.
|
|
return DataFormat.KEY_VALUE_DICT
|
|
elif is_list_like(input_data):
|
|
return DataFormat.LIST_OF_VALUES
|
|
|
|
return DataFormat.UNKNOWN
|
|
|
|
|
|
def _unify_missing_values(df: DataFrame) -> DataFrame:
|
|
"""Unify all missing values in a DataFrame to None.
|
|
|
|
Pandas uses a variety of values to represent missing values, including np.nan,
|
|
NaT, None, and pd.NA. This function replaces all of these values with None,
|
|
which is the only missing value type that is supported by all data
|
|
"""
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
# Replace all recognized nulls (np.nan, pd.NA, NaT) with None
|
|
# then infer objects without creating a separate copy:
|
|
# For performance reasons, we could use copy=False here.
|
|
# However, this is only available in pandas >=2.
|
|
return df.replace([pd.NA, pd.NaT, np.nan], None).infer_objects()
|
|
|
|
|
|
def _pandas_df_to_series(df: DataFrame) -> Series[Any]:
|
|
"""Convert a Pandas DataFrame to a Pandas Series by selecting the first column.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If the DataFrame has more than one column.
|
|
"""
|
|
# Select first column in dataframe and create a new series based on the values
|
|
if len(df.columns) != 1:
|
|
raise ValueError(
|
|
f"DataFrame is expected to have a single column but has {len(df.columns)}."
|
|
)
|
|
return df[df.columns[0]]
|
|
|
|
|
|
def convert_pandas_df_to_data_format(
|
|
df: DataFrame, data_format: DataFormat
|
|
) -> (
|
|
DataFrame
|
|
| Series[Any]
|
|
| pa.Table
|
|
| pa.Array
|
|
| np.ndarray[Any, np.dtype[Any]]
|
|
| tuple[Any]
|
|
| list[Any]
|
|
| set[Any]
|
|
| dict[str, Any]
|
|
):
|
|
"""Convert a Pandas DataFrame to the specified data format.
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
The dataframe to convert.
|
|
|
|
data_format : DataFormat
|
|
The data format to convert to.
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame, pd.Series, pyarrow.Table, np.ndarray, xarray.Dataset,
|
|
xarray.DataArray, polars.Dataframe, polars.Series, list, set, tuple, or dict.
|
|
The converted dataframe.
|
|
"""
|
|
|
|
if data_format in {
|
|
DataFormat.EMPTY,
|
|
DataFormat.DASK_OBJECT,
|
|
DataFormat.DBAPI_CURSOR,
|
|
DataFormat.DUCKDB_RELATION,
|
|
DataFormat.MODIN_OBJECT,
|
|
DataFormat.PANDAS_ARRAY,
|
|
DataFormat.PANDAS_DATAFRAME,
|
|
DataFormat.PANDAS_INDEX,
|
|
DataFormat.PANDAS_STYLER,
|
|
DataFormat.PYSPARK_OBJECT,
|
|
DataFormat.RAY_DATASET,
|
|
DataFormat.SNOWPANDAS_OBJECT,
|
|
DataFormat.SNOWPARK_OBJECT,
|
|
}:
|
|
return df
|
|
if data_format == DataFormat.NUMPY_LIST:
|
|
import numpy as np
|
|
|
|
# It's a 1-dimensional array, so we only return
|
|
# the first column as numpy array
|
|
# Calling to_numpy() on the full DataFrame would result in:
|
|
# [[1], [2]] instead of [1, 2]
|
|
return np.ndarray(0) if df.empty else df.iloc[:, 0].to_numpy()
|
|
if data_format == DataFormat.NUMPY_MATRIX:
|
|
import numpy as np
|
|
|
|
return np.ndarray(0) if df.empty else df.to_numpy()
|
|
if data_format == DataFormat.PYARROW_TABLE:
|
|
import pyarrow as pa
|
|
|
|
return pa.Table.from_pandas(df)
|
|
if data_format == DataFormat.PYARROW_ARRAY:
|
|
import pyarrow as pa
|
|
|
|
return pa.Array.from_pandas(_pandas_df_to_series(df))
|
|
if data_format == DataFormat.PANDAS_SERIES:
|
|
return _pandas_df_to_series(df)
|
|
if data_format in {DataFormat.POLARS_DATAFRAME, DataFormat.POLARS_LAZYFRAME}:
|
|
import polars as pl # type: ignore[import-not-found]
|
|
|
|
return pl.from_pandas(df)
|
|
if data_format == DataFormat.POLARS_SERIES:
|
|
import polars as pl
|
|
|
|
return pl.from_pandas(_pandas_df_to_series(df))
|
|
if data_format == DataFormat.XARRAY_DATASET:
|
|
import xarray as xr # type: ignore[import-not-found]
|
|
|
|
return xr.Dataset.from_dataframe(df)
|
|
if data_format == DataFormat.XARRAY_DATA_ARRAY:
|
|
import xarray as xr
|
|
|
|
return xr.DataArray.from_series(_pandas_df_to_series(df))
|
|
if data_format == DataFormat.LIST_OF_RECORDS:
|
|
return _unify_missing_values(df).to_dict(orient="records")
|
|
if data_format == DataFormat.LIST_OF_ROWS:
|
|
# to_numpy converts the dataframe to a list of rows
|
|
return _unify_missing_values(df).to_numpy().tolist()
|
|
if data_format == DataFormat.COLUMN_INDEX_MAPPING:
|
|
return _unify_missing_values(df).to_dict(orient="dict")
|
|
if data_format == DataFormat.COLUMN_VALUE_MAPPING:
|
|
return _unify_missing_values(df).to_dict(orient="list")
|
|
if data_format == DataFormat.COLUMN_SERIES_MAPPING:
|
|
return df.to_dict(orient="series")
|
|
if data_format in [
|
|
DataFormat.LIST_OF_VALUES,
|
|
DataFormat.TUPLE_OF_VALUES,
|
|
DataFormat.SET_OF_VALUES,
|
|
]:
|
|
df = _unify_missing_values(df)
|
|
return_list = []
|
|
if len(df.columns) == 1:
|
|
# Get the first column and convert to list
|
|
return_list = df[df.columns[0]].tolist()
|
|
elif len(df.columns) >= 1:
|
|
raise ValueError(
|
|
"DataFrame is expected to have a single column but "
|
|
f"has {len(df.columns)}."
|
|
)
|
|
if data_format == DataFormat.TUPLE_OF_VALUES:
|
|
return tuple(return_list)
|
|
if data_format == DataFormat.SET_OF_VALUES:
|
|
return set(return_list)
|
|
return return_list
|
|
if data_format == DataFormat.KEY_VALUE_DICT:
|
|
df = _unify_missing_values(df)
|
|
# The key is expected to be the index -> this will return the first column
|
|
# as a dict with index as key.
|
|
return {} if df.empty else df.iloc[:, 0].to_dict()
|
|
|
|
raise ValueError(f"Unsupported input data format: {data_format}")
|