From e1a1ae8bc4538efadb02bd48730c15f0d31927a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Fri, 8 Apr 2022 00:00:55 +0200 Subject: [PATCH] some more type hinting magic --- src/c3nav/mapdata/render/geometry/hybrid.py | 9 ++++++--- src/c3nav/mapdata/utils/mpl.py | 18 ++++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/c3nav/mapdata/render/geometry/hybrid.py b/src/c3nav/mapdata/render/geometry/hybrid.py index 1b0208a0..48fc1125 100644 --- a/src/c3nav/mapdata/render/geometry/hybrid.py +++ b/src/c3nav/mapdata/render/geometry/hybrid.py @@ -29,7 +29,7 @@ def hybrid_union(geoms): return HybridGeometry(geom=unary_union(tuple(geom.geom for geom in geoms)), faces=tuple(chain(*(geom.faces for geom in geoms))), add_faces=add_faces, - crop_ids=reduce(operator.or_, (other.crop_ids for other in geoms), set())) + crop_ids=reduce(operator.or_, (other.crop_ids for other in geoms), frozenset())) THybridGeometry = TypeVar("THybridGeometry", bound="HybridGeometry") @@ -51,7 +51,7 @@ class HybridGeometry: add_faces: dict = field(default_factory=dict) # todo: specify type more precisely @classmethod - def create(cls, geom, face_centers) -> THybridGeometry: + def create(cls, geom, face_centers: np.ndarray[tuple[int, Literal[2]], np.uint32]) -> THybridGeometry: """ Create from existing facets and just select the ones that lie inside this polygon. """ @@ -61,7 +61,10 @@ class HybridGeometry: set(np.argwhere(shapely_to_mpl(subgeom).contains_points(face_centers)).flatten()) for subgeom in assert_multipolygon(geom) ) - return HybridGeometry(geom, tuple(f for f in faces if f)) # todo: wtf? that is the wrong typing + + faces = tuple(reduce(operator.or_, faces, set())) + return HybridGeometry(geom, faces) # old code had wrong typing + # return HybridGeometry(geom, tuple(f for f in faces if f)) # old code had wrong typing @classmethod def create_full(cls, geom: BaseGeometry, diff --git a/src/c3nav/mapdata/utils/mpl.py b/src/c3nav/mapdata/utils/mpl.py index 54d8c776..969bf98d 100644 --- a/src/c3nav/mapdata/utils/mpl.py +++ b/src/c3nav/mapdata/utils/mpl.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import InitVar, dataclass, field +from typing import Literal import numpy as np from matplotlib.path import Path @@ -11,11 +12,15 @@ from c3nav.mapdata.utils.geometry import assert_multipolygon class MplPathProxy(ABC): @abstractmethod - def intersects_path(self, path): + def intersects_path(self, path: Path) -> bool: pass @abstractmethod - def contains_point(self, point): + def contains_point(self, point: tuple[int, int]) -> bool: + pass + + @abstractmethod + def contains_points(self, points: np.ndarray[tuple[int, Literal[2]], np.uint32]) -> bool: pass @@ -23,11 +28,11 @@ class MplPathProxy(ABC): class MplPolygonPath(MplPathProxy): polygon: InitVar[Polygon] exterior: Path = field(init=False) - interiors: list[Path] = field(init=False) + interiors: tuple[Path, ...] = field(init=False) def __post_init__(self, polygon): self.exterior = linearring_to_mpl_path(polygon.exterior) - self.interiors = [linearring_to_mpl_path(interior) for interior in polygon.interiors] + self.interiors = tuple(linearring_to_mpl_path(interior) for interior in polygon.interiors) @property def exteriors(self): @@ -52,6 +57,7 @@ class MplPolygonPath(MplPathProxy): return False def contains_points(self, points): + # noinspection PyTypeChecker result = self.exterior.contains_points(points) for interior in self.interiors: if not result.any(): @@ -72,11 +78,11 @@ class MplPolygonPath(MplPathProxy): @dataclass(slots=True) class MplMultipolygonPath(MplPathProxy): - polygons: list[MplPolygonPath] = field(init=False) + polygons: tuple[MplPolygonPath, ...] = field(init=False) polygons_: InitVar[Polygon | MultiPolygon | GeometryCollection] def __post_init__(self, polygons_): - self.polygons = [MplPolygonPath(polygon) for polygon in assert_multipolygon(polygons_)] + self.polygons = tuple(MplPolygonPath(polygon) for polygon in assert_multipolygon(polygons_)) @property def exteriors(self):