175 lines
6.6 KiB
Python
175 lines
6.6 KiB
Python
![]() |
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import json
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import Any, Generic, TypeVar
|
||
|
|
||
|
from streamlit.runtime.secrets import AttrDict, secrets_singleton
|
||
|
from streamlit.util import calc_md5
|
||
|
|
||
|
RawConnectionT = TypeVar("RawConnectionT")
|
||
|
|
||
|
|
||
|
class BaseConnection(ABC, Generic[RawConnectionT]):
|
||
|
"""The abstract base class that all Streamlit Connections must inherit from.
|
||
|
|
||
|
This base class provides connection authors with a standardized way to hook into the
|
||
|
``st.connection()`` factory function: connection authors are required to provide an
|
||
|
implementation for the abstract method ``_connect`` in their subclasses.
|
||
|
|
||
|
Additionally, it also provides a few methods/properties designed to make
|
||
|
implementation of connections more convenient. See the docstrings for each of the
|
||
|
methods of this class for more information
|
||
|
|
||
|
.. note::
|
||
|
While providing an implementation of ``_connect`` is technically all that's
|
||
|
required to define a valid connection, connections should also provide the user
|
||
|
with context-specific ways of interacting with the underlying connection object.
|
||
|
For example, the first-party SQLConnection provides a ``query()`` method for
|
||
|
reads and a ``session`` property for more complex operations.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, connection_name: str, **kwargs: Any) -> None:
|
||
|
"""Create a BaseConnection.
|
||
|
|
||
|
This constructor is called by the connection factory machinery when a user
|
||
|
script calls ``st.connection()``.
|
||
|
|
||
|
Subclasses of BaseConnection that want to overwrite this method should take care
|
||
|
to also call the base class' implementation.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
connection_name : str
|
||
|
The name of this connection. This corresponds to the
|
||
|
``[connections.<connection_name>]`` config section in ``st.secrets``.
|
||
|
kwargs : dict
|
||
|
Any other kwargs to pass to this connection class' ``_connect`` method.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
None
|
||
|
"""
|
||
|
self._connection_name = connection_name
|
||
|
self._kwargs = kwargs
|
||
|
|
||
|
self._config_section_hash = calc_md5(json.dumps(self._secrets.to_dict()))
|
||
|
secrets_singleton.file_change_listener.connect(self._on_secrets_changed)
|
||
|
|
||
|
self._raw_instance: RawConnectionT | None = self._connect(**kwargs)
|
||
|
|
||
|
def __del__(self) -> None:
|
||
|
secrets_singleton.file_change_listener.disconnect(self._on_secrets_changed)
|
||
|
|
||
|
def __getattribute__(self, name: str) -> Any:
|
||
|
try:
|
||
|
return object.__getattribute__(self, name)
|
||
|
except AttributeError:
|
||
|
if hasattr(self._instance, name):
|
||
|
raise AttributeError(
|
||
|
f"`{name}` doesn't exist here, but you can call `._instance.{name}` instead"
|
||
|
)
|
||
|
raise
|
||
|
|
||
|
# Methods with default implementations that we don't expect subclasses to want or
|
||
|
# need to overwrite.
|
||
|
def _on_secrets_changed(self, _: str) -> None:
|
||
|
"""Reset the raw connection object when this connection's secrets change.
|
||
|
|
||
|
We don't expect either user scripts or connection authors to have to use or
|
||
|
overwrite this method.
|
||
|
"""
|
||
|
new_hash = calc_md5(json.dumps(self._secrets.to_dict()))
|
||
|
|
||
|
# Only reset the connection if the secrets file section specific to this
|
||
|
# connection has changed.
|
||
|
if new_hash != self._config_section_hash:
|
||
|
self._config_section_hash = new_hash
|
||
|
self.reset()
|
||
|
|
||
|
@property
|
||
|
def _secrets(self) -> AttrDict:
|
||
|
"""Get the secrets for this connection from the corresponding st.secrets section.
|
||
|
|
||
|
We expect this property to be used primarily by connection authors when they
|
||
|
are implementing their class' ``_connect`` method. User scripts should, for the
|
||
|
most part, have no reason to use this property.
|
||
|
"""
|
||
|
connections_section = None
|
||
|
if secrets_singleton.load_if_toml_exists():
|
||
|
connections_section = secrets_singleton.get("connections")
|
||
|
|
||
|
if type(connections_section) is not AttrDict:
|
||
|
return AttrDict({})
|
||
|
|
||
|
return connections_section.get(self._connection_name, AttrDict({}))
|
||
|
|
||
|
def reset(self) -> None:
|
||
|
"""Reset this connection so that it gets reinitialized the next time it's used.
|
||
|
|
||
|
This method can be useful when a connection has become stale, an auth token has
|
||
|
expired, or in similar scenarios where a broken connection might be fixed by
|
||
|
reinitializing it. Note that some connection methods may already use ``reset()``
|
||
|
in their error handling code.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
None
|
||
|
|
||
|
Example
|
||
|
-------
|
||
|
>>> import streamlit as st
|
||
|
>>>
|
||
|
>>> conn = st.connection("my_conn")
|
||
|
>>>
|
||
|
>>> # Reset the connection before using it if it isn't healthy
|
||
|
>>> # Note: is_healthy() isn't a real method and is just shown for example here.
|
||
|
>>> if not conn.is_healthy():
|
||
|
... conn.reset()
|
||
|
>>>
|
||
|
>>> # Do stuff with conn...
|
||
|
"""
|
||
|
self._raw_instance = None
|
||
|
|
||
|
@property
|
||
|
def _instance(self) -> RawConnectionT:
|
||
|
"""Get an instance of the underlying connection, creating a new one if needed."""
|
||
|
if self._raw_instance is None:
|
||
|
self._raw_instance = self._connect(**self._kwargs)
|
||
|
|
||
|
return self._raw_instance
|
||
|
|
||
|
# Abstract fields/methods that subclasses of BaseConnection must implement
|
||
|
@abstractmethod
|
||
|
def _connect(self, **kwargs: Any) -> RawConnectionT:
|
||
|
"""Create an instance of an underlying connection object.
|
||
|
|
||
|
This abstract method is the one method that we require subclasses of
|
||
|
BaseConnection to provide an implementation for. It is called when first
|
||
|
creating a connection and when reconnecting after a connection is reset.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
kwargs : dict
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
RawConnectionT
|
||
|
The underlying connection object.
|
||
|
"""
|
||
|
raise NotImplementedError
|