141 lines
4.3 KiB
Python
141 lines
4.3 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Data marshalling utilities for ArrowTable protobufs, which are used by
|
|
CustomComponent for dataframe serialization.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from streamlit import dataframe_util
|
|
from streamlit.elements.lib import pandas_styler_utils
|
|
|
|
if TYPE_CHECKING:
|
|
from pandas import DataFrame, Index, Series
|
|
|
|
from streamlit.proto.Components_pb2 import ArrowTable as ArrowTableProto
|
|
|
|
|
|
def _maybe_tuple_to_list(item: Any) -> Any:
|
|
"""Convert a tuple to a list. Leave as is if it's not a tuple."""
|
|
return list(item) if isinstance(item, tuple) else item
|
|
|
|
|
|
def marshall(
|
|
proto: ArrowTableProto, data: Any, default_uuid: str | None = None
|
|
) -> None:
|
|
"""Marshall data into an ArrowTable proto.
|
|
|
|
Parameters
|
|
----------
|
|
proto : proto.ArrowTable
|
|
Output. The protobuf for a Streamlit ArrowTable proto.
|
|
|
|
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict, or None
|
|
Something that is or can be converted to a dataframe.
|
|
|
|
"""
|
|
if dataframe_util.is_pandas_styler(data):
|
|
pandas_styler_utils.marshall_styler(proto, data, default_uuid) # type: ignore
|
|
|
|
df = dataframe_util.convert_anything_to_pandas_df(data)
|
|
_marshall_index(proto, df.index)
|
|
_marshall_columns(proto, df.columns)
|
|
_marshall_data(proto, df)
|
|
|
|
|
|
def _marshall_index(proto: ArrowTableProto, index: Index) -> None:
|
|
"""Marshall pandas.DataFrame index into an ArrowTable proto.
|
|
|
|
Parameters
|
|
----------
|
|
proto : proto.ArrowTable
|
|
Output. The protobuf for a Streamlit ArrowTable proto.
|
|
|
|
index : pd.Index
|
|
Index to use for resulting frame.
|
|
Will default to RangeIndex (0, 1, 2, ..., n) if no index is provided.
|
|
|
|
"""
|
|
import pandas as pd
|
|
|
|
index = map(_maybe_tuple_to_list, index.values)
|
|
index_df = pd.DataFrame(index)
|
|
proto.index = dataframe_util.convert_pandas_df_to_arrow_bytes(index_df)
|
|
|
|
|
|
def _marshall_columns(proto: ArrowTableProto, columns: Series) -> None:
|
|
"""Marshall pandas.DataFrame columns into an ArrowTable proto.
|
|
|
|
Parameters
|
|
----------
|
|
proto : proto.ArrowTable
|
|
Output. The protobuf for a Streamlit ArrowTable proto.
|
|
|
|
columns : Series
|
|
Column labels to use for resulting frame.
|
|
Will default to RangeIndex (0, 1, 2, ..., n) if no column labels are provided.
|
|
|
|
"""
|
|
import pandas as pd
|
|
|
|
columns = map(_maybe_tuple_to_list, columns.values)
|
|
columns_df = pd.DataFrame(columns)
|
|
proto.columns = dataframe_util.convert_pandas_df_to_arrow_bytes(columns_df)
|
|
|
|
|
|
def _marshall_data(proto: ArrowTableProto, df: DataFrame) -> None:
|
|
"""Marshall pandas.DataFrame data into an ArrowTable proto.
|
|
|
|
Parameters
|
|
----------
|
|
proto : proto.ArrowTable
|
|
Output. The protobuf for a Streamlit ArrowTable proto.
|
|
|
|
df : pandas.DataFrame
|
|
A dataframe to marshall.
|
|
|
|
"""
|
|
proto.data = dataframe_util.convert_pandas_df_to_arrow_bytes(df)
|
|
|
|
|
|
def arrow_proto_to_dataframe(proto: ArrowTableProto) -> DataFrame:
|
|
"""Convert ArrowTable proto to pandas.DataFrame.
|
|
|
|
Parameters
|
|
----------
|
|
proto : proto.ArrowTable
|
|
Output. pandas.DataFrame
|
|
|
|
"""
|
|
|
|
if dataframe_util.is_pyarrow_version_less_than("14.0.1"):
|
|
raise RuntimeError(
|
|
"The installed pyarrow version is not compatible with this component. "
|
|
"Please upgrade to 14.0.1 or higher: pip install -U pyarrow"
|
|
)
|
|
|
|
import pandas as pd
|
|
|
|
data = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.data)
|
|
index = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.index)
|
|
columns = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.columns)
|
|
|
|
return pd.DataFrame(
|
|
data.to_numpy(),
|
|
index=index.to_numpy().T.tolist(),
|
|
columns=columns.to_numpy().T.tolist(),
|
|
)
|