# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Assert statements are allowed here since the app testing logic is used within unit tests: # ruff: noqa: S101 from __future__ import annotations import textwrap from abc import ABC, abstractmethod from collections.abc import Iterator, Sequence from dataclasses import dataclass, field, fields, is_dataclass from datetime import date, datetime, time, timedelta from typing import ( TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload, ) from typing_extensions import Self, TypeAlias from streamlit import dataframe_util, util from streamlit.elements.heading import HeadingProtoTag from streamlit.elements.widgets.select_slider import SelectSliderSerde from streamlit.elements.widgets.slider import ( SliderSerde, SliderStep, SliderValueT, ) from streamlit.elements.widgets.time_widgets import ( DateInputSerde, DateWidgetReturn, TimeInputSerde, _parse_date_value, ) from streamlit.proto.Alert_pb2 import Alert as AlertProto from streamlit.proto.Checkbox_pb2 import Checkbox as CheckboxProto from streamlit.proto.Markdown_pb2 import Markdown as MarkdownProto from streamlit.proto.Slider_pb2 import Slider as SliderProto from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates from streamlit.runtime.state.common import TESTING_KEY, user_key_from_element_id if TYPE_CHECKING: from pandas import DataFrame as PandasDataframe from streamlit.proto.Arrow_pb2 import Arrow as ArrowProto from streamlit.proto.Block_pb2 import Block as BlockProto from streamlit.proto.Button_pb2 import Button as ButtonProto from streamlit.proto.ButtonGroup_pb2 import ButtonGroup as ButtonGroupProto from streamlit.proto.ChatInput_pb2 import ChatInput as ChatInputProto from streamlit.proto.Code_pb2 import Code as CodeProto from streamlit.proto.ColorPicker_pb2 import ColorPicker as ColorPickerProto from streamlit.proto.DateInput_pb2 import DateInput as DateInputProto from streamlit.proto.Element_pb2 import Element as ElementProto from streamlit.proto.Exception_pb2 import Exception as ExceptionProto from streamlit.proto.ForwardMsg_pb2 import ForwardMsg from streamlit.proto.Heading_pb2 import Heading as HeadingProto from streamlit.proto.Json_pb2 import Json as JsonProto from streamlit.proto.Metric_pb2 import Metric as MetricProto from streamlit.proto.MultiSelect_pb2 import MultiSelect as MultiSelectProto from streamlit.proto.NumberInput_pb2 import NumberInput as NumberInputProto from streamlit.proto.Radio_pb2 import Radio as RadioProto from streamlit.proto.Selectbox_pb2 import Selectbox as SelectboxProto from streamlit.proto.Text_pb2 import Text as TextProto from streamlit.proto.TextArea_pb2 import TextArea as TextAreaProto from streamlit.proto.TextInput_pb2 import TextInput as TextInputProto from streamlit.proto.TimeInput_pb2 import TimeInput as TimeInputProto from streamlit.proto.Toast_pb2 import Toast as ToastProto from streamlit.runtime.state.safe_session_state import SafeSessionState from streamlit.testing.v1.app_test import AppTest T = TypeVar("T") @dataclass class InitialValue: """Used to represent the initial value of a widget.""" pass # TODO: This class serves as a fallback option for elements that have not # been implemented yet, as well as providing implementations of some # trivial methods. It may have significantly reduced scope once all elements # have been implemented. # This class will not be sufficient implementation for most elements. # Widgets need their own classes to translate interactions into the appropriate # WidgetState and provide higher level interaction interfaces, and other elements # have enough variation in how to get their values that most will need their # own classes too. @dataclass class Element(ABC): """ Element base class for testing. This class's methods and attributes are universal for all elements implemented in testing. For example, ``Caption``, ``Code``, ``Text``, and ``Title`` inherit from ``Element``. All widget classes also inherit from Element, but have additional methods specific to each widget type. See the ``AppTest`` class for the full list of supported elements. For all element classes, parameters of the original element can be obtained as properties. For example, ``Button.label``, ``Caption.help``, and ``Toast.icon``. """ type: str = field(repr=False) proto: Any = field(repr=False) root: ElementTree = field(repr=False) key: str | None @abstractmethod def __init__(self, proto: ElementProto, root: ElementTree) -> None: ... def __iter__(self) -> Iterator[Self]: yield self @property @abstractmethod def value(self) -> Any: """The value or contents of the element.""" ... def __getattr__(self, name: str) -> Any: """Fallback attempt to get an attribute from the proto.""" return getattr(self.proto, name) def run(self, *, timeout: float | None = None) -> AppTest: """Run the ``AppTest`` script which contains the element. Parameters ---------- timeout The maximum number of seconds to run the script. None means use the AppTest's default. """ return self.root.run(timeout=timeout) def __repr__(self) -> str: return util.repr_(self) @dataclass(repr=False) class UnknownElement(Element): def __init__(self, proto: ElementProto, root: ElementTree) -> None: ty = proto.WhichOneof("type") assert ty is not None self.proto = getattr(proto, ty) self.root = root self.type = ty self.key = None @property def value(self) -> Any: try: state = self.root.session_state assert state is not None return state[self.proto.id] except ValueError: # No id field, not a widget return self.proto.value @dataclass(repr=False) class Widget(Element, ABC): """Widget base class for testing.""" id: str = field(repr=False) disabled: bool key: str | None _value: Any def __init__(self, proto: Any, root: ElementTree) -> None: self.proto = proto self.root = root self.key = user_key_from_element_id(self.id) self._value = None def set_value(self, v: Any) -> Self: """Set the value of the widget.""" self._value = v return self @property @abstractmethod def _widget_state(self) -> WidgetState: ... El_co = TypeVar("El_co", bound=Element, covariant=True) class ElementList(Generic[El_co]): def __init__(self, els: Sequence[El_co]) -> None: self._list: Sequence[El_co] = els def __len__(self) -> int: return len(self._list) @property def len(self) -> int: return len(self) @overload def __getitem__(self, idx: int) -> El_co: ... @overload def __getitem__(self, idx: slice) -> ElementList[El_co]: ... def __getitem__(self, idx: int | slice) -> El_co | ElementList[El_co]: if isinstance(idx, slice): return ElementList(self._list[idx]) return self._list[idx] def __iter__(self) -> Iterator[El_co]: yield from self._list def __repr__(self) -> str: return util.repr_(self) def __eq__(self, other: ElementList[El_co] | object) -> bool: if isinstance(other, ElementList): return self._list == other._list return self._list == other def __hash__(self) -> int: return hash(tuple(self._list)) @property def values(self) -> Sequence[Any]: return [e.value for e in self] W_co = TypeVar("W_co", bound=Widget, covariant=True) class WidgetList(ElementList[W_co], Generic[W_co]): def __call__(self, key: str) -> W_co: for e in self._list: if e.key == key: return e raise KeyError(key) @dataclass(repr=False) class AlertBase(Element): proto: AlertProto = field(repr=False) icon: str def __init__(self, proto: AlertProto, root: ElementTree) -> None: self.proto = proto self.key = None self.root = root @property def value(self) -> str: return self.proto.body @dataclass(repr=False) class Error(AlertBase): def __init__(self, proto: AlertProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "error" @dataclass(repr=False) class Warning(AlertBase): # noqa: A001 def __init__(self, proto: AlertProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "warning" @dataclass(repr=False) class Info(AlertBase): def __init__(self, proto: AlertProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "info" @dataclass(repr=False) class Success(AlertBase): def __init__(self, proto: AlertProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "success" @dataclass(repr=False) class Button(Widget): """A representation of ``st.button`` and ``st.form_submit_button``.""" _value: bool proto: ButtonProto = field(repr=False) label: str help: str form_id: str def __init__(self, proto: ButtonProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = False self.type = "button" @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id ws.trigger_value = self._value return ws @property def value(self) -> bool: """The value of the button. (bool)""" # noqa: D400 if self._value: return self._value state = self.root.session_state assert state return cast("bool", state[TESTING_KEY][self.id]) def set_value(self, v: bool) -> Button: """Set the value of the button.""" self._value = v return self def click(self) -> Button: """Set the value of the button to True.""" return self.set_value(True) @dataclass(repr=False) class ChatInput(Widget): """A representation of ``st.chat_input``.""" _value: str | None proto: ChatInputProto = field(repr=False) placeholder: str def __init__(self, proto: ChatInputProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "chat_input" def set_value(self, v: str | None) -> ChatInput: """Set the value of the widget.""" self._value = v return self @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id if self._value is not None: ws.string_trigger_value.data = self._value return ws @property def value(self) -> str | None: """The value of the widget. (str)""" # noqa: D400 if self._value: return self._value state = self.root.session_state assert state return state[TESTING_KEY][self.id] # type: ignore @dataclass(repr=False) class Checkbox(Widget): """A representation of ``st.checkbox``.""" _value: bool | None proto: CheckboxProto = field(repr=False) label: str help: str form_id: str def __init__(self, proto: CheckboxProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "checkbox" @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id ws.bool_value = self.value return ws @property def value(self) -> bool: """The value of the widget. (bool)""" # noqa: D400 if self._value is not None: return self._value state = self.root.session_state assert state return cast("bool", state[self.id]) def set_value(self, v: bool) -> Checkbox: """Set the value of the widget.""" self._value = v return self def check(self) -> Checkbox: """Set the value of the widget to True.""" return self.set_value(True) def uncheck(self) -> Checkbox: """Set the value of the widget to False.""" return self.set_value(False) @dataclass(repr=False) class Code(Element): """A representation of ``st.code``.""" proto: CodeProto = field(repr=False) language: str show_line_numbers: bool key: None def __init__(self, proto: CodeProto, root: ElementTree) -> None: self.proto = proto self.key = None self.root = root self.type = "code" @property def value(self) -> str: """The value of the element. (str)""" # noqa: D400 return self.proto.code_text @dataclass(repr=False) class ColorPicker(Widget): """A representation of ``st.color_picker``.""" _value: str | None label: str help: str form_id: str proto: ColorPickerProto = field(repr=False) def __init__(self, proto: ColorPickerProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "color_picker" @property def value(self) -> str: """The currently selected value as a hex string. (str)""" # noqa: D400 if self._value is not None: return self._value state = self.root.session_state assert state return cast("str", state[self.id]) @property def _widget_state(self) -> WidgetState: """Protobuf message representing the state of the widget, including any interactions that have happened. Should be the same as the frontend would produce for those interactions. """ ws = WidgetState() ws.id = self.id ws.string_value = self.value return ws def set_value(self, v: str) -> ColorPicker: """Set the value of the widget as a hex string.""" self._value = v return self def pick(self, v: str) -> ColorPicker: """Set the value of the widget as a hex string. May omit the "#" prefix.""" if not v.startswith("#"): v = f"#{v}" return self.set_value(v) @dataclass(repr=False) class Dataframe(Element): proto: ArrowProto = field(repr=False) def __init__(self, proto: ArrowProto, root: ElementTree) -> None: self.key = None self.proto = proto self.root = root self.type = "arrow_data_frame" @property def value(self) -> PandasDataframe: return dataframe_util.convert_arrow_bytes_to_pandas_df(self.proto.data) SingleDateValue: TypeAlias = Union[date, datetime] DateValue: TypeAlias = Union[SingleDateValue, Sequence[SingleDateValue], None] @dataclass(repr=False) class DateInput(Widget): """A representation of ``st.date_input``.""" _value: DateValue | None | InitialValue proto: DateInputProto = field(repr=False) label: str min: date max: date is_range: bool help: str form_id: str def __init__(self, proto: DateInputProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = InitialValue() self.type = "date_input" self.min = datetime.strptime(proto.min, "%Y/%m/%d").date() self.max = datetime.strptime(proto.max, "%Y/%m/%d").date() def set_value(self, v: DateValue) -> DateInput: """Set the value of the widget.""" self._value = v return self @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id serde = DateInputSerde(None) # type: ignore ws.string_array_value.data[:] = serde.serialize(self.value) return ws @property def value(self) -> DateWidgetReturn: """The value of the widget. (date or Tuple of date)""" # noqa: D400 if not isinstance(self._value, InitialValue): parsed, _ = _parse_date_value(self._value) return tuple(parsed) if parsed is not None else None # type: ignore state = self.root.session_state assert state return state[self.id] # type: ignore @dataclass(repr=False) class Exception(Element): # noqa: A001 message: str is_markdown: bool stack_trace: list[str] is_warning: bool def __init__(self, proto: ExceptionProto, root: ElementTree) -> None: self.key = None self.root = root self.proto = proto self.type = "exception" self.is_markdown = proto.message_is_markdown self.stack_trace = list(proto.stack_trace) @property def value(self) -> str: return self.message @dataclass(repr=False) class HeadingBase(Element, ABC): proto: HeadingProto = field(repr=False) tag: str anchor: str | None hide_anchor: bool key: None def __init__(self, proto: HeadingProto, root: ElementTree, type_: str) -> None: self.proto = proto self.key = None self.root = root self.type = type_ @property def value(self) -> str: return self.proto.body @dataclass(repr=False) class Header(HeadingBase): def __init__(self, proto: HeadingProto, root: ElementTree) -> None: super().__init__(proto, root, "header") @dataclass(repr=False) class Subheader(HeadingBase): def __init__(self, proto: HeadingProto, root: ElementTree) -> None: super().__init__(proto, root, "subheader") @dataclass(repr=False) class Title(HeadingBase): def __init__(self, proto: HeadingProto, root: ElementTree) -> None: super().__init__(proto, root, "title") @dataclass(repr=False) class Json(Element): proto: JsonProto = field(repr=False) expanded: bool def __init__(self, proto: JsonProto, root: ElementTree) -> None: self.proto = proto self.key = None self.root = root self.type = "json" @property def value(self) -> str: return self.proto.body @dataclass(repr=False) class Markdown(Element): proto: MarkdownProto = field(repr=False) is_caption: bool allow_html: bool key: None def __init__(self, proto: MarkdownProto, root: ElementTree) -> None: self.proto = proto self.key = None self.root = root self.type = "markdown" @property def value(self) -> str: return self.proto.body @dataclass(repr=False) class Caption(Markdown): def __init__(self, proto: MarkdownProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "caption" @dataclass(repr=False) class Divider(Markdown): def __init__(self, proto: MarkdownProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "divider" @dataclass(repr=False) class Latex(Markdown): def __init__(self, proto: MarkdownProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "latex" @dataclass(repr=False) class Metric(Element): proto: MetricProto label: str delta: str color: str help: str def __init__(self, proto: MetricProto, root: ElementTree) -> None: self.proto = proto self.key = None self.root = root self.type = "metric" @property def value(self) -> str: return self.proto.body @dataclass(repr=False) class ButtonGroup(Widget, Generic[T]): """A representation of button_group that is used by ``st.feedback``.""" _value: list[T] | None proto: ButtonGroupProto = field(repr=False) options: list[ButtonGroupProto.Option] form_id: str def __init__(self, proto: ButtonGroupProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "button_group" self.options = list(proto.options) @property def _widget_state(self) -> WidgetState: """Protobuf message representing the state of the widget, including any interactions that have happened. Should be the same as the frontend would produce for those interactions. """ ws = WidgetState() ws.id = self.id ws.int_array_value.data[:] = self.indices return ws @property def value(self) -> list[T]: """The currently selected values from the options. (list)""" # noqa: D400 if self._value is not None: return self._value state = self.root.session_state assert state return cast("list[T]", state[self.id]) @property def indices(self) -> Sequence[int]: """The indices of the currently selected values from the options. (list)""" # noqa: D400 return [self.options.index(self.format_func(v)) for v in self.value] @property def format_func(self) -> Callable[[Any], Any]: """The widget's formatting function for displaying options. (callable)""" # noqa: D400 ss = self.root.session_state return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id]) def set_value(self, v: list[T]) -> ButtonGroup[T]: """Set the value of the multiselect widget. (list)""" # noqa: D400 self._value = v return self def select(self, v: T) -> ButtonGroup[T]: """ Add a selection to the widget. Do nothing if the value is already selected.\ If testing a multiselect widget with repeated options, use ``set_value``\ instead. """ current = self.value if v in current: return self new = current.copy() new.append(v) self.set_value(new) return self def unselect(self, v: T) -> ButtonGroup[T]: """ Remove a selection from the widget. Do nothing if the value is not\ already selected. If a value is selected multiple times, the first\ instance is removed. """ current = self.value if v not in current: return self new = current.copy() while v in new: new.remove(v) self.set_value(new) return self @dataclass(repr=False) class Multiselect(Widget, Generic[T]): """A representation of ``st.multiselect``.""" _value: list[T] | None proto: MultiSelectProto = field(repr=False) label: str options: list[str] max_selections: int help: str form_id: str def __init__(self, proto: MultiSelectProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "multiselect" self.options = list(proto.options) @property def _widget_state(self) -> WidgetState: """Protobuf message representing the state of the widget, including any interactions that have happened. Should be the same as the frontend would produce for those interactions. """ ws = WidgetState() ws.id = self.id ws.string_array_value.data[:] = self.values return ws @property def value(self) -> list[T]: """The currently selected values from the options. (list)""" # noqa: D400 if self._value is not None: return self._value state = self.root.session_state assert state return cast("list[T]", state[self.id]) @property def indices(self) -> Sequence[int]: """The indices of the currently selected values from the options. (list)""" # noqa: D400 return [self.options.index(self.format_func(v)) for v in self.value] @property def values(self) -> Sequence[str]: """The currently selected values from the options. (list)""" # noqa: D400 return [self.format_func(v) for v in self.value] @property def format_func(self) -> Callable[[Any], Any]: """The widget's formatting function for displaying options. (callable)""" # noqa: D400 ss = self.root.session_state return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id]) def set_value(self, v: list[T]) -> Multiselect[T]: """Set the value of the multiselect widget. (list)""" # noqa: D400 self._value = v return self def select(self, v: T) -> Multiselect[T]: """ Add a selection to the widget. Do nothing if the value is already selected.\ If testing a multiselect widget with repeated options, use ``set_value``\ instead. """ current = self.value if v in current: return self new = current.copy() new.append(v) self.set_value(new) return self def unselect(self, v: T) -> Multiselect[T]: """ Remove a selection from the widget. Do nothing if the value is not\ already selected. If a value is selected multiple times, the first\ instance is removed. """ current = self.value if v not in current: return self new = current.copy() while v in new: new.remove(v) self.set_value(new) return self Number = Union[int, float] @dataclass(repr=False) class NumberInput(Widget): """A representation of ``st.number_input``.""" _value: Number | None | InitialValue proto: NumberInputProto = field(repr=False) label: str min: Number | None max: Number | None step: Number help: str form_id: str def __init__(self, proto: NumberInputProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = InitialValue() self.type = "number_input" self.min = proto.min if proto.has_min else None self.max = proto.max if proto.has_max else None def set_value(self, v: Number | None) -> NumberInput: """Set the value of the ``st.number_input`` widget.""" self._value = v return self @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id if self.value is not None: ws.double_value = self.value return ws @property def value(self) -> Number | None: """Get the current value of the ``st.number_input`` widget.""" if not isinstance(self._value, InitialValue): return self._value state = self.root.session_state assert state # Awkward to do this with `cast` return state[self.id] # type: ignore def increment(self) -> NumberInput: """Increment the ``st.number_input`` widget as if the user clicked "+".""" if self.value is None: return self v = min(self.value + self.step, self.max or float("inf")) return self.set_value(v) def decrement(self) -> NumberInput: """Decrement the ``st.number_input`` widget as if the user clicked "-".""" if self.value is None: return self v = max(self.value - self.step, self.min or float("-inf")) return self.set_value(v) @dataclass(repr=False) class Radio(Widget, Generic[T]): """A representation of ``st.radio``.""" _value: T | None | InitialValue proto: RadioProto = field(repr=False) label: str options: list[str] horizontal: bool help: str form_id: str def __init__(self, proto: RadioProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = InitialValue() self.type = "radio" self.options = list(proto.options) @property def index(self) -> int | None: """The index of the current selection. (int)""" # noqa: D400 if self.value is None: return None return self.options.index(self.format_func(self.value)) @property def value(self) -> T | None: """The currently selected value from the options. (Any)""" # noqa: D400 if not isinstance(self._value, InitialValue): return self._value state = self.root.session_state assert state return cast("T", state[self.id]) @property def format_func(self) -> Callable[[Any], Any]: """The widget's formatting function for displaying options. (callable)""" # noqa: D400 ss = self.root.session_state return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id]) def set_value(self, v: T | None) -> Radio[T]: """Set the selection by value.""" self._value = v return self @property def _widget_state(self) -> WidgetState: """Protobuf message representing the state of the widget, including any interactions that have happened. Should be the same as the frontend would produce for those interactions. """ ws = WidgetState() ws.id = self.id if self.index is not None: ws.int_value = self.index return ws @dataclass(repr=False) class Selectbox(Widget, Generic[T]): """A representation of ``st.selectbox``.""" _value: T | None | InitialValue proto: SelectboxProto = field(repr=False) label: str options: list[str] help: str form_id: str def __init__(self, proto: SelectboxProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = InitialValue() self.type = "selectbox" self.options = list(proto.options) @property def index(self) -> int | None: """The index of the current selection. (int)""" # noqa: D400 if self.value is None: return None if len(self.options) == 0: return 0 return self.options.index(self.format_func(self.value)) @property def value(self) -> T | None: """The currently selected value from the options. (Any)""" # noqa: D400 if not isinstance(self._value, InitialValue): return self._value state = self.root.session_state assert state return cast("T", state[self.id]) @property def format_func(self) -> Callable[[Any], Any]: """The widget's formatting function for displaying options. (callable)""" # noqa: D400 ss = self.root.session_state return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id]) def set_value(self, v: T | None) -> Selectbox[T]: """Set the selection by value.""" self._value = v return self def select(self, v: T | None) -> Selectbox[T]: """Set the selection by value.""" return self.set_value(v) def select_index(self, index: int | None) -> Selectbox[T]: """Set the selection by index.""" if index is None: return self.set_value(None) return self.set_value(cast("T", self.options[index])) @property def _widget_state(self) -> WidgetState: """Protobuf message representing the state of the widget, including any interactions that have happened. Should be the same as the frontend would produce for those interactions. """ ws = WidgetState() ws.id = self.id if self.index is not None and len(self.options) > 0: ws.string_value = self.options[self.index] return ws @dataclass(repr=False) class SelectSlider(Widget, Generic[T]): """A representation of ``st.select_slider``.""" _value: T | Sequence[T] | None proto: SliderProto = field(repr=False) label: str data_type: SliderProto.DataType.ValueType options: list[str] help: str form_id: str def __init__(self, proto: SliderProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "select_slider" self.options = list(proto.options) def set_value(self, v: T | Sequence[T]) -> SelectSlider[T]: """Set the (single) selection by value.""" self._value = v return self @property def _widget_state(self) -> WidgetState: serde = SelectSliderSerde(self.options, [], False) try: v = serde.serialize(self.format_func(self.value)) except (ValueError, TypeError): try: v = serde.serialize([self.format_func(val) for val in self.value]) # type: ignore except: # noqa: E722 raise ValueError(f"Could not find index for {self.value}") ws = WidgetState() ws.id = self.id ws.double_array_value.data[:] = v return ws @property def value(self) -> T | Sequence[T]: """The currently selected value or range. (Any or Sequence of Any)""" # noqa: D400 if self._value is not None: return self._value state = self.root.session_state assert state # Awkward to do this with `cast` return state[self.id] # type: ignore @property def format_func(self) -> Callable[[Any], Any]: """The widget's formatting function for displaying options. (callable)""" # noqa: D400 ss = self.root.session_state return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id]) def set_range(self, lower: T, upper: T) -> SelectSlider[T]: """Set the ranged selection by values.""" return self.set_value([lower, upper]) @dataclass(repr=False) class Slider(Widget, Generic[SliderValueT]): """A representation of ``st.slider``.""" _value: SliderValueT | Sequence[SliderValueT] | None proto: SliderProto = field(repr=False) label: str data_type: SliderProto.DataType.ValueType min: SliderValueT max: SliderValueT step: SliderStep help: str form_id: str def __init__(self, proto: SliderProto, root: ElementTree) -> None: super().__init__(proto, root) self.type = "slider" def set_value( self, v: SliderValueT | Sequence[SliderValueT] ) -> Slider[SliderValueT]: """Set the (single) value of the slider.""" self._value = v return self @property def _widget_state(self) -> WidgetState: data_type = self.proto.data_type serde = SliderSerde([], data_type, True, None) v = serde.serialize(self.value) ws = WidgetState() ws.id = self.id ws.double_array_value.data[:] = v return ws @property def value(self) -> SliderValueT | Sequence[SliderValueT]: """The currently selected value or range. (Any or Sequence of Any)""" # noqa: D400 if self._value is not None: return self._value state = self.root.session_state assert state # Awkward to do this with `cast` return state[self.id] # type: ignore def set_range( self, lower: SliderValueT, upper: SliderValueT ) -> Slider[SliderValueT]: """Set the ranged value of the slider.""" return self.set_value([lower, upper]) @dataclass(repr=False) class Table(Element): proto: ArrowProto = field(repr=False) def __init__(self, proto: ArrowProto, root: ElementTree) -> None: self.key = None self.proto = proto self.root = root self.type = "arrow_table" @property def value(self) -> PandasDataframe: return dataframe_util.convert_arrow_bytes_to_pandas_df(self.proto.data) @dataclass(repr=False) class Text(Element): proto: TextProto = field(repr=False) key: None = None def __init__(self, proto: TextProto, root: ElementTree) -> None: self.proto = proto self.root = root self.type = "text" @property def value(self) -> str: """The value of the element. (str)""" # noqa: D400 return self.proto.body @dataclass(repr=False) class TextArea(Widget): """A representation of ``st.text_area``.""" _value: str | None | InitialValue proto: TextAreaProto = field(repr=False) label: str max_chars: int placeholder: str help: str form_id: str def __init__(self, proto: TextAreaProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = InitialValue() self.type = "text_area" def set_value(self, v: str | None) -> TextArea: """Set the value of the widget.""" self._value = v return self @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id if self.value is not None: ws.string_value = self.value return ws @property def value(self) -> str | None: """The current value of the widget. (str)""" # noqa: D400 if not isinstance(self._value, InitialValue): return self._value state = self.root.session_state assert state # Awkward to do this with `cast` return state[self.id] # type: ignore def input(self, v: str) -> TextArea: """ Set the value of the widget only if the value does not exceed the\ maximum allowed characters. """ # TODO: should input be setting or appending? if self.max_chars and len(v) > self.max_chars: return self return self.set_value(v) @dataclass(repr=False) class TextInput(Widget): """A representation of ``st.text_input``.""" _value: str | None | InitialValue proto: TextInputProto = field(repr=False) label: str max_chars: int autocomplete: str placeholder: str help: str form_id: str def __init__(self, proto: TextInputProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = InitialValue() self.type = "text_input" def set_value(self, v: str | None) -> TextInput: """Set the value of the widget.""" self._value = v return self @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id if self.value is not None: ws.string_value = self.value return ws @property def value(self) -> str | None: """The current value of the widget. (str)""" # noqa: D400 if not isinstance(self._value, InitialValue): return self._value state = self.root.session_state assert state # Awkward to do this with `cast` return state[self.id] # type: ignore def input(self, v: str) -> TextInput: """ Set the value of the widget only if the value does not exceed the\ maximum allowed characters. """ # TODO: should input be setting or appending? if self.max_chars and len(v) > self.max_chars: return self return self.set_value(v) TimeValue: TypeAlias = Union[time, datetime] @dataclass(repr=False) class TimeInput(Widget): """A representation of ``st.time_input``.""" _value: TimeValue | None | InitialValue proto: TimeInputProto = field(repr=False) label: str step: int help: str form_id: str def __init__(self, proto: TimeInputProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = InitialValue() self.type = "time_input" def set_value(self, v: TimeValue | None) -> TimeInput: """Set the value of the widget.""" self._value = v return self @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id serde = TimeInputSerde(None) serialized_value = serde.serialize(self.value) if serialized_value is not None: ws.string_value = serialized_value return ws @property def value(self) -> time | None: """The current value of the widget. (time)""" # noqa: D400 if not isinstance(self._value, InitialValue): v = self._value return v.time() if isinstance(v, datetime) else v state = self.root.session_state assert state return state[self.id] # type: ignore def increment(self) -> TimeInput: """Select the next available time.""" if self.value is None: return self dt = datetime.combine(date.today(), self.value) + timedelta(seconds=self.step) return self.set_value(dt.time()) def decrement(self) -> TimeInput: """Select the previous available time.""" if self.value is None: return self dt = datetime.combine(date.today(), self.value) - timedelta(seconds=self.step) return self.set_value(dt.time()) @dataclass(repr=False) class Toast(Element): proto: ToastProto = field(repr=False) icon: str def __init__(self, proto: ToastProto, root: ElementTree) -> None: self.proto = proto self.key = None self.root = root self.type = "toast" @property def value(self) -> str: return self.proto.body @dataclass(repr=False) class Toggle(Widget): """A representation of ``st.toggle``.""" _value: bool | None proto: CheckboxProto = field(repr=False) label: str help: str form_id: str def __init__(self, proto: CheckboxProto, root: ElementTree) -> None: super().__init__(proto, root) self._value = None self.type = "toggle" @property def _widget_state(self) -> WidgetState: ws = WidgetState() ws.id = self.id ws.bool_value = self.value return ws @property def value(self) -> bool: """The current value of the widget. (bool)""" # noqa: D400 if self._value is not None: return self._value state = self.root.session_state assert state return cast("bool", state[self.id]) def set_value(self, v: bool) -> Toggle: """Set the value of the widget.""" self._value = v return self @dataclass(repr=False) class Block: """A container of other elements. Elements within a Block can be inspected and interacted with. This follows the same syntax as inspecting and interacting within an ``AppTest`` object. For all container classes, parameters of the original element can be obtained as properties. For example, ``ChatMessage.avatar`` and ``Tab.label``. """ type: str children: dict[int, Node] proto: Any = field(repr=False) root: ElementTree = field(repr=False) def __init__( self, proto: BlockProto | None, root: ElementTree, ) -> None: self.children = {} self.proto = proto if proto: ty = proto.WhichOneof("type") if ty is not None: self.type = ty else: # `st.container` has no sub-message self.type = "container" else: self.type = "unknown" self.root = root def __len__(self) -> int: return len(self.children) def __iter__(self) -> Iterator[Node]: yield self for child_idx in self.children: yield from self.children[child_idx] def __getitem__(self, k: int) -> Node: return self.children[k] @property def key(self) -> str | None: return None # We could implement these using __getattr__ but that would have # much worse type information. @property def button(self) -> WidgetList[Button]: return WidgetList(self.get("button")) # type: ignore @property def button_group(self) -> WidgetList[ButtonGroup[Any]]: return WidgetList(self.get("button_group")) # type: ignore @property def caption(self) -> ElementList[Caption]: return ElementList(self.get("caption")) # type: ignore @property def chat_input(self) -> WidgetList[ChatInput]: return WidgetList(self.get("chat_input")) # type: ignore @property def chat_message(self) -> Sequence[ChatMessage]: return self.get("chat_message") # type: ignore @property def checkbox(self) -> WidgetList[Checkbox]: return WidgetList(self.get("checkbox")) # type: ignore @property def code(self) -> ElementList[Code]: return ElementList(self.get("code")) # type: ignore @property def color_picker(self) -> WidgetList[ColorPicker]: return WidgetList(self.get("color_picker")) # type: ignore @property def columns(self) -> Sequence[Column]: return self.get("column") # type: ignore @property def dataframe(self) -> ElementList[Dataframe]: return ElementList(self.get("arrow_data_frame")) # type: ignore @property def date_input(self) -> WidgetList[DateInput]: return WidgetList(self.get("date_input")) # type: ignore @property def divider(self) -> ElementList[Divider]: return ElementList(self.get("divider")) # type: ignore @property def error(self) -> ElementList[Error]: return ElementList(self.get("error")) # type: ignore @property def exception(self) -> ElementList[Exception]: return ElementList(self.get("exception")) # type: ignore @property def expander(self) -> Sequence[Expander]: return self.get("expander") # type: ignore @property def header(self) -> ElementList[Header]: return ElementList(self.get("header")) # type: ignore @property def info(self) -> ElementList[Info]: return ElementList(self.get("info")) # type: ignore @property def json(self) -> ElementList[Json]: return ElementList(self.get("json")) # type: ignore @property def latex(self) -> ElementList[Latex]: return ElementList(self.get("latex")) # type: ignore @property def markdown(self) -> ElementList[Markdown]: return ElementList(self.get("markdown")) # type: ignore @property def metric(self) -> ElementList[Metric]: return ElementList(self.get("metric")) # type: ignore @property def multiselect(self) -> WidgetList[Multiselect[Any]]: return WidgetList(self.get("multiselect")) # type: ignore @property def number_input(self) -> WidgetList[NumberInput]: return WidgetList(self.get("number_input")) # type: ignore @property def radio(self) -> WidgetList[Radio[Any]]: return WidgetList(self.get("radio")) # type: ignore @property def select_slider(self) -> WidgetList[SelectSlider[Any]]: return WidgetList(self.get("select_slider")) # type: ignore @property def selectbox(self) -> WidgetList[Selectbox[Any]]: return WidgetList(self.get("selectbox")) # type: ignore @property def slider(self) -> WidgetList[Slider[Any]]: return WidgetList(self.get("slider")) # type: ignore @property def status(self) -> Sequence[Status]: return self.get("status") # type: ignore @property def subheader(self) -> ElementList[Subheader]: return ElementList(self.get("subheader")) # type: ignore @property def success(self) -> ElementList[Success]: return ElementList(self.get("success")) # type: ignore @property def table(self) -> ElementList[Table]: return ElementList(self.get("arrow_table")) # type: ignore @property def tabs(self) -> Sequence[Tab]: return self.get("tab") # type: ignore @property def text(self) -> ElementList[Text]: return ElementList(self.get("text")) # type: ignore @property def text_area(self) -> WidgetList[TextArea]: return WidgetList(self.get("text_area")) # type: ignore @property def text_input(self) -> WidgetList[TextInput]: return WidgetList(self.get("text_input")) # type: ignore @property def time_input(self) -> WidgetList[TimeInput]: return WidgetList(self.get("time_input")) # type: ignore @property def title(self) -> ElementList[Title]: return ElementList(self.get("title")) # type: ignore @property def toast(self) -> ElementList[Toast]: return ElementList(self.get("toast")) # type: ignore @property def toggle(self) -> WidgetList[Toggle]: return WidgetList(self.get("toggle")) # type: ignore @property def warning(self) -> ElementList[Warning]: return ElementList(self.get("warning")) # type: ignore def get(self, element_type: str) -> Sequence[Node]: return [e for e in self if e.type == element_type] def run(self, *, timeout: float | None = None) -> AppTest: """Run the script with updated widget values. Parameters ---------- timeout The maximum number of seconds to run the script. None means use the AppTest's default. """ return self.root.run(timeout=timeout) def __repr__(self) -> str: return repr_(self) def repr_(self: object) -> str: """A custom repr similar to `streamlit.util.repr_` but that shows tree structure using indentation. """ classname = self.__class__.__name__ defaults: list[Any] = [None, "", False, [], set(), {}] if is_dataclass(self): fields_vals = ( (f.name, getattr(self, f.name)) for f in fields(self) if f.repr and getattr(self, f.name) != f.default and getattr(self, f.name) not in defaults ) else: fields_vals = ((f, v) for (f, v) in self.__dict__.items() if v not in defaults) reprs = [] for field_name, value in fields_vals: line = ( f"{field_name}={format_dict(value)}" if isinstance(value, dict) else f"{field_name}={value!r}" ) reprs.append(line) reprs[0] = "\n" + reprs[0] field_reprs = ",\n".join(reprs) field_reprs = textwrap.indent(field_reprs, " " * 4) return f"{classname}({field_reprs}\n)" def format_dict(d: dict[Any, Any]) -> str: lines = [] for k, v in d.items(): line = f"{k}: {v!r}" lines.append(line) r = ",\n".join(lines) r = textwrap.indent(r, " " * 4) return f"{{\n{r}\n}}" @dataclass(repr=False) class SpecialBlock(Block): """Base class for the sidebar and main body containers.""" def __init__( self, proto: BlockProto | None, root: ElementTree, type: str | None = None, ) -> None: self.children = {} self.proto = proto if type: self.type = type elif proto and proto.WhichOneof("type"): ty = proto.WhichOneof("type") assert ty is not None self.type = ty else: self.type = "unknown" self.root = root @dataclass(repr=False) class ChatMessage(Block): """A representation of ``st.chat_message``.""" type: str = field(repr=False) proto: BlockProto.ChatMessage = field(repr=False) name: str avatar: str def __init__( self, proto: BlockProto.ChatMessage, root: ElementTree, ) -> None: self.children = {} self.proto = proto self.root = root self.type = "chat_message" self.name = proto.name self.avatar = proto.avatar @dataclass(repr=False) class Column(Block): """A representation of a column within ``st.columns``.""" type: str = field(repr=False) proto: BlockProto.Column = field(repr=False) weight: float gap: str def __init__( self, proto: BlockProto.Column, root: ElementTree, ) -> None: self.children = {} self.proto = proto self.root = root self.type = "column" self.weight = proto.weight self.gap = proto.gap @dataclass(repr=False) class Expander(Block): type: str = field(repr=False) proto: BlockProto.Expandable = field(repr=False) icon: str label: str def __init__(self, proto: BlockProto.Expandable, root: ElementTree) -> None: self.children = {} self.proto = proto self.root = root # The internal name is "expandable" but the public API uses "expander" # so the naming of the class and type follows the public name. self.type = "expander" self.icon = proto.icon self.label = proto.label @dataclass(repr=False) class Status(Block): type: str = field(repr=False) proto: BlockProto.Expandable = field(repr=False) icon: str label: str def __init__(self, proto: BlockProto.Expandable, root: ElementTree) -> None: self.children = {} self.proto = proto self.root = root self.type = "status" self.icon = proto.icon self.label = proto.label @property def state(self) -> str: if self.icon == "spinner": return "running" if self.icon == ":material/check:": return "complete" if self.icon == ":material/error:": return "error" raise ValueError("Unknown Status state") @dataclass(repr=False) class Tab(Block): """A representation of tab within ``st.tabs``.""" type: str = field(repr=False) proto: BlockProto.Tab = field(repr=False) label: str def __init__( self, proto: BlockProto.Tab, root: ElementTree, ) -> None: self.children = {} self.proto = proto self.root = root self.type = "tab" self.label = proto.label Node: TypeAlias = Union[Element, Block] def get_widget_state(node: Node) -> WidgetState | None: if isinstance(node, Widget): return node._widget_state return None @dataclass(repr=False) class ElementTree(Block): """A tree of the elements produced by running a streamlit script. Elements can be queried in three ways: - By element type, using `.foo` properties to get a list of all of that element, in the order they appear in the app - By user key, for widgets, by calling the above list with a key: `.foo(key='bar')` - Positionally, using list indexing syntax (`[...]`) to access a child of a block element. Not recommended because the exact tree structure can be surprising. Element queries made on a block container will return only the elements descending from that block. Returned elements have methods for accessing whatever attributes are relevant. For very simple elements this may be only its value, while complex elements like widgets have many. Widgets provide a fluent API for faking frontend interaction and rerunning the script with the new widget values. All widgets provide a low level `set_value` method, along with higher level methods specific to that type of widget. After an interaction, calling `.run()` will update the AppTest with the results of that script run. """ _runner: AppTest | None = field(repr=False, default=None) def __init__(self) -> None: self.children = {} self.root = self self.type = "root" @property def main(self) -> Block: m = self[0] assert isinstance(m, Block) return m @property def sidebar(self) -> Block: s = self[1] assert isinstance(s, Block) return s @property def session_state(self) -> SafeSessionState: assert self._runner is not None return self._runner.session_state def get_widget_states(self) -> WidgetStates: ws = WidgetStates() for node in self: w = get_widget_state(node) if w is not None: ws.widgets.append(w) return ws def run(self, *, timeout: float | None = None) -> AppTest: """Run the script with updated widget values. Parameters ---------- timeout The maximum number of seconds to run the script. None means use the AppTest's default. """ assert self._runner is not None widget_states = self.get_widget_states() return self._runner._run(widget_states, timeout=timeout) def __repr__(self) -> str: return format_dict(self.children) def parse_tree_from_messages(messages: list[ForwardMsg]) -> ElementTree: """Transform a list of `ForwardMsg` into a tree matching the implicit tree structure of blocks and elements in a streamlit app. Returns the root of the tree, which acts as the entrypoint for the query and interaction API. """ root = ElementTree() root.children = { 0: SpecialBlock(type="main", root=root, proto=None), 1: SpecialBlock(type="sidebar", root=root, proto=None), 2: SpecialBlock(type="event", root=root, proto=None), } for msg in messages: if not msg.HasField("delta"): continue delta_path = msg.metadata.delta_path delta = msg.delta if delta.WhichOneof("type") == "new_element": elt = delta.new_element ty = elt.WhichOneof("type") new_node: Node if ty == "alert": alert_format = elt.alert.format if alert_format == AlertProto.Format.ERROR: new_node = Error(elt.alert, root=root) elif alert_format == AlertProto.Format.INFO: new_node = Info(elt.alert, root=root) elif alert_format == AlertProto.Format.SUCCESS: new_node = Success(elt.alert, root=root) elif alert_format == AlertProto.Format.WARNING: new_node = Warning(elt.alert, root=root) else: raise ValueError( f"Unknown alert type with format {elt.alert.format}" ) elif ty == "arrow_data_frame": new_node = Dataframe(elt.arrow_data_frame, root=root) elif ty == "arrow_table": new_node = Table(elt.arrow_table, root=root) elif ty == "button": new_node = Button(elt.button, root=root) elif ty == "button_group": new_node = ButtonGroup(elt.button_group, root=root) elif ty == "chat_input": new_node = ChatInput(elt.chat_input, root=root) elif ty == "checkbox": style = elt.checkbox.type if style == CheckboxProto.StyleType.TOGGLE: new_node = Toggle(elt.checkbox, root=root) else: new_node = Checkbox(elt.checkbox, root=root) elif ty == "code": new_node = Code(elt.code, root=root) elif ty == "color_picker": new_node = ColorPicker(elt.color_picker, root=root) elif ty == "date_input": new_node = DateInput(elt.date_input, root=root) elif ty == "exception": new_node = Exception(elt.exception, root=root) elif ty == "heading": if elt.heading.tag == HeadingProtoTag.TITLE_TAG.value: new_node = Title(elt.heading, root=root) elif elt.heading.tag == HeadingProtoTag.HEADER_TAG.value: new_node = Header(elt.heading, root=root) elif elt.heading.tag == HeadingProtoTag.SUBHEADER_TAG.value: new_node = Subheader(elt.heading, root=root) else: raise ValueError(f"Unknown heading type with tag {elt.heading.tag}") elif ty == "json": new_node = Json(elt.json, root=root) elif ty == "markdown": if elt.markdown.element_type == MarkdownProto.Type.NATIVE: new_node = Markdown(elt.markdown, root=root) elif elt.markdown.element_type == MarkdownProto.Type.CAPTION: new_node = Caption(elt.markdown, root=root) elif elt.markdown.element_type == MarkdownProto.Type.LATEX: new_node = Latex(elt.markdown, root=root) elif elt.markdown.element_type == MarkdownProto.Type.DIVIDER: new_node = Divider(elt.markdown, root=root) else: raise ValueError( f"Unknown markdown type {elt.markdown.element_type}" ) elif ty == "metric": new_node = Metric(elt.metric, root=root) elif ty == "multiselect": new_node = Multiselect(elt.multiselect, root=root) elif ty == "number_input": new_node = NumberInput(elt.number_input, root=root) elif ty == "radio": new_node = Radio(elt.radio, root=root) elif ty == "selectbox": new_node = Selectbox(elt.selectbox, root=root) elif ty == "slider": if elt.slider.type == SliderProto.Type.SLIDER: new_node = Slider(elt.slider, root=root) elif elt.slider.type == SliderProto.Type.SELECT_SLIDER: new_node = SelectSlider(elt.slider, root=root) else: raise ValueError(f"Slider with unknown type {elt.slider}") elif ty == "text": new_node = Text(elt.text, root=root) elif ty == "text_area": new_node = TextArea(elt.text_area, root=root) elif ty == "text_input": new_node = TextInput(elt.text_input, root=root) elif ty == "time_input": new_node = TimeInput(elt.time_input, root=root) elif ty == "toast": new_node = Toast(elt.toast, root=root) else: new_node = UnknownElement(elt, root=root) elif delta.WhichOneof("type") == "add_block": block = delta.add_block bty = block.WhichOneof("type") if bty == "chat_message": new_node = ChatMessage(block.chat_message, root=root) elif bty == "column": new_node = Column(block.column, root=root) elif bty == "expandable": if block.expandable.icon: new_node = Status(block.expandable, root=root) else: new_node = Expander(block.expandable, root=root) elif bty == "tab": new_node = Tab(block.tab, root=root) else: new_node = Block(proto=block, root=root) else: # add_rows continue current_node: Block = root # Every node up to the end is a Block for idx in delta_path[:-1]: children = current_node.children child = children.get(idx) if child is None: child = Block(proto=None, root=root) children[idx] = child assert isinstance(child, Block) current_node = child # Handle a block when we already have a placeholder for that location if isinstance(new_node, Block): placeholder_block = current_node.children.get(delta_path[-1]) if placeholder_block is not None: new_node.children = placeholder_block.children current_node.children[delta_path[-1]] = new_node return root