user get_submodels everywhere

This commit is contained in:
Laura Klünder 2017-06-22 19:27:51 +02:00
parent bd7dcc647e
commit 03d40add64
6 changed files with 14 additions and 45 deletions

View file

@ -12,10 +12,11 @@ from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet, ReadOnlyModelViewSet
from c3nav.mapdata.models import Building, Door, Hole, LocationGroup, Source, Space
from c3nav.mapdata.models.geometry.level import LEVEL_MODELS
from c3nav.mapdata.models.geometry.space import SPACE_MODELS, Area, Column, LineObstacle, Obstacle, Point, Stair
from c3nav.mapdata.models.geometry.level import LevelGeometryMixin
from c3nav.mapdata.models.geometry.space import Area, Column, LineObstacle, Obstacle, Point, SpaceGeometryMixin, Stair
from c3nav.mapdata.models.level import Level
from c3nav.mapdata.models.locations import LOCATION_MODELS, Location, LocationRedirect, LocationSlug, SpecificLocation
from c3nav.mapdata.models.locations import Location, LocationRedirect, LocationSlug, SpecificLocation
from c3nav.mapdata.utils.models import get_submodels
def optimize_query(qs):
@ -28,7 +29,7 @@ class MapdataViewSet(ReadOnlyModelViewSet):
def list(self, request, *args, **kwargs):
qs = optimize_query(self.get_queryset())
geometry = ('geometry' in request.GET)
if qs.model in LEVEL_MODELS and 'level' in request.GET:
if issubclass(qs.model, LevelGeometryMixin) and 'level' in request.GET:
if not request.GET['level'].isdigit():
raise ValidationError(detail={'detail': _('%s is not an integer.') % 'level'})
try:
@ -36,7 +37,7 @@ class MapdataViewSet(ReadOnlyModelViewSet):
except Level.DoesNotExist:
raise NotFound(detail=_('level not found.'))
qs = qs.filter(level=level)
if qs.model in SPACE_MODELS and 'space' in request.GET:
if issubclass(qs.model, SpaceGeometryMixin) and 'space' in request.GET:
if not request.GET['space'].isdigit():
raise ValidationError(detail={'detail': _('%s is not an integer.') % 'space'})
try:
@ -73,7 +74,7 @@ class LevelViewSet(MapdataViewSet):
@list_route(methods=['get'])
def geometrytypes(self, request):
return self.list_types(LEVEL_MODELS)
return self.list_types(get_submodels(LevelGeometryMixin))
@detail_route(methods=['get'])
def svg(self, requests, pk=None):
@ -93,7 +94,7 @@ class SpaceViewSet(MapdataViewSet):
@list_route(methods=['get'])
def geometrytypes(self, request):
return self.list_types(SPACE_MODELS)
return self.list_types(get_submodels(SpaceGeometryMixin))
class DoorViewSet(MapdataViewSet):
@ -152,7 +153,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet):
def list(self, request, *args, **kwargs):
queryset = sorted(chain(*(optimize_query(model.objects.filter(Q(can_search=True) | Q(can_describe=True)))
for model in LOCATION_MODELS)), key=lambda obj: obj.id)
for model in get_submodels(Location))), key=lambda obj: obj.id)
return Response([obj.serialize(include_type=True, detailed='detailed' in request.GET) for obj in queryset])
def retrieve(self, request, slug=None, *args, **kwargs):
@ -168,7 +169,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet):
@list_route(methods=['get'])
def types(self, request):
return MapdataViewSet.list_types(LOCATION_MODELS, geomtype=False)
return MapdataViewSet.list_types(get_submodels(Location), geomtype=False)
@list_route(methods=['get'])
def redirects(self, request):
@ -178,7 +179,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet):
def search(self, request):
# todo: implement caching here
results = sorted(chain(*(optimize_query(model.objects.filter(can_search=True))
for model in LOCATION_MODELS)), key=lambda obj: obj.id)
for model in get_submodels(Location))), key=lambda obj: obj.id)
search = request.GET.get('s')
if not search:
return Response([obj.serialize(include_type=True, detailed='detailed' in request.GET) for obj in results])