diff --git a/src/c3nav/mapdata/render/geometry/hybrid.py b/src/c3nav/mapdata/render/geometry/hybrid.py index 1d1b1e71..1b0208a0 100644 --- a/src/c3nav/mapdata/render/geometry/hybrid.py +++ b/src/c3nav/mapdata/render/geometry/hybrid.py @@ -1,12 +1,17 @@ import operator from collections import deque +from dataclasses import dataclass, field from functools import reduce from itertools import chain +from typing import Literal, TypeVar import numpy as np -from shapely.geometry import GeometryCollection, LineString, MultiLineString, Point +from matplotlib.patches import Polygon +from shapely.geometry import GeometryCollection, LineString, MultiLineString, MultiPolygon, Point +from shapely.geometry.base import BaseGeometry from shapely.ops import unary_union +from c3nav.mapdata.render.geometry.mesh import Mesh from c3nav.mapdata.utils.geometry import assert_multipolygon from c3nav.mapdata.utils.mesh import triangulate_polygon from c3nav.mapdata.utils.mpl import shapely_to_mpl @@ -27,6 +32,10 @@ def hybrid_union(geoms): crop_ids=reduce(operator.or_, (other.crop_ids for other in geoms), set())) +THybridGeometry = TypeVar("THybridGeometry", bound="HybridGeometry") + + +@dataclass(slots=True) class HybridGeometry: """ A geometry containing a mesh as well as a shapely geometry, @@ -36,38 +45,40 @@ class HybridGeometry: - 2d mesh state where faces refers to indizes of faces from an external list - 3d mesh state where faces refers to Mesh instances """ - __slots__ = ('geom', 'faces', 'crop_ids', 'add_faces') - - def __init__(self, geom, faces, crop_ids=frozenset(), add_faces=None): - self.geom = geom - self.faces = faces - self.add_faces = add_faces or {} - self.crop_ids = crop_ids + geom: BaseGeometry + faces: tuple[int, ...] | tuple[Mesh, ...] + crop_ids: frozenset = field(default_factory=frozenset) # todo: specify type more precisely + add_faces: dict = field(default_factory=dict) # todo: specify type more precisely @classmethod - def create(cls, geom, face_centers): + def create(cls, geom, face_centers) -> THybridGeometry: """ Create from existing facets and just select the ones that lie inside this polygon. """ if isinstance(geom, (LineString, MultiLineString)): - return HybridGeometry(geom, set()) + return HybridGeometry(geom, ()) faces = tuple( set(np.argwhere(shapely_to_mpl(subgeom).contains_points(face_centers)).flatten()) for subgeom in assert_multipolygon(geom) ) - return HybridGeometry(geom, tuple(f for f in faces if f)) + return HybridGeometry(geom, tuple(f for f in faces if f)) # todo: wtf? that is the wrong typing @classmethod - def create_full(cls, geom, vertices_offset, faces_offset): + def create_full(cls, geom: BaseGeometry, + vertices_offset: int, faces_offset: int) -> tuple[THybridGeometry, + np.ndarray[tuple[int, Literal[2]], np.uint32], + np.ndarray[tuple[int, Literal[3]], np.uint32]]: """ Create by triangulating a polygon and adding the resulting facets to the total list. """ if isinstance(geom, (LineString, MultiLineString, Point)): - return HybridGeometry(geom, set()), np.empty((0, 2), dtype=np.int32), np.empty((0, 3), dtype=np.uint32) + return HybridGeometry(geom, tuple()), np.empty((0, 2), dtype=np.int32), np.empty((0, 3), dtype=np.uint32) - vertices = deque() - faces = deque() - faces_i = deque() + geom: Polygon | MultiPolygon | GeometryCollection + + vertices: deque = deque() + faces: deque = deque() + faces_i: deque = deque() for subgeom in assert_multipolygon(geom): new_vertices, new_faces = triangulate_polygon(subgeom) new_faces += vertices_offset @@ -78,10 +89,10 @@ class HybridGeometry: faces_offset += new_faces.shape[0] if not vertices: - return HybridGeometry(geom, set()), np.empty((0, 2), dtype=np.int32), np.empty((0, 3), dtype=np.uint32) + return HybridGeometry(geom, tuple()), np.empty((0, 2), dtype=np.int32), np.empty((0, 3), dtype=np.uint32) - vertices = np.vstack(vertices) - faces = np.vstack(faces) + vertices: np.ndarray[tuple[int, Literal[2]], np.uint32] = np.vstack(vertices) + faces: np.ndarray[tuple[int, Literal[3]], np.uint32] = np.vstack(faces) return HybridGeometry(geom, tuple(faces_i)), vertices, faces diff --git a/src/c3nav/mapdata/utils/geometry.py b/src/c3nav/mapdata/utils/geometry.py index 360427c5..1b39160b 100644 --- a/src/c3nav/mapdata/utils/geometry.py +++ b/src/c3nav/mapdata/utils/geometry.py @@ -87,7 +87,7 @@ def clean_geometry(geometry): return geometry -def assert_multipolygon(geometry: Union[Polygon, MultiPolygon, GeometryCollection]) -> List[Polygon]: +def assert_multipolygon(geometry: Polygon | MultiPolygon | GeometryCollection) -> list[Polygon]: """ given a Polygon or a MultiPolygon, return a list of Polygons :param geometry: a Polygon or a MultiPolygon @@ -100,7 +100,7 @@ def assert_multipolygon(geometry: Union[Polygon, MultiPolygon, GeometryCollectio return [geom for geom in geometry.geoms if isinstance(geom, Polygon)] -def assert_multilinestring(geometry: Union[LineString, MultiLineString, GeometryCollection]) -> List[LineString]: +def assert_multilinestring(geometry: LineString | MultiLineString | GeometryCollection) -> list[LineString]: """ given a LineString or MultiLineString, return a list of LineStrings :param geometry: a LineString or a MultiLineString diff --git a/src/c3nav/mapdata/utils/mesh.py b/src/c3nav/mapdata/utils/mesh.py index 3fca2cb7..4cd237df 100644 --- a/src/c3nav/mapdata/utils/mesh.py +++ b/src/c3nav/mapdata/utils/mesh.py @@ -1,7 +1,7 @@ from collections import deque from functools import lru_cache from itertools import chain -from typing import Union +from typing import Literal, Union import numpy as np from meshpy import triangle @@ -15,7 +15,8 @@ def get_face_indizes(start, length): return np.vstack((indices, (indices[-1][-1], indices[0][0]))) -def triangulate_rings(rings, holes=None): +def triangulate_rings(rings, holes=None) -> tuple[np.ndarray[tuple[int, Literal[2]], np.uint32], + np.ndarray[tuple[int, Literal[3]], np.uint32]]: return ( np.zeros((0, 2), dtype=np.uint32), np.zeros((0, 3), dtype=np.uint32), @@ -62,7 +63,8 @@ def triangulate_rings(rings, holes=None): return mesh_points, mesh_elements -def _triangulate_polygon(polygon: Polygon, keep_holes=False): +def _triangulate_polygon(polygon: Polygon, keep_holes=False) -> tuple[np.ndarray[tuple[int, Literal[2]], np.uint32], + np.ndarray[tuple[int, Literal[3]], np.uint32]]: holes = None if not keep_holes: holes = np.array(tuple( @@ -73,7 +75,9 @@ def _triangulate_polygon(polygon: Polygon, keep_holes=False): return triangulate_rings((polygon.exterior, *polygon.interiors), holes) -def triangulate_polygon(geometry: Union[Polygon, MultiPolygon], keep_holes=False): +def triangulate_polygon(geometry: Union[Polygon, MultiPolygon], + keep_holes=False) -> tuple[np.ndarray[tuple[int, Literal[2]], np.uint32], + np.ndarray[tuple[int, Literal[3]], np.uint32]]: if isinstance(geometry, Polygon): return _triangulate_polygon(geometry, keep_holes=keep_holes) diff --git a/src/c3nav/mapdata/utils/mpl.py b/src/c3nav/mapdata/utils/mpl.py index ce5e60ff..54d8c776 100644 --- a/src/c3nav/mapdata/utils/mpl.py +++ b/src/c3nav/mapdata/utils/mpl.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod +from dataclasses import InitVar, dataclass, field import numpy as np from matplotlib.path import Path -from shapely.geometry import GeometryCollection, MultiPolygon, Polygon +from shapely.geometry import GeometryCollection, LinearRing, MultiPolygon, Polygon +from shapely.geometry.base import BaseGeometry from c3nav.mapdata.utils.geometry import assert_multipolygon @@ -17,40 +19,13 @@ class MplPathProxy(ABC): pass -class MplMultipolygonPath(MplPathProxy): - __slots__ = ('polygons') - - def __init__(self, polygon): - self.polygons = tuple(MplPolygonPath(polygon) for polygon in assert_multipolygon(polygon)) - - @property - def exteriors(self): - return tuple(polygon.exterior for polygon in self.polygons) - - def intersects_path(self, path, filled=False): - for polygon in self.polygons: - if polygon.intersects_path(path, filled=filled): - return True - return False - - def contains_point(self, point): - for polygon in self.polygons: - if polygon.contains_point(point): - return True - return False - - def contains_points(self, points): - result = np.full((len(points),), fill_value=False, dtype=np.bool) - for polygon in self.polygons: - ix = np.argwhere(np.logical_not(result)).flatten() - result[ix] = polygon.contains_points(points[ix]) - return result - - +@dataclass(slots=True) class MplPolygonPath(MplPathProxy): - __slots__ = ('exterior', 'interiors') + polygon: InitVar[Polygon] + exterior: Path = field(init=False) + interiors: list[Path] = field(init=False) - def __init__(self, polygon): + def __post_init__(self, polygon): self.exterior = linearring_to_mpl_path(polygon.exterior) self.interiors = [linearring_to_mpl_path(interior) for interior in polygon.interiors] @@ -95,7 +70,39 @@ class MplPolygonPath(MplPathProxy): return True -def shapely_to_mpl(geometry): +@dataclass(slots=True) +class MplMultipolygonPath(MplPathProxy): + polygons: list[MplPolygonPath] = field(init=False) + polygons_: InitVar[Polygon | MultiPolygon | GeometryCollection] + + def __post_init__(self, polygons_): + self.polygons = [MplPolygonPath(polygon) for polygon in assert_multipolygon(polygons_)] + + @property + def exteriors(self): + return tuple(polygon.exterior for polygon in self.polygons) + + def intersects_path(self, path, filled=False): + for polygon in self.polygons: + if polygon.intersects_path(path, filled=filled): + return True + return False + + def contains_point(self, point): + for polygon in self.polygons: + if polygon.contains_point(point): + return True + return False + + def contains_points(self, points): + result = np.full((len(points),), fill_value=False, dtype=np.bool) + for polygon in self.polygons: + ix = np.argwhere(np.logical_not(result)).flatten() + result[ix] = polygon.contains_points(points[ix]) + return result + + +def shapely_to_mpl(geometry: BaseGeometry) -> MplPathProxy: """ convert a shapely Polygon or Multipolygon to a matplotlib Path :param geometry: shapely Polygon or Multipolygon @@ -108,6 +115,6 @@ def shapely_to_mpl(geometry): raise TypeError -def linearring_to_mpl_path(linearring): +def linearring_to_mpl_path(linearring: LinearRing) -> Path: return Path(np.array(linearring.coords), (Path.MOVETO, *([Path.LINETO] * (len(linearring.coords)-2)), Path.CLOSEPOLY), readonly=True)