yet more typing magic

This commit is contained in:
Laura Klünder 2022-04-08 00:03:58 +02:00
parent e1a1ae8bc4
commit b8002a4aba

View file

@ -12,7 +12,7 @@ from c3nav.mapdata.utils.geometry import assert_multipolygon
class MplPathProxy(ABC):
@abstractmethod
def intersects_path(self, path: Path) -> bool:
def intersects_path(self, path: Path, filled: bool = False) -> bool:
pass
@abstractmethod
@ -38,7 +38,7 @@ class MplPolygonPath(MplPathProxy):
def exteriors(self):
return (self.exterior, )
def intersects_path(self, path, filled=False):
def intersects_path(self, path: Path, filled: bool = False) -> bool:
if filled:
if not self.exterior.intersects_path(path, filled=True):
return False
@ -56,7 +56,7 @@ class MplPolygonPath(MplPathProxy):
return True
return False
def contains_points(self, points):
def contains_points(self, points: np.ndarray[tuple[int, Literal[2]], np.uint32]) -> bool:
# noinspection PyTypeChecker
result = self.exterior.contains_points(points)
for interior in self.interiors:
@ -66,7 +66,7 @@ class MplPolygonPath(MplPathProxy):
result[ix] = np.logical_not(interior.contains_points(points[ix]))
return result
def contains_point(self, point):
def contains_point(self, point: tuple[int, int]) -> bool:
if not self.exterior.contains_point(point):
return False
@ -88,19 +88,19 @@ class MplMultipolygonPath(MplPathProxy):
def exteriors(self):
return tuple(polygon.exterior for polygon in self.polygons)
def intersects_path(self, path, filled=False):
def intersects_path(self, path: Path, filled: bool = False) -> bool:
for polygon in self.polygons:
if polygon.intersects_path(path, filled=filled):
return True
return False
def contains_point(self, point):
def contains_point(self, point: tuple[int, int]) -> bool:
for polygon in self.polygons:
if polygon.contains_point(point):
return True
return False
def contains_points(self, points):
def contains_points(self, points: np.ndarray[tuple[int, Literal[2]], np.uint32]) -> bool:
result = np.full((len(points),), fill_value=False, dtype=np.bool)
for polygon in self.polygons:
ix = np.argwhere(np.logical_not(result)).flatten()