From 0637b120cda276e15375c2fa6db46f5feac267c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Tue, 17 Dec 2024 22:53:28 +0000 Subject: [PATCH] more router typehinting --- src/c3nav/routing/router.py | 78 +++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/src/c3nav/routing/router.py b/src/c3nav/routing/router.py index 32202a57..2bcc5e33 100644 --- a/src/c3nav/routing/router.py +++ b/src/c3nav/routing/router.py @@ -5,7 +5,7 @@ from collections import deque, namedtuple from dataclasses import dataclass, field from functools import reduce from itertools import chain -from typing import Optional, TypeVar, Generic, Mapping, Any, Sequence, TypeAlias +from typing import Optional, TypeVar, Generic, Mapping, Any, Sequence, TypeAlias, ClassVar import numpy as np from django.conf import settings @@ -38,21 +38,20 @@ PointCompatible: TypeAlias = Point | CustomLocation | CustomLocationProxyMixin EdgeIndex: TypeAlias = tuple[int, int] +@dataclass class Router: - filename = settings.CACHE_ROOT / 'router' + filename: ClassVar = settings.CACHE_ROOT / 'router' - def __init__(self, levels, spaces, areas, pois, groups, restrictions: dict[int, "RouterRestriction"], - nodes, edges, waytypes, graph): - self.levels = levels - self.spaces = spaces - self.areas = areas - self.pois = pois - self.groups = groups - self.restrictions = restrictions - self.nodes = nodes - self.edges = edges - self.waytypes = waytypes - self.graph = graph + levels: dict[int, "RouterLevel"] + spaces: dict[int, "RouterSpace"] + areas: dict[int, "RouterArea"] + pois: dict[int, "RouterPoint"] + groups: dict[int, "RouterGroup"] + restrictions: dict[int, "RouterRestriction"] + nodes: deque["RouterNode"] + edghes: dict[EdgeIndex, "RouterEdge"] + waytypes: dict[int, "RouterWayType"] + graph: np.ndarray @staticmethod def get_altitude_in_areas(areas, point): @@ -66,20 +65,20 @@ class Router: 'spaces__graphnodes', 'spaces__areas', 'spaces__areas__groups', 'spaces__pois', 'spaces__pois__groups') - levels = {} - spaces = {} - areas = {} - pois = {} - groups = {} - restrictions = {} - nodes = deque() + levels: dict[int, RouterLevel] = {} + spaces: dict[int, RouterSpace] = {} + areas: dict[int, RouterArea] = {} + pois: dict[int, RouterPoint] = {} + groups: dict[int, RouterGroup] = {} + restrictions: dict[int, RouterRestriction] = {} + nodes: deque[RouterNode] = deque() for level in levels_query: buildings_geom = unary_union(tuple(unwrap_geom(building.geometry) for building in level.buildings.all())) nodes_before_count = len(nodes) for group in level.groups.all(): - groups.setdefault(group.pk, {}).setdefault('levels', set()).add(level.pk) + groups.setdefault(group.pk, RouterGroup()).levels.add(level.pk) if level.access_restriction_id: restrictions.setdefault(level.access_restriction_id, RouterRestriction()).spaces.update( @@ -103,7 +102,7 @@ class Router: clear_geom_prep = prepared.prep(clear_geom) for group in space.groups.all(): - groups.setdefault(group.pk, {}).setdefault('spaces', set()).add(space.pk) + groups.setdefault(group.pk, RouterGroup()).spaces.add(space.pk) if space.access_restriction_id: restrictions.setdefault(space.access_restriction_id, RouterRestriction()).spaces.add(space.pk) @@ -120,7 +119,7 @@ class Router: for area in space_obj.areas.all(): for group in area.groups.all(): - groups.setdefault(group.pk, {}).setdefault('areas', set()).add(area.pk) + groups.setdefault(group.pk, RouterGroup()).areas.add(area.pk) area._prefetched_objects_cache = {} area = RouterArea(area) @@ -199,7 +198,7 @@ class Router: for poi in space_obj.pois.all(): for group in poi.groups.all(): - groups.setdefault(group.pk, {}).setdefault('pois', set()).add(poi.pk) + groups.setdefault(group.pk, RouterGroup()).pois.add(poi.pk) poi._prefetched_objects_cache = {} poi = RouterPoint(poi) @@ -293,7 +292,18 @@ class Router: for restriction in restrictions.values(): restriction.edges = np.array(restriction.edges, dtype=np.uint32).reshape((-1, 2)) - router = cls(levels, spaces, areas, pois, groups, restrictions, nodes, edges, waytypes, graph) + router = cls( + levels=levels, + spaces=spaces, + areas=areas, + pois=pois, + groups=groups, + restrictions=restrictions, + nodes=nodes, + edges=edges, + waytypes=waytypes, + graph=graph + ) pickle.dump(router, open(cls.build_filename(update), 'wb')) return router @@ -346,13 +356,13 @@ class Router: raise NotYetRoutable group = self.groups[location.pk] locations = tuple(chain( - (level for level in (self.levels[pk] for pk in group.get('levels', ())) + (level for level in (self.levels[pk] for pk in group.levels) if level.access_restriction_id not in restrictions), - (space for space in (self.spaces[pk] for pk in group.get('spaces', ())) + (space for space in (self.spaces[pk] for pk in group.spaces) if space.pk not in restrictions.spaces), - (area for area in (self.areas[pk] for pk in group.get('areas', ())) + (area for area in (self.areas[pk] for pk in group.areas) if area.space_id not in restrictions.spaces and area.access_restriction_id not in restrictions), - (poi for poi in (self.pois[pk] for pk in group.get('pois', ())) + (poi for poi in (self.pois[pk] for pk in group.pois) if poi.space_id not in restrictions.spaces and poi.access_restriction_id not in restrictions), )) elif isinstance(location, (CustomLocation, CustomLocationProxyMixin)): @@ -651,6 +661,14 @@ class RouterPoint(BaseRouterProxy[Point]): return np.array((self.x, self.y, self.altitude)) +@dataclass +class RouterGroup: + levels: set[int] = field(default_factory=set) + spaces: set[int] = field(default_factory=set) + areas: set[int] = field(default_factory=set) + pois: set[int] = field(default_factory=set) + + @dataclass class RouterAltitudeArea: geometry: Polygon | MultiPolygon