respect access_restriction in mapdata API

This commit is contained in:
Laura Klünder 2017-07-13 19:22:57 +02:00
parent 9466c2559f
commit bbdfc9aadc
2 changed files with 17 additions and 4 deletions

View file

@ -29,6 +29,12 @@ def optimize_query(qs):
class MapdataViewSet(ReadOnlyModelViewSet): class MapdataViewSet(ReadOnlyModelViewSet):
def get_queryset(self):
qs = super().get_queryset()
if hasattr(qs.model, 'qs_for_request'):
return qs.model.qs_for_request(self.request)
return qs
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
qs = optimize_query(self.get_queryset()) qs = optimize_query(self.get_queryset())
geometry = ('geometry' in request.GET) geometry = ('geometry' in request.GET)
@ -36,7 +42,7 @@ class MapdataViewSet(ReadOnlyModelViewSet):
if not request.GET['level'].isdigit(): if not request.GET['level'].isdigit():
raise ValidationError(detail={'detail': _('%s is not an integer.') % 'level'}) raise ValidationError(detail={'detail': _('%s is not an integer.') % 'level'})
try: try:
level = Level.objects.get(pk=request.GET['level']) level = Level.qs_for_request(request).get(pk=request.GET['level'])
except Level.DoesNotExist: except Level.DoesNotExist:
raise NotFound(detail=_('level not found.')) raise NotFound(detail=_('level not found.'))
qs = qs.filter(level=level) qs = qs.filter(level=level)
@ -44,7 +50,7 @@ class MapdataViewSet(ReadOnlyModelViewSet):
if not request.GET['space'].isdigit(): if not request.GET['space'].isdigit():
raise ValidationError(detail={'detail': _('%s is not an integer.') % 'space'}) raise ValidationError(detail={'detail': _('%s is not an integer.') % 'space'})
try: try:
space = Space.objects.get(pk=request.GET['space']) space = Space.qs_for_request(request).get(pk=request.GET['space'])
except Space.DoesNotExist: except Space.DoesNotExist:
raise NotFound(detail=_('space not found.')) raise NotFound(detail=_('space not found.'))
qs = qs.filter(space=space) qs = qs.filter(space=space)
@ -187,6 +193,8 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet):
for name, value in subconditions.items())) for name, value in subconditions.items()))
if group is not None: if group is not None:
condition &= Q(**{model._meta.default_related_name+'__groups': group}) condition &= Q(**{model._meta.default_related_name+'__groups': group})
# noinspection PyUnresolvedReferences
condition &= model.q_for_request(self.request, prefix=model._meta.default_related_name+'__')
conditions.append(condition) conditions.append(condition)
queryset = queryset.filter(reduce(operator.or_, conditions)) queryset = queryset.filter(reduce(operator.or_, conditions))

View file

@ -1,4 +1,5 @@
from django.db import models from django.db import models
from django.db.models import Q
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from c3nav.mapdata.models.base import SerializableMixin, TitledMixin from c3nav.mapdata.models.base import SerializableMixin, TitledMixin
@ -35,6 +36,10 @@ class AccessRestrictionMixin(SerializableMixin, models.Model):
@classmethod @classmethod
def qs_for_request(cls, request): def qs_for_request(cls, request):
return cls.objects.filter(cls.q_for_request(request))
@classmethod
def q_for_request(cls, request, prefix=''):
if request.user.is_authenticated and request.user.is_superuser: if request.user.is_authenticated and request.user.is_superuser:
return cls.objects.all() return Q()
return cls.objects.filter(access_restriction__isnull=True) return Q(**{prefix+'access_restriction__isnull': True})