more router typehinting

This commit is contained in:
Laura Klünder 2024-12-17 22:53:28 +00:00
parent 20b746e814
commit 0637b120cd

View file

@ -5,7 +5,7 @@ from collections import deque, namedtuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import reduce from functools import reduce
from itertools import chain from itertools import chain
from typing import Optional, TypeVar, Generic, Mapping, Any, Sequence, TypeAlias from typing import Optional, TypeVar, Generic, Mapping, Any, Sequence, TypeAlias, ClassVar
import numpy as np import numpy as np
from django.conf import settings from django.conf import settings
@ -38,21 +38,20 @@ PointCompatible: TypeAlias = Point | CustomLocation | CustomLocationProxyMixin
EdgeIndex: TypeAlias = tuple[int, int] EdgeIndex: TypeAlias = tuple[int, int]
@dataclass
class Router: class Router:
filename = settings.CACHE_ROOT / 'router' filename: ClassVar = settings.CACHE_ROOT / 'router'
def __init__(self, levels, spaces, areas, pois, groups, restrictions: dict[int, "RouterRestriction"], levels: dict[int, "RouterLevel"]
nodes, edges, waytypes, graph): spaces: dict[int, "RouterSpace"]
self.levels = levels areas: dict[int, "RouterArea"]
self.spaces = spaces pois: dict[int, "RouterPoint"]
self.areas = areas groups: dict[int, "RouterGroup"]
self.pois = pois restrictions: dict[int, "RouterRestriction"]
self.groups = groups nodes: deque["RouterNode"]
self.restrictions = restrictions edghes: dict[EdgeIndex, "RouterEdge"]
self.nodes = nodes waytypes: dict[int, "RouterWayType"]
self.edges = edges graph: np.ndarray
self.waytypes = waytypes
self.graph = graph
@staticmethod @staticmethod
def get_altitude_in_areas(areas, point): def get_altitude_in_areas(areas, point):
@ -66,20 +65,20 @@ class Router:
'spaces__graphnodes', 'spaces__areas', 'spaces__areas__groups', 'spaces__graphnodes', 'spaces__areas', 'spaces__areas__groups',
'spaces__pois', 'spaces__pois__groups') 'spaces__pois', 'spaces__pois__groups')
levels = {} levels: dict[int, RouterLevel] = {}
spaces = {} spaces: dict[int, RouterSpace] = {}
areas = {} areas: dict[int, RouterArea] = {}
pois = {} pois: dict[int, RouterPoint] = {}
groups = {} groups: dict[int, RouterGroup] = {}
restrictions = {} restrictions: dict[int, RouterRestriction] = {}
nodes = deque() nodes: deque[RouterNode] = deque()
for level in levels_query: for level in levels_query:
buildings_geom = unary_union(tuple(unwrap_geom(building.geometry) for building in level.buildings.all())) buildings_geom = unary_union(tuple(unwrap_geom(building.geometry) for building in level.buildings.all()))
nodes_before_count = len(nodes) nodes_before_count = len(nodes)
for group in level.groups.all(): for group in level.groups.all():
groups.setdefault(group.pk, {}).setdefault('levels', set()).add(level.pk) groups.setdefault(group.pk, RouterGroup()).levels.add(level.pk)
if level.access_restriction_id: if level.access_restriction_id:
restrictions.setdefault(level.access_restriction_id, RouterRestriction()).spaces.update( restrictions.setdefault(level.access_restriction_id, RouterRestriction()).spaces.update(
@ -103,7 +102,7 @@ class Router:
clear_geom_prep = prepared.prep(clear_geom) clear_geom_prep = prepared.prep(clear_geom)
for group in space.groups.all(): for group in space.groups.all():
groups.setdefault(group.pk, {}).setdefault('spaces', set()).add(space.pk) groups.setdefault(group.pk, RouterGroup()).spaces.add(space.pk)
if space.access_restriction_id: if space.access_restriction_id:
restrictions.setdefault(space.access_restriction_id, RouterRestriction()).spaces.add(space.pk) restrictions.setdefault(space.access_restriction_id, RouterRestriction()).spaces.add(space.pk)
@ -120,7 +119,7 @@ class Router:
for area in space_obj.areas.all(): for area in space_obj.areas.all():
for group in area.groups.all(): for group in area.groups.all():
groups.setdefault(group.pk, {}).setdefault('areas', set()).add(area.pk) groups.setdefault(group.pk, RouterGroup()).areas.add(area.pk)
area._prefetched_objects_cache = {} area._prefetched_objects_cache = {}
area = RouterArea(area) area = RouterArea(area)
@ -199,7 +198,7 @@ class Router:
for poi in space_obj.pois.all(): for poi in space_obj.pois.all():
for group in poi.groups.all(): for group in poi.groups.all():
groups.setdefault(group.pk, {}).setdefault('pois', set()).add(poi.pk) groups.setdefault(group.pk, RouterGroup()).pois.add(poi.pk)
poi._prefetched_objects_cache = {} poi._prefetched_objects_cache = {}
poi = RouterPoint(poi) poi = RouterPoint(poi)
@ -293,7 +292,18 @@ class Router:
for restriction in restrictions.values(): for restriction in restrictions.values():
restriction.edges = np.array(restriction.edges, dtype=np.uint32).reshape((-1, 2)) restriction.edges = np.array(restriction.edges, dtype=np.uint32).reshape((-1, 2))
router = cls(levels, spaces, areas, pois, groups, restrictions, nodes, edges, waytypes, graph) router = cls(
levels=levels,
spaces=spaces,
areas=areas,
pois=pois,
groups=groups,
restrictions=restrictions,
nodes=nodes,
edges=edges,
waytypes=waytypes,
graph=graph
)
pickle.dump(router, open(cls.build_filename(update), 'wb')) pickle.dump(router, open(cls.build_filename(update), 'wb'))
return router return router
@ -346,13 +356,13 @@ class Router:
raise NotYetRoutable raise NotYetRoutable
group = self.groups[location.pk] group = self.groups[location.pk]
locations = tuple(chain( locations = tuple(chain(
(level for level in (self.levels[pk] for pk in group.get('levels', ())) (level for level in (self.levels[pk] for pk in group.levels)
if level.access_restriction_id not in restrictions), if level.access_restriction_id not in restrictions),
(space for space in (self.spaces[pk] for pk in group.get('spaces', ())) (space for space in (self.spaces[pk] for pk in group.spaces)
if space.pk not in restrictions.spaces), if space.pk not in restrictions.spaces),
(area for area in (self.areas[pk] for pk in group.get('areas', ())) (area for area in (self.areas[pk] for pk in group.areas)
if area.space_id not in restrictions.spaces and area.access_restriction_id not in restrictions), if area.space_id not in restrictions.spaces and area.access_restriction_id not in restrictions),
(poi for poi in (self.pois[pk] for pk in group.get('pois', ())) (poi for poi in (self.pois[pk] for pk in group.pois)
if poi.space_id not in restrictions.spaces and poi.access_restriction_id not in restrictions), if poi.space_id not in restrictions.spaces and poi.access_restriction_id not in restrictions),
)) ))
elif isinstance(location, (CustomLocation, CustomLocationProxyMixin)): elif isinstance(location, (CustomLocation, CustomLocationProxyMixin)):
@ -651,6 +661,14 @@ class RouterPoint(BaseRouterProxy[Point]):
return np.array((self.x, self.y, self.altitude)) return np.array((self.x, self.y, self.altitude))
@dataclass
class RouterGroup:
levels: set[int] = field(default_factory=set)
spaces: set[int] = field(default_factory=set)
areas: set[int] = field(default_factory=set)
pois: set[int] = field(default_factory=set)
@dataclass @dataclass
class RouterAltitudeArea: class RouterAltitudeArea:
geometry: Polygon | MultiPolygon geometry: Polygon | MultiPolygon