diff --git a/src/c3nav/mapdata/api.py b/src/c3nav/mapdata/api.py index 0309bf42..90d917f3 100644 --- a/src/c3nav/mapdata/api.py +++ b/src/c3nav/mapdata/api.py @@ -29,6 +29,12 @@ def optimize_query(qs): class MapdataViewSet(ReadOnlyModelViewSet): + def get_queryset(self): + qs = super().get_queryset() + if hasattr(qs.model, 'qs_for_request'): + return qs.model.qs_for_request(self.request) + return qs + def list(self, request, *args, **kwargs): qs = optimize_query(self.get_queryset()) geometry = ('geometry' in request.GET) @@ -36,7 +42,7 @@ class MapdataViewSet(ReadOnlyModelViewSet): if not request.GET['level'].isdigit(): raise ValidationError(detail={'detail': _('%s is not an integer.') % 'level'}) try: - level = Level.objects.get(pk=request.GET['level']) + level = Level.qs_for_request(request).get(pk=request.GET['level']) except Level.DoesNotExist: raise NotFound(detail=_('level not found.')) qs = qs.filter(level=level) @@ -44,7 +50,7 @@ class MapdataViewSet(ReadOnlyModelViewSet): if not request.GET['space'].isdigit(): raise ValidationError(detail={'detail': _('%s is not an integer.') % 'space'}) try: - space = Space.objects.get(pk=request.GET['space']) + space = Space.qs_for_request(request).get(pk=request.GET['space']) except Space.DoesNotExist: raise NotFound(detail=_('space not found.')) qs = qs.filter(space=space) @@ -187,6 +193,8 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet): for name, value in subconditions.items())) if group is not None: condition &= Q(**{model._meta.default_related_name+'__groups': group}) + # noinspection PyUnresolvedReferences + condition &= model.q_for_request(self.request, prefix=model._meta.default_related_name+'__') conditions.append(condition) queryset = queryset.filter(reduce(operator.or_, conditions)) diff --git a/src/c3nav/mapdata/models/access.py b/src/c3nav/mapdata/models/access.py index 9429cffc..593c20ef 100644 --- a/src/c3nav/mapdata/models/access.py +++ b/src/c3nav/mapdata/models/access.py @@ -1,4 +1,5 @@ from django.db import models +from django.db.models import Q from django.utils.translation import ugettext_lazy as _ from c3nav.mapdata.models.base import SerializableMixin, TitledMixin @@ -35,6 +36,10 @@ class AccessRestrictionMixin(SerializableMixin, models.Model): @classmethod def qs_for_request(cls, request): + return cls.objects.filter(cls.q_for_request(request)) + + @classmethod + def q_for_request(cls, request, prefix=''): if request.user.is_authenticated and request.user.is_superuser: - return cls.objects.all() - return cls.objects.filter(access_restriction__isnull=True) + return Q() + return Q(**{prefix+'access_restriction__isnull': True})