2071 lines
62 KiB
Python
2071 lines
62 KiB
Python
![]() |
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
# Assert statements are allowed here since the app testing logic is used within unit tests:
|
||
|
# ruff: noqa: S101
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import textwrap
|
||
|
from abc import ABC, abstractmethod
|
||
|
from collections.abc import Iterator, Sequence
|
||
|
from dataclasses import dataclass, field, fields, is_dataclass
|
||
|
from datetime import date, datetime, time, timedelta
|
||
|
from typing import (
|
||
|
TYPE_CHECKING,
|
||
|
Any,
|
||
|
Callable,
|
||
|
Generic,
|
||
|
TypeVar,
|
||
|
Union,
|
||
|
cast,
|
||
|
overload,
|
||
|
)
|
||
|
|
||
|
from typing_extensions import Self, TypeAlias
|
||
|
|
||
|
from streamlit import dataframe_util, util
|
||
|
from streamlit.elements.heading import HeadingProtoTag
|
||
|
from streamlit.elements.widgets.select_slider import SelectSliderSerde
|
||
|
from streamlit.elements.widgets.slider import (
|
||
|
SliderSerde,
|
||
|
SliderStep,
|
||
|
SliderValueT,
|
||
|
)
|
||
|
from streamlit.elements.widgets.time_widgets import (
|
||
|
DateInputSerde,
|
||
|
DateWidgetReturn,
|
||
|
TimeInputSerde,
|
||
|
_parse_date_value,
|
||
|
)
|
||
|
from streamlit.proto.Alert_pb2 import Alert as AlertProto
|
||
|
from streamlit.proto.Checkbox_pb2 import Checkbox as CheckboxProto
|
||
|
from streamlit.proto.Markdown_pb2 import Markdown as MarkdownProto
|
||
|
from streamlit.proto.Slider_pb2 import Slider as SliderProto
|
||
|
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
|
||
|
from streamlit.runtime.state.common import TESTING_KEY, user_key_from_element_id
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from pandas import DataFrame as PandasDataframe
|
||
|
|
||
|
from streamlit.proto.Arrow_pb2 import Arrow as ArrowProto
|
||
|
from streamlit.proto.Block_pb2 import Block as BlockProto
|
||
|
from streamlit.proto.Button_pb2 import Button as ButtonProto
|
||
|
from streamlit.proto.ButtonGroup_pb2 import ButtonGroup as ButtonGroupProto
|
||
|
from streamlit.proto.ChatInput_pb2 import ChatInput as ChatInputProto
|
||
|
from streamlit.proto.Code_pb2 import Code as CodeProto
|
||
|
from streamlit.proto.ColorPicker_pb2 import ColorPicker as ColorPickerProto
|
||
|
from streamlit.proto.DateInput_pb2 import DateInput as DateInputProto
|
||
|
from streamlit.proto.Element_pb2 import Element as ElementProto
|
||
|
from streamlit.proto.Exception_pb2 import Exception as ExceptionProto
|
||
|
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||
|
from streamlit.proto.Heading_pb2 import Heading as HeadingProto
|
||
|
from streamlit.proto.Json_pb2 import Json as JsonProto
|
||
|
from streamlit.proto.Metric_pb2 import Metric as MetricProto
|
||
|
from streamlit.proto.MultiSelect_pb2 import MultiSelect as MultiSelectProto
|
||
|
from streamlit.proto.NumberInput_pb2 import NumberInput as NumberInputProto
|
||
|
from streamlit.proto.Radio_pb2 import Radio as RadioProto
|
||
|
from streamlit.proto.Selectbox_pb2 import Selectbox as SelectboxProto
|
||
|
from streamlit.proto.Text_pb2 import Text as TextProto
|
||
|
from streamlit.proto.TextArea_pb2 import TextArea as TextAreaProto
|
||
|
from streamlit.proto.TextInput_pb2 import TextInput as TextInputProto
|
||
|
from streamlit.proto.TimeInput_pb2 import TimeInput as TimeInputProto
|
||
|
from streamlit.proto.Toast_pb2 import Toast as ToastProto
|
||
|
from streamlit.runtime.state.safe_session_state import SafeSessionState
|
||
|
from streamlit.testing.v1.app_test import AppTest
|
||
|
|
||
|
T = TypeVar("T")
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class InitialValue:
|
||
|
"""Used to represent the initial value of a widget."""
|
||
|
|
||
|
pass
|
||
|
|
||
|
|
||
|
# TODO: This class serves as a fallback option for elements that have not
|
||
|
# been implemented yet, as well as providing implementations of some
|
||
|
# trivial methods. It may have significantly reduced scope once all elements
|
||
|
# have been implemented.
|
||
|
# This class will not be sufficient implementation for most elements.
|
||
|
# Widgets need their own classes to translate interactions into the appropriate
|
||
|
# WidgetState and provide higher level interaction interfaces, and other elements
|
||
|
# have enough variation in how to get their values that most will need their
|
||
|
# own classes too.
|
||
|
@dataclass
|
||
|
class Element(ABC):
|
||
|
"""
|
||
|
Element base class for testing.
|
||
|
|
||
|
This class's methods and attributes are universal for all elements
|
||
|
implemented in testing. For example, ``Caption``, ``Code``, ``Text``, and
|
||
|
``Title`` inherit from ``Element``. All widget classes also
|
||
|
inherit from Element, but have additional methods specific to each
|
||
|
widget type. See the ``AppTest`` class for the full list of supported
|
||
|
elements.
|
||
|
|
||
|
For all element classes, parameters of the original element can be obtained
|
||
|
as properties. For example, ``Button.label``, ``Caption.help``, and
|
||
|
``Toast.icon``.
|
||
|
|
||
|
"""
|
||
|
|
||
|
type: str = field(repr=False)
|
||
|
proto: Any = field(repr=False)
|
||
|
root: ElementTree = field(repr=False)
|
||
|
key: str | None
|
||
|
|
||
|
@abstractmethod
|
||
|
def __init__(self, proto: ElementProto, root: ElementTree) -> None: ...
|
||
|
|
||
|
def __iter__(self) -> Iterator[Self]:
|
||
|
yield self
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def value(self) -> Any:
|
||
|
"""The value or contents of the element."""
|
||
|
...
|
||
|
|
||
|
def __getattr__(self, name: str) -> Any:
|
||
|
"""Fallback attempt to get an attribute from the proto."""
|
||
|
return getattr(self.proto, name)
|
||
|
|
||
|
def run(self, *, timeout: float | None = None) -> AppTest:
|
||
|
"""Run the ``AppTest`` script which contains the element.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
timeout
|
||
|
The maximum number of seconds to run the script. None means
|
||
|
use the AppTest's default.
|
||
|
"""
|
||
|
return self.root.run(timeout=timeout)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return util.repr_(self)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class UnknownElement(Element):
|
||
|
def __init__(self, proto: ElementProto, root: ElementTree) -> None:
|
||
|
ty = proto.WhichOneof("type")
|
||
|
assert ty is not None
|
||
|
self.proto = getattr(proto, ty)
|
||
|
self.root = root
|
||
|
self.type = ty
|
||
|
self.key = None
|
||
|
|
||
|
@property
|
||
|
def value(self) -> Any:
|
||
|
try:
|
||
|
state = self.root.session_state
|
||
|
assert state is not None
|
||
|
return state[self.proto.id]
|
||
|
except ValueError:
|
||
|
# No id field, not a widget
|
||
|
return self.proto.value
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Widget(Element, ABC):
|
||
|
"""Widget base class for testing."""
|
||
|
|
||
|
id: str = field(repr=False)
|
||
|
disabled: bool
|
||
|
key: str | None
|
||
|
_value: Any
|
||
|
|
||
|
def __init__(self, proto: Any, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.key = user_key_from_element_id(self.id)
|
||
|
self._value = None
|
||
|
|
||
|
def set_value(self, v: Any) -> Self:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def _widget_state(self) -> WidgetState: ...
|
||
|
|
||
|
|
||
|
El_co = TypeVar("El_co", bound=Element, covariant=True)
|
||
|
|
||
|
|
||
|
class ElementList(Generic[El_co]):
|
||
|
def __init__(self, els: Sequence[El_co]) -> None:
|
||
|
self._list: Sequence[El_co] = els
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self._list)
|
||
|
|
||
|
@property
|
||
|
def len(self) -> int:
|
||
|
return len(self)
|
||
|
|
||
|
@overload
|
||
|
def __getitem__(self, idx: int) -> El_co: ...
|
||
|
|
||
|
@overload
|
||
|
def __getitem__(self, idx: slice) -> ElementList[El_co]: ...
|
||
|
|
||
|
def __getitem__(self, idx: int | slice) -> El_co | ElementList[El_co]:
|
||
|
if isinstance(idx, slice):
|
||
|
return ElementList(self._list[idx])
|
||
|
return self._list[idx]
|
||
|
|
||
|
def __iter__(self) -> Iterator[El_co]:
|
||
|
yield from self._list
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return util.repr_(self)
|
||
|
|
||
|
def __eq__(self, other: ElementList[El_co] | object) -> bool:
|
||
|
if isinstance(other, ElementList):
|
||
|
return self._list == other._list
|
||
|
return self._list == other
|
||
|
|
||
|
def __hash__(self) -> int:
|
||
|
return hash(tuple(self._list))
|
||
|
|
||
|
@property
|
||
|
def values(self) -> Sequence[Any]:
|
||
|
return [e.value for e in self]
|
||
|
|
||
|
|
||
|
W_co = TypeVar("W_co", bound=Widget, covariant=True)
|
||
|
|
||
|
|
||
|
class WidgetList(ElementList[W_co], Generic[W_co]):
|
||
|
def __call__(self, key: str) -> W_co:
|
||
|
for e in self._list:
|
||
|
if e.key == key:
|
||
|
return e
|
||
|
|
||
|
raise KeyError(key)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class AlertBase(Element):
|
||
|
proto: AlertProto = field(repr=False)
|
||
|
icon: str
|
||
|
|
||
|
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
return self.proto.body
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Error(AlertBase):
|
||
|
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "error"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Warning(AlertBase): # noqa: A001
|
||
|
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "warning"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Info(AlertBase):
|
||
|
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "info"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Success(AlertBase):
|
||
|
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "success"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Button(Widget):
|
||
|
"""A representation of ``st.button`` and ``st.form_submit_button``."""
|
||
|
|
||
|
_value: bool
|
||
|
|
||
|
proto: ButtonProto = field(repr=False)
|
||
|
label: str
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: ButtonProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = False
|
||
|
self.type = "button"
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.trigger_value = self._value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> bool:
|
||
|
"""The value of the button. (bool)""" # noqa: D400
|
||
|
if self._value:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("bool", state[TESTING_KEY][self.id])
|
||
|
|
||
|
def set_value(self, v: bool) -> Button:
|
||
|
"""Set the value of the button."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
def click(self) -> Button:
|
||
|
"""Set the value of the button to True."""
|
||
|
return self.set_value(True)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class ChatInput(Widget):
|
||
|
"""A representation of ``st.chat_input``."""
|
||
|
|
||
|
_value: str | None
|
||
|
proto: ChatInputProto = field(repr=False)
|
||
|
placeholder: str
|
||
|
|
||
|
def __init__(self, proto: ChatInputProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "chat_input"
|
||
|
|
||
|
def set_value(self, v: str | None) -> ChatInput:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
if self._value is not None:
|
||
|
ws.string_trigger_value.data = self._value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str | None:
|
||
|
"""The value of the widget. (str)""" # noqa: D400
|
||
|
if self._value:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return state[TESTING_KEY][self.id] # type: ignore
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Checkbox(Widget):
|
||
|
"""A representation of ``st.checkbox``."""
|
||
|
|
||
|
_value: bool | None
|
||
|
|
||
|
proto: CheckboxProto = field(repr=False)
|
||
|
label: str
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: CheckboxProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "checkbox"
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.bool_value = self.value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> bool:
|
||
|
"""The value of the widget. (bool)""" # noqa: D400
|
||
|
if self._value is not None:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("bool", state[self.id])
|
||
|
|
||
|
def set_value(self, v: bool) -> Checkbox:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
def check(self) -> Checkbox:
|
||
|
"""Set the value of the widget to True."""
|
||
|
return self.set_value(True)
|
||
|
|
||
|
def uncheck(self) -> Checkbox:
|
||
|
"""Set the value of the widget to False."""
|
||
|
return self.set_value(False)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Code(Element):
|
||
|
"""A representation of ``st.code``."""
|
||
|
|
||
|
proto: CodeProto = field(repr=False)
|
||
|
|
||
|
language: str
|
||
|
show_line_numbers: bool
|
||
|
key: None
|
||
|
|
||
|
def __init__(self, proto: CodeProto, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
self.type = "code"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
"""The value of the element. (str)""" # noqa: D400
|
||
|
return self.proto.code_text
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class ColorPicker(Widget):
|
||
|
"""A representation of ``st.color_picker``."""
|
||
|
|
||
|
_value: str | None
|
||
|
label: str
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
proto: ColorPickerProto = field(repr=False)
|
||
|
|
||
|
def __init__(self, proto: ColorPickerProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "color_picker"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
"""The currently selected value as a hex string. (str)""" # noqa: D400
|
||
|
if self._value is not None:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("str", state[self.id])
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
"""Protobuf message representing the state of the widget, including
|
||
|
any interactions that have happened.
|
||
|
Should be the same as the frontend would produce for those interactions.
|
||
|
"""
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.string_value = self.value
|
||
|
return ws
|
||
|
|
||
|
def set_value(self, v: str) -> ColorPicker:
|
||
|
"""Set the value of the widget as a hex string."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
def pick(self, v: str) -> ColorPicker:
|
||
|
"""Set the value of the widget as a hex string. May omit the "#" prefix."""
|
||
|
if not v.startswith("#"):
|
||
|
v = f"#{v}"
|
||
|
return self.set_value(v)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Dataframe(Element):
|
||
|
proto: ArrowProto = field(repr=False)
|
||
|
|
||
|
def __init__(self, proto: ArrowProto, root: ElementTree) -> None:
|
||
|
self.key = None
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.type = "arrow_data_frame"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> PandasDataframe:
|
||
|
return dataframe_util.convert_arrow_bytes_to_pandas_df(self.proto.data)
|
||
|
|
||
|
|
||
|
SingleDateValue: TypeAlias = Union[date, datetime]
|
||
|
DateValue: TypeAlias = Union[SingleDateValue, Sequence[SingleDateValue], None]
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class DateInput(Widget):
|
||
|
"""A representation of ``st.date_input``."""
|
||
|
|
||
|
_value: DateValue | None | InitialValue
|
||
|
proto: DateInputProto = field(repr=False)
|
||
|
label: str
|
||
|
min: date
|
||
|
max: date
|
||
|
is_range: bool
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: DateInputProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = InitialValue()
|
||
|
self.type = "date_input"
|
||
|
self.min = datetime.strptime(proto.min, "%Y/%m/%d").date()
|
||
|
self.max = datetime.strptime(proto.max, "%Y/%m/%d").date()
|
||
|
|
||
|
def set_value(self, v: DateValue) -> DateInput:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
|
||
|
serde = DateInputSerde(None) # type: ignore
|
||
|
ws.string_array_value.data[:] = serde.serialize(self.value)
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> DateWidgetReturn:
|
||
|
"""The value of the widget. (date or Tuple of date)""" # noqa: D400
|
||
|
if not isinstance(self._value, InitialValue):
|
||
|
parsed, _ = _parse_date_value(self._value)
|
||
|
return tuple(parsed) if parsed is not None else None # type: ignore
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return state[self.id] # type: ignore
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Exception(Element): # noqa: A001
|
||
|
message: str
|
||
|
is_markdown: bool
|
||
|
stack_trace: list[str]
|
||
|
is_warning: bool
|
||
|
|
||
|
def __init__(self, proto: ExceptionProto, root: ElementTree) -> None:
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
self.proto = proto
|
||
|
self.type = "exception"
|
||
|
|
||
|
self.is_markdown = proto.message_is_markdown
|
||
|
self.stack_trace = list(proto.stack_trace)
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
return self.message
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class HeadingBase(Element, ABC):
|
||
|
proto: HeadingProto = field(repr=False)
|
||
|
|
||
|
tag: str
|
||
|
anchor: str | None
|
||
|
hide_anchor: bool
|
||
|
key: None
|
||
|
|
||
|
def __init__(self, proto: HeadingProto, root: ElementTree, type_: str) -> None:
|
||
|
self.proto = proto
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
self.type = type_
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
return self.proto.body
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Header(HeadingBase):
|
||
|
def __init__(self, proto: HeadingProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root, "header")
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Subheader(HeadingBase):
|
||
|
def __init__(self, proto: HeadingProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root, "subheader")
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Title(HeadingBase):
|
||
|
def __init__(self, proto: HeadingProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root, "title")
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Json(Element):
|
||
|
proto: JsonProto = field(repr=False)
|
||
|
|
||
|
expanded: bool
|
||
|
|
||
|
def __init__(self, proto: JsonProto, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
self.type = "json"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
return self.proto.body
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Markdown(Element):
|
||
|
proto: MarkdownProto = field(repr=False)
|
||
|
|
||
|
is_caption: bool
|
||
|
allow_html: bool
|
||
|
key: None
|
||
|
|
||
|
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
self.type = "markdown"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
return self.proto.body
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Caption(Markdown):
|
||
|
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "caption"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Divider(Markdown):
|
||
|
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "divider"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Latex(Markdown):
|
||
|
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "latex"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Metric(Element):
|
||
|
proto: MetricProto
|
||
|
label: str
|
||
|
delta: str
|
||
|
color: str
|
||
|
help: str
|
||
|
|
||
|
def __init__(self, proto: MetricProto, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
self.type = "metric"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
return self.proto.body
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class ButtonGroup(Widget, Generic[T]):
|
||
|
"""A representation of button_group that is used by ``st.feedback``."""
|
||
|
|
||
|
_value: list[T] | None
|
||
|
|
||
|
proto: ButtonGroupProto = field(repr=False)
|
||
|
options: list[ButtonGroupProto.Option]
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: ButtonGroupProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "button_group"
|
||
|
self.options = list(proto.options)
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
"""Protobuf message representing the state of the widget, including
|
||
|
any interactions that have happened.
|
||
|
Should be the same as the frontend would produce for those interactions.
|
||
|
"""
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.int_array_value.data[:] = self.indices
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> list[T]:
|
||
|
"""The currently selected values from the options. (list)""" # noqa: D400
|
||
|
if self._value is not None:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("list[T]", state[self.id])
|
||
|
|
||
|
@property
|
||
|
def indices(self) -> Sequence[int]:
|
||
|
"""The indices of the currently selected values from the options. (list)""" # noqa: D400
|
||
|
return [self.options.index(self.format_func(v)) for v in self.value]
|
||
|
|
||
|
@property
|
||
|
def format_func(self) -> Callable[[Any], Any]:
|
||
|
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
|
||
|
ss = self.root.session_state
|
||
|
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
|
||
|
|
||
|
def set_value(self, v: list[T]) -> ButtonGroup[T]:
|
||
|
"""Set the value of the multiselect widget. (list)""" # noqa: D400
|
||
|
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
def select(self, v: T) -> ButtonGroup[T]:
|
||
|
"""
|
||
|
Add a selection to the widget. Do nothing if the value is already selected.\
|
||
|
If testing a multiselect widget with repeated options, use ``set_value``\
|
||
|
instead.
|
||
|
"""
|
||
|
current = self.value
|
||
|
if v in current:
|
||
|
return self
|
||
|
new = current.copy()
|
||
|
new.append(v)
|
||
|
self.set_value(new)
|
||
|
return self
|
||
|
|
||
|
def unselect(self, v: T) -> ButtonGroup[T]:
|
||
|
"""
|
||
|
Remove a selection from the widget. Do nothing if the value is not\
|
||
|
already selected. If a value is selected multiple times, the first\
|
||
|
instance is removed.
|
||
|
"""
|
||
|
current = self.value
|
||
|
if v not in current:
|
||
|
return self
|
||
|
new = current.copy()
|
||
|
while v in new:
|
||
|
new.remove(v)
|
||
|
self.set_value(new)
|
||
|
return self
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Multiselect(Widget, Generic[T]):
|
||
|
"""A representation of ``st.multiselect``."""
|
||
|
|
||
|
_value: list[T] | None
|
||
|
|
||
|
proto: MultiSelectProto = field(repr=False)
|
||
|
label: str
|
||
|
options: list[str]
|
||
|
max_selections: int
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: MultiSelectProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "multiselect"
|
||
|
self.options = list(proto.options)
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
"""Protobuf message representing the state of the widget, including
|
||
|
any interactions that have happened.
|
||
|
Should be the same as the frontend would produce for those interactions.
|
||
|
"""
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.string_array_value.data[:] = self.values
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> list[T]:
|
||
|
"""The currently selected values from the options. (list)""" # noqa: D400
|
||
|
if self._value is not None:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("list[T]", state[self.id])
|
||
|
|
||
|
@property
|
||
|
def indices(self) -> Sequence[int]:
|
||
|
"""The indices of the currently selected values from the options. (list)""" # noqa: D400
|
||
|
return [self.options.index(self.format_func(v)) for v in self.value]
|
||
|
|
||
|
@property
|
||
|
def values(self) -> Sequence[str]:
|
||
|
"""The currently selected values from the options. (list)""" # noqa: D400
|
||
|
return [self.format_func(v) for v in self.value]
|
||
|
|
||
|
@property
|
||
|
def format_func(self) -> Callable[[Any], Any]:
|
||
|
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
|
||
|
ss = self.root.session_state
|
||
|
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
|
||
|
|
||
|
def set_value(self, v: list[T]) -> Multiselect[T]:
|
||
|
"""Set the value of the multiselect widget. (list)""" # noqa: D400
|
||
|
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
def select(self, v: T) -> Multiselect[T]:
|
||
|
"""
|
||
|
Add a selection to the widget. Do nothing if the value is already selected.\
|
||
|
If testing a multiselect widget with repeated options, use ``set_value``\
|
||
|
instead.
|
||
|
"""
|
||
|
current = self.value
|
||
|
if v in current:
|
||
|
return self
|
||
|
new = current.copy()
|
||
|
new.append(v)
|
||
|
self.set_value(new)
|
||
|
return self
|
||
|
|
||
|
def unselect(self, v: T) -> Multiselect[T]:
|
||
|
"""
|
||
|
Remove a selection from the widget. Do nothing if the value is not\
|
||
|
already selected. If a value is selected multiple times, the first\
|
||
|
instance is removed.
|
||
|
"""
|
||
|
current = self.value
|
||
|
if v not in current:
|
||
|
return self
|
||
|
new = current.copy()
|
||
|
while v in new:
|
||
|
new.remove(v)
|
||
|
self.set_value(new)
|
||
|
return self
|
||
|
|
||
|
|
||
|
Number = Union[int, float]
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class NumberInput(Widget):
|
||
|
"""A representation of ``st.number_input``."""
|
||
|
|
||
|
_value: Number | None | InitialValue
|
||
|
proto: NumberInputProto = field(repr=False)
|
||
|
label: str
|
||
|
min: Number | None
|
||
|
max: Number | None
|
||
|
step: Number
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: NumberInputProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = InitialValue()
|
||
|
self.type = "number_input"
|
||
|
self.min = proto.min if proto.has_min else None
|
||
|
self.max = proto.max if proto.has_max else None
|
||
|
|
||
|
def set_value(self, v: Number | None) -> NumberInput:
|
||
|
"""Set the value of the ``st.number_input`` widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
if self.value is not None:
|
||
|
ws.double_value = self.value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> Number | None:
|
||
|
"""Get the current value of the ``st.number_input`` widget."""
|
||
|
if not isinstance(self._value, InitialValue):
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
|
||
|
# Awkward to do this with `cast`
|
||
|
return state[self.id] # type: ignore
|
||
|
|
||
|
def increment(self) -> NumberInput:
|
||
|
"""Increment the ``st.number_input`` widget as if the user clicked "+"."""
|
||
|
if self.value is None:
|
||
|
return self
|
||
|
|
||
|
v = min(self.value + self.step, self.max or float("inf"))
|
||
|
return self.set_value(v)
|
||
|
|
||
|
def decrement(self) -> NumberInput:
|
||
|
"""Decrement the ``st.number_input`` widget as if the user clicked "-"."""
|
||
|
if self.value is None:
|
||
|
return self
|
||
|
|
||
|
v = max(self.value - self.step, self.min or float("-inf"))
|
||
|
return self.set_value(v)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Radio(Widget, Generic[T]):
|
||
|
"""A representation of ``st.radio``."""
|
||
|
|
||
|
_value: T | None | InitialValue
|
||
|
|
||
|
proto: RadioProto = field(repr=False)
|
||
|
label: str
|
||
|
options: list[str]
|
||
|
horizontal: bool
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: RadioProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = InitialValue()
|
||
|
self.type = "radio"
|
||
|
self.options = list(proto.options)
|
||
|
|
||
|
@property
|
||
|
def index(self) -> int | None:
|
||
|
"""The index of the current selection. (int)""" # noqa: D400
|
||
|
if self.value is None:
|
||
|
return None
|
||
|
return self.options.index(self.format_func(self.value))
|
||
|
|
||
|
@property
|
||
|
def value(self) -> T | None:
|
||
|
"""The currently selected value from the options. (Any)""" # noqa: D400
|
||
|
if not isinstance(self._value, InitialValue):
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("T", state[self.id])
|
||
|
|
||
|
@property
|
||
|
def format_func(self) -> Callable[[Any], Any]:
|
||
|
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
|
||
|
ss = self.root.session_state
|
||
|
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
|
||
|
|
||
|
def set_value(self, v: T | None) -> Radio[T]:
|
||
|
"""Set the selection by value."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
"""Protobuf message representing the state of the widget, including
|
||
|
any interactions that have happened.
|
||
|
Should be the same as the frontend would produce for those interactions.
|
||
|
"""
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
if self.index is not None:
|
||
|
ws.int_value = self.index
|
||
|
return ws
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Selectbox(Widget, Generic[T]):
|
||
|
"""A representation of ``st.selectbox``."""
|
||
|
|
||
|
_value: T | None | InitialValue
|
||
|
|
||
|
proto: SelectboxProto = field(repr=False)
|
||
|
label: str
|
||
|
options: list[str]
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: SelectboxProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = InitialValue()
|
||
|
self.type = "selectbox"
|
||
|
self.options = list(proto.options)
|
||
|
|
||
|
@property
|
||
|
def index(self) -> int | None:
|
||
|
"""The index of the current selection. (int)""" # noqa: D400
|
||
|
if self.value is None:
|
||
|
return None
|
||
|
|
||
|
if len(self.options) == 0:
|
||
|
return 0
|
||
|
return self.options.index(self.format_func(self.value))
|
||
|
|
||
|
@property
|
||
|
def value(self) -> T | None:
|
||
|
"""The currently selected value from the options. (Any)""" # noqa: D400
|
||
|
if not isinstance(self._value, InitialValue):
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("T", state[self.id])
|
||
|
|
||
|
@property
|
||
|
def format_func(self) -> Callable[[Any], Any]:
|
||
|
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
|
||
|
ss = self.root.session_state
|
||
|
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
|
||
|
|
||
|
def set_value(self, v: T | None) -> Selectbox[T]:
|
||
|
"""Set the selection by value."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
def select(self, v: T | None) -> Selectbox[T]:
|
||
|
"""Set the selection by value."""
|
||
|
return self.set_value(v)
|
||
|
|
||
|
def select_index(self, index: int | None) -> Selectbox[T]:
|
||
|
"""Set the selection by index."""
|
||
|
if index is None:
|
||
|
return self.set_value(None)
|
||
|
return self.set_value(cast("T", self.options[index]))
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
"""Protobuf message representing the state of the widget, including
|
||
|
any interactions that have happened.
|
||
|
Should be the same as the frontend would produce for those interactions.
|
||
|
"""
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
if self.index is not None and len(self.options) > 0:
|
||
|
ws.string_value = self.options[self.index]
|
||
|
return ws
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class SelectSlider(Widget, Generic[T]):
|
||
|
"""A representation of ``st.select_slider``."""
|
||
|
|
||
|
_value: T | Sequence[T] | None
|
||
|
|
||
|
proto: SliderProto = field(repr=False)
|
||
|
label: str
|
||
|
data_type: SliderProto.DataType.ValueType
|
||
|
options: list[str]
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: SliderProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "select_slider"
|
||
|
self.options = list(proto.options)
|
||
|
|
||
|
def set_value(self, v: T | Sequence[T]) -> SelectSlider[T]:
|
||
|
"""Set the (single) selection by value."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
serde = SelectSliderSerde(self.options, [], False)
|
||
|
try:
|
||
|
v = serde.serialize(self.format_func(self.value))
|
||
|
except (ValueError, TypeError):
|
||
|
try:
|
||
|
v = serde.serialize([self.format_func(val) for val in self.value]) # type: ignore
|
||
|
except: # noqa: E722
|
||
|
raise ValueError(f"Could not find index for {self.value}")
|
||
|
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.double_array_value.data[:] = v
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> T | Sequence[T]:
|
||
|
"""The currently selected value or range. (Any or Sequence of Any)""" # noqa: D400
|
||
|
if self._value is not None:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
# Awkward to do this with `cast`
|
||
|
return state[self.id] # type: ignore
|
||
|
|
||
|
@property
|
||
|
def format_func(self) -> Callable[[Any], Any]:
|
||
|
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
|
||
|
ss = self.root.session_state
|
||
|
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
|
||
|
|
||
|
def set_range(self, lower: T, upper: T) -> SelectSlider[T]:
|
||
|
"""Set the ranged selection by values."""
|
||
|
return self.set_value([lower, upper])
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Slider(Widget, Generic[SliderValueT]):
|
||
|
"""A representation of ``st.slider``."""
|
||
|
|
||
|
_value: SliderValueT | Sequence[SliderValueT] | None
|
||
|
|
||
|
proto: SliderProto = field(repr=False)
|
||
|
label: str
|
||
|
data_type: SliderProto.DataType.ValueType
|
||
|
min: SliderValueT
|
||
|
max: SliderValueT
|
||
|
step: SliderStep
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: SliderProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self.type = "slider"
|
||
|
|
||
|
def set_value(
|
||
|
self, v: SliderValueT | Sequence[SliderValueT]
|
||
|
) -> Slider[SliderValueT]:
|
||
|
"""Set the (single) value of the slider."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
data_type = self.proto.data_type
|
||
|
serde = SliderSerde([], data_type, True, None)
|
||
|
v = serde.serialize(self.value)
|
||
|
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.double_array_value.data[:] = v
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> SliderValueT | Sequence[SliderValueT]:
|
||
|
"""The currently selected value or range. (Any or Sequence of Any)""" # noqa: D400
|
||
|
if self._value is not None:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
# Awkward to do this with `cast`
|
||
|
return state[self.id] # type: ignore
|
||
|
|
||
|
def set_range(
|
||
|
self, lower: SliderValueT, upper: SliderValueT
|
||
|
) -> Slider[SliderValueT]:
|
||
|
"""Set the ranged value of the slider."""
|
||
|
return self.set_value([lower, upper])
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Table(Element):
|
||
|
proto: ArrowProto = field(repr=False)
|
||
|
|
||
|
def __init__(self, proto: ArrowProto, root: ElementTree) -> None:
|
||
|
self.key = None
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.type = "arrow_table"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> PandasDataframe:
|
||
|
return dataframe_util.convert_arrow_bytes_to_pandas_df(self.proto.data)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Text(Element):
|
||
|
proto: TextProto = field(repr=False)
|
||
|
|
||
|
key: None = None
|
||
|
|
||
|
def __init__(self, proto: TextProto, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.type = "text"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
"""The value of the element. (str)""" # noqa: D400
|
||
|
return self.proto.body
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class TextArea(Widget):
|
||
|
"""A representation of ``st.text_area``."""
|
||
|
|
||
|
_value: str | None | InitialValue
|
||
|
|
||
|
proto: TextAreaProto = field(repr=False)
|
||
|
label: str
|
||
|
max_chars: int
|
||
|
placeholder: str
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: TextAreaProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = InitialValue()
|
||
|
self.type = "text_area"
|
||
|
|
||
|
def set_value(self, v: str | None) -> TextArea:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
if self.value is not None:
|
||
|
ws.string_value = self.value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str | None:
|
||
|
"""The current value of the widget. (str)""" # noqa: D400
|
||
|
if not isinstance(self._value, InitialValue):
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
# Awkward to do this with `cast`
|
||
|
return state[self.id] # type: ignore
|
||
|
|
||
|
def input(self, v: str) -> TextArea:
|
||
|
"""
|
||
|
Set the value of the widget only if the value does not exceed the\
|
||
|
maximum allowed characters.
|
||
|
"""
|
||
|
# TODO: should input be setting or appending?
|
||
|
if self.max_chars and len(v) > self.max_chars:
|
||
|
return self
|
||
|
return self.set_value(v)
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class TextInput(Widget):
|
||
|
"""A representation of ``st.text_input``."""
|
||
|
|
||
|
_value: str | None | InitialValue
|
||
|
proto: TextInputProto = field(repr=False)
|
||
|
label: str
|
||
|
max_chars: int
|
||
|
autocomplete: str
|
||
|
placeholder: str
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: TextInputProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = InitialValue()
|
||
|
self.type = "text_input"
|
||
|
|
||
|
def set_value(self, v: str | None) -> TextInput:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
if self.value is not None:
|
||
|
ws.string_value = self.value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str | None:
|
||
|
"""The current value of the widget. (str)""" # noqa: D400
|
||
|
if not isinstance(self._value, InitialValue):
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
# Awkward to do this with `cast`
|
||
|
return state[self.id] # type: ignore
|
||
|
|
||
|
def input(self, v: str) -> TextInput:
|
||
|
"""
|
||
|
Set the value of the widget only if the value does not exceed the\
|
||
|
maximum allowed characters.
|
||
|
"""
|
||
|
# TODO: should input be setting or appending?
|
||
|
if self.max_chars and len(v) > self.max_chars:
|
||
|
return self
|
||
|
return self.set_value(v)
|
||
|
|
||
|
|
||
|
TimeValue: TypeAlias = Union[time, datetime]
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class TimeInput(Widget):
|
||
|
"""A representation of ``st.time_input``."""
|
||
|
|
||
|
_value: TimeValue | None | InitialValue
|
||
|
proto: TimeInputProto = field(repr=False)
|
||
|
label: str
|
||
|
step: int
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: TimeInputProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = InitialValue()
|
||
|
self.type = "time_input"
|
||
|
|
||
|
def set_value(self, v: TimeValue | None) -> TimeInput:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
|
||
|
serde = TimeInputSerde(None)
|
||
|
serialized_value = serde.serialize(self.value)
|
||
|
if serialized_value is not None:
|
||
|
ws.string_value = serialized_value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> time | None:
|
||
|
"""The current value of the widget. (time)""" # noqa: D400
|
||
|
if not isinstance(self._value, InitialValue):
|
||
|
v = self._value
|
||
|
return v.time() if isinstance(v, datetime) else v
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return state[self.id] # type: ignore
|
||
|
|
||
|
def increment(self) -> TimeInput:
|
||
|
"""Select the next available time."""
|
||
|
if self.value is None:
|
||
|
return self
|
||
|
dt = datetime.combine(date.today(), self.value) + timedelta(seconds=self.step)
|
||
|
return self.set_value(dt.time())
|
||
|
|
||
|
def decrement(self) -> TimeInput:
|
||
|
"""Select the previous available time."""
|
||
|
if self.value is None:
|
||
|
return self
|
||
|
dt = datetime.combine(date.today(), self.value) - timedelta(seconds=self.step)
|
||
|
return self.set_value(dt.time())
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Toast(Element):
|
||
|
proto: ToastProto = field(repr=False)
|
||
|
icon: str
|
||
|
|
||
|
def __init__(self, proto: ToastProto, root: ElementTree) -> None:
|
||
|
self.proto = proto
|
||
|
self.key = None
|
||
|
self.root = root
|
||
|
self.type = "toast"
|
||
|
|
||
|
@property
|
||
|
def value(self) -> str:
|
||
|
return self.proto.body
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Toggle(Widget):
|
||
|
"""A representation of ``st.toggle``."""
|
||
|
|
||
|
_value: bool | None
|
||
|
|
||
|
proto: CheckboxProto = field(repr=False)
|
||
|
label: str
|
||
|
help: str
|
||
|
form_id: str
|
||
|
|
||
|
def __init__(self, proto: CheckboxProto, root: ElementTree) -> None:
|
||
|
super().__init__(proto, root)
|
||
|
self._value = None
|
||
|
self.type = "toggle"
|
||
|
|
||
|
@property
|
||
|
def _widget_state(self) -> WidgetState:
|
||
|
ws = WidgetState()
|
||
|
ws.id = self.id
|
||
|
ws.bool_value = self.value
|
||
|
return ws
|
||
|
|
||
|
@property
|
||
|
def value(self) -> bool:
|
||
|
"""The current value of the widget. (bool)""" # noqa: D400
|
||
|
if self._value is not None:
|
||
|
return self._value
|
||
|
state = self.root.session_state
|
||
|
assert state
|
||
|
return cast("bool", state[self.id])
|
||
|
|
||
|
def set_value(self, v: bool) -> Toggle:
|
||
|
"""Set the value of the widget."""
|
||
|
self._value = v
|
||
|
return self
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Block:
|
||
|
"""A container of other elements.
|
||
|
|
||
|
Elements within a Block can be inspected and interacted with. This follows
|
||
|
the same syntax as inspecting and interacting within an ``AppTest`` object.
|
||
|
|
||
|
For all container classes, parameters of the original element can be
|
||
|
obtained as properties. For example, ``ChatMessage.avatar`` and
|
||
|
``Tab.label``.
|
||
|
"""
|
||
|
|
||
|
type: str
|
||
|
children: dict[int, Node]
|
||
|
proto: Any = field(repr=False)
|
||
|
root: ElementTree = field(repr=False)
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
proto: BlockProto | None,
|
||
|
root: ElementTree,
|
||
|
) -> None:
|
||
|
self.children = {}
|
||
|
self.proto = proto
|
||
|
if proto:
|
||
|
ty = proto.WhichOneof("type")
|
||
|
if ty is not None:
|
||
|
self.type = ty
|
||
|
else:
|
||
|
# `st.container` has no sub-message
|
||
|
self.type = "container"
|
||
|
else:
|
||
|
self.type = "unknown"
|
||
|
self.root = root
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.children)
|
||
|
|
||
|
def __iter__(self) -> Iterator[Node]:
|
||
|
yield self
|
||
|
for child_idx in self.children:
|
||
|
yield from self.children[child_idx]
|
||
|
|
||
|
def __getitem__(self, k: int) -> Node:
|
||
|
return self.children[k]
|
||
|
|
||
|
@property
|
||
|
def key(self) -> str | None:
|
||
|
return None
|
||
|
|
||
|
# We could implement these using __getattr__ but that would have
|
||
|
# much worse type information.
|
||
|
@property
|
||
|
def button(self) -> WidgetList[Button]:
|
||
|
return WidgetList(self.get("button")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def button_group(self) -> WidgetList[ButtonGroup[Any]]:
|
||
|
return WidgetList(self.get("button_group")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def caption(self) -> ElementList[Caption]:
|
||
|
return ElementList(self.get("caption")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def chat_input(self) -> WidgetList[ChatInput]:
|
||
|
return WidgetList(self.get("chat_input")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def chat_message(self) -> Sequence[ChatMessage]:
|
||
|
return self.get("chat_message") # type: ignore
|
||
|
|
||
|
@property
|
||
|
def checkbox(self) -> WidgetList[Checkbox]:
|
||
|
return WidgetList(self.get("checkbox")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def code(self) -> ElementList[Code]:
|
||
|
return ElementList(self.get("code")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def color_picker(self) -> WidgetList[ColorPicker]:
|
||
|
return WidgetList(self.get("color_picker")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def columns(self) -> Sequence[Column]:
|
||
|
return self.get("column") # type: ignore
|
||
|
|
||
|
@property
|
||
|
def dataframe(self) -> ElementList[Dataframe]:
|
||
|
return ElementList(self.get("arrow_data_frame")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def date_input(self) -> WidgetList[DateInput]:
|
||
|
return WidgetList(self.get("date_input")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def divider(self) -> ElementList[Divider]:
|
||
|
return ElementList(self.get("divider")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def error(self) -> ElementList[Error]:
|
||
|
return ElementList(self.get("error")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def exception(self) -> ElementList[Exception]:
|
||
|
return ElementList(self.get("exception")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def expander(self) -> Sequence[Expander]:
|
||
|
return self.get("expander") # type: ignore
|
||
|
|
||
|
@property
|
||
|
def header(self) -> ElementList[Header]:
|
||
|
return ElementList(self.get("header")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def info(self) -> ElementList[Info]:
|
||
|
return ElementList(self.get("info")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def json(self) -> ElementList[Json]:
|
||
|
return ElementList(self.get("json")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def latex(self) -> ElementList[Latex]:
|
||
|
return ElementList(self.get("latex")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def markdown(self) -> ElementList[Markdown]:
|
||
|
return ElementList(self.get("markdown")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def metric(self) -> ElementList[Metric]:
|
||
|
return ElementList(self.get("metric")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def multiselect(self) -> WidgetList[Multiselect[Any]]:
|
||
|
return WidgetList(self.get("multiselect")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def number_input(self) -> WidgetList[NumberInput]:
|
||
|
return WidgetList(self.get("number_input")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def radio(self) -> WidgetList[Radio[Any]]:
|
||
|
return WidgetList(self.get("radio")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def select_slider(self) -> WidgetList[SelectSlider[Any]]:
|
||
|
return WidgetList(self.get("select_slider")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def selectbox(self) -> WidgetList[Selectbox[Any]]:
|
||
|
return WidgetList(self.get("selectbox")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def slider(self) -> WidgetList[Slider[Any]]:
|
||
|
return WidgetList(self.get("slider")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def status(self) -> Sequence[Status]:
|
||
|
return self.get("status") # type: ignore
|
||
|
|
||
|
@property
|
||
|
def subheader(self) -> ElementList[Subheader]:
|
||
|
return ElementList(self.get("subheader")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def success(self) -> ElementList[Success]:
|
||
|
return ElementList(self.get("success")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def table(self) -> ElementList[Table]:
|
||
|
return ElementList(self.get("arrow_table")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def tabs(self) -> Sequence[Tab]:
|
||
|
return self.get("tab") # type: ignore
|
||
|
|
||
|
@property
|
||
|
def text(self) -> ElementList[Text]:
|
||
|
return ElementList(self.get("text")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def text_area(self) -> WidgetList[TextArea]:
|
||
|
return WidgetList(self.get("text_area")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def text_input(self) -> WidgetList[TextInput]:
|
||
|
return WidgetList(self.get("text_input")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def time_input(self) -> WidgetList[TimeInput]:
|
||
|
return WidgetList(self.get("time_input")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def title(self) -> ElementList[Title]:
|
||
|
return ElementList(self.get("title")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def toast(self) -> ElementList[Toast]:
|
||
|
return ElementList(self.get("toast")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def toggle(self) -> WidgetList[Toggle]:
|
||
|
return WidgetList(self.get("toggle")) # type: ignore
|
||
|
|
||
|
@property
|
||
|
def warning(self) -> ElementList[Warning]:
|
||
|
return ElementList(self.get("warning")) # type: ignore
|
||
|
|
||
|
def get(self, element_type: str) -> Sequence[Node]:
|
||
|
return [e for e in self if e.type == element_type]
|
||
|
|
||
|
def run(self, *, timeout: float | None = None) -> AppTest:
|
||
|
"""Run the script with updated widget values.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
timeout
|
||
|
The maximum number of seconds to run the script. None means
|
||
|
use the AppTest's default.
|
||
|
"""
|
||
|
return self.root.run(timeout=timeout)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return repr_(self)
|
||
|
|
||
|
|
||
|
def repr_(self: object) -> str:
|
||
|
"""A custom repr similar to `streamlit.util.repr_` but that shows tree
|
||
|
structure using indentation.
|
||
|
"""
|
||
|
classname = self.__class__.__name__
|
||
|
|
||
|
defaults: list[Any] = [None, "", False, [], set(), {}]
|
||
|
|
||
|
if is_dataclass(self):
|
||
|
fields_vals = (
|
||
|
(f.name, getattr(self, f.name))
|
||
|
for f in fields(self)
|
||
|
if f.repr
|
||
|
and getattr(self, f.name) != f.default
|
||
|
and getattr(self, f.name) not in defaults
|
||
|
)
|
||
|
else:
|
||
|
fields_vals = ((f, v) for (f, v) in self.__dict__.items() if v not in defaults)
|
||
|
|
||
|
reprs = []
|
||
|
for field_name, value in fields_vals:
|
||
|
line = (
|
||
|
f"{field_name}={format_dict(value)}"
|
||
|
if isinstance(value, dict)
|
||
|
else f"{field_name}={value!r}"
|
||
|
)
|
||
|
reprs.append(line)
|
||
|
|
||
|
reprs[0] = "\n" + reprs[0]
|
||
|
field_reprs = ",\n".join(reprs)
|
||
|
|
||
|
field_reprs = textwrap.indent(field_reprs, " " * 4)
|
||
|
return f"{classname}({field_reprs}\n)"
|
||
|
|
||
|
|
||
|
def format_dict(d: dict[Any, Any]) -> str:
|
||
|
lines = []
|
||
|
for k, v in d.items():
|
||
|
line = f"{k}: {v!r}"
|
||
|
lines.append(line)
|
||
|
r = ",\n".join(lines)
|
||
|
r = textwrap.indent(r, " " * 4)
|
||
|
return f"{{\n{r}\n}}"
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class SpecialBlock(Block):
|
||
|
"""Base class for the sidebar and main body containers."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
proto: BlockProto | None,
|
||
|
root: ElementTree,
|
||
|
type: str | None = None,
|
||
|
) -> None:
|
||
|
self.children = {}
|
||
|
self.proto = proto
|
||
|
if type:
|
||
|
self.type = type
|
||
|
elif proto and proto.WhichOneof("type"):
|
||
|
ty = proto.WhichOneof("type")
|
||
|
assert ty is not None
|
||
|
self.type = ty
|
||
|
else:
|
||
|
self.type = "unknown"
|
||
|
self.root = root
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class ChatMessage(Block):
|
||
|
"""A representation of ``st.chat_message``."""
|
||
|
|
||
|
type: str = field(repr=False)
|
||
|
proto: BlockProto.ChatMessage = field(repr=False)
|
||
|
name: str
|
||
|
avatar: str
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
proto: BlockProto.ChatMessage,
|
||
|
root: ElementTree,
|
||
|
) -> None:
|
||
|
self.children = {}
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.type = "chat_message"
|
||
|
self.name = proto.name
|
||
|
self.avatar = proto.avatar
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Column(Block):
|
||
|
"""A representation of a column within ``st.columns``."""
|
||
|
|
||
|
type: str = field(repr=False)
|
||
|
proto: BlockProto.Column = field(repr=False)
|
||
|
weight: float
|
||
|
gap: str
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
proto: BlockProto.Column,
|
||
|
root: ElementTree,
|
||
|
) -> None:
|
||
|
self.children = {}
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.type = "column"
|
||
|
self.weight = proto.weight
|
||
|
self.gap = proto.gap
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Expander(Block):
|
||
|
type: str = field(repr=False)
|
||
|
proto: BlockProto.Expandable = field(repr=False)
|
||
|
icon: str
|
||
|
label: str
|
||
|
|
||
|
def __init__(self, proto: BlockProto.Expandable, root: ElementTree) -> None:
|
||
|
self.children = {}
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
# The internal name is "expandable" but the public API uses "expander"
|
||
|
# so the naming of the class and type follows the public name.
|
||
|
self.type = "expander"
|
||
|
self.icon = proto.icon
|
||
|
self.label = proto.label
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Status(Block):
|
||
|
type: str = field(repr=False)
|
||
|
proto: BlockProto.Expandable = field(repr=False)
|
||
|
icon: str
|
||
|
label: str
|
||
|
|
||
|
def __init__(self, proto: BlockProto.Expandable, root: ElementTree) -> None:
|
||
|
self.children = {}
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.type = "status"
|
||
|
self.icon = proto.icon
|
||
|
self.label = proto.label
|
||
|
|
||
|
@property
|
||
|
def state(self) -> str:
|
||
|
if self.icon == "spinner":
|
||
|
return "running"
|
||
|
if self.icon == ":material/check:":
|
||
|
return "complete"
|
||
|
if self.icon == ":material/error:":
|
||
|
return "error"
|
||
|
raise ValueError("Unknown Status state")
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class Tab(Block):
|
||
|
"""A representation of tab within ``st.tabs``."""
|
||
|
|
||
|
type: str = field(repr=False)
|
||
|
proto: BlockProto.Tab = field(repr=False)
|
||
|
label: str
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
proto: BlockProto.Tab,
|
||
|
root: ElementTree,
|
||
|
) -> None:
|
||
|
self.children = {}
|
||
|
self.proto = proto
|
||
|
self.root = root
|
||
|
self.type = "tab"
|
||
|
self.label = proto.label
|
||
|
|
||
|
|
||
|
Node: TypeAlias = Union[Element, Block]
|
||
|
|
||
|
|
||
|
def get_widget_state(node: Node) -> WidgetState | None:
|
||
|
if isinstance(node, Widget):
|
||
|
return node._widget_state
|
||
|
return None
|
||
|
|
||
|
|
||
|
@dataclass(repr=False)
|
||
|
class ElementTree(Block):
|
||
|
"""A tree of the elements produced by running a streamlit script.
|
||
|
|
||
|
Elements can be queried in three ways:
|
||
|
- By element type, using `.foo` properties to get a list of all of that element,
|
||
|
in the order they appear in the app
|
||
|
- By user key, for widgets, by calling the above list with a key: `.foo(key='bar')`
|
||
|
- Positionally, using list indexing syntax (`[...]`) to access a child of a
|
||
|
block element. Not recommended because the exact tree structure can be surprising.
|
||
|
|
||
|
Element queries made on a block container will return only the elements
|
||
|
descending from that block.
|
||
|
|
||
|
Returned elements have methods for accessing whatever attributes are relevant.
|
||
|
For very simple elements this may be only its value, while complex elements
|
||
|
like widgets have many.
|
||
|
|
||
|
Widgets provide a fluent API for faking frontend interaction and rerunning
|
||
|
the script with the new widget values. All widgets provide a low level `set_value`
|
||
|
method, along with higher level methods specific to that type of widget.
|
||
|
After an interaction, calling `.run()` will update the AppTest with the
|
||
|
results of that script run.
|
||
|
"""
|
||
|
|
||
|
_runner: AppTest | None = field(repr=False, default=None)
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self.children = {}
|
||
|
self.root = self
|
||
|
self.type = "root"
|
||
|
|
||
|
@property
|
||
|
def main(self) -> Block:
|
||
|
m = self[0]
|
||
|
assert isinstance(m, Block)
|
||
|
return m
|
||
|
|
||
|
@property
|
||
|
def sidebar(self) -> Block:
|
||
|
s = self[1]
|
||
|
assert isinstance(s, Block)
|
||
|
return s
|
||
|
|
||
|
@property
|
||
|
def session_state(self) -> SafeSessionState:
|
||
|
assert self._runner is not None
|
||
|
return self._runner.session_state
|
||
|
|
||
|
def get_widget_states(self) -> WidgetStates:
|
||
|
ws = WidgetStates()
|
||
|
for node in self:
|
||
|
w = get_widget_state(node)
|
||
|
if w is not None:
|
||
|
ws.widgets.append(w)
|
||
|
|
||
|
return ws
|
||
|
|
||
|
def run(self, *, timeout: float | None = None) -> AppTest:
|
||
|
"""Run the script with updated widget values.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
timeout
|
||
|
The maximum number of seconds to run the script. None means
|
||
|
use the AppTest's default.
|
||
|
"""
|
||
|
assert self._runner is not None
|
||
|
|
||
|
widget_states = self.get_widget_states()
|
||
|
return self._runner._run(widget_states, timeout=timeout)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return format_dict(self.children)
|
||
|
|
||
|
|
||
|
def parse_tree_from_messages(messages: list[ForwardMsg]) -> ElementTree:
|
||
|
"""Transform a list of `ForwardMsg` into a tree matching the implicit
|
||
|
tree structure of blocks and elements in a streamlit app.
|
||
|
|
||
|
Returns the root of the tree, which acts as the entrypoint for the query
|
||
|
and interaction API.
|
||
|
"""
|
||
|
root = ElementTree()
|
||
|
root.children = {
|
||
|
0: SpecialBlock(type="main", root=root, proto=None),
|
||
|
1: SpecialBlock(type="sidebar", root=root, proto=None),
|
||
|
2: SpecialBlock(type="event", root=root, proto=None),
|
||
|
}
|
||
|
|
||
|
for msg in messages:
|
||
|
if not msg.HasField("delta"):
|
||
|
continue
|
||
|
delta_path = msg.metadata.delta_path
|
||
|
delta = msg.delta
|
||
|
if delta.WhichOneof("type") == "new_element":
|
||
|
elt = delta.new_element
|
||
|
ty = elt.WhichOneof("type")
|
||
|
new_node: Node
|
||
|
if ty == "alert":
|
||
|
alert_format = elt.alert.format
|
||
|
if alert_format == AlertProto.Format.ERROR:
|
||
|
new_node = Error(elt.alert, root=root)
|
||
|
elif alert_format == AlertProto.Format.INFO:
|
||
|
new_node = Info(elt.alert, root=root)
|
||
|
elif alert_format == AlertProto.Format.SUCCESS:
|
||
|
new_node = Success(elt.alert, root=root)
|
||
|
elif alert_format == AlertProto.Format.WARNING:
|
||
|
new_node = Warning(elt.alert, root=root)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unknown alert type with format {elt.alert.format}"
|
||
|
)
|
||
|
elif ty == "arrow_data_frame":
|
||
|
new_node = Dataframe(elt.arrow_data_frame, root=root)
|
||
|
elif ty == "arrow_table":
|
||
|
new_node = Table(elt.arrow_table, root=root)
|
||
|
elif ty == "button":
|
||
|
new_node = Button(elt.button, root=root)
|
||
|
elif ty == "button_group":
|
||
|
new_node = ButtonGroup(elt.button_group, root=root)
|
||
|
elif ty == "chat_input":
|
||
|
new_node = ChatInput(elt.chat_input, root=root)
|
||
|
elif ty == "checkbox":
|
||
|
style = elt.checkbox.type
|
||
|
if style == CheckboxProto.StyleType.TOGGLE:
|
||
|
new_node = Toggle(elt.checkbox, root=root)
|
||
|
else:
|
||
|
new_node = Checkbox(elt.checkbox, root=root)
|
||
|
elif ty == "code":
|
||
|
new_node = Code(elt.code, root=root)
|
||
|
elif ty == "color_picker":
|
||
|
new_node = ColorPicker(elt.color_picker, root=root)
|
||
|
elif ty == "date_input":
|
||
|
new_node = DateInput(elt.date_input, root=root)
|
||
|
elif ty == "exception":
|
||
|
new_node = Exception(elt.exception, root=root)
|
||
|
elif ty == "heading":
|
||
|
if elt.heading.tag == HeadingProtoTag.TITLE_TAG.value:
|
||
|
new_node = Title(elt.heading, root=root)
|
||
|
elif elt.heading.tag == HeadingProtoTag.HEADER_TAG.value:
|
||
|
new_node = Header(elt.heading, root=root)
|
||
|
elif elt.heading.tag == HeadingProtoTag.SUBHEADER_TAG.value:
|
||
|
new_node = Subheader(elt.heading, root=root)
|
||
|
else:
|
||
|
raise ValueError(f"Unknown heading type with tag {elt.heading.tag}")
|
||
|
elif ty == "json":
|
||
|
new_node = Json(elt.json, root=root)
|
||
|
elif ty == "markdown":
|
||
|
if elt.markdown.element_type == MarkdownProto.Type.NATIVE:
|
||
|
new_node = Markdown(elt.markdown, root=root)
|
||
|
elif elt.markdown.element_type == MarkdownProto.Type.CAPTION:
|
||
|
new_node = Caption(elt.markdown, root=root)
|
||
|
elif elt.markdown.element_type == MarkdownProto.Type.LATEX:
|
||
|
new_node = Latex(elt.markdown, root=root)
|
||
|
elif elt.markdown.element_type == MarkdownProto.Type.DIVIDER:
|
||
|
new_node = Divider(elt.markdown, root=root)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unknown markdown type {elt.markdown.element_type}"
|
||
|
)
|
||
|
elif ty == "metric":
|
||
|
new_node = Metric(elt.metric, root=root)
|
||
|
elif ty == "multiselect":
|
||
|
new_node = Multiselect(elt.multiselect, root=root)
|
||
|
elif ty == "number_input":
|
||
|
new_node = NumberInput(elt.number_input, root=root)
|
||
|
elif ty == "radio":
|
||
|
new_node = Radio(elt.radio, root=root)
|
||
|
elif ty == "selectbox":
|
||
|
new_node = Selectbox(elt.selectbox, root=root)
|
||
|
elif ty == "slider":
|
||
|
if elt.slider.type == SliderProto.Type.SLIDER:
|
||
|
new_node = Slider(elt.slider, root=root)
|
||
|
elif elt.slider.type == SliderProto.Type.SELECT_SLIDER:
|
||
|
new_node = SelectSlider(elt.slider, root=root)
|
||
|
else:
|
||
|
raise ValueError(f"Slider with unknown type {elt.slider}")
|
||
|
elif ty == "text":
|
||
|
new_node = Text(elt.text, root=root)
|
||
|
elif ty == "text_area":
|
||
|
new_node = TextArea(elt.text_area, root=root)
|
||
|
elif ty == "text_input":
|
||
|
new_node = TextInput(elt.text_input, root=root)
|
||
|
elif ty == "time_input":
|
||
|
new_node = TimeInput(elt.time_input, root=root)
|
||
|
elif ty == "toast":
|
||
|
new_node = Toast(elt.toast, root=root)
|
||
|
else:
|
||
|
new_node = UnknownElement(elt, root=root)
|
||
|
elif delta.WhichOneof("type") == "add_block":
|
||
|
block = delta.add_block
|
||
|
bty = block.WhichOneof("type")
|
||
|
if bty == "chat_message":
|
||
|
new_node = ChatMessage(block.chat_message, root=root)
|
||
|
elif bty == "column":
|
||
|
new_node = Column(block.column, root=root)
|
||
|
elif bty == "expandable":
|
||
|
if block.expandable.icon:
|
||
|
new_node = Status(block.expandable, root=root)
|
||
|
else:
|
||
|
new_node = Expander(block.expandable, root=root)
|
||
|
elif bty == "tab":
|
||
|
new_node = Tab(block.tab, root=root)
|
||
|
else:
|
||
|
new_node = Block(proto=block, root=root)
|
||
|
else:
|
||
|
# add_rows
|
||
|
continue
|
||
|
|
||
|
current_node: Block = root
|
||
|
# Every node up to the end is a Block
|
||
|
for idx in delta_path[:-1]:
|
||
|
children = current_node.children
|
||
|
child = children.get(idx)
|
||
|
if child is None:
|
||
|
child = Block(proto=None, root=root)
|
||
|
children[idx] = child
|
||
|
assert isinstance(child, Block)
|
||
|
current_node = child
|
||
|
|
||
|
# Handle a block when we already have a placeholder for that location
|
||
|
if isinstance(new_node, Block):
|
||
|
placeholder_block = current_node.children.get(delta_path[-1])
|
||
|
if placeholder_block is not None:
|
||
|
new_node.children = placeholder_block.children
|
||
|
|
||
|
current_node.children[delta_path[-1]] = new_node
|
||
|
|
||
|
return root
|