team-10/env/Lib/site-packages/streamlit/components/v1/component_arrow.py
2025-08-02 07:34:44 +02:00

141 lines
4.3 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data marshalling utilities for ArrowTable protobufs, which are used by
CustomComponent for dataframe serialization.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from streamlit import dataframe_util
from streamlit.elements.lib import pandas_styler_utils
if TYPE_CHECKING:
from pandas import DataFrame, Index, Series
from streamlit.proto.Components_pb2 import ArrowTable as ArrowTableProto
def _maybe_tuple_to_list(item: Any) -> Any:
"""Convert a tuple to a list. Leave as is if it's not a tuple."""
return list(item) if isinstance(item, tuple) else item
def marshall(
proto: ArrowTableProto, data: Any, default_uuid: str | None = None
) -> None:
"""Marshall data into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict, or None
Something that is or can be converted to a dataframe.
"""
if dataframe_util.is_pandas_styler(data):
pandas_styler_utils.marshall_styler(proto, data, default_uuid) # type: ignore
df = dataframe_util.convert_anything_to_pandas_df(data)
_marshall_index(proto, df.index)
_marshall_columns(proto, df.columns)
_marshall_data(proto, df)
def _marshall_index(proto: ArrowTableProto, index: Index) -> None:
"""Marshall pandas.DataFrame index into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
index : pd.Index
Index to use for resulting frame.
Will default to RangeIndex (0, 1, 2, ..., n) if no index is provided.
"""
import pandas as pd
index = map(_maybe_tuple_to_list, index.values)
index_df = pd.DataFrame(index)
proto.index = dataframe_util.convert_pandas_df_to_arrow_bytes(index_df)
def _marshall_columns(proto: ArrowTableProto, columns: Series) -> None:
"""Marshall pandas.DataFrame columns into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
columns : Series
Column labels to use for resulting frame.
Will default to RangeIndex (0, 1, 2, ..., n) if no column labels are provided.
"""
import pandas as pd
columns = map(_maybe_tuple_to_list, columns.values)
columns_df = pd.DataFrame(columns)
proto.columns = dataframe_util.convert_pandas_df_to_arrow_bytes(columns_df)
def _marshall_data(proto: ArrowTableProto, df: DataFrame) -> None:
"""Marshall pandas.DataFrame data into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
df : pandas.DataFrame
A dataframe to marshall.
"""
proto.data = dataframe_util.convert_pandas_df_to_arrow_bytes(df)
def arrow_proto_to_dataframe(proto: ArrowTableProto) -> DataFrame:
"""Convert ArrowTable proto to pandas.DataFrame.
Parameters
----------
proto : proto.ArrowTable
Output. pandas.DataFrame
"""
if dataframe_util.is_pyarrow_version_less_than("14.0.1"):
raise RuntimeError(
"The installed pyarrow version is not compatible with this component. "
"Please upgrade to 14.0.1 or higher: pip install -U pyarrow"
)
import pandas as pd
data = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.data)
index = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.index)
columns = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.columns)
return pd.DataFrame(
data.to_numpy(),
index=index.to_numpy().T.tolist(),
columns=columns.to_numpy().T.tolist(),
)