485 lines
16 KiB
Python
485 lines
16 KiB
Python
![]() |
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
"""A wrapper for simple PyDeck scatter charts."""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import copy
|
||
|
import json
|
||
|
from typing import TYPE_CHECKING, Any, Final, cast
|
||
|
|
||
|
from streamlit import config, dataframe_util
|
||
|
from streamlit.elements import deck_gl_json_chart
|
||
|
from streamlit.elements.lib.color_util import (
|
||
|
Color,
|
||
|
IntColorTuple,
|
||
|
is_color_like,
|
||
|
to_int_color_tuple,
|
||
|
)
|
||
|
from streamlit.errors import StreamlitAPIException
|
||
|
from streamlit.proto.DeckGlJsonChart_pb2 import DeckGlJsonChart as DeckGlJsonChartProto
|
||
|
from streamlit.runtime.metrics_util import gather_metrics
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from collections.abc import Collection
|
||
|
|
||
|
from pandas import DataFrame
|
||
|
|
||
|
from streamlit.dataframe_util import Data
|
||
|
from streamlit.delta_generator import DeltaGenerator
|
||
|
|
||
|
# Map used as the basis for st.map.
|
||
|
_DEFAULT_MAP: Final[dict[str, Any]] = dict(deck_gl_json_chart.EMPTY_MAP)
|
||
|
|
||
|
# Other default parameters for st.map.
|
||
|
_DEFAULT_LAT_COL_NAMES: Final = {"lat", "latitude", "LAT", "LATITUDE"}
|
||
|
_DEFAULT_LON_COL_NAMES: Final = {"lon", "longitude", "LON", "LONGITUDE"}
|
||
|
_DEFAULT_COLOR: Final = (200, 30, 0, 160)
|
||
|
_DEFAULT_SIZE: Final = 100
|
||
|
_DEFAULT_ZOOM_LEVEL: Final = 12
|
||
|
_ZOOM_LEVELS: Final = [
|
||
|
360,
|
||
|
180,
|
||
|
90,
|
||
|
45,
|
||
|
22.5,
|
||
|
11.25,
|
||
|
5.625,
|
||
|
2.813,
|
||
|
1.406,
|
||
|
0.703,
|
||
|
0.352,
|
||
|
0.176,
|
||
|
0.088,
|
||
|
0.044,
|
||
|
0.022,
|
||
|
0.011,
|
||
|
0.005,
|
||
|
0.003,
|
||
|
0.001,
|
||
|
0.0005,
|
||
|
0.00025,
|
||
|
]
|
||
|
|
||
|
|
||
|
class MapMixin:
|
||
|
@gather_metrics("map")
|
||
|
def map(
|
||
|
self,
|
||
|
data: Data = None,
|
||
|
*,
|
||
|
latitude: str | None = None,
|
||
|
longitude: str | None = None,
|
||
|
color: None | str | Color = None,
|
||
|
size: None | str | float = None,
|
||
|
zoom: int | None = None,
|
||
|
use_container_width: bool = True,
|
||
|
width: int | None = None,
|
||
|
height: int | None = None,
|
||
|
) -> DeltaGenerator:
|
||
|
"""Display a map with a scatterplot overlaid onto it.
|
||
|
|
||
|
This is a wrapper around ``st.pydeck_chart`` to quickly create
|
||
|
scatterplot charts on top of a map, with auto-centering and auto-zoom.
|
||
|
|
||
|
When using this command, a service called Carto_ provides the map tiles to render
|
||
|
map content. If you're using advanced PyDeck features you may need to obtain
|
||
|
an API key from Carto first. You can do that as
|
||
|
``pydeck.Deck(api_keys={"carto": YOUR_KEY})`` or by setting the CARTO_API_KEY
|
||
|
environment variable. See `PyDeck's documentation`_ for more information.
|
||
|
|
||
|
Another common provider for map tiles is Mapbox_. If you prefer to use that,
|
||
|
you'll need to create an account at https://mapbox.com and specify your Mapbox
|
||
|
key when creating the ``pydeck.Deck`` object. You can do that as
|
||
|
``pydeck.Deck(api_keys={"mapbox": YOUR_KEY})`` or by setting the MAPBOX_API_KEY
|
||
|
environment variable.
|
||
|
|
||
|
.. _Carto: https://carto.com
|
||
|
.. _Mapbox: https://mapbox.com
|
||
|
.. _PyDeck's documentation: https://deckgl.readthedocs.io/en/latest/deck.html
|
||
|
|
||
|
Carto and Mapbox are third-party products and Streamlit accepts no responsibility
|
||
|
or liability of any kind for Carto or Mapbox, or for any content or information
|
||
|
made available by Carto or Mapbox. The use of Carto or Mapbox is governed by
|
||
|
their respective Terms of Use.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
data : Anything supported by st.dataframe
|
||
|
The data to be plotted.
|
||
|
|
||
|
latitude : str or None
|
||
|
The name of the column containing the latitude coordinates of
|
||
|
the datapoints in the chart.
|
||
|
|
||
|
If None, the latitude data will come from any column named 'lat',
|
||
|
'latitude', 'LAT', or 'LATITUDE'.
|
||
|
|
||
|
longitude : str or None
|
||
|
The name of the column containing the longitude coordinates of
|
||
|
the datapoints in the chart.
|
||
|
|
||
|
If None, the longitude data will come from any column named 'lon',
|
||
|
'longitude', 'LON', or 'LONGITUDE'.
|
||
|
|
||
|
color : str or tuple or None
|
||
|
The color of the circles representing each datapoint.
|
||
|
|
||
|
Can be:
|
||
|
|
||
|
- None, to use the default color.
|
||
|
- A hex string like "#ffaa00" or "#ffaa0088".
|
||
|
- An RGB or RGBA tuple with the red, green, blue, and alpha
|
||
|
components specified as ints from 0 to 255 or floats from 0.0 to
|
||
|
1.0.
|
||
|
- The name of the column to use for the color. Cells in this column
|
||
|
should contain colors represented as a hex string or color tuple,
|
||
|
as described above.
|
||
|
|
||
|
size : str or float or None
|
||
|
The size of the circles representing each point, in meters.
|
||
|
|
||
|
This can be:
|
||
|
|
||
|
- None, to use the default size.
|
||
|
- A number like 100, to specify a single size to use for all
|
||
|
datapoints.
|
||
|
- The name of the column to use for the size. This allows each
|
||
|
datapoint to be represented by a circle of a different size.
|
||
|
|
||
|
zoom : int
|
||
|
Zoom level as specified in
|
||
|
https://wiki.openstreetmap.org/wiki/Zoom_levels.
|
||
|
|
||
|
use_container_width : bool
|
||
|
Whether to override the map's native width with the width of
|
||
|
the parent container. If ``use_container_width`` is ``True``
|
||
|
(default), Streamlit sets the width of the map to match the width
|
||
|
of the parent container. If ``use_container_width`` is ``False``,
|
||
|
Streamlit sets the width of the chart to fit its contents according
|
||
|
to the plotting library, up to the width of the parent container.
|
||
|
|
||
|
width : int or None
|
||
|
Desired width of the chart expressed in pixels. If ``width`` is
|
||
|
``None`` (default), Streamlit sets the width of the chart to fit
|
||
|
its contents according to the plotting library, up to the width of
|
||
|
the parent container. If ``width`` is greater than the width of the
|
||
|
parent container, Streamlit sets the chart width to match the width
|
||
|
of the parent container.
|
||
|
|
||
|
To use ``width``, you must set ``use_container_width=False``.
|
||
|
|
||
|
height : int or None
|
||
|
Desired height of the chart expressed in pixels. If ``height`` is
|
||
|
``None`` (default), Streamlit sets the height of the chart to fit
|
||
|
its contents according to the plotting library.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> import streamlit as st
|
||
|
>>> import pandas as pd
|
||
|
>>> import numpy as np
|
||
|
>>>
|
||
|
>>> df = pd.DataFrame(
|
||
|
... np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4],
|
||
|
... columns=["lat", "lon"],
|
||
|
... )
|
||
|
>>> st.map(df)
|
||
|
|
||
|
.. output::
|
||
|
https://doc-map.streamlit.app/
|
||
|
height: 600px
|
||
|
|
||
|
You can also customize the size and color of the datapoints:
|
||
|
|
||
|
>>> st.map(df, size=20, color="#0044ff")
|
||
|
|
||
|
And finally, you can choose different columns to use for the latitude
|
||
|
and longitude components, as well as set size and color of each
|
||
|
datapoint dynamically based on other columns:
|
||
|
|
||
|
>>> import streamlit as st
|
||
|
>>> import pandas as pd
|
||
|
>>> import numpy as np
|
||
|
>>>
|
||
|
>>> df = pd.DataFrame(
|
||
|
... {
|
||
|
... "col1": np.random.randn(1000) / 50 + 37.76,
|
||
|
... "col2": np.random.randn(1000) / 50 + -122.4,
|
||
|
... "col3": np.random.randn(1000) * 100,
|
||
|
... "col4": np.random.rand(1000, 4).tolist(),
|
||
|
... }
|
||
|
... )
|
||
|
>>>
|
||
|
>>> st.map(df, latitude="col1", longitude="col2", size="col3", color="col4")
|
||
|
|
||
|
.. output::
|
||
|
https://doc-map-color.streamlit.app/
|
||
|
height: 600px
|
||
|
|
||
|
"""
|
||
|
map_proto = DeckGlJsonChartProto()
|
||
|
deck_gl_json = to_deckgl_json(data, latitude, longitude, size, color, zoom)
|
||
|
marshall(
|
||
|
map_proto, deck_gl_json, use_container_width, width=width, height=height
|
||
|
)
|
||
|
return self.dg._enqueue("deck_gl_json_chart", map_proto)
|
||
|
|
||
|
@property
|
||
|
def dg(self) -> DeltaGenerator:
|
||
|
"""Get our DeltaGenerator."""
|
||
|
return cast("DeltaGenerator", self)
|
||
|
|
||
|
|
||
|
def to_deckgl_json(
|
||
|
data: Data,
|
||
|
lat: str | None,
|
||
|
lon: str | None,
|
||
|
size: None | str | float,
|
||
|
color: None | str | Collection[float],
|
||
|
zoom: int | None,
|
||
|
) -> str:
|
||
|
if data is None:
|
||
|
return json.dumps(_DEFAULT_MAP)
|
||
|
|
||
|
# TODO(harahu): iterables don't have the empty attribute. This is either
|
||
|
# a bug, or the documented data type is too broad. One or the other
|
||
|
# should be addressed
|
||
|
if hasattr(data, "empty") and data.empty:
|
||
|
return json.dumps(_DEFAULT_MAP)
|
||
|
|
||
|
df = dataframe_util.convert_anything_to_pandas_df(data)
|
||
|
|
||
|
lat_col_name = _get_lat_or_lon_col_name(df, "latitude", lat, _DEFAULT_LAT_COL_NAMES)
|
||
|
lon_col_name = _get_lat_or_lon_col_name(
|
||
|
df, "longitude", lon, _DEFAULT_LON_COL_NAMES
|
||
|
)
|
||
|
size_arg, size_col_name = _get_value_and_col_name(df, size, _DEFAULT_SIZE)
|
||
|
color_arg, color_col_name = _get_value_and_col_name(df, color, _DEFAULT_COLOR)
|
||
|
|
||
|
# Drop columns we're not using.
|
||
|
# (Sort for tests)
|
||
|
used_columns = sorted(
|
||
|
[
|
||
|
c
|
||
|
for c in {lat_col_name, lon_col_name, size_col_name, color_col_name}
|
||
|
if c is not None
|
||
|
]
|
||
|
)
|
||
|
df = df[used_columns]
|
||
|
|
||
|
converted_color_arg = _convert_color_arg_or_column(df, color_arg, color_col_name)
|
||
|
|
||
|
zoom, center_lat, center_lon = _get_viewport_details(
|
||
|
df, lat_col_name, lon_col_name, zoom
|
||
|
)
|
||
|
|
||
|
default = copy.deepcopy(_DEFAULT_MAP)
|
||
|
default["initialViewState"]["latitude"] = center_lat
|
||
|
default["initialViewState"]["longitude"] = center_lon
|
||
|
default["initialViewState"]["zoom"] = zoom
|
||
|
default["layers"] = [
|
||
|
{
|
||
|
"@@type": "ScatterplotLayer",
|
||
|
"getPosition": f"@@=[{lon_col_name}, {lat_col_name}]",
|
||
|
"getRadius": size_arg,
|
||
|
"radiusMinPixels": 3,
|
||
|
"radiusUnits": "meters",
|
||
|
"getFillColor": converted_color_arg,
|
||
|
"data": df.to_dict("records"),
|
||
|
}
|
||
|
]
|
||
|
|
||
|
return json.dumps(default)
|
||
|
|
||
|
|
||
|
def _get_lat_or_lon_col_name(
|
||
|
data: DataFrame,
|
||
|
human_readable_name: str,
|
||
|
col_name_from_user: str | None,
|
||
|
default_col_names: set[str],
|
||
|
) -> str:
|
||
|
"""Returns the column name to be used for latitude or longitude."""
|
||
|
|
||
|
if isinstance(col_name_from_user, str) and col_name_from_user in data.columns:
|
||
|
col_name = col_name_from_user
|
||
|
|
||
|
else:
|
||
|
# Try one of the default col_names:
|
||
|
candidate_col_name = None
|
||
|
|
||
|
for c in default_col_names:
|
||
|
if c in data.columns:
|
||
|
candidate_col_name = c
|
||
|
break
|
||
|
|
||
|
if candidate_col_name is None:
|
||
|
formatted_allowed_col_name = ", ".join(map(repr, sorted(default_col_names)))
|
||
|
formmated_col_names = ", ".join(map(repr, list(data.columns)))
|
||
|
|
||
|
raise StreamlitAPIException(
|
||
|
f"Map data must contain a {human_readable_name} column named: "
|
||
|
f"{formatted_allowed_col_name}. Existing columns: {formmated_col_names}"
|
||
|
)
|
||
|
col_name = candidate_col_name
|
||
|
|
||
|
# Check that the column is well-formed.
|
||
|
# IMPLEMENTATION NOTE: We can't use isnull().values.any() because .values can return
|
||
|
# ExtensionArrays, which don't have a .any() method.
|
||
|
# (Read about ExtensionArrays here: # https://pandas.pydata.org/community/blog/extension-arrays.html)
|
||
|
# However, after a performance test I found the solution below runs basically as
|
||
|
# fast as .values.any().
|
||
|
if any(data[col_name].isna().array):
|
||
|
raise StreamlitAPIException(
|
||
|
f"Column {col_name} is not allowed to contain null values, such "
|
||
|
"as NaN, NaT, or None."
|
||
|
)
|
||
|
|
||
|
return col_name
|
||
|
|
||
|
|
||
|
def _get_value_and_col_name(
|
||
|
data: DataFrame,
|
||
|
value_or_name: Any,
|
||
|
default_value: Any,
|
||
|
) -> tuple[str, str | None]:
|
||
|
"""Take a value_or_name passed in by the Streamlit developer and return a PyDeck
|
||
|
argument and column name for that property.
|
||
|
|
||
|
This is used for the size and color properties of the chart.
|
||
|
|
||
|
Example:
|
||
|
- If the user passes size=None, this returns the default size value and no column.
|
||
|
- If the user passes size=42, this returns 42 and no column.
|
||
|
- If the user passes size="my_col_123", this returns "@@=my_col_123" and "my_col_123".
|
||
|
"""
|
||
|
|
||
|
pydeck_arg: str
|
||
|
|
||
|
if isinstance(value_or_name, str) and value_or_name in data.columns:
|
||
|
col_name = value_or_name
|
||
|
pydeck_arg = f"@@={col_name}"
|
||
|
else:
|
||
|
col_name = None
|
||
|
|
||
|
pydeck_arg = default_value if value_or_name is None else value_or_name
|
||
|
|
||
|
return pydeck_arg, col_name
|
||
|
|
||
|
|
||
|
def _convert_color_arg_or_column(
|
||
|
data: DataFrame,
|
||
|
color_arg: str,
|
||
|
color_col_name: str | None,
|
||
|
) -> None | str | IntColorTuple:
|
||
|
"""Converts color to a format accepted by PyDeck.
|
||
|
|
||
|
For example:
|
||
|
- If color_arg is "#fff", then returns (255, 255, 255, 255).
|
||
|
- If color_col_name is "my_col_123", then it converts everything in column my_col_123 to
|
||
|
an accepted color format such as (0, 100, 200, 255).
|
||
|
|
||
|
NOTE: This function mutates the data argument.
|
||
|
"""
|
||
|
|
||
|
color_arg_out: None | str | IntColorTuple = None
|
||
|
|
||
|
if color_col_name is not None:
|
||
|
# Convert color column to the right format.
|
||
|
if len(data[color_col_name]) > 0 and is_color_like(
|
||
|
data[color_col_name].iloc[0]
|
||
|
):
|
||
|
# Use .loc[] to avoid a SettingWithCopyWarning in some cases.
|
||
|
data.loc[:, color_col_name] = data.loc[:, color_col_name].map(
|
||
|
to_int_color_tuple
|
||
|
)
|
||
|
else:
|
||
|
raise StreamlitAPIException(
|
||
|
f'Column "{color_col_name}" does not appear to contain valid colors.'
|
||
|
)
|
||
|
|
||
|
color_arg_out = color_arg
|
||
|
|
||
|
elif color_arg is not None:
|
||
|
color_arg_out = to_int_color_tuple(color_arg)
|
||
|
|
||
|
return color_arg_out
|
||
|
|
||
|
|
||
|
def _get_viewport_details(
|
||
|
data: DataFrame, lat_col_name: str, lon_col_name: str, zoom: int | None
|
||
|
) -> tuple[int, float, float]:
|
||
|
"""Auto-set viewport when not fully specified by user."""
|
||
|
min_lat = data[lat_col_name].min()
|
||
|
max_lat = data[lat_col_name].max()
|
||
|
min_lon = data[lon_col_name].min()
|
||
|
max_lon = data[lon_col_name].max()
|
||
|
center_lat = (max_lat + min_lat) / 2.0
|
||
|
center_lon = (max_lon + min_lon) / 2.0
|
||
|
range_lon = abs(max_lon - min_lon)
|
||
|
range_lat = abs(max_lat - min_lat)
|
||
|
|
||
|
if zoom is None:
|
||
|
longitude_distance = max(range_lat, range_lon)
|
||
|
zoom = _get_zoom_level(longitude_distance)
|
||
|
|
||
|
return zoom, center_lat, center_lon
|
||
|
|
||
|
|
||
|
def _get_zoom_level(distance: float) -> int:
|
||
|
"""Get the zoom level for a given distance in degrees.
|
||
|
|
||
|
See https://wiki.openstreetmap.org/wiki/Zoom_levels for reference.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
distance : float
|
||
|
How many degrees of longitude should fit in the map.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
int
|
||
|
The zoom level, from 0 to 20.
|
||
|
|
||
|
"""
|
||
|
for i in range(len(_ZOOM_LEVELS) - 1):
|
||
|
if _ZOOM_LEVELS[i + 1] < distance <= _ZOOM_LEVELS[i]:
|
||
|
return i
|
||
|
|
||
|
# For small number of points the default zoom level will be used.
|
||
|
return _DEFAULT_ZOOM_LEVEL
|
||
|
|
||
|
|
||
|
def marshall(
|
||
|
pydeck_proto: DeckGlJsonChartProto,
|
||
|
pydeck_json: str,
|
||
|
use_container_width: bool,
|
||
|
height: int | None = None,
|
||
|
width: int | None = None,
|
||
|
) -> None:
|
||
|
pydeck_proto.json = pydeck_json
|
||
|
pydeck_proto.use_container_width = use_container_width
|
||
|
|
||
|
if width:
|
||
|
pydeck_proto.width = width
|
||
|
if height:
|
||
|
pydeck_proto.height = height
|
||
|
|
||
|
pydeck_proto.id = ""
|
||
|
|
||
|
mapbox_token = config.get_option("mapbox.token")
|
||
|
if mapbox_token:
|
||
|
pydeck_proto.mapbox_token = mapbox_token
|