631 lines
21 KiB
Python
631 lines
21 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Hashing for st.cache_data and st.cache_resource."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import collections.abc
|
|
import dataclasses
|
|
import datetime
|
|
import functools
|
|
import hashlib
|
|
import inspect
|
|
import io
|
|
import os
|
|
import pickle
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
import uuid
|
|
import weakref
|
|
from enum import Enum
|
|
from re import Pattern
|
|
from types import MappingProxyType
|
|
from typing import TYPE_CHECKING, Any, Callable, Final, Union, cast
|
|
|
|
from typing_extensions import TypeAlias
|
|
|
|
from streamlit import logger, type_util, util
|
|
from streamlit.errors import StreamlitAPIException
|
|
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
|
|
from streamlit.runtime.caching.cache_type import CacheType
|
|
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
|
|
|
if TYPE_CHECKING:
|
|
import numpy.typing as npt
|
|
from PIL.Image import Image
|
|
|
|
_LOGGER: Final = logger.get_logger(__name__)
|
|
|
|
# If a dataframe has more than this many rows, we consider it large and hash a sample.
|
|
_PANDAS_ROWS_LARGE: Final = 50_000
|
|
_PANDAS_SAMPLE_SIZE: Final = 10_000
|
|
|
|
# Similar to dataframes, we also sample large numpy arrays.
|
|
_NP_SIZE_LARGE: Final = 500_000
|
|
_NP_SAMPLE_SIZE: Final = 100_000
|
|
|
|
HashFuncsDict: TypeAlias = dict[Union[str, type[Any]], Callable[[Any], Any]]
|
|
|
|
# Arbitrary item to denote where we found a cycle in a hashed object.
|
|
# This allows us to hash self-referencing lists, dictionaries, etc.
|
|
_CYCLE_PLACEHOLDER: Final = (
|
|
b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
|
|
)
|
|
|
|
|
|
class UserHashError(StreamlitAPIException):
|
|
def __init__(
|
|
self,
|
|
orig_exc: BaseException,
|
|
object_to_hash: Any,
|
|
hash_func: Callable[[Any], Any],
|
|
cache_type: CacheType | None = None,
|
|
) -> None:
|
|
self.alternate_name = type(orig_exc).__name__
|
|
self.hash_func = hash_func
|
|
self.cache_type = cache_type
|
|
|
|
msg = self._get_message_from_func(orig_exc, object_to_hash)
|
|
|
|
super().__init__(msg)
|
|
self.with_traceback(orig_exc.__traceback__)
|
|
|
|
def _get_message_from_func(
|
|
self,
|
|
orig_exc: BaseException,
|
|
cached_func: Any,
|
|
) -> str:
|
|
args = self._get_error_message_args(orig_exc, cached_func)
|
|
|
|
return (
|
|
f"""
|
|
{args["orig_exception_desc"]}
|
|
|
|
This error is likely due to a bug in {args["hash_func_name"]}, which is a
|
|
user-defined hash function that was passed into the `{args["cache_primitive"]}` decorator of
|
|
{args["object_desc"]}.
|
|
|
|
{args["hash_func_name"]} failed when hashing an object of type
|
|
`{args["failed_obj_type_str"]}`. If you don't know where that object is coming from,
|
|
try looking at the hash chain below for an object that you do recognize, then
|
|
pass that to `hash_funcs` instead:
|
|
|
|
```
|
|
{args["hash_stack"]}
|
|
```
|
|
|
|
If you think this is actually a Streamlit bug, please
|
|
[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).
|
|
"""
|
|
).strip("\n")
|
|
|
|
def _get_error_message_args(
|
|
self,
|
|
orig_exc: BaseException,
|
|
failed_obj: Any,
|
|
) -> dict[str, Any]:
|
|
hash_source = hash_stacks.current.hash_source
|
|
|
|
failed_obj_type_str = type_util.get_fqn_type(failed_obj)
|
|
|
|
if hash_source is None:
|
|
object_desc = "something"
|
|
elif hasattr(hash_source, "__name__"):
|
|
object_desc = f"`{hash_source.__name__}()`"
|
|
else:
|
|
object_desc = "a function"
|
|
|
|
decorator_name = ""
|
|
if self.cache_type is CacheType.RESOURCE:
|
|
decorator_name = "@st.cache_resource"
|
|
elif self.cache_type is CacheType.DATA:
|
|
decorator_name = "@st.cache_data"
|
|
|
|
hash_func_name = (
|
|
f"`{self.hash_func.__name__}()`"
|
|
if hasattr(self.hash_func, "__name__")
|
|
else "a function"
|
|
)
|
|
|
|
return {
|
|
"orig_exception_desc": str(orig_exc),
|
|
"failed_obj_type_str": failed_obj_type_str,
|
|
"hash_stack": hash_stacks.current.pretty_print(),
|
|
"object_desc": object_desc,
|
|
"cache_primitive": decorator_name,
|
|
"hash_func_name": hash_func_name,
|
|
}
|
|
|
|
|
|
def update_hash(
|
|
val: Any,
|
|
hasher: Any,
|
|
cache_type: CacheType,
|
|
hash_source: Callable[..., Any] | None = None,
|
|
hash_funcs: HashFuncsDict | None = None,
|
|
) -> None:
|
|
"""Updates a hashlib hasher with the hash of val.
|
|
|
|
This is the main entrypoint to hashing.py.
|
|
"""
|
|
|
|
hash_stacks.current.hash_source = hash_source
|
|
|
|
ch = _CacheFuncHasher(cache_type, hash_funcs)
|
|
ch.update(hasher, val)
|
|
|
|
|
|
class _HashStack:
|
|
"""Stack of what has been hashed, for debug and circular reference detection.
|
|
|
|
This internally keeps 1 stack per thread.
|
|
|
|
Internally, this stores the ID of pushed objects rather than the objects
|
|
themselves because otherwise the "in" operator inside __contains__ would
|
|
fail for objects that don't return a boolean for "==" operator. For
|
|
example, arr == 10 where arr is a NumPy array returns another NumPy array.
|
|
This causes the "in" to crash since it expects a boolean.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._stack: collections.OrderedDict[int, list[Any]] = collections.OrderedDict()
|
|
# A function that we decorate with streamlit cache
|
|
# primitive (st.cache_data or st.cache_resource).
|
|
self.hash_source: Callable[..., Any] | None = None
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
def push(self, val: Any) -> None:
|
|
self._stack[id(val)] = val
|
|
|
|
def pop(self) -> None:
|
|
self._stack.popitem()
|
|
|
|
def __contains__(self, val: Any) -> bool:
|
|
return id(val) in self._stack
|
|
|
|
def pretty_print(self) -> str:
|
|
def to_str(v: Any) -> str:
|
|
try:
|
|
return f"Object of type {type_util.get_fqn_type(v)}: {v}"
|
|
except Exception:
|
|
return "<Unable to convert item to string>"
|
|
|
|
return "\n".join(to_str(x) for x in reversed(self._stack.values()))
|
|
|
|
|
|
class _HashStacks:
|
|
"""Stacks of what has been hashed, with at most 1 stack per thread."""
|
|
|
|
def __init__(self) -> None:
|
|
self._stacks: weakref.WeakKeyDictionary[threading.Thread, _HashStack] = (
|
|
weakref.WeakKeyDictionary()
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
@property
|
|
def current(self) -> _HashStack:
|
|
current_thread = threading.current_thread()
|
|
|
|
stack = self._stacks.get(current_thread, None)
|
|
|
|
if stack is None:
|
|
stack = _HashStack()
|
|
self._stacks[current_thread] = stack
|
|
|
|
return stack
|
|
|
|
|
|
hash_stacks = _HashStacks()
|
|
|
|
|
|
def _int_to_bytes(i: int) -> bytes:
|
|
num_bytes = (i.bit_length() + 8) // 8
|
|
return i.to_bytes(num_bytes, "little", signed=True)
|
|
|
|
|
|
def _float_to_bytes(f: float) -> bytes:
|
|
# Lazy-load for performance reasons.
|
|
import struct
|
|
|
|
# Floats are 64bit in Python, so we need to use the "d" format.
|
|
return struct.pack("<d", f)
|
|
|
|
|
|
def _key(obj: Any | None) -> Any:
|
|
"""Return key for memoization."""
|
|
|
|
if obj is None:
|
|
return None
|
|
|
|
def is_simple(obj: Any) -> bool:
|
|
return (
|
|
isinstance(obj, (bytes, bytearray, str, float, int, bool, uuid.UUID))
|
|
or obj is None
|
|
)
|
|
|
|
if is_simple(obj):
|
|
return obj
|
|
|
|
if isinstance(obj, tuple) and all(map(is_simple, obj)):
|
|
return obj
|
|
|
|
if isinstance(obj, list) and all(map(is_simple, obj)):
|
|
return ("__l", tuple(obj))
|
|
|
|
if inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.iscode(obj):
|
|
return id(obj)
|
|
|
|
return NoResult
|
|
|
|
|
|
class _CacheFuncHasher:
|
|
"""A hasher that can hash objects with cycles."""
|
|
|
|
def __init__(
|
|
self, cache_type: CacheType, hash_funcs: HashFuncsDict | None = None
|
|
) -> None:
|
|
# Can't use types as the keys in the internal _hash_funcs because
|
|
# we always remove user-written modules from memory when rerunning a
|
|
# script in order to reload it and grab the latest code changes.
|
|
# (See LocalSourcesWatcher.py:on_file_changed) This causes
|
|
# the type object to refer to different underlying class instances each run,
|
|
# so type-based comparisons fail. To solve this, we use the types converted
|
|
# to fully-qualified strings as keys in our internal dict.
|
|
self._hash_funcs: HashFuncsDict
|
|
if hash_funcs:
|
|
self._hash_funcs = {
|
|
k if isinstance(k, str) else type_util.get_fqn(k): v
|
|
for k, v in hash_funcs.items()
|
|
}
|
|
else:
|
|
self._hash_funcs = {}
|
|
self._hashes: dict[Any, bytes] = {}
|
|
|
|
# The number of the bytes in the hash.
|
|
self.size = 0
|
|
|
|
self.cache_type = cache_type
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
def to_bytes(self, obj: Any) -> bytes:
|
|
"""Add memoization to _to_bytes and protect against cycles in data structures."""
|
|
tname = type(obj).__qualname__.encode()
|
|
key = (tname, _key(obj))
|
|
|
|
# Memoize if possible.
|
|
if key[1] is not NoResult and key in self._hashes:
|
|
return self._hashes[key]
|
|
|
|
# Break recursive cycles.
|
|
if obj in hash_stacks.current:
|
|
return _CYCLE_PLACEHOLDER
|
|
|
|
hash_stacks.current.push(obj)
|
|
|
|
try:
|
|
# Hash the input
|
|
b = b"%s:%s" % (tname, self._to_bytes(obj))
|
|
|
|
# Hmmm... It's possible that the size calculation is wrong. When we
|
|
# call to_bytes inside _to_bytes things get double-counted.
|
|
self.size += sys.getsizeof(b)
|
|
|
|
if key[1] is not NoResult:
|
|
self._hashes[key] = b
|
|
|
|
finally:
|
|
# In case an UnhashableTypeError (or other) error is thrown, clean up the
|
|
# stack so we don't get false positives in future hashing calls
|
|
hash_stacks.current.pop()
|
|
|
|
return b
|
|
|
|
def update(self, hasher: Any, obj: Any) -> None:
|
|
"""Update the provided hasher with the hash of an object."""
|
|
b = self.to_bytes(obj)
|
|
hasher.update(b)
|
|
|
|
def _to_bytes(self, obj: Any) -> bytes:
|
|
"""Hash objects to bytes, including code with dependencies.
|
|
|
|
Python's built in `hash` does not produce consistent results across
|
|
runs.
|
|
"""
|
|
|
|
h = hashlib.new("md5", usedforsecurity=False)
|
|
|
|
if type_util.is_type(obj, "unittest.mock.Mock") or type_util.is_type(
|
|
obj, "unittest.mock.MagicMock"
|
|
):
|
|
# Mock objects can appear to be infinitely
|
|
# deep, so we don't try to hash them at all.
|
|
return self.to_bytes(id(obj))
|
|
|
|
if isinstance(obj, (bytes, bytearray)):
|
|
return obj
|
|
|
|
if type_util.get_fqn_type(obj) in self._hash_funcs:
|
|
# Escape hatch for unsupported objects
|
|
hash_func = self._hash_funcs[type_util.get_fqn_type(obj)]
|
|
try:
|
|
output = hash_func(obj)
|
|
except Exception as ex:
|
|
raise UserHashError(
|
|
ex, obj, hash_func=hash_func, cache_type=self.cache_type
|
|
) from ex
|
|
return self.to_bytes(output)
|
|
|
|
if isinstance(obj, str):
|
|
return obj.encode()
|
|
|
|
if isinstance(obj, float):
|
|
return _float_to_bytes(obj)
|
|
|
|
if isinstance(obj, int):
|
|
return _int_to_bytes(obj)
|
|
|
|
if isinstance(obj, uuid.UUID):
|
|
return obj.bytes
|
|
|
|
if isinstance(obj, datetime.datetime):
|
|
return obj.isoformat().encode()
|
|
|
|
if isinstance(obj, (list, tuple)):
|
|
for item in obj:
|
|
self.update(h, item)
|
|
return h.digest()
|
|
|
|
if isinstance(obj, dict):
|
|
for item in obj.items():
|
|
self.update(h, item)
|
|
return h.digest()
|
|
|
|
if obj is None:
|
|
return b"0"
|
|
|
|
if obj is True:
|
|
return b"1"
|
|
|
|
if obj is False:
|
|
return b"0"
|
|
|
|
if not isinstance(obj, type) and dataclasses.is_dataclass(obj):
|
|
return self.to_bytes(dataclasses.asdict(obj))
|
|
if isinstance(obj, Enum):
|
|
return str(obj).encode()
|
|
|
|
if type_util.is_type(obj, "pandas.core.series.Series"):
|
|
import pandas as pd
|
|
|
|
series_obj: pd.Series = cast("pd.Series", obj)
|
|
self.update(h, series_obj.size)
|
|
self.update(h, series_obj.dtype.name)
|
|
|
|
if len(series_obj) >= _PANDAS_ROWS_LARGE:
|
|
series_obj = series_obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
|
|
|
|
try:
|
|
self.update(
|
|
h, pd.util.hash_pandas_object(series_obj).to_numpy().tobytes()
|
|
)
|
|
return h.digest()
|
|
except TypeError:
|
|
_LOGGER.warning(
|
|
"Pandas Series hash failed. Falling back to pickling the object.",
|
|
exc_info=True,
|
|
)
|
|
|
|
# Use pickle if pandas cannot hash the object for example if
|
|
# it contains unhashable objects.
|
|
return b"%s" % pickle.dumps(series_obj, pickle.HIGHEST_PROTOCOL)
|
|
|
|
elif type_util.is_type(obj, "pandas.core.frame.DataFrame"):
|
|
import pandas as pd
|
|
|
|
df_obj: pd.DataFrame = cast("pd.DataFrame", obj)
|
|
self.update(h, df_obj.shape)
|
|
|
|
if len(df_obj) >= _PANDAS_ROWS_LARGE:
|
|
df_obj = df_obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
|
|
try:
|
|
column_hash_bytes = self.to_bytes(
|
|
pd.util.hash_pandas_object(df_obj.dtypes)
|
|
)
|
|
self.update(h, column_hash_bytes)
|
|
values_hash_bytes = self.to_bytes(pd.util.hash_pandas_object(df_obj))
|
|
self.update(h, values_hash_bytes)
|
|
return h.digest()
|
|
except TypeError:
|
|
_LOGGER.warning(
|
|
"Pandas DataFrame hash failed. Falling back to pickling the object.",
|
|
exc_info=True,
|
|
)
|
|
|
|
# Use pickle if pandas cannot hash the object for example if
|
|
# it contains unhashable objects.
|
|
return b"%s" % pickle.dumps(df_obj, pickle.HIGHEST_PROTOCOL)
|
|
|
|
elif type_util.is_type(obj, "polars.series.series.Series"):
|
|
import polars as pl # type: ignore[import-not-found]
|
|
|
|
obj = cast("pl.Series", obj)
|
|
self.update(h, str(obj.dtype).encode())
|
|
self.update(h, obj.shape)
|
|
|
|
if len(obj) >= _PANDAS_ROWS_LARGE:
|
|
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, seed=0)
|
|
|
|
try:
|
|
self.update(h, obj.hash(seed=0).to_arrow().to_string().encode())
|
|
return h.digest()
|
|
except TypeError:
|
|
_LOGGER.warning(
|
|
"Polars Series hash failed. Falling back to pickling the object.",
|
|
exc_info=True,
|
|
)
|
|
|
|
# Use pickle if polars cannot hash the object for example if
|
|
# it contains unhashable objects.
|
|
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
|
elif type_util.is_type(obj, "polars.dataframe.frame.DataFrame"):
|
|
import polars as pl # noqa: TC002
|
|
|
|
obj = cast("pl.DataFrame", obj)
|
|
self.update(h, obj.shape)
|
|
|
|
if len(obj) >= _PANDAS_ROWS_LARGE:
|
|
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, seed=0)
|
|
try:
|
|
for c, t in obj.schema.items():
|
|
self.update(h, c.encode())
|
|
self.update(h, str(t).encode())
|
|
|
|
values_hash_bytes = (
|
|
obj.hash_rows(seed=0).hash(seed=0).to_arrow().to_string().encode()
|
|
)
|
|
|
|
self.update(h, values_hash_bytes)
|
|
return h.digest()
|
|
except TypeError:
|
|
_LOGGER.warning(
|
|
"Polars DataFrame hash failed. Falling back to pickling the object.",
|
|
exc_info=True,
|
|
)
|
|
|
|
# Use pickle if polars cannot hash the object for example if
|
|
# it contains unhashable objects.
|
|
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
|
elif type_util.is_type(obj, "numpy.ndarray"):
|
|
np_obj: npt.NDArray[Any] = cast("npt.NDArray[Any]", obj)
|
|
self.update(h, np_obj.shape)
|
|
self.update(h, str(np_obj.dtype))
|
|
|
|
if np_obj.size >= _NP_SIZE_LARGE:
|
|
import numpy as np
|
|
|
|
state = np.random.RandomState(0)
|
|
np_obj = state.choice(np_obj.flat, size=_NP_SAMPLE_SIZE)
|
|
|
|
self.update(h, np_obj.tobytes())
|
|
return h.digest()
|
|
elif type_util.is_type(obj, "PIL.Image.Image"):
|
|
import numpy as np
|
|
|
|
pil_obj: Image = cast("Image", obj)
|
|
|
|
# we don't just hash the results of obj.tobytes() because we want to use
|
|
# the sampling logic for numpy data
|
|
np_array = np.frombuffer(pil_obj.tobytes(), dtype="uint8")
|
|
return self.to_bytes(np_array)
|
|
|
|
elif inspect.isbuiltin(obj):
|
|
return bytes(obj.__name__.encode())
|
|
|
|
elif isinstance(obj, (MappingProxyType, collections.abc.ItemsView)):
|
|
return self.to_bytes(dict(obj))
|
|
|
|
elif type_util.is_type(obj, "builtins.getset_descriptor"):
|
|
return bytes(obj.__qualname__.encode())
|
|
|
|
elif isinstance(obj, UploadedFile):
|
|
# UploadedFile is a BytesIO (thus IOBase) but has a name.
|
|
# It does not have a timestamp so this must come before
|
|
# temporary files
|
|
self.update(h, obj.name)
|
|
self.update(h, obj.tell())
|
|
self.update(h, obj.getvalue())
|
|
return h.digest()
|
|
|
|
elif hasattr(obj, "name") and (
|
|
# Handle temporary files used during testing
|
|
isinstance(obj, (io.IOBase, tempfile._TemporaryFileWrapper))
|
|
):
|
|
# Hash files as name + last modification date + offset.
|
|
# NB: we're using hasattr("name") to differentiate between
|
|
# on-disk and in-memory StringIO/BytesIO file representations.
|
|
# That means that this condition must come *before* the next
|
|
# condition, which just checks for StringIO/BytesIO.
|
|
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
|
|
self.update(h, obj_name)
|
|
self.update(h, os.path.getmtime(obj_name))
|
|
self.update(h, obj.tell())
|
|
return h.digest()
|
|
|
|
elif isinstance(obj, Pattern):
|
|
return self.to_bytes([obj.pattern, obj.flags])
|
|
|
|
elif isinstance(obj, (io.StringIO, io.BytesIO)):
|
|
# Hash in-memory StringIO/BytesIO by their full contents
|
|
# and seek position.
|
|
self.update(h, obj.tell())
|
|
self.update(h, obj.getvalue())
|
|
return h.digest()
|
|
|
|
elif type_util.is_type(obj, "numpy.ufunc"):
|
|
# For numpy.remainder, this returns remainder.
|
|
return bytes(obj.__name__.encode())
|
|
|
|
elif inspect.ismodule(obj):
|
|
# TODO: Figure out how to best show this kind of warning to the
|
|
# user. In the meantime, show nothing. This scenario is too common,
|
|
# so the current warning is quite annoying...
|
|
# st.warning(('Streamlit does not support hashing modules. '
|
|
# 'We did not hash `%s`.') % obj.__name__)
|
|
# TODO: Hash more than just the name for internal modules.
|
|
return self.to_bytes(obj.__name__)
|
|
|
|
elif inspect.isclass(obj):
|
|
# TODO: Figure out how to best show this kind of warning to the
|
|
# user. In the meantime, show nothing. This scenario is too common,
|
|
# (e.g. in every "except" statement) so the current warning is
|
|
# quite annoying...
|
|
# st.warning(('Streamlit does not support hashing classes. '
|
|
# 'We did not hash `%s`.') % obj.__name__)
|
|
# TODO: Hash more than just the name of classes.
|
|
return self.to_bytes(obj.__name__)
|
|
|
|
elif isinstance(obj, functools.partial):
|
|
# The return value of functools.partial is not a plain function:
|
|
# it's a callable object that remembers the original function plus
|
|
# the values you pickled into it. So here we need to special-case it.
|
|
self.update(h, obj.args)
|
|
self.update(h, obj.func)
|
|
self.update(h, obj.keywords)
|
|
return h.digest()
|
|
|
|
else:
|
|
# As a last resort, hash the output of the object's __reduce__ method
|
|
try:
|
|
reduce_data = obj.__reduce__()
|
|
except Exception as ex:
|
|
raise UnhashableTypeError() from ex
|
|
|
|
for item in reduce_data:
|
|
self.update(h, item)
|
|
return h.digest()
|
|
|
|
|
|
class NoResult:
|
|
"""Placeholder class for return values when None is meaningful."""
|
|
|
|
pass
|