team-10/env/Lib/site-packages/streamlit/connections/base_connection.py
2025-08-02 07:34:44 +02:00

174 lines
6.6 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
from streamlit.runtime.secrets import AttrDict, secrets_singleton
from streamlit.util import calc_md5
RawConnectionT = TypeVar("RawConnectionT")
class BaseConnection(ABC, Generic[RawConnectionT]):
"""The abstract base class that all Streamlit Connections must inherit from.
This base class provides connection authors with a standardized way to hook into the
``st.connection()`` factory function: connection authors are required to provide an
implementation for the abstract method ``_connect`` in their subclasses.
Additionally, it also provides a few methods/properties designed to make
implementation of connections more convenient. See the docstrings for each of the
methods of this class for more information
.. note::
While providing an implementation of ``_connect`` is technically all that's
required to define a valid connection, connections should also provide the user
with context-specific ways of interacting with the underlying connection object.
For example, the first-party SQLConnection provides a ``query()`` method for
reads and a ``session`` property for more complex operations.
"""
def __init__(self, connection_name: str, **kwargs: Any) -> None:
"""Create a BaseConnection.
This constructor is called by the connection factory machinery when a user
script calls ``st.connection()``.
Subclasses of BaseConnection that want to overwrite this method should take care
to also call the base class' implementation.
Parameters
----------
connection_name : str
The name of this connection. This corresponds to the
``[connections.<connection_name>]`` config section in ``st.secrets``.
kwargs : dict
Any other kwargs to pass to this connection class' ``_connect`` method.
Returns
-------
None
"""
self._connection_name = connection_name
self._kwargs = kwargs
self._config_section_hash = calc_md5(json.dumps(self._secrets.to_dict()))
secrets_singleton.file_change_listener.connect(self._on_secrets_changed)
self._raw_instance: RawConnectionT | None = self._connect(**kwargs)
def __del__(self) -> None:
secrets_singleton.file_change_listener.disconnect(self._on_secrets_changed)
def __getattribute__(self, name: str) -> Any:
try:
return object.__getattribute__(self, name)
except AttributeError:
if hasattr(self._instance, name):
raise AttributeError(
f"`{name}` doesn't exist here, but you can call `._instance.{name}` instead"
)
raise
# Methods with default implementations that we don't expect subclasses to want or
# need to overwrite.
def _on_secrets_changed(self, _: str) -> None:
"""Reset the raw connection object when this connection's secrets change.
We don't expect either user scripts or connection authors to have to use or
overwrite this method.
"""
new_hash = calc_md5(json.dumps(self._secrets.to_dict()))
# Only reset the connection if the secrets file section specific to this
# connection has changed.
if new_hash != self._config_section_hash:
self._config_section_hash = new_hash
self.reset()
@property
def _secrets(self) -> AttrDict:
"""Get the secrets for this connection from the corresponding st.secrets section.
We expect this property to be used primarily by connection authors when they
are implementing their class' ``_connect`` method. User scripts should, for the
most part, have no reason to use this property.
"""
connections_section = None
if secrets_singleton.load_if_toml_exists():
connections_section = secrets_singleton.get("connections")
if type(connections_section) is not AttrDict:
return AttrDict({})
return connections_section.get(self._connection_name, AttrDict({}))
def reset(self) -> None:
"""Reset this connection so that it gets reinitialized the next time it's used.
This method can be useful when a connection has become stale, an auth token has
expired, or in similar scenarios where a broken connection might be fixed by
reinitializing it. Note that some connection methods may already use ``reset()``
in their error handling code.
Returns
-------
None
Example
-------
>>> import streamlit as st
>>>
>>> conn = st.connection("my_conn")
>>>
>>> # Reset the connection before using it if it isn't healthy
>>> # Note: is_healthy() isn't a real method and is just shown for example here.
>>> if not conn.is_healthy():
... conn.reset()
>>>
>>> # Do stuff with conn...
"""
self._raw_instance = None
@property
def _instance(self) -> RawConnectionT:
"""Get an instance of the underlying connection, creating a new one if needed."""
if self._raw_instance is None:
self._raw_instance = self._connect(**self._kwargs)
return self._raw_instance
# Abstract fields/methods that subclasses of BaseConnection must implement
@abstractmethod
def _connect(self, **kwargs: Any) -> RawConnectionT:
"""Create an instance of an underlying connection object.
This abstract method is the one method that we require subclasses of
BaseConnection to provide an implementation for. It is called when first
creating a connection and when reconnecting after a connection is reset.
Parameters
----------
kwargs : dict
Returns
-------
RawConnectionT
The underlying connection object.
"""
raise NotImplementedError