diff --git a/src/c3nav/mapdata/api.py b/src/c3nav/mapdata/api.py index 64f208ae..4117747e 100644 --- a/src/c3nav/mapdata/api.py +++ b/src/c3nav/mapdata/api.py @@ -1,7 +1,7 @@ import mimetypes from itertools import chain -from django.db.models import Q +from django.db.models import Prefetch, Q from django.http import HttpResponse from django.shortcuts import redirect from django.utils.translation import ugettext_lazy as _ @@ -15,12 +15,18 @@ from c3nav.mapdata.models import Building, Door, Hole, LocationGroup, Source, Sp from c3nav.mapdata.models.geometry.level import LEVEL_MODELS from c3nav.mapdata.models.geometry.space import SPACE_MODELS, Area, Column, LineObstacle, Obstacle, Point, Stair from c3nav.mapdata.models.level import Level -from c3nav.mapdata.models.locations import LOCATION_MODELS, Location, LocationRedirect, LocationSlug +from c3nav.mapdata.models.locations import LOCATION_MODELS, Location, LocationRedirect, LocationSlug, SpecificLocation + + +def optimize_query(qs): + if issubclass(qs.model, SpecificLocation): + qs = qs.prefetch_related(Prefetch('groups', queryset=LocationGroup.objects.only('id'))) + return qs class MapdataViewSet(ReadOnlyModelViewSet): def list(self, request, *args, **kwargs): - qs = self.get_queryset() + qs = optimize_query(self.get_queryset()) geometry = ('geometry' in request.GET) if qs.model in LEVEL_MODELS and 'level' in request.GET: if not request.GET['level'].isdigit(): @@ -145,7 +151,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet): lookup_field = 'slug' def list(self, request, *args, **kwargs): - queryset = sorted(chain(*(model.objects.filter(Q(can_search=True) | Q(can_describe=True)) + queryset = sorted(chain(*(optimize_query(model.objects.filter(Q(can_search=True) | Q(can_describe=True))) for model in LOCATION_MODELS)), key=lambda obj: obj.id) return Response([obj.serialize(include_type=True, detailed='detailed' in request.GET) for obj in queryset]) @@ -171,7 +177,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet): @list_route(methods=['get']) def search(self, request): # todo: implement caching here - results = sorted(chain(*(model.objects.filter(can_search=True) + results = sorted(chain(*(optimize_query(model.objects.filter(can_search=True)) for model in LOCATION_MODELS)), key=lambda obj: obj.id) search = request.GET.get('s') if not search: diff --git a/src/c3nav/mapdata/models/geometry/space.py b/src/c3nav/mapdata/models/geometry/space.py index 6810f9d7..e8bae7b4 100644 --- a/src/c3nav/mapdata/models/geometry/space.py +++ b/src/c3nav/mapdata/models/geometry/space.py @@ -21,7 +21,7 @@ class SpaceGeometryMixin(GeometryMixin): def _serialize(self, space=True, **kwargs): result = super()._serialize(**kwargs) if space: - result['space'] = self.space.id + result['space'] = self.space_id return result def get_geojson_properties(self) -> dict: diff --git a/src/c3nav/mapdata/models/locations.py b/src/c3nav/mapdata/models/locations.py index b336994c..2da69744 100644 --- a/src/c3nav/mapdata/models/locations.py +++ b/src/c3nav/mapdata/models/locations.py @@ -136,7 +136,7 @@ class SpecificLocation(Location, models.Model): def _serialize(self, **kwargs): result = super()._serialize(**kwargs) - result['groups'] = list(self.groups.values_list('id', flat=True)) + result['groups'] = list(g.pk for g in self.groups.all()) return result