277 lines
9 KiB
Python
277 lines
9 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import sys
|
|
from typing import Any, Final
|
|
|
|
from streamlit import config
|
|
|
|
# When a Streamlit app is magicified, we insert a `magic_funcs` import near the top of
|
|
# its module's AST: import streamlit.runtime.scriptrunner.magic_funcs as __streamlitmagic__
|
|
MAGIC_MODULE_NAME: Final = "__streamlitmagic__"
|
|
|
|
|
|
def add_magic(code: str, script_path: str) -> Any:
|
|
"""Modifies the code to support magic Streamlit commands.
|
|
|
|
Parameters
|
|
----------
|
|
code : str
|
|
The Python code.
|
|
script_path : str
|
|
The path to the script file.
|
|
|
|
Returns
|
|
-------
|
|
ast.Module
|
|
The syntax tree for the code.
|
|
|
|
"""
|
|
# Pass script_path so we get pretty exceptions.
|
|
tree = ast.parse(code, script_path, "exec")
|
|
|
|
file_ends_in_semicolon = _does_file_end_in_semicolon(tree, code)
|
|
|
|
_modify_ast_subtree(
|
|
tree, is_root=True, file_ends_in_semicolon=file_ends_in_semicolon
|
|
)
|
|
|
|
return tree
|
|
|
|
|
|
def _modify_ast_subtree(
|
|
tree: Any,
|
|
body_attr: str = "body",
|
|
is_root: bool = False,
|
|
file_ends_in_semicolon: bool = False,
|
|
) -> None:
|
|
"""Parses magic commands and modifies the given AST (sub)tree."""
|
|
|
|
body = getattr(tree, body_attr)
|
|
|
|
for i, node in enumerate(body):
|
|
node_type = type(node)
|
|
|
|
# Recursively parses the content of the statements
|
|
# `with` as well as function definitions.
|
|
# Also covers their async counterparts
|
|
if (
|
|
node_type is ast.FunctionDef
|
|
or node_type is ast.With
|
|
or node_type is ast.AsyncFunctionDef
|
|
or node_type is ast.AsyncWith
|
|
):
|
|
_modify_ast_subtree(node)
|
|
|
|
# Recursively parses the content of the statements
|
|
# `for` and `while`.
|
|
# Also covers their async counterparts
|
|
elif (
|
|
node_type is ast.For or node_type is ast.While or node_type is ast.AsyncFor
|
|
):
|
|
_modify_ast_subtree(node)
|
|
_modify_ast_subtree(node, "orelse")
|
|
|
|
# Recursively parses methods in a class.
|
|
elif node_type is ast.ClassDef:
|
|
for inner_node in node.body:
|
|
if type(inner_node) in {ast.FunctionDef, ast.AsyncFunctionDef}:
|
|
_modify_ast_subtree(inner_node)
|
|
|
|
# Recursively parses the contents of try statements,
|
|
# all their handlers (except and else) and the finally body
|
|
elif node_type is ast.Try or (
|
|
sys.version_info >= (3, 11) and node_type is ast.TryStar
|
|
):
|
|
_modify_ast_subtree(node)
|
|
_modify_ast_subtree(node, body_attr="finalbody")
|
|
_modify_ast_subtree(node, body_attr="orelse")
|
|
for handler_node in node.handlers:
|
|
_modify_ast_subtree(handler_node)
|
|
|
|
# Recursively parses if blocks, as well as their else/elif blocks
|
|
# (else/elif are both mapped to orelse)
|
|
# it intentionally does not parse the test expression.
|
|
elif node_type is ast.If:
|
|
_modify_ast_subtree(node)
|
|
_modify_ast_subtree(node, "orelse")
|
|
|
|
elif sys.version_info >= (3, 10) and node_type is ast.Match:
|
|
for case_node in node.cases:
|
|
_modify_ast_subtree(case_node)
|
|
|
|
# Convert standalone expression nodes to st.write
|
|
elif node_type is ast.Expr:
|
|
value = _get_st_write_from_expr(
|
|
node,
|
|
i,
|
|
parent_type=type(tree),
|
|
is_root=is_root,
|
|
is_last_expr=(i == len(body) - 1),
|
|
file_ends_in_semicolon=file_ends_in_semicolon,
|
|
)
|
|
if value is not None:
|
|
node.value = value
|
|
|
|
if is_root:
|
|
# Import Streamlit so we can use it in the new_value above.
|
|
_insert_import_statement(tree)
|
|
|
|
ast.fix_missing_locations(tree)
|
|
|
|
|
|
def _insert_import_statement(tree: Any) -> None:
|
|
"""Insert Streamlit import statement at the top(ish) of the tree."""
|
|
|
|
st_import = _build_st_import_statement()
|
|
|
|
# If the 0th node is already an import statement, put the Streamlit
|
|
# import below that, so we don't break "from __future__ import".
|
|
if tree.body and type(tree.body[0]) in {ast.ImportFrom, ast.Import}:
|
|
tree.body.insert(1, st_import)
|
|
|
|
# If the 0th node is a docstring and the 1st is an import statement,
|
|
# put the Streamlit import below those, so we don't break "from
|
|
# __future__ import".
|
|
elif (
|
|
len(tree.body) > 1
|
|
and (
|
|
type(tree.body[0]) is ast.Expr
|
|
and _is_string_constant_node(tree.body[0].value)
|
|
)
|
|
and type(tree.body[1]) in {ast.ImportFrom, ast.Import}
|
|
):
|
|
tree.body.insert(2, st_import)
|
|
|
|
else:
|
|
tree.body.insert(0, st_import)
|
|
|
|
|
|
def _build_st_import_statement() -> ast.Import:
|
|
"""Build AST node for `import magic_funcs as __streamlitmagic__`."""
|
|
return ast.Import(
|
|
names=[
|
|
ast.alias(
|
|
name="streamlit.runtime.scriptrunner.magic_funcs",
|
|
asname=MAGIC_MODULE_NAME,
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
def _build_st_write_call(nodes: list[Any]) -> ast.Call:
|
|
"""Build AST node for `__streamlitmagic__.transparent_write(*nodes)`."""
|
|
return ast.Call(
|
|
func=ast.Attribute(
|
|
attr="transparent_write",
|
|
value=ast.Name(id=MAGIC_MODULE_NAME, ctx=ast.Load()),
|
|
ctx=ast.Load(),
|
|
),
|
|
args=nodes,
|
|
keywords=[],
|
|
)
|
|
|
|
|
|
def _get_st_write_from_expr(
|
|
node: Any,
|
|
i: int,
|
|
parent_type: Any,
|
|
is_root: bool,
|
|
is_last_expr: bool,
|
|
file_ends_in_semicolon: bool,
|
|
) -> ast.Call | None:
|
|
# Don't wrap function calls
|
|
# (Unless the function call happened at the end of the root node, AND
|
|
# magic.displayLastExprIfNoSemicolon is True. This allows us to support notebook-like
|
|
# behavior, where we display the last function in a cell)
|
|
if type(node.value) is ast.Call and not _is_displayable_last_expr(
|
|
is_root, is_last_expr, file_ends_in_semicolon
|
|
):
|
|
return None
|
|
|
|
# Don't wrap DocString nodes
|
|
# (Unless magic.displayRootDocString, in which case we do wrap the root-level
|
|
# docstring with st.write. This allows us to support notebook-like behavior
|
|
# where you can have a cell with a markdown string)
|
|
if _is_docstring_node(
|
|
node.value, i, parent_type
|
|
) and not _should_display_docstring_like_node_anyway(is_root):
|
|
return None
|
|
|
|
# Don't wrap yield nodes
|
|
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
|
|
return None
|
|
|
|
# Don't wrap await nodes
|
|
if type(node.value) is ast.Await:
|
|
return None
|
|
|
|
# If tuple, call st.write(*the_tuple). This allows us to add a comma at the end of a
|
|
# statement to turn it into an expression that should be
|
|
# st-written. Ex: "np.random.randn(1000, 2),"
|
|
args = node.value.elts if type(node.value) is ast.Tuple else [node.value]
|
|
return _build_st_write_call(args)
|
|
|
|
|
|
def _is_string_constant_node(node: Any) -> bool:
|
|
return isinstance(node, ast.Constant) and isinstance(node.value, str)
|
|
|
|
|
|
def _is_docstring_node(node: Any, node_index: int, parent_type: Any) -> bool:
|
|
return (
|
|
node_index == 0
|
|
and _is_string_constant_node(node)
|
|
and parent_type in {ast.FunctionDef, ast.AsyncFunctionDef, ast.Module}
|
|
)
|
|
|
|
|
|
def _does_file_end_in_semicolon(tree: Any, code: str) -> bool:
|
|
file_ends_in_semicolon = False
|
|
|
|
# Avoid spending time with this operation if magic.displayLastExprIfNoSemicolon is
|
|
# not set.
|
|
if config.get_option("magic.displayLastExprIfNoSemicolon"):
|
|
if len(tree.body) == 0:
|
|
return False
|
|
|
|
last_line_num = getattr(tree.body[-1], "end_lineno", None)
|
|
|
|
if last_line_num is not None:
|
|
last_line_str: str = code.split("\n")[last_line_num - 1]
|
|
file_ends_in_semicolon = last_line_str.strip(" ").endswith(";")
|
|
|
|
return file_ends_in_semicolon
|
|
|
|
|
|
def _is_displayable_last_expr(
|
|
is_root: bool, is_last_expr: bool, file_ends_in_semicolon: bool
|
|
) -> bool:
|
|
return (
|
|
# This is a "displayable last expression" if...
|
|
# ...it's actually the last expression...
|
|
is_last_expr
|
|
# ...in the root scope...
|
|
and is_root
|
|
# ...it does not end in a semicolon...
|
|
and not file_ends_in_semicolon
|
|
# ...and this config option is telling us to show it
|
|
and config.get_option("magic.displayLastExprIfNoSemicolon")
|
|
)
|
|
|
|
|
|
def _should_display_docstring_like_node_anyway(is_root: bool) -> bool:
|
|
return config.get_option("magic.displayRootDocString") and is_root
|