# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import ast import sys from typing import Any, Final from streamlit import config # When a Streamlit app is magicified, we insert a `magic_funcs` import near the top of # its module's AST: import streamlit.runtime.scriptrunner.magic_funcs as __streamlitmagic__ MAGIC_MODULE_NAME: Final = "__streamlitmagic__" def add_magic(code: str, script_path: str) -> Any: """Modifies the code to support magic Streamlit commands. Parameters ---------- code : str The Python code. script_path : str The path to the script file. Returns ------- ast.Module The syntax tree for the code. """ # Pass script_path so we get pretty exceptions. tree = ast.parse(code, script_path, "exec") file_ends_in_semicolon = _does_file_end_in_semicolon(tree, code) _modify_ast_subtree( tree, is_root=True, file_ends_in_semicolon=file_ends_in_semicolon ) return tree def _modify_ast_subtree( tree: Any, body_attr: str = "body", is_root: bool = False, file_ends_in_semicolon: bool = False, ) -> None: """Parses magic commands and modifies the given AST (sub)tree.""" body = getattr(tree, body_attr) for i, node in enumerate(body): node_type = type(node) # Recursively parses the content of the statements # `with` as well as function definitions. # Also covers their async counterparts if ( node_type is ast.FunctionDef or node_type is ast.With or node_type is ast.AsyncFunctionDef or node_type is ast.AsyncWith ): _modify_ast_subtree(node) # Recursively parses the content of the statements # `for` and `while`. # Also covers their async counterparts elif ( node_type is ast.For or node_type is ast.While or node_type is ast.AsyncFor ): _modify_ast_subtree(node) _modify_ast_subtree(node, "orelse") # Recursively parses methods in a class. elif node_type is ast.ClassDef: for inner_node in node.body: if type(inner_node) in {ast.FunctionDef, ast.AsyncFunctionDef}: _modify_ast_subtree(inner_node) # Recursively parses the contents of try statements, # all their handlers (except and else) and the finally body elif node_type is ast.Try or ( sys.version_info >= (3, 11) and node_type is ast.TryStar ): _modify_ast_subtree(node) _modify_ast_subtree(node, body_attr="finalbody") _modify_ast_subtree(node, body_attr="orelse") for handler_node in node.handlers: _modify_ast_subtree(handler_node) # Recursively parses if blocks, as well as their else/elif blocks # (else/elif are both mapped to orelse) # it intentionally does not parse the test expression. elif node_type is ast.If: _modify_ast_subtree(node) _modify_ast_subtree(node, "orelse") elif sys.version_info >= (3, 10) and node_type is ast.Match: for case_node in node.cases: _modify_ast_subtree(case_node) # Convert standalone expression nodes to st.write elif node_type is ast.Expr: value = _get_st_write_from_expr( node, i, parent_type=type(tree), is_root=is_root, is_last_expr=(i == len(body) - 1), file_ends_in_semicolon=file_ends_in_semicolon, ) if value is not None: node.value = value if is_root: # Import Streamlit so we can use it in the new_value above. _insert_import_statement(tree) ast.fix_missing_locations(tree) def _insert_import_statement(tree: Any) -> None: """Insert Streamlit import statement at the top(ish) of the tree.""" st_import = _build_st_import_statement() # If the 0th node is already an import statement, put the Streamlit # import below that, so we don't break "from __future__ import". if tree.body and type(tree.body[0]) in {ast.ImportFrom, ast.Import}: tree.body.insert(1, st_import) # If the 0th node is a docstring and the 1st is an import statement, # put the Streamlit import below those, so we don't break "from # __future__ import". elif ( len(tree.body) > 1 and ( type(tree.body[0]) is ast.Expr and _is_string_constant_node(tree.body[0].value) ) and type(tree.body[1]) in {ast.ImportFrom, ast.Import} ): tree.body.insert(2, st_import) else: tree.body.insert(0, st_import) def _build_st_import_statement() -> ast.Import: """Build AST node for `import magic_funcs as __streamlitmagic__`.""" return ast.Import( names=[ ast.alias( name="streamlit.runtime.scriptrunner.magic_funcs", asname=MAGIC_MODULE_NAME, ) ] ) def _build_st_write_call(nodes: list[Any]) -> ast.Call: """Build AST node for `__streamlitmagic__.transparent_write(*nodes)`.""" return ast.Call( func=ast.Attribute( attr="transparent_write", value=ast.Name(id=MAGIC_MODULE_NAME, ctx=ast.Load()), ctx=ast.Load(), ), args=nodes, keywords=[], ) def _get_st_write_from_expr( node: Any, i: int, parent_type: Any, is_root: bool, is_last_expr: bool, file_ends_in_semicolon: bool, ) -> ast.Call | None: # Don't wrap function calls # (Unless the function call happened at the end of the root node, AND # magic.displayLastExprIfNoSemicolon is True. This allows us to support notebook-like # behavior, where we display the last function in a cell) if type(node.value) is ast.Call and not _is_displayable_last_expr( is_root, is_last_expr, file_ends_in_semicolon ): return None # Don't wrap DocString nodes # (Unless magic.displayRootDocString, in which case we do wrap the root-level # docstring with st.write. This allows us to support notebook-like behavior # where you can have a cell with a markdown string) if _is_docstring_node( node.value, i, parent_type ) and not _should_display_docstring_like_node_anyway(is_root): return None # Don't wrap yield nodes if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom: return None # Don't wrap await nodes if type(node.value) is ast.Await: return None # If tuple, call st.write(*the_tuple). This allows us to add a comma at the end of a # statement to turn it into an expression that should be # st-written. Ex: "np.random.randn(1000, 2)," args = node.value.elts if type(node.value) is ast.Tuple else [node.value] return _build_st_write_call(args) def _is_string_constant_node(node: Any) -> bool: return isinstance(node, ast.Constant) and isinstance(node.value, str) def _is_docstring_node(node: Any, node_index: int, parent_type: Any) -> bool: return ( node_index == 0 and _is_string_constant_node(node) and parent_type in {ast.FunctionDef, ast.AsyncFunctionDef, ast.Module} ) def _does_file_end_in_semicolon(tree: Any, code: str) -> bool: file_ends_in_semicolon = False # Avoid spending time with this operation if magic.displayLastExprIfNoSemicolon is # not set. if config.get_option("magic.displayLastExprIfNoSemicolon"): if len(tree.body) == 0: return False last_line_num = getattr(tree.body[-1], "end_lineno", None) if last_line_num is not None: last_line_str: str = code.split("\n")[last_line_num - 1] file_ends_in_semicolon = last_line_str.strip(" ").endswith(";") return file_ends_in_semicolon def _is_displayable_last_expr( is_root: bool, is_last_expr: bool, file_ends_in_semicolon: bool ) -> bool: return ( # This is a "displayable last expression" if... # ...it's actually the last expression... is_last_expr # ...in the root scope... and is_root # ...it does not end in a semicolon... and not file_ends_in_semicolon # ...and this config option is telling us to show it and config.get_option("magic.displayLastExprIfNoSemicolon") ) def _should_display_docstring_like_node_anyway(is_root: bool) -> bool: return config.get_option("magic.displayRootDocString") and is_root