team-10/env/Lib/site-packages/streamlit/dataframe_util.py
2025-08-02 07:34:44 +02:00

1416 lines
48 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A bunch of useful utilities for dealing with dataframes."""
from __future__ import annotations
import contextlib
import dataclasses
import inspect
import math
import re
from collections import ChainMap, UserDict, UserList, deque
from collections.abc import ItemsView, Iterable, Mapping, Sequence
from enum import Enum, EnumMeta, auto
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Final,
Protocol,
TypeVar,
Union,
cast,
runtime_checkable,
)
from typing_extensions import TypeAlias, TypeGuard
from streamlit import config, errors, logger, string_util
from streamlit.type_util import (
CustomDict,
has_callable_attr,
is_custom_dict,
is_dataclass_instance,
is_list_like,
is_namedtuple,
is_pydantic_model,
is_type,
is_version_less_than,
)
if TYPE_CHECKING:
import numpy as np
import pyarrow as pa
from pandas import DataFrame, Index, Series
from pandas.core.indexing import _iLocIndexer
from pandas.io.formats.style import Styler
_LOGGER: Final = logger.get_logger(__name__)
# Maximum number of rows to request from an unevaluated (out-of-core) dataframe
_MAX_UNEVALUATED_DF_ROWS = 10000
_PANDAS_DATA_OBJECT_TYPE_RE: Final = re.compile(r"^pandas.*$")
_DASK_DATAFRAME: Final = "dask.dataframe.dask_expr._collection.DataFrame"
_DASK_SERIES: Final = "dask.dataframe.dask_expr._collection.Series"
_DASK_INDEX: Final = "dask.dataframe.dask_expr._collection.Index"
# Dask removed the old legacy types, to support older and newer versions
# we are still supporting the old an new types.
_DASK_DATAFRAME_LEGACY: Final = "dask.dataframe.core.DataFrame"
_DASK_SERIES_LEGACY: Final = "dask.dataframe.core.Series"
_DASK_INDEX_LEGACY: Final = "dask.dataframe.core.Index"
_DUCKDB_RELATION: Final = "duckdb.duckdb.DuckDBPyRelation"
_MODIN_DF_TYPE_STR: Final = "modin.pandas.dataframe.DataFrame"
_MODIN_SERIES_TYPE_STR: Final = "modin.pandas.series.Series"
_PANDAS_STYLER_TYPE_STR: Final = "pandas.io.formats.style.Styler"
_POLARS_DATAFRAME: Final = "polars.dataframe.frame.DataFrame"
_POLARS_LAZYFRAME: Final = "polars.lazyframe.frame.LazyFrame"
_POLARS_SERIES: Final = "polars.series.series.Series"
_PYSPARK_DF_TYPE_STR: Final = "pyspark.sql.dataframe.DataFrame"
_PYSPARK_CONNECT_DF_TYPE_STR: Final = "pyspark.sql.connect.dataframe.DataFrame"
_RAY_DATASET: Final = "ray.data.dataset.Dataset"
_RAY_MATERIALIZED_DATASET: Final = "ray.data.dataset.MaterializedDataset"
_SNOWPANDAS_DF_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.dataframe.DataFrame"
_SNOWPANDAS_INDEX_TYPE_STR: Final = (
"snowflake.snowpark.modin.plugin.extensions.index.Index"
)
_SNOWPANDAS_SERIES_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.series.Series"
_SNOWPARK_DF_ROW_TYPE_STR: Final = "snowflake.snowpark.row.Row"
_SNOWPARK_DF_TYPE_STR: Final = "snowflake.snowpark.dataframe.DataFrame"
_SNOWPARK_TABLE_TYPE_STR: Final = "snowflake.snowpark.table.Table"
_XARRAY_DATASET_TYPE_STR: Final = "xarray.core.dataset.Dataset"
_XARRAY_DATA_ARRAY_TYPE_STR: Final = "xarray.core.dataarray.DataArray"
V_co = TypeVar(
"V_co",
covariant=True, # https://peps.python.org/pep-0484/#covariance-and-contravariance
)
@runtime_checkable
class DBAPICursor(Protocol):
"""Protocol for DBAPI 2.0 Cursor objects (PEP 249).
This is a simplified version of the DBAPI Cursor protocol
that only contains the methods that are relevant or used for
our DB API Integration.
Specification: https://peps.python.org/pep-0249/
Inspired by: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/dbapi.pyi
"""
@property
def description(
self,
) -> (
Sequence[
tuple[
str,
Any | None,
int | None,
int | None,
int | None,
int | None,
bool | None,
]
]
| None
): ...
def fetchmany(self, size: int = ..., /) -> Sequence[Sequence[Any]]: ...
def fetchall(self) -> Sequence[Sequence[Any]]: ...
class DataFrameGenericAlias(Protocol[V_co]):
"""Technically not a GenericAlias, but serves the same purpose in
OptionSequence below, in that it is a type which admits DataFrame,
but is generic. This allows OptionSequence to be a fully generic type,
significantly increasing its usefulness.
We can't use types.GenericAlias, as it is only available from python>=3.9,
and isn't easily back-ported.
"""
@property
def iloc(self) -> _iLocIndexer: ...
class PandasCompatible(Protocol):
"""Protocol for Pandas compatible objects that have a `to_pandas` method."""
def to_pandas(self) -> DataFrame | Series: ...
class DataframeInterchangeCompatible(Protocol):
"""Protocol for objects support the dataframe-interchange protocol.
https://data-apis.org/dataframe-protocol/latest/index.html
"""
def __dataframe__(self, allow_copy: bool) -> Any: ...
OptionSequence: TypeAlias = Union[
Iterable[V_co],
DataFrameGenericAlias[V_co],
PandasCompatible,
DataframeInterchangeCompatible,
]
# Various data types supported by our dataframe processing
# used for commands like `st.dataframe`, `st.table`, `st.map`,
# st.line_chart`...
Data: TypeAlias = Union[
"DataFrame",
"Series",
"Styler",
"Index",
"pa.Table",
"pa.Array",
"np.ndarray[Any, np.dtype[Any]]",
Iterable[Any],
"Mapping[Any, Any]",
DBAPICursor,
PandasCompatible,
DataframeInterchangeCompatible,
CustomDict,
None,
]
class DataFormat(Enum):
"""DataFormat is used to determine the format of the data."""
UNKNOWN = auto()
EMPTY = auto() # None
COLUMN_INDEX_MAPPING = auto() # {column: {index: value}}
COLUMN_SERIES_MAPPING = auto() # {column: Series(values)}
COLUMN_VALUE_MAPPING = auto() # {column: List[values]}
DASK_OBJECT = auto() # dask.dataframe.core.DataFrame, Series, Index
DBAPI_CURSOR = auto() # DBAPI Cursor (PEP 249)
DUCKDB_RELATION = auto() # DuckDB Relation
KEY_VALUE_DICT = auto() # {index: value}
LIST_OF_RECORDS = auto() # List[Dict[str, Scalar]]
LIST_OF_ROWS = auto() # List[List[Scalar]]
LIST_OF_VALUES = auto() # List[Scalar]
MODIN_OBJECT = auto() # Modin DataFrame, Series
NUMPY_LIST = auto() # np.array[Scalar]
NUMPY_MATRIX = auto() # np.array[List[Scalar]]
PANDAS_ARRAY = auto() # pd.array
PANDAS_DATAFRAME = auto() # pd.DataFrame
PANDAS_INDEX = auto() # pd.Index
PANDAS_SERIES = auto() # pd.Series
PANDAS_STYLER = auto() # pandas Styler
POLARS_DATAFRAME = auto() # polars.dataframe.frame.DataFrame
POLARS_LAZYFRAME = auto() # polars.lazyframe.frame.LazyFrame
POLARS_SERIES = auto() # polars.series.series.Series
PYARROW_ARRAY = auto() # pyarrow.Array
PYARROW_TABLE = auto() # pyarrow.Table
PYSPARK_OBJECT = auto() # pyspark.DataFrame
RAY_DATASET = auto() # ray.data.dataset.Dataset, MaterializedDataset
SET_OF_VALUES = auto() # Set[Scalar]
SNOWPANDAS_OBJECT = auto() # Snowpandas DataFrame, Series
SNOWPARK_OBJECT = auto() # Snowpark DataFrame, Table, List[Row]
TUPLE_OF_VALUES = auto() # Tuple[Scalar]
XARRAY_DATASET = auto() # xarray.Dataset
XARRAY_DATA_ARRAY = auto() # xarray.DataArray
def is_pyarrow_version_less_than(v: str) -> bool:
"""Return True if the current Pyarrow version is less than the input version.
Parameters
----------
v : str
Version string, e.g. "0.25.0"
Returns
-------
bool
Raises
------
InvalidVersion
If the version strings are not valid.
"""
import pyarrow as pa
return is_version_less_than(pa.__version__, v)
def is_pandas_version_less_than(v: str) -> bool:
"""Return True if the current Pandas version is less than the input version.
Parameters
----------
v : str
Version string, e.g. "0.25.0"
Returns
-------
bool
Raises
------
InvalidVersion
If the version strings are not valid.
"""
import pandas as pd
return is_version_less_than(pd.__version__, v)
def is_dataframe_like(obj: object) -> bool:
"""True if the object is a dataframe-like object.
This does not include basic collection types like list, dict, tuple, etc.
"""
# We exclude list and dict here since there are some cases where a list or dict is
# considered a dataframe-like object.
if obj is None or isinstance(obj, (tuple, set, str, bytes, int, float, bool)):
# Basic types are not considered dataframe-like, so we can
# return False early to avoid unnecessary checks.
return False
return determine_data_format(obj) in {
DataFormat.COLUMN_SERIES_MAPPING,
DataFormat.DASK_OBJECT,
DataFormat.DBAPI_CURSOR,
DataFormat.MODIN_OBJECT,
DataFormat.NUMPY_LIST,
DataFormat.NUMPY_MATRIX,
DataFormat.PANDAS_ARRAY,
DataFormat.PANDAS_DATAFRAME,
DataFormat.PANDAS_INDEX,
DataFormat.PANDAS_SERIES,
DataFormat.PANDAS_STYLER,
DataFormat.POLARS_DATAFRAME,
DataFormat.POLARS_LAZYFRAME,
DataFormat.POLARS_SERIES,
DataFormat.PYARROW_ARRAY,
DataFormat.PYARROW_TABLE,
DataFormat.PYSPARK_OBJECT,
DataFormat.RAY_DATASET,
DataFormat.SNOWPANDAS_OBJECT,
DataFormat.SNOWPARK_OBJECT,
DataFormat.XARRAY_DATASET,
DataFormat.XARRAY_DATA_ARRAY,
}
def is_unevaluated_data_object(obj: object) -> bool:
"""True if the object is one of the supported unevaluated data objects.
Currently supported objects are:
- Snowpark DataFrame / Table
- PySpark DataFrame
- Modin DataFrame / Series
- Snowpandas DataFrame / Series / Index
- Dask DataFrame / Series / Index
- Ray Dataset
- Polars LazyFrame
- Generator functions
- DB API 2.0 Cursor (PEP 249)
- DuckDB Relation (Relational API)
Unevaluated means that the data is not yet in the local memory.
Unevaluated data objects are treated differently from other data objects by only
requesting a subset of the data instead of loading all data into th memory
"""
return (
is_snowpark_data_object(obj)
or is_pyspark_data_object(obj)
or is_snowpandas_data_object(obj)
or is_modin_data_object(obj)
or is_ray_dataset(obj)
or is_polars_lazyframe(obj)
or is_dask_object(obj)
or is_duckdb_relation(obj)
or is_dbapi_cursor(obj)
or inspect.isgeneratorfunction(obj)
)
def is_pandas_data_object(obj: object) -> bool:
"""True if obj is a Pandas object (e.g. DataFrame, Series, Index, Styler, ...)."""
return is_type(obj, _PANDAS_DATA_OBJECT_TYPE_RE)
def is_snowpark_data_object(obj: object) -> bool:
"""True if obj is a Snowpark DataFrame or Table."""
return is_type(obj, _SNOWPARK_TABLE_TYPE_STR) or is_type(obj, _SNOWPARK_DF_TYPE_STR)
def is_snowpark_row_list(obj: object) -> bool:
"""True if obj is a list of snowflake.snowpark.row.Row."""
return (
isinstance(obj, list)
and len(obj) > 0
and is_type(obj[0], _SNOWPARK_DF_ROW_TYPE_STR)
and has_callable_attr(obj[0], "as_dict")
)
def is_pyspark_data_object(obj: object) -> bool:
"""True if obj is a PySpark or PySpark Connect dataframe."""
return (
is_type(obj, _PYSPARK_DF_TYPE_STR) or is_type(obj, _PYSPARK_CONNECT_DF_TYPE_STR)
) and has_callable_attr(obj, "toPandas")
def is_dask_object(obj: object) -> bool:
"""True if obj is a Dask DataFrame, Series, or Index."""
return (
is_type(obj, _DASK_DATAFRAME)
or is_type(obj, _DASK_DATAFRAME_LEGACY)
or is_type(obj, _DASK_SERIES)
or is_type(obj, _DASK_SERIES_LEGACY)
or is_type(obj, _DASK_INDEX)
or is_type(obj, _DASK_INDEX_LEGACY)
)
def is_modin_data_object(obj: object) -> bool:
"""True if obj is of Modin Dataframe or Series."""
return is_type(obj, _MODIN_DF_TYPE_STR) or is_type(obj, _MODIN_SERIES_TYPE_STR)
def is_snowpandas_data_object(obj: object) -> bool:
"""True if obj is a Snowpark Pandas DataFrame or Series."""
return (
is_type(obj, _SNOWPANDAS_DF_TYPE_STR)
or is_type(obj, _SNOWPANDAS_SERIES_TYPE_STR)
or is_type(obj, _SNOWPANDAS_INDEX_TYPE_STR)
)
def is_polars_dataframe(obj: object) -> bool:
"""True if obj is a Polars Dataframe."""
return is_type(obj, _POLARS_DATAFRAME)
def is_xarray_dataset(obj: object) -> bool:
"""True if obj is a Xarray Dataset."""
return is_type(obj, _XARRAY_DATASET_TYPE_STR)
def is_xarray_data_array(obj: object) -> bool:
"""True if obj is a Xarray DataArray."""
return is_type(obj, _XARRAY_DATA_ARRAY_TYPE_STR)
def is_polars_series(obj: object) -> bool:
"""True if obj is a Polars Series."""
return is_type(obj, _POLARS_SERIES)
def is_polars_lazyframe(obj: object) -> bool:
"""True if obj is a Polars Lazyframe."""
return is_type(obj, _POLARS_LAZYFRAME)
def is_ray_dataset(obj: object) -> bool:
"""True if obj is a Ray Dataset."""
return is_type(obj, _RAY_DATASET) or is_type(obj, _RAY_MATERIALIZED_DATASET)
def is_pandas_styler(obj: object) -> TypeGuard[Styler]:
"""True if obj is a pandas Styler."""
return is_type(obj, _PANDAS_STYLER_TYPE_STR)
def is_dbapi_cursor(obj: object) -> TypeGuard[DBAPICursor]:
"""True if obj looks like a DB API 2.0 Cursor.
https://peps.python.org/pep-0249/
"""
return isinstance(obj, DBAPICursor)
def is_duckdb_relation(obj: object) -> bool:
"""True if obj is a DuckDB relation.
https://duckdb.org/docs/api/python/relational_api
"""
return is_type(obj, _DUCKDB_RELATION)
def _is_list_of_scalars(data: Iterable[Any]) -> bool:
"""Check if the list only contains scalar values."""
from pandas.api.types import infer_dtype
# Overview on all value that are interpreted as scalar:
# https://pandas.pydata.org/docs/reference/api/pandas.api.types.is_scalar.html
return infer_dtype(data, skipna=True) not in ["mixed", "unknown-array"]
def _iterable_to_list(
iterable: Iterable[Any], max_iterations: int | None = None
) -> list[Any]:
"""Convert an iterable to a list.
Parameters
----------
iterable : Iterable
The iterable to convert to a list.
max_iterations : int or None
The maximum number of iterations to perform. If None, all iterations are performed.
Returns
-------
list
The converted list.
"""
if max_iterations is None:
return list(iterable)
result = []
for i, item in enumerate(iterable):
if i >= max_iterations:
break
result.append(item)
return result
def _fix_column_naming(data_df: DataFrame) -> DataFrame:
"""Rename the first column to "value" if it is not named
and if there is only one column in the dataframe.
The default name of the first column is 0 if it is not named
which is not very descriptive.
"""
if len(data_df.columns) == 1 and data_df.columns[0] == 0:
# Pandas automatically names the first column with 0 if it is not named.
# We rename it to "value" to make it more descriptive if there is only
# one column in the dataframe.
data_df = data_df.rename(columns={0: "value"})
return data_df
def _dict_to_pandas_df(data: dict[Any, Any]) -> DataFrame:
"""Convert a key-value dict to a Pandas DataFrame.
Parameters
----------
data : dict
The dict to convert to a Pandas DataFrame.
Returns
-------
pandas.DataFrame
The converted Pandas DataFrame.
"""
import pandas as pd
return _fix_column_naming(pd.DataFrame.from_dict(data, orient="index"))
def convert_anything_to_pandas_df(
data: Any,
max_unevaluated_rows: int = _MAX_UNEVALUATED_DF_ROWS,
ensure_copy: bool = False,
) -> DataFrame:
"""Try to convert different formats to a Pandas Dataframe.
Parameters
----------
data : dataframe-, array-, or collections-like object
The data to convert to a Pandas DataFrame.
max_unevaluated_rows: int
If unevaluated data is detected this func will evaluate it,
taking max_unevaluated_rows, defaults to 10k.
ensure_copy: bool
If True, make sure to always return a copy of the data. If False, it depends on
the type of the data. For example, a Pandas DataFrame will be returned as-is.
Returns
-------
pandas.DataFrame
"""
import array
import numpy as np
import pandas as pd
if isinstance(data, pd.DataFrame):
return data.copy() if ensure_copy else cast("pd.DataFrame", data)
if isinstance(data, (pd.Series, pd.Index, pd.api.extensions.ExtensionArray)):
return pd.DataFrame(data)
if is_pandas_styler(data):
return cast("pd.DataFrame", data.data.copy() if ensure_copy else data.data)
if isinstance(data, np.ndarray):
return (
pd.DataFrame([])
if len(data.shape) == 0
else _fix_column_naming(pd.DataFrame(data))
)
if is_polars_dataframe(data):
data = data.clone() if ensure_copy else data
return data.to_pandas()
if is_polars_series(data):
data = data.clone() if ensure_copy else data
return data.to_pandas().to_frame()
if is_polars_lazyframe(data):
data = data.limit(max_unevaluated_rows).collect().to_pandas()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `collect()` on the dataframe to show more."
)
return cast("pd.DataFrame", data)
if is_xarray_dataset(data):
if ensure_copy:
data = data.copy(deep=True)
return data.to_dataframe()
if is_xarray_data_array(data):
if ensure_copy:
data = data.copy(deep=True)
return data.to_series().to_frame()
if is_dask_object(data):
data = data.head(max_unevaluated_rows, compute=True)
# Dask returns a Pandas object (DataFrame, Series, Index) when
# executing operations like `head`.
if isinstance(data, (pd.Series, pd.Index)):
data = data.to_frame()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `compute()` on the data object to show more."
)
return cast("pd.DataFrame", data)
if is_ray_dataset(data):
data = data.limit(max_unevaluated_rows).to_pandas()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `to_pandas()` on the dataset to show more."
)
return cast("pd.DataFrame", data)
if is_modin_data_object(data):
data = data.head(max_unevaluated_rows)._to_pandas()
if isinstance(data, (pd.Series, pd.Index)):
data = data.to_frame()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `_to_pandas()` on the data object to show more."
)
return cast("pd.DataFrame", data)
if is_pyspark_data_object(data):
data = data.limit(max_unevaluated_rows).toPandas()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `toPandas()` on the data object to show more."
)
return cast("pd.DataFrame", data)
if is_snowpandas_data_object(data):
data = data[:max_unevaluated_rows].to_pandas()
if isinstance(data, (pd.Series, pd.Index)):
data = data.to_frame()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `to_pandas()` on the data object to show more."
)
return cast("pd.DataFrame", data)
if is_snowpark_data_object(data):
data = data.limit(max_unevaluated_rows).to_pandas()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `to_pandas()` on the data object to show more."
)
return cast("pd.DataFrame", data)
if is_duckdb_relation(data):
data = data.limit(max_unevaluated_rows).df()
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `df()` on the relation to show more."
)
return data
if is_dbapi_cursor(data):
# Based on the specification, the first item in the description is the
# column name (if available)
columns = (
[d[0] if d else "" for d in data.description] if data.description else None
)
data = pd.DataFrame(data.fetchmany(max_unevaluated_rows), columns=columns)
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Call `fetchall()` on the Cursor to show more."
)
return data
if is_snowpark_row_list(data):
return pd.DataFrame([row.as_dict() for row in data])
if has_callable_attr(data, "to_pandas"):
return pd.DataFrame(data.to_pandas())
# Check for dataframe interchange protocol
# Only available in pandas >= 1.5.0
# https://pandas.pydata.org/docs/whatsnew/v1.5.0.html#dataframe-interchange-protocol-implementation
if (
has_callable_attr(data, "__dataframe__")
and is_pandas_version_less_than("1.5.0") is False
):
data_df = pd.api.interchange.from_dataframe(data)
return data_df.copy() if ensure_copy else data_df
# Support for generator functions
if inspect.isgeneratorfunction(data):
data = _fix_column_naming(
pd.DataFrame(_iterable_to_list(data(), max_iterations=max_unevaluated_rows))
)
if data.shape[0] == max_unevaluated_rows:
_show_data_information(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
"rows. Convert the data to a list to show more."
)
return data
if isinstance(data, EnumMeta):
# Support for enum classes
return _fix_column_naming(pd.DataFrame([c.value for c in data])) # type: ignore
# Support for some list like objects
if isinstance(data, (deque, map, array.ArrayType, UserList)):
return _fix_column_naming(pd.DataFrame(list(data)))
# Support for Streamlit's custom dict-like objects
if is_custom_dict(data):
return _dict_to_pandas_df(data.to_dict())
# Support for named tuples
if is_namedtuple(data):
return _dict_to_pandas_df(data._asdict())
# Support for dataclass instances
if is_dataclass_instance(data):
return _dict_to_pandas_df(dataclasses.asdict(data))
# Support for dict-like objects
if isinstance(data, (ChainMap, MappingProxyType, UserDict)) or is_pydantic_model(
data
):
return _dict_to_pandas_df(dict(data))
# Try to convert to pandas.DataFrame. This will raise an error is df is not
# compatible with the pandas.DataFrame constructor.
try:
return _fix_column_naming(pd.DataFrame(data))
except ValueError as ex:
if isinstance(data, dict):
with contextlib.suppress(ValueError):
# Try to use index orient as back-up to support key-value dicts
return _dict_to_pandas_df(data)
raise errors.StreamlitAPIException(
f"""
Unable to convert object of type `{type(data)}` to `pandas.DataFrame`.
Offending object:
```py
{data}
```"""
) from ex
def convert_arrow_table_to_arrow_bytes(table: pa.Table) -> bytes:
"""Serialize pyarrow.Table to Arrow IPC bytes.
Parameters
----------
table : pyarrow.Table
A table to convert.
Returns
-------
bytes
The serialized Arrow IPC bytes.
"""
try:
table = _maybe_truncate_table(table)
except RecursionError as err:
# This is a very unlikely edge case, but we want to make sure that
# it doesn't lead to unexpected behavior.
# If there is a recursion error, we just return the table as-is
# which will lead to the normal message limit exceed error.
_LOGGER.warning(
"Recursion error while truncating Arrow table. This is not "
"supposed to happen.",
exc_info=err,
)
import pyarrow as pa
# Convert table to bytes
sink = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(sink, table.schema)
writer.write_table(table)
writer.close()
return cast("bytes", sink.getvalue().to_pybytes())
def convert_pandas_df_to_arrow_bytes(df: DataFrame) -> bytes:
"""Serialize pandas.DataFrame to Arrow IPC bytes.
Parameters
----------
df : pandas.DataFrame
A dataframe to convert.
Returns
-------
bytes
The serialized Arrow IPC bytes.
"""
import pyarrow as pa
try:
table = pa.Table.from_pandas(df)
except (pa.ArrowTypeError, pa.ArrowInvalid, pa.ArrowNotImplementedError) as ex:
_LOGGER.info(
"Serialization of dataframe to Arrow table was unsuccessful. "
"Applying automatic fixes for column types to make the dataframe "
"Arrow-compatible.",
exc_info=ex,
)
df = fix_arrow_incompatible_column_types(df)
table = pa.Table.from_pandas(df)
return convert_arrow_table_to_arrow_bytes(table)
def convert_arrow_bytes_to_pandas_df(source: bytes) -> DataFrame:
"""Convert Arrow bytes (IPC format) to pandas.DataFrame.
Using this function in production needs to make sure that
the pyarrow version >= 14.0.1, because of a critical
security vulnerability in pyarrow < 14.0.1.
Parameters
----------
source : bytes
A bytes object to convert.
Returns
-------
pandas.DataFrame
The converted dataframe.
"""
import pyarrow as pa
reader = pa.RecordBatchStreamReader(source)
return reader.read_pandas()
def _show_data_information(msg: str) -> None:
"""Show a message to the user with important information
about the processed dataset.
"""
from streamlit.delta_generator_singletons import get_dg_singleton_instance
get_dg_singleton_instance().main_dg.caption(msg)
def convert_anything_to_arrow_bytes(
data: Any,
max_unevaluated_rows: int = _MAX_UNEVALUATED_DF_ROWS,
) -> bytes:
"""Try to convert different formats to Arrow IPC format (bytes).
This method tries to directly convert the input data to Arrow bytes
for some supported formats, but falls back to conversion to a Pandas
DataFrame and then to Arrow bytes.
Parameters
----------
data : dataframe-, array-, or collections-like object
The data to convert to Arrow bytes.
max_unevaluated_rows: int
If unevaluated data is detected this func will evaluate it,
taking max_unevaluated_rows, defaults to 10k.
Returns
-------
bytes
The serialized Arrow IPC bytes.
"""
import pyarrow as pa
if isinstance(data, pa.Table):
return convert_arrow_table_to_arrow_bytes(data)
# TODO(lukasmasuch): Add direct conversion to Arrow for supported formats here
# Fallback: try to convert to pandas DataFrame
# and then to Arrow bytes.
df = convert_anything_to_pandas_df(data, max_unevaluated_rows)
return convert_pandas_df_to_arrow_bytes(df)
def convert_anything_to_list(obj: OptionSequence[V_co]) -> list[V_co]:
"""Try to convert different formats to a list.
If the input is a dataframe-like object, we just select the first
column to iterate over. Non sequence-like objects and scalar types,
will just be wrapped into a list.
Parameters
----------
obj : dataframe-, array-, or collections-like object
The object to convert to a list.
Returns
-------
list
The converted list.
"""
if obj is None:
return [] # type: ignore
if isinstance(obj, (str, int, float, bool)):
# Wrap basic objects into a list
return [obj] # type: ignore[list-item]
if isinstance(obj, EnumMeta):
# Support for enum classes. For string enums, we return the string value
# of the enum members. For other enums, we just return the enum member.
return [member.value if isinstance(member, str) else member for member in obj] # type: ignore
if isinstance(obj, Mapping):
return list(obj.keys())
if is_list_like(obj) and not is_snowpark_row_list(obj):
# This also ensures that the sequence is copied to prevent
# potential mutations to the original object.
return list(obj)
# Fallback to our DataFrame conversion logic:
try:
# We use ensure_copy here because the return value of this function is
# saved in a widget serde class instance to be used in later script runs,
# and we don't want mutations to the options object passed to a
# widget affect the widget.
# (See https://github.com/streamlit/streamlit/issues/7534)
data_df = convert_anything_to_pandas_df(obj, ensure_copy=True)
# Return first column as a list:
return (
[]
if data_df.empty
else cast("list[V_co]", list(data_df.iloc[:, 0].to_list()))
)
except errors.StreamlitAPIException:
# Wrap the object into a list
return [obj] # type: ignore
def _maybe_truncate_table(
table: pa.Table, truncated_rows: int | None = None
) -> pa.Table:
"""Experimental feature to automatically truncate tables that
are larger than the maximum allowed message size. It needs to be enabled
via the server.enableArrowTruncation config option.
Parameters
----------
table : pyarrow.Table
A table to truncate.
truncated_rows : int or None
The number of rows that have been truncated so far. This is used by
the recursion logic to keep track of the total number of truncated
rows.
"""
if config.get_option("server.enableArrowTruncation"):
# This is an optimization problem: We don't know at what row
# the perfect cut-off is to comply with the max size. But we want to figure
# it out in as few iterations as possible. We almost always will cut out
# more than required to keep the iterations low.
# The maximum size allowed for protobuf messages in bytes:
max_message_size = int(config.get_option("server.maxMessageSize") * 1e6)
# We add 1 MB for other overhead related to the protobuf message.
# This is a very conservative estimate, but it should be good enough.
table_size = int(table.nbytes + 1 * 1e6)
table_rows = table.num_rows
if table_rows > 1 and table_size > max_message_size:
# targeted rows == the number of rows the table should be truncated to.
# Calculate an approximation of how many rows we need to truncate to.
targeted_rows = math.ceil(table_rows * (max_message_size / table_size))
# Make sure to cut out at least a couple of rows to avoid running
# this logic too often since it is quite inefficient and could lead
# to infinity recursions without these precautions.
targeted_rows = math.floor(
max(
min(
# Cut out:
# an additional 5% of the estimated num rows to cut out:
targeted_rows - math.floor((table_rows - targeted_rows) * 0.05),
# at least 1% of table size:
table_rows - (table_rows * 0.01),
# at least 5 rows:
table_rows - 5,
),
1, # but it should always have at least 1 row
)
)
sliced_table = table.slice(0, targeted_rows)
return _maybe_truncate_table(
sliced_table, (truncated_rows or 0) + (table_rows - targeted_rows)
)
if truncated_rows:
displayed_rows = string_util.simplify_number(table.num_rows)
total_rows = string_util.simplify_number(table.num_rows + truncated_rows)
if displayed_rows == total_rows:
# If the simplified numbers are the same,
# we just display the exact numbers.
displayed_rows = str(table.num_rows)
total_rows = str(table.num_rows + truncated_rows)
_show_data_information(
f"⚠️ Showing {displayed_rows} out of {total_rows} "
"rows due to data size limitations."
)
return table
def is_colum_type_arrow_incompatible(column: Series[Any] | Index) -> bool:
"""Return True if the column type is known to cause issues during
Arrow conversion.
"""
from pandas.api.types import infer_dtype, is_dict_like, is_list_like
if column.dtype.kind in [
"c", # complex64, complex128, complex256
]:
return True
if str(column.dtype) in {
# These period types are not yet supported by our frontend impl.
# See comments in Quiver.ts for more details.
"period[B]",
"period[N]",
"period[ns]",
"period[U]",
"period[us]",
"geometry",
}:
return True
if column.dtype == "object":
# The dtype of mixed type columns is always object, the actual type of the column
# values can be determined via the infer_dtype function:
# https://pandas.pydata.org/docs/reference/api/pandas.api.types.infer_dtype.html
inferred_type = infer_dtype(column, skipna=True)
if inferred_type in [
"mixed-integer",
"complex",
]:
return True
if inferred_type == "mixed":
# This includes most of the more complex/custom types (objects, dicts,
# lists, ...)
if len(column) == 0 or not hasattr(column, "iloc"):
# The column seems to be invalid, so we assume it is incompatible.
# But this would most likely never happen since empty columns
# cannot be mixed.
return True
# Get the first value to check if it is a supported list-like type.
first_value = column.iloc[0]
if ( # noqa: SIM103
not is_list_like(first_value)
# dicts are list-like, but have issues in Arrow JS (see comments in
# Quiver.ts)
or is_dict_like(first_value)
# Frozensets are list-like, but are not compatible with pyarrow.
or isinstance(first_value, frozenset)
):
# This seems to be an incompatible list-like type
return True
return False
# We did not detect an incompatible type, so we assume it is compatible:
return False
def fix_arrow_incompatible_column_types(
df: DataFrame, selected_columns: list[str] | None = None
) -> DataFrame:
"""Fix column types that are not supported by Arrow table.
This includes mixed types (e.g. mix of integers and strings)
as well as complex numbers (complex128 type). These types will cause
errors during conversion of the dataframe to an Arrow table.
It is fixed by converting all values of the column to strings
This is sufficient for displaying the data on the frontend.
Parameters
----------
df : pandas.DataFrame
A dataframe to fix.
selected_columns: List[str] or None
A list of columns to fix. If None, all columns are evaluated.
Returns
-------
The fixed dataframe.
"""
import pandas as pd
# Make a copy, but only initialize if necessary to preserve memory.
df_copy: DataFrame | None = None
for col in selected_columns or df.columns:
if is_colum_type_arrow_incompatible(df[col]):
if df_copy is None:
df_copy = df.copy()
df_copy[col] = df[col].astype("string")
# The index can also contain mixed types
# causing Arrow issues during conversion.
# Skipping multi-indices since they won't return
# the correct value from infer_dtype
if not selected_columns and (
not isinstance(
df.index,
pd.MultiIndex,
)
and is_colum_type_arrow_incompatible(df.index)
):
if df_copy is None:
df_copy = df.copy()
df_copy.index = df.index.astype("string")
return df_copy if df_copy is not None else df
def determine_data_format(input_data: Any) -> DataFormat:
"""Determine the data format of the input data.
Parameters
----------
input_data : Any
The input data to determine the data format of.
Returns
-------
DataFormat
The data format of the input data.
"""
import numpy as np
import pandas as pd
import pyarrow as pa
if input_data is None:
return DataFormat.EMPTY
if isinstance(input_data, pd.DataFrame):
return DataFormat.PANDAS_DATAFRAME
if isinstance(input_data, np.ndarray):
if len(input_data.shape) == 1:
# For technical reasons, we need to distinguish one
# one-dimensional numpy array from multidimensional ones.
return DataFormat.NUMPY_LIST
return DataFormat.NUMPY_MATRIX
if isinstance(input_data, pa.Table):
return DataFormat.PYARROW_TABLE
if isinstance(input_data, pa.Array):
return DataFormat.PYARROW_ARRAY
if isinstance(input_data, pd.Series):
return DataFormat.PANDAS_SERIES
if isinstance(input_data, pd.Index):
return DataFormat.PANDAS_INDEX
if is_pandas_styler(input_data):
return DataFormat.PANDAS_STYLER
if isinstance(input_data, pd.api.extensions.ExtensionArray):
return DataFormat.PANDAS_ARRAY
if is_polars_series(input_data):
return DataFormat.POLARS_SERIES
if is_polars_dataframe(input_data):
return DataFormat.POLARS_DATAFRAME
if is_polars_lazyframe(input_data):
return DataFormat.POLARS_LAZYFRAME
if is_modin_data_object(input_data):
return DataFormat.MODIN_OBJECT
if is_snowpandas_data_object(input_data):
return DataFormat.SNOWPANDAS_OBJECT
if is_pyspark_data_object(input_data):
return DataFormat.PYSPARK_OBJECT
if is_xarray_dataset(input_data):
return DataFormat.XARRAY_DATASET
if is_xarray_data_array(input_data):
return DataFormat.XARRAY_DATA_ARRAY
if is_ray_dataset(input_data):
return DataFormat.RAY_DATASET
if is_dask_object(input_data):
return DataFormat.DASK_OBJECT
if is_snowpark_data_object(input_data) or is_snowpark_row_list(input_data):
return DataFormat.SNOWPARK_OBJECT
if is_duckdb_relation(input_data):
return DataFormat.DUCKDB_RELATION
if is_dbapi_cursor(input_data):
return DataFormat.DBAPI_CURSOR
if (
isinstance(
input_data,
(ChainMap, UserDict, MappingProxyType),
)
or is_dataclass_instance(input_data)
or is_namedtuple(input_data)
or is_custom_dict(input_data)
or is_pydantic_model(input_data)
):
return DataFormat.KEY_VALUE_DICT
if isinstance(input_data, (ItemsView, enumerate)):
return DataFormat.LIST_OF_ROWS
if isinstance(input_data, (list, tuple, set, frozenset)):
if _is_list_of_scalars(input_data):
# -> one-dimensional data structure
if isinstance(input_data, tuple):
return DataFormat.TUPLE_OF_VALUES
if isinstance(input_data, (set, frozenset)):
return DataFormat.SET_OF_VALUES
return DataFormat.LIST_OF_VALUES
# -> Multi-dimensional data structure
# This should always contain at least one element,
# otherwise the values type from infer_dtype would have been empty
first_element = next(iter(input_data))
if isinstance(first_element, dict):
return DataFormat.LIST_OF_RECORDS
if isinstance(first_element, (list, tuple, set, frozenset)):
return DataFormat.LIST_OF_ROWS
elif isinstance(input_data, (dict, Mapping)):
if not input_data:
return DataFormat.KEY_VALUE_DICT
if len(input_data) > 0:
first_value = next(iter(input_data.values()))
# In the future, we could potentially also support tight & split formats
if isinstance(first_value, dict):
return DataFormat.COLUMN_INDEX_MAPPING
if isinstance(first_value, (list, tuple)):
return DataFormat.COLUMN_VALUE_MAPPING
if isinstance(first_value, pd.Series):
return DataFormat.COLUMN_SERIES_MAPPING
# Use key-value dict as fallback. However, if the values of the dict
# contains mixed types, it will become non-editable in the frontend.
return DataFormat.KEY_VALUE_DICT
elif is_list_like(input_data):
return DataFormat.LIST_OF_VALUES
return DataFormat.UNKNOWN
def _unify_missing_values(df: DataFrame) -> DataFrame:
"""Unify all missing values in a DataFrame to None.
Pandas uses a variety of values to represent missing values, including np.nan,
NaT, None, and pd.NA. This function replaces all of these values with None,
which is the only missing value type that is supported by all data
"""
import numpy as np
import pandas as pd
# Replace all recognized nulls (np.nan, pd.NA, NaT) with None
# then infer objects without creating a separate copy:
# For performance reasons, we could use copy=False here.
# However, this is only available in pandas >=2.
return df.replace([pd.NA, pd.NaT, np.nan], None).infer_objects()
def _pandas_df_to_series(df: DataFrame) -> Series[Any]:
"""Convert a Pandas DataFrame to a Pandas Series by selecting the first column.
Raises
------
ValueError
If the DataFrame has more than one column.
"""
# Select first column in dataframe and create a new series based on the values
if len(df.columns) != 1:
raise ValueError(
f"DataFrame is expected to have a single column but has {len(df.columns)}."
)
return df[df.columns[0]]
def convert_pandas_df_to_data_format(
df: DataFrame, data_format: DataFormat
) -> (
DataFrame
| Series[Any]
| pa.Table
| pa.Array
| np.ndarray[Any, np.dtype[Any]]
| tuple[Any]
| list[Any]
| set[Any]
| dict[str, Any]
):
"""Convert a Pandas DataFrame to the specified data format.
Parameters
----------
df : pd.DataFrame
The dataframe to convert.
data_format : DataFormat
The data format to convert to.
Returns
-------
pd.DataFrame, pd.Series, pyarrow.Table, np.ndarray, xarray.Dataset,
xarray.DataArray, polars.Dataframe, polars.Series, list, set, tuple, or dict.
The converted dataframe.
"""
if data_format in {
DataFormat.EMPTY,
DataFormat.DASK_OBJECT,
DataFormat.DBAPI_CURSOR,
DataFormat.DUCKDB_RELATION,
DataFormat.MODIN_OBJECT,
DataFormat.PANDAS_ARRAY,
DataFormat.PANDAS_DATAFRAME,
DataFormat.PANDAS_INDEX,
DataFormat.PANDAS_STYLER,
DataFormat.PYSPARK_OBJECT,
DataFormat.RAY_DATASET,
DataFormat.SNOWPANDAS_OBJECT,
DataFormat.SNOWPARK_OBJECT,
}:
return df
if data_format == DataFormat.NUMPY_LIST:
import numpy as np
# It's a 1-dimensional array, so we only return
# the first column as numpy array
# Calling to_numpy() on the full DataFrame would result in:
# [[1], [2]] instead of [1, 2]
return np.ndarray(0) if df.empty else df.iloc[:, 0].to_numpy()
if data_format == DataFormat.NUMPY_MATRIX:
import numpy as np
return np.ndarray(0) if df.empty else df.to_numpy()
if data_format == DataFormat.PYARROW_TABLE:
import pyarrow as pa
return pa.Table.from_pandas(df)
if data_format == DataFormat.PYARROW_ARRAY:
import pyarrow as pa
return pa.Array.from_pandas(_pandas_df_to_series(df))
if data_format == DataFormat.PANDAS_SERIES:
return _pandas_df_to_series(df)
if data_format in {DataFormat.POLARS_DATAFRAME, DataFormat.POLARS_LAZYFRAME}:
import polars as pl # type: ignore[import-not-found]
return pl.from_pandas(df)
if data_format == DataFormat.POLARS_SERIES:
import polars as pl
return pl.from_pandas(_pandas_df_to_series(df))
if data_format == DataFormat.XARRAY_DATASET:
import xarray as xr # type: ignore[import-not-found]
return xr.Dataset.from_dataframe(df)
if data_format == DataFormat.XARRAY_DATA_ARRAY:
import xarray as xr
return xr.DataArray.from_series(_pandas_df_to_series(df))
if data_format == DataFormat.LIST_OF_RECORDS:
return _unify_missing_values(df).to_dict(orient="records")
if data_format == DataFormat.LIST_OF_ROWS:
# to_numpy converts the dataframe to a list of rows
return _unify_missing_values(df).to_numpy().tolist()
if data_format == DataFormat.COLUMN_INDEX_MAPPING:
return _unify_missing_values(df).to_dict(orient="dict")
if data_format == DataFormat.COLUMN_VALUE_MAPPING:
return _unify_missing_values(df).to_dict(orient="list")
if data_format == DataFormat.COLUMN_SERIES_MAPPING:
return df.to_dict(orient="series")
if data_format in [
DataFormat.LIST_OF_VALUES,
DataFormat.TUPLE_OF_VALUES,
DataFormat.SET_OF_VALUES,
]:
df = _unify_missing_values(df)
return_list = []
if len(df.columns) == 1:
# Get the first column and convert to list
return_list = df[df.columns[0]].tolist()
elif len(df.columns) >= 1:
raise ValueError(
"DataFrame is expected to have a single column but "
f"has {len(df.columns)}."
)
if data_format == DataFormat.TUPLE_OF_VALUES:
return tuple(return_list)
if data_format == DataFormat.SET_OF_VALUES:
return set(return_list)
return return_list
if data_format == DataFormat.KEY_VALUE_DICT:
df = _unify_missing_values(df)
# The key is expected to be the index -> this will return the first column
# as a dict with index as key.
return {} if df.empty else df.iloc[:, 0].to_dict()
raise ValueError(f"Unsupported input data format: {data_format}")