more router typehinting
This commit is contained in:
parent
20b746e814
commit
0637b120cd
1 changed files with 48 additions and 30 deletions
|
@ -5,7 +5,7 @@ from collections import deque, namedtuple
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Optional, TypeVar, Generic, Mapping, Any, Sequence, TypeAlias
|
from typing import Optional, TypeVar, Generic, Mapping, Any, Sequence, TypeAlias, ClassVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
@ -38,21 +38,20 @@ PointCompatible: TypeAlias = Point | CustomLocation | CustomLocationProxyMixin
|
||||||
EdgeIndex: TypeAlias = tuple[int, int]
|
EdgeIndex: TypeAlias = tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Router:
|
class Router:
|
||||||
filename = settings.CACHE_ROOT / 'router'
|
filename: ClassVar = settings.CACHE_ROOT / 'router'
|
||||||
|
|
||||||
def __init__(self, levels, spaces, areas, pois, groups, restrictions: dict[int, "RouterRestriction"],
|
levels: dict[int, "RouterLevel"]
|
||||||
nodes, edges, waytypes, graph):
|
spaces: dict[int, "RouterSpace"]
|
||||||
self.levels = levels
|
areas: dict[int, "RouterArea"]
|
||||||
self.spaces = spaces
|
pois: dict[int, "RouterPoint"]
|
||||||
self.areas = areas
|
groups: dict[int, "RouterGroup"]
|
||||||
self.pois = pois
|
restrictions: dict[int, "RouterRestriction"]
|
||||||
self.groups = groups
|
nodes: deque["RouterNode"]
|
||||||
self.restrictions = restrictions
|
edghes: dict[EdgeIndex, "RouterEdge"]
|
||||||
self.nodes = nodes
|
waytypes: dict[int, "RouterWayType"]
|
||||||
self.edges = edges
|
graph: np.ndarray
|
||||||
self.waytypes = waytypes
|
|
||||||
self.graph = graph
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_altitude_in_areas(areas, point):
|
def get_altitude_in_areas(areas, point):
|
||||||
|
@ -66,20 +65,20 @@ class Router:
|
||||||
'spaces__graphnodes', 'spaces__areas', 'spaces__areas__groups',
|
'spaces__graphnodes', 'spaces__areas', 'spaces__areas__groups',
|
||||||
'spaces__pois', 'spaces__pois__groups')
|
'spaces__pois', 'spaces__pois__groups')
|
||||||
|
|
||||||
levels = {}
|
levels: dict[int, RouterLevel] = {}
|
||||||
spaces = {}
|
spaces: dict[int, RouterSpace] = {}
|
||||||
areas = {}
|
areas: dict[int, RouterArea] = {}
|
||||||
pois = {}
|
pois: dict[int, RouterPoint] = {}
|
||||||
groups = {}
|
groups: dict[int, RouterGroup] = {}
|
||||||
restrictions = {}
|
restrictions: dict[int, RouterRestriction] = {}
|
||||||
nodes = deque()
|
nodes: deque[RouterNode] = deque()
|
||||||
for level in levels_query:
|
for level in levels_query:
|
||||||
buildings_geom = unary_union(tuple(unwrap_geom(building.geometry) for building in level.buildings.all()))
|
buildings_geom = unary_union(tuple(unwrap_geom(building.geometry) for building in level.buildings.all()))
|
||||||
|
|
||||||
nodes_before_count = len(nodes)
|
nodes_before_count = len(nodes)
|
||||||
|
|
||||||
for group in level.groups.all():
|
for group in level.groups.all():
|
||||||
groups.setdefault(group.pk, {}).setdefault('levels', set()).add(level.pk)
|
groups.setdefault(group.pk, RouterGroup()).levels.add(level.pk)
|
||||||
|
|
||||||
if level.access_restriction_id:
|
if level.access_restriction_id:
|
||||||
restrictions.setdefault(level.access_restriction_id, RouterRestriction()).spaces.update(
|
restrictions.setdefault(level.access_restriction_id, RouterRestriction()).spaces.update(
|
||||||
|
@ -103,7 +102,7 @@ class Router:
|
||||||
clear_geom_prep = prepared.prep(clear_geom)
|
clear_geom_prep = prepared.prep(clear_geom)
|
||||||
|
|
||||||
for group in space.groups.all():
|
for group in space.groups.all():
|
||||||
groups.setdefault(group.pk, {}).setdefault('spaces', set()).add(space.pk)
|
groups.setdefault(group.pk, RouterGroup()).spaces.add(space.pk)
|
||||||
|
|
||||||
if space.access_restriction_id:
|
if space.access_restriction_id:
|
||||||
restrictions.setdefault(space.access_restriction_id, RouterRestriction()).spaces.add(space.pk)
|
restrictions.setdefault(space.access_restriction_id, RouterRestriction()).spaces.add(space.pk)
|
||||||
|
@ -120,7 +119,7 @@ class Router:
|
||||||
|
|
||||||
for area in space_obj.areas.all():
|
for area in space_obj.areas.all():
|
||||||
for group in area.groups.all():
|
for group in area.groups.all():
|
||||||
groups.setdefault(group.pk, {}).setdefault('areas', set()).add(area.pk)
|
groups.setdefault(group.pk, RouterGroup()).areas.add(area.pk)
|
||||||
area._prefetched_objects_cache = {}
|
area._prefetched_objects_cache = {}
|
||||||
|
|
||||||
area = RouterArea(area)
|
area = RouterArea(area)
|
||||||
|
@ -199,7 +198,7 @@ class Router:
|
||||||
|
|
||||||
for poi in space_obj.pois.all():
|
for poi in space_obj.pois.all():
|
||||||
for group in poi.groups.all():
|
for group in poi.groups.all():
|
||||||
groups.setdefault(group.pk, {}).setdefault('pois', set()).add(poi.pk)
|
groups.setdefault(group.pk, RouterGroup()).pois.add(poi.pk)
|
||||||
poi._prefetched_objects_cache = {}
|
poi._prefetched_objects_cache = {}
|
||||||
|
|
||||||
poi = RouterPoint(poi)
|
poi = RouterPoint(poi)
|
||||||
|
@ -293,7 +292,18 @@ class Router:
|
||||||
for restriction in restrictions.values():
|
for restriction in restrictions.values():
|
||||||
restriction.edges = np.array(restriction.edges, dtype=np.uint32).reshape((-1, 2))
|
restriction.edges = np.array(restriction.edges, dtype=np.uint32).reshape((-1, 2))
|
||||||
|
|
||||||
router = cls(levels, spaces, areas, pois, groups, restrictions, nodes, edges, waytypes, graph)
|
router = cls(
|
||||||
|
levels=levels,
|
||||||
|
spaces=spaces,
|
||||||
|
areas=areas,
|
||||||
|
pois=pois,
|
||||||
|
groups=groups,
|
||||||
|
restrictions=restrictions,
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
waytypes=waytypes,
|
||||||
|
graph=graph
|
||||||
|
)
|
||||||
pickle.dump(router, open(cls.build_filename(update), 'wb'))
|
pickle.dump(router, open(cls.build_filename(update), 'wb'))
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
@ -346,13 +356,13 @@ class Router:
|
||||||
raise NotYetRoutable
|
raise NotYetRoutable
|
||||||
group = self.groups[location.pk]
|
group = self.groups[location.pk]
|
||||||
locations = tuple(chain(
|
locations = tuple(chain(
|
||||||
(level for level in (self.levels[pk] for pk in group.get('levels', ()))
|
(level for level in (self.levels[pk] for pk in group.levels)
|
||||||
if level.access_restriction_id not in restrictions),
|
if level.access_restriction_id not in restrictions),
|
||||||
(space for space in (self.spaces[pk] for pk in group.get('spaces', ()))
|
(space for space in (self.spaces[pk] for pk in group.spaces)
|
||||||
if space.pk not in restrictions.spaces),
|
if space.pk not in restrictions.spaces),
|
||||||
(area for area in (self.areas[pk] for pk in group.get('areas', ()))
|
(area for area in (self.areas[pk] for pk in group.areas)
|
||||||
if area.space_id not in restrictions.spaces and area.access_restriction_id not in restrictions),
|
if area.space_id not in restrictions.spaces and area.access_restriction_id not in restrictions),
|
||||||
(poi for poi in (self.pois[pk] for pk in group.get('pois', ()))
|
(poi for poi in (self.pois[pk] for pk in group.pois)
|
||||||
if poi.space_id not in restrictions.spaces and poi.access_restriction_id not in restrictions),
|
if poi.space_id not in restrictions.spaces and poi.access_restriction_id not in restrictions),
|
||||||
))
|
))
|
||||||
elif isinstance(location, (CustomLocation, CustomLocationProxyMixin)):
|
elif isinstance(location, (CustomLocation, CustomLocationProxyMixin)):
|
||||||
|
@ -651,6 +661,14 @@ class RouterPoint(BaseRouterProxy[Point]):
|
||||||
return np.array((self.x, self.y, self.altitude))
|
return np.array((self.x, self.y, self.altitude))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RouterGroup:
|
||||||
|
levels: set[int] = field(default_factory=set)
|
||||||
|
spaces: set[int] = field(default_factory=set)
|
||||||
|
areas: set[int] = field(default_factory=set)
|
||||||
|
pois: set[int] = field(default_factory=set)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RouterAltitudeArea:
|
class RouterAltitudeArea:
|
||||||
geometry: Polygon | MultiPolygon
|
geometry: Polygon | MultiPolygon
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue