team-10/env/Lib/site-packages/streamlit/testing/v1/element_tree.py
2025-08-02 07:34:44 +02:00

2070 lines
62 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Assert statements are allowed here since the app testing logic is used within unit tests:
# ruff: noqa: S101
from __future__ import annotations
import textwrap
from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field, fields, is_dataclass
from datetime import date, datetime, time, timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import Self, TypeAlias
from streamlit import dataframe_util, util
from streamlit.elements.heading import HeadingProtoTag
from streamlit.elements.widgets.select_slider import SelectSliderSerde
from streamlit.elements.widgets.slider import (
SliderSerde,
SliderStep,
SliderValueT,
)
from streamlit.elements.widgets.time_widgets import (
DateInputSerde,
DateWidgetReturn,
TimeInputSerde,
_parse_date_value,
)
from streamlit.proto.Alert_pb2 import Alert as AlertProto
from streamlit.proto.Checkbox_pb2 import Checkbox as CheckboxProto
from streamlit.proto.Markdown_pb2 import Markdown as MarkdownProto
from streamlit.proto.Slider_pb2 import Slider as SliderProto
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
from streamlit.runtime.state.common import TESTING_KEY, user_key_from_element_id
if TYPE_CHECKING:
from pandas import DataFrame as PandasDataframe
from streamlit.proto.Arrow_pb2 import Arrow as ArrowProto
from streamlit.proto.Block_pb2 import Block as BlockProto
from streamlit.proto.Button_pb2 import Button as ButtonProto
from streamlit.proto.ButtonGroup_pb2 import ButtonGroup as ButtonGroupProto
from streamlit.proto.ChatInput_pb2 import ChatInput as ChatInputProto
from streamlit.proto.Code_pb2 import Code as CodeProto
from streamlit.proto.ColorPicker_pb2 import ColorPicker as ColorPickerProto
from streamlit.proto.DateInput_pb2 import DateInput as DateInputProto
from streamlit.proto.Element_pb2 import Element as ElementProto
from streamlit.proto.Exception_pb2 import Exception as ExceptionProto
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.Heading_pb2 import Heading as HeadingProto
from streamlit.proto.Json_pb2 import Json as JsonProto
from streamlit.proto.Metric_pb2 import Metric as MetricProto
from streamlit.proto.MultiSelect_pb2 import MultiSelect as MultiSelectProto
from streamlit.proto.NumberInput_pb2 import NumberInput as NumberInputProto
from streamlit.proto.Radio_pb2 import Radio as RadioProto
from streamlit.proto.Selectbox_pb2 import Selectbox as SelectboxProto
from streamlit.proto.Text_pb2 import Text as TextProto
from streamlit.proto.TextArea_pb2 import TextArea as TextAreaProto
from streamlit.proto.TextInput_pb2 import TextInput as TextInputProto
from streamlit.proto.TimeInput_pb2 import TimeInput as TimeInputProto
from streamlit.proto.Toast_pb2 import Toast as ToastProto
from streamlit.runtime.state.safe_session_state import SafeSessionState
from streamlit.testing.v1.app_test import AppTest
T = TypeVar("T")
@dataclass
class InitialValue:
"""Used to represent the initial value of a widget."""
pass
# TODO: This class serves as a fallback option for elements that have not
# been implemented yet, as well as providing implementations of some
# trivial methods. It may have significantly reduced scope once all elements
# have been implemented.
# This class will not be sufficient implementation for most elements.
# Widgets need their own classes to translate interactions into the appropriate
# WidgetState and provide higher level interaction interfaces, and other elements
# have enough variation in how to get their values that most will need their
# own classes too.
@dataclass
class Element(ABC):
"""
Element base class for testing.
This class's methods and attributes are universal for all elements
implemented in testing. For example, ``Caption``, ``Code``, ``Text``, and
``Title`` inherit from ``Element``. All widget classes also
inherit from Element, but have additional methods specific to each
widget type. See the ``AppTest`` class for the full list of supported
elements.
For all element classes, parameters of the original element can be obtained
as properties. For example, ``Button.label``, ``Caption.help``, and
``Toast.icon``.
"""
type: str = field(repr=False)
proto: Any = field(repr=False)
root: ElementTree = field(repr=False)
key: str | None
@abstractmethod
def __init__(self, proto: ElementProto, root: ElementTree) -> None: ...
def __iter__(self) -> Iterator[Self]:
yield self
@property
@abstractmethod
def value(self) -> Any:
"""The value or contents of the element."""
...
def __getattr__(self, name: str) -> Any:
"""Fallback attempt to get an attribute from the proto."""
return getattr(self.proto, name)
def run(self, *, timeout: float | None = None) -> AppTest:
"""Run the ``AppTest`` script which contains the element.
Parameters
----------
timeout
The maximum number of seconds to run the script. None means
use the AppTest's default.
"""
return self.root.run(timeout=timeout)
def __repr__(self) -> str:
return util.repr_(self)
@dataclass(repr=False)
class UnknownElement(Element):
def __init__(self, proto: ElementProto, root: ElementTree) -> None:
ty = proto.WhichOneof("type")
assert ty is not None
self.proto = getattr(proto, ty)
self.root = root
self.type = ty
self.key = None
@property
def value(self) -> Any:
try:
state = self.root.session_state
assert state is not None
return state[self.proto.id]
except ValueError:
# No id field, not a widget
return self.proto.value
@dataclass(repr=False)
class Widget(Element, ABC):
"""Widget base class for testing."""
id: str = field(repr=False)
disabled: bool
key: str | None
_value: Any
def __init__(self, proto: Any, root: ElementTree) -> None:
self.proto = proto
self.root = root
self.key = user_key_from_element_id(self.id)
self._value = None
def set_value(self, v: Any) -> Self:
"""Set the value of the widget."""
self._value = v
return self
@property
@abstractmethod
def _widget_state(self) -> WidgetState: ...
El_co = TypeVar("El_co", bound=Element, covariant=True)
class ElementList(Generic[El_co]):
def __init__(self, els: Sequence[El_co]) -> None:
self._list: Sequence[El_co] = els
def __len__(self) -> int:
return len(self._list)
@property
def len(self) -> int:
return len(self)
@overload
def __getitem__(self, idx: int) -> El_co: ...
@overload
def __getitem__(self, idx: slice) -> ElementList[El_co]: ...
def __getitem__(self, idx: int | slice) -> El_co | ElementList[El_co]:
if isinstance(idx, slice):
return ElementList(self._list[idx])
return self._list[idx]
def __iter__(self) -> Iterator[El_co]:
yield from self._list
def __repr__(self) -> str:
return util.repr_(self)
def __eq__(self, other: ElementList[El_co] | object) -> bool:
if isinstance(other, ElementList):
return self._list == other._list
return self._list == other
def __hash__(self) -> int:
return hash(tuple(self._list))
@property
def values(self) -> Sequence[Any]:
return [e.value for e in self]
W_co = TypeVar("W_co", bound=Widget, covariant=True)
class WidgetList(ElementList[W_co], Generic[W_co]):
def __call__(self, key: str) -> W_co:
for e in self._list:
if e.key == key:
return e
raise KeyError(key)
@dataclass(repr=False)
class AlertBase(Element):
proto: AlertProto = field(repr=False)
icon: str
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
self.proto = proto
self.key = None
self.root = root
@property
def value(self) -> str:
return self.proto.body
@dataclass(repr=False)
class Error(AlertBase):
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "error"
@dataclass(repr=False)
class Warning(AlertBase): # noqa: A001
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "warning"
@dataclass(repr=False)
class Info(AlertBase):
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "info"
@dataclass(repr=False)
class Success(AlertBase):
def __init__(self, proto: AlertProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "success"
@dataclass(repr=False)
class Button(Widget):
"""A representation of ``st.button`` and ``st.form_submit_button``."""
_value: bool
proto: ButtonProto = field(repr=False)
label: str
help: str
form_id: str
def __init__(self, proto: ButtonProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = False
self.type = "button"
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
ws.trigger_value = self._value
return ws
@property
def value(self) -> bool:
"""The value of the button. (bool)""" # noqa: D400
if self._value:
return self._value
state = self.root.session_state
assert state
return cast("bool", state[TESTING_KEY][self.id])
def set_value(self, v: bool) -> Button:
"""Set the value of the button."""
self._value = v
return self
def click(self) -> Button:
"""Set the value of the button to True."""
return self.set_value(True)
@dataclass(repr=False)
class ChatInput(Widget):
"""A representation of ``st.chat_input``."""
_value: str | None
proto: ChatInputProto = field(repr=False)
placeholder: str
def __init__(self, proto: ChatInputProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "chat_input"
def set_value(self, v: str | None) -> ChatInput:
"""Set the value of the widget."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
if self._value is not None:
ws.string_trigger_value.data = self._value
return ws
@property
def value(self) -> str | None:
"""The value of the widget. (str)""" # noqa: D400
if self._value:
return self._value
state = self.root.session_state
assert state
return state[TESTING_KEY][self.id] # type: ignore
@dataclass(repr=False)
class Checkbox(Widget):
"""A representation of ``st.checkbox``."""
_value: bool | None
proto: CheckboxProto = field(repr=False)
label: str
help: str
form_id: str
def __init__(self, proto: CheckboxProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "checkbox"
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
ws.bool_value = self.value
return ws
@property
def value(self) -> bool:
"""The value of the widget. (bool)""" # noqa: D400
if self._value is not None:
return self._value
state = self.root.session_state
assert state
return cast("bool", state[self.id])
def set_value(self, v: bool) -> Checkbox:
"""Set the value of the widget."""
self._value = v
return self
def check(self) -> Checkbox:
"""Set the value of the widget to True."""
return self.set_value(True)
def uncheck(self) -> Checkbox:
"""Set the value of the widget to False."""
return self.set_value(False)
@dataclass(repr=False)
class Code(Element):
"""A representation of ``st.code``."""
proto: CodeProto = field(repr=False)
language: str
show_line_numbers: bool
key: None
def __init__(self, proto: CodeProto, root: ElementTree) -> None:
self.proto = proto
self.key = None
self.root = root
self.type = "code"
@property
def value(self) -> str:
"""The value of the element. (str)""" # noqa: D400
return self.proto.code_text
@dataclass(repr=False)
class ColorPicker(Widget):
"""A representation of ``st.color_picker``."""
_value: str | None
label: str
help: str
form_id: str
proto: ColorPickerProto = field(repr=False)
def __init__(self, proto: ColorPickerProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "color_picker"
@property
def value(self) -> str:
"""The currently selected value as a hex string. (str)""" # noqa: D400
if self._value is not None:
return self._value
state = self.root.session_state
assert state
return cast("str", state[self.id])
@property
def _widget_state(self) -> WidgetState:
"""Protobuf message representing the state of the widget, including
any interactions that have happened.
Should be the same as the frontend would produce for those interactions.
"""
ws = WidgetState()
ws.id = self.id
ws.string_value = self.value
return ws
def set_value(self, v: str) -> ColorPicker:
"""Set the value of the widget as a hex string."""
self._value = v
return self
def pick(self, v: str) -> ColorPicker:
"""Set the value of the widget as a hex string. May omit the "#" prefix."""
if not v.startswith("#"):
v = f"#{v}"
return self.set_value(v)
@dataclass(repr=False)
class Dataframe(Element):
proto: ArrowProto = field(repr=False)
def __init__(self, proto: ArrowProto, root: ElementTree) -> None:
self.key = None
self.proto = proto
self.root = root
self.type = "arrow_data_frame"
@property
def value(self) -> PandasDataframe:
return dataframe_util.convert_arrow_bytes_to_pandas_df(self.proto.data)
SingleDateValue: TypeAlias = Union[date, datetime]
DateValue: TypeAlias = Union[SingleDateValue, Sequence[SingleDateValue], None]
@dataclass(repr=False)
class DateInput(Widget):
"""A representation of ``st.date_input``."""
_value: DateValue | None | InitialValue
proto: DateInputProto = field(repr=False)
label: str
min: date
max: date
is_range: bool
help: str
form_id: str
def __init__(self, proto: DateInputProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = InitialValue()
self.type = "date_input"
self.min = datetime.strptime(proto.min, "%Y/%m/%d").date()
self.max = datetime.strptime(proto.max, "%Y/%m/%d").date()
def set_value(self, v: DateValue) -> DateInput:
"""Set the value of the widget."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
serde = DateInputSerde(None) # type: ignore
ws.string_array_value.data[:] = serde.serialize(self.value)
return ws
@property
def value(self) -> DateWidgetReturn:
"""The value of the widget. (date or Tuple of date)""" # noqa: D400
if not isinstance(self._value, InitialValue):
parsed, _ = _parse_date_value(self._value)
return tuple(parsed) if parsed is not None else None # type: ignore
state = self.root.session_state
assert state
return state[self.id] # type: ignore
@dataclass(repr=False)
class Exception(Element): # noqa: A001
message: str
is_markdown: bool
stack_trace: list[str]
is_warning: bool
def __init__(self, proto: ExceptionProto, root: ElementTree) -> None:
self.key = None
self.root = root
self.proto = proto
self.type = "exception"
self.is_markdown = proto.message_is_markdown
self.stack_trace = list(proto.stack_trace)
@property
def value(self) -> str:
return self.message
@dataclass(repr=False)
class HeadingBase(Element, ABC):
proto: HeadingProto = field(repr=False)
tag: str
anchor: str | None
hide_anchor: bool
key: None
def __init__(self, proto: HeadingProto, root: ElementTree, type_: str) -> None:
self.proto = proto
self.key = None
self.root = root
self.type = type_
@property
def value(self) -> str:
return self.proto.body
@dataclass(repr=False)
class Header(HeadingBase):
def __init__(self, proto: HeadingProto, root: ElementTree) -> None:
super().__init__(proto, root, "header")
@dataclass(repr=False)
class Subheader(HeadingBase):
def __init__(self, proto: HeadingProto, root: ElementTree) -> None:
super().__init__(proto, root, "subheader")
@dataclass(repr=False)
class Title(HeadingBase):
def __init__(self, proto: HeadingProto, root: ElementTree) -> None:
super().__init__(proto, root, "title")
@dataclass(repr=False)
class Json(Element):
proto: JsonProto = field(repr=False)
expanded: bool
def __init__(self, proto: JsonProto, root: ElementTree) -> None:
self.proto = proto
self.key = None
self.root = root
self.type = "json"
@property
def value(self) -> str:
return self.proto.body
@dataclass(repr=False)
class Markdown(Element):
proto: MarkdownProto = field(repr=False)
is_caption: bool
allow_html: bool
key: None
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
self.proto = proto
self.key = None
self.root = root
self.type = "markdown"
@property
def value(self) -> str:
return self.proto.body
@dataclass(repr=False)
class Caption(Markdown):
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "caption"
@dataclass(repr=False)
class Divider(Markdown):
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "divider"
@dataclass(repr=False)
class Latex(Markdown):
def __init__(self, proto: MarkdownProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "latex"
@dataclass(repr=False)
class Metric(Element):
proto: MetricProto
label: str
delta: str
color: str
help: str
def __init__(self, proto: MetricProto, root: ElementTree) -> None:
self.proto = proto
self.key = None
self.root = root
self.type = "metric"
@property
def value(self) -> str:
return self.proto.body
@dataclass(repr=False)
class ButtonGroup(Widget, Generic[T]):
"""A representation of button_group that is used by ``st.feedback``."""
_value: list[T] | None
proto: ButtonGroupProto = field(repr=False)
options: list[ButtonGroupProto.Option]
form_id: str
def __init__(self, proto: ButtonGroupProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "button_group"
self.options = list(proto.options)
@property
def _widget_state(self) -> WidgetState:
"""Protobuf message representing the state of the widget, including
any interactions that have happened.
Should be the same as the frontend would produce for those interactions.
"""
ws = WidgetState()
ws.id = self.id
ws.int_array_value.data[:] = self.indices
return ws
@property
def value(self) -> list[T]:
"""The currently selected values from the options. (list)""" # noqa: D400
if self._value is not None:
return self._value
state = self.root.session_state
assert state
return cast("list[T]", state[self.id])
@property
def indices(self) -> Sequence[int]:
"""The indices of the currently selected values from the options. (list)""" # noqa: D400
return [self.options.index(self.format_func(v)) for v in self.value]
@property
def format_func(self) -> Callable[[Any], Any]:
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
ss = self.root.session_state
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
def set_value(self, v: list[T]) -> ButtonGroup[T]:
"""Set the value of the multiselect widget. (list)""" # noqa: D400
self._value = v
return self
def select(self, v: T) -> ButtonGroup[T]:
"""
Add a selection to the widget. Do nothing if the value is already selected.\
If testing a multiselect widget with repeated options, use ``set_value``\
instead.
"""
current = self.value
if v in current:
return self
new = current.copy()
new.append(v)
self.set_value(new)
return self
def unselect(self, v: T) -> ButtonGroup[T]:
"""
Remove a selection from the widget. Do nothing if the value is not\
already selected. If a value is selected multiple times, the first\
instance is removed.
"""
current = self.value
if v not in current:
return self
new = current.copy()
while v in new:
new.remove(v)
self.set_value(new)
return self
@dataclass(repr=False)
class Multiselect(Widget, Generic[T]):
"""A representation of ``st.multiselect``."""
_value: list[T] | None
proto: MultiSelectProto = field(repr=False)
label: str
options: list[str]
max_selections: int
help: str
form_id: str
def __init__(self, proto: MultiSelectProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "multiselect"
self.options = list(proto.options)
@property
def _widget_state(self) -> WidgetState:
"""Protobuf message representing the state of the widget, including
any interactions that have happened.
Should be the same as the frontend would produce for those interactions.
"""
ws = WidgetState()
ws.id = self.id
ws.string_array_value.data[:] = self.values
return ws
@property
def value(self) -> list[T]:
"""The currently selected values from the options. (list)""" # noqa: D400
if self._value is not None:
return self._value
state = self.root.session_state
assert state
return cast("list[T]", state[self.id])
@property
def indices(self) -> Sequence[int]:
"""The indices of the currently selected values from the options. (list)""" # noqa: D400
return [self.options.index(self.format_func(v)) for v in self.value]
@property
def values(self) -> Sequence[str]:
"""The currently selected values from the options. (list)""" # noqa: D400
return [self.format_func(v) for v in self.value]
@property
def format_func(self) -> Callable[[Any], Any]:
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
ss = self.root.session_state
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
def set_value(self, v: list[T]) -> Multiselect[T]:
"""Set the value of the multiselect widget. (list)""" # noqa: D400
self._value = v
return self
def select(self, v: T) -> Multiselect[T]:
"""
Add a selection to the widget. Do nothing if the value is already selected.\
If testing a multiselect widget with repeated options, use ``set_value``\
instead.
"""
current = self.value
if v in current:
return self
new = current.copy()
new.append(v)
self.set_value(new)
return self
def unselect(self, v: T) -> Multiselect[T]:
"""
Remove a selection from the widget. Do nothing if the value is not\
already selected. If a value is selected multiple times, the first\
instance is removed.
"""
current = self.value
if v not in current:
return self
new = current.copy()
while v in new:
new.remove(v)
self.set_value(new)
return self
Number = Union[int, float]
@dataclass(repr=False)
class NumberInput(Widget):
"""A representation of ``st.number_input``."""
_value: Number | None | InitialValue
proto: NumberInputProto = field(repr=False)
label: str
min: Number | None
max: Number | None
step: Number
help: str
form_id: str
def __init__(self, proto: NumberInputProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = InitialValue()
self.type = "number_input"
self.min = proto.min if proto.has_min else None
self.max = proto.max if proto.has_max else None
def set_value(self, v: Number | None) -> NumberInput:
"""Set the value of the ``st.number_input`` widget."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
if self.value is not None:
ws.double_value = self.value
return ws
@property
def value(self) -> Number | None:
"""Get the current value of the ``st.number_input`` widget."""
if not isinstance(self._value, InitialValue):
return self._value
state = self.root.session_state
assert state
# Awkward to do this with `cast`
return state[self.id] # type: ignore
def increment(self) -> NumberInput:
"""Increment the ``st.number_input`` widget as if the user clicked "+"."""
if self.value is None:
return self
v = min(self.value + self.step, self.max or float("inf"))
return self.set_value(v)
def decrement(self) -> NumberInput:
"""Decrement the ``st.number_input`` widget as if the user clicked "-"."""
if self.value is None:
return self
v = max(self.value - self.step, self.min or float("-inf"))
return self.set_value(v)
@dataclass(repr=False)
class Radio(Widget, Generic[T]):
"""A representation of ``st.radio``."""
_value: T | None | InitialValue
proto: RadioProto = field(repr=False)
label: str
options: list[str]
horizontal: bool
help: str
form_id: str
def __init__(self, proto: RadioProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = InitialValue()
self.type = "radio"
self.options = list(proto.options)
@property
def index(self) -> int | None:
"""The index of the current selection. (int)""" # noqa: D400
if self.value is None:
return None
return self.options.index(self.format_func(self.value))
@property
def value(self) -> T | None:
"""The currently selected value from the options. (Any)""" # noqa: D400
if not isinstance(self._value, InitialValue):
return self._value
state = self.root.session_state
assert state
return cast("T", state[self.id])
@property
def format_func(self) -> Callable[[Any], Any]:
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
ss = self.root.session_state
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
def set_value(self, v: T | None) -> Radio[T]:
"""Set the selection by value."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
"""Protobuf message representing the state of the widget, including
any interactions that have happened.
Should be the same as the frontend would produce for those interactions.
"""
ws = WidgetState()
ws.id = self.id
if self.index is not None:
ws.int_value = self.index
return ws
@dataclass(repr=False)
class Selectbox(Widget, Generic[T]):
"""A representation of ``st.selectbox``."""
_value: T | None | InitialValue
proto: SelectboxProto = field(repr=False)
label: str
options: list[str]
help: str
form_id: str
def __init__(self, proto: SelectboxProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = InitialValue()
self.type = "selectbox"
self.options = list(proto.options)
@property
def index(self) -> int | None:
"""The index of the current selection. (int)""" # noqa: D400
if self.value is None:
return None
if len(self.options) == 0:
return 0
return self.options.index(self.format_func(self.value))
@property
def value(self) -> T | None:
"""The currently selected value from the options. (Any)""" # noqa: D400
if not isinstance(self._value, InitialValue):
return self._value
state = self.root.session_state
assert state
return cast("T", state[self.id])
@property
def format_func(self) -> Callable[[Any], Any]:
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
ss = self.root.session_state
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
def set_value(self, v: T | None) -> Selectbox[T]:
"""Set the selection by value."""
self._value = v
return self
def select(self, v: T | None) -> Selectbox[T]:
"""Set the selection by value."""
return self.set_value(v)
def select_index(self, index: int | None) -> Selectbox[T]:
"""Set the selection by index."""
if index is None:
return self.set_value(None)
return self.set_value(cast("T", self.options[index]))
@property
def _widget_state(self) -> WidgetState:
"""Protobuf message representing the state of the widget, including
any interactions that have happened.
Should be the same as the frontend would produce for those interactions.
"""
ws = WidgetState()
ws.id = self.id
if self.index is not None and len(self.options) > 0:
ws.string_value = self.options[self.index]
return ws
@dataclass(repr=False)
class SelectSlider(Widget, Generic[T]):
"""A representation of ``st.select_slider``."""
_value: T | Sequence[T] | None
proto: SliderProto = field(repr=False)
label: str
data_type: SliderProto.DataType.ValueType
options: list[str]
help: str
form_id: str
def __init__(self, proto: SliderProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "select_slider"
self.options = list(proto.options)
def set_value(self, v: T | Sequence[T]) -> SelectSlider[T]:
"""Set the (single) selection by value."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
serde = SelectSliderSerde(self.options, [], False)
try:
v = serde.serialize(self.format_func(self.value))
except (ValueError, TypeError):
try:
v = serde.serialize([self.format_func(val) for val in self.value]) # type: ignore
except: # noqa: E722
raise ValueError(f"Could not find index for {self.value}")
ws = WidgetState()
ws.id = self.id
ws.double_array_value.data[:] = v
return ws
@property
def value(self) -> T | Sequence[T]:
"""The currently selected value or range. (Any or Sequence of Any)""" # noqa: D400
if self._value is not None:
return self._value
state = self.root.session_state
assert state
# Awkward to do this with `cast`
return state[self.id] # type: ignore
@property
def format_func(self) -> Callable[[Any], Any]:
"""The widget's formatting function for displaying options. (callable)""" # noqa: D400
ss = self.root.session_state
return cast("Callable[[Any], Any]", ss[TESTING_KEY][self.id])
def set_range(self, lower: T, upper: T) -> SelectSlider[T]:
"""Set the ranged selection by values."""
return self.set_value([lower, upper])
@dataclass(repr=False)
class Slider(Widget, Generic[SliderValueT]):
"""A representation of ``st.slider``."""
_value: SliderValueT | Sequence[SliderValueT] | None
proto: SliderProto = field(repr=False)
label: str
data_type: SliderProto.DataType.ValueType
min: SliderValueT
max: SliderValueT
step: SliderStep
help: str
form_id: str
def __init__(self, proto: SliderProto, root: ElementTree) -> None:
super().__init__(proto, root)
self.type = "slider"
def set_value(
self, v: SliderValueT | Sequence[SliderValueT]
) -> Slider[SliderValueT]:
"""Set the (single) value of the slider."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
data_type = self.proto.data_type
serde = SliderSerde([], data_type, True, None)
v = serde.serialize(self.value)
ws = WidgetState()
ws.id = self.id
ws.double_array_value.data[:] = v
return ws
@property
def value(self) -> SliderValueT | Sequence[SliderValueT]:
"""The currently selected value or range. (Any or Sequence of Any)""" # noqa: D400
if self._value is not None:
return self._value
state = self.root.session_state
assert state
# Awkward to do this with `cast`
return state[self.id] # type: ignore
def set_range(
self, lower: SliderValueT, upper: SliderValueT
) -> Slider[SliderValueT]:
"""Set the ranged value of the slider."""
return self.set_value([lower, upper])
@dataclass(repr=False)
class Table(Element):
proto: ArrowProto = field(repr=False)
def __init__(self, proto: ArrowProto, root: ElementTree) -> None:
self.key = None
self.proto = proto
self.root = root
self.type = "arrow_table"
@property
def value(self) -> PandasDataframe:
return dataframe_util.convert_arrow_bytes_to_pandas_df(self.proto.data)
@dataclass(repr=False)
class Text(Element):
proto: TextProto = field(repr=False)
key: None = None
def __init__(self, proto: TextProto, root: ElementTree) -> None:
self.proto = proto
self.root = root
self.type = "text"
@property
def value(self) -> str:
"""The value of the element. (str)""" # noqa: D400
return self.proto.body
@dataclass(repr=False)
class TextArea(Widget):
"""A representation of ``st.text_area``."""
_value: str | None | InitialValue
proto: TextAreaProto = field(repr=False)
label: str
max_chars: int
placeholder: str
help: str
form_id: str
def __init__(self, proto: TextAreaProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = InitialValue()
self.type = "text_area"
def set_value(self, v: str | None) -> TextArea:
"""Set the value of the widget."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
if self.value is not None:
ws.string_value = self.value
return ws
@property
def value(self) -> str | None:
"""The current value of the widget. (str)""" # noqa: D400
if not isinstance(self._value, InitialValue):
return self._value
state = self.root.session_state
assert state
# Awkward to do this with `cast`
return state[self.id] # type: ignore
def input(self, v: str) -> TextArea:
"""
Set the value of the widget only if the value does not exceed the\
maximum allowed characters.
"""
# TODO: should input be setting or appending?
if self.max_chars and len(v) > self.max_chars:
return self
return self.set_value(v)
@dataclass(repr=False)
class TextInput(Widget):
"""A representation of ``st.text_input``."""
_value: str | None | InitialValue
proto: TextInputProto = field(repr=False)
label: str
max_chars: int
autocomplete: str
placeholder: str
help: str
form_id: str
def __init__(self, proto: TextInputProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = InitialValue()
self.type = "text_input"
def set_value(self, v: str | None) -> TextInput:
"""Set the value of the widget."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
if self.value is not None:
ws.string_value = self.value
return ws
@property
def value(self) -> str | None:
"""The current value of the widget. (str)""" # noqa: D400
if not isinstance(self._value, InitialValue):
return self._value
state = self.root.session_state
assert state
# Awkward to do this with `cast`
return state[self.id] # type: ignore
def input(self, v: str) -> TextInput:
"""
Set the value of the widget only if the value does not exceed the\
maximum allowed characters.
"""
# TODO: should input be setting or appending?
if self.max_chars and len(v) > self.max_chars:
return self
return self.set_value(v)
TimeValue: TypeAlias = Union[time, datetime]
@dataclass(repr=False)
class TimeInput(Widget):
"""A representation of ``st.time_input``."""
_value: TimeValue | None | InitialValue
proto: TimeInputProto = field(repr=False)
label: str
step: int
help: str
form_id: str
def __init__(self, proto: TimeInputProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = InitialValue()
self.type = "time_input"
def set_value(self, v: TimeValue | None) -> TimeInput:
"""Set the value of the widget."""
self._value = v
return self
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
serde = TimeInputSerde(None)
serialized_value = serde.serialize(self.value)
if serialized_value is not None:
ws.string_value = serialized_value
return ws
@property
def value(self) -> time | None:
"""The current value of the widget. (time)""" # noqa: D400
if not isinstance(self._value, InitialValue):
v = self._value
return v.time() if isinstance(v, datetime) else v
state = self.root.session_state
assert state
return state[self.id] # type: ignore
def increment(self) -> TimeInput:
"""Select the next available time."""
if self.value is None:
return self
dt = datetime.combine(date.today(), self.value) + timedelta(seconds=self.step)
return self.set_value(dt.time())
def decrement(self) -> TimeInput:
"""Select the previous available time."""
if self.value is None:
return self
dt = datetime.combine(date.today(), self.value) - timedelta(seconds=self.step)
return self.set_value(dt.time())
@dataclass(repr=False)
class Toast(Element):
proto: ToastProto = field(repr=False)
icon: str
def __init__(self, proto: ToastProto, root: ElementTree) -> None:
self.proto = proto
self.key = None
self.root = root
self.type = "toast"
@property
def value(self) -> str:
return self.proto.body
@dataclass(repr=False)
class Toggle(Widget):
"""A representation of ``st.toggle``."""
_value: bool | None
proto: CheckboxProto = field(repr=False)
label: str
help: str
form_id: str
def __init__(self, proto: CheckboxProto, root: ElementTree) -> None:
super().__init__(proto, root)
self._value = None
self.type = "toggle"
@property
def _widget_state(self) -> WidgetState:
ws = WidgetState()
ws.id = self.id
ws.bool_value = self.value
return ws
@property
def value(self) -> bool:
"""The current value of the widget. (bool)""" # noqa: D400
if self._value is not None:
return self._value
state = self.root.session_state
assert state
return cast("bool", state[self.id])
def set_value(self, v: bool) -> Toggle:
"""Set the value of the widget."""
self._value = v
return self
@dataclass(repr=False)
class Block:
"""A container of other elements.
Elements within a Block can be inspected and interacted with. This follows
the same syntax as inspecting and interacting within an ``AppTest`` object.
For all container classes, parameters of the original element can be
obtained as properties. For example, ``ChatMessage.avatar`` and
``Tab.label``.
"""
type: str
children: dict[int, Node]
proto: Any = field(repr=False)
root: ElementTree = field(repr=False)
def __init__(
self,
proto: BlockProto | None,
root: ElementTree,
) -> None:
self.children = {}
self.proto = proto
if proto:
ty = proto.WhichOneof("type")
if ty is not None:
self.type = ty
else:
# `st.container` has no sub-message
self.type = "container"
else:
self.type = "unknown"
self.root = root
def __len__(self) -> int:
return len(self.children)
def __iter__(self) -> Iterator[Node]:
yield self
for child_idx in self.children:
yield from self.children[child_idx]
def __getitem__(self, k: int) -> Node:
return self.children[k]
@property
def key(self) -> str | None:
return None
# We could implement these using __getattr__ but that would have
# much worse type information.
@property
def button(self) -> WidgetList[Button]:
return WidgetList(self.get("button")) # type: ignore
@property
def button_group(self) -> WidgetList[ButtonGroup[Any]]:
return WidgetList(self.get("button_group")) # type: ignore
@property
def caption(self) -> ElementList[Caption]:
return ElementList(self.get("caption")) # type: ignore
@property
def chat_input(self) -> WidgetList[ChatInput]:
return WidgetList(self.get("chat_input")) # type: ignore
@property
def chat_message(self) -> Sequence[ChatMessage]:
return self.get("chat_message") # type: ignore
@property
def checkbox(self) -> WidgetList[Checkbox]:
return WidgetList(self.get("checkbox")) # type: ignore
@property
def code(self) -> ElementList[Code]:
return ElementList(self.get("code")) # type: ignore
@property
def color_picker(self) -> WidgetList[ColorPicker]:
return WidgetList(self.get("color_picker")) # type: ignore
@property
def columns(self) -> Sequence[Column]:
return self.get("column") # type: ignore
@property
def dataframe(self) -> ElementList[Dataframe]:
return ElementList(self.get("arrow_data_frame")) # type: ignore
@property
def date_input(self) -> WidgetList[DateInput]:
return WidgetList(self.get("date_input")) # type: ignore
@property
def divider(self) -> ElementList[Divider]:
return ElementList(self.get("divider")) # type: ignore
@property
def error(self) -> ElementList[Error]:
return ElementList(self.get("error")) # type: ignore
@property
def exception(self) -> ElementList[Exception]:
return ElementList(self.get("exception")) # type: ignore
@property
def expander(self) -> Sequence[Expander]:
return self.get("expander") # type: ignore
@property
def header(self) -> ElementList[Header]:
return ElementList(self.get("header")) # type: ignore
@property
def info(self) -> ElementList[Info]:
return ElementList(self.get("info")) # type: ignore
@property
def json(self) -> ElementList[Json]:
return ElementList(self.get("json")) # type: ignore
@property
def latex(self) -> ElementList[Latex]:
return ElementList(self.get("latex")) # type: ignore
@property
def markdown(self) -> ElementList[Markdown]:
return ElementList(self.get("markdown")) # type: ignore
@property
def metric(self) -> ElementList[Metric]:
return ElementList(self.get("metric")) # type: ignore
@property
def multiselect(self) -> WidgetList[Multiselect[Any]]:
return WidgetList(self.get("multiselect")) # type: ignore
@property
def number_input(self) -> WidgetList[NumberInput]:
return WidgetList(self.get("number_input")) # type: ignore
@property
def radio(self) -> WidgetList[Radio[Any]]:
return WidgetList(self.get("radio")) # type: ignore
@property
def select_slider(self) -> WidgetList[SelectSlider[Any]]:
return WidgetList(self.get("select_slider")) # type: ignore
@property
def selectbox(self) -> WidgetList[Selectbox[Any]]:
return WidgetList(self.get("selectbox")) # type: ignore
@property
def slider(self) -> WidgetList[Slider[Any]]:
return WidgetList(self.get("slider")) # type: ignore
@property
def status(self) -> Sequence[Status]:
return self.get("status") # type: ignore
@property
def subheader(self) -> ElementList[Subheader]:
return ElementList(self.get("subheader")) # type: ignore
@property
def success(self) -> ElementList[Success]:
return ElementList(self.get("success")) # type: ignore
@property
def table(self) -> ElementList[Table]:
return ElementList(self.get("arrow_table")) # type: ignore
@property
def tabs(self) -> Sequence[Tab]:
return self.get("tab") # type: ignore
@property
def text(self) -> ElementList[Text]:
return ElementList(self.get("text")) # type: ignore
@property
def text_area(self) -> WidgetList[TextArea]:
return WidgetList(self.get("text_area")) # type: ignore
@property
def text_input(self) -> WidgetList[TextInput]:
return WidgetList(self.get("text_input")) # type: ignore
@property
def time_input(self) -> WidgetList[TimeInput]:
return WidgetList(self.get("time_input")) # type: ignore
@property
def title(self) -> ElementList[Title]:
return ElementList(self.get("title")) # type: ignore
@property
def toast(self) -> ElementList[Toast]:
return ElementList(self.get("toast")) # type: ignore
@property
def toggle(self) -> WidgetList[Toggle]:
return WidgetList(self.get("toggle")) # type: ignore
@property
def warning(self) -> ElementList[Warning]:
return ElementList(self.get("warning")) # type: ignore
def get(self, element_type: str) -> Sequence[Node]:
return [e for e in self if e.type == element_type]
def run(self, *, timeout: float | None = None) -> AppTest:
"""Run the script with updated widget values.
Parameters
----------
timeout
The maximum number of seconds to run the script. None means
use the AppTest's default.
"""
return self.root.run(timeout=timeout)
def __repr__(self) -> str:
return repr_(self)
def repr_(self: object) -> str:
"""A custom repr similar to `streamlit.util.repr_` but that shows tree
structure using indentation.
"""
classname = self.__class__.__name__
defaults: list[Any] = [None, "", False, [], set(), {}]
if is_dataclass(self):
fields_vals = (
(f.name, getattr(self, f.name))
for f in fields(self)
if f.repr
and getattr(self, f.name) != f.default
and getattr(self, f.name) not in defaults
)
else:
fields_vals = ((f, v) for (f, v) in self.__dict__.items() if v not in defaults)
reprs = []
for field_name, value in fields_vals:
line = (
f"{field_name}={format_dict(value)}"
if isinstance(value, dict)
else f"{field_name}={value!r}"
)
reprs.append(line)
reprs[0] = "\n" + reprs[0]
field_reprs = ",\n".join(reprs)
field_reprs = textwrap.indent(field_reprs, " " * 4)
return f"{classname}({field_reprs}\n)"
def format_dict(d: dict[Any, Any]) -> str:
lines = []
for k, v in d.items():
line = f"{k}: {v!r}"
lines.append(line)
r = ",\n".join(lines)
r = textwrap.indent(r, " " * 4)
return f"{{\n{r}\n}}"
@dataclass(repr=False)
class SpecialBlock(Block):
"""Base class for the sidebar and main body containers."""
def __init__(
self,
proto: BlockProto | None,
root: ElementTree,
type: str | None = None,
) -> None:
self.children = {}
self.proto = proto
if type:
self.type = type
elif proto and proto.WhichOneof("type"):
ty = proto.WhichOneof("type")
assert ty is not None
self.type = ty
else:
self.type = "unknown"
self.root = root
@dataclass(repr=False)
class ChatMessage(Block):
"""A representation of ``st.chat_message``."""
type: str = field(repr=False)
proto: BlockProto.ChatMessage = field(repr=False)
name: str
avatar: str
def __init__(
self,
proto: BlockProto.ChatMessage,
root: ElementTree,
) -> None:
self.children = {}
self.proto = proto
self.root = root
self.type = "chat_message"
self.name = proto.name
self.avatar = proto.avatar
@dataclass(repr=False)
class Column(Block):
"""A representation of a column within ``st.columns``."""
type: str = field(repr=False)
proto: BlockProto.Column = field(repr=False)
weight: float
gap: str
def __init__(
self,
proto: BlockProto.Column,
root: ElementTree,
) -> None:
self.children = {}
self.proto = proto
self.root = root
self.type = "column"
self.weight = proto.weight
self.gap = proto.gap
@dataclass(repr=False)
class Expander(Block):
type: str = field(repr=False)
proto: BlockProto.Expandable = field(repr=False)
icon: str
label: str
def __init__(self, proto: BlockProto.Expandable, root: ElementTree) -> None:
self.children = {}
self.proto = proto
self.root = root
# The internal name is "expandable" but the public API uses "expander"
# so the naming of the class and type follows the public name.
self.type = "expander"
self.icon = proto.icon
self.label = proto.label
@dataclass(repr=False)
class Status(Block):
type: str = field(repr=False)
proto: BlockProto.Expandable = field(repr=False)
icon: str
label: str
def __init__(self, proto: BlockProto.Expandable, root: ElementTree) -> None:
self.children = {}
self.proto = proto
self.root = root
self.type = "status"
self.icon = proto.icon
self.label = proto.label
@property
def state(self) -> str:
if self.icon == "spinner":
return "running"
if self.icon == ":material/check:":
return "complete"
if self.icon == ":material/error:":
return "error"
raise ValueError("Unknown Status state")
@dataclass(repr=False)
class Tab(Block):
"""A representation of tab within ``st.tabs``."""
type: str = field(repr=False)
proto: BlockProto.Tab = field(repr=False)
label: str
def __init__(
self,
proto: BlockProto.Tab,
root: ElementTree,
) -> None:
self.children = {}
self.proto = proto
self.root = root
self.type = "tab"
self.label = proto.label
Node: TypeAlias = Union[Element, Block]
def get_widget_state(node: Node) -> WidgetState | None:
if isinstance(node, Widget):
return node._widget_state
return None
@dataclass(repr=False)
class ElementTree(Block):
"""A tree of the elements produced by running a streamlit script.
Elements can be queried in three ways:
- By element type, using `.foo` properties to get a list of all of that element,
in the order they appear in the app
- By user key, for widgets, by calling the above list with a key: `.foo(key='bar')`
- Positionally, using list indexing syntax (`[...]`) to access a child of a
block element. Not recommended because the exact tree structure can be surprising.
Element queries made on a block container will return only the elements
descending from that block.
Returned elements have methods for accessing whatever attributes are relevant.
For very simple elements this may be only its value, while complex elements
like widgets have many.
Widgets provide a fluent API for faking frontend interaction and rerunning
the script with the new widget values. All widgets provide a low level `set_value`
method, along with higher level methods specific to that type of widget.
After an interaction, calling `.run()` will update the AppTest with the
results of that script run.
"""
_runner: AppTest | None = field(repr=False, default=None)
def __init__(self) -> None:
self.children = {}
self.root = self
self.type = "root"
@property
def main(self) -> Block:
m = self[0]
assert isinstance(m, Block)
return m
@property
def sidebar(self) -> Block:
s = self[1]
assert isinstance(s, Block)
return s
@property
def session_state(self) -> SafeSessionState:
assert self._runner is not None
return self._runner.session_state
def get_widget_states(self) -> WidgetStates:
ws = WidgetStates()
for node in self:
w = get_widget_state(node)
if w is not None:
ws.widgets.append(w)
return ws
def run(self, *, timeout: float | None = None) -> AppTest:
"""Run the script with updated widget values.
Parameters
----------
timeout
The maximum number of seconds to run the script. None means
use the AppTest's default.
"""
assert self._runner is not None
widget_states = self.get_widget_states()
return self._runner._run(widget_states, timeout=timeout)
def __repr__(self) -> str:
return format_dict(self.children)
def parse_tree_from_messages(messages: list[ForwardMsg]) -> ElementTree:
"""Transform a list of `ForwardMsg` into a tree matching the implicit
tree structure of blocks and elements in a streamlit app.
Returns the root of the tree, which acts as the entrypoint for the query
and interaction API.
"""
root = ElementTree()
root.children = {
0: SpecialBlock(type="main", root=root, proto=None),
1: SpecialBlock(type="sidebar", root=root, proto=None),
2: SpecialBlock(type="event", root=root, proto=None),
}
for msg in messages:
if not msg.HasField("delta"):
continue
delta_path = msg.metadata.delta_path
delta = msg.delta
if delta.WhichOneof("type") == "new_element":
elt = delta.new_element
ty = elt.WhichOneof("type")
new_node: Node
if ty == "alert":
alert_format = elt.alert.format
if alert_format == AlertProto.Format.ERROR:
new_node = Error(elt.alert, root=root)
elif alert_format == AlertProto.Format.INFO:
new_node = Info(elt.alert, root=root)
elif alert_format == AlertProto.Format.SUCCESS:
new_node = Success(elt.alert, root=root)
elif alert_format == AlertProto.Format.WARNING:
new_node = Warning(elt.alert, root=root)
else:
raise ValueError(
f"Unknown alert type with format {elt.alert.format}"
)
elif ty == "arrow_data_frame":
new_node = Dataframe(elt.arrow_data_frame, root=root)
elif ty == "arrow_table":
new_node = Table(elt.arrow_table, root=root)
elif ty == "button":
new_node = Button(elt.button, root=root)
elif ty == "button_group":
new_node = ButtonGroup(elt.button_group, root=root)
elif ty == "chat_input":
new_node = ChatInput(elt.chat_input, root=root)
elif ty == "checkbox":
style = elt.checkbox.type
if style == CheckboxProto.StyleType.TOGGLE:
new_node = Toggle(elt.checkbox, root=root)
else:
new_node = Checkbox(elt.checkbox, root=root)
elif ty == "code":
new_node = Code(elt.code, root=root)
elif ty == "color_picker":
new_node = ColorPicker(elt.color_picker, root=root)
elif ty == "date_input":
new_node = DateInput(elt.date_input, root=root)
elif ty == "exception":
new_node = Exception(elt.exception, root=root)
elif ty == "heading":
if elt.heading.tag == HeadingProtoTag.TITLE_TAG.value:
new_node = Title(elt.heading, root=root)
elif elt.heading.tag == HeadingProtoTag.HEADER_TAG.value:
new_node = Header(elt.heading, root=root)
elif elt.heading.tag == HeadingProtoTag.SUBHEADER_TAG.value:
new_node = Subheader(elt.heading, root=root)
else:
raise ValueError(f"Unknown heading type with tag {elt.heading.tag}")
elif ty == "json":
new_node = Json(elt.json, root=root)
elif ty == "markdown":
if elt.markdown.element_type == MarkdownProto.Type.NATIVE:
new_node = Markdown(elt.markdown, root=root)
elif elt.markdown.element_type == MarkdownProto.Type.CAPTION:
new_node = Caption(elt.markdown, root=root)
elif elt.markdown.element_type == MarkdownProto.Type.LATEX:
new_node = Latex(elt.markdown, root=root)
elif elt.markdown.element_type == MarkdownProto.Type.DIVIDER:
new_node = Divider(elt.markdown, root=root)
else:
raise ValueError(
f"Unknown markdown type {elt.markdown.element_type}"
)
elif ty == "metric":
new_node = Metric(elt.metric, root=root)
elif ty == "multiselect":
new_node = Multiselect(elt.multiselect, root=root)
elif ty == "number_input":
new_node = NumberInput(elt.number_input, root=root)
elif ty == "radio":
new_node = Radio(elt.radio, root=root)
elif ty == "selectbox":
new_node = Selectbox(elt.selectbox, root=root)
elif ty == "slider":
if elt.slider.type == SliderProto.Type.SLIDER:
new_node = Slider(elt.slider, root=root)
elif elt.slider.type == SliderProto.Type.SELECT_SLIDER:
new_node = SelectSlider(elt.slider, root=root)
else:
raise ValueError(f"Slider with unknown type {elt.slider}")
elif ty == "text":
new_node = Text(elt.text, root=root)
elif ty == "text_area":
new_node = TextArea(elt.text_area, root=root)
elif ty == "text_input":
new_node = TextInput(elt.text_input, root=root)
elif ty == "time_input":
new_node = TimeInput(elt.time_input, root=root)
elif ty == "toast":
new_node = Toast(elt.toast, root=root)
else:
new_node = UnknownElement(elt, root=root)
elif delta.WhichOneof("type") == "add_block":
block = delta.add_block
bty = block.WhichOneof("type")
if bty == "chat_message":
new_node = ChatMessage(block.chat_message, root=root)
elif bty == "column":
new_node = Column(block.column, root=root)
elif bty == "expandable":
if block.expandable.icon:
new_node = Status(block.expandable, root=root)
else:
new_node = Expander(block.expandable, root=root)
elif bty == "tab":
new_node = Tab(block.tab, root=root)
else:
new_node = Block(proto=block, root=root)
else:
# add_rows
continue
current_node: Block = root
# Every node up to the end is a Block
for idx in delta_path[:-1]:
children = current_node.children
child = children.get(idx)
if child is None:
child = Block(proto=None, root=root)
children[idx] = child
assert isinstance(child, Block)
current_node = child
# Handle a block when we already have a placeholder for that location
if isinstance(new_node, Block):
placeholder_block = current_node.children.get(delta_path[-1])
if placeholder_block is not None:
new_node.children = placeholder_block.children
current_node.children[delta_path[-1]] = new_node
return root